Skip to content

Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch

License

Notifications You must be signed in to change notification settings

ShoukanLabs/muse-maskgit-pytorch

 
 

Repository files navigation

Muse - Pytorch

Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch originally made by Lucidrains.

We have added additional code to allow anyone to train their own model and we have optimized the code for low end hardware.

Join us at Sygil.Dev's Discord Server Generic badge

Install

For installing the code you have two options:

1 - You can install it directly from the repo with pip:

$ pip install git+https://github.com/Sygil-Dev/muse-maskgit-pytorch

2 - or you can clone it and then install from source:

$ git clone https://github.com/Sygil-Dev/muse-maskgit-pytorch
$ cd muse-maskgit-pytorch
$ pip install .

Usage

First train your VAE - VQGanVAE

import torch
from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer

vae = VQGanVAE(
    dim = 256,
    vq_codebook_size = 512
)

# train on folder of images, as many images as possible

trainer = VQGanVAETrainer(
    vae = vae,
    image_size = 128,             # you may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it
    folder = '/path/to/images',
    batch_size = 4,
    grad_accum_every = 8,
    num_train_steps = 50000
).cuda()

trainer.train()

Then pass the trained VQGanVAE and a Transformer to MaskGit

import torch
from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer

# first instantiate your vae

vae = VQGanVAE(
    dim = 256,
    vq_codebook_size = 512
).cuda()

vae.load('/path/to/vae.pt') # you will want to load the exponentially moving averaged VAE

# then you plug the vae and transformer into your MaskGit as so

# (1) create your transformer / attention network

transformer = MaskGitTransformer(
    num_tokens = 512,         # must be same as codebook size above
    seq_len = 256,            # must be equivalent to fmap_size ** 2 in vae
    dim = 512,                # model dimension
    depth = 8,                # depth
    dim_head = 64,            # attention head dimension
    heads = 8,                # attention heads,
    ff_mult = 4,              # feedforward expansion factor
    t5_name = 't5-small',     # name of your T5
)

# (2) pass your trained VAE and the base transformer to MaskGit

base_maskgit = MaskGit(
    vae = vae,                 # vqgan vae
    transformer = transformer, # transformer
    image_size = 256,          # image size
    cond_drop_prob = 0.25,     # conditional dropout, for classifier free guidance
).cuda()

# ready your training text and images

texts = [
    'a child screaming at finding a worm within a half-eaten apple',
    'lizard running across the desert on two feet',
    'waking up to a psychedelic landscape',
    'seashells sparkling in the shallow waters'
]

images = torch.randn(4, 3, 256, 256).cuda()

# feed it into your maskgit instance, with return_loss set to True

loss = base_maskgit(
    images,
    texts = texts
)

loss.backward()

# do this for a long time on much data
# then...

images = base_maskgit.generate(texts = [
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles'
], cond_scale = 3.) # conditioning scale for classifier free guidance

images.shape # (3, 3, 256, 256)

To train the super-resolution maskgit requires you to change 1 field on MaskGit instantiation (you will need to now pass in the cond_image_size, as the previous image size being conditioned on)

Optionally, you can pass in a different VAE as cond_vae for the conditioning low-resolution image. By default it will use the vae for both tokenizing the super and low resoluted images.

import torch
import torch.nn.functional as F
from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer

# first instantiate your ViT VQGan VAE
# a VQGan VAE made of transformers

vae = VQGanVAE(
    dim = 256,
    vq_codebook_size = 512
).cuda()

vae.load('./path/to/vae.pt') # you will want to load the exponentially moving averaged VAE

# then you plug the VqGan VAE into your MaskGit as so

# (1) create your transformer / attention network

transformer = MaskGitTransformer(
    num_tokens = 512,         # must be same as codebook size above
    seq_len = 1024,           # must be equivalent to fmap_size ** 2 in vae
    dim = 512,                # model dimension
    depth = 2,                # depth
    dim_head = 64,            # attention head dimension
    heads = 8,                # attention heads,
    ff_mult = 4,              # feedforward expansion factor
    t5_name = 't5-small',     # name of your T5
)

# (2) pass your trained VAE and the base transformer to MaskGit

superres_maskgit = MaskGit(
    vae = vae,
    transformer = transformer,
    cond_drop_prob = 0.25,
    image_size = 512,                     # larger image size
    cond_image_size = 256,                # conditioning image size <- this must be set
).cuda()

# ready your training text and images

texts = [
    'a child screaming at finding a worm within a half-eaten apple',
    'lizard running across the desert on two feet',
    'waking up to a psychedelic landscape',
    'seashells sparkling in the shallow waters'
]

images = torch.randn(4, 3, 512, 512).cuda()

