0%

2024-CVPR-One-Step Diffusion Distillation through Score Implicit Matching论文精读

全文翻译

摘要

尽管扩散模型在许多生成任务上表现出色,但它们需要大量的采样步骤才能生成逼真的样本。这促使社区开发有效的方法,将预训练的扩散模型蒸馏为更高效的模型,但这些方法通常仍需要少步推理,或者性能明显低于基础模型。在本文中,我们提出了分数隐式匹配(SIM),这是一种将预训练扩散模型蒸馏为单步生成器模型的新方法,同时保持与原始模型几乎相同的样本生成能力,并且无需数据——蒸馏过程不需要训练样本。该方法基于这样一个事实:尽管对于生成器模型来说,传统的基于分数的损失难以最小化,但在特定条件下,我们可以高效地计算扩散模型和生成器之间广泛类别的基于分数的散度的梯度。SIM在单步生成器方面表现出强大的实证性能:在CIFAR10数据集上,其无条件生成的FID为2.06,类条件生成的FID为1.96。此外,通过将SIM应用于领先的基于Transformer的扩散模型,我们蒸馏出用于文本到图像(T2I)生成的单步生成器,其美学分数达到6.42,与原始多步模型相比没有性能下降,明显优于其他单步生成器,包括SDXL-TURBO(5.33)、SDXL-LIGHTNING(5.34)和HYPER-SDXL(5.85)。我们将随本文发布这种适用于工业界的基于Transformer的单步T2I生成器。

1 引言

在过去的几年里,扩散模型(DMs)[20, 66, 64]在从数据合成[24, 25, 50, 51, 21, 55, 22, 30]到密度估计[31, 7],从文本到图像生成[53, 59, 2, 79, 6]、文本到3D创作[55, 73, 27, 33]、图像编辑[46, 8, 18, 1, 29, 48],以及其他领域[82, 78, 4, 84, 17, 58, 13, 72, 88, 71, 41, 77, 43, 83, 12, 10, 45, 15, 70, 54, 9]等广泛的应用中取得了显著的进展。从高层次的角度来看,扩散模型(也称为基于分数的扩散模型)使用扩散过程来破坏数据分布,然后经过训练以近似不同噪声水平下噪声数据分布的分数函数。

扩散模型具有多个优点,如训练灵活性、可扩展性以及生成高质量样本的能力,这使其成为现代AIGC模型的首选。训练完成后,学习到的分数函数可用于逆转数据破坏过程,这可以通过数值求解相关的随机微分方程来实现。这种数据生成机制通常需要多次神经网络评估,这导致了扩散模型的一个显著限制:当采样步骤数量减少时,扩散模型的生成性能会大幅下降。这一缺点限制了扩散模型的实际部署,尤其是在需要快速推理的场景中,例如在移动电话和边缘设备等计算能力有限的设备上,或在需要快速响应时间的应用中。

这一挑战促使人们提出了各种方法,旨在加快扩散模型的采样过程,同时保留其强大的生成能力。特别是蒸馏方法,专注于应用蒸馏算法,将知识从预训练的教师扩散模型转移到高效的学生生成模型,这些学生模型能够在少数生成步骤内生成高质量的样本。

一些工作从概率散度最小化的角度研究了扩散蒸馏算法。例如,Luo等人[42]、Yin等人[81]研究了最小化教师模型和单步学生模型之间KL散度的算法。Zhou等人[92]探索了使用Fisher散度进行蒸馏,取得了令人印象深刻的实证性能。尽管这些研究在理论和实证方面都为社区做出了贡献,并提供了可应用的单步生成器模型,但它们的理论是建立在特定的散度(即Kullback-Leibler散度和Fisher散度)之上的,这可能限制了蒸馏性能。目前仍然缺乏一个更通用的框架来理解和改进扩散蒸馏。

在这项工作中,我们引入了分数隐式匹配(SIM),这是一种将预训练扩散模型蒸馏为单步生成器网络,同时保持高质量生成的新框架。为此,我们针对生成器模型的(难以处理的)分数函数与原始扩散模型的分数函数之间的任意距离函数,提出了一类广泛而灵活的基于分数的散度。这项工作的关键技术见解是,尽管此类散度无法显式计算,但我们可以使用我们称为分数梯度定理的结果精确计算这些散度的梯度,从而实现散度的隐式最小化。这使我们能够基于此类散度高效地训练模型。

我们使用不同的距离函数选择来定义散度,将SIM的性能与先前的方法进行了评估。最相关的是,我们将SIM与使用基于KL散度项的Diff-Instruct(DI)[42]方法,以及Score Identity Distillation(SiD)方法[92]进行了比较。我们表明,当距离函数简单地选择为平方L₂距离时,SiD是我们方法的一个特例(尽管推导方式完全不同)。我们还通过实证表明,使用专门设计的Pseudo-Huber距离函数的SIM比L₂距离表现出更快的收敛速度和更强的超参数鲁棒性,使得所得到的方法明显优于先前的方法。

