模型:
airesearch/wav2vec2-large-xlsr-53-th
在泰语上对wav2vec2-large-xlsr-53进行微调
我们使用 Common Voice Corpus 7.0 的泰语示例对 wav2vec2-large-xlsr-53 进行微调,基于 Fine-tuning Wav2Vec2 for English ASR 。笔记本和脚本可以在 vistec-ai/wav2vec2-large-xlsr-53-th 中找到。预训练模型和处理器可以在 airesearch/wav2vec2-large-xlsr-53-th 中找到。
在 robust-speech-event 的eval.py中添加音节划分器syllable_tokenize,词语划分器word_tokenize ( PyThaiNLP )和 deepcut 划分器
> python eval.py --model_id ./ --dataset mozilla-foundation/common_voice_7_0 --config th --split test --log_outputs --thai_tokenizer newmm/syllable/deepcut/cer
WER PyThaiNLP 2.3.1 | WER deepcut | SER | CER | |
---|---|---|---|---|
Only Tokenization | 0.9524% | 2.5316% | 1.2346% | 0.1623% |
Cleaning rules and Tokenization | TBD | TBD | TBD | TBD |
#load pretrained processor and model processor = Wav2Vec2Processor.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th") model = Wav2Vec2ForCTC.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th") #function to resample to 16_000 def speech_file_to_array_fn(batch, text_col="sentence", fname_col="path", resampling_to=16000): speech_array, sampling_rate = torchaudio.load(batch[fname_col]) resampler=torchaudio.transforms.Resample(sampling_rate, resampling_to) batch["speech"] = resampler(speech_array)[0].numpy() batch["sampling_rate"] = resampling_to batch["target_text"] = batch[text_col] return batch #get 2 examples as sample input test_dataset = test_dataset.map(speech_file_to_array_fn) inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True) #infer with torch.no_grad(): logits = model(inputs.input_values,).logits predicted_ids = torch.argmax(logits, dim=-1) print("Prediction:", processor.batch_decode(predicted_ids)) print("Reference:", test_dataset["sentence"][:2]) >> Prediction: ['และ เขา ก็ สัมผัส ดีบุก', 'คุณ สามารถ รับทราบ เมื่อ ข้อความ นี้ ถูก อ่าน แล้ว'] >> Reference: ['และเขาก็สัมผัสดีบุก', 'คุณสามารถรับทราบเมื่อข้อความนี้ถูกอ่านแล้ว']
Common Voice Corpus 7.0]( https://commonvoice.mozilla.org/en/datasets )包含了5GB的泰语验证数据,总时长为255小时。我们使用pythainlp.tokenize.word_tokenize进行预分词。我们使用notebooks/cv-preprocess.ipynb中的清洗规则进行数据预处理,通过 ekapolc/Thai_commonvoice_split 进行去重和拆分,以避免在 Common Voice Corpus 7.0 中清洗后进行随机拆分的数据泄漏,同时保留大部分数据作为训练集。数据集加载脚本是scripts/th_common_voice_70.py。您可以使用train_cleand.tsv,validation_cleaned.tsv和test_cleaned.tsv与该脚本一起使用,以获得与我们相同的拆分。生成的数据集如下:
DatasetDict({ train: Dataset({ features: ['path', 'sentence'], num_rows: 86586 }) test: Dataset({ features: ['path', 'sentence'], num_rows: 2502 }) validation: Dataset({ features: ['path', 'sentence'], num_rows: 3027 }) })
我们在单个V100 GPU上使用以下配置进行微调,并选择验证损失最低的检查点。微调脚本是scripts/wav2vec2_finetune.py
# create model model = Wav2Vec2ForCTC.from_pretrained( "facebook/wav2vec2-large-xlsr-53", attention_dropout=0.1, hidden_dropout=0.1, feat_proj_dropout=0.0, mask_time_prob=0.05, layerdrop=0.1, gradient_checkpointing=True, ctc_loss_reduction="mean", pad_token_id=processor.tokenizer.pad_token_id, vocab_size=len(processor.tokenizer) ) model.freeze_feature_extractor() training_args = TrainingArguments( output_dir="../data/wav2vec2-large-xlsr-53-thai", group_by_length=True, per_device_train_batch_size=32, gradient_accumulation_steps=1, per_device_eval_batch_size=16, metric_for_best_model='wer', evaluation_strategy="steps", eval_steps=1000, logging_strategy="steps", logging_steps=1000, save_strategy="steps", save_steps=1000, num_train_epochs=100, fp16=True, learning_rate=1e-4, warmup_steps=1000, save_total_limit=3, report_to="tensorboard" )
我们使用 PyThaiNLP 2.3.1和 deepcut 对分词为单词的测试集进行WER和CER评估。我们还测量了使用 TNC n-gram进行拼写修正时的性能。评估代码可以在notebooks/wav2vec2_finetuning_tutorial.ipynb中找到。基准测试在test-unique拆分上执行。
WER PyThaiNLP 2.3.1 | WER deepcut | CER | |
---|---|---|---|
12321321 | 23.04 | 7.57 | |
Ours without spell correction | 13.634024 | 8.152052 | 2.813019 |
Ours with spell correction | 17.996397 | 14.167975 | 5.225761 |
Google Web Speech API※ | 13.711234 | 10.860058 | 7.357340 |
Microsoft Bing Speech API※ | 12.578819 | 9.620991 | 5.016620 |
Amazon Transcribe※ | 21.86334 | 14.487553 | 7.077562 |
NECTEC AI for Thai Partii API※ | 20.105887 | 15.515631 | 9.551027 |
※不使用Common Voice 7.0数据进行API微调