We implement a simple conditional form of Diffusion Model described in Denoising Diffusion Probabilistic Models, in PyTorch. Preparing this repository, we inspired by the course How Diffusion Models Work and the repository minDiffusion. While training, we use MNIST, FashionMNIST, and Sprite (see FrootsnVeggies and kyrise) datasets.
- Install Conda, if not already installed.
- Clone the repository
git clone https://github.com/byrkbrk/diffusion-model.git
- In the directory
diffusion-model
, for macos, run:For linux or windows, run:conda env create -f diffusion-env_macos.yaml
conda env create -f diffusion-env_linux_or_windows.yaml
- Activate the environment:
conda activate diffusion-env
To train the model on MNIST dataset from scratch,
python3 train.py --dataset-name mnist
In order to sample from our (pretrained) checkpoint:
python3 sample.py pretrained_mnist_checkpoint_49.pth --n-samples 400 --n-images-per-row 20
Results (jpeg and gif files) will be saved into generated-images
directory, and are seen below where each two rows represents a class label (in total 20 rows and 10 classes).
To train the model from scratch on Fashion-MNIST dataset,
python3 train.py --dataset-name fashion_mnist
In order to sample from our (pretrained) checkpoint, run:
python3 sample.py pretrained_fashion_mnist_checkpoint_49.pth --n-samples 400 --n-images-per-row 20
Results (jpeg and gif files) will be saved into generated-images
directory, and are seen below where each two rows represents a class label (in total 20 rows and 10 classes).
To train the model from scratch on Sprite dataset:
python3 train.py --dataset-name sprite
In order to sample from our (pretrained) checkpoint, run:
python3 sample.py pretrained_sprite_checkpoint_49.pth --n-samples 225 --n-images-per-row 15
Results (jpeg and gif files) will be saved into generated-images
directory, and are seen below where each three rows represents a class label (in total 15 rows and 5 classes).