最后,我们表明,相对于该领域过去在CIFAR10图像生成和文本到图像生成方面的工作,SIM在绝对性能上取得了非常强的实证结果。在CIFAR10数据集上,SIM展示了单步生成性能,无条件生成的Frechet Inception Distance(FID)为2.06,类条件生成的FID为1.96。更定性地说,蒸馏一个领先的基于扩散Transformer的[52]文本到图像扩散模型,得到了一个能力极强的单步文本到图像生成器,我们表明其在生成性能方面与教师扩散模型几乎没有损失。特别是,通过将SIM应用于PixelArt-α[6],蒸馏出的单步生成器达到了6.42的出色美学分数,与原始多步扩散模型相比没有性能下降。这显著优于其他单步文本到图像生成器,包括SDXL-TURBO[63](5.33)、SDXL-LIGHTNING[34](5.34)和HYPER-SDXL[56](5.85)。这一结果不仅标志着单步文本到图像生成的新方向,还激发了对其他领域(如视频生成)中基于扩散Transformer的AIGC模型进行蒸馏的进一步研究。

2 扩散模型

在本节中,我们介绍关于扩散模型和扩散蒸馏的预备知识和符号表示。假设我们从潜在分布 $q_d(x)$ 中观察数据,生成式建模的目标是训练模型以生成新样本 $x \sim q_d(x)$。扩散模型的前向扩散过程将任意初始分布 $q_0 = q_d$ 转换为某个简单的噪声分布,其表达式为:

其中 $F$ 是预定义的漂移函数,$G(t)$ 是预定义的标量值扩散系数,$w_t$ 表示独立的维纳过程。连续索引的分数网络 $s_\varphi(x, t)$ 用于近似前向扩散过程(2.1)的边缘分数函数。分数网络的学习通过最小化加权去噪分数匹配目标来实现 [69, 66],即:

这里的加权函数 $\lambda(t)$ 控制不同时间水平的学习重要性,$q_t(x_t | x_0)$ 表示前向扩散(2.1)的条件转移。训练完成后,分数网络 $s_\varphi(x_t, t) \approx \nabla_{x_t} \log q_t(x_t)$ 能够很好地近似扩散数据分布的边缘分数函数。扩散模型的高质量样本可以通过模拟由学习到的分数网络实现的随机微分方程来生成 [66]。然而,随机微分方程的模拟明显慢于其他模型(如单步生成器模型)的模拟。

3 分数隐式匹配

在本节中,我们介绍分数隐式匹配(Score Implicit Matching,SIM),这是一种专为基于分数的扩散模型单步蒸馏设计的通用方法。我们首先介绍问题设置和符号表示,然后引入一类通用的基于分数的概率散度,并展示如何使用SIM来最小化所述散度。最后,我们讨论该方法的具体选择(如距离函数的选择),并探究其对蒸馏效果的影响。

3.1 问题设置

我们的起点是一个由分数函数定义的预训练扩散模型:

其中,$q_t(x_t)$ 是根据公式(2.1)在时间 $t$ 扩散的潜在分布。我们假设预训练扩散模型能充分近似数据分布,因此是我们方法的唯一考虑对象。

目标学生模型是单步生成器网络 $g_\theta$,它可以将初始随机噪声 $z \sim p_z$ 转换为样本 $x = g_\theta(z)$,该网络由参数 $\theta$ parameterized。令 $p_{\theta, 0}$ 表示学生模型的数据分布,$p_{\theta, t}$ 表示学生模型在相同扩散过程(2.1)下的边缘扩散数据分布。学生分布隐式诱导出分数函数:

对其进行评估通常需要训练一个替代分数网络,如后文所述。

3.2 通用基于分数的散度

单步扩散蒸馏的目标是让学生分布 $p_{\theta, 0}$ 匹配数据分布 $q_0$。为此,我们提出在所有扩散时间水平上匹配扩散边缘分布 $p_{\theta, t}$ 和 $q_t$。我们可以通过以下通用的基于分数的散度来定义这一目标:假设 $d: \mathbb{R}^d \to \mathbb{R}$ 是一个标量值的适当距离函数(即满足 $d(x) \geq 0$ 且当且仅当 $x=0$ 时 $d(x)=0$)。给定一个采样分布 $\pi_t$(其分布支撑集大于 $p_t$ 和 $q_t$),我们可以正式定义时间积分分数散度为:

其中,$p_t$ 和 $q_t$ 分别表示以 $p$ 和 $q$ 为初始值的扩散过程(2.1)在时间 $t$ 的边缘密度,$w(t)$ 是积分加权函数。显然,当且仅当所有边缘分数函数一致时,$\mathcal{D}^{[0, T]}(p, q) = 0$,这意味着 $p_0(x_t) = q_0(x_t)$ 几乎处处成立(关于 $\pi_0$)。

3.3 分数隐式匹配

基于上述动机,我们希望最小化 $p_\theta$ 和 $q$ 之间的积分分数散度,以训练学生模型,即:

其中假设分布 $\pi_t$ 不依赖于参数 $\theta$,例如 $\psi_t(x_t) = p_{sg[\theta]}(x_t)$($sg[\theta]$ 表示截断 $\theta$ 参数依赖的停止梯度算子)。对 $\theta$ 求梯度可得:

其中 $d’$ 表示 $d$ 对输入的导数,即 $\nabla_y d(y)$。不幸的是,由于分数函数难以处理,直接计算 $\frac{\partial}{\partial \theta} s_{p_{\theta, t}(x_t)}$ 是不可能的,这使得直接方法不切实际。

幸运的是,本文的一个关键发现是:如果我们将采样分布选择为扩散隐式分布,即 $\pi_t = p_{sg[\theta]}$(其中 $sg[\theta]$ 表示截断 $\theta$ 参数依赖的停止梯度算子),则损失函数(3.4)及其难以处理的梯度(3.5)可以通过一个梯度等价的损失高效最小化。这依赖于定理3.1:

