DistilBERT:一种轻量级的BERT变体,实现模型压缩

2023年10月08日 由 neo 发表 893 0

近年来,大型语言模型的发展呈现出爆炸式的增长。BERT成为了最受欢迎和高效的模型之一,能够以高精度解决广泛的自然语言处理(NLP)任务。在BERT之后,还出现了一系列其他的模型,也展示了出色的结果。

1_2qXbvSTvQOLc2J9Qsmv_og

一个容易观察到的明显趋势是,随着时间的推移,大型语言模型(LLM)倾向于变得更加复杂,通过指数增加它们训练所需的参数和数据。深度学习的研究表明,这样的技术通常会带来更好的结果。不幸的是,机器学习领域已经面临了几个关于LLM的问题,而可扩展性已经成为有效训练、存储和使用它们的主要障碍。

考虑到这个问题,一些特别的技术已经被设计出来用于压缩LLM。压缩算法的目标是减少训练时间、降低内存消耗或加速模型推理。在实践中使用最常见的三种压缩技术如下:

  • 知识蒸馏涉及训练一个较小的模型,试图表示一个较大模型的行为。
  • 量化是减少存储表示模型权重的数字所需内存的过程。
  • 剪枝是指丢弃最不重要的模型权重。

在本文中,我们将了解应用于BERT的蒸馏机制,它产生了一个新的模型叫做DistilBERT。顺便说一下,下面讨论的技术也可以应用于其他NLP模型。

蒸馏基础

蒸馏的目标是创建一个能够模仿一个较大模型的较小模型。在实践中,这意味着如果一个较大模型预测了某件事情,那么一个较小模型就应该做出类似的预测。

为了实现这一目标,需要预先训练一个更大的模型(在我们的例子中是BERT)。那么就需要选择更小的模型的架构。为了增加成功模仿的可能性,通常建议较小的模型具有与较大模型相似的架构,但参数数量较少。最后,较小的模型从较大的模型对特定数据集做出的预测中学习。为了实现这一目标,选择合适的损失函数至关重要,这将有助于较小的模型更好地学习。

在蒸馏符号中,较大模型被称为教师,而较小模型被称为学生。
通常,在预训练阶段应用蒸馏过程,但也可以在微调阶段应用。

DistilBERT

DistilBERT从BERT学习,并通过使用由三个部分组成的损失函数来更新它的权重:

  • 掩码语言建模(MLM)损失
  • 蒸馏损失
  • 相似性损失

接下来,我们将讨论这些损失组件,并理解它们各自的必要性。然而,在深入探讨之前,有必要理解一个重要概念,即softmax激活函数中温度(temperature)概念。温度概念被用于DistilBERT的损失函数中。

Softmax温度

我们经常观察到softmax变换作为神经网络的最后一层。Softmax将所有模型输出归一化,使它们加起来等于1,并且可以被解释为概率。

存在一个softmax公式,其中模型的所有输出都除以一个温度参数T:

1_hUUDeP-t1Pa1FOSsTsuKxQ

Softmax 温度公式。pᵢ 和 zᵢ 分别是第 i 个对象的模型输出和归一化概率。T是温度参数。

温度T控制输出分布的平滑度:

  • 如果 T > 1,那么分布变得更加平滑。
  • 如果 T = 1,那么分布和正常的softmax输出一样。
  • 如果 T < 1,那么分布变得更加粗糙。

为了让事情更清楚,让我们看一个例子。考虑一个有5个标签的分类任务,其中一个神经网络产生了5个值,表示输入对象属于相应类别的置信度。用不同的 T 值应用softmax会得到不同的输出分布。

1_ecAqeSiXXn2XzKbRQP_rGg

基于温度T生成不同概率分布的神经网络示例

温度越高,概率分布越平滑。

1_vQQK-J9LOczMGzpwv8tG1g

基于不同温度T值的logits(1到5的自然数)的Softmax变换。随着温度升高,softmax值彼此变得更加一致。

损失函数

掩码语言建模损失

与教师模型(BERT)类似,在预训练期间,学生模型(DistilBERT)通过为掩码语言建模任务做出预测来学习语言。在为某个词元产生预测后,将预测的概率分布与教师模型的独热编码概率分布进行比较。

one-hot编码分布指定一种概率分布,其中最可能的标记的概率设置为1,所有其他标记的概率设置为0。

