VAE for MNIST Image Classification

  1. VAE for MNIST Image Classification
  2. Breakdown of the Math
  3. Cheating with the 'conditional' VAE

In this section, we are going to go through some practical code examples of variational encoding method for MNIST image classification.

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.metrics.cluster import normalized_mutual_info_score

def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    
def plot_reconstruction(model, n=24):
    x,_ = next(iter(data_loader))
    x = x[:n,:,:,:].to(device)
    try:
        out, _, _, log_p = model(x.view(-1, image_size)) 
    except:
        out, _, _ = model(x.view(-1, image_size)) 
    x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
    out_grid = torchvision.utils.make_grid(x_concat).cpu().data
    show(out_grid)

def plot_generation(model, n=24):
    with torch.no_grad():
        z = torch.randn(n, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)

    out_grid = torchvision.utils.make_grid(out).cpu()
    show(out_grid)

def plot_conditional_generation(model, n=8, fix_number=None):
    with torch.no_grad():
        matrix = np.zeros((n,n_classes))
        matrix[:,0] = 1

        if fix_number is None:
            final = matrix[:]
            for i in range(1,n_classes):
                final = np.vstack((final,np.roll(matrix,i)))
            #z = torch.randn(8*n_classes, z_dim).to(device)
            z = torch.randn(8, z_dim)
            z = z.repeat(n_classes,1).to(device)
            y_onehot = torch.tensor(final).type(torch.FloatTensor).to(device)
            out = model.decode(z,y_onehot).view(-1, 1, 28, 28)
        else:
            z = torch.randn(n, z_dim).to(device)
            y_onehot = torch.tensor(np.roll(matrix, fix_number))
                       .type(torch.FloatTensor).to(device)
            out = model.decode(z,y_onehot).view(-1, 1, 28, 28)

    out_grid = torchvision.utils.make_grid(out).cpu()
    show(out_grid)
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
data_dir = 'data'
# MNIST dataset
dataset = torchvision.datasets.MNIST(root=data_dir,
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=128, 
                                          shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(data_dir, train=False, download=True, transform=transforms.ToTensor()),
    batch_size=10, shuffle=False)

Breakdown of the Math

Consider a latent variable model with a data variable xXx\in \mathcal{X} and a latent variable zZz\in \mathcal{Z}, p(z,x)=p(z)pθ(xz)p(z,x) = p(z)p_\theta(x|z). Given the data x1,,xnx_1,\dots, x_n, we want to train the model by maximizing the marginal log-likelihood:

L=Epd(x)[logpθ(x)]=Epd(x)[logZpθ(xz)p(z)dz]\begin{array}{rcl} \mathcal{L} = \mathbf{E}_{p_d(x)}\left[\log p_\theta(x)\right]=\mathbf{E}_{p_d(x)}\left[\log \int_{\mathcal{Z}}p_{\theta}(x|z)p(z)dz\right] \end{array}

where pdp_d denotes the empirical distribution of XX: pd(x)=1ni=1nδxi(x)p_d(x) =\frac{1}{n}\sum_{i=1}^n \delta_{x_i}(x).

To avoid the (often) difficult computation of the integral above, the idea behind variational methods is to instead maximize a lower bound to the log-likelihood:

LL(pθ(xz),q(zx))=Epd(x)[Eq(zx)[logpθ(xz)]KL(q(zx)p(z))]\begin{array}{rcl} \mathcal{L} \geq L(p_\theta(x|z),q(z|x)) =\mathbf{E}_{p_d(x)}\left[\mathbf{E}_{q(z|x)}\left[\log p_\theta(x|z)\right]-\mathrm{KL}\left( q(z|x)||p(z)\right)\right] \end{array}

Any choice of q(zx)q(z|x) gives a valid lower bound. Variational autoencoders replace the variational posterior q(zx)q(z|x) by an inference network qϕ(zx)q_{\phi}(z|x) that is trained together with pθ(xz)p_{\theta}(x|z) to jointly maximize L(pθ,qϕ)L(p_\theta,q_\phi)

The variational posterior qϕ(zx)q_{\phi}(z|x) is also called the encoder and the generative model pθ(xz)p_{\theta}(x|z), the decoder or generator.

