探索Medusa和多标记预测指南

2024年08月06日 由 alex 发表 97 0

互联网竞争异常激烈。研究表明,如果网页加载时间超过 5 秒,客户就会离开网页。这给大多数大型语言模型(LLM)带来了挑战,因为它们无疑是最慢的程序之一。虽然定制硬件可以大大加快 LLM 的运行速度,但目前在这种硬件上运行 LLM 的成本很高。如果我们能找到充分利用标准硬件的方法,就能大幅提升 LLM 的客户体验。


MEDUSA:具有多个解码头的简单 LLM 推理加速框架 "论文的作者提出了一种架构变革,在现有硬件上运行时可实现 2x-3x 的速度提升。


让我们深入了解一下!


推测解码

推测解码(Speculative Decoding)是作为一种加快 LLM 推断速度的方法引入的。你看,LLM 是自回归的,这意味着我们利用刚刚预测的输出标记来帮助预测我们想要的下一个标记。通常情况下,我们一次预测一个标记(或神经网络每次前向传递一个标记)。然而,由于下一个标记的注意力模式与上一个标记的注意力模式非常相似,因此我们在重复大部分相同的计算,并没有获得多少新信息。


投机性解码是指,我们不是对一个标记进行一次前向传递,而是在一次前向传递后尝试找到尽可能多的标记。一般来说,这有三个步骤:

(1) 生成候选人

(2) 处理候选人

(3) 接受某些候选码


Medusa 是一种推测式解码,因此它的步骤可以直接映射到这些步骤。Medusa 将解码头添加到模型的最后一层,作为 (1) 的实现。树形关注是它处理 (2) 候选码的方法。最后,Medusa 使用剔除采样或典型接受方案来实现 (3)。让我们逐一详细介绍。


解码头和Medusa

解码头采用模型前向传递产生的隐藏状态的内部表示,然后创建与词汇表中不同标记相对应的概率。从本质上讲,它是将模型学到的东西转换成概率,从而决定下一个标记是什么。


5


Medusa 通过在模型的最后一个隐藏层添加多个解码头,调整了典型变换器的结构。这样,它就能在前向传递时预测不止一个标记。我们每增加一个解码头,就能多预测一个标记。因此,如果有 3 个Medusa,就能预测前向传递的第一个标记,之后再用Medusa预测 3 个标记。在论文中,作者建议使用 5 个,因为他们认为这样可以在速度和质量之间取得最佳平衡。


为了实现这一目标,论文作者为 Medusa 提出了以下解码头:


6


通过这个等式,我们可以得出标记 t 来自第 k 个头部的概率。我们首先使用通过训练Medusa头找到的权重 W1,再乘以标记 t 的内部状态。我们将内部状态作为跳转连接的一部分进行第二次添加,这样就不会在 SiLU 的线性激活过程中丢失信息,从而提高模型的性能。然后,我们将总和乘以我们为头部训练的第二组权重 W2,并通过 softmax 运行该乘积,从而得到我们的概率。


树的注意力

第一个Medusa根据前向传递给出了它们应该考虑的模型概率,但后面的Medusa需要根据前面Medusa的选择来确定它们应该选择什么标记。


当然,先前的Medusa提出的选项越多(超参数 Sk),未来的Medusa需要考虑的选项就越多。例如,当我们只考虑首脑 1 的前两名候选人(s1=2)和首脑 2 的前三名候选人(s2=3)时,我们需要计算 6 种不同的情况。


由于这种扩展,我们希望尽可能同时生成和验证这些候选方案。


7


上面的矩阵显示了我们如何通过树状注意力在同一批次中进行所有这些计算。与典型的因果自关注不同,只有来自同一延续的标记才会被视为与关注模式相关。正如矩阵所示,在有限的空间内,我们可以将所有候选项放入一个批次,并同时对它们运行注意力。


这里的挑战在于,每个预测只需要考虑直接在其后面的候选标记。换句话说,如果我们从头 1 中选择了 "It",而我们正在评估下一个标记应该是哪个,那么我们就不希望 "I "的注意力模式被用于其他标记。


作者通过使用掩码来避免将无关标记的数据传递到注意力计算中,从而避免了这种干扰。通过使用这种掩码,他们可以在计算注意力模式时提高内存效率,然后在解码头中使用该信息生成后续的标记候选。


虽然上面的矩阵显示我们对每个预测的考虑都是一样的,但如果我们对每个预测都有一个概率,我们就可以根据它们成为最佳选择的可能性来区别对待。下面的树形图就是一个很好的例子。


8


在上图中,有 4 个Medusa头像,每个头像都有多个候选人。不过,并不是每个预测都会被计算出来。我们根据预测正确的概率在树上添加节点。在这里,树的权重主要向左倾斜,这表明预测的概率越高,显示的可能性就越大。简而言之,我们在这里所做的就是只将我们认为有合理可能性成为最佳选择的预测加载到树的注意力中。


典型的接受方案与拒绝抽样

现在我们到了最后阶段,即决定使用哪种预测(如果有的话)。正如我们一开始所说的,模型是自动回归的,因此如果我们预测了前向传递的下 5 个标记,我们就可以简单地将这下 5 个标记放入下一轮的模型中,从而享受推理速度的提升。不过,我们只有在获得高质量预测时才会这么做。如何确定这一点呢?


