全文翻译
摘要
扩散模型作为一类新型生成模型近年来备受关注。尽管取得了成功,但这类模型存在一个显著缺陷——采样速度缓慢,需要进行数百甚至数千次函数评估(NFE)。为此,研究人员探索了无学习(learning-free)和有学习(learning-based)两类采样策略来加速采样过程。无学习采样基于扩散常微分方程(ODE)的公式表述,采用各种常微分方程求解器。然而,该方法在准确追踪真实采样轨迹方面面临挑战,尤其是在函数评估次数较少的情况下。相反,基于知识蒸馏等有学习采样方法需要大量额外训练,限制了其实用性。为克服这些局限性,我们提出了蒸馏型常微分方程求解器(D-ODE求解器),这是一种基于常微分方程求解器公式表述的简洁蒸馏方法。该方法无缝融合了无学习采样和有学习采样的优势。
D-ODE求解器通过对现有常微分方程求解器进行单一参数调整构建而成。此外,我们利用知识蒸馏技术,从大步数常微分方程求解器中提取知识,优化小步数D-ODE求解器,并在一批样本上完成这一过程。综合实验表明,与现有常微分方程求解器(包括DDIM、PNDM、DPM-Solver、DEIS和EDM)相比,D-ODE求解器性能更优,尤其在函数评估次数较少的场景中表现突出。值得注意的是,与以往蒸馏技术相比,我们的方法计算开销可忽略不计,便于与现有采样器快速集成。定性分析表明,D-ODE求解器不仅能提升图像质量,还能忠实遵循目标常微分方程轨迹。
1. 引言
扩散模型[15, 44, 46]作为一种极具吸引力的生成模型框架近年来逐渐兴起,在众多应用中展现出最先进的性能,例如图像生成[6, 47]、文本生成[2, 17]、音频生成[27, 32]、3D形状生成[5, 30]、视频合成[12, 56]以及图生成[35, 50]。
尽管扩散模型在生成高质量样本以及缓解模式崩溃[43, 59]等问题上表现出色,但它们的采样过程通常需要大量的网络评估,导致该过程缓慢且计算密集[54]。近期研究聚焦于优化采样过程,旨在在不牺牲样本质量的前提下提高效率[19, 42, 45]。值得注意的是,针对扩散模型中采样效率提升的方法主要分为两类:无学习采样和有学习采样[55]。
无学习采样可应用于预训练扩散模型,无需额外训练,且通常涉及随机微分方程(SDE)或常微分方程(ODE)的高效求解器[47]。典型示例包括DDIM[45],它采用非马尔可夫过程实现加速采样;以及PNDM[24],其引入伪数值方法在给定数据流形上求解微分方程。EDM[19]利用Heun二阶方法,相比朴素的欧拉方法[47]展现出更优的采样质量。最近,DPM-Solver[25]和DEIS[58]利用扩散ODE的半线性结构,并采用指数积分器的数值方法。这些ODE求解器旨在沿着ODE采样轨迹中数据分布密度较高的区域准确估计得分函数[24, 60]。然而,ODE求解器的采样路径可能偏离真实轨迹,尤其是在去噪步数较少时,会导致得分函数出现显著的拟合误差[23, 48, 54]。
相反,有学习采样需要额外训练以优化特定目标,例如知识蒸馏[42, 48]和优化离散化[33, 52]。例如,渐进式蒸馏[42]通过迭代方式将预训练扩散模型蒸馏到需要更少采样步数的学生模型中。最近,Song等人[48]提出了一致性模型,该模型经过训练可在同一ODE轨迹上预测一致的输出。尽管基于蒸馏的技术能够在几步之内完成生成,但需要大量训练来使扩散模型的去噪网络适应每个数据集、采样器和网络。
为解决这些挑战,我们提出了一种用于扩散模型的新型蒸馏方法,称为蒸馏型ODE求解器(D-ODE求解器),该方法利用现有ODE求解器中固有的采样动态特性。D-ODE求解器弥合了无学习采样和有学习采样之间的差距,同时减轻了相关问题。我们的方法基于以下观察:去噪网络的输出(即去噪输出)在相邻时间步之间表现出高度相关性。
D-ODE求解器通过向现有ODE求解器引入单个额外参数,将当前去噪网络输出与先前输出进行线性组合。这使得在每个时间步t能够更准确地估计去噪输出。对于高阶求解器(例如PNDM、DEIS和DPM-Solver),我们对其高阶估计进行线性组合,以利用它们逼近真实得分函数的能力。通过最小化小步数D-ODE求解器(学生)与大步数ODE求解器(教师)的输出差异,为每个数据集优化额外参数。一旦确定最优参数,D-ODE求解器可在采样过程中跨批次重用,同时保持去噪网络冻结。值得注意的是,D-ODE求解器持续改进了先前ODE求解器(包括一阶和高阶方法)的FID值,显著减少了蒸馏的计算时间。我们的主要贡献总结如下:
- 我们引入了D-ODE求解器,通过简洁的公式表述将知识从大步数ODE求解器转移到小步数ODE求解器中。
- D-ODE求解器减少了对预训练去噪网络进行大量参数更新的需求,显著缩短了知识蒸馏时间。
- 在定量研究中,我们的新型采样器在多个图像生成基准测试的FID分数方面优于最先进的ODE求解器。
2. 背景
前向和反向扩散过程
前向过程$ \{x_t \in \mathbb{R}^D\}_{t \in [0, T]}$ 始于从数据分布$ p_{\text{data}}(x)$ 中抽取的$ x_0$ ,并在时间步$ T>0$ 时演化至$ x_T$ 。给定$ x_0$ ,$ x_t$ 的分布可表示为:
其中$ \alpha_t \in \mathbb{R}$ 和$ \sigma_t \in \mathbb{R}$ 决定扩散模型的噪声调度,信噪比(SNR)$ \alpha_t^2 / \sigma_t^2$ 随$ t$ 的推进严格递减[21]。这确保了在实际应用中,$ x_T$ 的分布$ q_T(x_T)$ 近似于纯高斯噪声。
扩散模型的反向过程通过去噪网络迭代去除噪声来近似。从$ x_T$ 开始,反向过程的转移定义如下[15]:
其中$ \theta$ 表示去噪网络中的可训练参数,$ \mu_\theta(x_t, t)$ 和$ \sum_\theta(x_t, t)$ 是由去噪网络$ \theta$ 估计的高斯均值和方差。
SDE和ODE公式表述
Song等人[47]利用随机微分方程(SDE)表述前向扩散过程,以实现与式(1)相同的转移分布。给定$ x_0 \sim p_{\text{data}}(x)$ ,从时间步0到$ T$ 的前向扩散过程新定义为:
其中$ w_t \in \mathbb{R}^D$ 是标准维纳过程,$ f(t)$ 和$ g(t)$ 是关于$ \alpha_t$ 和$ \sigma_t$ 的函数。Song等人[47]还基于Anderson[1]引入了反向时间SDE,当给定$ x_T \sim q_T(x_T)$ 时,该SDE从时间步$ T$ 演化至0:
其中$ \overline{w}_t$ 是反向时间中的标准维纳过程,$ \nabla_x \log q_t(x_t)$ 被称为得分函数[18]。可以忽略维纳过程引入的随机性,从而在反向过程中定义扩散常微分方程(ODE),其对应于对SDE的平均求解。从$ x_T \sim q_T(x_T)$ 开始,从时间步$ T$ 到0的概率流ODE推进如下:
概率流ODE的公式表述为使用各种ODE求解器加速基于扩散的采样过程提供了可能[19, 24, 25, 58]。
去噪得分匹配
为在采样过程中求解式(5),必须估计得分函数$ \nabla_x \log q_t(x_t)$ 。Ho等人[15]提出使用噪声预测网络$ \epsilon_\theta$ 估计得分函数,使得当$ x_t = \alpha_t x + \sigma_t \epsilon$ 时,$ \nabla_x \log q_t(x_t) = -\epsilon_\theta(x_t, t) / \sigma_t$ 。噪声预测网络$ \epsilon_\theta$ 使用$ L_2$ 范数进行训练,训练样本来自$ p_{\text{data}}$ :
这里,高斯噪声按照噪声调度$ (\alpha_t, \sigma_t)$ 添加到数据中,噪声预测网络$ \epsilon_\theta$ 从含噪样本中预测添加的噪声$ \epsilon$ 。
或者,得分函数也可以使用数据预测网络$ x_\theta$ 而非$ \epsilon_\theta$ 来表示,即$ \nabla_x \log q_t(x_t) = (x_\theta(x_t, t) - x_t) / \sigma_t^2$ 。数据预测网络$ x_\theta$ 使用以下$ L_2$ 范数进行训练:
值得注意的是,估计原始数据在理论上等同于学习预测噪声$ \epsilon$ [15, 29]。尽管一些研究认为预测噪声在经验上能产生更高质量的样本[15, 41],但Karras等人[19]最近使用数据预测网络取得了最先进的性能。在本文中,我们使用噪声预测网络和数据预测网络进行了综合实验。在本文的其余部分,我们用$ D_\theta$ 表示扩散模型的去噪网络,它可以是噪声预测网络或数据预测网络。
3. 所提方法
本研究旨在弥合有学习采样与无学习采样之间的差距,同时利用两种方法的优势。我们充分利用ODE求解器的采样动态特性,同时通过简洁高效的知识蒸馏提升样本质量。本节首先阐述关于去噪网络输出(即去噪输出)间高度相关性的基本观察,这一观察启发了D-ODE求解器的设计。随后,我们详细介绍如何将知识从ODE求解器迁移至D-ODE求解器。
3.1 去噪输出间的相关性
ODE求解器通常通过利用去噪网络的输出历史来优化采样过程,从而能够省略许多中间步骤。因此,理解去噪输出之间的关系对于开发D-ODE求解器至关重要。我们的目标是设计新型ODE求解器,在充分利用采样动态特性优势的同时,将优化自由度降至最低。
图2展示了基于1000步DDIM[45]运行过程中所有去噪输出间余弦相似度计算得到的热力图。我们观察到,在噪声预测模型和数据预测模型中,相邻时间步的去噪输出都表现出高度相关性,余弦相似度接近1。这一观察表明,去噪输出包含冗余和重复的信息,这使得我们可以跳过大多数时间步的去噪网络评估。例如,去噪输出的历史可以被组合起来,以更好地表示下一个输出,从而有效减少准确采样所需的步数。这一理念在大多数ODE求解器中得到了应用,这些求解器的设计基于求解微分方程的理论原理[19, 24, 25, 52, 58]。
3.2 D-ODE求解器的公式表述
如图1所示,扩散模型中的每个去噪步骤通常包含两个部分:(1)去噪网络;(2)ODE求解器。给定时间步$ t$ 的估计含噪样本,去噪网络生成去噪输出,随后ODE求解器利用该去噪输出和时间步$ t$ 的含噪样本生成下一个样本。虽然高阶ODE求解器也会利用去噪输出的历史,但为简洁起见,此处我们省略相关符号。这一过程会反复迭代,直到扩散模型得到估计的原始样本。
我们现在引入一种具有简洁参数化的D-ODE求解器,用于从ODE求解器中提取知识。我们首先概述一种基本方法,将时间步$ t$ 的新去噪输出估计为当前和先前去噪输出的线性组合:
其中$ \lambda_k$ 是每个去噪输出$ d_k$ 的权重参数。通过精心选择$ \lambda_k$ ,我们期望新的去噪输出能够更好地逼近式(5)中ODE的目标得分函数,从而提升样本质量。一些高阶ODE求解器[24, 25, 58]采用了与式(8)类似的公式表述,其权重参数通过数学方法确定。
式(8)存在一个挑战:新去噪输出$ O_t$ 的值可能会因权重$ \{\lambda_k\}_{k=t}^{T}$ 的取值而变得不稳定和易变。在ODE求解器中,通过数值计算得到的权重不太可能出现这种不稳定性,但当权重通过知识蒸馏进行优化时,收敛性无法得到保证。为了生成高质量样本,采样过程必须遵循扩散模型训练所依据的真实ODE轨迹[24, 48]。换句话说,对于数据目标流形之外的样本,去噪网络可能无法产生可靠的输出[23, 34, 54]。
为避免这一问题,式(8)应受到约束,以使其遵循原始ODE轨迹。因此,新的去噪输出$ O_t$ 可定义如下:
此外,我们通过经验发现,仅使用前一个时间步的去噪输出就足以从教师采样中提取知识(详见补充材料)。因此,我们得到了D-ODE求解器的式(10)。值得注意的是,新去噪输出$ O_t$ 的均值近似于原始去噪输出的均值,因为在足够大的批次中,样本关于时间步$ t$ 和$ t+1$ 的均值不会发生显著变化(例如,$ \mathbb{E}_{x \sim p_{\text{data}}}[O_t] \approx \mathbb{E}_{x \sim p_{\text{data}}}[d_t]$ )。这是D-ODE求解器的一个关键特性,因为我们的目标是保持在ODE的原始采样轨迹上。
对于DDIM[45],只需将$ d_t$ 替换为$ O_t$ 即可构建D-DDIM:
其中$ (\alpha_t, \sigma_t)$ 表示预定义的噪声调度。$ \lambda_t$ 随后通过知识蒸馏进行优化。
与高阶ODE求解器的比较
高阶采样方法会利用去噪输出的历史。由于这些方法比一阶方法(如DDIM)能更好地逼近ODE的目标得分函数,我们在它们的近似基础上应用式(10)来构建D-ODE求解器。换句话说,式(10)中的$ d_t$ 被替换为每种方法的高阶近似。通过这种方式,我们可以涉及更多时间步来获取$ O_t$ ,同时通过为每个数据集适配的额外参数$ \lambda_t$ 克服ODE求解器的瓶颈。与高阶ODE求解器不同,D-ODE求解器配备了参数$ \lambda_t$ ,该参数通过知识蒸馏针对特定数据集进行优化,以进一步降低得分函数的拟合误差。补充材料包含了D-ODE求解器的具体应用以及不同的公式表述。
3.3 D-ODE求解器的知识蒸馏
在图1中,教师采样过程从时间步$ Ct$ 的含噪样本$ \hat{x}_{Ct}^{(t)}$ 开始,经过$ C$ 个去噪步骤,生成时间步$ C(t-1)$ 的样本$ \hat{x}_{C(t-1)}^{(t)}$ 。同时,学生采样过程从时间步$ t$ 的含噪样本$ \hat{x}_t^{(s)}$ 开始,经过一个去噪步骤后得到时间步$ t-1$ 的样本$ \hat{x}_{t-1}^{(s)}$ 。为了优化D-ODE求解器$ S_d$ 中的$ \lambda_t$ ,首先对一个批次执行教师采样,保存中间样本$ \{\hat{x}_k^{(t)}\}_{k=C(t-1)}^{Ct}$ 作为目标。同时执行学生采样,得到中间样本$ \{\hat{x}_k^{(s)}\}_{k=t-1}^{t}$ 作为预测结果。随后,通过最小化批次$ B$ 上目标与预测结果之间的差异来确定$ \lambda_t^{\ast}$ :
其中$ d_t^{(s)} = D_\theta(\hat{x}_t^{(s)}, t)$ 。对学生采样的每个时间步$ t$ 求解上述方程,得到一组最优的$ \lambda_t$ 值(例如,$ \lambda^{\ast} = \{\lambda_0^{\ast}, \lambda_1^{\ast}, \ldots, \lambda_{T-1}^{\ast}\}$ )。值得注意的是,$ \lambda^{\ast}$ 仅使用一个批次的样本来估计,这一过程通常只需几分钟的CPU时间,且之后可重用于其他批次。
算法1概述了D-ODE求解器的整体采样流程。当生成$ N$ 个样本时,通常将$ N$ 划分为$ M$ 个批次,并按顺序对每个包含$ |B| = N/M$ 个样本的批次$ B$ 执行采样过程(第3行)。对于第一个批次,使用去噪网络$ D_\theta$ 和ODE求解器$ S$ 执行$ CT$ 步教师采样,以获得中间输出作为目标样本(第7行)。随后,使用D-ODE求解器$ S_d(\lambda)$ 执行$ T$ 步学生采样(第8行)。此时,通过求解式(14)为每个时间步估计并保存$ \lambda^{\ast}$ (第9行)。从第二个批次开始,可以使用冻结的去噪网络$ D_\theta$ 和D-ODE求解器$ S_d(\lambda^{\ast})$ 进行采样(第11行)。需要注意的是,学生采样仅需$ T$ 步就能生成样本,其质量与教师采样经过$ CT$ 步生成的样本相近。
4. 实验
在本节中,我们对D-ODE求解器与现有ODE求解器在多种图像生成基准测试集上进行了全面评估,涉及不同分辨率的数据集,包括CIFAR-10(32×32)、CelebA(64×64)、ImageNet(64×64和128×128)、FFHQ(64×64)以及LSUN卧室(256×256)。我们的实验涵盖了噪声预测模型和数据预测模型,每种模型都涉及不同的ODE求解器集合。采用Fréchet Inception距离(FID)[13]作为评估指标,按照Lu等人[25]的方案,在不同数量的去噪函数评估次数(NFE)下,使用50K生成样本进行测量。报告的FID分数是三次独立实验(使用不同随机种子)的平均值。
对于ODE求解器的蒸馏,我们将尺度参数设置为$ C=10$ ,批次大小设置为$ |B|=100$ ,但LSUN卧室数据集除外,由于GPU内存限制,其批次大小采用25。需要注意的是,除非另有明确说明,DDIM作为主要的教师采样方法来指导学生采样。选择DDIM是因为某些ODE求解器在采样过程中采用多步方法,难以将其中间输出设置为蒸馏的目标,而DDIM每个去噪步骤生成一个中间输出,便于建立DDIM目标与学生预测之间的匹配对。关于D-ODE求解器的具体应用以及尺度$ C$ 和批次大小$ |B|$ 的消融研究详见补充材料。
4.1 噪声预测模型
我们将D-ODE求解器应用于噪声预测模型中使用的离散时间ODE求解器,包括DDIM[45]、iPNDM[58]、DPM-Solver[25]和DEIS[58]。对于DPM-Solver和DEIS,我们选择了三阶方法。虽然这些ODE求解器主要在NFE大于10的情况下进行评估,但我们也在极小NFE(如2或3)的场景中进行了实验,以评估D-ODE求解器在采样初始阶段的性能。
图3表明,D-ODE求解器优于ODE求解器,在大多数NFE下实现了更低的FID。在图3a和图3d中,当NFE超过5时,D-DDIM性能优于DDIM,并且随着NFE的增加,逐渐收敛到与DDIM相近的FID分数。需要注意的是,小NFE(2或5)的DDIM无法生成有意义的图像,这一点在D-DDIM的性能中也有所体现。iPNDM是一种利用先前去噪输出的高阶方法,除了在2次NFE的情况下,通过D-ODE求解器公式表述持续获得性能提升。这种改进在DPM-Solver3和DEIS3等高阶方法中尤为显著。具体而言,D-DPM-Solver3有效缓解了多步方法在极小NFE下的不稳定性,性能显著优于DPM-Solver3。尽管DEIS3已经通过高阶近似对当前去噪输出进行了精确表示,但图3显示,D-DEIS3通过针对每个数据集进行知识蒸馏优化的参数$ \lambda$ ,可以进一步改进这种近似。在补充材料中,我们还展示了将D-ODE求解器应用于DPM-Solver++[26]也是有效的。
4.2 数据预测模型
在数据预测模型的实验中,我们遵循Karras等人[19]提出的配置。我们将D-ODE求解器应用于基于该配置重新构建的DDIM以及采用Heun二阶方法的EDM[19]。虽然Karras等人[19]在其论文中也重新实现了基于欧拉方法的采样器,但由于EDM表现出更优的FID分数,我们未将这些采样器纳入实验。
图4表明,D-ODE求解器优于ODE求解器,尤其在较小NFE的情况下。例如,25次NFE的D-DDIM生成的样本在FID方面可与250次NFE的DDIM相媲美,实现了约10倍的加速。随着NFE的增加,ODE求解器和D-ODE求解器的FID分数逐渐收敛。由于学生采样的性能与教师采样密切相关,因此在较大NFE下,学生采样和教师采样的FID分数相近是合理的。此外,值得注意的是,在NFE约为2时,DDIM的性能偶尔略优于D-DDIM。这一观察表明,2步DDIM可能没有足够的能力有效从教师采样中蒸馏知识,特别是当DDIM已经生成含噪图像(FID分数超过250)时。
5. 分析
本节包含采样过程的可视化结果和定性分析。我们首先采用Liu等人[24]的方法进行可视化分析,旨在探究采样过程的全局和局部特征。随后,我们对ODE求解器和D-ODE求解器生成的图像进行对比分析。
5.1 采样轨迹的可视化
为便于解释高维数据,我们借鉴Liu等人[24]的分析框架,采用两种不同的度量指标:作为全局特征的范数变化,以及作为局部特征的特定像素值变化。作为参考,我们纳入了1000步DDIM的范数,因为它遵循目标数据流形。
在图5的上半部分,D-ODE求解器的范数变化轨迹与ODE求解器的范数轨迹高度接近。这一观察表明,D-ODE求解器始终处于数据的高密度区域,对ODE轨迹的影响极小。这与我们设计D-ODE求解器的目标一致,如3.2节所述,我们确保新的去噪输出与ODE求解器的去噪输出均值相匹配。
在图5的下半部分,我们从图像中随机选择了两个像素,并展示了它们的数值变化,以1000步DDIM的结果作为目标参考。显然,D-ODE求解器的像素值比ODE求解器的像素值更接近目标轨迹。结果表明,小步数的D-ODE求解器能够生成与大步数DDIM相当的高质量局部特征样本。这也凸显了数据特定参数λ的重要性,它能进一步降低得分函数的拟合误差。总之,D-ODE求解器通过引导像素向目标值靠近,同时忠实于原始数据流形,实现了高质量的图像生成。
5.2 定性分析
在图6中,我们展示了使用在ImageNet和FFHQ数据集上训练的数据预测模型,通过ODE求解器和D-ODE求解器生成的图像对比结果。总体而言,我们的方法在图像质量上优于ODE求解器,尤其在NFE较小时表现更为明显。DDIM往往生成边界模糊的图像,而D-DDIM生成的图像更清晰,色彩对比度更突出。EDM在NFE小于5时生成的图像噪声高且存在伪影,导致FID分数超过250;相比之下,即使在5次NFE下,D-EDM也能生成相对清晰的物体。更多分析图和定性结果详见补充材料。
6. 结论
在本研究中,我们提出了D-ODE求解器,这是一种创新性的扩散模型蒸馏方法,其核心是利用现有ODE求解器的原理。D-ODE求解器通过向ODE求解器引入单个参数构建而成,能够高效地将大步数教师采样的知识蒸馏到小步数学生采样中,且只需极少的额外训练。我们的实验表明,D-ODE求解器能够有效提升最先进ODE求解器的FID分数,尤其在小函数评估次数(NFE)场景中表现突出。可视化分析揭示了我们方法在全局和局部特征上的改进,证明图像质量得到了显著提升。
尽管在大NFE情况下,改进幅度往往较小或有限,最终会收敛到教师采样过程的FID分数,但D-ODE求解器仍然是一个极具吸引力的选择,因为它能以可忽略的额外计算成本提升样本质量。其适用性广泛,可应用于各种采样器、数据集和网络。然而,对于高分辨率图像生成,D-ODE求解器的单参数特性可能不够用。未来研究的一个有趣方向是探索通过图像网格划分或 latent 空间操作[39]引入局部特定参数。
扩散模型ODE求解器的小步数蒸馏补充材料
7. 生成模型的三难困境
正如文献[54]所阐述,生成模型面临着一个由三个关键要素构成的三难困境:
- 高质量样本:生成模型应具备生成高质量样本的能力。
- 模式覆盖与样本多样性:生成模型应能实现模式覆盖,确保生成的样本具有多样性,涵盖数据分布中的各种模式。
- 快速采样:高效的生成模型应能快速生成样本。
例如,生成对抗网络(GANs)[4,9]在仅需一次网络评估的情况下就能生成高质量样本,但GANs往往难以生成多样化样本,导致模式覆盖性较差[43,59]。相反,变分自编码器(VAEs)[22]和归一化流(Normalizing Flows)[7]在保证模式覆盖方面设计更为完善,但可能存在样本质量不高的问题。近年来,扩散模型作为一类新型生成模型崭露头角,其生成的高质量样本可与GANs相媲美[6,41],同时还能提供丰富多样的样本。然而,传统扩散模型的采样过程通常需要数百至数千次网络评估,在实际应用中计算成本高昂。扩散模型采样过程的主要瓶颈与去噪网络的评估次数密切相关。因此,众多研究工作探索了在保持生成样本质量的前提下,通过跳过或优化采样步骤来加速采样过程的技术。这些技术大致可分为两类:如主论文引言部分所述的有学习采样方法和无学习采样方法[55]。
8. 噪声预测网络与数据预测网络
去噪网络的输出应通过参数化来估计反向时间ODE的得分函数。得分函数表示数据分布对数的梯度,指示着更可能存在且噪声更少的数据方向。一种直接的参数化方法是直接估计原始数据,在这种情况下,得分函数通过计算给定当前噪声水平下指向原始数据的梯度来估计:
另一种方法是间接设计去噪网络来预测噪声$ \epsilon$ ,噪声$ \epsilon$ 代表注入原始样本中的残余信号。在这种情况下,得分函数可计算为:
尽管噪声预测网络$ \epsilon_{\theta}$ 和数据预测网络$ x_{\theta}$ 在理论上是等价的[19,21,29],但它们在采样过程中表现出不同的特性。
噪声预测网络
噪声预测网络最初可能会在真实噪声与预测噪声之间引入显著差异[3]。由于采样从高噪声样本开始,去噪网络缺乏足够信息来准确预测噪声[15]。此外,每个时间步所需的修正幅度相对较小,需要多个时间步才能纠正此类偏差[29]。
数据预测网络
已知数据预测网络在采样初期具有更高的准确性,而噪声预测网络在后期更具优势。预测数据有助于去噪网络理解目标样本的全局结构[29]。经验证据表明,在采样过程开始时,预测数据与真实数据较为接近[10,37]。然而,在后期阶段,当大部分结构已经形成,只需去除少量噪声伪影时,精细细节往往难以恢复[3]。本质上,早期数据预测提供的信息在采样后期效果会减弱。
我们的实验
数据预测网络和噪声预测网络之间的差异在主论文中去噪输出相关性的图表中也很明显。在采样初期,数据预测的相关性高于后期;而噪声预测在后期的相关性则高于初期。对于噪声估计而言,样本在最后几个时间步仍残留少量噪声,导致样本发生细微变化且方差较大。总之,在采样后期,每个时间步修改的细节各不相同。
另一方面,数据估计器很难从初始噪声样本中预测原始样本。但随着采样进行,样本噪声减少,其预测在后期会更加一致。这一观察结果与Benny和Wolf[3]的分析一致,即数据估计器的方差随着采样步骤的增加逐渐减小,而$ \epsilon$ 估计器的方差在采样最后阶段突然增大。
9. 扩散模型中的知识蒸馏
知识蒸馏[14]最初被提出用于将知识从较大的模型(教师模型)转移到较小的模型(学生模型),学生模型通过训练来模仿教师模型的输出。这一概念可应用于基于扩散的采样过程,将多个时间步(教师)合并为单个时间步(学生),以加快生成速度。
Luhman和Luhman[28]通过最小化一步学生采样器输出与多步DDIM采样器输出之间的差异,将知识蒸馏直接应用于扩散模型。因此,学生模型通过训练模仿教师模型的输出,并以预训练去噪网络为初始值,以继承教师模型的知识。
随后,渐进式蒸馏[42]提出了一种迭代方法,训练学生网络合并教师网络的两个采样时间步,直到实现一步采样以模仿整个采样过程。这使得学生网络能够逐步学习教师的采样过程,因为学习预测两步采样的输出比学习预测多步采样的输出更容易。给定预训练去噪网络$ \theta$ 作为教师模型,Salimans和Ho[42]首先训练学生网络$ \theta’$ 预测教师网络两步采样的输出。然后,学生网络$ \theta’$ 成为新的教师模型,再训练参数为$ \theta’’$ 的新学生网络合并新教师网络$ \theta’$ 的两个采样时间步,直到总时间步减少到一步。学生模型与教师模型采用相同的深度神经网络进行参数化和初始化,并且渐进式蒸馏基于DDIM采样器进行研究。
Meng等人[31]将渐进式蒸馏扩展到无分类器扩散引导的场景,实现了文本到图像生成、类别条件生成、图像到图像转换和图像修复的单步或少步生成。他们采用两阶段方法,首先训练学生模型匹配条件模型和无条件模型的合并输出,然后通过将学生模型设为新教师模型来应用渐进式蒸馏。大多数配置与Salimans和Ho[42]相同,主要使用DDIM采样器。
最近,Song等人[48]提出了一类新的生成模型——一致性模型,该模型利用概率流ODE轨迹上的一致性特性。它们经过训练,能从同一ODE轨迹上的任意点预测原始样本。训练过程中使用目标网络和在线网络,在线网络通过优化生成与目标网络相同的输出,而目标网络通过指数移动平均进行更新。一致性模型本质上能够单步或少步生成样本,还能以零样本方式实现图像修复、上色和超分辨率等任务。它们既可以独立训练,也可以通过蒸馏训练,分别称为一致性训练和一致性蒸馏。在本文中,我们关注一致性蒸馏,并将其与我们的蒸馏方法进行比较。
然而,这些蒸馏方法通常需要大量训练以适应不同的预训练模型、数据集和ODE求解器,这限制了它们的实际应用。在本文中,我们提出专门优化新参数化的ODE求解器(D-ODE求解器)。这种方法能有效将大步数采样过程蒸馏为新的小步数过程,同时保持预训练去噪网络固定。由于我们的方法不需要更新去噪网络的参数,蒸馏过程仅需几分钟的CPU时间即可完成。
10. D-ODE求解器的实现细节
在本节中,我们详细解释我们关注的ODE求解器及其在D-ODE求解器框架中的应用。根据扩散时间步的性质,我们将ODE求解器分为两类:离散型和连续型。离散时间ODE求解器包括DDIM、PNDM、DPM-Solver和DEIS,我们基于Lu等人[25]的代码构建;连续时间ODE求解器包括基于Karras等人[19]工作重新实现的DDIM和EDM。
10.1 噪声预测网络中的D-ODE求解器
DDIM[45]是基于DDPM[15]的非马尔可夫扩散过程构建的,它使用隐式模型定义确定性生成过程。给定时间步$ t$ 的估计样本$ \hat{x}_t$ ,DDIM采样过程表示为:
其中$ d_t = D_\theta(\hat{x}_t, t)$ ,$ D_\theta$ 为去噪网络。这里,$ (\alpha_t, \sigma_t)$ 表示预定义的噪声调度,去噪网络被参数化为噪声预测网络$ \epsilon_\theta$ 。D-ODE求解器定义的新去噪输出$ O_t$ 如主论文中的符号表示为:
然后,我们只需将采样过程中的去噪输出$ d_t$ 替换为新的去噪输出$ O_t$ :
上述方程定义了D-DDIM采样过程,其中$ \lambda_t$ 通过知识蒸馏进行优化。在无法获取前一个去噪输出的情况下(例如,在时间步$ T$ ),我们使用给定的噪声样本定义新的去噪输出$ O_t$ ,因此在初始时间步$ T$ 有$ O_t = d_T + \lambda_T(d_T - x_T)$ 。理论上,$ x_T$ 和$ d_T$ 都遵循正态分布$ N(0, \sigma_t^2 I)$ ,这确保了$ O_t$ 的均值与原始去噪输出的均值一致。可以预期,$ (d_T - x_T)$ 在某种程度上包含指向真实$ x_{T-1}$ 的方向信息,这在实际中确实提高了FID分数。因此,我们也将这种采样方法应用于其他基于噪声预测网络的D-ODE求解器。
PNDM[24]基于数据流形上的伪数值方法,其构建基于经典数值方法可能偏离数据高密度区域这一观察。PNDM将DDIM作为简单情况包含在内,并通过高阶方法超越DDIM。然而,PNDM在前3步需要12次NFE,难以与使用固定NFE的其他方法进行比较。因此,我们选择iPNDM[58],它无需初始预热步骤,在保持伪数值采样过程的同时性能优于PNDM。iPNDM采用多个去噪输出的线性组合来表示当前去噪输出,同时遵循DDIM的采样更新路径,如下所示:
其中$ \hat{d}_t$ 通过三个先前的去噪输出(即$ d_{t+1}$ 、$ d_{t+2}$ 和$ d_{t+3}$ )近似,然后应用于DDIM采样过程。因此,前三个去噪输出应独立定义如下:
利用iPNDM定义的这些新去噪输出$ \hat{d}_t^{(p)}$ (三步后$ p=3$ ),我们构建D-iPNDM的采样过程,其中新去噪输出$ O_t$ 可定义为:
然后,式(21)中的$ \hat{d}_t^{(p)}$ 被替换为$ O_t$ ,这导致与式(19)相同的更新规则,但$ O_t$ 的表述不同。
DPM-Solver[25]通过求解ODE线性部分的精确表述,并使用指数积分器近似神经网络的加权积分[16],利用概率流ODE的半线性结构。DPM-Solver提供一阶、二阶和三阶方法,其中一阶变体对应DDIM。对于单步方法,DPM-Solver策略性地使用这些不同阶数的方法划分总采样步骤。例如,DPM-Solver2(二阶DPM-Solver)被使用5次,生成包含10次去噪步骤的样本,其中DPM-Solver2内去噪网络被评估两次。为实现15次去噪步骤,DPM-Solver2被应用7次,最后一次去噪步骤应用DPM-Solver1(或DDIM)。
在本节中,我们深入探讨D-DPM-Solver2的表述,DPM-Solver3和DPM-Solver++[26]的应用遵循类似方法。首先,我们将$ \tau_t = \log(\alpha_t / \sigma_t)$ 表示为信噪比(SNR)的对数,$ \tau_t$ 是随$ t$ 增加而严格递减的函数。因此,我们可以建立从$ \tau$ 到$ t$ 的逆函数映射,记为$ t_\tau(\cdot): \mathbb{R} \to \mathbb{R}$ 。现在,我们可以概述DPM-Solver2的步骤如下:
在这些方程中,$ h_t = \tau_{t-1} - \tau_t$ ,$ \hat{x}_{t-\frac{1}{2}}$ 表示时间步$ t-1$ 和$ t$ 之间的中间输出。由于DPM-Solver2采用两阶段去噪步骤,我们必须定义两个去噪输出$ O_t$ 和$ O_{t-\frac{1}{2}}$ ,以通过知识蒸馏优化$ \lambda_t$ 和$ \lambda_{t-\frac{1}{2}}$ 来构建D-DPM-Solver2:
这些新去噪输出随后应用于式(27)和式(28),定义D-DPM-Solver2的采样过程:
与DPM-Solver类似,DEIS[58]采用指数积分器利用反向时间扩散过程的半线性结构。具体而言,他们提出使用高阶多项式近似ODE中的非线性项,如下所示:
其中$ \{C_{t j}\}_{j=0}^r$ 通过加权积分数值确定,以近似真实ODE轨迹。DEIS基于用于估计$ C_{t j}$ 的数值方法提供多种变体,在我们的实验中,我们选择tAB-DEIS,因为它在变体中表现出最有前景的结果。此外,Zhang和Chen[58]探索了不同$ r \in \{1,2,3\}$ 值的DEIS,其中较大的$ r$ 值通常能更好地近似目标得分函数。值得注意的是,DDIM可视为$ r=0$ 时tAB-DEIS的特例。
参考式(34),我们定义新去噪输出$ O_t$ 和D-DEIS的采样过程如下:
10.2 数据预测网络中的D-ODE求解器
在我们的研究中,我们在连续设置下使用数据预测网络的参数化重新实现了DDIM[45]。我们遵循Karras等人[19]概述的配置。这种改进的DDIM的采样过程定义如下:
其中$ d_t = D_\theta(\hat{x}_t, t)$ ,$ s_t$ 近似得分函数,指向数据的高密度区域。去噪网络被参数化为数据预测网络$ x_\theta$ ,去噪步骤在式(38)中基于$ (\sigma_t - \sigma_{t-1})$ 测量的噪声水平差异进行。
与噪声预测网络中的D-DDIM类似,D-DDIM的新去噪输出$ O_t$ 定义为:
然后,$ O_t$ 代替$ d_t$ 纳入DDIM的采样过程:
Karras等人[19]引入基于Heun二阶方法的EDM采样器,该方法在CIFAR-10和ImageNet64上实现了最先进的FID分数。他们利用新颖的ODE表述、参数选择和改进的神经架构。EDM采样过程如下所示:
其中$ d_{t-1}’ = D_\theta(\hat{x}_{t-1}’, t-1)$ 。EDM的第一阶段(式(42)和式(43))等同于DDIM,然后在第二阶段(式(44)和式(45))通过线性组合两个估计$ s_t$ 和$ s_t’$ 更准确地估计得分函数。值得注意的是,18步EDM采样对应35次NFE,因为一步EDM涉及两次网络评估,且最后一步不计算式(44)和式(45)。
为构建D-EDM的采样过程,我们定义两个去噪输出:
因此,D-EDM的采样步骤描述如下:
10.3 D-ODE求解器的多种解释
D-ODE求解器中的新去噪输出$ O_t$ 基于去噪输出高度相关这一观察构建,且保持与原始输出相同的均值至关重要。我们将去噪输出的定义重写如下:
上述表述可以解释为在当前和先前去噪输出之间进行插值(或外推),以估计准确的得分函数。因此,D-ODE求解器可视为通过知识蒸馏优化$ \lambda_t$ ,动态对去噪输出进行插值(或外推)的过程。类似地,Zhang等人[57]提出对原始数据$ \hat{x}_t$ 的当前和先前估计进行外推。他们认为,通过优化真实均值估计,两个预测之间的外推包含指向目标数据的有用信息。尽管准确外推需要通过网格搜索进行参数调整,但他们证明了对各种ODE求解器的FID有改进。
另一种解释基于Permenter和Yuan[36]的工作,他们在特定假设下将去噪过程与应用于欧几里得距离函数的梯度下降相匹配。他们利用向真实数据分布投影的定义重新解释扩散模型,并通过最小化相邻时间步之间$ \epsilon$ 预测的误差提出新采样器。他们的采样器对应通过网格搜索选择$ \lambda_t = 1$ 的D-DDIM,且性能优于DDIM和PNDM。
最后一种解释是,D-ODE求解器加速样本生成收敛的方式类似于动量加速随机梯度下降(SGD)中的优化[49]。正如带动量的SGD利用先前梯度历史加速神经网络中的参数更新一样,D-ODE求解器利用先前去噪输出来加速采样收敛。一个有趣的未来方向可以探索机器学习模型中使用的先进优化器[8,20,40]是否能有效应用于扩散模型。
10.4 D-ODE求解器的多种表述
为进一步验证D-ODE求解器的有效性,我们基于DDIM探索了D-ODE求解器的不同表述。例如,我们可以分别为两个相邻去噪输出估计参数,而不是优化单个参数$ \lambda_t$ ,我们将其命名为D-DDIM-Sep。D-DDIM-Sep对应主论文中式(8),其中$ T = t+1$ 。主论文中式(8)表示为D-DDIM-All,其中利用所有先前去噪输出来估计新的去噪输出。此外,我们还包括主论文中式(10)所示的D-DDIM和等同于主论文中式(9)且$ T = t+2$ 的D-DDIM-2。为便于比较,所有方法明确呈现如下,其中$ d_t = D_\theta(\hat{x}_t, t)$ :
我们在CIFAR-10上用不同NFE测试了上述五种表述,而蒸馏和采样的所有其他配置保持不变。如表2所示,D-DDIM优于所有其他表述,而其他表述如D-DDIM-Sep、D-DDIM-All和D-DDIM-2甚至比DDIM的FID分数更差。D-DDIM-Sep和D-DDIM-All的FID分数尤其高,这可以解释为采样过程未能适当收敛以生成逼真样本。正如我们在主论文中指出的,独立估计的参数可能偏离ODE求解器的目标轨迹。这是因为式(54)和式(55)中通过蒸馏确定的$ \lambda$ 集可能在没有任何约束的情况下不稳定,且可能无法反映不同批次间的一般采样规则。D-DDIM-2也没有改善DDIM的FID分数。一个可能的原因是在一个批次上优化的参数可能不适用于其他批次。由于这两个参数仅在一个批次上优化,像D-DDIM-2这样的去噪预测精细估计可能对所有批次都不有效。
此外,我们参考主论文图5展示了图7中范数变化的比较。虽然D-DDIM-All-10和D-DDIM-Sep-10最初似乎遵循目标轨迹(即DDIM-1000),但它们最终严重偏离目标轨迹或原始ODE轨迹(即D-DDIM-10),这与表2中的高FID分数一致。如3.2节所述,这是由于式(55)固有的不稳定性。
11. 实验细节
模型架构
对于噪声预测模型,我们遵循Ho等人[15]和Dhariwal与Nichol[6]的架构和配置,使用他们的预训练模型。具体而言,我们在CIFAR-10和CelebA 64×64的实验中采用DDPM[15]中的模型架构和配置。对于ImageNet 128×128和LSUN Bedroom 256×256,我们使用Dhariwal和Nichol[6]中的相应网络架构。在数据预测模型的实验中,我们使用Karras等人[19]的配置和预训练模型,用于CIFAR-10、FFHQ 64×64和ImageNet 64×64。
蒸馏配置
如主论文算法所述,我们首先执行$ CT$ 步教师采样以设置目标样本,然后执行$ T$ 步学生采样,使学生输出与教师目标匹配。对于大多数D-ODE求解器,我们使用DDIM采样作为教师采样方法,因为它每个去噪步骤生成一个去噪输出,能够实现目标与预测之间的一对一匹配。对于iPNDM和DEIS,我们分别使用它们自身作为蒸馏的教师方法(例如,$ CT$ 步的DEIS作为教师,$ T$ 步的D-DEIS作为学生)。尽管它们使用先前去噪输出的线性组合来估计当前去噪预测,但采样动态与DDIM相同。因此,教师目标和学生预测可以轻松匹配。
此外,学生采样按顺序执行以优化D-ODE求解器中的$ \lambda$ 。换句话说,首先通过蒸馏估计$ \lambda_t$ ,然后在学生采样期间,使用时间步$ t$ 优化后的D-ODE求解器生成时间步$ t+1$ 的下一个样本。这种方法有助于稳定采样过程,因为$ \lambda_{t+1}$ 基于使用$ \lambda_t^*$ 的D-ODE求解器先前生成的样本进行估计。因此,它可以通过精确估计的$ \lambda$ 减轻暴露偏差[34,38]。
采样细节
为简单起见,我们对所有ODE求解器采用均匀划分的时间步。我们生成50K样本,并报告三次不同种子运行后计算的平均FID分数。所有实验均使用GPU进行,包括NVIDIA TITAN Xp、Nvidia V100和Nvidia A100。我们固定尺度$ C=10$ 和批次大小$ |B|=100$ ,但LSUN Bedroom除外,由于内存限制,其批次大小为25。这两个参数的消融研究在第12节中呈现。
每个ODE求解器需要做出若干设计选择。PNDM在前3步需要12次NFE,难以与使用固定NFE的其他方法进行比较。因此,我们采用iPNDM[58],它无需初始预热步骤且性能优于PNDM。DEIS提供多种版本的ODE求解器,其中我们选择tAB-DEIS,它在他们的实验中表现出最佳FID分数。DPM-Solver使用自适应步长组合不同阶数的求解器。为简单起见,我们选择单步DPM-Solver,它依次使用DPM-Solver1、DPM-Solver2和DPM-Solver3来构成总时间步。虽然EDM在设计上允许随机采样,但我们采用确定性采样以获得教师采样生成的明确目标样本。
12. 消融研究
我们对D-ODE求解器蒸馏的两个关键参数进行了消融研究:尺度$ C$ 和批次大小$ |B|$ 。尺度$ C$ 决定教师采样的步数,教师采样的去噪步数是学生采样的$ C$ 倍。较大的尺度$ C$ 会使教师采样生成更好的目标样本,且可视为增加蒸馏过程中教师的指导强度。选择合适的批次大小$ |B|$ 也至关重要,因为最优$ \lambda$ 在单个批次$ B$ 上估计,然后重用于其他批次。因此,批次大小应足够大以涵盖数据集中样本的不同模式,而过大的批次大小可能无法放入GPU内存。
我们在表3a中使用CIFAR-10上训练的噪声预测模型测试了各种尺度。随着尺度增加,不同NFE值下的FID分数持续改善。较大的尺度$ C$ 使学生采样在准确的教师目标强烈指导下,FID更低。然而,随着NFE增加,指导尺度的影响减弱。这是合理的,因为学生采样的性能在很大程度上取决于教师采样的性能,且教师的FID分数最终会收敛到某个值。由于离散时间步的最大数量为1000,50次NFE下的尺度20和30生成的样本由相同的教师采样指导。
在表3b中,不同批次大小的D-ODE求解器也表现出明显趋势。随着批次大小增加,FID分数和方差均趋于减小。在NFE值较大时,FID分数和方差收敛到某个点。由于随着NFE增加,蒸馏的效果减弱,即使小批次大小也会导致低方差。我们为大多数数据集选择批次大小100,它足以捕捉数据集的固有多样性,并与更小的批次大小相比降低方差。
13. 更多比较
在本节中,我们进一步比较D-ODE求解器与先前的有学习(知识蒸馏)和无学习方法。图8a展示了CIFAR-10上不同NFE的FID分数,包括一致性蒸馏(CD)[48](可执行单步或少步采样)和渐进式蒸馏(PD)[42](允许采样步数为几何序列,例如1、2、4、…、1024)。D-EDM至少需要两步才能利用先前的去噪输出。
总体而言,CD在单步生成的FID方面优于其他方法。然而,需要注意的是,这种比较未考虑训练时间。例如,Song等人[48]报告CIFAR-10上的一致性模型训练使用了8个Nvidia A100 GPU。另一方面,仅用单个A100 GPU在30步内生成50K样本不到30分钟,就能达到与一致性模型相近的样本质量。虽然CD和PD对于拥有充足计算资源的从业者而言是有吸引力的选择,因为它们能够实现单步生成,但D-ODE求解器的主要优势在于能够以最小修改和快速优化增强现有基于ODE求解器的采样器。
最近,Zhang等人[57]引入前瞻扩散模型,通过使用先前数据预测优化均值估计来提高现有ODE求解器的FID分数。他们通过外推先前的初始数据预测来近似目标数据。与D-ODE求解器不同,前瞻模型需要通过网格搜索选择参数$ \lambda$ ,实验中默认设置$ \lambda=0.1$ 。遵循他们的配置,我们在表4中比较DDIM的前瞻扩散模型(称为LA-DDIM)与我们的D-DDIM。表中显示,除50次NFE外,D-DDIM优于LA-DDIM。
受LA-DDIM启发,我们还实验了在D-DDIM中将$ \lambda_t$ 固定为常数$ \lambda$ ,并通过网格搜索对其进行优化。我们将这种改进方法称为Fixed-D-DDIM。在图8b和图8c中,我们在CIFAR-10上使用10步采样器对$ \lambda$ 进行网格搜索。此外,我们提供DDIM和D-DDIM的FID分数作为参考(虚线)。尽管LA-DDIM进行了网格搜索,但仍无法达到D-DDIM的FID。另一方面,Fixed-D-DDIM通过充分的网格搜索能够达到与D-DDIM相同的FID。这表明利用去噪输出比依赖初始数据预测更有效。此外,Fixed-D-DDIM在25和50次NFE下进一步优于D-DDIM的性能,表明有可能找到更优的$ \lambda$ 值以降低FID。未来的研究方向可以探索有效确定$ \lambda$ 的各种方法。需要强调的是,LA-DDIM和Fixed-D-DDIM的FID随所选$ \lambda$ 变化。然而,D-DDIM相对于其他方法的优势在于它不依赖网格搜索,且采样时间与DDIM相当。
14. DPM-Solver++的更多实验
在DPM-Solver[25]的基础上,DPM-Solver++[26]解决了先前求解扩散ODE的多步方法中的不稳定性,并采用阈值方法将解约束在原始数据范围内。与10.1节中解释的D-DPM-Solver表述类似,我们应用新的去噪输出来替换原始去噪输出。图9表明,将D-ODE求解器应用于DPM-Solver++可通过蒸馏进一步提高图像质量。
此外,我们在图10中展示了噪声预测模型在CIFAR-10、CelebA64和ImageNet128上的额外实验结果。
15. 分析图和定性结果
在图11中,展示了更多与主论文图5类似的分析图,涉及不同像素。我们还在图12、图13、图14和图15中展示了更多定性结果。