Overview of Variational Auto-Encoders
Auto-encoders are unsupervised models that seek to learn a low-dimensional representation of a particular distribution. They do this by taking a high dimesnional input \(x\), mapping it to a low dimensional vector (the encoder), and then mapping that vector back to a high dimensional \(x'\) (the decoder) which is hopefully similar to \(x\). This is achieved by applying stochastic gradient descent to the encoder parameterss \(\phi\) and decoder parameters \(\theta\) with the aim of minimizing the distance between \(x\) and \(x'\).
Variational-autoencoders are extensions of auto-encoders that seek to learn distributions. In particular we do the following
- We define our decoder as computing the parameters of a distribution describing the latent variable \(z\) conditional on \(x\): \(E_\phi = q_\phi(z | x)\). Here, we must make some high-level assumptions about the distributions so that the problem does not become too complicated. For example, we may assume \(z | x\) should be normally distributed and have \(E_\phi = q_\phi(z | x) = \mathcal{N}(\mu_{\phi}(x), \sigma_{\phi}(x))\) where \(\mu_{\phi}\) and \(\sigma_\phi^2\) are the functions we actually learn.
- We pass \(x\) through our decoder to get our conditional distribution and generate a point in the low dimensional space by sampling \(z\) with \(z\sim q_{\phi}(z | x)\).
- Finally, we compute \(\hat x = D_{\theta}(z)\), which then becomes our output. Note that if we assume \(x \sim \mathcal{N}(\hat x, \textbf{I})\), then \(p_{\theta}(x | z) = \mathcal{N}(||x-D_{\theta}(z)||^2)\). This can be used in estimating the evidence lower bound, which improves when \(p_\theta(x)\) is large and the KL divergence between \(p_\theta(z | x)\) and \(q_\phi(z | x)\) is small.
Handling Stochasticity in Variational Auto-Encoders
The Difficulty with Stochastic Backpropagation
Performing stochastic gradient descent on a variational auto-encoder requires that one differentiate \(z\sim q_\phi\) with respect to \(\phi\). In particular, we seek to compute how our random sample \(z\) changes as \(\phi\) changes:
\begin{flalign} \nabla_{\phi} E_{z\sim q_{\phi} | x}[z]&=\nabla_{\phi}\int_{-\infty}^{\infty} q_{\phi}(z | x)\cdot z \, dz = \int_{-\infty}^{\infty} z \nabla_{\phi} q_{\phi}(z | x)\, dz \end{flalign}However in practice, computing this integral during every step of the gradient descent is a lot of work. It will typically not admit an analytic solution in terms of \(\phi\) and there's no obvious way to even reliably approximate it (sampling values of a black-box function between negative and positive infinity doesn't make sense). Fortunately, there are ways around it.
Log Derivative Trick
The simplest way around the integral is to use chain-rule to re-write \(\nabla_{\phi} q_{\phi}(z | x)\) as \(q_{\phi}(z | x) \nabla_{\phi} \log(q_{\phi}(z | x))\). This is a viable substitution as shown through chain-rule[1] and allows the integral to be written as an expectation given that \(q_{\phi}(z | x)\) is a probability measure:
\begin{flalign} \nabla_{\phi} E_{z\sim q_{\phi} | x}[z] &= \int_{-\infty}^{\infty} z \nabla_{\phi} q_{\phi}(z | x)\, dz = \int_{-\infty}^{\infty} z\cdot q_{\phi}(z | x) \nabla_{\phi} \log(q_{\phi}(z | x))\, dz\\ &= E_{z\sim q_{\phi} | x}[z\nabla_{\phi} \log(q_{\phi}(z | x))] \end{flalign}Because of this, we can move forward with our sample \(z\) during stochastic gradient descent and use \(z\nabla_{\phi} \log(q_{\phi}(z | x))\) as an estimate of \(\nabla_{\phi} E_{z\sim q_{\phi} | x}[z]\) during backpropagation. Note that the distribution of the output of the auto-encoder \(D_\theta (z)\) is defined by a pushforward measure of \(q_{\phi}\). Allegedly, one issue with this method is that it usually has high variance which makes convergence difficult.
Reparametrization Trick
An alternative approach to the log derivative trick is the reparametrization trick. In this case, we define a random variable \(\varepsilon\) with distribution \(p_\varepsilon\) and a transformation function\(T_\phi(x, \varepsilon)\). This allows us to construct another random variable \(Z_\phi\) so \(Z_\phi | x = T_{\phi}(x, \varepsilon)\) which we then sample from. This makes it easy to estimate the gradient:
\begin{flalign} \nabla_{\phi} E_{z\sim q_{\phi} | x}[z] = \nabla_{\phi} E[Z_\phi | x] = \nabla_{\phi} E_{\varepsilon\sim p_\varepsilon}[T_{\phi}(x, \varepsilon)] = E_{\varepsilon\sim p_\varepsilon}[ \nabla_{\phi} T_{\phi}(x, \varepsilon)] \end{flalign}Thus, when we sample \(\varepsilon\), we can estimate the gradient as \(\nabla_{\phi} T_{\phi}(x, \varepsilon)\) assuming \(T\) is differentiable in its parameters. This estimator is allegedly lower variance than using the log derivative trick.
Gumbel Softmax
In some cases where one is modeling data that exists in discrete classes rather than across a continuous spectrum, it makes sense for \(q_\phi | x\) to describe a discrete distribution. The log derivative trick and variants of it work for this purpose but its not obvious how to use re-parametrization. One approach is to let \(Z = T_\phi(x, \varepsilon) = \text{argmax}_i\, h_{\phi, i}(x) + \varepsilon_i\) where \(\varepsilon_i\) has a Gumbel distribution. In this case
$$ p_\phi(Z=i | x) = \frac{e^{h_{\phi, i}(x)}}{\sum_{j=1}^n e^{h_{\phi, j}(x)}} $$where \(n\) is the number of classes. This almost allows one to backpropagate through a discrete distribution but now \(\text{argmax}\) is not differentiable. However, it can> be replaced with a differentiable soft-max function. In this case, \(Z\) will correspond to samples in continuous space that are close to corresponding samples from the true discrete distribution but are not exactly the same.
Notes
[1] The expression include \(\log\) is produced through chain-rule. Define \(f_z(\phi) = q_{\phi}(z | x)\) and observe:
$$f_z(\phi)\frac{d\log(f_z(\phi))}{d\phi} = f_z(\phi)\left.\frac{d \log(x)}{dx}\frac{d f_z(\phi)}{d\phi} \right |_{x=f_z(\phi)} = f_z(\phi)\frac{1}{f_z(\phi)}\frac{d f_z(\phi)}{d\phi} = \frac{d f_z(\phi)}{d\phi} $$[2] When \(q_{\phi}=\mathcal{N}(\mu_{\phi}(x), \sigma_{\phi}(x))\), one can apply the log
\begin{flalign} \nabla_{\phi} E_{z\sim q_{\phi} | x}[z] &= E_{z\sim q_{\phi} | x}\left[z\nabla_{\phi} \log e^{-\frac{1}{2}\left(\frac{z-\mu_{\phi}(x)}{\sigma_{\phi}(x)}\right )^2} \right] = E_{z\sim q_{\phi} | x}\left[-\frac{z}{2}\nabla_{\phi}\left(\frac{z-\mu_{\phi}(x)}{\sigma_{\phi}(x)}\right )^2\right] \\ &= E_{z\sim q_{\phi} | x}\left [-z\nabla_\phi (\mu_\phi (x), \sigma_\phi(x)) \cdot \left (\frac{z-\mu_\phi(x)}{\sigma_\phi^{2}(x)}, -\frac{(z-\mu_\phi(x))^2}{\sigma_\phi^{3}(x)}\right )\right] \\ & = \nabla_\phi (\mu_\phi (x), \sigma_\phi(x))\cdot E_{z\sim q_{\phi} | x}\left [-z \left (\frac{z-\mu_\phi(x)}{\sigma_\phi^{2}(x)}, -\frac{(z-\mu_\phi(x))^2}{\sigma_\phi^{3}(x)}\right )\right] \end{flalign}