My notes and derivations for SMLDs

1. Derivation of weighted score matching loss

We derive the score matching equation assuming Gaussian perturbations, and compute it as an expectation over different noise scales \(\sigma_i\):

\begin{align} \mathbb{E}_{\sigma_i}\mathbb{E}_{q_{\sigma_i}(\xxtilde|\xx)q(\xx)} \Big[ \| \st(\xxtilde, \sigma_i) - \nabla_{\xxtilde} \log q_{\sigma_i}(\xxtilde|\xx) \|^{2} \Big], \tag{1} \end{align}

In practice, we use a weighted form of Eqn. (1) where we introduce a function \(\lambda(\sigma) > 0\) to weight different noise scales:

\begin{align} \mathbb{E}_{\sigma_i}\mathbb{E}_{q_{\sigma_i}(\xxtilde|\xx)q(\xx)} \lambda(\sigma_i) \Big[ \| \st(\xxtilde, \sigma_i) - \nabla_{\xxtilde} \log q_{\sigma_i}(\xxtilde|\xx) \|^{2} \Big]. \tag{2} \end{align}

Since \(q_{\sigma}\) is a Gaussian distribution with standard deviation \(\sigma\), let us simplify further:

\begin{align} \mathcal{L}(\theta) & = \mathbb{E}_{\sigma_i}\mathbb{E}_{q_{\sigma_i}(\xxtilde|\xx)q(\xx)} \lambda(\sigma_i) \Big[ \| \st(\xxtilde, \sigma_i) + \frac{\xxtilde - \xx}{\sigma_i^2} \|_{2}^{2} \Big] \tag{3a} \\ & = \mathbb{E}_{\sigma_i}\mathbb{E}_{q_{\sigma_i}(\xxtilde|\xx)q(\xx)} \lambda(\sigma_i) \Big[ \| \st(\xxtilde, \sigma_i) + \frac{\xx + \sigma_i\epsilon - \xx}{\sigma_i^2} \|_{2}^{2} \Big] \tag{3b} \\ & = \mathbb{E}_{\sigma_i}\mathbb{E}_{q_{\sigma_i}(\xxtilde|\xx)q(\xx)} \lambda(\sigma_i) \Big[ \| \st(\xxtilde, \sigma_i) + \frac{\sigma_i\epsilon}{\sigma_i^2} \|_{2}^{2} \Big] \tag{3c} \\ & = \mathbb{E}_{\sigma_i}\mathbb{E}_{q_{\sigma_i}(\xxtilde|\xx)q(\xx)} \lambda(\sigma_i) \Big[ \| \st(\xxtilde, \sigma_i) + \frac{\epsilon}{\sigma_i} \|_{2}^{2} \Big] \tag{3d} \end{align}

In song2019generative, the authors observed that at optimality the norm of the predicted score \(\| \st(\cdot; \sigma_i) \|_{2}\) is proportional to \(1 / \sigma\). Because they intended for the magnitude of this norm to be the same for any \(\sigma_i\) (to give all noise scales equal weighting) they proposed using \(\lambda(\sigma) = \sigma^2\) so that \(\| \sigma_i \st(\cdot; \sigma_i) \|_{2} \propto 1\) for any noise scale \(i\). We can show this via the following derivation, starting with Eqn. (3d):

\begin{align} \ell(\xx, \sigma_i; \theta) & = \sigma_i^2 \frac{1}{2}\Big\| \st(\xx + \epsilon\sigma_i, \sigma_i) + \frac{\epsilon}{\sigma_i} \Big\|^{2}_{2} \tag{4a} \\ & = \sigma_i^2 \frac{1}{2} \sum_{j} \Big[ \st(\xx + \epsilon\sigma_i, \sigma_i)^2 + \frac{2\epsilon}{\sigma_i} \st(\xx+\epsilon\sigma, i) + \frac{\epsilon^2}{\sigma_i^2}\Big]_{j} \tag{4b} \\ & = \frac{1}{2} \sum_{j} \Big[ \sigma_i^2 \st(\xx + \epsilon\sigma_i, \sigma_i)^2 + 2\epsilon\sigma_i \st(\xx+\epsilon\sigma, i) + \epsilon^2\Big]_{j} \tag{4c} \\ & = \frac{1}{2}\Big\| \sigma_i \st(\xx+\epsilon\sigma_i, \sigma_i) + \epsilon \Big\|^{2}_{2}, \tag{4d} \end{align}

