扩散模型因其出色的图像生成能力而备受瞩目,目前已成为流行文本到图像模型(如DALL-E、Stable Diffusion和Midjourney)中的主要架构。然而,扩散模型的应用不仅限于图像生成。近期,Meta、普林斯顿大学和德克萨斯大学奥斯汀分校的研究人员发表了一篇新论文,指出扩散模型在强化学习领域也具有重要价值。
这篇论文介绍了一种利用基于扩散的世界模型(DWM)来训练强化学习代理的技术。DWM通过预测环境未来几步的状态,有效提升了当前基于模型的强化学习系统的性能。
强化学习分为无模型和基于模型两种。无模型强化学习直接从与环境的交互中学习策略或价值函数,而不预测环境未来的状态。而基于模型的强化学习则通过模拟环境来训练策略。这种方法的优势在于,它需要的真实环境数据样本较少,对于自动驾驶汽车和机器人等需要大量实验数据的应用场景尤为适用。然而,基于模型的强化学习高度依赖于世界模型的准确性。如果模型不准确,基于模型的强化学习系统的表现可能会比无模型系统更差。
传统的世界模型使用单步动态预测,即仅基于当前状态和动作预测奖励和下一个状态。然而,当进行长期规划时,这种方法可能会因为小错误的累积而导致长期预测的不准确。为了解决这个问题,DWM采用了一次性预测多个未来步骤的方法。这种方法可以有效减少长期预测中的错误,从而提高基于模型的强化学习算法的性能。
扩散模型的工作原理是通过逆转一个逐渐添加噪声的过程来生成数据。在训练过程中,模型逐渐向数据添加噪声层,然后尝试逆转这个过程并预测原始数据。通过重复这个过程并添加更多的噪声层,模型学会了从纯噪声中创建详细的数据。条件扩散模型则通过使模型的输出与特定输入相关联,增加了一层控制。这使得模型可以根据给定的输入(如与图像相关的标题)生成相应的数据。
尽管扩散模型在图像生成方面表现出色,但它们也可以应用于其他数据类型。在强化学习领域,DWM使用扩散模型来预测未来多个步骤的状态和奖励。这种方法使得代理能够更准确地模拟环境,并制定相应的策略。
DWM框架包括两个训练阶段。在第一阶段,扩散模型在从环境中收集的一系列轨迹上进行训练。通过这个过程,DWM学习了一个强大的世界模型,能够一次性预测多个步骤。这使得它在长期模拟中比其他基于模型的方法更稳定。在第二阶段,使用actor-critic算法和扩散世界模型进行离线强化学习策略的训练。这种方法消除了训练期间需要在线交互的需求,提高了训练速度并降低了成本和风险。
为了验证DWM的有效性,研究人员将其与基于模型和无模型的强化学习系统进行了比较。实验结果表明,DWM与单步世界模型相比,性能提高了44%。此外,当将单步世界模型应用于无模型强化学习算法时,通常会降低性能。然而,当与DWM结合时,无模型强化学习系统的性能超过了原始版本。
这项研究展示了扩散模型在非生成性任务中的更广泛应用趋势。在过去的一年里,由于生成模型的进步,机器人研究取得了重大突破。语言模型正在帮助缩小自然语言命令和机器人运动命令之间的差距。同时,变换器还帮助研究人员将从不同形态和设置中收集的数据整合在一起,并训练出可以推广到不同机器人和任务的模型。随着这些技术的不断发展,我们有望看到更多创新的强化学习应用问世。