模型:
theblackcat102/mt0-chat-large
这是谷歌的Meena模型的修剪版本,它通过类似于Google的 Meena chatbot 的方式进行修剪。Meena有一个发展演变的Transformer编码器块和13个发展演变的Transformer解码器块,如下图所示。编码器负责处理对话上下文以帮助Meena理解先前对话中已经说过的内容。然后解码器使用这些信息来形成实际的回应。通过调整超参数,我们发现更强大的解码器是提高对话质量的关键。
这个修剪版本的模型是为了作为Meena类似架构的预训练版本,用于将来研究这种架构设计思想。
下面是模型修剪的方式:
将分词器通过多语种对话语料库的列表,只保留使用的一组标记ID。然后重新映射分词器和嵌入以削减约40%的嵌入权重。需要注意的是,字节级标记仍然保留,因此该分词器仍然可以处理未见字符。
使用修剪的mt0模型的简化版本,将编码器层减少到仅4层(使编码器-解码器比例约为1:6)。理想情况下,我们只想保留2层编码器,但我发现这在后期阶段太弱。
因为新的编码器具有与旧编码器不同的输出嵌入,所以我们需要使用原始编码器作为教师进行重新训练。在这个项目中,我们简单地将原始编码器的输出特征作为潜在的ground truth,而新的较小编码器的任务是通过MAE损失来拟合ground truth潜在。
所有模型都在损失曲线稳定并且不再改进的时候停止训练。
没有重新初始化阶段:
input :what is one plus one?</s> trimmed output : Extendeds input :你想知道關於我的什麼?</s> trimmed output : 你个朋友 input :こんにちは!お元気</s> trimmed output : !- -_n_wip-------------D2
重新初始化阶段:
input :what is one plus one?</s> trimmed output : hundred input :你想知道關於我的什麼?</s> trimmed output : 你們?我對嗎?道不那做的對說對這做了受關沒? input :こんにちは!お元気</s> trimmed output : !,
请注意,由于大约30%的权重被修剪掉,无法达到原始模型的性能。