我们通过使用航点来引导,提高了扩散策略生成的轨迹的效率和质量。基础的扩散策略是一种创建多样化计划以解决任务的有效方法,但这些轨迹可能不可行或效率低下。我们通过引入从收集到的数据中学习到的航点作为引导,证明了可以在Maze2D模拟中提高生成轨迹的效率。我们相信,通过我们的方法,既可以生成更好的轨迹,也能够在不同的环境配置中重复使用扩散模型。
强化学习
强化学习是一种训练方法,用于训练模型在环境中交互以完成目标。在最简单的方法中,这可以通过直接与环境交互、收集经验并更新智能体的“策略”(即智能体在可选动作之间做出选择的模型)来实现。这被称为“在线强化学习”,并且存在一些问题。在线方法的局限性在于,数据收集的速度受到智能体探索环境速度的限制。此外,由于具有非最优策略的智能体必须自己与环境交互以收集数据,因此要达到专家级能力所需的样本数量非常高。离线方法试图通过从预先收集的专家数据中学习策略来解决这些扩展性限制。传统的强化学习方法针对固定、预定义的任务进行优化,并且以环境的状态空间为条件。目标条件强化学习(GCRL)允许智能体开发灵活的策略,这些策略可以在不同的目标状态下动态调整其行为。通过将目标视为额外的输入参数,这些学习系统可以在多个任务之间进行泛化。
离线强化学习方法
最流行的离线强化学习方法涉及学习状态“价值”的估计,即遵循特定策略时将获得的折扣奖励。这允许策略就长期目标而言,评估采取一个动作相对于另一个动作的优势。在离线强化学习的背景下,这隐式地将数据中看到的轨迹拼接在一起,使策略能够超越单纯复制数据中看到的策略。然而,基于价值的方法训练起来很困难,需要大量的超参数调优以确保训练稳定。
相比之下,通过监督学习的强化学习(RvS)方法基于目标条件行为克隆,直接使用监督学习复制数据中演示的策略。其优点是这些方法消除了价值估计的需求,使训练更加直接。历史上,这些方法的缺点是它们在处理拼接问题时遇到困难。RvS中的拼接问题出现在智能体需要连接或“拼接”数据集中次优行为片段以形成全局最优策略时。例如,当从所有涉及死胡同的示例中学习导航迷宫时,必须将这些示例的子轨迹拼接在一起以创建有效解决方案,以限制训练所需的数据量。这反映了在数据中不同轨迹的成功部分之间弥合差距的挑战,特别是当数据集包含次优、不完整或稀疏的演示时。当基于基础行为克隆的策略到达数据可用性低的状态空间部分时,它们很可能会失败。这导致RvS方法的性能落后于基于价值的方法。
使用航点拼接生成的子轨迹
航点变换器
航点变换器是一种RvS方法,它引入了一个模块来生成中间目标和更稳定的代理奖励,以引导策略朝向目标。然后,他们使用从数据演示中学习到达生成的子目标的变换器,以获得遵循航点到达目标的策略。这解决了其他RvS方法中发现的拼接问题,并且论文特别演示了如何使用航点使航点变换器能够用比其他方法更少的样本穿越状态空间区域。
扩散策略
最近,扩散策略作为一种新方法崭露头角,用于学习复杂分布并生成多样化的样本。最值得注意的是,它们已被证明能够以可控的方式生成新颖的图像,同时也可以用类似的可控方式生成长远规划。与其他方法使用的迭代规划不同,扩散策略避免了误差累积问题。这使得它们成为目标条件强化学习(GCRL)的一个有吸引力的选择。然而,单独的扩散模型往往会生成不可行的规划。
基于扩散的子轨迹拼接——SSD
基于扩散的子轨迹拼接展示了一种方法,旨在解决扩散策略在解决目标条件强化学习(GCRL)问题时遇到的限制。受其他离线强化学习方法启发,SSD使用学习到的价值估计器来条件化负责在设定时间范围内生成轨迹的扩散模型。该模型学习如何将这些较短的子轨迹连接起来以达到指定目标。通过交替训练价值估计器和扩散策略来训练整个流程。价值估计器是基于Transformer的U-Net架构。
在扩散策略中使用航点
如前所述,直接从数据中生成的航点不会遭受训练价值估计器时遇到的稳定性问题。训练一个模型来生成航点既快速又简单。此外,该模型不需要与策略联合训练。因此,我们提议用类似于航点变换器论文中提出的航点生成器来替换SSD中的动作-价值评价器。
Push-T环境
我们在流行的Push-T环境中进行了实验。遵循与航点变换器相同的思路,我们从Push-T数据集中训练了一个基于多层感知器(MLP)的航点预测器。训练过程很简单,因为我们可以选择演示序列中的任意时间步,并使用未来状态作为监督。给定当前观测值[末端执行器.x, 末端执行器.y, T.x, T.y, T.angle],网络使用均方误差(MSE)损失预测K步之后的状态[T.x, T.y, T.angle]。
以下是一些预测T的航点的示例。我们展示了两个成功案例和一个失败案例。
扩散策略网络的扩展非常简单。原始网络通过FiLM(特征线性调制)条件化,基于过去的观测值进行条件化。我们可以简单地将目标条件附加到条件变量中,并训练扩散网络。然后,在执行阶段,我们可以查询我们的航点扩散器以获取中间目标条件,并通过扩散生成动作序列。
我们用相同的初始条件初始化环境,并让普通扩散和航点扩散都生成动作序列。两种策略都在相同的数据集上进行训练,观测范围为2,预测范围为16。在执行阶段,普通扩散坚持原始方案,仅在重新规划前执行8个动作。相比之下,我们的航点扩散在重新规划前执行所有16个动作。
我们可以看到,航点扩散策略能够比普通扩散策略更快地将T形物体移动到期望的位置。
这一观察结果得到了奖励图的支持。在初始步骤中,航点扩散的奖励增加速度比普通扩散更快。虽然它们的最大奖励相当,但航点扩散有时表现出较低的终点准确性。我们推测,航点预测网络可以引导T形物体朝正确方向移动,但在预测T形物体的精确目标位置时存在困难。因此,当T形物体接近目标位置时,航点网络可能产生质量相似甚至更低的目标状态。
Maze2D环境以及与SSD的比较
我们将我们的方法与SSD进行比较,SSD是另一种在Maze2D环境中拼接扩散生成的子轨迹的方法,它使用受其他离线强化学习方法启发的动作-价值评价器。我们保持其扩散策略架构不变,仅将其动作-价值评价器模型替换为我们的航点生成器模型,并相应地调整训练过程,从交替训练价值估计器和扩散策略转变为先完全训练航点生成器,然后冻结它,并在训练扩散策略时使用它。
我们在简单的maze-umaze-v1和maze-medium-v1环境上比较了定性结果。UMaze提供了100万个样本,而Medium提供了200万个样本。我们使用SSD的SampleEveryControl控制器从每个策略中采样动作。该控制器在每个时间步直接从扩散策略中采样每个动作,以便进行更直接的比较。
我们训练了一个简单的3层多层感知器(MLP)([4+2x256],[256x256],[256x256],[256x256],[256x2]),其任务是预测轨迹中当前状态K步之后的位置(其中K远小于SSD扩散策略使用的有限范围),使用均方误差损失。网络接收代理在任何特定状态下将获得的观测值([x, y, v_x, v_y])和总体目标位置([x,y])的连接作为输入。K的选择是一个超参数,但我们发现较短的K(在我们的实验中为8步)比较长的K选择表现更好。这个网络替换了SSD中的动作价值评价器,并且在训练扩散策略之前进行训练和冻结,类似于航点变换器的做法。
Maze2D-UMaze-v1
Maze2D-Medium-v1
我们能够在Maze2D环境中匹配SSD的成功案例,而无需使用更大的基于Transformer的动作价值评价网络以及复杂的训练过程。相反,通过一个简单的多层感知器(MLP),其参数仅为价值网络的一小部分,并且可以独立于扩散模型进行训练,我们就能够实现类似的结果。我们还观察到,在更复杂的环境中,通过使用直接学习到的航点,策略能够更直接地穿越迷宫的两个部分。这使我们推测,通过使用航点,扩散策略更容易穿越数据样本有限的区域,并且使我们的方法能够生成更有效的轨迹。
结论
我们在小规模示例中展示了,可以使用直接从收集的专家数据中训练的简单航点生成器来制作目标条件化的扩散策略。这种方法比SSD中的方法训练起来要简单得多,并且需要更少的参数,因为航点模型不需要与扩散策略联合训练。这可以启用一些用例,比如快速创建的航点网络库可以与单个扩散模型一起使用。由于扩散模型在训练中代表了最大的成本,因此重用该组件将是有利的,并且可能使未见环境的测试时新策略学习成为可能。