# feed it into your maskgit instance, with return_loss set to True

loss = superres_maskgit(
    images,
    texts = texts
)

loss.backward()

# do this for a long time on much data
# then...

images = superres_maskgit.generate(
    texts = [
        'a whale breaching from afar',
        'young girl blowing out candles on her birthday cake',
        'fireworks with blue and green sparkles',
        'waking up to a psychedelic landscape'
    ],
    cond_images = F.interpolate(images, 256),  # conditioning images must be passed in for generating from superres
    cond_scale = 3.
)

images.shape # (4, 3, 512, 512)

All together now

from muse_maskgit_pytorch import Muse

base_maskgit.load('./path/to/base.pt')

superres_maskgit.load('./path/to/superres.pt')

# pass in the trained base_maskgit and superres_maskgit from above

muse = Muse(
    base = base_maskgit,
    superres = superres_maskgit
)

images = muse([
    'a whale breaching from afar',
    'young girl blowing out candles on her birthday cake',
    'fireworks with blue and green sparkles',
    'waking up to a psychedelic landscape'
])

images # List[PIL.Image.Image]

Training

Training should be done in 4 stages.

  1. Training base VAE(swap out the dataset_name with your huggingface dataset)

    accelerate launch train_muse_vae.py --dataset_name="Isamu136/big-animal-dataset"
    
  2. Once you trained enough in the base VAE, move the checkpoint of your latest version to a new location. Then, do

    accelerate launch train_muse_maskgit.py --dataset_name="Isamu136/big-animal-dataset" --vae_path=path_to_vae_checkpoint
    

    Alternatively, if you want to use a pretrained autoencoder, download one from here and then extract it. In the below code, we are using vqgan_imagenet_f16_1024. Change the paths accordingly

    accelerate launch train_muse_maskgit.py --dataset_name="Isamu136/big-animal-dataset" --taming_model_path="models/image_net_f16/ckpts/last.ckpt" --taming_config_path="models/image_net_f16/configs/model.yaml" --validation_prompt="elephant"
    

    or if you want to train on cifar10, try

accelerate launch train_muse_maskgit.py --dataset_name="cifar10" --taming_model_path="models/image_net_f16/ckpts/last.ckpt" --taming_config_path="models/image_net_f16/configs/model.yaml" --validation_prompt="0" --image_column="img" --caption_column="label"

Checkpoints and Pretrained Models

We currently do not have any usable pretrained model for Muse but we are trying to train it with whatever resources we have available, for more information check the Sygil Muse repository on HuggingFace where we are uploading the checkpoints for different tests we have performed and where we will be uploading the final weights once we have something everyone can use.

Appreciation

  • Lucidrains for the original Muse-Maskgit-Pytorch implementation.

  • The ShoukanLabs team for contributing so much to improving the code and adding new features.

  • StabilityAI for the sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.

  • 🤗 Huggingface for the transformers and accelerate library, both which are wonderful

Todo

  • test end-to-end

  • separate cond_images_or_ids, it is not done right

  • add training code for vae

  • add optional self-conditioning on embeddings

  • combine with token critic paper, already implemented at Phenaki

  • hook up accelerate training code for maskgit

  • train a base model

Citations

@inproceedings{Chang2023MuseTG,
    title   = {Muse: Text-To-Image Generation via Masked Generative Transformers},
    author  = {Huiwen Chang and Han Zhang and Jarred Barber and AJ Maschinot and Jos{\'e} Lezama and Lu Jiang and Ming-Hsuan Yang and Kevin P. Murphy and William T. Freeman and Michael Rubinstein and Yuanzhen Li and Dilip Krishnan},
    year    = {2023}
}
@article{Chen2022AnalogBG,
    title   = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
    author  = {Ting Chen and Ruixiang Zhang and Geo rey E. Hinton},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.04202}
}
@misc{jabri2022scalable,
    title   = {Scalable Adaptive Computation for Iterative Generation},
    author  = {Allan Jabri and David Fleet and Ting Chen},
    year    = {2022},
    eprint  = {2212.11972},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@article{Lezama2022ImprovedMI,
    title   = {Improved Masked Image Generation with Token-Critic},
    author  = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2209.04439}
}
@inproceedings{Nijkamp2021SCRIPTSP,
    title   = {SCRIPT: Self-Critic PreTraining of Transformers},
    author  = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong},
    booktitle = {North American Chapter of the Association for Computational Linguistics},
    year    = {2021}
}
@misc{gilmer2023intriguing
    title  = {Intriguing Properties of Transformer Training Instabilities},
    author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
    year   = {2023},
    status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}

About

Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.0%
  • Shell 1.0%