The first term Eq(zx)[logpθ(xz)]\mathbf{E}_{q(z|x)}\left[\log p_\theta(x|z)\right] is the negative reconstruction error. Indeed under a gaussian assumption i.e. pθ(xz)=N(μθ(z),I)p_{\theta}(x|z) = \mathcal{N}(\mu_{\theta}(z), I) the term logpθ(xz)\log p_\theta(x|z) reduces to xμθ(z)2\propto \|x-\mu_\theta(z)\|^2, which is often used in practice. The term KL(q(zx)p(z))\mathrm{KL}\left( q(z|x)||p(z)\right) can be seen as a regularization term, where the variational posterior qϕ(zx)q_\phi(z|x) should be matched to the prior p(z)=N(0,I)p(z)= \mathcal{N}(0, I).

Variational Autoencoders were introduced by Kingma and Welling (2013), see also (Doersch, 2016) for a tutorial.

There are various examples of VAE in PyTorch available here or here. The code below is taken from this last source.


vae


# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
learning_rate = 1e-3

# VAE model
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        return torch.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Here for the loss, instead of MSE for the reconstruction loss, we take Binary Cross-Entropy. The code below is still from the PyTorch tutorial (with minor modifications to avoid warnings!).

# Start training
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and kl divergence
        # For KL divergence between Gaussians, see Appendix B in VAE paper or (Doersch, 2016):
        # https://arxiv.org/abs/1606.05908
        reconst_loss = F.binary_cross_entropy(x_reconst, x, reduction='sum')
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item()/len(x), kl_div.item()/len(x)))

Let see how our network reconstructs our last batch. We display pairs of original digits and reconstructed version.

plot_reconstruction(model)

Let's see now how our network generates new samples.

plot_generation(model)

Not great, but we did not train our network for long... That being said, we have no control of the generated digits. In the rest of this notebook, we explore ways to generates zeroes, ones, twos and so on.

As a by-product, we show how our VAE will allow us to do clustering thanks to the Gumbel VAE described below. But before that, we start by cheating a little bit...

Cheating with the 'conditional' VAE

We will first use the labels here (like what we did in the course with Conditional GAN). The idea is to modify slightly the architecture above by feeding a one hot version of the label to the decoder in addition to the code computed by the decoder.

First code a function transforming a label in its one hot encoding. This function will be used in the training loop (not in the architecture of the neural network!).

n_classes = 10
def l_2_onehot(labels,nb_digits=n_classes):
    # take labels (from the dataloader) and return labels onehot-encoded
    #
    # your code here
    #

You can test it on a batch.

(x,labels) = next(iter(data_loader))

Now modify the architecture of the VAE where the decoder takes as input the random code concatenated with the onehot encoding of the label.

n_classes = 10

class VAE_Cond(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20, n_classes = 10):
        super(VAE_Cond, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        #
        # your code here
        #
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, l_onehot):
        #
        # your code here / use torch.cat 
        #      
    
    def forward(self, x, l_onehot):
        #
        # your code here / use F.gumbel_softmax
        #

Test your new model on a batch:

model_C = VAE_Cond().to(device)
x = x.to(device).view(-1, image_size)
l_onehot = l_2_onehot(labels)
l_onehot = l_onehot.to(device)
model_C(x, l_onehot)

Now you can modify the training loop of your network. The parameter will allow you to scale the KL term in your loss as explained in the β\beta -VAE paper see formula (4) in the paper.

def train_C(model, data_loader=data_loader,num_epochs=num_epochs, beta=10., verbose=True):
    nmi_scores = []
    model.train(True)
    for epoch in range(num_epochs):
        for i, (x, labels) in enumerate(data_loader):
            # Forward pass
            x = x.to(device).view(-1, image_size)
            #
            # your code here
            #
            
            
            reconst_loss = F.binary_cross_entropy(x_reconst, x, reduction='sum')
            kl_div =  - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
            
            # Backprop and optimize
            loss = reconst_loss + beta*kl_div
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if verbose:
                if (i+1) % 10 == 0:
                    print("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}"
                           .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item()/len(x),
                                   kl_div.item()/len(x)))
model_C = VAE_Cond().to(device)
optimizer = torch.optim.Adam(model_C.parameters(), lr=learning_rate)
train_C(model_C,num_epochs=15,verbose=True)

plot_conditional_generation(model_C, n=8)

Here you should get nice results. Now we will avoid the use of the labels...