🎉 Introduction to Variational Autoencoders (VAEs) with PyTorch

Image credit: Unsplash
pip install pytorch-pretrained-biggan
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: pytorch-pretrained-biggan in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (0.1.1)
Requirement already satisfied: torch>=0.4.1 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from pytorch-pretrained-biggan) (2.5.1)
Requirement already satisfied: numpy in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from pytorch-pretrained-biggan) (2.0.1)
Requirement already satisfied: boto3 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from pytorch-pretrained-biggan) (1.35.63)
Requirement already satisfied: requests in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from pytorch-pretrained-biggan) (2.32.3)
Requirement already satisfied: tqdm in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from pytorch-pretrained-biggan) (4.67.0)
Requirement already satisfied: filelock in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from torch>=0.4.1->pytorch-pretrained-biggan) (3.13.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from torch>=0.4.1->pytorch-pretrained-biggan) (4.11.0)
Requirement already satisfied: setuptools in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from torch>=0.4.1->pytorch-pretrained-biggan) (75.1.0)
Requirement already satisfied: sympy==1.13.1 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from torch>=0.4.1->pytorch-pretrained-biggan) (1.13.1)
Requirement already satisfied: networkx in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from torch>=0.4.1->pytorch-pretrained-biggan) (3.2.1)
Requirement already satisfied: jinja2 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from torch>=0.4.1->pytorch-pretrained-biggan) (3.1.4)
Requirement already satisfied: fsspec in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from torch>=0.4.1->pytorch-pretrained-biggan) (2024.10.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from sympy==1.13.1->torch>=0.4.1->pytorch-pretrained-biggan) (1.3.0)
Requirement already satisfied: botocore<1.36.0,>=1.35.63 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from boto3->pytorch-pretrained-biggan) (1.35.63)
Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from boto3->pytorch-pretrained-biggan) (1.0.1)
Requirement already satisfied: s3transfer<0.11.0,>=0.10.0 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from boto3->pytorch-pretrained-biggan) (0.10.3)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from requests->pytorch-pretrained-biggan) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from requests->pytorch-pretrained-biggan) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from requests->pytorch-pretrained-biggan) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from requests->pytorch-pretrained-biggan) (2024.8.30)
Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from botocore<1.36.0,>=1.35.63->boto3->pytorch-pretrained-biggan) (2.9.0.post0)
Requirement already satisfied: MarkupSafe>=2.0 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from jinja2->torch>=0.4.1->pytorch-pretrained-biggan) (2.1.3)
Requirement already satisfied: six>=1.5 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.36.0,>=1.35.63->boto3->pytorch-pretrained-biggan) (1.16.0)
Note: you may need to restart the kernel to use updated packages.
pip install Pillow libsixel-python
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: Pillow in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (11.0.0)
Requirement already satisfied: libsixel-python in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (0.5.0)
Note: you may need to restart the kernel to use updated packages.
import torch
import random

import numpy as np
import matplotlib.pylab as plt

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets, transforms
pip install matplotlib tqdm
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: matplotlib in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (3.9.2)
Requirement already satisfied: tqdm in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (4.67.0)
Requirement already satisfied: contourpy>=1.0.1 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from matplotlib) (1.3.1)
Requirement already satisfied: cycler>=0.10 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from matplotlib) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from matplotlib) (4.55.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from matplotlib) (1.4.7)
Requirement already satisfied: numpy>=1.23 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from matplotlib) (2.0.1)
Requirement already satisfied: packaging>=20.0 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from matplotlib) (24.2)
Requirement already satisfied: pillow>=8 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from matplotlib) (11.0.0)
Requirement already satisfied: pyparsing>=2.3.1 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from matplotlib) (3.2.0)
Requirement already satisfied: python-dateutil>=2.7 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: six>=1.5 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Note: you may need to restart the kernel to use updated packages.
pip install scipy
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: scipy in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (1.14.1)
Requirement already satisfied: numpy<2.3,>=1.23.5 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from scipy) (2.0.1)
Note: you may need to restart the kernel to use updated packages.
from pytorch_pretrained_biggan import one_hot_from_names
from tqdm.notebook import tqdm, trange
# @title Figure settings
import logging
logging.getLogger('matplotlib.font_manager').disabled = True

import ipywidgets as widgets
from ipywidgets import FloatSlider, IntSlider, HBox, Layout, VBox
from ipywidgets import interactive_output, Dropdown

%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/content-creation/main/nma.mplstyle")
!pip3 install vibecheck datatops

from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt
        notebook_section,
        {
            "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
            "name": "neuromatch_dl",
            "user_key": "f379rz8y",
        },
    ).render()


feedback_prefix = "W2D4_T1"
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: vibecheck in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (0.0.5)
Requirement already satisfied: datatops in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (0.2.2)
Requirement already satisfied: ipython in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from vibecheck) (8.29.0)
Requirement already satisfied: ipywidgets in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from vibecheck) (8.1.5)
Requirement already satisfied: requests in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from datatops) (2.32.3)
Requirement already satisfied: flask in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from datatops) (3.1.0)
Requirement already satisfied: Werkzeug>=3.1 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from flask->datatops) (3.1.3)
Requirement already satisfied: Jinja2>=3.1.2 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from flask->datatops) (3.1.4)
Requirement already satisfied: itsdangerous>=2.2 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from flask->datatops) (2.2.0)
Requirement already satisfied: click>=8.1.3 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from flask->datatops) (8.1.7)
Requirement already satisfied: blinker>=1.9 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from flask->datatops) (1.9.0)
Requirement already satisfied: decorator in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from ipython->vibecheck) (5.1.1)
Requirement already satisfied: jedi>=0.16 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from ipython->vibecheck) (0.19.2)
Requirement already satisfied: matplotlib-inline in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from ipython->vibecheck) (0.1.7)
Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from ipython->vibecheck) (3.0.48)
Requirement already satisfied: pygments>=2.4.0 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from ipython->vibecheck) (2.18.0)
Requirement already satisfied: stack-data in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from ipython->vibecheck) (0.6.3)
Requirement already satisfied: traitlets>=5.13.0 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from ipython->vibecheck) (5.14.3)
Requirement already satisfied: pexpect>4.3 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from ipython->vibecheck) (4.9.0)
Requirement already satisfied: comm>=0.1.3 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from ipywidgets->vibecheck) (0.2.2)
Requirement already satisfied: widgetsnbextension~=4.0.12 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from ipywidgets->vibecheck) (4.0.13)
Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from ipywidgets->vibecheck) (3.0.13)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from requests->datatops) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from requests->datatops) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from requests->datatops) (2.2.3)
Requirement already satisfied: certifi>=2017.4.17 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from requests->datatops) (2024.8.30)
Requirement already satisfied: parso<0.9.0,>=0.8.4 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from jedi>=0.16->ipython->vibecheck) (0.8.4)
Requirement already satisfied: MarkupSafe>=2.0 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from Jinja2>=3.1.2->flask->datatops) (2.1.3)
Requirement already satisfied: ptyprocess>=0.5 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from pexpect>4.3->ipython->vibecheck) (0.7.0)
Requirement already satisfied: wcwidth in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython->vibecheck) (0.2.13)
Requirement already satisfied: executing>=1.2.0 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from stack-data->ipython->vibecheck) (2.1.0)
Requirement already satisfied: asttokens>=2.1.0 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from stack-data->ipython->vibecheck) (2.4.1)
Requirement already satisfied: pure-eval in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from stack-data->ipython->vibecheck) (0.2.3)
Requirement already satisfied: six>=1.12.0 in /home/bjzhu/libs/miniconda3/envs/pt2.51/lib/python3.12/site-packages (from asttokens>=2.1.0->stack-data->ipython->vibecheck) (1.16.0)
# Helper functions

DEVICE = torch.device("cuda:0")

def image_moments(image_batches, n_batches=None):
    """
    Compute mean and covariance of all pixels from batches of images

    Args:
        Image_batches: tuple
            Image batches
        n_batches: int
            Number of Batch size

    Returns:
        m1: float
            Mean of all pixels
        cov: float
            Covariance of all pixels
    """
    m1, m2 = torch.zeros((), device=DEVICE), torch.zeros((), device=DEVICE)
    n = 0
    for im in tqdm(image_batches, total=n_batches, leave=False, desc="Computing pixel mean and covariance"):
        im = im.to(DEVICE)
        b = im.size()[0]
        im = im.view(b, -1)
        m1 = m1 + im.sum(dim=0)
        m2 = m2 + (im.view(b, -1, 1) * im.view(b, 1, -1)).sum(dim=0)
        n += b
    m1, m2 = m1/n, m2/n
    cov = m2 - m1.view(-1, 1) * m1.view(1, -1)
    return m1.cpu(), cov.cpu()
def interpolate(A, B, num_interps):
    """
    Function to interpolate between images.
    It does this by linearly interpolating between the
    probability of each category you select and linearly
    interpolating between the latent vector values.

    Args:
        A: list
            List of categories
        B: list
            List of categories
        num_interps: int
            Quantity of pixel grids

    Returns:
        Interpolated np.ndarray
    """
    if A.shape != B.shape:
        raise ValueError("A and B must have the same shape")
    alphas = np.linspace(0, 1, num_interps)
    return np.array([(1-a) * A + a * B for a in alphas])
def kl_q_p(zs, phi):
    """
    Compute the KL divergence KL(q||p).

    Args:
        zs: torch.tensor
            Samples z drawn from q, shape [b, n, k].
        phi: torch.tensor
            Parameters of q, shape [b, k+1].

    Returns:
        torch.tensor
            KL divergence estimate.
    """
    b, n, k = zs.size()
    mu_q, log_sig_q = phi[:, :-1], phi[:, -1]
    
    # Compute log probabilities for p(z) ~ N(0, 1)
    log_p = -0.5 * (zs**2)

    # Compute log probabilities for q(z|phi)
    log_q = -0.5 * ((zs - mu_q.view(b, 1, k))**2 / log_sig_q.exp().view(b, 1, 1)**2) - log_sig_q.view(b, 1, -1)

    # Compute KL divergence
    return (log_q - log_p).sum(dim=-1).mean()
def log_p_x(x, mu_xs, sig_x):
    """
    Given [batch, ...] input x and [batch, n, ...] reconstructions, compute
    pixel-wise log Gaussian probability Sum over pixel dimensions, but mean over batch and samples.

    Args:
        x: np.ndarray
            Input Data
        mu_xs: np.ndarray
            Log of mean of samples
        sig_x: np.ndarray
            Log of standard deviation

    Returns:
        Mean over batch and samples.
    """
    b, n = mu_xs.size()[:2]
    # Flatten out pixels and add a singleton dimension [1] so that x will be
    # implicitly expanded when combined with mu_xs
    x = x.reshape(b, 1, -1)
    _, _, p = x.size()
    squared_error = (x - mu_xs.view(b, n, -1))**2 / (2 * sig_x**2)

    # Size of squared_error is [b, n, p]. log prob is by definition sum over [p].
    # Expected value requires mean over [n].
    # Handling different size batches requires mean over [b].
    return -(squared_error + torch.log(sig_x)).sum(dim=2).mean(dim=(0, 1))
def pca_encoder_decoder(mu, cov, k):
    mu = mu.view(1, -1)
    u, s, v = torch.svd_lowrank(cov, q=k)
    W_encode = v / torch.sqrt(s)
    W_decode = u * torch.sqrt(s)

    def pca_encode(x):
        return (x.view(-1, mu.numel()) - mu) @ W_encode

    def pca_decode(h):
        return (h @ W_decode.T) + mu

    return pca_encode, pca_decode
def cout(x, layer):
    """
    Unnecessarily complicated but complete way to
    calculate the output depth, height and width size for a Conv2D layer

    Args:
        x: tuple
            Input size (depth, height, width)
        layer: nn.Conv2d
            The Conv2D layer

    Returns:
        Tuple of out-depth/out-height and out-width 
        Output shape as given in [Ref]
        Ref: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
    """
    assert isinstance(layer, nn.Conv2d)
    p = layer.padding if isinstance(layer.padding, tuple) else (layer.padding, )
    k = layer.kernel_size if isinstance(layer.kernel_size, tuple) else (layer.kernel_size, )
    d = layer.dilation if isinstance(layer.dilation, tuple) else (layer.dilation, )
    s = layer.stride if isinstance(layer.stride, tuple) else (layer.stride, )
    in_depth, in_height, in_width = x
    out_depth = layer.out_channels
    out_height = 1 + (in_height + 2 * p[0] - (k[0] - 1) * d[0] - 1) // s[0]
    out_width = 1 + (in_width + 2 * p[-1] - (k[-1] - 1) * d[-1] - 1) // s[-1]
    return (out_depth, out_height, out_width)
    
def cout(x, layer):
    """
    Unnecessarily complicated but complete way to
    calculate the output depth, height and width size for a Conv2D layer

    Args:
        x: tuple
            Input size (depth, height, width)
        layer: nn.Conv2d
            The Conv2D layer

    Returns:
        Tuple of out-depth/out-height and out-width 
        Output shape as given in [Ref]
        Ref: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
    """
    assert isinstance(layer, nn.Conv2d)
    p = layer.padding if isinstance(layer.padding, tuple) else (layer.padding, )
    k = layer.kernel_size if isinstance(layer.kernel_size, tuple) else (layer.kernel_size, )
    d = layer.dilation if isinstance(layer.dilation, tuple) else (layer.dilation, )
    s = layer.stride if isinstance(layer.stride, tuple) else (layer.stride, )
    in_depth, in_height, in_width = x
    out_depth = layer.out_channels
    out_height = 1 + (in_height + 2 * p[0] - (k[0] - 1) * d[0] - 1) // s[0]
    out_width = 1 + (in_width + 2 * p[-1] - (k[-1] - 1) * d[-1] - 1) // s[-1]
    return (out_depth, out_height, out_width)
    
def cout(x, layer):
    """
    Unnecessarily complicated but complete way to
    calculate the output depth, height and width size for a Conv2D layer

    Args:
        x: tuple
            Input size (depth, height, width)
        layer: nn.Conv2d
            The Conv2D layer

    Returns:
        Tuple of out-depth/out-height and out-width 
        Output shape as given in [Ref]
        Ref: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
    """
    assert isinstance(layer, nn.Conv2d)
    p = layer.padding if isinstance(layer.padding, tuple) else (layer.padding, )
    k = layer.kernel_size if isinstance(layer.kernel_size, tuple) else (layer.kernel_size, )
    d = layer.dilation if isinstance(layer.dilation, tuple) else (layer.dilation, )
    s = layer.stride if isinstance(layer.stride, tuple) else (layer.stride, )
    in_depth, in_height, in_width = x
    out_depth = layer.out_channels
    out_height = 1 + (in_height + 2 * p[0] - (k[0] - 1) * d[0] - 1) // s[0]
    out_width = 1 + (in_width + 2 * p[-1] - (k[-1] - 1) * d[-1] - 1) // s[-1]
    return (out_depth, out_height, out_width)
    
# @title Plotting functions

def plot_gen_samples_ppca(therm1, therm2, therm_data_sim):
    """
    Plotting generated samples

    Args:
        therm1: list
            Thermometer 1
        therm2: list
            Thermometer 2
        therm_data_sim: list
            Generated (simulate, draw) `n_samples` from pPCA model

    Returns:
        Nothing
    """
    plt.plot(therm1, therm2, '.', c='c', label='training data')
    plt.plot(therm_data_sim[0], therm_data_sim[1], '.', c='m', label='"generated" data')
    plt.axis('equal')
    plt.xlabel(r'Thermometer 1 ($^\circ$C)')
    plt.ylabel(r'Thermometer 2 ($^\circ$C)')
    plt.legend()
    plt.show()


def plot_linear_ae(lin_losses):
    """
    Plotting linear autoencoder

    Args:
        lin_losses: list
            Log of linear autoencoder MSE losses

    Returns:
        Nothing
    """
    plt.figure()
    plt.plot(lin_losses)
    plt.ylim([0, 2 * torch.as_tensor(lin_losses).median()])
    plt.xlabel('Training batch')
    plt.ylabel('MSE Loss')
    plt.show()


def plot_conv_ae(lin_losses, conv_losses):
    """
    Plotting convolutional autoencoder

    Args:
        lin_losses: list
            Log of linear autoencoder MSE losses
        conv_losses: list
            Log of convolutional model MSE losses

    Returns:
        Nothing
    """
    plt.figure()
    plt.plot(lin_losses)
    plt.plot(conv_losses)
    plt.legend(['Lin AE', 'Conv AE'])
    plt.xlabel('Training batch')
    plt.ylabel('MSE Loss')
    plt.ylim([0,
              2 * max(torch.as_tensor(conv_losses).median(),
                      torch.as_tensor(lin_losses).median())])
    plt.show()


def plot_images(images, h=3, w=3, plt_title=''):
    """
    Helper function to plot images

    Args:
        images: torch.tensor
            Images
        h: int
            Image height
        w: int
            Image width
        plt_title: string
            Plot title

    Returns:
        Nothing
    """
    plt.figure(figsize=(h * 2, w * 2))
    plt.suptitle(plt_title, y=1.03)
    for i in range(h * w):
        plt.subplot(h, w, i + 1)
        plot_torch_image(images[i])
    plt.axis('off')
    plt.show()


def plot_phi(phi, num=4):
    """
    Contour plot of relative entropy across samples

    Args:
        phi: list
            Log of relative entropy changes
        num: int
            Number of iterations
    """
    plt.figure(figsize=(12, 3))
    for i in range(num):
        plt.subplot(1, num, i + 1)
        plt.scatter(zs[i, :, 0], zs[i, :, 1], marker='.')
        th = torch.linspace(0, 6.28318, 100)
        x, y = torch.cos(th), torch.sin(th)
        # Draw 2-sigma contours
        plt.plot(
            2 * x * phi[i, 2].exp().item() + phi[i, 0].item(),
            2 * y * phi[i, 2].exp().item() + phi[i, 1].item()
        )
        plt.xlim(-5, 5)
        plt.ylim(-5, 5)
        plt.grid()
        plt.axis('equal')
    plt.suptitle('If rsample() is correct, then most but not all points should lie in the circles')
    plt.show()


def plot_torch_image(image, ax=None):
    """
    Helper function to plot torch image

    Args:
        image: torch.tensor
            Image
        ax: plt object
            If None, plt.gca()

    Returns:
        Nothing
    """
    ax = ax if ax is not None else plt.gca()
    c, h, w = image.size()
    if c == 1:
        cm = 'gray'
    else:
        cm = None

    # Torch images have shape (channels, height, width)
    # but matplotlib expects
    # (height, width, channels) or just
    # (height, width) when grayscale
    im_plt = torch.clip(image.detach().cpu().permute(1, 2, 0).squeeze(), 0.0, 1.0)
    ax.imshow(im_plt, cmap=cm)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
# @title Set random seed

# @markdown Executing `set_seed(seed=seed)` sets the random seed.