where we note that \(\epsilon \sim \mathcal{N}(0,1)\) and \(\| \sigma_i \st(\cdot; \sigma_i) \|_{2}^{2} \propto 1\) and so the the magnitude of the above is independent of \(\sigma_i\).

1.1. Improved techniques for SMLDs

If we use the score predictor formulation in song2020improved then we can simplify this equation further by defining \(\st(\xx, \sigma_i) := \st(\xx) / \sigma_i\):

\begin{align} \ell(\xx; \sigma_i; \theta) = \frac{1}{2}\Big\| \st(\xx+\epsilon\sigma_i) + \epsilon \Big\|^{2}_{2}. \tag{4d} \end{align}

1.2. SMLDs via stochastic differential equations

In this paper \(t \in [0,1]\) is now a continuous random variable which indexes time, rather than discrete \(\{1, \dots, T\}\). They also unify DDPMs and SMLDs and so some notation from the former is also used, i.e. \(q(\xx_0, \dots, \xx_T)\) to denote the forward process.

\begin{align} \lambda(t) \propto 1 / \mathbb{E}\big[ \| \nabla_{\xx_t} \log q_{\sigma_t}(\xx_t | \xx_0) \|_{2}^{2} \big], \tag{5} \end{align}

where \(t \in \mathbb{Z}\) and I am using DDPM-style notation here to be consistent with how it's written in the original paper. From this it looks like you can just simpify Eqn. (5) by substituting in the actual score:

\begin{align} \lambda(\sigma_t) & \propto \frac{1}{ \mathbb{E}_{\xx_t, \xx}\big[ \| \frac{\xxtilde - \xx}{\sigma_t^2} \|_{2}^{2} \big] } = \frac{1}{ \mathbb{E}_{\epsilon}\big[ \| \frac{\epsilon}{\sigma_t} \|_{2}^{2} \big] } \tag{5b} = \frac{1}{ \frac{1}{\sigma_t^2} \mathbb{E}_{\epsilon} \| \epsilon \|_{2}^{2} } = \frac{1}{ \frac{1}{\sigma_t^2}p \mathbb{E}_{\epsilon} \epsilon^{2} } = \frac{\sigma_t^2}{p} \end{align}

I am hoping that my derivation from (5d) to (5e) is correct. We know the distribution \(\epsilon \sim \mathcal{N}(0,1)\) but the squared norm is computing \(\epsilon_j^2\) for \(j \in \{1, \dots, p\}\), and \(\epsilon^2\) supposedly comes from a Chi-squared distribution with mean \(\text{Var}(\epsilon)\), so I am just substituting \(\mathbb{E}_{\epsilon} \epsilon^2\) with \(1\) and multiplying with \(p\), which is just a constant and does not depend on \(t\). Anyway, the key thing is that the \(\sigma_t^2\) term gets popped up into the numerator.

As a side note, since \(t \in [0,1]\) for song2020score, in their Colab notebook they use the following weighting function, and this is just the variance of the $t$-step forward process:

\begin{align} \lambda(t) = \frac{1}{2 \log \sigma}\big( \sigma^{2t} - 1 \big). \end{align}

1.3. SDE variants