定理3.1(分数散度梯度定理):若分布 $p_{\theta, t}$ 满足某些温和的正则条件,则对于任意分数函数 $s_{q_t}(.)$,对所有参数 $\theta$ 成立:

这里的关键观察是:我们用右侧更易计算的分数函数评估替换了左侧分数函数的难处理梯度,后者可以通过一个单独的近似网络更轻松地完成。该定理可通过分数投影恒等式 [69, 92] 证明,该恒等式最初用于桥接去噪分数匹配和去噪自编码器。然而,证明定理3.1的关键在于通过适当停止定理中所示的梯度,合理选择 $\theta$ 参数的(非)依赖性。详细证明见附录A.1。

现在,我们可以给出用于训练隐式生成器 $g_\theta$ 的目标函数。(3.6)的直接结果是,梯度(3.5)可以通过最小化一个可处理的损失函数来实现:

其中 $y_t := s_{p_{sg[\theta], t}}(x_t) - s_{q_t}(x_t)$。根据定理3.1,这个替代损失与原始损失的梯度相同,且无需访问分数网络的梯度。

在实践中,我们可以使用另一个在线扩散模型 $s_\psi(x_t, t)$ 逐点近似生成器模型的分数函数 $s_{p_{sg[\theta], t}}(x_t)$,这与之前的工作(如Luo等人 [42]、Zhou等人 [92] 和Yin等人 [81])一致。我们将最小化(3.7)中目标函数 $\mathcal{L}_{SIM}(\theta)$ 的蒸馏方法称为分数隐式匹配(SIM),因为学习过程隐式地将隐式学生模型的难处理边缘分数函数 $s_{p_{\theta, t}}(.)$ 与预训练扩散模型的显式分数函数 $s_{q_t}(.)$ 进行匹配。

SIM的完整算法如算法1所示,该算法通过两个交替阶段训练学生模型:学习边缘分数函数 $s_\psi$,以及使用梯度(3.7)更新生成器模型。前一阶段遵循标准的扩散模型学习流程,即最小化去噪分数匹配损失函数(2.2),仅略微调整为从生成器生成样本。得到的 $s_\psi(x_t, t)$ 为 $s_{p_{sg[\theta], t}}(x_t)$ 提供了良好的逐点估计。后一阶段通过最小化损失函数(3.7)更新生成器参数 $\theta$,其中两个所需函数由预训练扩散模型 $s_{q_t}(x_t)$ 和学习的扩散模型 $s_\psi(x_t, t)$ 提供。

3.4 分数隐式匹配的实例

前一节介绍了SIM算法,但未选择特定的距离函数 $d(.)$。这里我们讨论不同的选择及其对蒸馏过程的影响,并表明在SIM框架中,SiD可视为一个特例。

距离函数 $d(.)$ 的设计选择:显然,不同的距离函数 $d(.)$ 会导致不同的蒸馏算法。最自然的选择可能是简单的平方距离,即 $d(y_t) = | y_t |_2^2$,其导数项为 $d’(y_t) = 2y_t$。事实上,这种损失函数重新得到了SiD [92] 中研究的delta损失,其中作者通过实验发现该损失函数效果良好(尽管推导方式截然不同)。因此,SiD实际上是SIM的一个特例,尽管SiD的推导并未暗示如何使用其他损失函数。二次形式的直接推广是 $\alpha$-范数的 $\alpha$ 次幂($\alpha > 1$ 且为偶数),此时距离函数为 $d(y_t) = \alpha y_t^{(\alpha-1)}$,对应的损失函数总结在附录A.3的表4中。

Pseudo-Huber距离函数:不同于幂范数,我们引入带Pseudo-Huber距离函数的SIM,其定义为 $d(y) := \sqrt{| y_t |_2^2 + c^2} - c$,其中 $c$ 是预定义的正常数。对应的蒸馏目标为:

除非另有说明,本文其余部分将Pseudo-Huber距离作为默认选择。由于篇幅限制,我们在表4中总结了不同距离函数的选择及其对应的损失函数和推导,并在附录A.3中进行了更多讨论。特别地,与SiD(表4中的 $L^2$ 情况)不同,在SIM中使用Pseudo-Huber距离时,我们观察到向量 $y_t$ 通过除以向量的平方根自然自适应归一化。这种归一化可以稳定训练损失,从而实现鲁棒且快速收敛的蒸馏过程。在4.1节中,我们通过实验展示了三个优势:对大学习率的鲁棒性、快速收敛性和改进的性能。

3.5 相关工作

扩散蒸馏 [40] 是一个旨在利用教师扩散模型降低生成成本的研究领域,主要包括三种蒸馏方法:

  1. 轨迹蒸馏:该方法训练学生模型以更少的步骤模拟扩散模型的生成过程。直接蒸馏([38, 14])和渐进蒸馏([60, 47])变体从噪声输入预测更少噪声的数据;基于一致性的方法([67, 28, 65, 35, 16])最小化自一致性度量,这些方法需要真实数据样本进行训练。
  2. 分布匹配:专注于使学生的生成分布与教师扩散模型的分布对齐。其中包括需要真实数据来蒸馏扩散模型的对抗训练方法([75, 76]),以及另一类重要方法——尝试最小化KL散度([81])(如Diff-Instruct (DI) [44, 81])和Fisher散度(如Score Identity Distillation (SiD) [92]),通常无需真实数据。尽管SIM从SiD和DI中获得启发,但与它们的差距显著:SIM不仅提供了坚实的数学基础(可能有助于深入理解扩散蒸馏),还提供了使用不同距离函数的灵活性,当使用特定的Pseudo-Huber距离时,可实现强大的实证性能。
  3. 其他方法:算子学习([85])和ReFlow([36])为蒸馏提供了替代见解。此外,许多工作致力于将扩散蒸馏扩展到单步文本到图像生成等领域[39, 49, 68, 81, 91]。

