英文

Wav2vec 2.0大型VoxRex瑞典语(C)

免责声明:这是一个正在进行中的项目。请参阅 VoxRex 获取更多详细信息。

更新于2022年01月10日:已更新为VoxRex-C版本。

更新于2022年05月16日:论文编号为 here

通过使用瑞典广播、NST和Common Voice数据对KBs VoxRex large 模型进行了微调。在没有语言模型的情况下进行评估,结果如下:NST + Common Voice测试集(总句子的2%)的识别错误率(WER)为2.5%。Common Voice测试集的WER为8.49%,使用4元语言模型则为7.37%。

使用此模型时,请确保输入的语音采样率为16kHz。

性能*

*图表显示的是不包括额外的20k步Common Voice微调的性能。

训练

该模型在NST + CommonVoice上进行了120000次更新的微调,然后在仅使用CommonVoice进行了额外的20000次更新的微调。额外对CommonVoice的微调在NST+CommonVoice测试集上略微降低了性能,但在CommonVoice测试集上提高了性能。它似乎总体表现较好[citation needed]。

使用方法

可以直接使用该模型(无需语言模型),方法如下:

import torch
import torchaudio
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
test_dataset = load_dataset("common_voice", "sv-SE", split="test[:2%]").
processor = Wav2Vec2Processor.from_pretrained("KBLab/wav2vec2-large-voxrex-swedish")
model = Wav2Vec2ForCTC.from_pretrained("KBLab/wav2vec2-large-voxrex-swedish")
resampler = torchaudio.transforms.Resample(48_000, 16_000)
# Preprocessing the datasets.
# We need to read the aduio files as arrays
def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = resampler(speech_array).squeeze().numpy()
    return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
    logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"][:2])