全文翻译
摘要
本文提出的“判别器引导(Discriminator Guidance)”方法旨在改进预训练扩散模型的样本生成质量。该方法引入一个判别器,对去噪样本路径的真实性进行显式监督。与生成对抗网络(GANs)不同,我们的方法无需对分数网络和判别器网络进行联合训练。相反,我们在分数网络训练完成后再训练判别器,这使得判别器的训练过程更加稳定且收敛速度更快。在样本生成阶段,我们在预训练的分数中添加一个辅助项以“欺骗”判别器。在最优判别器条件下,该辅助项能将模型分数修正为数据分数,这意味着判别器以互补的方式帮助实现更优的分数估计。通过我们的算法,在ImageNet 256x256数据集上取得了当前最优结果,FID值为1.83,召回率为0.64,与验证集数据的FID(1.68)和召回率(0.66)相当。我们已在https://github.com/alsdudrla10/DG发布相关代码。
1. 引言
近年来,扩散模型因其在图像生成(Dhariwal & Nichol, 2021; Ho et al., 2022a; Karras et al., 2022; Song et al., 2020b)、视频生成(Singer et al., 2022; Ho et al., 2022b; Voleti et al., 2022)以及文本到图像生成(Rombach et al., 2022; Ramesh et al., 2022; Saharia et al., 2022)等领域的成功而备受关注。当前最优模型已能实现接近人类水平的生成效果,但对于扩散模型的深入理解仍有很大探索空间。
生成模型领域广泛将训练良好的分数模型(Dhariwal & Nichol, 2021; Rombach et al., 2022)应用于下游任务(Meng et al., 2021; Kawar et al., 2022; Su et al., 2022; Kim et al.)。部分原因是从头训练一个新的分数模型计算成本极高。然而,随着复用预训练模型的需求增加,针对如何利用预训练分数模型提升样本质量的研究却为数不多。
为避免进一步训练分数模型可能引发的过拟合(Nichol & Dhariwal, 2021)或记忆化(Carlini et al., 2023)等问题(如图25所示),我们的方法固定预训练分数模型,并引入一个新组件在样本生成过程中提供监督。具体而言,我们提出将判别器作为预训练模型的辅助自由度。
该判别器在所有噪声尺度下对真实数据和生成数据进行分类,为样本去噪过程提供直接反馈,指示样本路径是否真实。我们通过在模型分数中添加由判别器构建的修正项来实现这一点,该修正项将样本路径导向更真实的区域(图1)。此修正项旨在使最优判别器下的模型分数与数据分数相匹配(定理1),从而通过调整模型分数帮助我们的方法找到真实的样本路径。在实验中,我们在CIFAR-10、CelebA/FFHQ 64x64和ImageNet 256x256等图像数据集上取得了新的最优性能。由于判别器训练是一个稳定且收敛迅速的极小化问题(图3),因此只需较低的计算成本就能实现显著的性能提升(表6)。我们的贡献总结如下:
- 我们提出了一种新的生成过程——判别器引导,它对给定预训练分数模型的分数进行调整。
- 我们从理论和实证上证明,判别器引导的样本比非引导样本更接近真实世界数据。
2. 预备知识
假设$p_{r}(x_{0})$为数据分布,$p_{\theta}(x_{0})$为模型分布。基于似然的潜变量模型通过最小化KL散度$D_{KL}(p_{r}(x_{0}) | p_{\theta}(x_{0}))$的上界来优化其参数,该上界由下式给出:
其中$x_{1: T}$为$T$个潜变量;$q(x_{0: T})$为推理分布,其边缘密度为$q(x_{0}):=p_{r}(x_{0})$;$p_{\theta}(x_{0: T})$为生成分布,其边缘密度为$p_{\theta}(x_{T}):=\pi(x_{T})$,其中$\pi$是便于采样的先验分布,用于生成目的。
去噪扩散概率模型(DDPM)(Ho等人,2020)通过逐步向数据变量$x_{0}$添加迭代高斯噪声来构建$x_{1: T}$,使得$q$成为非参数化的固定推理分布,即$q(x_{0: T})=p_{r}(x_{0}) \prod_{t=1}^{T} q(x_{t} | x_{t-1})$。大多数扩散模型(Okhotin等人,2023)假设生成过程为马尔可夫链,以满足$p_{\theta}(x_{0: T})=\pi(x_{T}) \prod_{t=1}^{T} p_{\theta}(x_{t-1} | x_{t})$,这种建模选择使得能够以易于处理的方式优化替代目标$D_{KL}(q(x_{0: T}) | p_{\theta}(x_{0: T}))$。
DDPM的连续时间对应模型(Song等人,2020b)用随机微分方程(SDE)描述扩散过程:
其中$t$是$[0, T]$范围内的连续扩散索引,$f(x_{t}, t)$和$g(t)$分别是漂移系数和波动率系数。我们主要在连续时间框架下描述我们的模型,这主要是为了符号简化。我们的模型适用于离散时间和连续时间两种设置。
在连续时间框架下,式(1)中的正向时间扩散过程具有唯一的反向时间扩散过程(Anderson,1982):
其中$d \bar{t}$和$\bar{w}_{t}$分别是无穷小反向时间和反向时间布朗运动。随后,连续时间生成过程变为:
其中分数网络$s_{\theta}(x_{t}, t)$的估计目标是实际数据分数$\nabla \log p_{r}^{t}(x_{t})$。这里,$p_{r}^{t}$是遵循式(1)中正向时间扩散过程的数据分布的扩散概率密度。
连续时间模型使用去噪分数匹配损失(Song & Ermon,2019)训练分数网络:
其中$\xi$是时间权重,$p_{0 t}$是从$x_{0}$到$x_{t}$的转移概率。如果$\xi(t)=g^{2}(t)$(Chen等人,2016;Song等人,2021),则该去噪分数目标与联合KL散度$D_{KL}(q(x_{0: T}) | p_{\theta}(x_{0: T}))$一致。此外,在不同的权重函数下,该目标可以等效地解释为噪声匹配损失$\int_{0}^{T} \mathbb{E}[\left|\epsilon_{\theta}-\epsilon\right|_{2}^{2}]$(Ho等人,2020)或数据重建损失$\int_{0}^{T} \mathbb{E}[\left|\hat{x}_{\theta}(x_{t})-x_{0}\right|_{2}^{2}]$(Kingma等人,2021)。
有多种方法可以提高分数训练的精度。例如,Kim等人(2022b)、Kingma & Gao(2023)、Hang等人(2023)提出使用最大扰动似然估计来更新分数网络,以提高大时间去噪精度。相反,Lai等人(2022)、Daras等人(2023)研究了数据扩散过程的不变特征,并建议在去噪分数损失中添加额外的正则化项以满足这些不变性质。另一方面,我们的工作旨在通过噪声对比估计来优化固定的模型分数,这与先前提高分数精度的尝试不同。
3. 利用判别器引导优化生成过程
3.1 预训练模型分数的修正
分数网络训练完成后,我们通过时间反转生成过程合成样本:
其中$s_{\theta_{\infty}}$表示收敛后的分数网络。如果局部最优解$\theta_{\infty}$偏离全局最优解$\theta_{_}$,则该生成过程可能与时间反转的数据过程不一致。我们在定理1中证明:若对模型分数进行调整,式(3)的生成过程可与式(2)的数据过程一致。我们将这种差距称为修正项,只要$\theta_{\infty} \neq \theta_{_}$,该修正项就不为零。
定理1:设$p_{\theta_{\infty}}$为式(3)时间反转生成过程的解。令$p_{r}^{t}$和$p_{\theta_{\infty}}^{t}$分别为从$p_{r}$和$p_{\theta_{\infty}}$出发的正向SDE $dx_{t}=f(x_{t},t)dt+g(t)dw_{t}$在时间$t$的边缘密度。若$s_{\theta_{\infty}}(x,T)=\nabla \log \pi(x)$(其中$\pi$为先验分布),且对数似然$\log p_{\theta_{\infty}}$等于其证据下界$L_{\theta_{\infty}}$,则反向SDE:
与具有调整后分数的扩散过程一致:
其中:
3.2 判别器引导
修正项$c_{\theta_{\infty}}(x_{t},t)=\nabla \log \frac{p_{r}^{t}(x_{t})}{p_{\theta_{\infty}}^{t}(x_{t})}$通常难以直接求解,因为密度比$\frac{p_{r}^{t}}{p_{\theta_{\infty}}^{t}}$无法直接获取。因此,我们通过在所有噪声水平$t$上训练判别器来估计该密度比。对于判别器训练,我们首先从式(3)的生成过程中抽取与数据样本数量相当的伪造样本,然后使用嵌入噪声的二元交叉熵(BCE)损失对真实数据和伪造数据进行分类:
其中$\lambda$为时间权重,详见算法1和附录A.4。
在最优判别器$d_{\phi_{_}}$(即最小化$L_{\phi}$的判别器)下,修正项可表示为:
因此,我们使用神经判别器$d_{\phi}$估计修正项$c_{\theta_{\infty}}$:
基于上述易于处理的修正项估计,我们定义判别器引导(DG)为:
图3显示,判别器确实能快速收敛并提升样本质量。
3.3 理论分析
尽管我们在微分方程的框架下引入了判别器引导,但本节将从数据分布与样本分布之间的统计散度角度分析该方法。具体而言,我们将式(5)的判别器引导样本分布记为$p_{\theta_{\infty},\phi}$。核心问题是:$p_{\theta_{\infty},\phi}$是否比$p_{\theta_{\infty}}$更接近数据分布$p_{r}$?
我们通过定理2回答这一问题。
定理2:若定理1的假设成立,则:
其中$E_{\theta_{\infty}}$为分数误差:
$E_{\theta_{\infty},\phi}$为判别器调整后的分数误差:
为衡量判别器训练的效果,我们利用定理2通过两个KL散度的差值计算增益:
其中$Gain(\theta_{\infty},\phi)=E_{\theta_{\infty}}-E_{\theta_{\infty},\phi}$表示分数误差与判别器调整后分数误差的差值。需要注意的是,定理2并未保证增益一定为正,但增益初始值接近零,并在判别器训练过程中逐渐增大,如表1所示。具体而言,当判别器完全失效时($d_{\phi_{b}} \equiv 0.5$),判别器梯度无信号,此时判别器调整后的分数误差$E_{\theta_{\infty},\phi_{b}}$等于分数误差$E_{\theta_{\infty}}$。因此,如图3所示,当判别器未训练时($d_{\phi_{0}} \approx 0.5$),增益近似为零。另一方面,在最优判别器$d_{\phi_{_}}$下,神经修正项$c_{\phi_{_}}$与目标修正项$c_{\theta_{\infty}}$匹配,且满足$E_{\theta_{\infty},\phi_{*}}=0$,因此随着判别器参数的更新,增益可达到最大值(见图4的示意图)。换言之,我们可以将判别器引导理解为引入额外的自由度$\phi$,将分数误差$E_{\theta_{\infty}}$重新参数化为判别器调整后的分数误差$E_{\theta_{\infty},\phi}$。因此,分数误差$E_{\theta_{\infty}}$不再通过去噪分数损失优化,而是通过式(4)的替代损失$L_{\phi}$优化重新参数化后的误差$E_{\theta_{\infty},\phi}$。
表1:判别器调整后的分数误差$E_{\theta_{\infty},\phi}$及相应增益
| 判别器 | $E_{\theta_{\infty},\phi}$ | 增益 |
|---|---|---|
| 失效判别器$d_{\phi_{b}}(\equiv 0.5)$ | $E_{\theta_{\infty}}$ | 0 |
| 最优判别器$d_{\phi_{*}}$ | 0 | $E_{\theta_{\infty}}$(最大值) |
| 未训练判别器$d_{\phi_{0}}(\approx 0.5)$ | $\approx E_{\theta_{\infty}}$ | $\approx 0$ |
| 训练后判别器$d_{\phi_{\infty}}$ | $\ll E_{\theta_{\infty}}$ | $\nearrow E_{\theta_{\infty}}$ |
3.4 最优性分析
我们进一步分析分数组件。与不稳定的GAN训练不同,我们的判别器训练过程稳定,因为训练时固定预训练分数模型。因此,当判别器达到最优后,调整后的模型分数为:
因此,样本分布$(p_{r}^{t})^{w}(p_{\theta_{\infty}}^{t})^{1-w}$平衡了数据分布与非引导分布。这一结论同样适用于条件生成场景,使得判别器引导可作为类内多样性的控制器。图5在ImageNet数据集上的实验表明,存在一个判别器引导权重的“最佳点”,能同时优化样本质量(FID)和多样性(召回率)。
3.5 与分类器引导的关联
分类器引导(CG)(Dhariwal & Nichol, 2021)是一项里程碑式的技术,它利用预训练分类器$p_{\psi_{\infty}}(c | x_{t},t)$引导样本生成。分类器引导的生成过程为:$dx_{t}=[f(x_{t},t)-g^{2}(t)(s_{\theta_{\infty}}(x_{t},t)+\nabla \log p_{\psi_{\infty}}(y | x_{t},t))] d\bar{t}+g_{t} d\overline{w}_{t}$。这相当于从$(x_{t},y)$的联合分布中采样,因为:
其中$p(y | x_{t},t)$为时间$t$的理想分类器。分类器引导为样本路径提供监督信息,判断样本是否被类别标签$y$正确分类。然而,由于分类器引导最大化分类概率$p(y | x_{t},t)$,可能导致模式坍塌。相比之下,如3.4节所述,判别器引导通过提供样本路径真实性的独特监督信息,可增强模式覆盖度。
由于从$(x_{t},y)$的联合分布中采样需要准确的分数估计,判别器引导与分类器引导可结合使用以产生协同效应。我们建议将两种引导技术结合:
其中$w_{t}^{DG}$和$w_{t}^{CG}$分别为时间相关权重。理想情况下,这两种信息能以互补方式引导样本向分类器和判别器的共同高概率区域移动。
表2:算法2包含DDPM、DDIM和EDM采样器
| 采样器 | $\gamma_{t}$ | $\eta$ | $w_{t}^{CG}$ | $w_{t}^{DG}$ |
|---|---|---|---|---|
| DDPM | 0 | 1 | 0 | 0 |
| DDIM | 0 | 0 | 0 | 0 |
| EDM | $\geq 0$ | 0 | 0 | 0 |
| ADM-G | 0 | 1 | $>0$ | 0 |
| EDM-G++ | $\geq 0$ | 0 | 0 | $>0$ |
| ADM-G++ | 0 | 1 | $>0$ | $>0$ |
算法2描述了式(6)采样过程的完整细节。该算法通过表2中的相应超参数,简化了DDPM(Ho等人,2020;Dhariwal & Nichol, 2021)、DDIM(Song等人,2020a)和EDM(Karras等人,2022)的采样器。我们的采样器在基础采样器前缀后以“G++”为后缀表示。详见附录D.1的详细采样步骤。
4. 相关工作
有一系列研究将扩散模型与GAN模型相结合。Zheng等人(2022b)、Lyu等人(2022)利用GAN生成器合成扩散数据$x_{\sigma_{mid}}$(将$x_{T}$作为生成器的输入),并通过扩散模型将$x_{\sigma_{mid}}$去噪为$x_{0}$。Xiao等人(2022)用少量序列条件GAN生成器替代数千步去噪步骤。Wang等人(2022)利用扩散概念训练GAN。相反,Jolicoeur-Martineau等人(2021)通过对抗损失训练扩散模型。然而,在上述工作中,扩散模型与GAN在训练完成后的生成过程中并未相互作用。相比之下,判别器引导中的判别器直接干预生成过程。与先前研究的另一区别是,我们的判别器训练无需任何生成器参与,因此训练过程稳定。
其他先前工作使用拒绝采样或MCMC从重新加权的模型分布中抽取样本。Azadi等人(2019)、Che等人(2020)利用拒绝采样,在似然比技巧(Gutmann & Hyvärinen, 2010)下通过判别器调整生成器的隐式分布。类似地,Turner等人(2019)利用似然比技巧进行MCMC采样。具体而言,Aneja等人(2021)和Bauer & Mnih(2019)分别在VAE中引入了聚合后验的重要性采样和拒绝采样。在扩散模型中,判别器引导通过梯度上升最大化$\frac{p_{r}^{t}}{p_{\theta_{\infty}}^{t}}$,本质上是用重要性权重$\frac{p_{r}^{t}}{p_{\theta_{\infty}}^{t}}$对模型分布$p_{\theta_{\infty}}^{t}$进行重新加权。因此,判别器引导无需样本拒绝,且可扩展至高维场景。
5. 实验
5.1 二维玩具案例
图6展示了一个易于处理的二维玩具案例的实验结果。我们训练了一个含256个神经元的4层MLP判别器直至收敛,并假设存在一个不正确的分数函数$s:=\nabla \log p_{g}^{t}$(对应合成生成分布$p_{g}$),该函数与数据分数不匹配。若没有引导,不正确的分数$s$会生成来自错误分布的样本,如图6所示。相反,$s + c_{\phi_{\infty}}$成功将$s$引导至$\nabla \log p_{r}^{t}$,更多可视化结果见图15和图16。
5.2 图像生成
我们在CIFAR-10、CelebA/FFHQ 64x64和ImageNet 256x256数据集上进行了实验。CIFAR-10和FFHQ的预训练网络来自Karras等人(2022)、Vahdat等人(2021),CelebA的预训练网络来自Kim等人(2022b),ImageNet的预训练网络来自Dhariwal & Nichol(2021)、Peebles & Xie(2022)。
判别器网络
我们采用U-Net结构的编码器作为判别器网络。对于数据空间上的扩散模型,我们附加两个嵌入噪声的U-Net编码器:预训练的ADM分类器(Dhariwal & Nichol, 2021)和一个辅助(浅层)U-Net编码器。我们将$(x_{t}, t)$输入ADM分类器,并从预训练分类器的最后一个池化层提取$x_{t}$的 latent 特征$z_{t}$。然后,我们将$(z_{t}, t)$输入辅助U-Net编码器,并通过其输出来预测真实/伪造样本。我们冻结ADM分类器的参数,默认只微调浅层U-Net编码器。值得一提的是,微调不仅节省训练成本,其性能还优于或等同于训练整个架构(Kato & Teshima, 2021)。对于LSGM-G++,我们从头训练U-Net编码器。对于DiT-XL-G++,我们训练与ADM分类器架构相同但输入维度不同的latent分类器,并微调浅层U-Net编码器作为判别器。对于类别条件生成,我们训练类别条件判别器。详细训练配置见表8。
定量分析
我们在所有数据集(包括CIFAR-10、CelebA、FFHQ和ImageNet)上均实现了新的最优FID结果。在CIFAR-10上,表3显示判别器引导在数据扩散模型(EDM)和潜在扩散模型(LSGM)中均有效。除判别器引导外,我们使用EDM和LSGM的所有超参数,因此性能提升完全来自判别器组件。在人脸数据集上,判别器引导的增益同样显著,见表4。
在ImageNet 256x256上,我们在包括FID、sFID、IS和召回率在内的多种指标上呈现了最优结果,见表5。作为参考,我们还在ImageNet 50k验证集上测量了这些指标。值得注意的是,验证集的IS和精确率与最佳模型相当。因此,一旦模型的IS和精确率达到验证集水平,优化其他指标(如FID、sFID和召回率)就变得更为重要。实验表明,借助判别器引导,我们在DiT-XL/2-G++上实现了FID和召回率的最佳性能,表明样本质量和多样性均得到显著提升。详细超参数见表8,未筛选样本见附录D.4。
定性分析
图7展示了FFHQ数据集上原始样本与其100次再生样本平均值的对比(Meng等人,2021)。若分数估计准确,当扰动噪声足够小时,平均重建图像应近似等于原始图像。为解释这一点,假设原始图像为$y$,我们通过$y + \sigma(t)\epsilon$沿固定方向$\epsilon$对其进行扰动。将该扰动数据代入Tweedie公式(Robbins, 1992;Jolicoeur-Martineau等人,2021),平均重建数据$x_{0}$为:
若$\sigma(t)$足够小,则有$p_{r}^{t}(y + \sigma(t)\epsilon) \propto p(\epsilon)$,这导致$\mathbb{E}_{\epsilon}[\nabla \log p_{r}^{t}(y + \sigma(t)\epsilon)] \approx 0$。因此,沿随机方向$\epsilon$对$y$进行重建的平均图像近似为原始数据:
综上,平均重建图像与原始图像的接近程度间接反映了估计分数的准确性,因为上述结论对数据分数成立。在图7中,当$\sigma(t) = 1$(较小值)时,调整后的分数比原始模型分数提供了更准确的估计。
图8显示,训练后的判别器能够准确区分真实数据的扩散路径和生成样本的去噪路径。相反,经判别器调整的去噪路径“欺骗”了判别器,使得调整后生成SDE(式5)的密度比曲线与数据正向SDE的密度比曲线高度接近。
图9展示了在CIFAR-10上判别器引导随采样NFE(函数评估次数)的变化效果。随着NFE减少,离散化误差主导采样误差(De Bortoli, 2022),判别器引导的增益变得次优。我们将在极低NFE采样器上适配判别器引导作为未来工作。更多消融实验见附录D.3。
图10展示了判别器训练过程中的精确率/召回率曲线。在第0轮(判别器训练前),我们观察到 vanilla DiT-XL-G的精确率/召回率分别高于/低于验证集。这是因为分类器引导生成的样本在分类器看来过于自信。然而,判别器引导显著缓解了分类器引导的精确率-召回率权衡问题。
图11展示了按噪声尺度归一化的累积目标损失。为确保公平比较,我们使用相同的权重函数$\xi(t) = \lambda(t)$评估判别器损失$L_{\phi}$和分数损失$\mathcal{L}_{\theta}$。结果表明,判别器能够捕捉估计误差,尤其是在决定样本多样性的大扩散时间尺度上。这一发现凸显了判别器引导作为补充方法解决分数匹配框架中大时间尺度估计不佳问题(Kim等人,2022b)的潜力。
5.3 图像到图像翻译
判别器引导可应用于图像到图像(I2I)翻译任务。I2I(Meng等人,2021)通过在目标域训练的分数网络对扰动的源图像进行去噪。对于判别器训练,我们首先翻译源训练图像,然后将这些翻译图像与源图像聚合为伪造数据集。通过将D指定为目标图像,我们遵循算法1训练判别器。判别器引导避免样本停留在源域或翻译域,而是引导至目标域。实证上,图12-(b)的曲线表明我们的方法缓解了真实性(贴近目标域)与忠实度(贴近源域)之间的权衡。详细内容见附录C。
6. 讨论
本节探讨判别器引导未来发展的两个可能方向。第一个方向是用Bregman散度重写我们的方法,第二个方向是探索分数网络与判别器网络的联合训练。在第一个方向中,密度比$r_{\theta_{\infty}}^{t}=\frac{p_{r}^{t}}{p_{\theta_{\infty}}^{t}}$是判别器的目标,而BCE损失$L_{\phi}$可推广到h-Bregman散度家族(Sugiyama等人,2012):
其中$r_{\phi}=\frac{d_{\phi}}{1-d_{\phi}}$。值得注意的是,BCE损失是唯一同时属于h-Bregman散度和f-散度的散度(Amari,2009)。因此,我们的方法提出了一种新的分数估计散度家族,详见附录B的更多讨论。表7呈现了关于Bregman散度的实验结果。
另一个潜在方向是判别器与分数网络的联合训练,这可能比GAN更具吸引力,因为它是一个极小-极小问题而非GAN的极小-极大问题。然而,$L_{\theta}$和$L_{\phi}$的损失函数相互独立,因此它们的联合作用会受到限制。如图13所示,判别器对分数准确性的提升效果甚微。
作为替代方案,我们可将分数损失修改为f-散度(Song等人,2021):
其中在定理1的假设下,$E_{\theta, \phi}^{f}=\frac{1}{2} \int_{0}^{T} g^{2}(t) \mathbb{E}[f^{\prime \prime}(\frac{p_{r}^{t}(x_{t})}{p_{\theta}^{t}(x_{t})}) \frac{p_{r}^{t}(x_{t})}{p_{\theta}^{t}(x_{t})} | \nabla \log p_{r}^{t}(x_{t})-s_{\theta}(x_{t}, t) |_{2}^{2}] d t$。当使用神经比率近似真实似然比时,f-散度建立了判别器损失与分数损失之间的联系。f-散度通过在高密度比$\frac{p_{r}^{t}}{p_{\theta}^{t}}$的空间域上赋予分数匹配更高权重,从而在感知合理区域实现更优的分数估计,这可能是其相对于KL散度训练的优势。我们将其留作未来工作。
7. 结论
本文通过调整分数估计优化了去噪过程。借助所提出的方法,我们能够进一步优化预训练分数模型的散度。实证结果表明,我们的方法在所有数据集上均实现了新的最优FID值。深度伪造图像是本研究可能被滥用的潜在风险之一。