Note: This blog post accompanies code here that demonstrates a VAE through a pure numpy implementation. You can also find a PyTorch VAE implementation here.
Generative models, in short, aims examine and learn characteristics existing data, and generate new samples of data that are similar, but not exactly the same, as the dataset they examined. They have gained significant academic interest, and resulted in useful, fascinating, a perhaps disturbing applications of machine learning. With the rise of deep learning techniques like stochastic gradient descent and backpropagation, Variational Autoencoders have risen in popularity as an unsupervised latent variable model able to generate complicated probability distributions that define the underlying data characteristics. This blog does not aim to explain VAEs and how they work, but will rather expound on a couple of technical details in understanding it’s implementation, specifically involving calculating and backpropagating the Kullbeck-Liebler Divergence loss. There are plenty of great blog posts explaining VAEs. Or you can read the original paper. Or papers explaining the paper.
The KL component of the loss function usually relies on assuming the true distribution of Q(z|X) as a Spherical Gaussian prior. Though there is ongoing research on other types of priors, Gaussians are often used because they are computationally simple in closed form.
A lot of VAE examples usually write the code plainly as:
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
But it might not be straightforward where this derivation comes from. This can be derived like so:
Recall that the PDF for a multivariate gaussian is:
\[\frac{1}{\sqrt{((2\pi)^{k}\det\Sigma)}}{\rm exp}(-\frac{1}{2}(x - \mu)^T\Sigma^{-1}(x - \mu))\]Where \(x\) is a \(k\)-dimensional vector.
For two multivariate Gaussians \(P_1, P_2 \in R^n\):
\[KLD(P_1 || P_2) = E_{P_1}[\log P_1 - \log P_2]\] \[\eqalign{&= \frac{1}{2} E_{P_1}[-\log \det\Sigma_1 - (x - \mu _1)^T\Sigma_{1}^{-1}(x - \mu_1) + \log\det\Sigma_2 + (x - \mu _2)^T\Sigma_{2}^{-1}(x - \mu_2)]\cr &= \frac{1}{2}\log \frac{\det\Sigma_2}{\det\Sigma_1} + \frac{1}{2}E_{P_1}[- (x - \mu _1)^T\Sigma_{1}^{-1}(x - \mu_1) + (x - \mu _2)^T\Sigma_{2}^{-1}(x - \mu_2)] \cr &= \frac{1}{2}\log \frac{\det\Sigma_2}{\det\Sigma_1} + \frac{1}{2}E_{P_1}[tr (\Sigma_{1}^{-1}(x - \mu_1)(x - \mu _1)^T) + tr(\Sigma_{2}^{-1}(x - \mu_2)(x - \mu _2)^T)]\cr &= \frac{1}{2}\log \frac{\det\Sigma_2}{\det\Sigma_1} + \frac{1}{2}[tr (\Sigma_{1}^{-1}E_{P_1}(x - \mu_1)(x - \mu _1)^T) + tr(\Sigma_{2}^{-1}(x - \mu_2)(x - \mu _2)^T)]\cr &= \frac{1}{2}\log \frac{\det\Sigma_2}{\det\Sigma_1} + \frac{1}{2}E_{P_1}[tr (\Sigma_{1}^{-1}\Sigma_{1}) + tr(\Sigma_{2}^{-1}(xx^T - 2x\mu^{T}_{2} + \mu_2\mu_{2}^T)]\cr &= \frac{1}{2}\log \frac{\det\Sigma_2}{\det\Sigma_1} + \frac{1}{2}n + \frac{1}{2} tr(\Sigma_{2}^{-1}(\Sigma_1 + \mu_1\mu_{1}^T - 2\mu_2\mu^{T}_{1} + \mu_2\mu_{2}^T)]\cr &= \frac{1}{2}(\log \frac{\det\Sigma_2}{\det\Sigma_1} - n + tr(\Sigma_{2}^{-1}(\Sigma_1) + tr(\mu_{1}^T\Sigma_{2}^{-1}\mu_1 - 2\mu_{1}^T\Sigma_{2}^{-1}\mu_2 + \mu_{2}^T\Sigma_{2}^{-1}\mu_2)\cr &= \frac{1}{2}(\log \frac{\det\Sigma_2}{\det\Sigma_1} - n + tr(\Sigma_{2}^{-1}(\Sigma_1) + (\mu_{2}-\mu_1)^T\Sigma_{2}^{-1}(\mu_{2}-\mu_1))}\]In the last step, you need to apply the trace trick, such that you can get the quadratic:
\[\mu_2^T\Sigma_2^{-1}\mu_2 - 2\mu_2^T\Sigma_{2}^{-1}\mu_1 + \mu_1^T\Sigma_{2}^{-1}\mu_1 = (\mu_2-\mu_1)^T\Sigma_{2}^{-1}(\mu_2-\mu_1)\]In our case, \(P_2\) is \(\mathcal{N}(0, I)\), which simplifies this to:
\[\frac{1}{2}(- \log\det\Sigma_1 - n + tr(\Sigma_1) + \mu_{2}^T\mu_{2})\]which is equivalent to:
\[-\frac{1}{2}\sum_{j=1}^{J}(1 + \log(\sigma_j^2) - (\mu_j)^2 - (\sigma_j)^2)\]where J is the dimensionality of the latent vector \(z\). Note in the implementation the encoder encodes log-variance rather than variance directly, and since the variance matrix is diagonal, it’s represented by a vector with its diagonal values instead.
Backpropagating the KL loss was a bit of a pain, because matrix calculus is confusing, especially when there is a summation involved in the equation. However, I learned a rather elegant mathematical solution which relies on properties of Hadamard (elementwise) and Frobenious (trace) products for matrices.
Let
\[\eqalign{ B &= C\circ A &\implies B_{ij} = C_{ij}A_{ij} \cr \beta &= C:A &\implies \beta = \sum_i\sum_j C_{ij}A_{ij} \cr }\]denote the Hadamard and Frobenius products, respectively. Some useful properties follows. Since the Frobenius product is equivalent to trace:
\[A:B={\rm tr}(A^TB)\]Cyclic and transpositional properties apply:
\[\eqalign{ A:B &= A^T:B^T \cr A:BC &= AC^T:B \cr }\]Likewise, Hadamard and Frobenius products commute with each other:
\[\eqalign{ A:B &= B:A \cr A\circ B &= B\circ A \cr A:(B\circ C) &= (A\circ B):C \cr }\]The matrix of all ones is the identity element for Hadamard products
\[A\circ 1 = A\]Let’s say you have a two layer encoder such that:
\[\eqalign{l = W_1(W_0x + b_0) + b_1\cr KLD = \frac{1}{2}\sum_{j}(p - l + m^2 - 1)}\]where j is the jth index of \(p\), \(l\), \(m\)
For example, to find the derivative of the K-L loss with respect to \(W_0\):
Define:
\[\eqalign{ z &= W_0x+b &\implies dz=dW_0\,x\cr h &= {\rm step}(z) &\implies H={\rm Diag}(h)\cr r &= {\rm relu}(z) &\implies dr=h\circ dz = H\,dz \cr l &= W_1r + b &\implies dl=W_1dr\cr p &= \exp(l) &\implies dp=p\circ dl \cr }\]Then
\[\lambda = \alpha(p-l+m\circ m-1):1\]Where the scalar \(λ\) is defined as the sum of the elements of the (KL) vector via a Frobenius product with 1, the vector of all ones.
\[\eqalign{ d\lambda &= \alpha(dp-dl):1 \cr &= \alpha(p\circ dl-1\circ dl):1 \cr &= \alpha(p-1):dl \cr &= \alpha(p-1):W_1dr \cr &= \alpha W_1^T(p-1):H\,dz \cr &= \alpha HW_1^T(p-1):dW_0\,x \cr &= \alpha HW_1^T(p-1)x^T:dW_0 \cr \frac{\partial\lambda}{\partial W_0} &= \alpha HW_1^T(p-1)x^T \cr }\]Note the derivative of ReLU (or the ramp function) is the Heaviside Step Function. Here, I replace the elementwise operation \(d{\rm relu}(z) = h\circ dz\) with \(Hdz\), to turn it into a matrix multiplication, for simplification.
References: