Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

OpenAI攻克扩散模型短板,清华校友路橙、宋飏合作最新论文

多项改进实现规模空前的连续时间一致性模型。

扩散模型很成功,但也有一块重大短板:采样速度非常慢,生成一个样本往往需要执行成百上千步采样。为此,研究社区已经提出了多种扩展蒸馏(diffusion distillation)技术,包括直接蒸馏、对抗蒸馏、渐进式蒸馏和变分分数蒸馏(VSD)。但是,这些方法也有自己的问题,包括成本高、复杂性高、多样性有限等。

一致性模型(CM)在解决这些问题方面具有巨大的优势。这又进一步分为离散时间 CM 和连续时间 CM。其中离散时间 CM 会引入离散化误差,并且需要仔细调度时间步长网格,这可能会导致样本质量不佳。而连续时间 CM 虽可避免这些问题,但也会有训练不稳定的问题。

近日,OpenAI 的研究科学家路橙(Cheng Lu)与战略探索团队负责人宋飏(Yang Song)发布了一篇研究论文,提出了一些可简化、稳定化和扩展连续时间一致性模型的技术。值得一提的是,这两位作者都是清华校友,师从朱军教授,在扩散概率模型领域做出过代表性工作。
图片
  • 论文标题:Simplifying, Stabilizing & Scaling Continuous-Time Consistency Models
  • 论文地址:https://arxiv.org/pdf/2410.11081v1

