One-step Diffusion Models with \(f\)-Divergence Distribution Matching

We present a general \(f\)-divergence minimization framework to distill multi-step diffusion models into a single step generator, with state-of-the-art one-step performance.
We present a general \(f\)-divergence minimization framework to distill multi-step diffusion models into a single step generator, with state-of-the-art one-step performance.
Inference latency of the teacher diffusion model and the \(f\)-distill. The FID scores are evaluated on ImageNet-64 (left) and MS-COCO (right). Notably, \(f\)-distill also outperforms the teacher model in terms of FID score. By applying đť‘“-distillation to the teacher model and incorporating a GAN objective on real data, we achieve FID scores surpassing the teacher's.
Comparison between 50-step teacher (SD 1.5) and one-step student (\(f\)-distill), on more samples given each prompt.
Comparison between 50-step teacher (SD 1.5) and one-step student (\(f\)-distill)
Sampling from diffusion models involves a slow iterative process that hinders their practical deployment, especially for interactive applications. To accelerate generation speed, recent approaches distill a multi-step diffusion model into a single-step student generator via variational score distillation, which matches the distribution of samples generated by the student to the teacher's distribution. However, these approaches use the reverse Kullback–Leibler (KL) divergence for distribution matching which is known to be mode seeking. In this paper, we generalize the distribution matching approach using a novel \(f\)-divergence minimization framework, termed \(f\)-distill, that covers different divergences with different trade-offs in terms of mode coverage and training variance. We derive the gradient of the \(f\)-divergence between the teacher and student distributions and show that it is expressed as the product of their score differences and a weighting function determined by their density ratio. This weighting function naturally emphasizes samples with higher density in the teacher distribution, when using a less mode-seeking divergence. We observe that the popular variational score distillation approach using the reverse-KL divergence is a special case within our framework. Empirically, we demonstrate that alternative \(f\)-divergences, such as forward-KL and Jensen-Shannon divergences, outperform the current best variational score distillation methods across image generation tasks. In particular, when using Jensen-Shannon divergence, \(f\)-distill achieves current state-of-the-art one-step generation performance on ImageNet64 and zero-shot text-to-image generation on MS-COCO.
To distill a one-step student generator \(G_\theta\) from the multi-step teacher diffusion models, we aim to match the student distribution, denoted by \(q\), with the teacher distribution \(p\). We perform this distribution matching by minimizing their \(f\)-divergence:
\[D_f(p||q) = \int p(\mathbf{x}) f\left(\frac{p(\mathbf{x})}{q(\mathbf{x})}\right) d\mathbf{x}\]where \(f\) is a convex function that satisfies \(f(1) = 0\). Since the student distribution \(q\) is the push-forward measure induced by the one-step generator \(G_\theta\), it implicitly depends on the generator's parameters \(\theta\). Due to this implicit dependency, directly calculating the gradient of \(f\)-divergence, \(D_f(p||q)\), w.r.t \(\theta\) presents a challenge. We establish the analytical expression for the gradient for \(D_f(p_t||q_t), \forall t \ge 0 \), where \(p_t\) is the perturbed distribution through the diffusion forward process for the teacher's distribution \(p\), i.e., \(p_t = p * \mathcal{N}( \mathbf{0},\sigma^2(t)\mathbf{I})\) (same for the student distribution \(q\)):
\[\nabla_\theta D_f(p_t||q_t) = \mathbb{E}_{\substack{\mathbf{z}, \epsilon}}-\left[f''\left(\frac{p_t(\mathbf{x})}{q_t(\mathbf{x})}\right)\left(\frac{p_t(\mathbf{x})}{q_t(\mathbf{x})}\right)^2 \left(\underbrace{\nabla_\mathbf{x} \log p_t(\mathbf{x})}_{\textrm{teacher score}} - \underbrace{\nabla_\mathbf{x} \log q_t(\mathbf{x})}_{\textrm{fake score}} \right) \nabla_\theta G_\theta(\mathbf{z})\right]\]where \(\mathbf{z} \sim p(\mathbf{z}), \epsilon \sim \mathcal{N}( \mathbf{0}, \mathbf{I})\) and \( \mathbf{x} = G_\theta(\mathbf{z})+\sigma(t)\epsilon \). This gradient is expressed as the score difference between the teacher's and student's distributions, weighted by a time-dependent factor \(f''\left({p_t(\mathbf{x}_t)}/{q_t(\mathbf{x}_t)}\right)\left({p_t(\mathbf{x}_t)}/{q_t(\mathbf{x}_t)}\right)^2\) determined by both the chosen \(f\)-divergence and the density ratio. Crucially, every term in the gradient is tractable, enabling the optimization of distributional matching through general \(f\)-divergence minimization. In practice, the score of student distribution \(\nabla_\mathbf{x} \log q_{t}(\mathbf{x}_t)\) is approximated by an online diffusion model (fake score), and the density ratio \({p_t(\mathbf{x}_t)}/{q_t(\mathbf{x}_t)}\) in the weighting function is readily available from the discriminator in the auxiliary GAN objective.
The gradient update for the one-step student in \(f\)-distill. The gradient is a product of the difference between the teacher score and fake score, and a weighting function determined by the chosen \(f\)-divergence and density ratio. The density ratio is readily available from the discriminator in the auxiliary GAN objective.
For notation convenience, let \(h(r) := f''(r)r^2\) denote the weighting function, and \(r_t(\mathbf{x}) := p_t(\mathbf{x})/q_t(\mathbf{x})\) denote the density-ratio at time \(t\). Here, we inspect three properties across different distance measures in the \(f\)-divergence family, in the context of diffusion distillation:
Comparison of different \(f\)-divergences as a function of the likelihood ratio \(r :=p(\mathbf{x})/q(\mathbf{x})\)
We validate \(f\)-distill on class-conditioned ImageNet (\(64 \times 64\) and text-to-image tasks (\(512 \times 512\). In practice, we use EDM as teacher model on ImageNet-64, and Stable Diffusion v1.5 as teacher model on text-to-image tasks. \(f\)-distill achieves the new state-of-the-art single-step performance these benchmarks, when using Jensen-Shannon divergence. \(f\)-distill obtains an FID score of 1.16 on ImageNet-64, and an FID score 7.42 on zero-shot MS-COCO. For reference, the FID scores of teachers are 2.35 and 8.59 on the two datasets, respectively.
Quantitative comparison of \(f\)-distill with state-of-the-art multi-step and one-step generation methods, on ImageNet-64 (right) and MS-COCO (left)
One-step Diffusion Models with \(f\)-Divergence Distribution Matching Yilun Xu, Weili Nie, Arash Vahdat
Paper
@inproceedings{xu2024onestep,
title={One-step Diffusion Models with f-Divergence Distribution Matching},
author={Xu, Yilun and Nie, Weili and Vahdat, Arash},
booktitle={},
year={2024}}