大型语言模型(LLM)受限于在“语言空间”中进行推理,它们通常通过思维链(Chain-of-Thought, CoT)来表达推理过程,以解决复杂的推理问题。然而,语言空间并不总是推理的最佳选择。
为了探索在不受限制的潜在空间中而非使用自然语言进行LLM推理的潜力,本文引入了一种新范式Coconut(连续思维链)。它涉及对传统CoT过程的一个简单修改:不是使用语言模型的头部和嵌入层在隐藏状态和语言标记之间进行映射,而是如下图所示,Coconut直接将最后一个隐藏状态(一个连续的思维)作为下一个标记的输入嵌入。
Coconut:连续思维链
i) 方法概述
在此方法中,LLM(大型语言模型)在“语言模式”和“潜在模式”之间切换,如上图所示。
在语言模式下,模型作为标准语言模型运行,自回归地生成下一个标记(token)。在潜在模式下,它直接利用最后一个隐藏状态作为下一个输入嵌入。这个最后的隐藏状态代表了当前的推理状态,被称为“连续思维”。
特殊标记<bot>和<eot>分别用于标记潜在思维模式的开始和结束。例如,我们假设潜在推理发生在位置i和j之间,即xi = <bot>且xj = <eot>。
当模型处于潜在模式时(i < t < j),我们使用前一个标记的最后一个隐藏状态来替换输入嵌入。
在潜在模式结束后(t ≥ j),输入会恢复为使用标记嵌入,即Et = [e(x1), e(x2), …, e(xi), hi, hi+1, …, hj−1, e(xj), …, e(xt)]。
ii) 训练流程
利用语言思维链(CoT)数据,通过实施多阶段训练课程来监督连续思维。
如下图所示,在初始阶段,模型会在常规的思维链实例上进行训练。
在后续阶段,即第 k 阶段,CoT 中的前 k 个推理步骤被 k×c 个连续想法1取代,其中 c 是一个超参数,控制取代单个语言推理步骤的潜在想法的数量。
当训练阶段切换时,我们还会重置优化器状态。
我们插入<bot>和<eot>标记来封装连续的想法。
在训练过程中,我们掩盖了问题和潜在思维的损失。需要注意的是,目标不是鼓励连续思维压缩被移除的语言思维,而是促进对未来推理的预测。
iii) 训练细节
提出的连续思维是完全可微的,并且允许进行反向传播。
当在当前训练阶段安排了n个潜在思维时,会执行n+1次前向传递,每次传递都会计算一个新的潜在思维,并最终进行一次额外的前向传递以获得剩余文本序列的损失。
虽然我们可以使用键值(KV)缓存来节省重复计算,但多次前向传递的顺序性质对并行性构成了挑战。
iv) 推理过程
Coconut的推理过程与标准语言模型解码类似,不同之处在于在潜在模式下,我们直接将最后一个隐藏状态作为下一个输入嵌入。
由于我们关注的是问题解决场景,因此在问题标记之后立即插入<bot>标记。
对于<eot>,我们考虑了两种潜在策略:a) 在潜在思维上训练一个二分类器,使模型能够自主决定何时终止潜在推理;b) 总是将潜在思维填充到固定长度。我们发现这两种方法的效果相当。
实验
i) 实验设置
使用预训练的GPT-2作为所有实验的基础模型。
a) 数学推理:
b) 逻辑推理:
ii) 基线和Coconut的变体
a) 使用的基线
b) Coconut的变体
iii) 结果
下表显示了GSM8l、ProntoQA和ProsQA三个数据集上的结果。更高的准确率表明更强的推理能力,而生成更少的标记则表明更高的效率。
在GSM8k的实验中,我们发现Coconut的表现优于采用类似策略训练的其他架构,尤其是超越了最新的基线模型iCoT。
下图展示了调整超参数c的实验结果,该参数控制一个语言推理步骤对应的潜在思维数量。
随着我们将c从0增加到1再到2,模型的性能稳步提升,这表明在潜在空间中可以观察到类似于思维链(CoT)的效应。
在其他两个合成任务中,我们发现Coconut的变体(无思维或暂停作为思维)以及iCoT基线也取得了令人印象深刻的准确性。
Coconut、其变体以及iCoT在ProsQA上显著增强了推理能力,这表明在处理需要大量规划的任务时,潜在空间推理提供了明显的优势。
理解Coconut中的潜在推理
i) 实验设置
a) 方法
Coconut的设计允许我们通过手动设置在推理期间<eot>标记的位置来控制潜在思维的数量。
在我们的实验中,我们测试了Coconut在ProsQA上的变体,其中k的取值范围为{0, 1, 2, 3, 4, 5, 6}。请注意,所有这些变体在推理时有所不同,但它们共享相同的模型权重。
b) 评估指标
我们采用了两套评估指标,其中一套基于最终答案的正确性,而不考虑推理过程。
为了进行更细致的分析,我们定义了另一套针对推理过程的评估指标。假设我们有一个完整的语言推理链,它指定了图中的一条路径,我们可以将其分类为:
(1)正确路径:输出是到达正确答案的最短路径之一。
(2)较长路径:一个有效路径,能正确回答问题但比最短路径长。
(3)幻觉:路径包含不存在的边或断开。
(4)错误目标:图中的有效路径,但目标节点不是所询问的节点。
在无CoT和具有较大k值的Coconut中,模型可能只输出最终答案而不包含任何部分路径,这属于(5)正确标签或(6)错误标签。
ii) 潜在推理与语言推理之间的插值
下图展示了在ProsQA上不同推理方法的比较分析。
随着使用连续思维进行的推理增多(即k值增大),最终答案的准确性(图左)以及正确推理过程的比例(图右中的“正确标签”和“正确路径”)均有所提高。
下图展示了ProsQA的一个案例研究。使用CoT训练的模型在陷入死胡同后幻想出了一条边(每个yumpus都是一个rempus)。Coconut(k=1)输出的路径以一个不相关的节点结束。而Coconut(k=2)则正确地解决了问题。
iii) 解释潜在搜索树
基于连续思维能够编码多个潜在下一步的直觉,潜在推理可以被解释为一种搜索树,而不仅仅是推理“链”。
下图左半部分描绘了所有可能的分支。同样地,在第二步中,前沿节点将是Alex的孙子节点(下图右半部分)。
与标准的广度优先搜索(BFS)均匀探索所有前沿节点不同,该模型展示出了优先处理有希望的节点同时剪枝掉较不相关节点的能力。
这种概率分布可以被视为模型的隐式价值函数,用于估计每个节点到达目标的潜力。如图所示,“lempus”、“zhorpus”、“grimpus”和“sterpus”的价值分别为0.33、0.16、0.32和0.01。
下图分析了模型在第一次和第二次思维中的潜在推理并行性。对于第一次思维(左图),计算了前1、前2和前3候选节点的累积价值,并将其与测试集中各自的百分位数进行了对比和绘制。
三条线之间明显的差距表明,在这个阶段,模型的推理路径保持了显著的多样性,这表明了对替代可能性的广泛探索。相比之下,第二次思维(右图)显示了这些差距的缩小。
结论
我们提出了Coconut,这是一种在连续潜在空间中进行推理的新范式。
通过大量的实验,我们证明了Coconut显著增强了大型语言模型(LLM)的推理能力。