英文

Czech wav2vec2-xls-r-300m-cs-250

这个模型是基于 facebook/wav2vec2-xls-r-300m 在 common_voice 8.0 数据集以及其他以下列出的数据集上进行微调的版本。

它在评估集上达到以下结果:

  • 损失:0.1271
  • 词错误率(WER):0.1475
  • 字符错误率(CER):0.0329

使用语言模型的 eval.py 脚本的结果为:

  • 词错误率(WER):0.07274312090176113
  • 字符错误率(CER):0.021207369275558875

模型描述

使用数据集 Common Voice 在 Czech 上对 facebook/wav2vec2-large-xlsr-53 进行了微调。在使用此模型时,请确保语音输入采样率为16kHz。

可以直接使用该模型(无需语言模型),如下所示:

import torch
import torchaudio
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

test_dataset = load_dataset("mozilla-foundation/common_voice_8_0", "cs", split="test[:2%]")

processor = Wav2Vec2Processor.from_pretrained("comodoro/wav2vec2-xls-r-300m-cs-250")
model = Wav2Vec2ForCTC.from_pretrained("comodoro/wav2vec2-xls-r-300m-cs-250")

resampler = torchaudio.transforms.Resample(48_000, 16_000)

# Preprocessing the datasets.
# We need to read the aduio files as arrays
def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = resampler(speech_array).squeeze().numpy()
    return batch

test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset[:2]["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)

with torch.no_grad():
    logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits

predicted_ids = torch.argmax(logits, dim=-1)

print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset[:2]["sentence"])

评估

可以使用附带的 eval.py 脚本进行模型评估:

python eval.py --model_id comodoro/wav2vec2-xls-r-300m-cs-250 --dataset mozilla-foundation/common-voice_8_0 --split test --config cs

训练和评估数据

用于训练的数据集包括 Common Voice 8.0 的 train 和 validation 数据集,以及以下数据集:

训练超参数

训练过程中使用了以下超参数:

  • 学习率:0.0001
  • 训练批次大小:32
  • 评估批次大小:8
  • 种子:42
  • 优化器:Adam,参数为betas=(0.9,0.999)和epsilon=1e-08
  • lr_scheduler_type:线性
  • lr_scheduler_warmup_steps:800
  • 训练轮数:5
  • 混合精度训练:Native AMP

训练结果

Training Loss Epoch Step Validation Loss Wer Cer
3.4203 0.16 800 3.3148 1.0 1.0
2.8151 0.32 1600 0.8508 0.8938 0.2345
0.9411 0.48 2400 0.3335 0.3723 0.0847
0.7408 0.64 3200 0.2573 0.2840 0.0642
0.6516 0.8 4000 0.2365 0.2581 0.0595
0.6242 0.96 4800 0.2039 0.2433 0.0541
0.5754 1.12 5600 0.1832 0.2156 0.0482
0.5626 1.28 6400 0.1827 0.2091 0.0463
0.5342 1.44 7200 0.1744 0.2033 0.0468
0.4965 1.6 8000 0.1705 0.1963 0.0444
0.5047 1.76 8800 0.1604 0.1889 0.0422
0.4814 1.92 9600 0.1604 0.1827 0.0411
0.4471 2.09 10400 0.1566 0.1822 0.0406
0.4509 2.25 11200 0.1619 0.1853 0.0432
0.4415 2.41 12000 0.1513 0.1764 0.0397
0.4313 2.57 12800 0.1515 0.1739 0.0392
0.4163 2.73 13600 0.1445 0.1695 0.0377
0.4142 2.89 14400 0.1478 0.1699 0.0385
0.4184 3.05 15200 0.1430 0.1669 0.0376
0.3886 3.21 16000 0.1433 0.1644 0.0374
0.3795 3.37 16800 0.1426 0.1648 0.0373
0.3859 3.53 17600 0.1357 0.1604 0.0361
0.3762 3.69 18400 0.1344 0.1558 0.0349
0.384 3.85 19200 0.1379 0.1576 0.0359
0.3762 4.01 20000 0.1344 0.1539 0.0346
0.3559 4.17 20800 0.1339 0.1525 0.0351
0.3683 4.33 21600 0.1315 0.1518 0.0342
0.3572 4.49 22400 0.1307 0.1507 0.0342
0.3494 4.65 23200 0.1294 0.1491 0.0335
0.3476 4.81 24000 0.1287 0.1491 0.0336
0.3475 4.97 24800 0.1271 0.1475 0.0329

框架版本

  • Transformers 4.16.2
  • Pytorch 1.10.1+cu102
  • Datasets 1.18.3
  • Tokenizers 0.11.0