对于较长上下文的 LLM 推理,现有方法(包括选择性标记保留和基于窗口的注意)可以提高效率,但可能会丢弃未来文本生成所需的重要标记。在这篇论文 中,基于近端标记比远端标记更重要的观察,作者提出了 POD(近端标记优于远端标记)方法。它通过在相似的层之间共享远端标记的注意来分配更少的资源,通过减少不太重要的标记的内存和计算负载(而不是丢弃它们)来提高 LLM 效率而不会丢失标记。
它解决了两个挑战:
主要观察
核心动机是,不太重要的令牌应该被分配更少的资源,而不是被完全丢弃。这引发了两个挑战:1)对于一个令牌来说,重要的令牌分布在哪里?以及2)如何为不太重要的令牌优化内存和计算?本文试图通过两个关键观察来解决这两个挑战:
观察1:邻近令牌(初始令牌+最近令牌)比远距离令牌更重要。
下图显示了不同窗口大小下相同预测的比例:窗口注意力与密集注意力。
上图表明,即使只关注256个邻近令牌,模型在相同输入序列的情况下,也有80%的预测与关注所有令牌的模型完全相同,这一结果支持了我们的观察。
观察2:连续层之间的注意力分数相似
下图展示了Llama3–8B-32K模型中第14个注意力头在不同层之间的注意力分数相似性。
下图显示,红色框内各层之间的注意力分数表现出很强的相似性。
POD(邻近令牌优先于远距离令牌)
基于上述观察,作者提出了POD(邻近令牌优先于远距离令牌)方法,以优化解码阶段的推理效率。该方法专为远距离令牌共享层间注意力分数,而根据上述两个观察结果,对邻近令牌则保持不变。
此方法主要包括三个主要阶段:
方法论
如上所述,该方法包括三个主要步骤,如下图所示:
i) 离线层间注意力共享探索
a) 注意力分数计算
对于输入到模型M中的N个样本si = (x1, x2, ..., xn)(i=1,2,...,N),对于每个样本,收集最后q个(1 ≤ q ≤ n)令牌对其对应的前序令牌的注意力分数。
其中,L, H ∈ N+ 分别表示模型中的层数和注意力头的数量,Sℓ,hi ∈ Rq×n 表示在第h个注意力头的第ℓ层收集到的注意力分数。
b) 注意力相似性评估
对于任意两个不同的层ℓa和ℓb(1 ≤ ℓa, ℓb ≤ L 且 ℓa ≠ ℓb),第h个头在它们之间的注意力相似性定义为所有N个样本中最后q个令牌的平均Jensen-Shannon(JS)散度。
其中,Sℓ,hi,j 表示 Sℓ,hi 的第 j 行,且 0 ≤ simh(·, ·) ≤ 1。
c) 层分组
在计算了层间的头级注意力相似性之后,我们将连续的相似层分组为头级的块以做准备。
分组策略基于这样一个想法:同一个块内的任意两层都应该足够相似。
当 sim(ℓa, ℓb) ≥ δ 时,我们认为 ℓa 和 ℓb 是相似的,其中 0 ≤ δ ≤ 1 是一个超参数。
我们采用了一种自底向上的贪婪算法,迭代地将连续的相似层合并为块,具体算法如下:
ii) 轻量级训练适应
在每个块内应用注意力共享,并对大型语言模型(LLM)进行后训练。
a) 每个块内的注意力共享
其中,Qℓ,i 表示 Qℓ 的第 i 行,Kℓ,[a,b] 表示 Kℓ 的从第 i 行到第 j 行(包含边界)的行。
b) 邻近和远距离令牌注意力输出的聚合
一种无参数的门控机制可以通过(某种方式)整合对邻近和远距离令牌的注意力。
iii) 高效推理
使用后训练的大型语言模型(LLM)进行高效推理。
a) KV缓存内存占用优化
如前文所述,对于远距离令牌,查询和键状态在同一块内的各层之间是共享的。在推理过程中,不需要缓存查询状态,因为它们不会被重复使用。
这里提到的方法将减少KV缓存中键状态的内存消耗。如上图所示,在解码过程中共享注意力分数的层只保留一次远距离令牌;例如,只有第1层保留x2和x3的键状态,而第2层和第3层则不保留。
b) 远距离令牌的计算优化
经验证据表明,在许多情况下,预测下一个令牌可以有效地完成,而无需关注远距离令牌。
如上图所示,对于块内非最低层的层,我们可以预先评估gℓ,i的值。如果gℓ,i ≥ τ(0 ≤ τ ≤ 1是一个超参数),则可以省略对远距离令牌的注意力计算,从而减少远距离令牌的计算量。
实验
i) 性能评估
a) 在大量数据中寻找特定信息
下图展示了不同方法的搜索结果。
结果显示,当目标信息(即“针”)不在StreamingLLM和H2O预设的窗口内时,这两种方法都会失败。相比之下,我们的方法避免了令牌丢失,其表现与密集模型相似,并且能够定位到几乎所有的目标信息。
b) 长上下文基准测试
下表展示了不同方法在两种著名的长上下文基准测试上的评估结果。
POD的表现优于基于令牌驱逐的方法,这表明我们避免丢失令牌的方法确实有效。
只需少量的后训练数据,POD就能击败基于经典层共享的方法CLA,这表明我们的模型在适应现有大型语言模型(LLM)方面具有优势。
无论是PoD还是基于令牌选择的方法,都能达到与标准密集模型相当的性能。
此外,POD与基于令牌选择的方法是正交的,将它们结合起来可以进一步减少KV缓存的大小,同时保持模型性能。
ii) 效率评估
a) 内存占用
下表展示了内存消耗的结果。
结果显示,POD在不同输入文本长度下,最大批处理量增加了30%以上,这与我们理论上KV缓存节省率35%非常接近,表明POD有效地减少了内存使用。
b) 远距离令牌的计算量
下图展示了在LEval上计算节省率与性能损失之间的比率以及τ值的关系。
我们观察到,随着τ的减小,忽略远距离令牌的计算变得更容易,从而带来更大的计算节省,但也会伴随一些性能损失。然而,当τ<0.7时,性能下降的速度会放缓,而计算节省则变得更加明显。当τ=0.7时,计算成本降低了25%,而性能仅下降了5%。
iii) 分析
a) 扩展到更长上下文和其他大型语言模型(LLM)
下表展示了在InfiniteBench下不同上下文大小下5个子任务的评估结果。
与基于令牌驱逐的方法相比,POD方法导致的性能下降较小。然而,一个显著的差异是,基于令牌选择的方法似乎在较长上下文场景中难以保持模型性能。
b) POD中两个关键超参数与模型性能之间的关系:邻近令牌的数量和KV缓存节省率
下图显示,随着邻近令牌数量的增加,POD的性能稳步提升。
下图显示,随着节省率的增加,性能有所下降。为了平衡性能和效率,我们将KV缓存压缩至35%,在使用相同训练数据的情况下,保持了与LLaMA3–8B-32K相当的性能。
结论
基于观察到近距离令牌比远距离令牌更重要的现象,我们提出了POD方法,该方法通过为远距离令牌在相似层之间共享注意力来分配更少的资源。
评估结果显示,POD可以在不牺牲模型性能的情况下节省35%的KV缓存。