# For deep learning, it's critical to set the random seed so that students can 
# have a baseline to compare their results to expected results.
# Read more here: https://pytorch.org/docs/stable/notes/randomness.html

# Call `set_seed` function in exercises to ensure reproducibility.
import random
import torch
import numpy as np

def set_seed(seed=None, seed_torch=True):
    """
    Function that controls randomness. NumPy and random modules must be imported.

    Args:
        seed : int, optional
            A non-negative integer that defines the random state. Default is `None`.
        seed_torch : bool, optional
            If `True`, sets the random seed for PyTorch tensors. Default is `True`.

    Returns:
        None
    """
    if seed is None:
        seed = np.random.choice(2 ** 32)
    random.seed(seed)
    np.random.seed(seed)
    if seed_torch:
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    print(f'Random seed {seed} has been set.')

# In case that `DataLoader` is used
def seed_worker(worker_id):
    """
    DataLoader will reseed workers following randomness in
    multi-process data loading algorithm.

    Args:
        worker_id : int
            ID of subprocess to seed. `0` means that data will be loaded in 
            the main process.
            Refer to: https://pytorch.org/docs/stable/data.html#data-loading-randomness

    Returns:
        None
    """
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
def set_device():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device != "cuda":
        print("WARNING: Not running with GPU")
    else:
        print("GPU is enabled")
    return device
SEED = 2021
set_seed(seed=SEED)
DEVICE = set_device()
Random seed 2021 has been set.
GPU is enabled
# @title Download `wordnet` dataset

"""
NLTK Download Instructions:
To download `wordnet` using NLTK, run the following:

import nltk
nltk.download('wordnet')
"""

import os
import requests
import zipfile

# Set the environment variable for NLTK data directory
os.environ['NLTK_DATA'] = 'nltk_data/'

# Filenames and corresponding download URLs
fnames = ['wordnet.zip', 'omw-1.4.zip']
urls = ['https://osf.io/ekjxy/download', 'https://osf.io/kuwep/download']

# Download and extract each file
for fname, url in zip(fnames, urls):
    # Download the file
    response = requests.get(url, allow_redirects=True)
    with open(fname, 'wb') as file:
        file.write(response.content)
    
    # Extract the contents of the zip file
    with zipfile.ZipFile(fname, 'r') as zip_ref:
        zip_ref.extractall('nltk_data/corpora')

