Skip to content

EugenHotaj/pytorch-generative

Repository files navigation

pytorch-generative

pytorch-generative is a Python library which makes generative modeling in PyTorch easier by providing:

  • high quality reference implementations of SOTA generative models
  • useful abstractions of common building blocks found in the literature
  • utilities for training, debugging, and working with Google Colab
  • integration with TensorBoard for easy metrics visualization

To get started, click on one of the links below.

Installation

To install pytorch-generative, clone the repository and install the requirements:

git clone https://www.github.com/EugenHotaj/pytorch-generative
cd pytorch-generative
pip install -r requirements.txt

After installation, run the tests to sanity check that everything works:

python -m unittest discover

Reproducing Results

All our models implement a reproduce function with all the hyperparameters necessary to reproduce the results listed in the supported algorithms section. This makes it very easy to reproduce any results using our training script, for example:

python train.py --model image_gpt --logdir /tmp/run --use-cuda

Training metrics will periodically be logged to TensorBoard for easy visualization. To view these metrics, launch a local TensorBoard server:

tensorboard --logdir /tmp/run

To run the model on a different dataset, with different hyperparameters, etc, simply modify its reproduce function and rerun the commands above.

Google Colab

To use pytorch-generative in Google Colab, clone the repository and move it into the top-level directory:

!git clone https://www.github.com/EugenHotaj/pytorch-generative
!mv pytorch-generative/pytorch_generative pytorch-generative

You can then import pytorch-generative like any other library:

import pytorch_generative as pg_nn
from pytorch_generative import models
...

Example - ImageGPT

Supported models are implemented as PyTorch Modules and are easy to use:

from pytorch_generative import models

... # Data loading code.

model = models.ImageGPT(in_channels=1, out_channels=1, in_size=28)
model(batch)

Alternatively, lower level building blocks in pytorch_generative.nn can be used to write models from scratch. We show how to implement a convolutional ImageGPT model below:

from torch import nn

from pytorch_generative import nn as pg_nn


class TransformerBlock(nn.Module):
  """An ImageGPT Transformer block."""

  def __init__(self, 
               n_channels, 
               n_attention_heads):
    """Initializes a new TransformerBlock instance.
    
    Args:
      n_channels: The number of input and output channels.
      n_attention_heads: The number of attention heads to use.
    """
    super().__init__()
    self._ln1 = pg_nn.NCHWLayerNorm(n_channels)
    self._ln2 = pg_nn.NCHWLayerNorm(n_channels)
    self._attn = pg_nn.CausalAttention(
        in_channels=n_channels,
        embed_channels=n_channels,
        out_channels=n_channels,
        n_heads=n_attention_heads,
        mask_center=False)
    self._out = nn.Sequential(
        nn.Conv2d(
            in_channels=n_channels, 
            out_channels=4*n_channels, 
            kernel_size=1),
        nn.GELU(),
        nn.Conv2d(
            in_channels=4*n_channels, 
            out_channels=n_channels, 
            kernel_size=1))

  def forward(self, x):
    x = x + self._attn(self._ln1(x))
    return x + self._out(self._ln2(x))


class ImageGPT(nn.Module):
  """The ImageGPT Model."""
  
  def __init__(self,       
               in_channels,
               out_channels,
               in_size,
               n_transformer_blocks=8,
               n_attention_heads=4,
               n_embedding_channels=16):
    """Initializes a new ImageGPT instance.
    
    Args:
      in_channels: The number of input channels.
      out_channels: The number of output channels.
      in_size: Size of the input images. Used to create positional encodings.
      n_transformer_blocks: Number of TransformerBlocks to use.
      n_attention_heads: Number of attention heads to use.
      n_embedding_channels: Number of attention embedding channels to use.
    """
    super().__init__()
    self._pos = nn.Parameter(torch.zeros(1, in_channels, in_size, in_size))
    self._input = pg_nn.CausalConv2d(
        mask_center=True,
        in_channels=in_channels,
        out_channels=n_embedding_channels,
        kernel_size=3,
        padding=1)
    self._transformer = nn.Sequential(
        *[TransformerBlock(n_channels=n_embedding_channels,
                         n_attention_heads=n_attention_heads)
          for _ in range(n_transformer_blocks)])
    self._ln = pg_nn.NCHWLayerNorm(n_embedding_channels)
    self._out = nn.Conv2d(in_channels=n_embedding_channels,
                          out_channels=out_channels,
                          kernel_size=1)

  def forward(self, x):
    x = self._input(x + self._pos)
    x = self._transformer(x)
    x = self._ln(x)
    return self._out(x)

Supported Algorithms

pytorch-generative supports the following algorithms.

We train likelihood based models on dynamically Binarized MNIST and report the log likelihood in the tables below.

Autoregressive Models

Algorithm Binarized MNIST (nats) Links
PixelSNAIL 78.61 Code, Paper
ImageGPT 79.17 Code, Paper
Gated PixelCNN 81.50 Code, Paper
PixelCNN 81.45 Code, Paper
MADE 84.87 Code, Paper
NADE 85.65 Code, Paper
FVSBN 96.58 Code, Paper

Variational Autoencoders

NOTE: The results below are the (variational) upper bound on the negative log likelihod (or equivalently, the lower bound on the log likelihod).

Algorithm Binarized MNIST (nats) Links
VD-VAE <= 80.72 Code, Paper
VAE <= 86.77 Code, Paper
BetaVAE N/A Code, Paper
VQ-VAE N/A Code, Paper
VQ-VAE-2 N/A Code, Paper

Normalizing Flows

NOTE: Bits per dimension (bits/dim) can be calculated as (nll / 784 + log(256)) / log(2) where 784 is the MNIST dimension, log(256) accounts for dequantizing pixel values, and log(2.0) converts from natural log to base 2.

Algorithm MNIST (bits/dim) Links
NICE 4.34 Code, Paper

Miscellaneous

Algorithm Links
Mixture Models Code, Wiki
Kernel Density Estimators Code, Wiki
Nerual Style Transfer Code, Blog, Paper
Compositional Pattern Producing Networks Code, Wiki