diff --git a/examples/research_projects/instructpix2pix_lora/README.md b/examples/research_projects/instructpix2pix_lora/README.md index cfcd98926c07..25f7931b47d4 100644 --- a/examples/research_projects/instructpix2pix_lora/README.md +++ b/examples/research_projects/instructpix2pix_lora/README.md @@ -2,6 +2,34 @@ This extended LoRA training script was authored by [Aiden-Frost](https://github.com/Aiden-Frost). This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py). This script provides further support add LoRA layers for unet model. +## Running locally with PyTorch +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install . +``` + +Then cd in the example folder and run +```bash +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. + + ## Training script example ```bash @@ -9,7 +37,7 @@ export MODEL_ID="timbrooks/instruct-pix2pix" export DATASET_ID="instruction-tuning-sd/cartoonization" export OUTPUT_DIR="instructPix2Pix-cartoonization" -accelerate launch finetune_instruct_pix2pix.py \ +accelerate launch train_instruct_pix2pix_lora.py \ --pretrained_model_name_or_path=$MODEL_ID \ --dataset_name=$DATASET_ID \ --enable_xformers_memory_efficient_attention \ @@ -24,7 +52,10 @@ accelerate launch finetune_instruct_pix2pix.py \ --rank=4 \ --output_dir=$OUTPUT_DIR \ --report_to=wandb \ - --push_to_hub + --push_to_hub \ + --original_image_column="original_image" \ + --edited_image_column="cartoonized_image" \ + --edit_prompt_column="edit_prompt" ``` ## Inference diff --git a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py index 997d448fa281..fcb927c680a0 100644 --- a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py +++ b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py @@ -14,7 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Script to fine-tune Stable Diffusion for InstructPix2Pix.""" +""" + Script to fine-tune Stable Diffusion for LORA InstructPix2Pix. + Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py +""" import argparse import logging @@ -30,6 +33,7 @@ import PIL import requests import torch +import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint import transformers @@ -39,21 +43,28 @@ from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel -from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, deprecate, is_wandb_available +from diffusers.training_utils import EMAModel, cast_training_params +from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.26.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__, log_level="INFO") @@ -63,6 +74,92 @@ WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"] +def save_model_card( + repo_id: str, + images: list = None, + base_model: str = None, + dataset_name: str = None, + repo_folder: str = None, +): + img_str = "" + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + model_description = f""" +# LoRA text2image fine-tuning - {repo_id} +These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n +{img_str} +""" + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "stable-diffusion", + "stable-diffusion-diffusers", + "text-to-image", + "instruct-pix2pix", + "diffusers", + "diffusers-training", + "lora", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def log_validation( + pipeline, + args, + accelerator, + generator, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + original_image = download_image(args.val_image_url) + edited_images = [] + if torch.backends.mps.is_available(): + autocast_ctx = nullcontext() + else: + autocast_ctx = torch.autocast(accelerator.device.type) + + with autocast_ctx: + for _ in range(args.num_validation_images): + edited_images.append( + pipeline( + args.validation_prompt, + image=original_image, + num_inference_steps=20, + image_guidance_scale=1.5, + guidance_scale=7, + generator=generator, + ).images[0] + ) + + for tracker in accelerator.trackers: + if tracker.name == "wandb": + wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) + for edited_image in edited_images: + wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt) + tracker.log({"validation": wandb_table}) + + return edited_images + + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.") parser.add_argument( @@ -417,11 +514,6 @@ def main(): generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) - if args.report_to == "wandb": - if not is_wandb_available(): - raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - import wandb - # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -467,49 +559,58 @@ def main(): args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) + # InstructPix2Pix uses an additional image for conditioning. To accommodate that, + # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is + # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized + # from the pre-trained checkpoints. For the extra channels added to the first layer, they are + # initialized to zero. + logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.") + in_channels = 8 + out_channels = unet.conv_in.out_channels + unet.register_to_config(in_channels=in_channels) + + with torch.no_grad(): + new_conv_in = nn.Conv2d( + in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding + ) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :in_channels, :, :].copy_(unet.conv_in.weight) + unet.conv_in = new_conv_in + # Freeze vae, text_encoder and unet vae.requires_grad_(False) text_encoder.requires_grad_(False) unet.requires_grad_(False) # referred to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py - unet_lora_parameters = [] - for attn_processor_name, attn_processor in unet.attn_processors.items(): - # Parse the attention module. - attn_module = unet - for n in attn_processor_name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - - # Set the `lora_layer` attribute of the attention-related matrices. - attn_module.to_q.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank - ) - ) - attn_module.to_k.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank - ) - ) + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 - attn_module.to_v.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank - ) - ) - attn_module.to_out[0].set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_out[0].in_features, - out_features=attn_module.to_out[0].out_features, - rank=args.rank, - ) - ) + # Freeze the unet parameters before adding adapters + unet.requires_grad_(False) - # Accumulate the LoRA params to optimize. - unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + unet_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # Add adapter and make sure the trainable params are in float32. + unet.add_adapter(unet_lora_config) + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(unet, dtype=torch.float32) # Create EMA for the unet. if args.use_ema: @@ -528,6 +629,13 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + trainable_params = filter(lambda p: p.requires_grad, unet.parameters()) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -540,7 +648,8 @@ def save_model_hook(models, weights, output_dir): model.save_pretrained(os.path.join(output_dir, "unet")) # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if weights: + weights.pop() def load_model_hook(models, input_dir): if args.use_ema: @@ -589,9 +698,9 @@ def load_model_hook(models, input_dir): else: optimizer_cls = torch.optim.AdamW - # train on only unet_lora_parameters + # train on only lora_layers optimizer = optimizer_cls( - unet_lora_parameters, + trainable_params, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -730,22 +839,27 @@ def collate_fn(examples): ) # Scheduler and math around the number of training steps. - overrode_max_train_steps = False - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes if args.max_train_steps is None: - args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch - overrode_max_train_steps = True + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, - num_training_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, ) # Prepare everything with our `accelerator`. - unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler ) if args.use_ema: @@ -765,8 +879,14 @@ def collate_fn(examples): # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - if overrode_max_train_steps: + if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) @@ -885,7 +1005,7 @@ def collate_fn(examples): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual and compute loss - model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). @@ -895,7 +1015,7 @@ def collate_fn(examples): # Backpropagate accelerator.backward(loss) if accelerator.sync_gradients: - accelerator.clip_grad_norm_(unet_lora_parameters, args.max_grad_norm) + accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() @@ -903,7 +1023,7 @@ def collate_fn(examples): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: if args.use_ema: - ema_unet.step(unet_lora_parameters) + ema_unet.step(trainable_params) progress_bar.update(1) global_step += 1 accelerator.log({"train_loss": train_loss}, step=global_step) @@ -933,6 +1053,16 @@ def collate_fn(examples): save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) + unwrapped_unet = unwrap_model(unet) + unet_lora_state_dict = convert_state_dict_to_diffusers( + get_peft_model_state_dict(unwrapped_unet) + ) + + StableDiffusionInstructPix2PixPipeline.save_lora_weights( + save_directory=save_path, + unet_lora_layers=unet_lora_state_dict, + safe_serialization=True, + ) logger.info(f"Saved state to {save_path}") logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} @@ -959,45 +1089,22 @@ def collate_fn(examples): # The models need unwrapping because for compatibility in distributed training mode. pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), - vae=accelerator.unwrap_model(vae), + unet=unwrap_model(unet), + text_encoder=unwrap_model(text_encoder), + vae=unwrap_model(vae), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, ) - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) # run inference - original_image = download_image(args.val_image_url) - edited_images = [] - if torch.backends.mps.is_available(): - autocast_ctx = nullcontext() - else: - autocast_ctx = torch.autocast(accelerator.device.type) - - with autocast_ctx: - for _ in range(args.num_validation_images): - edited_images.append( - pipeline( - args.validation_prompt, - image=original_image, - num_inference_steps=20, - image_guidance_scale=1.5, - guidance_scale=7, - generator=generator, - ).images[0] - ) + log_validation( + pipeline, + args, + accelerator, + generator, + ) - for tracker in accelerator.trackers: - if tracker.name == "wandb": - wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) - for edited_image in edited_images: - wandb_table.add_data( - wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt - ) - tracker.log({"validation": wandb_table}) if args.use_ema: # Switch back to the original UNet parameters. ema_unet.restore(unet.parameters()) @@ -1008,22 +1115,47 @@ def collate_fn(examples): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) if args.use_ema: ema_unet.copy_to(unet.parameters()) + # store only LORA layers + unet = unet.to(torch.float32) + + unwrapped_unet = unwrap_model(unet) + unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet)) + StableDiffusionInstructPix2PixPipeline.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_state_dict, + safe_serialization=True, + ) + pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), - vae=accelerator.unwrap_model(vae), - unet=unet, + text_encoder=unwrap_model(text_encoder), + vae=unwrap_model(vae), + unet=unwrap_model(unet), revision=args.revision, variant=args.variant, ) - # store only LORA layers - unet.save_attn_procs(args.output_dir) + pipeline.load_lora_weights(args.output_dir) + + images = None + if (args.val_image_url is not None) and (args.validation_prompt is not None): + images = log_validation( + pipeline, + args, + accelerator, + generator, + ) if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + dataset_name=args.dataset_name, + repo_folder=args.output_dir, + ) upload_folder( repo_id=repo_id, folder_path=args.output_dir, @@ -1031,31 +1163,6 @@ def collate_fn(examples): ignore_patterns=["step_*", "epoch_*"], ) - if args.validation_prompt is not None: - edited_images = [] - pipeline = pipeline.to(accelerator.device) - with torch.autocast(str(accelerator.device).replace(":0", "")): - for _ in range(args.num_validation_images): - edited_images.append( - pipeline( - args.validation_prompt, - image=original_image, - num_inference_steps=20, - image_guidance_scale=1.5, - guidance_scale=7, - generator=generator, - ).images[0] - ) - - for tracker in accelerator.trackers: - if tracker.name == "wandb": - wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) - for edited_image in edited_images: - wandb_table.add_data( - wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt - ) - tracker.log({"test": wandb_table}) - accelerator.end_training()