4 实验

4.1 单步CIFAR10生成

实验设置:在本实验中,我们将SIM应用于CIFAR10[32]数据集,将预训练的EDM[25]扩散模型蒸馏为单步生成器模型。我们遵循与DI[42]和SiD[92]相同的设置,将扩散模型蒸馏为单步生成器,具体细节见附录B.2。我们参考SiD[92]的高质量代码库,通过在我们的设备上严格遵循其配置来复现结果,同时也在相同实验设置下重新实现了DI。

性能表现:我们通过Frechet Inception Distance(FID)[19]评估训练生成器的性能,FID值越低越好。我们参考[42]中的评估协议进行比较。表1和表2总结了CIFAR10数据集上生成模型的FID。我们在与SIM相同的计算环境和评估协议下复现了SiD和DI,以进行公平比较。表的上半部分模型与EDM模型架构或扩散模型不同,而下半部分模型与教师EDM扩散模型架构完全相同,因此可直接比较。

方法 NFE(↓) FID(↓)
与EDM模型架构不同
DDPM [20] 1000 3.17
DD-GAN(T=2) [75] 2 4.08
KD [38] 1 9.36
TDPM [89] 1 8.91
DFNO [87] 1 4.12
3-REFLOW(+DISTILL) [36] 1 5.21
STYLEGAN2-ADA [23] 1 2.92
STYLEGAN2-ADA+DI [42] 1 2.71
与EDM[25]模型架构相同
EDM [25] 35 1.97
EDM [25] 15 5.62
PD [60] 2 5.13
CD [67] 2 2.93
GET [14] 1 6.91
CT [67] 1 8.70
ICT-DEEP [65] 2 2.24
DIFF-INSTRUCT [42] 1 4.53
DMD [81] 1 3.77
CTM [28] 1 1.98
CTM[28] 2 1.87
SID(α=1.0) [92] 1 1.92
SIDα=1.2 1 2.02
DI † 1 3.70
SID †(α=1.0) 1 2.20
SIM(OURS) 1 2.06

表1:CIFAR10无条件样本质量。†表示我们复现的方法。

方法 NFE(↓) FID(↓)
与EDM模型架构不同
BIG GAN [3] 1 14.73
BIG GAN+TUNE [3] 1 8.47
STYLE GAN2 [24] 1 6.96
MULTI HINGE [26] 1 6.40
FQ-GAN [86] 1 5.59
STYLE GAN2-ADA [23] 1 2.42
STYLE GAN2-ADA+DI [42] 1 2.27
STYLE GAN2 + SMART [74] 1 2.06
STYLE GAN-XL [62] 1 1.85
与EDM[25]模型架构相同
EDM [25] 35 1.82
EDM [25] 20 2.54
EDM [25] 10 15.56
EDM [25] 1 314.81
GET [14] 1 6.25
DIFF-INSTRUCT [42] 1 4.19
DMD(W.O.REG) [81] 1 5.58
DMD(W.O.KL) [81] 1 3.82
DMD [81] 1 2.66
CTM [28] 1 1.73
CTM[28] 2 1.63
SID(α=1.0) [92] 1 1.93
SIDα=1.2 1 1.71
SID †(α=1.0) 1 2.34
SIM(OURS) 1 1.96

表2:CIFAR10数据集类条件样本质量。†表示我们复现的方法。

如表1所示,在CIFAR10无条件生成任务中,所提出的SIM仅用单步生成就实现了2.06的FID,在相同评估设置下优于SiD和DI,性能与CTM相当,且SiD的官方实现尚未发布。对于表2中的CIFAR10类条件生成,SIM达到1.96的FID,表现处于顶级模型之列。

SIM蒸馏的T2I生成器优于其他工业级模型:CIFAR-10生成任务相对简单,仅在有限容量的扩散模型和简单数据集上进行。我们将在文本到图像生成任务中对顶级基于Transformer的扩散模型进行蒸馏实验,展示单步模型的能力。在此之前,我们先深入了解SIM在CIFAR-10上相比SiD和DI的优势——对大学习率的鲁棒性和更快的收敛速度,这将为蒸馏方法如何扩展到具有更大神经网络的复杂任务提供启示。

对大学习率的鲁棒性:我们在相同设置下应用SIM、SiD和DI,从EDM蒸馏CIFAR10无条件生成任务,学习率为1e-4,并在图2中绘制FID和Inception Score[61]。DI和SiD即使在训练早期也不稳定,而SIM即使在大学习率下也能稳定收敛。潜在原因是SIM自然地对损失目标进行归一化,使其在训练过程中规模不会突然变化。这在训练大型模型时使SIM区别于SiD,因为训练现代大型模型成本高昂,研究人员在预算内很少有机会调整超参数。

