Skip to content

Commit

Permalink
Add model class for running transformer with precomputed text latents (
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Nov 22, 2024
1 parent ba8ca02 commit cb14024
Show file tree
Hide file tree
Showing 4 changed files with 647 additions and 5 deletions.
160 changes: 159 additions & 1 deletion diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from diffusion.models.pixel_diffusion import PixelDiffusion
from diffusion.models.precomputed_text_latent_diffusion import PrecomputedTextLatentDiffusion
from diffusion.models.stable_diffusion import StableDiffusion
from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT
from diffusion.models.t2i_transformer import ComposerPrecomputedTextLatentsToImageMMDiT, ComposerTextToImageMMDiT
from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer
from diffusion.models.transformer import DiffusionTransformer
from diffusion.schedulers.schedulers import ContinuousTimeScheduler
Expand Down Expand Up @@ -1008,6 +1008,164 @@ def text_to_image_transformer(
return model


def precomputed_text_latents_to_image_transformer(
vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix',
autoencoder_path: Optional[str] = None,
autoencoder_local_path: str = '/tmp/autoencoder_weights.pt',
include_text_encoders: bool = False,
text_encoder_dtype: str = 'bfloat16',
cache_dir: str = '/tmp/hf_files',
num_layers: int = 28,
max_image_side: int = 1280,
conditioning_features: int = 768,
conditioning_max_sequence_length: int = 512 + 77,
num_register_tokens: int = 0,
patch_size: int = 2,
latent_mean: Union[float, Tuple, str] = 0.0,
latent_std: Union[float, Tuple, str] = 7.67754318618,
timestep_mean: float = 0.0,
timestep_std: float = 1.0,
timestep_shift: float = 1.0,
image_key: str = 'image',
t5_latent_key: str = 'T5_LATENTS',
t5_mask_key: str = 'T5_ATTENTION_MASK',
clip_latent_key: str = 'CLIP_LATENTS',
clip_mask_key: str = 'CLIP_ATTENTION_MASK',
clip_pooled_key: str = 'CLIP_POOLED',
pretrained: bool = False,
):
"""Text to image transformer training setup.
Args:
vae_model_name (str): Name of the VAE model to load. Defaults to 'madebyollin/sdxl-vae-fp16-fix'.
autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified,
will use the vae from `model_name`. Default `None`.
include_text_encoders (bool): Whether to include text encoders in the model. Should only do this for running
inference. Default: `False`.
text_encoder_dtype (str): The dtype to use for the text encoder. One of [`float32`, `float16`, `bfloat16`].
Default: `bfloat16`.
cache_dir (str): Directory to cache the model in if using `include_text_encoders`. Default: `'/tmp/hf_files'`.
autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`.
num_layers (int): Number of layers in the transformer. Number of heads and layer width are determined by
this according to `num_features = 64 * num_layers`, and `num_heads = num_layers`. Default: `28`.
max_image_side (int): Maximum side length of the image. Default: `1280`.
conditioning_features (int): Number of features in the conditioning transformer. Default: `768`.
conditioning_max_sequence_length (int): Maximum sequence length for the conditioning transformer. Default: `77`.
num_register_tokens (int): Number of additional register tokens to use. Default: `0`.
patch_size (int): Patch size for the transformer. Default: `2`.
latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value,
a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder
checkpoint. Defaults to `0.0`.
latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value,
a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder
checkpoint. Defaults to `1/0.13025`.
timestep_mean (float): The mean of the timesteps. Default: `0.0`.
timestep_std (float): The std. dev. of the timesteps. Default: `1.0`.
timestep_shift (float): The shift of the timesteps. Default: `1.0`.
image_key (str): The key for the image in the batch. Default: `image`.
t5_latent_key (str): The key to use for the T5 latents in the precomputed latents. Default: `'T5_LATENTS'`.
t5_mask_key (str): The key to use for the T5 attention mask in the precomputed latents. Default: `'T5_ATTENTION_MASK'`.
clip_latent_key (str): The key to use for the CLIP latents in the precomputed latents. Default: `'CLIP_LATENTS'`.
clip_mask_key (str): The key to use for the CLIP attention mask in the precomputed latents. Default: `'CLIP_ATTENTION_MASK'`.
clip_pooled_key (str): The key to use for the CLIP pooled in the precomputed latents. Default: `'CLIP_POOLED'`.
pretrained (bool): Whether to load pretrained weights. Not used. Defaults to False.
"""
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)

precision = torch.float16
# Make the autoencoder
if autoencoder_path is None:
if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics':
raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.')
downsample_factor = 8
autoencoder_channels = 4
# Use the pretrained vae
try:
vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=precision)
except: # for handling SDXL vae fp16 fixed checkpoint
vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=precision)
else:
# Use a custom autoencoder
vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision)
if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'):
raise ValueError(
'Must specify latent scale when using a custom autoencoder without tracking latent statistics.')
if isinstance(latent_mean, str) and latent_mean == 'latent_statistics':
assert isinstance(latent_statistics, dict)
latent_mean = tuple(latent_statistics['latent_channel_means'])
if isinstance(latent_std, str) and latent_std == 'latent_statistics':
assert isinstance(latent_statistics, dict)
latent_std = tuple(latent_statistics['latent_channel_stds'])
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)
autoencoder_channels = vae.config['latent_channels']
assert isinstance(vae, torch.nn.Module)
if isinstance(latent_mean, float):
latent_mean = (latent_mean,) * autoencoder_channels
if isinstance(latent_std, float):
latent_std = (latent_std,) * autoencoder_channels
assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple)
# Figure out the maximum input sequence length
input_max_sequence_length = math.ceil(max_image_side / (downsample_factor * patch_size))
# Make the transformer model
transformer = DiffusionTransformer(num_features=64 * num_layers,
num_heads=num_layers,
num_layers=num_layers,
input_features=autoencoder_channels * (patch_size**2),
input_max_sequence_length=input_max_sequence_length,
input_dimension=2,
conditioning_features=64 * num_layers,
conditioning_max_sequence_length=conditioning_max_sequence_length,
conditioning_dimension=1,
expansion_factor=4,
num_register_tokens=num_register_tokens)

