Self Contained Text to Image Latent Diffusion using a Transformer core in PyTorch.
Below are some random examples (at 256 resolution) from a 100MM model trained from scratch for 260k iterations (about 32 hours on 1 A100):
a photo of a cat → an anime drawing of a super saiyan cat, artstation:
a cute great gray owl → starry night by van gogh:
Note that the model has not converged yet and could use more training.
By upsampling the positional encoding the model can also generate 512 or 1024 px images with minimal fine-tuning. See below for some examples of model fine-tuned on 100k extra 512 px images and 30k 1024 px images for about 2 hours on an A100. The images do sometimes lack global coherence at 1024 px - more to come here:
The main goal of this repo is to build an accessible diffusion model in PyTorch that is:
- fast (close to real time generation)
- small (~100MM params)
- reasonably good (of course not SOTA)
- can be trained in a reasonable amount of time on a single GPU (under 50 hours on an A100 or equivalent).
- simple self-contained codebase (model + train loop is about ~400 lines of PyTorch with little dependencies)
- uses ~ 1 million images with a focus on data quality over quantity with code provided for downloading and processing the data
The code is written in pure PyTorch with as few dependencies as possible.
- transformer_blocks.py - basic transformer building blocks relevant to the transformer denoiser
- denoiser.py - the architecture of the denoiser transformer
- train.py. The train loop uses
accelerate
so its training can scale to multiple GPUs if needed. - diffusion.py. Class to generate image from noise using reverse diffusion. Short (~60 lines) and self-contained.
- data.py. Data utils to download images/text and process necessary features for the diffusion model.
If you have your own dataset of URLs + captions, the process to train a model on the data consists of two steps:
-
Use
train.download_and_process_data
to obtain the latent and text encodings as numpy files. See for a notebook example downloading and processing 2000 images from this HuggingFace dataset. -
use the
train.main
function in an acceleratenotebook_launcher
- see for a colab notebook that trains a model on 100k images from scratch. Note that this downloads already pre-preprocessed latents and embeddings from here but you could just use whatever.npy
files you had saved from step 1.
To install the package and dependencies run:
pip install git+https://github.com/apapiu/transformer_latent_diffusion.git
PyTorch
numpy
einops
for model buildingwandb
tqdm
for logging + progress barsaccelerate
for train loop and multi-GPU supportimg2dataset
webdataset
torchvision
for data downloading and image processingdiffusers
clip
for pretrained VAE and CLIP text model
from tld.configs import LTDConfig, DenoiserConfig, TrainConfig
from tld.diffusion import DiffusionTransformer
denoiser_cfg = DenoiserConfig(n_channels=4) #configure your model here.
cfg = LTDConfig(denoiser_cfg=denoiser_cfg)
diffusion_transformer = DiffusionTransformer(cfg)
out = diffusion_transformer.generate_image_from_text(prompt="a cute cat")
from tld.train import main
from tld.configs import ModelConfig, DataConfig
data_config = DataConfig(
latent_path="latents.npy", text_emb_path="text_emb.npy", val_path="val_emb.npy"
)
model_cfg = ModelConfig(
data_config=data_config,
train_config=TrainConfig(n_epoch=100, save_model=False, compile=False, use_wandb=False),
)
main(model_cfg)
#OR in a notebook ot run the training process on 2 GPUs:
#notebook_launcher(main, model_cfg, num_processes=2)
The tests in test_diffuser.py
are a good place to start understanding the code. You can run all tests by running pytest -s
.
I have some github action configured to run tests, check linting, and build some docker images - if you're just exploring the code you can comment these out or delete the .github/workflows
folder.
Configs are in tld/configs.py
in the form of dataclasses. The default values can always be overwritten. For example: DenoiserConfig(n_layers=16)
keeps all defaults except for n_layers. You can also save your configs as JSON and load them in like so: DenoiserConfig(**json.load(file))
I try to speed up training and inference as much as possible by:
- using mixed precision for training + [sdpa]
- precompute all latent and text embeddings
- using float16 precision for inference
- using [sdpa] for the flash attention 2 + torch.compile() on pyttorch 2.0+
- use a highly performant sampler (DPM-Solver++(2M)) that gets good results in ~ 15 steps.
The time to generate a batch of 36 images (15 iterations) on a:
- T4: ~ 3.5 seconds
- A100: ~ 0.6 seconds In fact on an A100 the vae becomes the bottleneck even though it is only used once.
More examples generated with the 100MM model - click the photo to see the prompt and other params like cfg and seed:
I also fine-tuned an outpaing model on top of the original 101MM model. I had to modify the original input conv2d patch to 8 channel and initialize the mask channels parameters to zero. The rest of the architecture remained the same.
Below I apply the outpainting model repatedly to generate a somewhat consistent scenery based on the prompt "a cyberpunk marketplace":
In data.py, I have some helper functions to process images and captions. The flow is as follows:
- Use
img2dataset
to download images from a dataframe containing URLs and captions. - Use
CLIP
to encode the prompts and theVAE
to encode images to latents on a web2dataset data generator. - Save the latents and text embedding for future training.
There are two advantages to this approach. One is that the VAE encoding is somewhat expensive, so doing it every epoch would affect training times. The other is that we can discard the images after processing. For 3*256*256
images, the latent dimension is 4*32*32
, so every latent is around 4KB (when quantized in uint8; see here). This means that 1 million latents will be "only" 4GB in size, which is easy to handle even in RAM. Storing the raw images would have been 48x larger in size.
See here for the denoiser class.
The denoiser model is a Transformer-based model based on the archirtecture in DiT and Pixart-Alpha, albeit with quite a few modifications and simplifications. Using a Transformer as the denoiser is different from most diffusion models in that most other models used a CNN-based U-NET as the denoising backbone. I decided to use a Transformer for a few reasons. One was I just wanted to experiment and learn how to build and train Transformers from the ground up. Secondly, Transformers are fast both to train and to do inference on, and they will benefit most from future advances (both in hardware and in software) in performance.
Transformers are not natively built for spatial data and at first I found a lot of the outputs to be very "patchy". To remediy that I added a depth-wise convolution in the FFN layer of the transformer (this was introduced in the Local ViT paper. This allows the model to mix pixels that are close to each other with very little added compute cost.
The image latent inputs are 4*32*32
and we use a patch size of 2 to build 256 flattened 4*2*2=16
dimensional input "pixels". These are then projected into the embed dimensions are are fed through the transformer blocks.
The text and noise conditioning is very simple - we concatenate a pooled CLIP text embedding (ViT/L14
- 768-dimensional) and the sinusoidal noise embedding and feed it as input in the cross-attention layer in each transformer block. No unpooled CLIP embeddings are used.
The base model is 101MM parameters and has 12 layers and embedding dimension = 768. I train it with a batch size of 256 on a A100 and learning rate of 3e-4
. I used 1000 steps for warmup. Due to computational contraints I did not do any ablations for this configuration.
We train a denoising transformer that takes the following three inputs:
noise_level
(sampled from 0 to 1 with more values concentrated close to 0 - I use a beta distribution)- Image latent (x) corrupted with a level of random noise
- For a given
noise_level
between 0 and 1, the corruption is as follows:x_noisy = x*(1-noise_level) + eps*noise_level where eps ~ np.random.normal(0, 1)
- For a given
- CLIP embeddings of a text prompt
- You can think of this as a numerical representation of a text prompt.
- We use the pooled text embedding here (768 dimensional for
ViT/L14
)
The output is a prediction of the denoised image latent - call it f(x_noisy)
.
The model is trained to minimize the mean squared error |f(x_noisy) - x|
between the prediction and actual image
(you can also use absolute error here). Note that I don't reparameterize the loss in terms of the noise here to keep things simple.
Using this model, we then iteratively generate an image from random noise as follows:
for i in range(len(self.noise_levels) - 1):
curr_noise, next_noise = self.noise_levels[i], self.noise_levels[i + 1]
# Predict original denoised image:
x0_pred = predict_x_zero(new_img, label, curr_noise)
# New image at next_noise level is a weighted average of old image and predicted x0:
new_img = ((curr_noise - next_noise) * x0_pred + next_noise * new_img) / curr_noise
The predict_x_zero
method uses classifier free guidance by combining the conditional and unconditional
prediction: x0_pred = class_guidance * x0_pred_conditional + (1 - class_guidance) * x0_pred_unconditional
A bit of math: The approach above falls within the VDM parametrization see 3.1 in Kingma et al.:
Where
Generally,
- [] how to speed up generation even more - LCMs?
- [] add script to compute FID
- better config in the train file
- faster sampling - DDPM