快速收敛:SIM的第二个优势是比SiD收敛更快。为证明这一点,我们在CIFAR10无条件生成上遵循与SiD相同的设置。如图2所示,在所有配置下,SIM在相同训练迭代次数下始终表现出更好的FID和Inception Score。由于篇幅限制,更多细节见附录B.2。

CIFAR10生成实验表明,SIM是一种强大、鲁棒且收敛迅速的单步扩散蒸馏算法。然而,SIM的能力不仅限于CIFAR-10基准测试。在4.2节中,我们将SIM应用于蒸馏基于0.6B DiT[52]的文本到图像扩散模型,获得最先进的基于Transformer的单步生成器。

4.2 基于Transformer的单步文本到图像生成器

实验设置:近年来,基于Transformer的文本到X生成模型在图像生成(如Stable Diffusion V3[11])和视频生成(如Sora[5])等领域备受关注。在本节中,我们将SIM应用于蒸馏近期备受关注的开源基于DiT的扩散模型之一:0.6B PixelArt-α模型[6],其基于DiT模型[52]构建,在定量评估指标和主观用户研究方面均成为最先进的单步生成器。

实验设置与评估指标:单步蒸馏的目标是将扩散模型加速为单步生成,同时保持甚至超越教师扩散模型的性能。为验证我们的单步模型与扩散模型之间的性能差距,我们比较四个定量指标:美学分数、PickScore、图像奖励和用户研究比较分数。在SAM-LLaVA-Caption10M(原始PixelArt-α模型训练的数据集之一)上,我们比较SIM单步模型(称为SIM-DiT-600M)和使用14步DPM-Solver[37]的PixelArt-α模型,评估数据内性能差距。我们还在广泛使用的COCO-2017验证集上,将SIM-DiT-600M和PixelArt-α与其他少步模型(如LCM[39]、TCD[90]、PeReflow[80]和Hyper-SD[56]系列)进行比较,并参考Hyper-SD的评估协议计算评估指标。表3总结了所有模型的评估性能。对于针对PixArt-α和SIM-DiT-600M的人类偏好研究,我们从SAM Caption数据集随机选择17个提示,用PixArt-α和SIM-DiT-600M生成图像,然后让参与研究的用户根据图像质量和与提示的一致性选择偏好,图1展示了用户研究案例的可视化结果,其中难以区分PixArt-α和SIM-DiT-600M生成的图像。

模型 步数 类型 参数 美学分数 图像奖励 Pick分数 用户偏好 蒸馏成本
SD15-BASE [57] 25 UNET 860M 5.26 0.18 0.217
SD15-LCM [39] 4 UNET 860M 5.66 -0.37 0.212 8 A100×4天
SD15-TCD [90] 4 UNET 860M 5.45 -0.15 0.214 8 A800×5.8天
PERFLOW [80] 4 UNET 860M 5.64 -0.35 0.208 M GPU×N天
HYPER-SD15[56] 1 UNET 860M 5.79 0.29 0.215 32 A100×N天
SDXL-BASE [57] 25 UNET 2.6B 5.54 0.87 0.229
SDXL-LCM [39] 4 UNET 2.6B 5.42 0.48 0.224 8 A100×4天
SDXL-TCD [90] 4 UNET 2.6B 5.42 0.67 0.226 8 A800×5.8天
SDXL-LIGHTNING [34] 4 UNET 2.6B 5.63 0.72 0.229 64 A100×N天
HYPER-SDXL[56] 4 UNET 2.6B 5.74 0.93 0.232 32 A100×N天
SDXL-TURBO [63] 1 UNET 2.6B 5.33 0.78 0.228 M GPU×N天
SDXL-LIGHTNING [34] 1 UNET 2.6B 5.34 0.54 0.223 64 A100×N天
HYPER-SDXL[56] 1 UNET 2.6B 5.85 1.19 0.231 32 A100×N天
PIXART-α [6] 30 DiT 610M 5.97 0.82 0.226
SIM-DiT-600M 1 DiT 610M 6.42 0.67 0.223 4 A100×2天
PIXART-α ∗ [6] 30 DiT 610M 5.93 0.53 0.223 54.88%
SIM-DiT-600M ∗ 1 DiT 610M 5.91 0.44 0.223 45.12% 4 A100×2天

表3:在COCO-2017验证集上与前沿文本到图像模型的定量比较。用户偏好是我们的用户研究中SIM-DiT-600M相对于20步PixelArt-α的胜率。∗表示在SAM-LLaVA-Caption10M数据集上评估的结果,SIM-DiT-600M指从PixelArt-α-600M蒸馏的SIM生成器,不包括T5文本编码器。蒸馏成本M GPU×N天表示模型未报告成本。

近乎无损的单步蒸馏:令人惊讶的是,SIM-DiT-600M与教师扩散模型相比几乎没有性能损失。例如,在表3的SAM Caption数据集上,SIM-DiT-600M恢复了PixelArt-α模型99.6%的美学分数和100%的PickScore,但图像奖励略低,这可能通过更多训练计算进一步优化。与领先的少步文本到图像模型(如SDXL-Turbo、SDXL-lightning和Hyper-SDXL)相比,SIM-DiT-600M以显著优势展现出主导的美学分数,同时具有不错的图像奖励和Pick分数。

