Whisper 是 OpenAI 推出的一款基于 Transformer 的开源 ASR 模型。在我的案例中,该模型针对患有言语障碍的人的语音记录数据集进行了微调。
我已经尝试了以下选项来进行 CPU 推理:
总结
以下是最终结果:
使用 HuggingFace 管道进行私下推理
由于我们的模型是使用transformers库进行预先训练并存储在 HuggingFace 中心的,因此第一个也是最直接的选择是使用内置管道。
class WhisperService:
_initialized = False
def __init__(self, language='en'):
if not WhisperService._initialized:
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
transformers_log.set_verbosity_error()
self.model_name = utils.MODEL_NAME
self.language = language
self.task = utils.TASK
try:
# Initialize model and related components
log.info("Starting Whisper service...")
self.peft_config = self.generate_model_config()
self.model = self.get_whisper_model_from_hf(self.peft_config)
self.tokenizer = self.create_tokenizer(self.peft_config)
self.processor = self.create_processor(self.peft_config)
self.pipeline_asr, self.forced_decoder_ids = self.create_whisper_pipeline(
self.model, self.tokenizer, self.processor
)
WhisperService._initialized = True
log.info("Whisper service started with success!")
except Exception as e:
log.error(f"Error during Whisper service init: {str(e)}")
raise
def generate_model_config(self) -> PeftConfig:
"""
"""
try:
login(token=os.environ['API_TOKEN'])
config = PeftConfig.from_pretrained(self.model_name)
log.info("Model config generated")
return config
except Exception as e:
log.error(f"Error during model config generation: {str(e)}")
raise
def get_whisper_model_from_hf(self, peft_config: PeftConfig) -> PeftModel:
"""
"""
try:
model = WhisperForConditionalGeneration.from_pretrained(
peft_config.base_model_name_or_path
)
# Check if GPU is available
if torch.cuda.is_available():
log.info("Model loaded on GPU")
else:
log.info("Model loaded on CPU")
model = PeftModel.from_pretrained(model, self.model_name)
log.info("Whisper model configured with PeftModel")
return model
except Exception as e:
log.error(f"Error during Whisper model loading: {str(e)}")
raise
def create_processor(self, peft_config: PeftConfig) -> WhisperProcessor:
"""
"""
try:
processor = WhisperProcessor.from_pretrained(
peft_config.base_model_name_or_path,
language=self.language,
task=self.task
)
log.info("WhisperProcessor created")
return processor
except Exception as e:
log.error(f"Error during WhisperProcessor creation: {str(e)}")
raise
def create_tokenizer(self, peft_config: PeftConfig) -> WhisperTokenizer:
"""
"""
try:
tokenizer = WhisperTokenizer.from_pretrained(
peft_config.base_model_name_or_path,
language=self.language,
task=self.task
)
log.info("WhisperTokenizer created")
return tokenizer
except Exception as e:
log.error(f"Error during WhisperTokenizer creation: {str(e)}")
raise
def create_whisper_pipeline(self, model: PreTrainedModel, tokenizer: WhisperTokenizer,
processor: WhisperProcessor) -> tuple:
"""
"""
try:
feature_extractor = processor.feature_extractor
pipe_lora = AutomaticSpeechRecognitionPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor
)
forced_decoder_ids = processor.get_decoder_prompt_ids(language=self.language, task=self.task)
log.info("Pipeline created")
return pipe_lora, forced_decoder_ids
except Exception as e:
log.error(f"Error during Pipeline creation: {str(e)}")
raise
async def transcribe(self, audio_path: str) -> str:
"""
"""
try:
loop = asyncio.get_event_loop()
log.info(f"Transcribing the following file audio: {audio_path}")
with torch.cuda.amp.autocast():
text = await loop.run_in_executor(
None,
lambda:
self.pipeline_asr(audio_path, generate_kwargs={"forced_decoder_ids": self.forced_decoder_ids},
max_new_tokens=255)["text"]
)
log.info("Transcription completed!")
return text
except Exception as e:
log.error(f"Error during transcription: {str(e)}")
raise
我们从HuggingFace平台上获取模型(utils.MODEL_NAME是HuggingFace模型的标识符,例如“miosipof/asr_EN_medium_v1”)。
通过以下代码建立处理流程:
pipe_lora = AutomaticSpeechRecognitionPipeline(
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor
)
ONNX运行时
模型转换为ONNX格式
让我们导入一些库:
from onnxruntime.quantization import quantize_dynamic, QuantType
import onnx
import numpy as np
import onnxruntime as ort
import torchaudio
接下来,我们将使用transformers optimum库和CLI将模型从HuggingFace转换为ONNX格式:
pip install optimum[exporters]
optimum-cli export onnx --model local_path --task trascribe local_model_folder/
这将在local_path路径下的原始模型基础上,在local_model_folder中创建一系列文件。
让我们来设置一个ONNX会话:
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
session_options.intra_op_num_threads = 44
session_options.inter_op_num_threads = 16
我们将分别处理编码器和解码器:
sess_encoder = ort.InferenceSession("./path_to/encoder_q.onnx")"./path_to/encoder_q.onnx")
sess_decoder = ort.InferenceSession("./path_to/decoder_q.onnx")
为了提高性能,我们定义了一个模型量化函数,然后将其应用于编码器和解码器:
def quantize_onnx_model(onnx_model_path, quantized_model_path):
onnx_opt_model = onnx.load(onnx_model_path)
quantize_dynamic(onnx_model_path,
quantized_model_path,
weight_type=QuantType.QUInt8) #chnage QInt8 to QUInt8
quantize_onnx_model("./path_to/encoder.onnx","./path_to/encoder_q.onnx")
quantize_onnx_model("./path_to/decoder.onnx","./path_to/decoder_q.onnx")
使用ONNX模型进行推理
让我们初始化处理器和分词器:
processor = WhisperProcessor.from_pretrained("./path_to/q_whisper_onnx")"./path_to/q_whisper_onnx")
# tokenizer = processor.tokenizer
tokenizer = whisper.decoding.get_tokenizer(
model.is_multilingual,
task="transcribe",
language="en",
)
音频预处理脚本(类似于Whisper的log_mel_spectrogram()函数),用于将.wav文件转换为log_mel频谱图数组:
def preprocessing_torchaudio(audio_path):
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
mel = processor.feature_extractor(waveform[0], sampling_rate=16000).input_features
return torch.tensor(mel, dtype=torch.float32)
对于一个样本.wav文件,其音频数组x_mel将是:
x_mel = preprocessing_librosa("./path_to/audio.wav")"./path_to/audio.wav")
最后,使用我们量化后的ONNX模型进行序列编码和解码的自定义循环:
max_tokens = 448448
out_encoder, = sess_encoder.run(["last_hidden_state"], {"input_features": x_mel.numpy()})
next_token = tokenizer.sot
# next_token = "<|startoftranscript|>"
while x_tokens.shape[1] <= max_tokens and next_token != tokenizer.eot:
out_decoder, = sess_decoder.run(
["logits"],
{
"input_ids": x_tokens.numpy(),
"encoder_hidden_states": out_encoder,
},
)
next_token = out_decoder[0, -1].argmax()
next_token = torch.tensor(next_token)
print(next_token,next_token.shape,x_tokens.shape)
x_tokens = torch.concat(
[x_tokens, next_token.reshape(1, 1)],
axis=1,
)
print(tokenizer.decode(x_tokens[0]))
我把代码留在了这种不太理想的格式,因为ONNX的推理性能总是比通过OpenVino或PyTorch进行推理要差很多,这可能是因为ONNX格式最初是为卷积神经网络开发的,可能不是优化transformer的最佳选择。
OpenVino运行时
使用OpenVino进行推理的实现要简单得多。
首先,导入一些必要的库:
import os
from transformers import WhisperProcessor, logging as transformers_log
from optimum.intel.openvino import OVModelForSpeechSeq2Seq
import torchaudio
import torch
import numpy as np
import time
from src import log
from src.utils import utils
import asyncio
模型转换为OpenVino格式
我们将使用transformers optimum库将我们的HuggingFace模型导出为OpenVino格式(你可以将openai/whisper-medium替换为你自己的模型或HuggingFace平台上托管的其他任何Whisper模型):
[openvino,nncf]optimum-cli export openvino --model openai/whisper-medium --weight-format int8 asr_openvino_int8
注意,在导出时我们使用了int8量化。我也尝试过int4量化,但在我的情况下,它对转录质量影响很大。
以下是我们将用于获取OpenVino模型的方法:
def get_openvino_model(self):def get_openvino_model(self):
ov_config = {"CACHE_DIR": ""}
self.model = OVModelForSpeechSeq2Seq.from_pretrained(self.ov_model_name, ov_config=ov_config, compile=False)
log.info("OpenVino model loaded from " + str(self.ov_model_name))
try
ov_model_path = Path("src/model/" + self.model_name.replace("/", "_"))
ov_config = {"CACHE_DIR": ""}
if not ov_model_path.exists():
self.model = OVModelForSpeechSeq2Seq.from_pretrained(
self.model_name,
ov_config=ov_config,
export=True,
compile=False,
load_in_8bit=False,
)
self.model.half()
self.model.save_pretrained(ov_model_path)
log.info("HF model converted to OpenVino and saved in " + str(ov_model_path))
else:
self.model = OVModelForSpeechSeq2Seq.from_pretrained(ov_model_path, ov_config=ov_config, compile=False)
log.info("OpenVino model loaded from " + str(ov_model_path))
except Exception as e:
log.error(f"Error during OpenVino model loading: {str(e)}")
raise
return self.model
在这里,self.ov_model_name 将是我们之前用于optimum CLI命令的 asr_openvino_int8(加上其路径)。我使用了一个不太优雅的 self.model_name.replace("/", "_") 函数来将HuggingFace上的URL转换为模型名称。
接下来,需要编译OpenVino模型,因为它将直接通过OpenVino运行时加载:
def compile_openvino_model(self):def compile_openvino_model(self):
"""
"""
try:
if torch.cuda.is_available():
log.info("Model loaded on GPU")
self.device = "GPU"
else:
log.info("Model loaded on CPU")
self.device = "CPU"
self.model.to(self.device)
self.model.compile()
log.info("OpenVino model compiled successfully")
except Exception as e:
log.error(f"Error during OpenVino model compilation: {str(e)}")
raise
return self.model
使用OpenVino模型进行推理
现在,我们定义两个辅助函数来创建用于编码的Whisper处理器(与前向传播相比,这所花费的时间微不足道)以及音频预处理:
def create_processor(self):def create_processor(self):
"""
"""
try:
processor = WhisperProcessor.from_pretrained(
self.model_name,
language=self.language,
task=self.task
)
log.info("WhisperProcessor created")
return processor
except Exception as e:
log.error(f"Error during WhisperProcessor creation: {str(e)}")
raise
def preprocess_audio(self, waveform):
"""
"""
# compute log-Mel input features from input audio array
audio_features = self.processor.feature_extractor(waveform, sampling_rate=self.sr).input_features[0]
audio_features = torch.tensor(np.array([audio_features]))
return audio_features
最后,定义管道,即一个用于转录的异步函数——类似于HuggingFace管道的实现:
def openvino_pipeline(self,audio_path):def openvino_pipeline(self,audio_path):
print("1 - starting audio load:", time.time())
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sr)(waveform)[0]
print("2 - starting preprocessing:", time.time())
audio_features = self.preprocess_audio(waveform)
print("3 - starting forward pass:", time.time())
predicted_ids = self.model.generate(audio_features, max_new_tokens=224)
print("4 - starting decoding:", time.time())
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription[0]
async def transcribe(self, audio_path: str) -> str:
"""
"""
try:
loop = asyncio.get_event_loop()
log.info(f"Transcribing the following file audio: {audio_path}")
print("0 - starting the loop:",time.time())
text = await loop.run_in_executor(
None,
lambda: self.openvino_pipeline(audio_path)
)
print("5 - all done:", time.time())
log.info("Transcription completed!")
return text
except Exception as e:
log.error(f"Error during transcription: {str(e)}")
raise
以下是OpenVino推理类的完整代码,
class OpenVinoService:
_initialized = False
def __init__(self, language='en'):
if not OpenVinoService._initialized:
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
transformers_log.set_verbosity_error()
self.model_name = utils.MERGED_MODEL_NAME
self.ov_model_name = utils.OV_MODEL
self.language = language
self.task = utils.TASK
self.device = "CPU"
self.sr = utils.SAMPLING_RATE
try:
# Initialize model and related components
log.info("Starting OpenVino service...")
self.model = self.get_openvino_model()
self.compile_openvino_model()
self.processor = self.create_processor()
OpenVinoService._initialized = True
log.info("OpenVino service started with success!")
except Exception as e:
log.error(f"Error during OpenVino service init: {str(e)}")
raise
def get_openvino_model(self):
"""
"""
ov_config = {"CACHE_DIR": ""}
self.model = OVModelForSpeechSeq2Seq.from_pretrained(self.ov_model_name, ov_config=ov_config, compile=False)
log.info("OpenVino model loaded from " + str(self.ov_model_name))
try:
ov_model_path = Path("src/model/" + self.model_name.replace("/", "_"))
ov_config = {"CACHE_DIR": ""}
if not ov_model_path.exists():
self.model = OVModelForSpeechSeq2Seq.from_pretrained(
self.model_name,
ov_config=ov_config,
export=True,
compile=False,
load_in_8bit=False,
)
self.model.half()
self.model.save_pretrained(ov_model_path)
log.info("HF model converted to OpenVino and saved in " + str(ov_model_path))
else:
self.model = OVModelForSpeechSeq2Seq.from_pretrained(ov_model_path, ov_config=ov_config, compile=False)
log.info("OpenVino model loaded from " + str(ov_model_path))
except Exception as e:
log.error(f"Error during OpenVino model loading: {str(e)}")
raise
return self.model
def compile_openvino_model(self):
"""
"""
try:
if torch.cuda.is_available():
log.info("Model loaded on GPU")
self.device = "GPU"
else:
log.info("Model loaded on CPU")
self.device = "CPU"
self.model.to(self.device)
self.model.compile()
log.info("OpenVino model compiled successfully")
except Exception as e:
log.error(f"Error during OpenVino model compilation: {str(e)}")
raise
return self.model
def create_processor(self):
"""
"""
try:
processor = WhisperProcessor.from_pretrained(
self.model_name,
language=self.language,
task=self.task
)
log.info("WhisperProcessor created")
return processor
except Exception as e:
log.error(f"Error during WhisperProcessor creation: {str(e)}")
raise
def preprocess_audio(self, waveform):
"""
"""
# compute log-Mel input features from input audio array
audio_features = self.processor.feature_extractor(waveform, sampling_rate=self.sr).input_features[0]
audio_features = torch.tensor(np.array([audio_features]))
return audio_features
def openvino_pipeline(self,audio_path):
print("1 - starting audio load:", time.time())
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sr)(waveform)[0]
print("2 - starting preprocessing:", time.time())
audio_features = self.preprocess_audio(waveform)
print("3 - starting forward pass:", time.time())
predicted_ids = self.model.generate(audio_features, max_new_tokens=224)
print("4 - starting decoding:", time.time())
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)
return transcription[0]
async def transcribe(self, audio_path: str) -> str:
"""
"""
try:
loop = asyncio.get_event_loop()
log.info(f"Transcribing the following file audio: {audio_path}")
print("0 - starting the loop:",time.time())
text = await loop.run_in_executor(
None,
lambda: self.openvino_pipeline(audio_path)
)
print("5 - all done:", time.time())
log.info("Transcription completed!")
return text
except Exception as e:
log.error(f"Error during transcription: {str(e)}")
raise
PyTorch推理
通过直接实现Whisper的PyTorch代码来进行推理包括几个步骤:
让我们从HuggingFace平台上获取模型开始:
def get_hf_model(self):def get_hf_model(self):
"""
"""
try:
merged_model = WhisperForConditionalGeneration.from_pretrained(self.model_name)
pt_model_name = os.path.basename(self.model_name) + ".pth"
pt_dir_name = os.path.join("assets","pt_models")
self.pretrained_model_path = os.path.join(pt_dir_name, pt_model_name)
if not os.path.exists(pt_dir_name):
os.makedirs(pt_dir_name)
log.info(f"Directory {pt_dir_name} created and will be used to store PyTorch models")
else:
log.info(f"Directory {pt_dir_name} exists, using it to save PyTorch model")
torch.save(merged_model.state_dict(), self.pretrained_model_path)
log.info(f"HF model saved to {self.pretrained_model_path} in PyTorch format for conversion")
except Exception as e:
log.error(f"Error during HuggingFace model loading: {str(e)}")
raise
在这里,self.model_name代表我在HuggingFace上的模型ID(请注意,它应该是完整的合并模型,而不是适配器)。
从HuggingFace到PyTorch的模型转换
在transformers库实现的Whisper中使用的层名称与OpenAI原始仓库中使用的层名称不同。
从HuggingFace到OpenAI的映射函数是这样的:
def map_hf_to_pt(self,pretrained_weights):def map_hf_to_pt(self,pretrained_weights):
def rename_key(key):
new_key = key
for k, v in self.mapping:
new_key = new_key.replace(k, v)
return new_key
# Rename the keys in the state_dict
updated_weights = {rename_key(k): v for k, v in pretrained_weights.items()}
updated_weights.pop('proj_out.weight', None)
return updated_weights
在这里,self.mapping 是一个映射字典:
self.mapping = [ ('model.', ''),'model.', ''),
('decoder.layers', 'decoder.blocks'),
('encoder.layers', 'encoder.blocks'),
('encoder.embed_positions.weight', 'encoder.positional_embedding'),
('self_attn.k_proj', 'attn.key'),
('self_attn.q_proj', 'attn.query'),
('self_attn.v_proj', 'attn.value'),
('self_attn.out_proj', 'attn.out'),
('self_attn_layer_norm', 'attn_ln'),
('final_layer_norm', 'mlp_ln'),
('fc1', 'mlp.0'),
('fc2', 'mlp.2'),
('encoder_attn.k_proj','cross_attn.key'),
('encoder_attn.v_proj','cross_attn.value'),
('encoder_attn.q_proj','cross_attn.query'),
('encoder_attn.out_proj','cross_attn.out'),
('encoder_attn_layer_norm','cross_attn_ln'),
('decoder.embed_positions.weight','decoder.positional_embedding'),
('decoder.embed_tokens','decoder.token_embedding'),
('encoder.layer_norm','encoder.ln_post'),
('decoder.layer_norm','decoder.ln'),
]
现在,将这个映射应用到Whisper基础模型上,并使用我们从HuggingFace平台上下载的模型的预训练权重:
def set_pt_model(self):def set_pt_model(self):
model = whisper.load_model("medium")
log.info("Whisper base model loaded")
pretrained_model = torch.load(self.pretrained_model_path)
log.info(f"Whisper pretrained model loaded from {self.pretrained_model_path}")
# Extract state_dict if the loaded model is not already a state_dict
if hasattr(pretrained_model, "state_dict"):
pretrained_weights = pretrained_model.state_dict() # extract the state dict
else:
pretrained_weights = pretrained_model # it's already a state_dict
#######################################################################
updated_weights = self.map_hf_to_pt(pretrained_weights)
model.load_state_dict(updated_weights, strict=True)
log.info(f"Model weights mapped from HuggingFace model to PyTorch")
######################################################################
model.to(self.device)
model.requires_grad_(False)
model.eval()
log.info("Whisper PyTorch model loaded on " + str(self.device))
return model
使用PyTorch进行推理
我们几乎准备就绪了。接下来定义Whisper处理器和编码函数:
def create_processor(self):def create_processor(self):
"""
"""
try:
processor = WhisperProcessor.from_pretrained(
self.model_name,
language=self.language,
task=self.task
)
log.info("WhisperProcessor created")
return processor
except Exception as e:
log.error(f"Error during WhisperProcessor creation: {str(e)}")
raise
def preprocess_audio(self, waveform):
"""
"""
# compute log-Mel input features from input audio array
mel = self.processor.feature_extractor(waveform, sampling_rate=self.sr).input_features
return torch.tensor(mel, dtype=torch.float32)
最后,定义管道和转录函数:
def inference_pipeline(self,audio_path):def inference_pipeline(self,audio_path):
log.info("1 - Starting audio load:")
# waveform, sample_rate = librosa.load(audio_path, sr=self.sr)
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sr)(waveform)[0]
log.info("2 - starting preprocessing:")
audio_features = self.preprocess_audio(waveform)
log.info("3 - Starting forward pass:")
with torch.no_grad():
result = whisper.decode(
self.model,
audio_features,
options=whisper.DecodingOptions(
fp16=False,
language="it",
without_timestamps=True,
suppress_blank=False,
suppress_tokens=[],
),
)
return result[0].text
async def transcribe(self, audio_path: str) -> DecodingResult | list[DecodingResult]:
"""
"""
try:
loop = asyncio.get_event_loop()
log.info(f"Transcribing the following file audio: {audio_path}")
log.info("Transcription started...")
text = await loop.run_in_executor(
None,
lambda: self.inference_pipeline(audio_path)
)
log.info("Transcription completed!")
return text
except Exception as e:
log.error(f"Error during transcription: {str(e)}")
raise
以下是PyTorch推理类实现的完整代码。请注意在初始化期间使用的torch.set_num_threads(num_threads)——在这行代码中,我们设置了将用于推理的CPU核心数量,这对性能有很大影响:
import os
from src import log
from src.utils import utils
import asyncio
import whisper
from whisper import DecodingResult
from transformers import WhisperForConditionalGeneration, WhisperProcessor, logging as transformers_log
from huggingface_hub import hf_hub_download, login
import torch
import torchaudio
import torch.quantization
class InferenceService:
_initialized = False
def __init__(self, language='it', num_threads=1, quantization=True, device = "cpu"):
try:
login(token=os.environ['API_TOKEN'])
log.info("HuggingFace login successful")
except Exception as e:
log.error(f"Error during HuggingFace login: {str(e)}")
raise
if not InferenceService._initialized:
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
transformers_log.set_verbosity_error()
self.model_name = utils.MERGED_MODEL_NAME
self.language = language
self.pytorch_converted_model_source = utils.PRETRAINED_MODEL_PTH
self.pytorch_converted_model_filename = utils.PRETRAINED_MODEL_FILENAME
self.task = utils.TASK
self.device = device
self.sr = utils.SAMPLING_RATE
self.mapping = utils.HF_PT_MAPPING
try:
# Initialize model and related components
log.info("Starting PyTorch Inference service...")
try:
self.pretrained_model_path = hf_hub_download(repo_id=self.pytorch_converted_model_source,
filename=self.pytorch_converted_model_filename)
log.info(f"Whisper pretrained model downloaded to {self.pretrained_model_path}")
except Exception as e:
log.info(f"Unable to download the PyTorch model: {str(e)} - switching to model from HF for conversion")
self.get_hf_model()
self.model = self.set_pt_model()
if quantization:
self.model = torch.quantization.quantize_dynamic(self.model,
{torch.nn.Linear},
dtype=torch.qint8)
self.model = self.model.cpu()
self.processor = self.create_processor()
InferenceService._initialized = True
log.info("PyTorch Inference service started with success!")
except Exception as e:
log.error(f"Error during PyTorch Inference service init: {str(e)}")
raise
torch.set_num_threads(num_threads)
log.info(f"Number of threads set to {num_threads} for PyTorch calculations")
def get_hf_model(self):
"""
"""
try:
merged_model = WhisperForConditionalGeneration.from_pretrained(self.model_name)
pt_model_name = os.path.basename(self.model_name) + ".pth"
pt_dir_name = os.path.join("assets","pt_models")
self.pretrained_model_path = os.path.join(pt_dir_name, pt_model_name)
if not os.path.exists(pt_dir_name):
os.makedirs(pt_dir_name)
log.info(f"Directory {pt_dir_name} created and will be used to store PyTorch models")
else:
log.info(f"Directory {pt_dir_name} exists, using it to save PyTorch model")
torch.save(merged_model.state_dict(), self.pretrained_model_path)
log.info(f"HF model saved to {self.pretrained_model_path} in PyTorch format for conversion")
except Exception as e:
log.error(f"Error during HuggingFace model loading: {str(e)}")
raise
return 1
def map_hf_to_pt(self,pretrained_weights):
def rename_key(key):
new_key = key
for k, v in self.mapping:
new_key = new_key.replace(k, v)
return new_key
# Rename the keys in the state_dict
updated_weights = {rename_key(k): v for k, v in pretrained_weights.items()}
updated_weights.pop('proj_out.weight', None)
return updated_weights
def set_pt_model(self):
model = whisper.load_model("medium")
log.info("Whisper base model loaded")
pretrained_model = torch.load(self.pretrained_model_path)
log.info(f"Whisper pretrained model loaded from {self.pretrained_model_path}")
# Extract state_dict if the loaded model is not already a state_dict
if hasattr(pretrained_model, "state_dict"):
pretrained_weights = pretrained_model.state_dict() # extract the state dict
else:
pretrained_weights = pretrained_model # it's already a state_dict
#######################################################################
updated_weights = self.map_hf_to_pt(pretrained_weights)
model.load_state_dict(updated_weights, strict=True)
log.info(f"Model weights mapped from HuggingFace model to PyTorch")
# Activate to save converted model and/or its weights
# torch.save(model, 'src/model/whisper_pretrained_converted.pth')
# torch.save(updated_weights, 'src/model/whisper_pretrained_converted_weights.pth')
######################################################################
model.to(self.device)
model.requires_grad_(False)
model.eval()
log.info("Whisper PyTorch model loaded on " + str(self.device))
return model
def create_processor(self):
"""
"""
try:
processor = WhisperProcessor.from_pretrained(
self.model_name,
language=self.language,
task=self.task
)
log.info("WhisperProcessor created")
return processor
except Exception as e:
log.error(f"Error during WhisperProcessor creation: {str(e)}")
raise
def preprocess_audio(self, waveform):
"""
"""
# compute log-Mel input features from input audio array
mel = self.processor.feature_extractor(waveform, sampling_rate=self.sr).input_features
return torch.tensor(mel, dtype=torch.float32)
def inference_pipeline(self,audio_path):
log.info("1 - Starting audio load:")
# waveform, sample_rate = librosa.load(audio_path, sr=self.sr)
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sr)(waveform)[0]
log.info("2 - starting preprocessing:")
audio_features = self.preprocess_audio(waveform)
log.info("3 - Starting forward pass:")
with torch.no_grad():
result = whisper.decode(
self.model,
audio_features,
options=whisper.DecodingOptions(
fp16=False,
language="it",
without_timestamps=True,
suppress_blank=False,
suppress_tokens=[],
),
)
return result[0].text
async def transcribe(self, audio_path: str) -> DecodingResult | list[DecodingResult]:
"""
"""
try:
loop = asyncio.get_event_loop()
log.info(f"Transcribing the following file audio: {audio_path}")
log.info("Transcription started...")
text = await loop.run_in_executor(
None,
lambda: self.inference_pipeline(audio_path)
)
log.info("Transcription completed!")
return text
except Exception as e:
log.error(f"Error during transcription: {str(e)}")
raise