# Optionally load the tokenizers and text encoders
t5_tokenizer, t5_encoder, clip_tokenizer, clip_encoder = None, None, None, None
if include_text_encoders:
dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}
dtype = dtype_map[text_encoder_dtype]
t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True)
clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
subfolder='tokenizer',
cache_dir=cache_dir,
local_files_only=False)
t5_encoder = AutoModel.from_pretrained('google/t5-v1_1-xxl',
torch_dtype=dtype,
cache_dir=cache_dir,
local_files_only=False).encoder.eval()
clip_encoder = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0',
subfolder='text_encoder',
torch_dtype=dtype,
cache_dir=cache_dir,
local_files_only=False).cuda().eval()

# Make the composer model
model = ComposerPrecomputedTextLatentsToImageMMDiT(model=transformer,
autoencoder=vae,
t5_tokenizer=t5_tokenizer,
t5_encoder=t5_encoder,
clip_tokenizer=clip_tokenizer,
clip_encoder=clip_encoder,
latent_mean=latent_mean,
latent_std=latent_std,
patch_size=patch_size,
downsample_factor=downsample_factor,
latent_channels=autoencoder_channels,
timestep_mean=timestep_mean,
timestep_std=timestep_std,
timestep_shift=timestep_shift,
image_key=image_key,
t5_latent_key=t5_latent_key,
t5_mask_key=t5_mask_key,
clip_latent_key=clip_latent_key,
clip_mask_key=clip_mask_key,
clip_pooled_key=clip_pooled_key)

if torch.cuda.is_available():
model = DeviceGPU().module_to_device(model)
return model


def build_autoencoder(input_channels: int = 3,
output_channels: int = 3,
hidden_channels: int = 128,
Expand Down
Loading

0 comments on commit cb14024

Please sign in to comment.