除了顶尖性能,SIM-DiT-600M的训练成本也低得惊人。我们的最佳模型使用4个A100-80G GPU训练2天(无数据),而表3中的其他模型需要数百个A100 GPU天。我们在表3中总结了蒸馏成本,表明SIM是一种具有惊人扩展能力的超高效蒸馏方法。我们认为这种效率来自SIM的两个特性:首先,SIM是无数据的,使蒸馏过程无需真实图像数据;其次,Pseudo-Huber距离函数(3.3)的使用自适应地归一化损失函数,使其对超参数具有鲁棒性且训练稳定。

定性比较:图3将SIM-DiT-600M与其他领先的少步文本到图像生成模型进行了定性比较。显然,SIM-DiT-600M生成的图像美学性能高于其他模型,这与表3中SIM-DiT-600M达到高美学分数的定量结果一致。定量和定性结果均表明SIM-DiT-600M是性能最佳的单步文本到图像生成器,更多定性评估见补充材料。

单步SIM-DiT模型的失败案例:尽管SIM-DiT单步模型表现出色,但不可避免存在局限性。例如,我们发现0.6B的SIM-DiT单步模型有时难以生成高质量的微小人脸和正确的手臂手指,还可能生成物体数量错误或不完全符合提示的内容。我们认为扩大模型规模和教师扩散模型将有助于解决这些问题,失败案例的可视化见图4。

5 结论与未来工作

本文提出了一种新颖的扩散蒸馏方法——分数隐式匹配(SIM),该方法能够以无数据的方式将预训练的多步扩散模型转换为单步生成器。本文所介绍的理论基础和实用算法,使得单步生成器能够在各种领域和大规模应用中以更经济的方式部署,同时不影响基础生成模型的性能。

尽管如此,SIM仍存在一些局限性,需要进一步研究:首先,随着其他强大的预训练生成模型(如流匹配模型)的不断涌现,值得探索是否有可能将SIM的应用扩展到更广泛的生成模型家族。其次,尽管无数据是SIM的一个重要特性,但在SIM中引入新数据可以进一步提升教师模型生成失败图像的质量,这一潜在优势尚未被探索,我们希望这能简化大型生成模型的训练过程。

致谢

Zhengyang Geng 得到了博世人工智能中心的资助。Zico Kolter 衷心感谢博世对实验室的资助。

我们感谢 NeurIPS 2024 的审稿人及 AC/SAC/PC 成员提出的建设性建议。同时感谢 Diff-Instruct 和 Score-identity Distillation 的作者们为高质量扩散蒸馏 Python 代码所做的巨大贡献,也感谢 PixelArt-α 的作者们公开其基于 DiT 的扩散模型。

A 理论部分

A.1 定理3.1的证明

定理3.1的证明基于所谓的分数投影恒等式,该恒等式最初由Vincent[69]提出,用于连接去噪分数匹配和去噪自编码器。后来,Zhou等人[92]将该恒等式应用于推导基于Fisher散度的蒸馏方法。感谢Zhou等人[92]的努力,我们在此重述分数投影恒等式而不加以证明。读者可以参考Zhou等人[92]以获取分数投影恒等式的完整证明。

定理A.1(分数投影恒等式):设$u(\cdot, \theta)$是一个向量值函数,使用定理3.1的符号,在温和条件下,以下恒等式成立:

接下来,我们开始证明定理3.1。

证明:我们证明一个更一般的结果。设$u(\cdot)$是一个向量值函数,所谓的分数投影恒等式[92,69]成立:

对恒等式(A.1)两边关于$\theta$求梯度,我们有:

因此,我们得到以下恒等式:

该恒等式对于任意函数$u(\cdot, \theta)$和参数$\theta$都成立。如果我们令

那么我们形式上有:

A.2 分数隐式匹配的PyTorch风格伪代码

在本节中,我们给出算法1的PyTorch风格伪代码,使用Pseudo-Huber距离函数。关于CIFAR10与EDM模型的详细算法,请参见算法2。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
import torch.optim as optim

# 初始化生成器G
G = Generator()
# 加载教师扩散模型
Sd = DiffusionModel().load('/path_to_ckpt').eval().requires_grad_(False)
Sg = copy.deepcopy(Sd) # 用教师扩散模型初始化在线扩散模型

# 定义优化器
opt_G = optim.Adam(G.parameters(), lr=0.001, betas=(0.0, 0.999))
opt_Sg = optim.Adam(Sg.parameters(), lr=0.001, betas=(0.0, 0.999))

# 训练循环
while True:
# 更新Sg
Sg.train().requires_grad_(True)
G.eval().requires_grad_(False)

# 循环2次以更新Sg
for _ in range(2):
z = torch.randn((2000, 2)).to(device)
with torch.no_grad():
fake_x = G(z)

t = torch.from_numpy(np.random.choice(np.arange(1, Sd.T), size=fake_x.shape[0], replace=True)).to(device).long()
fake_xt, t, noise, sigma_t, g2_t = Sd(fake_x, t=t, return_t=True)
sigma_t = sigma_t.view(-1, 1).to(device)
g2_t = g2_t.to(device)
score = Sg(torch.cat([fake_xt, t.view(-1, 1) / Sd.T], -1)) / sigma_t

batch_sg_loss = score + noise / sigma_t
batch_sg_loss = (g2_t * batch_sg_loss.square().sum(-1)).mean() * Sd.T

optimizer_Sg.zero_grad()
batch_sg_loss.backward()
optimizer_Sg.step()

# 更新G
Sg.eval().requires_grad_(False)
G.train().requires_grad_(True)