This paper is a lot to take in and proposes a ton of new stuff, but it does unify DDPMs and NCSNs and show they are both different kinds of discretisations of an SDE. I am unable to offer much thoughts since I don't have much experience with SDEs. My summary of the paper is:

  • DDPMs -> variance preserving (VP) SDE (since for DDPMs \(p(x_T)\) is set to be unit variance).
  • NCSNs -> variance exploding (VE) SDE (since we define a noise schedule and the last one \(\sigma_{\text{max}}\) can be arbitrarily large).
  • Proposal of predictor corrector sampler, quote from paper: "PC samplers generalize the original sampling methods of SMLD and DDPM: SMLD uses an identity function as the predictor and annealed Langevin dynamics as the corrector, while the DDPM uses ancestral sampling as the predictor and identity as the corrector."
  • Table 3 CIFAR10 results: NSCN++, VE, has the best FID
  • Table 1 has way too much information to parse, I want to maybe convert this to a barplot or something. But it does look like PC sampling makes a big difference.
  • Use the same architecture as Ho et al (2020), so it is not the "divide by sigma" formulation in "improved techniques for NCSNs" where you define…

2. Understanding the implementation in official code

The official implementation of the loss doesn't look entirely clear at first, but we can do some derivations to show that it is equivalent to equation (1) if we do some algebra.

The official code is here, and the score matching loss is implemented as follows:

# def anneal_dsm_score_estimation(scorenet, samples, sigmas, labels=None, anneal_power=2., hook=None):
# ...
perturbed_samples = samples + noise
target = - 1 / (used_sigmas ** 2) * noise
scores = scorenet(perturbed_samples, labels)
target = target.view(target.shape[0], -1)
scores = scores.view(scores.shape[0], -1)
loss = 1 / 2. * ((scores - target) ** 2).sum(dim=-1) * used_sigmas.squeeze() ** anneal_power
# ...

When anneal_power=2 (the default), then \(\lambda(\sigma_i) = \sigma_i^2\) (see the original paper for what \(\lambda\) is). For a given triplet \((\xx, \sigma_i, \epsilon)\), the loss is:

\begin{align} \text{loss}_{\xx, \sigma_i, \epsilon} & = \sigma_i^2 \frac{1}{2}\Big\| \st(\xx + \epsilon\sigma_i, i) - (\frac{-1}{\sigma_i^2} \epsilon\sigma_i) \Big\|^{2}_2 \\ & = \sigma_i^2 \frac{1}{2}\Big\| \st(\xx + \epsilon\sigma_i, i) + \frac{\epsilon}{\sigma_i} \Big\|^{2}_{2} \ \ \text{(simplify)}\\ & = \sigma_i^2 \frac{1}{2} \sum_{j} \Big[ \st(\xx + \epsilon\sigma_i, i)^2 + \frac{2\epsilon}{\sigma_i} \st(\xx+\epsilon\sigma, i) + \frac{\epsilon^2}{\sigma_i^2}\Big]_{j} \ \ \text{(expand quadratic)} \\ & = \frac{1}{2} \sum_{j} \Big[ \sigma_i^2 \st(\xx + \epsilon\sigma_i, i)^2 + 2\epsilon\sigma_i \st(\xx+\epsilon\sigma, i) + \epsilon^2\Big]_{j} \ \ \text{(distribute $\sigma_i$)} \\ & = \frac{1}{2}\Big\| \sigma_i \st(\xx+\epsilon\sigma_i, i) + \epsilon \Big\|^{2}_{2}. \ \ \text{(re-factorise quadratic)} \end{align}

…and we're done! Of course, if you want, you can use the noise conditioning simplification to obtain equation (2) again.

3. References

  • song2019generative Song, Y., & Ermon, S. (2019). Generative modeling by estimating gradients of the data distribution. Advances in neural information processing systems, 32(), .
  • song2020improved Song, Y., & Ermon, S. (2020). Improved techniques for training score-based generative models. Advances in neural information processing systems, 33(), 12438–12448.
  • song2020score Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2020). Score-based generative modeling through stochastic differential equations. arXiv preprint arXiv:2011.13456, (), .