print("WordNet dataset downloaded and extracted successfully.")
WordNet dataset downloaded and extracted successfully.
# @title Submit your feedback
content_review(f"{feedback_prefix}_Generative_Modeling_Video")
VBox(children=(VBox(children=(HBox(children=(Button(description='🙂', layout=Layout(height='auto', padding='0.5…
# @title Download BigGAN (a generative model) and a few standard image datasets

# Initially, the model could be downloaded directly via:
# biggan_model = BigGAN.from_pretrained('biggan-deep-256')

import requests
import torch

# URL for the BigGAN model
url = "https://osf.io/3yvhw/download"
# Filename for the downloaded model
fname = "biggan_deep_256"

# # Download the BigGAN model
# response = requests.get(url, allow_redirects=True)
# with open(fname, 'wb') as file:
#     file.write(response.content)

# Load the downloaded BigGAN model
biggan_model = torch.load(fname)

print("BigGAN model downloaded and loaded successfully.")
/tmp/ipykernel_42455/2836300935.py:20: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  biggan_model = torch.load(fname)


BigGAN model downloaded and loaded successfully.
!pip install nltk --quiet

# @title BigGAN Image Generator (updates may take a few seconds, please be patient)

from scipy.stats import truncnorm
import torch
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import FloatSlider, Dropdown, VBox, interactive_output, Layout

def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1.0, seed=None):
    """
    Create a truncated noise vector.

    Args:
        batch_size (int): Number of samples to generate.
        dim_z (int): Dimension of the latent vector z.
        truncation (float): Truncation value to control the range of z.
        seed (int, optional): Seed for the random generator.

    Returns:
        np.ndarray: Array of shape (batch_size, dim_z) with truncated noise values.
    """
    state = None if seed is None else np.random.RandomState(seed)
    values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32)
    return truncation * values


def sample_from_biggan(category, z_magnitude):
    """
    Sample images from BigGAN Image Generator.

    Args:
        category (str): Selected category label.
        z_magnitude (float): Magnitude of variation vector (truncation value).

    Returns:
        None
    """
    truncation = z_magnitude
    z = truncated_noise_sample(truncation=truncation, batch_size=4)
    y = one_hot_from_names(category, batch_size=4)

    # Convert to tensors
    z = torch.from_numpy(z).float()
    y = torch.from_numpy(y)

    # Move tensors and model to GPU
    device = set_device()
    z = z.to(device)
    y = y.to(device)
    biggan_model.to(device)

    # Generate images
    with torch.no_grad():
        output = biggan_model(z, y, truncation)

    # Process the output to [0, 1] range
    output = output.to('cpu')
    output = torch.clip(((output.detach().clone() + 1) / 2.0), 0, 1)

    # Plot the generated images
    fig, axes = plt.subplots(2, 2, figsize=(8, 8))
    axes = axes.flatten()
    for i in range(4):
        axes[i].imshow(output[i].squeeze().permute(1, 2, 0))
        axes[i].axis('off')
    plt.show()


# UI Widgets for interactive control
z_slider = FloatSlider(
    min=0.1, max=1.0, step=0.1, value=0.1,
    continuous_update=False,
    description='Truncation Value',
    style={'description_width': '100px'},
    layout=Layout(width='440px')
)

category_dropdown = Dropdown(
    options=['tench', 'magpie', 'jellyfish', 'German shepherd', 'bee',
             'acoustic guitar', 'coffee mug', 'minibus', 'monitor'],
    value="German shepherd",
    description="Category:"
)

widgets_ui = VBox([category_dropdown, z_slider])

# Bind the UI widgets to the function
widgets_out = interactive_output(
    sample_from_biggan,
    {
        'z_magnitude': z_slider,
        'category': category_dropdown
    }
)

# Display the widgets and output
display(widgets_ui, widgets_out)
VBox(children=(Dropdown(description='Category:', index=3, options=('tench', 'magpie', 'jellyfish', 'German she…



Output()

Interpolating images with BigGAN

# @markdown BigGAN Interpolation Widget (the updates may take a few seconds)

def interpolate_biggan(category_A,
                       category_B):
    """
    Interpolation function with BigGan

    Args:
        category_A: string
            Category specification
        category_B: string
            Category specification

    Returns:
        Nothing
    """
    num_interps = 16

    # category_A = 'jellyfish' #@param ['tench', 'magpie', 'jellyfish', 'German shepherd', 'bee', 'acoustic guitar', 'coffee mug', 'minibus', 'monitor']
    # z_magnitude_A = 0 #@param {type:"slider", min:-10, max:10, step:1}

    # category_B = 'German shepherd' #@param ['tench', 'magpie', 'jellyfish', 'German shepherd', 'bee', 'acoustic guitar', 'coffee mug', 'minibus', 'monitor']
    # z_magnitude_B = 0 #@param {type:"slider", min:-10, max:10, step:1}


    def interpolate_and_shape(A, B, num_interps):
        """
        Function to interpolate and shape images.
        It does this by linearly interpolating between the
        probability of each category you select and linearly
        interpolating between the latent vector values.

        Args:
            A: list
                List of categories
            B: list
                List of categories
            num_interps: int
                Quantity of pixel grids

        Returns:
            Interpolated np.ndarray
        """
        interps = interpolate(A, B, num_interps)
        return (interps.transpose(1, 0, *range(2, len(interps.shape))).reshape(num_interps, *interps.shape[2:]))

    # unit_vector = np.ones((1, 128))/np.sqrt(128)
    # z_A = z_magnitude_A * unit_vector
    # z_B = z_magnitude_B * unit_vector
    truncation = .4
    z_A = truncated_noise_sample(truncation=truncation, batch_size=1)
    z_B = truncated_noise_sample(truncation=truncation, batch_size=1)
    y_A = one_hot_from_names(category_A, batch_size=1)
    y_B = one_hot_from_names(category_B, batch_size=1)

    z_interp = interpolate_and_shape(z_A, z_B, num_interps)
    y_interp = interpolate_and_shape(y_A, y_B, num_interps)

    # Convert to tensor
    z_interp = torch.from_numpy(z_interp).float()
    y_interp = torch.from_numpy(y_interp).float()

    # Move to GPU
    z_interp = z_interp.to(DEVICE)
    y_interp = y_interp.to(DEVICE)
    biggan_model.to(DEVICE)

    with torch.no_grad():
        output = biggan_model(z_interp, y_interp, 1)

    # Back to CPU
    output = output.to('cpu')

    # The output layer of BigGAN has a tanh layer,
    # resulting the range of [-1, 1] for the output image
    # Therefore, we normalize the images properly to
    # [0, 1] range.
    # Clipping is only in case of numerical instability
    # problems

    output = torch.clip(((output.detach().clone() + 1) / 2.0), 0, 1)
    output = output

    # Make grid and show generated samples
    output_grid = torchvision.utils.make_grid(output,
                                              nrow=min(4, output.shape[0]),
                                              padding=5)
    plt.axis('off');
    plt.imshow(output_grid.permute(1, 2, 0))
    plt.show()


# z_A_slider = IntSlider(min=-10, max=10, step=1, value=0,
#                        continuous_update=False, description='Z Magnitude A',
#                        layout=Layout(width='440px'),
#                        style={'description_width': 'initial'})

# z_B_slider = IntSlider(min=-10, max=10, step=1, value=0,
#                        continuous_update=False, description='Z Magntude B',
#                        layout=Layout(width='440px'),
#                        style={'description_width': 'initial'})

category_A_dropdown = Dropdown(
    options=['tench', 'magpie', 'jellyfish', 'German shepherd', 'bee',
             'acoustic guitar', 'coffee mug', 'minibus', 'monitor'],
             value="German shepherd",
             description="Category A: ")

category_B_dropdown = Dropdown(
    options=['tench', 'magpie', 'jellyfish', 'German shepherd', 'bee',
             'acoustic guitar', 'coffee mug', 'minibus', 'monitor'],
             value="jellyfish",
             description="Category B: ")



widgets_ui = VBox([HBox([category_A_dropdown]),
                   HBox([category_B_dropdown])])

widgets_out = interactive_output(interpolate_biggan,
                                 {'category_A': category_A_dropdown,
                                  # 'z_magnitude_A': z_A_slider,
                                  'category_B': category_B_dropdown})
                                  # 'z_magnitude_B': z_B_slider})

display(widgets_ui, widgets_out)
VBox(children=(HBox(children=(Dropdown(description='Category A: ', index=3, options=('tench', 'magpie', 'jelly…



Output()

Latent Variable Models

def generate_data(n_samples, mean_of_temps, cov_of_temps, seed):
    """
    Generate random data, normally distributed

    Args:
        n_samples : int
            The number of samples to be generated
        mean_of_temps : numpy.ndarray
            1D array with the mean of temparatures, Kx1
        cov_of_temps : numpy.ndarray
            2D array with the covariance, KxK
        seed : int
            Set random seed for the pseudo random generator

    Returns:
        therm1 : numpy.ndarray
            Thermometer 1
        therm2 : numpy.ndarray
            Thermometer 2
    """

    np.random.seed(seed)
    therm1, therm2 = np.random.multivariate_normal(mean_of_temps,
                                                   cov_of_temps,
                                                   n_samples).T
    return therm1, therm2


n_samples = 2000
mean_of_temps = np.array([25, 25])
cov_of_temps = np.array([[10, 5], [5, 10]])
therm1, therm2 = generate_data(n_samples, mean_of_temps, cov_of_temps, seed=SEED)

plt.plot(therm1, therm2, '.')
plt.axis('equal')
plt.xlabel('Thermometer 1 ($^\\circ$C)')
plt.ylabel('Thermometer 2 ($^\\circ$C)')
plt.show()

png

# @markdown Add first PC axes to the plot

plt.plot(therm1, therm2, '.')
plt.axis('equal')
plt.xlabel('Thermometer 1 ($^\\circ$C)')
plt.ylabel('Thermometer 2 ($^\\circ$C)')
plt.plot([plt.axis()[0], plt.axis()[1]],
         [plt.axis()[0], plt.axis()[1]])
plt.show()

png

# Project Data onto the principal component axes.
# We could have "learned" this from the data by applying PCA,
# but we "know" the value from the problem definition.
pc_axes = np.array([1.0, 1.0]) / np.sqrt(2.0)

# Thermometers data
therm_data = np.array([therm1, therm2])

# Zero center the data
therm_data_mean = np.mean(therm_data, 1)
therm_data_center = np.outer(therm_data_mean, np.ones(therm_data.shape[1]))
therm_data_zero_centered = therm_data - therm_data_center

# Calculate the variance of the projection on the PC axes
pc_projection = np.matmul(pc_axes, therm_data_zero_centered)
pc_axes_variance = np.var(pc_projection)

# Calculate the residual variance (variance not accounted for by projection on the PC axes)
sensor_noise_std = np.mean(np.linalg.norm(therm_data_zero_centered - np.outer(pc_axes, pc_projection), axis=0, ord=2))
sensor_noise_var = sensor_noise_std ** 2
def gen_from_pPCA(noise_var, data_mean, pc_axes, pc_variance):
    """
    Generate samples from pPCA

    Args:
        noise_var: np.ndarray
            Sensor noise variance
        data_mean: np.ndarray
            Thermometer data mean
        pc_axes: np.ndarray
            Principal component axes
        pc_variance: np.ndarray
            The variance of the projection on the PC axes

    Returns:
        therm_data_sim: np.ndarray
            Generated (simulate, draw) `n_samples` from pPCA model
    """
    # We are matching this value to the thermometer data so the visualizations look similar
    n_samples = 1000

    # Randomly sample from z (latent space value)
    z = np.random.normal(0.0, np.sqrt(pc_variance), n_samples)

    # Sensor noise covariance matrix (∑)
    epsilon_cov = [[noise_var, 0.0], [0.0, noise_var]]

    # Data mean reshaped for the generation
    sim_mean = np.outer(data_mean, np.ones(n_samples))

    # Draw `n_samples` from `np.random.multivariate_normal`
    rand_eps = np.random.multivariate_normal([0.0, 0.0], epsilon_cov, n_samples)
    rand_eps = rand_eps.T

    # Generate (simulate, draw) `n_samples` from pPCA model
    therm_data_sim = sim_mean + np.outer(pc_axes, z) + rand_eps

    return therm_data_sim


## Uncomment to test your code
therm_data_sim = gen_from_pPCA(sensor_noise_var, therm_data_mean, pc_axes, pc_axes_variance)
plot_gen_samples_ppca(therm1, therm2, therm_data_sim)

png

Autoencoders

# @markdown Download MNIST and CIFAR10 datasets
import tarfile, requests, os

fname = 'MNIST.tar.gz'
name = 'mnist'
url = 'https://osf.io/y2fj6/download'

if not os.path.exists(name):
    print('\nDownloading MNIST dataset...')
    r = requests.get(url, allow_redirects=True)
    with open(fname, 'wb') as fh:
        fh.write(r.content)
    print('\nDownloading MNIST completed!\n')

if not os.path.exists(name):
    with tarfile.open(fname) as tar:
        tar.extractall(name, filter=None)
        os.remove(fname)
else:
    print('MNIST dataset has been downloaded.\n')

fname = 'cifar-10-python.tar.gz'
name = 'cifar10'
url = 'https://osf.io/jbpme/download'

if not os.path.exists(name):
    print('\nDownloading CIFAR10 dataset...')
    r = requests.get(url, allow_redirects=True)
    with open(fname, 'wb') as fh:
        fh.write(r.content)
    print('\nDownloading CIFAR10 completed!')

if not os.path.exists(name):
    with tarfile.open(fname) as tar:
        tar.extractall(name, filter=None)
        os.remove(fname)
else:
    print('CIFAR10 dataset has been downloaded.')
MNIST dataset has been downloaded.

CIFAR10 dataset has been downloaded.
# @markdown Load MNIST and CIFAR10 image datasets
# See https://pytorch.org/docs/stable/torchvision/datasets.html

# MNIST
mnist = datasets.MNIST('./mnist/',
                       train=True,
                       transform=transforms.ToTensor(),
                       download=False)
mnist_val = datasets.MNIST('./mnist/',
                           train=False,
                           transform=transforms.ToTensor(),
                           download=False)

# CIFAR 10
cifar10 = datasets.CIFAR10('./cifar10/',
                           train=True,
                           transform=transforms.ToTensor(),
                           download=False)
cifar10_val = datasets.CIFAR10('./cifar10/',
                               train=False,
                               transform=transforms.ToTensor(),
                               download=False)
# @markdown Execute this cell to enable helper function `get_data`

def get_data(name='mnist'):
    """
    Get data

    Args:
        name: string
            Name of the dataset

    Returns:
        my_dataset: dataset instance
            Instance of dataset
        my_dataset_name: string
            Name of the dataset
        my_dataset_shape: tuple
            Shape of dataset
        my_dataset_size: int
            Size of dataset
        my_valset: torch.loader
            Validation loader
    """
    if name == 'mnist':
        my_dataset_name = "MNIST"
        my_dataset = mnist
        my_valset = mnist_val
        my_dataset_shape = (1, 28, 28)
        my_dataset_size = 28 * 28
    elif name == 'cifar10':
        my_dataset_name = "CIFAR10"
        my_dataset = cifar10
        my_valset = cifar10_val
        my_dataset_shape = (3, 32, 32)
        my_dataset_size = 3 * 32 * 32

    return my_dataset, my_dataset_name, my_dataset_shape, my_dataset_size, my_valset
# @markdown #### Run to define the `train_autoencoder` function.
# @markdown Feel free to inspect the training function if the time allows.

# @markdown `train_autoencoder(autoencoder, dataset, device, epochs=20, batch_size=250, seed=0)`

def train_autoencoder(autoencoder, dataset, device, epochs=20, batch_size=250, seed=0):
    """
    Function to train autoencoder

    Args:
        autoencoder: nn.module
            Autoencoder instance
        dataset: function
            Dataset
        device: string
            GPU if available. CPU otherwise
        epochs: int
            Number of epochs [default: 20]
        batch_size: int
            Batch size
        seed: int
            Set seed for reproducibility; [default: 0]

    Returns:
        mse_loss: float
            MSE Loss
    """
    autoencoder.to(device)
    optim = torch.optim.Adam(autoencoder.parameters(),
                             lr=1e-3,
                             weight_decay=1e-5)
    loss_fn = nn.MSELoss()
    g_seed = torch.Generator()
    g_seed.manual_seed(seed)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        pin_memory=True,
                        num_workers=2,
                        worker_init_fn=seed_worker,
                        generator=g_seed)

    mse_loss = torch.zeros(epochs * len(dataset) // batch_size, device=device)
    i = 0
    for epoch in trange(epochs, desc='Epoch'):
        for im_batch, _ in loader:
            im_batch = im_batch.to(device)
            optim.zero_grad()
            reconstruction = autoencoder(im_batch)
            # Loss calculation
            loss = loss_fn(reconstruction.view(batch_size, -1),
                           target=im_batch.view(batch_size, -1))
            loss.backward()
            optim.step()

            mse_loss[i] = loss.detach()
            i += 1
    # After training completes,
    # make sure the model is on CPU so we can easily
    # do more visualizations and demos.
    autoencoder.to('cpu')
    return mse_loss.cpu()
class LinearAutoEncoder(nn.Module):
    def __init__(self, x_dim, h_dim):
        super().__init__()
        self.enc_lin = nn.Linear(x_dim, h_dim)
        self.dec_lin = nn.Linear(h_dim, x_dim)

    def encode(self, x):
        return self.enc_lin(x)

    def decode(self, h):
        return self.dec_lin(h)

    def forward(self, x):
        flat_x = x.view(x.size()[0], -1)
        h = self.encode(flat_x)
        return self.decode(h).view(x.size())
dataset_name = 'mnist'  # This can be mnist or cifar10
train_set, dataset_name, data_shape, data_size, valid_set = get_data(name=dataset_name)

# Pick your own K
K = 20
set_seed(seed=SEED)
## Uncomment to test your code
lin_ae = LinearAutoEncoder(data_size, K)
lin_losses = train_autoencoder(lin_ae, train_set, device=DEVICE, seed=SEED)
plot_linear_ae(lin_losses)
Random seed 2021 has been set.



Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

png

# PCA requires finding the top K eigenvectors of the data covariance. Start by
# finding the mean and covariance of the pixels in our dataset
g_seed = torch.Generator()
g_seed.manual_seed(SEED)

loader = DataLoader(train_set,
                    batch_size=32,
                    pin_memory=True,
                    num_workers=2,
                    worker_init_fn=seed_worker,
                    generator=g_seed)
mu, cov = image_moments((im for im, _ in loader), n_batches=len(train_set) // 32)
pca_encode, pca_decode = pca_encoder_decoder(mu, cov, K)
Computing pixel mean and covariance:   0%|          | 0/1875 [00:00<?, ?it/s]
# @markdown Visualize the reconstructions $\mathbf{x}'$, run this code a few times to see different examples.

n_plot = 7
plt.figure(figsize=(10, 4.5))
for i in range(n_plot):
    idx = torch.randint(len(train_set), size=())
    image, _ = train_set[idx]
    # Get reconstructed image from autoencoder
    with torch.no_grad():
        reconstruction = lin_ae(image.unsqueeze(0)).reshape(image.size())

    # Get reconstruction from PCA dimensionality reduction
    h_pca = pca_encode(image)
    recon_pca = pca_decode(h_pca).reshape(image.size())

    plt.subplot(3, n_plot, i + 1)
    plot_torch_image(image)
    if i == 0:
        plt.ylabel('Original\nImage')

    plt.subplot(3, n_plot, i + 1 + n_plot)
    plot_torch_image(reconstruction)
    if i == 0:
        plt.ylabel(f'Lin AE\n(K={K})')

    plt.subplot(3, n_plot, i + 1 + 2 * n_plot)
    plot_torch_image(recon_pca)
    if i == 0:
        plt.ylabel(f'PCA\n(K={K})')
plt.show()

png

non-linear convolutional autoencoder

class BiasLayer(nn.Module):
    def __init__(self, shape):
        super(BiasLayer, self).__init__()
        init_bias = torch.zeros(shape)
        self.bias = nn.Parameter(init_bias, requires_grad=True)

    def forward(self, x):
        return x + self.bias
dummy_image = torch.rand(data_shape).unsqueeze(0)
in_channels = data_shape[0]
out_channels = 7

dummy_conv = nn.Conv2d(in_channels=in_channels,
                       out_channels=out_channels,
                       kernel_size=5)

dummy_deconv = nn.ConvTranspose2d(in_channels=out_channels,
                                  out_channels=in_channels,
                                  kernel_size=5)

print(f'Size of image is {dummy_image.shape}')
print(f'Size of Conv2D(image) {dummy_conv(dummy_image).shape}')
print(f'Size of ConvTranspose2D(Conv2D(image)) {dummy_deconv(dummy_conv(dummy_image)).shape}')
Size of image is torch.Size([1, 1, 28, 28])
Size of Conv2D(image) torch.Size([1, 7, 24, 24])
Size of ConvTranspose2D(Conv2D(image)) torch.Size([1, 1, 28, 28])
class ConvAutoEncoder(nn.Module):
    def __init__(self, x_dim, h_dim, n_filters=32, filter_size=5):
        super().__init__()
        channels, height, widths = x_dim

        # encoder
        self.enc_bias = BiasLayer(x_dim)

        self.enc_conv_1 = nn.Conv2d(channels, n_filters, filter_size)
        conv1_shape = cout(x_dim, self.enc_conv_1)

        self.enc_conv_2 = nn.Conv2d(n_filters, n_filters, filter_size)
        conv2_shape = cout(conv1_shape, self.enc_conv_2)

        self.enc_flatten = nn.Flatten()
        flat_after_conv = conv2_shape[0] * conv2_shape[1] * conv2_shape[2]

        self.enc_lin = nn.Linear(flat_after_conv, h_dim)


        # decoder
        self.dec_lin = nn.Linear(h_dim, flat_after_conv)

        self.dec_unflatten = nn.Unflatten(dim=-1, unflattened_size=conv2_shape)

        self.dec_deconv_1 = nn.ConvTranspose2d(n_filters, n_filters, filter_size)
        self.dec_deconv_2 = nn.ConvTranspose2d(n_filters, channels, filter_size)

        self.dec_bias = BiasLayer(x_dim)

    def encode(self, x):
        s = self.enc_bias(x)
        s = F.relu(self.enc_conv_1(s))
        s = F.relu(self.enc_conv_2(s))
        s = self.enc_flatten(s)
        h = self.enc_lin(s)
        return h

    def decode(self, h):
        s = F.relu(self.dec_lin(h))
        s = self.dec_unflatten(s)
        s = F.relu(self.dec_deconv_1(s))
        s = self.dec_deconv_2(s)
        x_prime = self.dec_bias(s)
        return x_prime

    def forward(self, x):
        h = self.encode(x)
        x_prime = self.decode(h)
        return x_prime
set_seed(seed=SEED)
## Uncomment to test your solution
trained_conv_AE = ConvAutoEncoder(data_shape, K)
assert trained_conv_AE.encode(train_set[0][0].unsqueeze(0)).numel() == K, "Encoder output size should be K!"
conv_losses = train_autoencoder(trained_conv_AE, train_set, device=DEVICE, seed=SEED)
with plt.xkcd():
  plot_conv_ae(lin_losses, conv_losses)
Random seed 2021 has been set.



Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

png

# @markdown Visualize the linear and nonlinear AE outputs
if lin_ae.enc_lin.out_features != trained_conv_AE.enc_lin.out_features:
    raise ValueError('ERROR: your linear and convolutional autoencoders have different values of K')

n_plot = 7
plt.figure(figsize=(10, 4.5))
for i in range(n_plot):
    idx = torch.randint(len(train_set), size=())
    image, _ = train_set[idx]
    with torch.no_grad():
        # Get reconstructed image from linear autoencoder
        lin_recon = lin_ae(image.unsqueeze(0))[0]

        # Get reconstruction from deep (nonlinear) autoencoder
        nonlin_recon = trained_conv_AE(image.unsqueeze(0))[0]

    plt.subplot(3, n_plot, i + 1)
    plot_torch_image(image)
    if i == 0:
        plt.ylabel('Original\nImage')

    plt.subplot(3, n_plot, i + 1 + n_plot)
    plot_torch_image(lin_recon)
    if i == 0:
        plt.ylabel(f'Lin AE\n(K={K})')

    plt.subplot(3, n_plot, i + 1 + 2 * n_plot)
    plot_torch_image(nonlin_recon)
    if i == 0:
        plt.ylabel(f'NonLin AE\n(K={K})')
plt.show()

png

# @markdown Train a VAE for MNIST while watching the video.
# (Note: this VAE has a 2D latent space. If you are feeling ambitious,
# edit the code and modify the latent space dimensionality and see what happens.)

K_VAE = 2


class ConvVAE(nn.Module):
    """
    Convolutional Variational Autoencoder
    """

    def __init__(self, K, num_filters=32, filter_size=5):
        """
        Initialize parameters of ConvVAE

        Args:
            K: int
                Bottleneck dimensionality
            num_filters: int
                Number of filters [default: 32]
            filter_size: int
                Filter size [default: 5]

        Returns:
            Nothing
        """
        super(ConvVAE, self).__init__()

        # With padding=0, calculate pixel reduction per Conv2D/ConvTranspose2D layer
        filter_reduction = 2 * (filter_size // 2)

        # Calculate shape after two Conv2D layers
        self.shape_after_conv = (
            num_filters,
            data_shape[1] - 2 * filter_reduction,
            data_shape[2] - 2 * filter_reduction
        )
        flat_size_after_conv = (
            self.shape_after_conv[0]
            * self.shape_after_conv[1]
            * self.shape_after_conv[2]
        )

        # Define the encoder (recognition model)
        self.q_bias = BiasLayer(data_shape)
        self.q_conv_1 = nn.Conv2d(data_shape[0], num_filters, 5)
        self.q_conv_2 = nn.Conv2d(num_filters, num_filters, 5)
        self.q_flatten = nn.Flatten()
        # K for mu, 1 for log_sig
        self.q_fc_phi = nn.Linear(flat_size_after_conv, K + 1)

        # Define the decoder (generative model)
        self.p_fc_upsample = nn.Linear(K, flat_size_after_conv)
        self.p_unflatten = nn.Unflatten(-1, self.shape_after_conv)
        self.p_deconv_1 = nn.ConvTranspose2d(num_filters, num_filters, 5)
        self.p_deconv_2 = nn.ConvTranspose2d(num_filters, data_shape[0], 5)
        self.p_bias = BiasLayer(data_shape)

        # Parameter for learning scalar sigma_x for all pixels
        self.log_sig_x = nn.Parameter(torch.zeros(()))

    def infer(self, x):
        """
        Map x to phi for latent space sampling

        Args:
            x: torch.tensor
                Input features

        Returns:
            phi: torch.tensor
                Latent space parameters
        """
        s = self.q_bias(x)
        s = F.relu(self.q_conv_1(s))
        s = F.relu(self.q_conv_2(s))
        flat_s = s.view(s.size()[0], -1)
        phi = self.q_fc_phi(flat_s)
        return phi

    def generate(self, zs):
        """
        Generate images from latent samples

        Args:
            zs: torch.tensor
                Latent samples

        Returns:
            mu_xs: torch.tensor
                Generated images
        """
        b, n, k = zs.size()
        s = zs.view(b * n, -1)
        s = F.relu(self.p_fc_upsample(s)).view((b * n,) + self.shape_after_conv)
        s = F.relu(self.p_deconv_1(s))
        s = self.p_deconv_2(s)
        s = self.p_bias(s)
        mu_xs = s.view(b, n, -1)
        return mu_xs

    def decode(self, zs):
        """
        Decoder for compatibility with conv-AE code

        Args:
            zs: torch.tensor
                Latent samples

        Returns:
            torch.tensor
                Decoded images
        """
        return self.generate(zs.unsqueeze(0))

    def forward(self, x):
        """
        Forward pass

        Args:
            x: torch.tensor
                Input image

        Returns:
            torch.tensor
                Reconstructed image
        """
        phi = self.infer(x)
        zs = rsample(phi, 1)
        return self.generate(zs).view(x.size())

    def elbo(self, x, n=1):
        """
        Compute the Evidence Lower Bound (ELBO)

        Args:
            x: torch.tensor
                Input image
            n: int
                Number of latent samples

        Returns:
            torch.tensor
                ELBO value
        """
        phi = self.infer(x)
        zs = rsample(phi, n)
        mu_xs = self.generate(zs)
        return log_p_x(x, mu_xs, self.log_sig_x.exp()) - kl_q_p(zs, phi)


def rsample(phi, n_samples):
    """
    Reparameterization trick for latent sampling

    Args:
        phi: torch.tensor
            Latent space parameters
        n_samples: int
            Number of samples

    Returns:
        torch.tensor
            Latent samples
    """
    b, kplus1 = phi.size()
    k = kplus1 - 1
    mu, sig = phi[:, :-1], phi[:, -1].exp()
    eps = torch.randn(b, n_samples, k, device=phi.device)
    return eps * sig.view(b, 1, 1) + mu.view(b, 1, k)


def train_vae(vae, dataset, epochs=10, n_samples=1000):
    """
    Train a Variational Autoencoder (VAE)

    Args:
        vae: nn.Module
            VAE model
        dataset: torch.utils.data.Dataset
            Dataset
        epochs: int
            Number of training epochs
        n_samples: int
            Number of latent samples

    Returns:
        elbo_vals: list
            ELBO values during training
    """
    opt = torch.optim.Adam(vae.parameters(), lr=1e-3, weight_decay=0)
    elbo_vals = []
    vae.to(DEVICE)
    vae.train()
    loader = DataLoader(dataset, batch_size=250, shuffle=True, pin_memory=True)

    for epoch in trange(epochs, desc='Epochs'):
        for im, _ in tqdm(loader, total=len(dataset) // 250, desc='Batches', leave=False):
            im = im.to(DEVICE)
            opt.zero_grad()
            loss = -vae.elbo(im)
            loss.backward()
            opt.step()
            elbo_vals.append(-loss.item())

    vae.to('cpu')
    vae.eval()
    return elbo_vals


trained_conv_VarAE = ConvVAE(K=K_VAE)
elbo_vals = train_vae(trained_conv_VarAE, train_set, n_samples=10000)

print(f'Learned sigma_x is {torch.exp(trained_conv_VarAE.log_sig_x)}')

# Uncomment below if you'd like to see the training curve of the ELBO loss function
plt.figure()
plt.plot(elbo_vals)
plt.xlabel('Batch #')
plt.ylabel('ELBO')
plt.show()
Epochs:   0%|          | 0/10 [00:00<?, ?it/s]



Batches:   0%|          | 0/240 [00:00<?, ?it/s]



Batches:   0%|          | 0/240 [00:00<?, ?it/s]



Batches:   0%|          | 0/240 [00:00<?, ?it/s]



Batches:   0%|          | 0/240 [00:00<?, ?it/s]



Batches:   0%|          | 0/240 [00:00<?, ?it/s]



Batches:   0%|          | 0/240 [00:00<?, ?it/s]



Batches:   0%|          | 0/240 [00:00<?, ?it/s]



Batches:   0%|          | 0/240 [00:00<?, ?it/s]



Batches:   0%|          | 0/240 [00:00<?, ?it/s]



Batches:   0%|          | 0/240 [00:00<?, ?it/s]


Learned sigma_x is 0.21018077433109283

png

## generating images from above VAE
def generate_images(autoencoder, K, n_images=1):
    output_shape = (n_images, ) + data_shape
    with torch.no_grad():
        # Sample z from a unit gaussian, pass through autoencoder.decode()
        z = torch.randn(n_images, K)
        x = autoencoder.decode(z)

        return x.reshape(output_shape)
        

set_seed(seed=SEED)
## Uncomment to test your solution
images = generate_images(trained_conv_AE, K, n_images=9)
plot_images(images, plt_title='Images Generated from the Conv-AE')
images = generate_images(trained_conv_VarAE, K_VAE, n_images=9)
plot_images(images, plt_title='Images Generated from a Conv-Variational-AE')
Random seed 2021 has been set.

png

png

VQ-VAE

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super(VectorQuantizer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost

        # Initialize embedding vectors
        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings)

    def forward(self, inputs):
        # Flatten input to (batch_size * H * W, embedding_dim)
        input_shape = inputs.shape
        flat_inputs = inputs.reshape(-1, self.embedding_dim)

        # Compute distances and find nearest embedding
        distances = (flat_inputs.unsqueeze(1) - self.embeddings.weight).pow(2).sum(2)
        indices = distances.argmin(1)
        quantized = self.embeddings(indices).reshape(input_shape)

        # Losses
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()  # Straight-through estimator
        return quantized, loss, indices
class VQVAE(nn.Module):
    def __init__(self, input_dim, embedding_dim, num_embeddings):
        super(VQVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_dim, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, embedding_dim, 1)
        )
        self.quantizer = VectorQuantizer(num_embeddings, embedding_dim)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, input_dim, 1),
            nn.Tanh()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        quantized, vq_loss, _ = self.quantizer(encoded.permute(0, 2, 3, 1))
        quantized = quantized.permute(0, 3, 1, 2)
        decoded = self.decoder(quantized)
        return decoded, vq_loss
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Visualize a sample image
sample_image, _ = next(iter(train_loader))
plt.imshow(sample_image[0][0], cmap="gray")
plt.title("Sample Image from MNIST")
plt.show()

png

def train_vqvae(vqvae, dataloader, optimizer, epochs=10):
    vqvae.train()
    for epoch in range(epochs):
        epoch_loss = 0
        for x, _ in dataloader:
            x = x.to(device)
            optimizer.zero_grad()

            reconstruction, vq_loss = vqvae(x)
            recon_loss = F.mse_loss(reconstruction, x)
            loss = recon_loss + vq_loss
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss / len(dataloader)}")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Model parameters
input_dim = 1  # MNIST has 1 channel
embedding_dim = 16  # Smaller latent dimension
num_embeddings = 256  # Smaller dictionary size

# Initialize model and optimizer
vqvae = VQVAE(input_dim, embedding_dim, num_embeddings).to(device)
optimizer = torch.optim.Adam(vqvae.parameters(), lr=1e-3)

# Train the model
train_vqvae(vqvae, train_loader, optimizer, epochs=20)
Epoch [1/20], Loss: 1.2864014311179297
Epoch [2/20], Loss: 0.09627232292313566
Epoch [3/20], Loss: 0.07010942528338067
Epoch [4/20], Loss: 0.06327186301271163
Epoch [5/20], Loss: 0.060750904606222345
Epoch [6/20], Loss: 0.06032293511113759
Epoch [7/20], Loss: 0.059614512918473306
Epoch [8/20], Loss: 0.05700735386405418
Epoch [9/20], Loss: 0.05542420087131991
Epoch [10/20], Loss: 0.05478271394014867
Epoch [11/20], Loss: 0.05404220243443304
Epoch [12/20], Loss: 0.05355198081654272
Epoch [13/20], Loss: 0.053218447188260964
Epoch [14/20], Loss: 0.052963705963766906
Epoch [15/20], Loss: 0.05270474087566074
Epoch [16/20], Loss: 0.05204789565681522
Epoch [17/20], Loss: 0.051554257947721206
Epoch [18/20], Loss: 0.05113089540953448
Epoch [19/20], Loss: 0.05075763985276349
Epoch [20/20], Loss: 0.05034408119044451
def visualize_reconstruction(vqvae, dataloader):
    vqvae.eval()
    with torch.no_grad():
        for x, _ in dataloader:
            x = x.to(device)
            reconstruction, _ = vqvae(x)
            plt.figure(figsize=(8, 4))
            
            # Original
            plt.subplot(1, 2, 1)
            plt.title("Original")
            plt.imshow(x[0][0].cpu().numpy(), cmap="gray")
            
            # Reconstructed
            plt.subplot(1, 2, 2)
            plt.title("Reconstructed")
            plt.imshow(reconstruction[0][0].cpu().numpy(), cmap="gray")
            
            plt.show()
            break

visualize_reconstruction(vqvae, train_loader)

png

def generate_images(vqvae, num_images=5):
    vqvae.eval()
    with torch.no_grad():
        # Sample latent codes from the embedding space
        indices = torch.randint(0, vqvae.quantizer.num_embeddings, (num_images, 7, 7)).to(device)
        quantized = vqvae.quantizer.embeddings(indices).permute(0, 3, 1, 2)  # Shape: (N, C, H, W)
        generated_images = vqvae.decoder(quantized).cpu()

        plt.figure(figsize=(10, 5))
        for i in range(num_images):
            plt.subplot(1, num_images, i + 1)
            plt.imshow(generated_images[i][0], cmap="gray")
            plt.axis("off")
        plt.show()

generate_images(vqvae, num_images=5)

png

on cifar10

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define the VectorQuantizer
class VectorQuantizer(nn.Module):
    def __init__(self, n_e, e_dim, beta):
        """
        Inputs:
        - n_e: Number of embeddings
        - e_dim: Embedding dimensionality
        - beta: Commitment cost for loss
        """
        super(VectorQuantizer, self).__init__()
        self.n_e = n_e
        self.e_dim = e_dim
        self.beta = beta
        self.embedding = nn.Embedding(n_e, e_dim)
        self.embedding.weight.data.uniform_(-1 / n_e, 1 / n_e)

    def forward(self, z):
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.e_dim)

        # Compute distances
        d = (
            torch.sum(z_flattened**2, dim=1, keepdim=True)
            + torch.sum(self.embedding.weight**2, dim=1)
            - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
        )

        # Find closest embeddings
        indices = torch.argmin(d, dim=1).unsqueeze(1)
        one_hot = torch.zeros(indices.shape[0], self.n_e, device=z.device)
        one_hot.scatter_(1, indices, 1)
        z_q = torch.matmul(one_hot, self.embedding.weight).view(z.shape)

        # Codebook and commitment losses
        loss_codebook = torch.mean((z_q.detach() - z) ** 2)
        loss_commitment = self.beta * torch.mean((z_q - z.detach()) ** 2)
        loss = loss_codebook + loss_commitment

        # Straight-through estimator
        z_q = z + (z_q - z).detach()

        # Compute perplexity
        encoding_prob = torch.mean(one_hot, dim=0)
        perplexity = torch.exp(-torch.sum(encoding_prob * torch.log(encoding_prob + 1e-10)))

        return z_q.permute(0, 3, 1, 2).contiguous(), loss, perplexity

# Define the VQ-VAE model
class VQVAE(nn.Module):
    def __init__(self, input_dim, embedding_dim, num_embeddings, beta=0.25):
        super(VQVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_dim, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, embedding_dim, 3, stride=1, padding=1),
        )
        self.quantizer = VectorQuantizer(num_embeddings, embedding_dim, beta)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, input_dim, 3, stride=1, padding=1),
        )

    def forward(self, x):
        z_e = self.encoder(x)
        z_q, vq_loss, perplexity = self.quantizer(z_e)
        x_recon = self.decoder(z_q)
        return x_recon, vq_loss, perplexity

# Load CIFAR-10 dataset
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
test_data = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True, pin_memory=True)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False, pin_memory=True)

# Training function
def train_vqvae(model, dataloader, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss, total_vq_loss, total_perplexity = 0, 0, 0
        for x, _ in dataloader:
            x = x.to(device)
            optimizer.zero_grad()

            # Forward pass
            x_recon, vq_loss, perplexity = model(x)
            recon_loss = F.mse_loss(x_recon, x)
            loss = recon_loss + vq_loss

            # Backward pass
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_vq_loss += vq_loss.item()
            total_perplexity += perplexity.item()

        print(
            f"Epoch [{epoch+1}/{epochs}], "
            f"Loss: {total_loss / len(dataloader):.4f}, "
            f"VQ Loss: {total_vq_loss / len(dataloader):.4f}, "
            f"Perplexity: {total_perplexity / len(dataloader):.4f}"
        )

# Instantiate and train the model
input_dim = 3
embedding_dim = 128
num_embeddings = 1024
beta = 0.25

vqvae = VQVAE(input_dim, embedding_dim, num_embeddings, beta).to(device)
optimizer = torch.optim.Adam(vqvae.parameters(), lr=1e-3)

train_vqvae(vqvae, train_loader, optimizer, epochs=30)

# Visualization function
def visualize_reconstruction(model, dataloader, num_images=5):
    model.eval()
    with torch.no_grad():
        for x, _ in dataloader:
            x = x.to(device)
            x_recon, _, _ = model(x)
            break

    x = x.cpu().numpy().transpose(0, 2, 3, 1)
    x_recon = x_recon.cpu().numpy().transpose(0, 2, 3, 1)

    plt.figure(figsize=(15, 5))
    for i in range(num_images):
        # Original image
        plt.subplot(2, num_images, i + 1)
        plt.imshow((x[i] * 255).astype(np.uint8))
        plt.axis("off")
        # Reconstructed image
        plt.subplot(2, num_images, i + 1 + num_images)
        plt.imshow((x_recon[i] * 255).astype(np.uint8))
        plt.axis("off")
    plt.show()

# Visualize reconstructions
visualize_reconstruction(vqvae, test_loader)
Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/30], Loss: 0.0889, VQ Loss: 0.0601, Perplexity: 12.7753
Epoch [2/30], Loss: 0.0124, VQ Loss: 0.0013, Perplexity: 45.3308
Epoch [3/30], Loss: 0.0092, VQ Loss: 0.0008, Perplexity: 81.9231
Epoch [4/30], Loss: 0.0077, VQ Loss: 0.0006, Perplexity: 111.4947
Epoch [5/30], Loss: 0.0069, VQ Loss: 0.0004, Perplexity: 132.6666
Epoch [6/30], Loss: 0.0066, VQ Loss: 0.0002, Perplexity: 147.0971
Epoch [7/30], Loss: 0.0064, VQ Loss: 0.0002, Perplexity: 159.4096
Epoch [8/30], Loss: 0.0064, VQ Loss: 0.0002, Perplexity: 168.2473
Epoch [9/30], Loss: 0.0063, VQ Loss: 0.0002, Perplexity: 174.2945
Epoch [10/30], Loss: 0.0062, VQ Loss: 0.0002, Perplexity: 177.6917
Epoch [11/30], Loss: 0.0062, VQ Loss: 0.0002, Perplexity: 180.2934
Epoch [12/30], Loss: 0.0062, VQ Loss: 0.0002, Perplexity: 180.4589
Epoch [13/30], Loss: 0.0063, VQ Loss: 0.0002, Perplexity: 179.2105
Epoch [14/30], Loss: 0.0061, VQ Loss: 0.0002, Perplexity: 179.5843
Epoch [15/30], Loss: 0.0061, VQ Loss: 0.0002, Perplexity: 179.5041
Epoch [16/30], Loss: 0.0061, VQ Loss: 0.0002, Perplexity: 178.6539
Epoch [17/30], Loss: 0.0061, VQ Loss: 0.0002, Perplexity: 178.6032
Epoch [18/30], Loss: 0.0061, VQ Loss: 0.0002, Perplexity: 178.3136
Epoch [19/30], Loss: 0.0061, VQ Loss: 0.0002, Perplexity: 178.1179
Epoch [20/30], Loss: 0.0061, VQ Loss: 0.0002, Perplexity: 178.3588
Epoch [21/30], Loss: 0.0061, VQ Loss: 0.0002, Perplexity: 177.8884
Epoch [22/30], Loss: 0.0060, VQ Loss: 0.0002, Perplexity: 178.1503
Epoch [23/30], Loss: 0.0061, VQ Loss: 0.0002, Perplexity: 177.8433
Epoch [24/30], Loss: 0.0060, VQ Loss: 0.0002, Perplexity: 177.4329
Epoch [25/30], Loss: 0.0060, VQ Loss: 0.0002, Perplexity: 177.5672
Epoch [26/30], Loss: 0.0060, VQ Loss: 0.0002, Perplexity: 177.9211
Epoch [27/30], Loss: 0.0060, VQ Loss: 0.0003, Perplexity: 177.7941
Epoch [28/30], Loss: 0.0060, VQ Loss: 0.0003, Perplexity: 177.4976
Epoch [29/30], Loss: 0.0060, VQ Loss: 0.0003, Perplexity: 177.8465
Epoch [30/30], Loss: 0.0060, VQ Loss: 0.0003, Perplexity: 177.4218

png

## Improved
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the Vector Quantizer
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, beta):
        super(VectorQuantizer, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.beta = beta
        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.embedding.weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)

    def forward(self, z):
        # Flatten spatial dimensions and compute distances
        z_flattened = z.permute(0, 2, 3, 1).contiguous().view(-1, self.embedding_dim)

        distances = (
            torch.sum(z_flattened**2, dim=1, keepdim=True)
            + torch.sum(self.embedding.weight**2, dim=1)
            - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
        )

        # Find closest embeddings
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.size(0), self.num_embeddings, device=z.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantized vectors
        quantized = torch.matmul(encodings, self.embedding.weight).view(z.shape[0], z.shape[2], z.shape[3], self.embedding_dim)
        quantized = quantized.permute(0, 3, 1, 2).contiguous()  # Back to (batch, embedding_dim, height, width)

        # Corrected Codebook Loss: Fixed z (encoder output), updates codebook
        codebook_loss = torch.mean((z.detach() - quantized) ** 2)

        # Corrected Commitment Loss: Fixed quantized, updates encoder
        commitment_loss = self.beta * torch.mean((z - quantized.detach()) ** 2)

        # Total VQ Loss
        vq_loss = codebook_loss + commitment_loss

        # Straight-through estimator for quantized vectors
        quantized = z + (quantized - z).detach()

        # Perplexity
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return vq_loss, quantized, perplexity


# Define VQ-VAE
class VQVAE(nn.Module):
    def __init__(self, input_dim, embedding_dim, num_embeddings):
        super(VQVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_dim, 128, kernel_size=4, stride=2, padding=1),  # Downsample 32x32 -> 16x16
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # Downsample 16x16 -> 8x8
            nn.ReLU(),
            nn.Conv2d(256, embedding_dim, kernel_size=3, stride=1, padding=1),  # Keep 8x8
        )
        self.quantizer = VectorQuantizer(num_embeddings, embedding_dim, beta=0.25)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embedding_dim, 256, kernel_size=4, stride=2, padding=1),  # Upsample 8x8 -> 16x16
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # Upsample 16x16 -> 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(128, input_dim, kernel_size=3, stride=1, padding=1),  # Keep 32x32
            nn.Sigmoid(),  # Output in [0, 1]
        )

    def forward(self, x):
        encoded = self.encoder(x)
        vq_loss, quantized, perplexity = self.quantizer(encoded)
        reconstructed = self.decoder(quantized)
        return reconstructed, vq_loss, perplexity


# Training Function
def train_vqvae(model, dataloader, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        total_vq_loss = 0
        total_perplexity = 0
        for x, _ in dataloader:
            x = x.to(device)
            optimizer.zero_grad()
            reconstruction, vq_loss, perplexity = model(x)
            recon_loss = F.mse_loss(reconstruction, x)
            loss = recon_loss + vq_loss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_vq_loss += vq_loss.item()
            total_perplexity += perplexity.item()

        print(
            f"Epoch [{epoch + 1}/{epochs}], "
            f"Loss: {total_loss / len(dataloader):.4f}, "
            f"VQ Loss: {total_vq_loss / len(dataloader):.4f}, "
            f"Perplexity: {total_perplexity / len(dataloader):.4f}"
        )


# Load CIFAR-10 Dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

# Initialize and Train VQ-VAE
input_dim = 3
embedding_dim = 64
num_embeddings = 512
vqvae = VQVAE(input_dim, embedding_dim, num_embeddings).to(device)
optimizer = torch.optim.Adam(vqvae.parameters(), lr=1e-3)

train_vqvae(vqvae, train_loader, optimizer, epochs=50)
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/50], Loss: 0.0659, VQ Loss: 0.0505, Perplexity: 28.9195
Epoch [2/50], Loss: 0.0124, VQ Loss: 0.0062, Perplexity: 71.2350
Epoch [3/50], Loss: 0.0118, VQ Loss: 0.0063, Perplexity: 86.9711
Epoch [4/50], Loss: 0.0117, VQ Loss: 0.0065, Perplexity: 92.3206
Epoch [5/50], Loss: 0.0117, VQ Loss: 0.0065, Perplexity: 94.5041
Epoch [6/50], Loss: 0.0117, VQ Loss: 0.0066, Perplexity: 96.1371
Epoch [7/50], Loss: 0.0117, VQ Loss: 0.0066, Perplexity: 96.0314
Epoch [8/50], Loss: 0.0116, VQ Loss: 0.0066, Perplexity: 96.5451
Epoch [9/50], Loss: 0.0115, VQ Loss: 0.0066, Perplexity: 98.4605
Epoch [10/50], Loss: 0.0115, VQ Loss: 0.0065, Perplexity: 98.5317
Epoch [11/50], Loss: 0.0115, VQ Loss: 0.0066, Perplexity: 98.4657
Epoch [12/50], Loss: 0.0115, VQ Loss: 0.0066, Perplexity: 98.1501
Epoch [13/50], Loss: 0.0114, VQ Loss: 0.0066, Perplexity: 98.8067
Epoch [14/50], Loss: 0.0114, VQ Loss: 0.0066, Perplexity: 98.9675
Epoch [15/50], Loss: 0.0114, VQ Loss: 0.0066, Perplexity: 99.4856
Epoch [16/50], Loss: 0.0114, VQ Loss: 0.0065, Perplexity: 99.6860
Epoch [17/50], Loss: 0.0114, VQ Loss: 0.0066, Perplexity: 99.8110
Epoch [18/50], Loss: 0.0114, VQ Loss: 0.0066, Perplexity: 99.8063
Epoch [19/50], Loss: 0.0113, VQ Loss: 0.0065, Perplexity: 101.4925
Epoch [20/50], Loss: 0.0113, VQ Loss: 0.0065, Perplexity: 101.5239
Epoch [21/50], Loss: 0.0113, VQ Loss: 0.0066, Perplexity: 101.4563
Epoch [22/50], Loss: 0.0113, VQ Loss: 0.0066, Perplexity: 101.4828
Epoch [23/50], Loss: 0.0113, VQ Loss: 0.0066, Perplexity: 102.3055
Epoch [24/50], Loss: 0.0113, VQ Loss: 0.0066, Perplexity: 102.2386
Epoch [25/50], Loss: 0.0113, VQ Loss: 0.0066, Perplexity: 102.8372
Epoch [26/50], Loss: 0.0112, VQ Loss: 0.0065, Perplexity: 104.6217
Epoch [27/50], Loss: 0.0112, VQ Loss: 0.0065, Perplexity: 104.6040
Epoch [28/50], Loss: 0.0112, VQ Loss: 0.0065, Perplexity: 105.4087
Epoch [29/50], Loss: 0.0111, VQ Loss: 0.0065, Perplexity: 105.5115
Epoch [30/50], Loss: 0.0112, VQ Loss: 0.0065, Perplexity: 105.6360
Epoch [31/50], Loss: 0.0111, VQ Loss: 0.0065, Perplexity: 106.3530
Epoch [32/50], Loss: 0.0111, VQ Loss: 0.0065, Perplexity: 106.4763
Epoch [33/50], Loss: 0.0111, VQ Loss: 0.0064, Perplexity: 107.1349
Epoch [34/50], Loss: 0.0111, VQ Loss: 0.0064, Perplexity: 107.1645
Epoch [35/50], Loss: 0.0111, VQ Loss: 0.0065, Perplexity: 107.1390
Epoch [36/50], Loss: 0.0111, VQ Loss: 0.0064, Perplexity: 107.2421
Epoch [37/50], Loss: 0.0111, VQ Loss: 0.0064, Perplexity: 107.1798
Epoch [38/50], Loss: 0.0111, VQ Loss: 0.0065, Perplexity: 107.2573
Epoch [39/50], Loss: 0.0111, VQ Loss: 0.0064, Perplexity: 107.2876
Epoch [40/50], Loss: 0.0111, VQ Loss: 0.0065, Perplexity: 107.3423
Epoch [41/50], Loss: 0.0111, VQ Loss: 0.0064, Perplexity: 107.1727
Epoch [42/50], Loss: 0.0110, VQ Loss: 0.0064, Perplexity: 107.5736
Epoch [43/50], Loss: 0.0110, VQ Loss: 0.0064, Perplexity: 107.9773
Epoch [44/50], Loss: 0.0110, VQ Loss: 0.0064, Perplexity: 108.0870
Epoch [45/50], Loss: 0.0110, VQ Loss: 0.0064, Perplexity: 108.1319
Epoch [46/50], Loss: 0.0110, VQ Loss: 0.0064, Perplexity: 108.1236
Epoch [47/50], Loss: 0.0110, VQ Loss: 0.0064, Perplexity: 108.8626
Epoch [48/50], Loss: 0.0109, VQ Loss: 0.0064, Perplexity: 109.7341
Epoch [49/50], Loss: 0.0109, VQ Loss: 0.0064, Perplexity: 109.9291
Epoch [50/50], Loss: 0.0109, VQ Loss: 0.0064, Perplexity: 109.6656
# Visualization Function
def visualize_reconstruction(model, dataloader, num_examples=4):
    model.eval()
    with torch.no_grad():
        for x, _ in dataloader:
            x = x.to(device)
            reconstruction, _, _ = model(x)
            plt.figure(figsize=(12, 6))
            for i in range(num_examples):
                # Original
                plt.subplot(2, num_examples, i + 1)
                plt.imshow(x[i].permute(1, 2, 0).cpu().numpy())
                plt.title("Original")
                plt.axis("off")
                # Reconstructed
                plt.subplot(2, num_examples, i + 1 + num_examples)
                plt.imshow(reconstruction[i].permute(1, 2, 0).cpu().numpy())
                plt.title("Reconstructed")
                plt.axis("off")
            break
    plt.show()

# Visualize Reconstructions
visualize_reconstruction(vqvae, test_loader, num_examples=4)

png

Benjin Zhu
Benjin Zhu
Ph.D. Candidate

A 3rd-year Ph.D. candidate at MMLab, CUHK.