其中一种方法是拒绝采样法,即我们有一个单独的模型来确定下一个标记是否足够好。当然,这种方法完全取决于其他模型的质量。如果它足够好,那么这个方法就很有效!但要注意的是,为了保持低延迟,你会希望另一个模型运行得相当快,这是很难与高质量相平衡的。


因此,作者提出了典型的验收方案来进行判断。由于所有的预测都是概率,我们可以用它们来设定一个阈值,超过这个阈值我们就接受一个标记。下式展示了我们如何做到这一点:


9


这里的关键是,我们将使用原始模型在这些标记上生成的概率来确定预测是否有效。p 代表原始模型的概率分布,而 ϵ 和 δ 则是阈值,用于确定何时概率高到足以将其纳入模型响应。这里要说明的是,高概率代币会流过,但概率较低但来自概率分布的代币也会流过,在概率分布中,大部分概率都较低。


此外,当我们调整温度时,这一功能也会导致重要的行为。一般来说,用户会提高 LLM 的温度,从而给出更有创意的回答。因此,当温度设置为零时,典型接受会确保只有前向传递预测的第一个标记通过,从而得到最一致的结果。然而,随着温度的升高,LLM 的概率分布会发生变化,从而使我们有更多的预测可能达到被接受的阈值。这不仅能更快地得到结果,而且往往还能得到更有创意的结果。


自我蒸馏

建议,在创建Medusa模型时,我们不要从头开始训练,而是采用高质量的基础模型(我们称之为模型的骨干部分),并在其基础上添加Medusa头。一旦我们对其进行了微调,使其能够理解新的头部,速度就会提高,而不会造成重大的性能损失。


首先,他们使用 ShareGPT 数据集来寻找人们期望与 LLM 进行的高质量交互。他们从数据集中提取了所有提示,然后通过骨干模型运行这些提示,以获得用于微调的基本事实。


虽然这种方法在微调Medusa头部时效果很好,但在微调整个新模型时效果并不理想。


这种退化意味着,地面实况信息不足以重新训练模型并保持高性能。因此,他们重新编写了损失函数,将概率分布作为基本真相。这就需要像下面这样重新制定损失函数。


10


简单解释一下,我们使用库尔贝-莱伯勒发散(KL)来测量标记的原始概率分布与新概率分布之间的差异。


不过,这种方法要求我们同时维护原始模型和新模型的概率,这既消耗存储空间,又消耗内存。


训练Medusa

现在我们有了数据,就可以开始微调了!


正如我们所看到的,Medusa 需要为模型添加额外参数才能运行,我们必须对其进行训练。为了减少所需的计算量(从而降低训练成本),为 Medusa 引入了两种微调形式: Medusa-1 和 Medusa-2。


Medusa-1

Medusa-1 包括冻结模型中除 Medusa 头以外的所有权重。通过只在 Medusa 头中运行梯度,我们不必担心降低原始模型的性能(性能保持不变),而且可以提高 Medusa 头的性能。下面的损失函数显示了它们如何将正确的地面实况标记与正确的Medusa头像相匹配。


11


Medusa-1 只关注额外的 Medusa 权重,这意味着它比 Medusa-2(我们稍后将深入讨论)更具成本效益。对于对训练价格敏感的人来说,建议使用量化骨干模型来进一步降低内存需求,同时使用量化低秩适应(QLoRA)微调方法来进一步降低成本。


Medusa-2

虽然 Medusa-1 更具成本效益,但当我们更新模型中的所有权重以考虑我们添加的新 Medusa 磁头时,仍能获得最佳性能。有趣的是,这并不像简单地进行 LoRA 那样直接,而是将梯度传递给所有权重(而不仅仅是 Medusa 权重)。


相反,首先运行了 Medusa-1,使 Medusa 权重达到合理的性能。然后,他们为 Medusa 权重和骨干模型权重选择了不同的学习率。从逻辑上讲,这样做是因为骨干模型的权重很可能已经接近其需要达到的水平,而 Medusa 模型的权重应该有更大的变化。最后,他们为骨干模型添加了损失函数(表示为 Llm),并将Medusa-1 损失函数缩放为一个值 λ0。λ的作用是平衡损失,这样我们就不会仅仅因为Medusa头而计算出过大的损失值。


12


结论


13


使用 Medusa 可以显著提高速度。从上图中我们可以看到,将 Vicuna(一种流行的开源 LLM)的速度提高了两到三倍。


无论是在互联网上还是在设备上,速度都至关重要。随着越来越多的公司推动创建本地 LLM,像 Medusa 这样的方法对于在有限的硬件上获得极快的速度似乎至关重要。看看像 Phi-3 这样的小型模型能提高多少速度将是非常有趣的事情(Phi-3 发布时在 A16 Bionic iPhone 芯片上的运行速度为每秒 12 个 token,更多信息请参阅我的博文)。对于开发者来说,这可能会为在本地运行多种不同类型的开源模型打开一扇门--即使这些模型最初并不是为像 Phi-3 这样的快速推理而设计的。



文章来源:https://towardsdatascience.com/exploring-medusa-and-multi-token-prediction-de7f8312e4a7
欢迎关注ATYUN官方公众号
商务合作及内容投稿请联系邮箱:bd@atyun.com
评论 登录
写评论取消
回复取消