生成式人工智能(AI)模型的核心目标是根据大型数据集中的模式,创造出逼真且高质量的数据,涵盖图像、音频和视频等多种形式。这些模型具备模仿复杂数据分布的能力,从而生成与原始样本相似的合成内容。其中,扩散模型作为一种广受欢迎的生成模型,通过逐步反转加入噪声的序列,成功实现了图像和视频的生成,直至达到高保真输出。然而,扩散模型的一个显著缺点是采样过程繁琐,通常需要几十甚至上百个步骤才能完成,这导致了高昂的计算资源和时间成本。在需要快速采样或大规模生成样本的场景中,如实时应用或大规模部署,这一缺点尤为突出。
扩散模型在采样过程中面临的主要挑战是计算负担,这源于系统地反转噪声序列所需的高昂计算成本,以及在离散化时间间隔时引入的误差。为解决这一问题,连续时间扩散模型应运而生,它们无需时间间隔,从而减少了采样误差。然而,由于训练期间存在的不稳定性,连续时间模型尚未得到广泛应用。这种不稳定性使得在大规模或复杂数据集上训练这些模型变得困难,进而阻碍了它们在计算效率至关重要的领域的采纳和发展。
为了提升扩散模型的效率,研究人员近期开发了一系列方法,包括直接蒸馏、对抗性蒸馏、渐进蒸馏和变分得分蒸馏(VSD)。这些方法在加速采样过程或提高样本质量方面展现出潜力,但同时也面临着实用上的挑战,如高计算开销、复杂的训练设置和可扩展性限制。例如,直接蒸馏需要从头开始训练,增加了时间和资源成本;对抗性蒸馏在使用生成对抗网络(GAN)架构时,常因输出的稳定性和一致性问题而面临挑战;渐进蒸馏和VSD虽然对短步骤模型有效,但通常会产生多样性有限或细节较少的样本,尤其是在高引导水平时。
OpenAI的一个研究团队引入了一种名为TrigFlow的新框架,旨在简化、稳定和扩展连续时间一致性模型(CM)。该框架专门针对训练连续时间模型中的不稳定问题,通过改进模型参数化、网络架构和训练目标,简化了训练流程。TrigFlow通过统一扩散模型和一致性模型,识别并减轻了不稳定的主要原因,使模型能够可靠地处理连续时间任务。即使在扩展到ImageNet等大型数据集时,该模型也能以最低的计算成本实现高质量采样。通过使用TrigFlow,团队成功训练了一个拥有15亿参数的模型,该模型通过两步采样过程就达到了较高的质量评分,且计算成本低于现有的扩散方法。
TrigFlow的核心在于对数学上的重新定义,简化了采样过程中使用的概率流常微分方程(ODE)。这一改进结合了自适应组归一化和使用自适应加权的更新目标函数,有助于稳定训练过程,使模型能够连续运行而不受离散化误差的影响。此外,TrigFlow在网络架构中的时间条件方法减少了对复杂计算的依赖,使得模型的扩展变得可行。重构后的训练目标逐步退火模型中的关键项,使其能更快稳定并达到前所未有的规模。
被命名为“sCM”(简单、稳定、可扩展一致性模型)的模型展示了与最先进扩散模型相媲美的结果。例如,在CIFAR-10数据集上,sCM的Fréchet Inception Distance(FID)达到了2.06;在ImageNet 64×64上达到了1.48;在ImageNet 512×512上达到了1.88。这些成绩显著缩小了与最佳扩散模型之间的差距,而sCM仅使用了两步采样。与需要更多步骤的先前方法相比,两步模型在FID上提升了近10%,标志着采样效率的显著提升。TrigFlow框架在模型可扩展性和计算效率方面取得了重要进展。
这项研究提供了几个关键启示。首先,通过精心构建的连续时间模型,可以解决传统扩散模型的计算低效率和局限性。其次,通过实施TrigFlow,研究人员稳定了连续时间CM,并将其扩展到更大的数据集和参数规模,几乎无需在计算上做出妥协。
研究的关键启示包括:
总之,这项研究在生成模型训练中取得了关键性进步。通过TrigFlow框架,研究团队解决了稳定性、可扩展性和采样效率的问题。OpenAI团队的TrigFlow架构和sCM模型有效地克服了连续时间一致性模型的主要挑战,提供了一个稳定且可扩展的解决方案。该方案在性能和质量上能与最佳扩散模型相媲美,同时显著降低了计算要求。