z = torch.randn((2000, 2)).to(device)
fake_x = G(z)

t = torch.from_numpy(np.random.choice(np.arange(1, diffusion.T), size=fake_x.shape[0], replace=True)).to(device).long()
fake_xt, t, noise, sigma_t, g2_t = diffusion(fake_x, t=t, return_t=True)
sigma_t = sigma_t.view(-1, 1).to(device)
g2_t = g2_t.to(device)

score_true = Sd(torch.cat([fake_xt, t.view(-1, 1) / diffusion.T], -1)) / sigma_t
score_fake = Sg(torch.cat([fake_xt, t.view(-1, 1) / diffusion.T], -1)) / sigma_t

score_diff = score_true - score_fake

offset_coeff = denoise_diff / torch.sqrt(denoise_diff.square().sum([1, 2, 3], keepdims=True) + self.phuber_c ** 2)
weight = 1.0

batch_g_loss = weight * offset_coeff * (fake_denoise - images)
batch_g_loss = batch_g_loss.sum([1, 2, 3]).mean()

optimizer_G.zero_grad()
batch_g_loss.backward()
optimizer_G.step()

A.3 不同距离函数下的SIM实例

在3.3节中,我们讨论了将幂范数作为距离函数。对于其他选择,如Huber距离,其定义为:

对于其他距离函数选择,如$L_1$范数和带幂范数的指数函数,我们将其列在表4中。

B 实证部分

B.1 人类偏好研究答案

图1中人类偏好研究的答案如下:

  • 第一行中间的图像由单步SIM-DiT-600M生成;
  • 第二行最左侧的图像由单步SIM-DiT-600M生成;
  • 第三行最左侧的图像由单步SIM-DiT-600M生成。

B.2 CIFAR10数据集实验细节

我们遵循SiD和DI在CIFAR10上的实验设置。首先简要介绍EDM模型[25]。

EDM模型依赖于如下扩散过程:

前向过程(B.1)的样本可通过向生成器函数的输出添加随机噪声生成,即$x_t = x_0 + t\epsilon$,其中$\epsilon \sim N(0, I)$是高斯向量。EDM模型还将扩散模型的分数匹配目标重新表述为去噪回归目标,表达式为:

其中$d_\psi(\cdot)$是一个去噪器网络,试图通过输入噪声样本预测干净样本。最小化损失(B.2)可得到训练好的去噪器,它与边缘分数函数有简单关系:

在这种表述下,我们实际上有用于实验的预训练去噪器模型。因此,后续部分将使用EDM符号。

单步生成器的构建:设$d_\theta(\cdot)$为预训练EDM去噪器模型。由于EDM模型的去噪器表述,我们构建的生成器与预训练EDM去噪器具有相同架构,并带有预先选择的索引$t^_$,表达式为:

我们使用教师EDM去噪器模型的相同参数初始化生成器。

时间索引分布:训练EDM扩散模型和生成器时,需要随机选择时间$t$以近似损失函数(B.2)的积分。EDM模型训练扩散(去噪器)模型时,$t$的默认分布为对数正态分布,即:

以及加权函数:

在我们的算法中,更新在线扩散(去噪器)模型时遵循与EDM模型相同的设置。

在SiD中,他们提出使用一种特殊的离散时间分布,表达式为:

他们提出从以下分布中均匀选择$t$:

我们在图2中将这种时间分布称为Karr分布,因为这种调度最初是在Karras的EDM工作中为采样提出的。

然而,在实践中,我们发现Karr分布(B.8)实证效果并不好。相反,我们发现使用修改后的对数正态时间分布更新SIM的生成器时,效果比Karr分布更好。我们的SIM时间分布表达式为:

加权函数:如前所述,更新去噪器模型时,我们使用与EDM相同的$\lambda_{EDM}(t)$(B.7)加权函数。更新生成器时,SiD使用一种特殊设计的加权函数,表达式为:

符号sg表示停止梯度,$C$是数据维度。他们声称这种加权函数有助于稳定训练。然而,在我们的实验中,由于SIM本身已经对损失进行了归一化(见第4节),我们没有使用这种特定的加权函数,而是将所有时间的加权函数都设为1。我们在图2中将SiD的加权函数称为sidwgt,将我们的加权函数称为nowgt。

在图2中,我们比较了使用不同时间分布和加权函数的SiD和SIM。发现SIM+nowgt+对数正态时间分布的性能明显更好,因此我们的最终实验采用这种配置。表5记录了我们在CIFAR10 EDM蒸馏上使用SIM的详细配置。

超参数 CIFAR-10(无条件) CIFAR-10(有条件)
DM $s_\psi$ 生成器 $g_\theta$ DM $s_\psi$ 生成器 $g_\theta$
学习率 1e-5 1e-5 1e-5 1e-5
批大小 256 256 256 256
$\sigma(t^*)$ 2.5 2.5 2.5 2.5
Adam $\beta_0$ 0.0 0.0 0.0 0.0
Adam $\beta_1$ 0.999 0.999 0.999 0.999
时间分布 $p_{EDM}(t)$(B.5) $p_{SIM}(t)$(B.9) $p_{EDM}(t)$(B.5) $p_{SIM}(t)$(B.9)
加权 $\lambda_{EDM}(t)$(B.7) 1 $\lambda_{EDM}(t)$(B.7) 1
损失函数 (B.2) (B.13) (B.2) (B.13)
GPU数量 4 A100-40G 4 A100-40G 4 A100-40G 4 A100-40G