他们的贡献包括:

  • TrigFlow,一个将 EDM(arXiv:2206.00364)与流匹配(Flow Matching)统一起来的公式,其能极大简化扩展模型、相关的概率流 ODE 和一致性模型(CM。
  • 在此基础上,他们分析了一致性模型训练不稳定的根本原因,并提出了一种完整的缓解方案。他们的方法包括改进网络架构中的时间调节和自适应分组归一化。
  • 此外,他们还重新构建了连续时间 CM 的训练目标,其中整合了关键项的自适应加权和归一化以及渐进退火,以实现稳定且可扩展的训练。

简化连续时间一致性模型

作为前提,这里先给出离散时间和连续时间一致性模型的公式:

离散时间 CM:
图片
连续时间 CM:
图片
图片
此前的一致性模型采用了 EDM 中的模型参数化和扩散过程。具体来说,一致性模型会被参数化以下形式:
图片
其中,F 是一个神经网络,θ 是其参数;c_skip、c_out、c_in 都是固定的系数,用以确保在所有时间步骤上初始化时扩散目标的方差相等;c_noise 是对 t 的一个变换运算,以便更好地实现时间调节。

由于在 EDM 扩散过程中,方差会爆炸式增长,也就意味着 x_t = x_0 + tz_t,基于此可以推导出下面三式:
图片
虽然这些系数对于训练效率很重要,但由于它们与 t 和 σ_d 之间存在复杂的算术关系,因此会使得对一致性模型的理论分析变得复杂。

为了简化 EDM 及随之的一致性模型,他们提出了 TrigFlow。这种扩散模型形式保留了 EDM 性质,但满足 c_skip (t) = cos (t)、c_out (t) = sin (t)、c_in (t) ≡ 1/σ_d。

TrigFlow 是流匹配(也称为随机插值或整流)和 v 预测参数化的一种特例。它与之前一些研究团队提出的三角插值非常相似,但经过修改从而纳入了对数据分布 p_d 的标准差 σ_d 的考量。

由于 TrigFlow 是流匹配的一个特例,同时满足 EDM 原理,因此其集两者之长,同时还让扩散过程、扩散模型参数化、PF-ODE、扩散训练目标和一致性模型参数化全都变得更简单了。
图片
让连续时间一致性模型变得稳定

连续时间 CM 的训练一直都高度不稳定。因此,它们的表现一直不及之前研究中的离散时间 CM。

为了解决这个问题,该团队在 TrigFlow 框架的基础上,引入了几项基于理论研究的改进措施,其中重点关注的是参数化、网络架构和训练目标。
 
参数化和网络架构

连续时间 CM 的训练的关键是 (2) 式,其取决于正切函数图片,而在 TrigFlow 框架下,该正切函数由以下公式给出:
图片
其中 图片表示 PF-ODE,其要么在一致性蒸馏中使用预训练的扩散模型估计得出,要么就在一致性训练中使用从噪声和干净样本计算得到的无偏估计器估计得出。

为了让训练过程稳定下来,必须确保 (6) 式中的正切函数在不同时间步骤中保持稳定。该团队在实践中发现 σ_dF_θ、PF-ODE 和噪声样本 x_t 能保持相对稳定。进一步分析后,他们发现图片通常经过良好调节。因此不稳定的来源是时间导数 图片,而其可被分解为:
图片
其中,emb (・) 是指时间嵌入,通常以位置嵌入或傅里叶嵌入的形式出现在扩散模型和 CM 的相关文献中。

该团队稳定 (7) 中每个元素的方法包括恒等时间变换(c_noise (t) = t)、位置时间嵌入和自适应双重归一化。详见原论文。

图 4 可视化地展示了在 CIFAR-10 上训练 CM 时稳定时间导数的情况。研究表明,这些改进可在不损害扩散模型训练的前提下稳定 CM 的训练动态。
图片
训练目标

使用 TrigFlow 和前述的优化技术,(2) 式中连续时间 CM 训练的梯度就会变为:
图片
 之后,该团队又使用了另外一些技术来显式地控制该梯度,以提升稳定性,其中包括正切归一化、自适应加权、扩散微调和正切预热。详见原论文。

有了这些技术,离散时间和连续时间 CM 训练的稳定性都能得到显著改善。

该团队在相同的设置下训练了连续时间 CM 和离散时间 CM。如图 5 (c) 所示,增加离散时间 CM 中的离散化步骤数 N 可提高样本质量,原因是这样做可减少离散化误差来;但一旦 N 变得太大(N > 1024 之后),样本质量就会降低,这是因为会出现数值精度问题。
图片
相较之下,在所有 N 值上,连续时间 CM 的表现都显著优于离散时间 CM。这能为我们提供选择连续时间 CM 的强有力依据。

该团队将他们的模型称为 sCM,其中 s 代表 simple、stable、scalable,即简单、稳定和可扩展。下面是 sCM 训练的详细伪代码
图片
扩展连续时间一致性模型

在这部分,研究者通过在各种具有挑战性的数据集上训练大规模 sCM 来测试上述内容中提出的所有改进措施。

大规模模型中的正切计算

训练大规模扩散模型的常见设置包括使用半精度(FP16)和 Flash Attention。由于训练连续时间 CM 需要精确计算正切 图片,我们需要提高数值精度,同时支持高效记忆注意力计算,详情如下。

计算图片时,需要计算 图片,这可通过图片与输入向量 图片和正切向量 图片的雅可比向量积(JVP)高效获得。然而根据经验,当 t 接近 0 或 π/2 时,正切可能会在中间层溢出。为了提高数值精度,研究者认为应该重新安排正切的计算。

具体来说,由于公式 (8) 中的目标函数包含图片,而 图片图片成正比,因此可以这样计算 JVP:
图片
图片的 JVP,输入为图片,正切为图片这种重新排列大大缓解了中间层的溢出问题,使 FP16 的训练更加稳定。

Flash Attention 被广泛用于大规模模型训练中的注意力计算,既能节省 GPU 内存,又能加快训练速度。然而,Flash Attention 并不计算雅可比向量积(JVP)。为了填补这一空白,研究者提出了一种类似的算法,它能以 Flash Attention 的风格在一次前向传递中高效计算 softmax 自注意力及其 JVP,从而显著减少注意力层中 JVP 计算所需的 GPU 内存用量。

实验
 
sCM 的训练计算。研究者在所有数据集上使用与教师扩散模型相同的批大小。sCD 每次训练迭代的有效计算量大约是教师模型的两倍。他们观察到,sCD 的两步采样质量收敛很快,只用了不到教师模型 20% 的训练计算量,就获得了与教师扩散模型相当的结果。在实践中,只需使用 sCD 进行 20k 次微调迭代,就能获得高质量的样本。
 
基准。在表 1 和表 2 中,研究者通过 FID 和函数评估次数(NFE)的基准,将本文结果与之前的方法进行了比较。首先,sCM 优于之前所有不依赖与其他网络联合训练的几步式方法,与对抗训练取得的最佳结果相当,甚至超越。值得注意的是,sCD-XXL 在 ImageNet 512×512 上的一步 FID 超过了 StyleGAN-XL 和 VAR。此外,sCD-XXL 的两步 FID 性能优于除扩散模型外的所有生成模型,可与需要 63 个连续步骤的最佳扩散模型相媲美。其次,两步式 sCM 模型将与教师扩散模型的 FID 差距显著缩小到 10% 以内。此外,sCT 在较小的扩展上更有效,但在较大扩展上的方差会增大,而 sCD 在小型扩展和大型扩展上都表现出一致的性能。
图片
图片
Scaling 研究。如图 6 所示,首先,随着模型 FLOPs 的增加,sCT 和 sCD 的样本质量都有所提高,这表明这两种方法都能从 Scaling 中获益。其次,与 sCD 相比,sCT 在较小分辨率下的计算效率更高,但在较大分辨率下的效率较低。第三,对于给定的数据集,sCD 的 Scaling 是可预测的,在不同大小的模型中,FID 的相对差异保持一致。这表明,sCD 的 FID 下降速度与教师扩散模型相同,因此,sCD 与教师扩散模型一样具有可扩展性。随着教师扩散模型的 FID 随规模的扩大而减小,sCD 与教师模型之间 FID 的绝对差异也随之减小。最后,FID 的相对差异随着采样步骤的增加而减小,两步式 sCD 的采样质量与教师扩散模型相当。
图片
与 VSD 的对比。如图 7 所示,研究者对比了 sCD、VSD、sCD 和 VSD 的组合等(通过简单地将两种损失相加)并观察到,VSD 具有与扩散模型中应用大 guidance scale 类似的人工效应:它提高了保真度(表现为更高的精确度分数),同时降低了多样性(表现为更低的召回分数)。这种效应随着 guidance scale 的增加而变得更加明显,最终导致严重的模式崩溃。相比之下,两步式 sCD 的精确度和召回分数与教师扩散模型相当,因此 FID 分数比 VSD 更高。

更多研究细节,可参考原论文。
工程一致性模型
相关数据
朱军人物

朱军,清华大学计算机系长聘副教授、卡内基梅隆大学兼职教授。2001 到 2009 年获清华大学计算机学士和博士学位,之后在卡内基梅隆大学做博士后,2011 年回清华任教。主要从事人工智能基础理论、高效算法及相关应用研究,在国际重要期刊与会议发表学术论文百余篇。担任人工智能顶级杂志 IEEE TPAMI 和 AI 的编委、《自动化学报》编委,担任机器学习国际大会 ICML2014 地区联合主席, ICML (2014-2018)、NIPS (2013, 2015, 2018)、UAI (2014-2018)、IJCAI(2015,2017)、AAAI(2016-2018)等国际会议的领域主席。获 CCF 自然科学一等奖、CCF 青年科学家奖、国家优秀青年基金、中创软件人才奖、北京市优秀青年人才奖等,入选国家「万人计划」青年拔尖人才、MIT TR35 中国区先锋者、IEEE Intelligent Systems 杂志评选的「AI's 10 to Watch」(人工智能青年十杰)、及清华大学 221 基础研究人才计划。

调度技术

调度在计算机中是分配工作所需资源的方法。资源可以指虚拟的计算资源,如线程、进程或数据流;也可以指硬件资源,如处理器、网络连接或扩展卡。 进行调度工作的程序叫做调度器。调度器通常的实现使得所有计算资源都处于忙碌状态,允许多位用户有效地同时共享系统资源,或达到指定的服务质量。 see planning for more details

自注意力技术

自注意力(Self-attention),有时也称为内部注意力,它是一种涉及单序列不同位置的注意力机制,并能计算序列的表征。自注意力在多种任务中都有非常成功的应用,例如阅读理解、摘要概括、文字蕴含和语句表征等。自注意力这种在序列内部执行 Attention 的方法可以视为搜索序列内部的隐藏关系,这种内部关系对于翻译以及序列任务的性能非常重要。

基准技术

一种简单的模型或启发法,用作比较模型效果时的参考点。基准有助于模型开发者针对特定问题量化最低预期效果。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

收敛技术

在数学,计算机科学和逻辑学中,收敛指的是不同的变换序列在有限的时间内达到一个结论(变换终止),并且得出的结论是独立于达到它的路径(他们是融合的)。 通俗来说,收敛通常是指在训练期间达到的一种状态,即经过一定次数的迭代之后,训练损失和验证损失在每次迭代中的变化都非常小或根本没有变化。也就是说,如果采用当前数据进行额外的训练将无法改进模型,模型即达到收敛状态。在深度学习中,损失值有时会在最终下降之前的多次迭代中保持不变或几乎保持不变,暂时形成收敛的假象。

伪代码技术

伪代码,又称为虚拟代码,是高层次描述算法的一种方法。它不是一种现实存在的编程语言;它可能综合使用多种编程语言的语法、保留字,甚至会用到自然语言。 它以编程语言的书写形式指明算法的职能。相比于程序语言它更类似自然语言。它是半形式化、不标准的语言。

导数技术

导数(Derivative)是微积分中的重要基础概念。当函数y=f(x)的自变量x在一点x_0上产生一个增量Δx时,函数输出值的增量Δy与自变量增量Δx的比值在Δx趋于0时的极限a如果存在,a即为在x0处的导数,记作f'(x_0) 或 df(x_0)/dx。

神经网络技术

(人工)神经网络是一种起源于 20 世纪 50 年代的监督式机器学习模型,那时候研究者构想了「感知器(perceptron)」的想法。这一领域的研究者通常被称为「联结主义者(Connectionist)」,因为这种模型模拟了人脑的功能。神经网络模型通常是通过反向传播算法应用梯度下降训练的。目前神经网络有两大主要类型,它们都是前馈神经网络:卷积神经网络(CNN)和循环神经网络(RNN),其中 RNN 又包含长短期记忆(LSTM)、门控循环单元(GRU)等等。深度学习是一种主要应用于神经网络帮助其取得更好结果的技术。尽管神经网络主要用于监督学习,但也有一些为无监督学习设计的变体,比如自动编码器和生成对抗网络(GAN)。

插值技术

数学的数值分析领域中,内插或称插值(英语:interpolation)是一种通过已知的、离散的数据点,在范围内推求新数据点的过程或方法。求解科学和工程的问题时,通常有许多数据点借由采样、实验等方法获得,这些数据可能代表了有限个数值函数,其中自变量的值。而根据这些数据,我们往往希望得到一个连续的函数(也就是曲线);或者更密集的离散方程与已知数据互相吻合,这个过程叫做拟合。

目标函数技术

目标函数f(x)就是用设计变量来表示的所追求的目标形式,所以目标函数就是设计变量的函数,是一个标量。从工程意义讲,目标函数是系统的性能标准,比如,一个结构的最轻重量、最低造价、最合理形式;一件产品的最短生产时间、最小能量消耗;一个实验的最佳配方等等,建立目标函数的过程就是寻找设计变量与目标的关系的过程,目标函数和设计变量的关系可用曲线、曲面或超曲面表示。

对抗训练技术

对抗训练涉及两个模型的联合训练:一个模型是生成器,学习生成假样本,目标是骗过另一个模型;这另一个模型是判别器,通过对比真实数据学习判别生成器生成样本的真伪,目标是不要被骗。一般而言,两者的目标函数是相反的。

生成模型技术

在概率统计理论中, 生成模型是指能够随机生成观测数据的模型,尤其是在给定某些隐含参数的条件下。 它给观测值和标注数据序列指定一个联合概率分布。 在机器学习中,生成模型可以用来直接对数据建模(例如根据某个变量的概率密度函数进行数据采样),也可以用来建立变量间的条件概率分布。

算术技术

算术(英语:arithmetic)是数学最古老且最简单的一个分支,几乎被每个人使用着,从日常生活上简单的算数到高深的科学及工商业计算都会用到。一般而言,算术这一词指的是记录数字某些运算基本性质的数学分支。

推荐文章
暂无评论
暂无评论~