与大多数语言模型一样,计算预测分布和真实分布之间的交叉熵损失,并通过反向传播更新学生模型的权重。

1_dA_2ufSWp94r8P0pDip6Mw

掩码语言建模损失计算示例

蒸馏损失

实际上,只使用学生损失来训练学生模型是可能的。但是,在很多情况下,这可能不够。只使用学生损失的一个常见问题在于它的softmax变换中温度T被设为1。在实践中,T=1的结果分布往往是这样的形式,其中一个可能的标签有非常接近1的高概率,而其他所有标签的概率都很低,接近0。

这种情况与对于特定输入有两个或更多有效标签的情况不太一致:T=1的softmax层很可能会排除除了一个之外的所有有效标签,并使概率分布接近独热编码分布。这会导致学生模型可能学到的潜在有用信息丢失,从而降低其多样性。

这就是为什么论文的作者引入了蒸馏损失,其中softmax概率是用一个T> 1的温度计算的,使得概率能够平滑地对齐,从而考虑到学生的多个可能答案。

在蒸馏损失中,学生和教师应用相同的温度 T。删除了教师分布的One-hot编码。

1_f_q37Zbzl4dtuog0EREVfA

蒸馏损失计算示例

可以使用KL散度损失来代替交叉熵损失。

相似度损失

研究人员还表示,在隐藏状态嵌入之间添加余弦相似性损失是有益的。

1_4vtYheWWCb34hOkXzUwFlA

余弦损失公式

这样,学生不仅可以正确地再现屏蔽标记,还可以构建与教师相似的嵌入。它还为在模型的两个空间中保持嵌入之间的相同关系打开了大门。

1_PVy7quRRXpyzHmDyzfeUlw

相似度损失计算示例

三重损失

最后,计算所有三个损失函数的线性组合之和,这定义了DistilBERT中的损失函数。根据损失值,对学生模型进行反向传播以更新其权重。

1_Ex0w0yZZYXVneyWgulsIOw

DistillBERT损失函数

有趣的事实是,在三个损失组件中,掩码语言建模损失对模型性能的影响最小。蒸馏损失和相似性损失有更大的影响。

推理

DistilBERT的推理过程与训练阶段完全一样。唯一的微妙之处是,softmax温度T被设为 1。这样做是为了得到接近BERT计算出来的概率。

架构

总体而言,DistilBERT使用与BERT相同的架构,除了以下几点变化:

1、DistilBERT只有BERT层的一半。模型中的每一层都是通过取两个BERT层中的一个来初始化的。

2、去掉了词元类型嵌入。

3、去掉了用于分类任务的[CLS]词元的隐藏状态上的全连接层。

4、为了获得更强大的性能,作者使用了RoBERTa提出的最佳想法:

·使用动态掩码

·去掉了下一句预测目标

·在更大的批次上进行训练

·应用了梯度累积技术来优化梯度计算

DistilBERT中最后一层隐藏层的大小(768)与BERT相同。作者报告说,减少它并没有带来显著的计算效率方面的改进。根据他们的说法,减少总层数有更大的影响。

数据

DistilBERT在与BERT相同的数据语料上进行训练,其中包括BooksCorpus(800M词)和英文维基百科(2500M词)。

BERT与DistilBERT比较 BERT和DistilBERT在几个最流行的基准测试上进行了关键性能参数的比较。以下是需要记住的事实:

1、在推理过程中,DistilBERT比BERT快60%。

2、DistilBERT比BERT少了44M个参数,总体上比BERT小40%。

3、DistilBERT保留了97%的BERT性能。

1_MRpbO6N-V3Z-4xtPxLyi6g

BERT与DistilBERT比较(在 GLUE 数据集上)

结论

DistilBERT在BERT的发展中迈出了一大步,它可以显着压缩模型,同时在各种NLP任务上实现可比较的性能。除此之外,DistilBERT的重量仅为207MB,使得在容量有限的设备上的集成更加容易。知识蒸馏并不是唯一适用的技术:DistilBERT可以通过量化或剪枝算法进一步压缩。

文章来源:https://towardsdatascience.com/distilbert-11c8810d29fc
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
热门职位
Maluuba
20000~40000/月
Cisco
25000~30000/月 深圳市
PilotAILabs
30000~60000/年 深圳市
写评论取消
回复取消