表5:CIFAR10 EDM蒸馏中SIM使用的超参数

在最优设置和EDM表述下,我们可以在算法2中以EDM风格重写我们的算法。

B.3 文本到图像蒸馏实验细节

在文本到图像蒸馏部分,为了与CIFAR10上的实验保持一致,我们用EDM表述重写PixArt-α模型:

这里,遵循EDM中的iDDPM+DDIM预处理,PixArt-α用$F_\theta$表示,$x_c$是带有标准差为$t$的噪声的图像数据,对于其余参数如$C_1$和$C_2$,我们保持不变以匹配EDM中的定义。与原始模型不同,我们只保留了该模型输出的图像通道。由于我们在EDM中采用了iDDPM+DDIM的预处理,每个σ值传入模型后会被四舍五入到最接近的1000个区间。对于PixArt-α中使用的实际值,beta_start设为0.0001,beta_end设为0.02。因此,根据EDM的表述,我们的噪声分布范围是[0.01, 156.6155],这将用于截断我们采样的$t$。我们的单步生成器表述为:

这里遵循SiD,$t^* = 2.5$且$z \sim N(0, (t^_)^2 I)$,我们在实践中观察到,较大的$t^_$值会导致模型收敛更快,但对于完整的模型训练过程,收敛速度的差异可以忽略不计,对最终结果的影响也很小。

算法2:用于蒸馏EDM教师的带Pseudo-Huber距离的SIM(PyTorch风格)
输入:预训练EDM去噪器$d_{q_t}(.)$、生成器$g_\theta$、先验分布$p_z$、在线EDM去噪器$d_\psi(.)$;可微距离函数$d(.)$和前向扩散(2.1)。
while 未收敛 do
// 冻结$\theta$,更新$\psi$:
// $t \sim p_{SIM}(t)$,$x_t = x_0 + t\epsilon$,$\epsilon \sim N(0, I)$
$L(\psi) = \lambda_{EDM}(t) \times | d_\psi(x_t, t) - x_0 |_2^2$
$x_0 = g_\theta(z).detach()$,$z \sim p_z$
$t \sim p_{EDM}(t)$,$x_t = x_0 + t\epsilon$,$\epsilon \sim N(0, I)$
$L(\psi).backward()$;更新$\psi$
$x_0 = g_\theta(z)$,$z \sim p_z$
// 冻结$\psi$,更新$\theta$:

,其中$y_t := d_\psi(x_t, t) - d_{q_t}(x_t)$ (B.13)
$L(\theta).backward()$;更新$\theta$
end
return $\theta$,$\psi$。

我们使用了SAM-LLaVA-Caption10M数据集,该数据集包含由LLaVA模型在SAM数据集上生成的提示。这些提示为图像提供了详细描述,从而为我们的蒸馏实验提供了具有挑战性的样本集。

本节所有实验均在4个A100-40G GPU上进行,采用bfloat16精度,使用PixArt-XL-2-512x512模型版本,并采用相同的超参数。两个优化器都使用Adam,学习率为5e-6,betas=[0, 0.999]。此外,为了实现1024的批大小,我们采用了梯度检查点,并将梯度累积设为8。最后,关于训练噪声分布,我们没有遵循原始的iDDPM调度,而是从均值为-2.0、标准差为2.0的对数正态分布中采样σ,我们对两个优化步骤使用相同的噪声分布,并将两个损失权重设为常数1。我们最好的模型在SAM Caption数据集上训练了约16k次迭代,相当于不到2个epoch。这个训练过程在4个A100-40G GPU上花费了大约2天时间。

我们还测试了不同噪声分布对蒸馏过程的影响。当噪声分布高度集中在较小值附近时,我们观察到生成的样本出现过暗现象。另一方面,当我们使用稍大的噪声分布时,发现生成样本的结构往往不稳定。

B.4 人类偏好研究说明

我们的用户研究主要关注蒸馏模型和教师模型的输出比较。每张图像都经过严格的人工审核,以确保调查参与者的安全。我们通过问卷进行研究,向用户展示由蒸馏模型和教师模型生成的两张随机排序的图像,让他们选择与文本描述最匹配且图像质量更高的样本。最后,我们将收集到的对蒸馏模型和教师模型的投票作为用户偏好的指标。用于进行这些评估的问卷网站如图5所示。

具体来说,我们随机选择了17个提示词,使用学生模型和教师模型生成512x512分辨率的图像。为了便于比较,我们将两张图像随机排序并排展示。在问卷中,除了生成的图像外,我们还提供了完整的提示词供参考。最终,我们总共收集了约30份调查回复。

B.5 CIFAR10上的生成样本

B.6 CIFAR10无条件生成的FID收敛

B.7 图3的提示词

  • 图3第一行提示词:撒哈拉沙漠中一株带着笑脸的小仙人掌。
  • 图3第二行提示词:一张翡翠绿和金色的法贝热彩蛋图像,16k分辨率,细节丰富,产品摄影,在ArtStation上流行,焦点清晰,工作室照片,复杂细节,背景较暗,完美光线,完美构图,清晰特征,Miki Asai微距摄影,特写,超细节,在ArtStation上流行,焦点清晰,工作室照片,复杂细节,细节丰富,由Greg Rutkowski创作。
  • 图3第三行提示词:婴儿在雪地里玩玩具。