My notes and derivations for SMLDs
1. Derivation of weighted score matching loss
We derive the score matching equation assuming Gaussian perturbations, and compute it as an expectation over different noise scales \(\sigma_i\):
\begin{align} \mathbb{E}_{\sigma_i}\mathbb{E}_{q_{\sigma_i}(\xxtilde|\xx)q(\xx)} \Big[ \| \st(\xxtilde, \sigma_i) - \nabla_{\xxtilde} \log q_{\sigma_i}(\xxtilde|\xx) \|^{2} \Big], \tag{1} \end{align}In practice, we use a weighted form of Eqn. (1) where we introduce a function \(\lambda(\sigma) > 0\) to weight different noise scales:
\begin{align} \mathbb{E}_{\sigma_i}\mathbb{E}_{q_{\sigma_i}(\xxtilde|\xx)q(\xx)} \lambda(\sigma_i) \Big[ \| \st(\xxtilde, \sigma_i) - \nabla_{\xxtilde} \log q_{\sigma_i}(\xxtilde|\xx) \|^{2} \Big]. \tag{2} \end{align}Since \(q_{\sigma}\) is a Gaussian distribution with standard deviation \(\sigma\), let us simplify further:
\begin{align} \mathcal{L}(\theta) & = \mathbb{E}_{\sigma_i}\mathbb{E}_{q_{\sigma_i}(\xxtilde|\xx)q(\xx)} \lambda(\sigma_i) \Big[ \| \st(\xxtilde, \sigma_i) + \frac{\xxtilde - \xx}{\sigma_i^2} \|_{2}^{2} \Big] \tag{3a} \\ & = \mathbb{E}_{\sigma_i}\mathbb{E}_{q_{\sigma_i}(\xxtilde|\xx)q(\xx)} \lambda(\sigma_i) \Big[ \| \st(\xxtilde, \sigma_i) + \frac{\xx + \sigma_i\epsilon - \xx}{\sigma_i^2} \|_{2}^{2} \Big] \tag{3b} \\ & = \mathbb{E}_{\sigma_i}\mathbb{E}_{q_{\sigma_i}(\xxtilde|\xx)q(\xx)} \lambda(\sigma_i) \Big[ \| \st(\xxtilde, \sigma_i) + \frac{\sigma_i\epsilon}{\sigma_i^2} \|_{2}^{2} \Big] \tag{3c} \\ & = \mathbb{E}_{\sigma_i}\mathbb{E}_{q_{\sigma_i}(\xxtilde|\xx)q(\xx)} \lambda(\sigma_i) \Big[ \| \st(\xxtilde, \sigma_i) + \frac{\epsilon}{\sigma_i} \|_{2}^{2} \Big] \tag{3d} \end{align}
In song2019generative
, the authors observed that at optimality the norm of the predicted score \(\| \st(\cdot; \sigma_i) \|_{2}\) is proportional to \(1 / \sigma\). Because they intended for the magnitude of this norm to be the same for any \(\sigma_i\) (to give all noise scales equal weighting) they proposed using \(\lambda(\sigma) = \sigma^2\) so that \(\| \sigma_i \st(\cdot; \sigma_i) \|_{2} \propto 1\) for any noise scale \(i\). We can show this via the following derivation, starting with Eqn. (3d):
where we note that \(\epsilon \sim \mathcal{N}(0,1)\) and \(\| \sigma_i \st(\cdot; \sigma_i) \|_{2}^{2} \propto 1\) and so the the magnitude of the above is independent of \(\sigma_i\).
1.1. Improved techniques for SMLDs
If we use the score predictor formulation in song2020improved
then we can simplify this equation further by defining \(\st(\xx, \sigma_i) := \st(\xx) / \sigma_i\):
1.2. SMLDs via stochastic differential equations
In this paper \(t \in [0,1]\) is now a continuous random variable which indexes time, rather than discrete \(\{1, \dots, T\}\). They also unify DDPMs and SMLDs and so some notation from the former is also used, i.e. \(q(\xx_0, \dots, \xx_T)\) to denote the forward process.
\begin{align} \lambda(t) \propto 1 / \mathbb{E}\big[ \| \nabla_{\xx_t} \log q_{\sigma_t}(\xx_t | \xx_0) \|_{2}^{2} \big], \tag{5} \end{align}where \(t \in \mathbb{Z}\) and I am using DDPM-style notation here to be consistent with how it's written in the original paper. From this it looks like you can just simpify Eqn. (5) by substituting in the actual score:
\begin{align} \lambda(\sigma_t) & \propto \frac{1}{ \mathbb{E}_{\xx_t, \xx}\big[ \| \frac{\xxtilde - \xx}{\sigma_t^2} \|_{2}^{2} \big] } = \frac{1}{ \mathbb{E}_{\epsilon}\big[ \| \frac{\epsilon}{\sigma_t} \|_{2}^{2} \big] } \tag{5b} = \frac{1}{ \frac{1}{\sigma_t^2} \mathbb{E}_{\epsilon} \| \epsilon \|_{2}^{2} } = \frac{1}{ \frac{1}{\sigma_t^2}p \mathbb{E}_{\epsilon} \epsilon^{2} } = \frac{\sigma_t^2}{p} \end{align}I am hoping that my derivation from (5d) to (5e) is correct. We know the distribution \(\epsilon \sim \mathcal{N}(0,1)\) but the squared norm is computing \(\epsilon_j^2\) for \(j \in \{1, \dots, p\}\), and \(\epsilon^2\) supposedly comes from a Chi-squared distribution with mean \(\text{Var}(\epsilon)\), so I am just substituting \(\mathbb{E}_{\epsilon} \epsilon^2\) with \(1\) and multiplying with \(p\), which is just a constant and does not depend on \(t\). Anyway, the key thing is that the \(\sigma_t^2\) term gets popped up into the numerator.
As a side note, since \(t \in [0,1]\) for song2020score
, in their Colab notebook they use the following weighting function, and this is just the variance of the $t$-step forward process:
1.3. SDE variants
This paper is a lot to take in and proposes a ton of new stuff, but it does unify DDPMs and NCSNs and show they are both different kinds of discretisations of an SDE. I am unable to offer much thoughts since I don't have much experience with SDEs. My summary of the paper is:
- DDPMs -> variance preserving (VP) SDE (since for DDPMs \(p(x_T)\) is set to be unit variance).
- NCSNs -> variance exploding (VE) SDE (since we define a noise schedule and the last one \(\sigma_{\text{max}}\) can be arbitrarily large).
- Proposal of predictor corrector sampler, quote from paper: "PC samplers generalize the original sampling methods of SMLD and DDPM: SMLD uses an identity function as the predictor and annealed Langevin dynamics as the corrector, while the DDPM uses ancestral sampling as the predictor and identity as the corrector."
- Table 3 CIFAR10 results: NSCN++, VE, has the best FID
- Table 1 has way too much information to parse, I want to maybe convert this to a barplot or something. But it does look like PC sampling makes a big difference.
- Use the same architecture as Ho et al (2020), so it is not the "divide by sigma" formulation in "improved techniques for NCSNs" where you define…
2. Understanding the implementation in official code
The official implementation of the loss doesn't look entirely clear at first, but we can do some derivations to show that it is equivalent to equation (1) if we do some algebra.
The official code is here, and the score matching loss is implemented as follows:
# def anneal_dsm_score_estimation(scorenet, samples, sigmas, labels=None, anneal_power=2., hook=None): # ... perturbed_samples = samples + noise target = - 1 / (used_sigmas ** 2) * noise scores = scorenet(perturbed_samples, labels) target = target.view(target.shape[0], -1) scores = scores.view(scores.shape[0], -1) loss = 1 / 2. * ((scores - target) ** 2).sum(dim=-1) * used_sigmas.squeeze() ** anneal_power # ...
When anneal_power=2
(the default), then \(\lambda(\sigma_i) = \sigma_i^2\) (see the original paper for what \(\lambda\) is). For a given triplet \((\xx, \sigma_i, \epsilon)\), the loss is:
…and we're done! Of course, if you want, you can use the noise conditioning simplification to obtain equation (2) again.
3. References
song2019generative
Song, Y., & Ermon, S. (2019). Generative modeling by estimating gradients of the data distribution. Advances in neural information processing systems, 32(), .song2020improved
Song, Y., & Ermon, S. (2020). Improved techniques for training score-based generative models. Advances in neural information processing systems, 33(), 12438–12448.song2020score
Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456, (), .