diff --git a/Makefile b/Makefile index 5d04b03a84..60cb27abde 100644 --- a/Makefile +++ b/Makefile @@ -55,6 +55,7 @@ slow_tests_deepspeed: test_installs slow_tests_diffusers: test_installs python -m pip install git+https://github.com/huggingface/diffusers.git python -m pytest tests/test_diffusers.py -v -s -k "test_no_" + python -m pytest tests/test_diffusers.py -v -s -k "test_textual_inversion" # Run text-generation non-regression tests slow_tests_text_generation_example: test_installs diff --git a/docs/source/tutorials/stable_diffusion.mdx b/docs/source/tutorials/stable_diffusion.mdx index 724db5b7f5..c662005a5f 100644 --- a/docs/source/tutorials/stable_diffusion.mdx +++ b/docs/source/tutorials/stable_diffusion.mdx @@ -105,33 +105,13 @@ There are two different checkpoints for Stable Diffusion 2: -## Tips - -To accelerate your Stable Diffusion pipeline, you can run it in full *bfloat16* precision. -This will also save memory. -You just need to pass `torch_dtype=torch.bfloat16` to `from_pretrained` when instantiating your pipeline. -Here is how to do it: - -```py -import torch - -pipeline = GaudiStableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", - scheduler=scheduler, - use_habana=True, - use_hpu_graphs=True, - gaudi_config="Habana/stable-diffusion", - torch_dtype=torch.bfloat16 -) -``` - -# Super-resolution +## Super-resolution The Stable Diffusion upscaler diffusion model was created by the researchers and engineers from CompVis, Stability AI, and LAION. It is used to enhance the resolution of input images by a factor of 4. See [here](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/upscale) for more information. -## How to upscale low resolution images? +### How to upscale low resolution images? To generate RGB and depth images with Stable Diffusion Upscale on Gaudi, you need to instantiate two instances: - A pipeline with [`GaudiStableDiffusionUpscalePipeline`](../package_reference/stable_diffusion_pipeline#optimum.habana.diffusers.GaudiStableDiffusionUpscalePipeline). @@ -172,4 +152,32 @@ pipeline = GaudiStableDiffusionUpscalePipeline.from_pretrained( upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] upscaled_image.save("upsampled_cat.png") -``` \ No newline at end of file +``` + + +## Tips + +To accelerate your Stable Diffusion pipeline, you can run it in full *bfloat16* precision. +This will also save memory. +You just need to pass `torch_dtype=torch.bfloat16` to `from_pretrained` when instantiating your pipeline. +Here is how to do it: + +```py +import torch + +pipeline = GaudiStableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + scheduler=scheduler, + use_habana=True, + use_hpu_graphs=True, + gaudi_config="Habana/stable-diffusion", + torch_dtype=torch.bfloat16 +) +``` + + +## Textual Inversion Fine-Tuning + +[Textual Inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like Stable Diffusion on your own images using just 3-5 examples. + +You can find [here](https://github.com/huggingface/optimum-habana/blob/main/examples/stable-diffusion/textual_inversion.py) an example script that implements this training method. diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index 51b3897bca..21b407f3e8 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -115,3 +115,98 @@ python text_to_image_generation.py \ > - use [the latest checkpoint](https://huggingface.co/Intel/ldm3d-4c) for generating improved results > - use [the pano checkpoint](https://huggingface.co/Intel/ldm3d-pano) to generate panoramic view + +## Textual Inversion + +[Textual Inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like Stable Diffusion on your own images using just 3-5 examples. +The `textual_inversion.py` script shows how to implement the training procedure on Habana Gaudi. + + +### Cat toy example + +Let's get our dataset. For this example, we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example . + +Let's first download it locally: + +```py +from huggingface_hub import snapshot_download + +local_dir = "./cat" +snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes") +``` + +This will be our training data. +Now we can launch the training using: + +```bash +python textual_inversion.py \ + --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \ + --train_data_dir ./cat \ + --learnable_property object \ + --placeholder_token "" \ + --initializer_token toy \ + --resolution 512 \ + --train_batch_size 4 \ + --max_train_steps 3000 \ + --learning_rate 5.0e-04 \ + --scale_lr \ + --lr_scheduler constant \ + --lr_warmup_steps 0 \ + --output_dir /tmp/textual_inversion_cat \ + --save_as_full_pipeline \ + --gaudi_config_name Habana/stable-diffusion \ + --throughput_warmup_steps 3 +``` + +> Change `--resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model. + +> As described in [the official paper](https://arxiv.org/abs/2208.01618), only one embedding vector is used for the placeholder token, *e.g.* `""`. However, one can also add multiple embedding vectors for the placeholder token to increase the number of fine-tuneable parameters. This can help the model to learn more complex details. To use multiple embedding vectors, you can define `--num_vectors` to a number larger than one, *e.g.*: `--num_vectors 5`. The saved textual inversion vectors will then be larger in size compared to the default case. + + +### Multi-card Run + +You can run this fine-tuning script in a distributed fashion as follows: +```bash +python ../gaudi_spawn.py --use_mpi --world_size 8 textual_inversion.py \ + --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \ + --train_data_dir ./cat \ + --learnable_property object \ + --placeholder_token '""' \ + --initializer_token toy \ + --resolution 512 \ + --train_batch_size 4 \ + --max_train_steps 375 \ + --learning_rate 5.0e-04 \ + --scale_lr \ + --lr_scheduler constant \ + --lr_warmup_steps 0 \ + --output_dir /tmp/textual_inversion_cat \ + --save_as_full_pipeline \ + --gaudi_config_name Habana/stable-diffusion \ + --throughput_warmup_steps 3 +``` + + +### Inference + +Once you have trained a model as described right above, inference can be done simply using the `GaudiStableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt. + +```python +import torch +from optimum.habana.diffusers import GaudiStableDiffusionPipeline + +model_id = "path-to-your-trained-model" +pipe = GaudiStableDiffusionPipeline.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + use_habana=True, + use_hpu_graphs=True, + gaudi_config="Habana/stable-diffusion", +) + +prompt = "A backpack" + +image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] + +image.save("cat-backpack.png") +``` diff --git a/examples/stable-diffusion/textual_inversion.py b/examples/stable-diffusion/textual_inversion.py new file mode 100644 index 0000000000..9f81d78885 --- /dev/null +++ b/examples/stable-diffusion/textual_inversion.py @@ -0,0 +1,1011 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import json +import logging +import math +import os +import random +import shutil +import time +import warnings +from pathlib import Path + +import diffusers +import numpy as np +import PIL +import safetensors +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from huggingface_hub import create_repo, upload_folder + +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from optimum.habana import GaudiConfig +from optimum.habana.accelerate import GaudiAccelerator +from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionPipeline +from optimum.habana.utils import set_seed + + +if is_wandb_available(): + import wandb + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.23.0") + +logger = get_logger(__name__) + + +def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- textual_inversion +inference: true +--- + """ + model_card = f""" +# Textual inversion text2image fine-tuning - {repo_id} +These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline (note: unet and vae are loaded again in float32) + pipeline = GaudiStableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + unet=unet, + vae=vae, + safety_checker=None, + revision=args.revision, + variant=args.variant, + use_habana=True, + use_hpu_graphs=True, + gaudi_config=args.gaudi_config_name, + ) + pipeline.scheduler = GaudiDDIMScheduler.from_config(pipeline.scheduler.config) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + return images + + +def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True): + logger.info("Saving embeddings") + learned_embeds = ( + accelerator.unwrap_model(text_encoder) + .get_input_embeddings() + .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] + ) + learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} + + if safe_serialization: + safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"}) + else: + torch.save(learned_embeds_dict, save_path) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Save learned_embeds.bin every X updates steps.", + ) + parser.add_argument( + "--save_as_full_pipeline", + action="store_true", + help="Save the complete stable diffusion pipeline.", + ) + parser.add_argument( + "--num_vectors", + type=int, + default=1, + help="How many textual inversion vectors shall be used to learn the concept.", + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data." + ) + parser.add_argument( + "--placeholder_token", + type=str, + default=None, + required=True, + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word." + ) + parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'") + parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution." + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--bf16", + action="store_true", + default=False, + help=("Whether to use bf16 mixed precision."), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=None, + help=( + "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--no_safe_serialization", + action="store_true", + help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.", + ) + parser.add_argument( + "--gaudi_config_name", + type=str, + default=None, + help="Local path to the Gaudi configuration file or its name on the Hugging Face Hub.", + ) + parser.add_argument( + "--throughput_warmup_steps", + type=int, + default=0, + help=( + "Number of steps to ignore for throughput calculation. For example, with throughput_warmup_steps=N, the" + " first N steps will not be considered in the calculation of the throughput. This is especially useful in" + " lazy mode." + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.train_data_dir is None: + raise ValueError("You must specify a train data directory.") + + return args + + +imagenet_templates_small = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +imagenet_style_templates_small = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + + +class TextualInversionDataset(Dataset): + def __init__( + self, + data_root, + tokenizer, + learnable_property="object", # [object, style] + size=512, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + center_crop=False, + ): + self.data_root = data_root + self.tokenizer = tokenizer + self.learnable_property = learnable_property + self.size = size + self.placeholder_token = placeholder_token + self.center_crop = center_crop + self.flip_p = flip_p + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + self.num_images = len(self.image_paths) + self._length = self.num_images + + if set == "train": + self._length = self.num_images * repeats + + self.interpolation = { + "linear": PIL_INTERPOLATION["linear"], + "bilinear": PIL_INTERPOLATION["bilinear"], + "bicubic": PIL_INTERPOLATION["bicubic"], + "lanczos": PIL_INTERPOLATION["lanczos"], + }[interpolation] + + self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small + self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + placeholder_string = self.placeholder_token + text = random.choice(self.templates).format(placeholder_string) + + example["input_ids"] = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids[0] + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + ( + h, + w, + ) = ( + img.shape[0], + img.shape[1], + ) + img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] + + image = Image.fromarray(img) + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip_transform(image) + image = np.array(image).astype(np.uint8) + image = (image / 127.5 - 1.0).astype(np.float32) + + example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) + return example + + +def main(): + args = parse_args() + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name) + + accelerator = GaudiAccelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision="bf16" if gaudi_config.use_torch_autocast or args.bf16 else "no", + log_with=args.report_to, + project_config=accelerator_project_config, + force_autocast=gaudi_config.use_torch_autocast or args.bf16, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + import habana_frameworks.torch.core as htcore + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load tokenizer + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + + # Load scheduler and models + noise_scheduler = GaudiDDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ).to(accelerator.device) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + # Add the placeholder token in tokenizer + placeholder_tokens = [args.placeholder_token] + + if args.num_vectors < 1: + raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}") + + # add dummy tokens for multi-vector + additional_tokens = [] + for i in range(1, args.num_vectors): + additional_tokens.append(f"{args.placeholder_token}_{i}") + placeholder_tokens += additional_tokens + + num_added_tokens = tokenizer.add_tokens(placeholder_tokens) + if num_added_tokens != args.num_vectors: + raise ValueError( + f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + + # Convert the initializer_token, placeholder_token to ids + token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + # Check if initializer_token is a single token or a sequence of tokens + if len(token_ids) > 1: + raise ValueError("The initializer token must be a single token.") + + initializer_token_id = token_ids[0] + placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + with torch.no_grad(): + for token_id in placeholder_token_ids: + token_embeds[token_id] = token_embeds[initializer_token_id].clone() + + # Freeze vae and unet + vae.requires_grad_(False) + unet.requires_grad_(False) + # Freeze all parameters except for the token embeddings in text encoder + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + + if args.gradient_checkpointing: + # Keep unet in train mode if we are using gradient checkpointing to save memory. + # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode. + unet.train() + text_encoder.gradient_checkpointing_enable() + unet.enable_gradient_checkpointing() + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if gaudi_config.use_fused_adam: + from habana_frameworks.torch.hpex.optimizers import FusedAdamW + + optimizer_cls = FusedAdamW + else: + optimizer_cls = torch.optim.AdamW + optimizer = optimizer_cls( + text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Dataset and DataLoaders creation: + train_dataset = TextualInversionDataset( + data_root=args.train_data_dir, + tokenizer=tokenizer, + size=args.resolution, + placeholder_token=(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))), + repeats=args.repeats, + learnable_property=args.learnable_property, + center_crop=args.center_crop, + set="train", + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers + ) + if args.validation_epochs is not None: + warnings.warn( + f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}." + " Deprecated validation_epochs in favor of `validation_steps`" + f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}", + FutureWarning, + stacklevel=2, + ) + args.validation_steps = args.validation_epochs * len(train_dataset) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + ) + + text_encoder.train() + # Prepare everything with our `accelerator`. + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if gaudi_config.use_torch_autocast or args.bf16: + weight_dtype = torch.bfloat16 + + # Move vae and unet to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # keep original embeddings as reference + orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone() + + t0 = None + + for epoch in range(first_epoch, args.num_train_epochs): + text_encoder.train() + for step, batch in enumerate(train_dataloader): + if t0 is None and global_step == args.throughput_warmup_steps: + t0 = time.perf_counter() + + with accelerator.accumulate(text_encoder): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype) + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + htcore.mark_step() + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + htcore.mark_step() + + # Let's make sure we don't update any embedding weights besides the newly added token + index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) + index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False + + with torch.no_grad(): + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + images = [] + progress_bar.update(1) + global_step += 1 + if global_step % args.save_steps == 0: + weight_name = ( + f"learned_embeds-steps-{global_step}.bin" + if args.no_safe_serialization + else f"learned_embeds-steps-{global_step}.safetensors" + ) + save_path = os.path.join(args.output_dir, weight_name) + save_progress( + text_encoder, + placeholder_token_ids, + accelerator, + args, + save_path, + safe_serialization=not args.no_safe_serialization, + ) + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + images = log_validation( + text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + duration = time.perf_counter() - t0 + throughput = args.max_train_steps * total_batch_size / duration + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + logger.info(f"Throughput = {throughput} samples/s") + logger.info(f"Train runtime = {duration} seconds") + metrics = { + "train_samples_per_second": throughput, + "train_runtime": duration, + } + with open(f"{args.output_dir}/speed_metrics.json", mode="w") as file: + json.dump(metrics, file) + if args.push_to_hub and not args.save_as_full_pipeline: + logger.warning("Enabling full model saving because --push_to_hub=True was specified.") + save_full_model = True + else: + save_full_model = args.save_as_full_pipeline + if save_full_model: + pipeline = GaudiStableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder), + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=noise_scheduler, + ) + pipeline.save_pretrained(args.output_dir) + # Save the newly trained embeddings + weight_name = "learned_embeds.bin" if args.no_safe_serialization else "learned_embeds.safetensors" + save_path = os.path.join(args.output_dir, weight_name) + save_progress( + text_encoder, + placeholder_token_ids, + accelerator, + args, + save_path, + safe_serialization=not args.no_safe_serialization, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py index 3da8ab93b3..b47b75a3c4 100644 --- a/optimum/habana/accelerate/accelerator.py +++ b/optimum/habana/accelerate/accelerator.py @@ -23,6 +23,7 @@ from collections import OrderedDict from contextlib import contextmanager from dataclasses import make_dataclass +from types import MethodType import torch from accelerate import Accelerator @@ -47,6 +48,7 @@ ProjectConfiguration, RNGType, check_os_kernel, + convert_outputs_to_fp32, is_deepspeed_available, parse_choice_from_env, ) @@ -98,6 +100,7 @@ def __init__( kwargs_handlers: list[KwargsHandler] | None = None, dynamo_backend: GaudiDynamoBackend | str | None = None, distribution_strategy: str = None, + force_autocast: bool = False, ): self.trackers = [] if project_config is not None: @@ -248,6 +251,8 @@ def __init__( self._distribution_strategy = distribution_strategy + self.force_autocast = force_autocast + check_os_kernel() @property @@ -325,20 +330,18 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e elif device_placement and not self.verify_device_map(model): model = model.to(self.device) - # The following block is commented because forward+backward+loss is already wrapped with autocast in Trainer - # if self.native_amp: - # model._original_forward = model.forward - # model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward - # if self.mixed_precision == "bf16": - # new_forward = torch.autocast(device_type=self.state.device.type, dtype=torch.bfloat16)( - # model_forward_func - # ) + # The following block is executed only when force_autocast is True + # because forward+backward+loss is already wrapped with autocast in Trainer + if self.native_amp and self.force_autocast: + model._original_forward = model.forward + model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward + new_forward = torch.autocast(device_type=self.state.device.type, dtype=torch.bfloat16)(model_forward_func) - # if hasattr(model.forward, "__func__"): - # model.forward = MethodType(new_forward, model) - # model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) - # else: - # model.forward = convert_outputs_to_fp32(new_forward) + if hasattr(model.forward, "__func__"): + model.forward = MethodType(new_forward, model) + model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) + else: + model.forward = convert_outputs_to_fp32(new_forward) # FP8 is not supported on Gaudi2 yet # elif self.mixed_precision == "fp8": # if not has_transformer_engine_layers(model): diff --git a/optimum/habana/accelerate/state.py b/optimum/habana/accelerate/state.py index 1944bcfdfc..b8626d34f0 100644 --- a/optimum/habana/accelerate/state.py +++ b/optimum/habana/accelerate/state.py @@ -41,9 +41,13 @@ def __init__(self, cpu: bool = False, **kwargs): self.device = torch.device(env_device) if env_device is not None else None self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE") - if int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu: - from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu + # initialize_distributed_hpu is already called in the __init__ of + # habana_frameworks.torch.distributed.hccl + # It is necessary so that the env variable LOCAL_RANK is set before the + # conditional statement right below + from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu + if int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu: world_size, rank, local_rank = initialize_distributed_hpu() self.backend = kwargs.pop("backend", "hccl") diff --git a/optimum/habana/diffusers/schedulers/scheduling_ddim.py b/optimum/habana/diffusers/schedulers/scheduling_ddim.py index 2f3a6bc08a..440c15268a 100644 --- a/optimum/habana/diffusers/schedulers/scheduling_ddim.py +++ b/optimum/habana/diffusers/schedulers/scheduling_ddim.py @@ -184,8 +184,6 @@ def step( Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. - timestep (`float`): - The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. eta (`float`): @@ -293,25 +291,26 @@ def step( return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - # def add_noise( - # self, - # original_samples: torch.FloatTensor, - # noise: torch.FloatTensor, - # timesteps: torch.IntTensor, - # ) -> torch.FloatTensor: - # # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - # self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - # timesteps = timesteps.to(original_samples.device) - - # sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - # sqrt_alpha_prod = sqrt_alpha_prod.flatten() - # while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - # sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - # sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - # sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - # while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - # sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - # noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - # return noisy_samples + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod has same device and dtype as original_samples + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index 21a72b6f71..aa57fdc9a6 100644 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -14,7 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import re +import subprocess import tempfile from io import BytesIO from pathlib import Path @@ -24,6 +27,7 @@ import requests import torch from diffusers import AutoencoderKL, UNet2DConditionModel +from huggingface_hub import snapshot_download from parameterized import parameterized from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -47,6 +51,9 @@ THROUGHPUT_BASELINE_BF16 = 0.301 THROUGHPUT_BASELINE_AUTOCAST = 0.108 +TEXTUAL_INVERSION_THROUGHPUT = 58.16156989437878 +TEXTUAL_INVERSION_RUNTIME = 206.32180358597543 + class GaudiPipelineUtilsTester(TestCase): """ @@ -730,3 +737,74 @@ def test_no_generation_regression_upscale(self): ) self.assertEqual(upscaled_image.shape, (512, 512, 3)) self.assertLess(np.abs(expected_slice - upscaled_image[-3:, -3:, -1].flatten()).max(), 5e-3) + + @slow + def test_textual_inversion(self): + path_to_script = ( + Path(os.path.dirname(__file__)).parent / "examples" / "stable-diffusion" / "textual_inversion.py" + ) + + with tempfile.TemporaryDirectory() as data_dir: + snapshot_download( + "diffusers/cat_toy_example", local_dir=data_dir, repo_type="dataset", ignore_patterns=".gitattributes" + ) + with tempfile.TemporaryDirectory() as run_dir: + cmd_line = [ + "python3", + f"{path_to_script.parent.parent / 'gaudi_spawn.py'}", + "--use_mpi", + "--world_size", + "8", + f"{path_to_script}", + "--pretrained_model_name_or_path runwayml/stable-diffusion-v1-5", + f"--train_data_dir {data_dir}", + '--learnable_property "object"', + '--placeholder_token ""', + '--initializer_token "toy"', + "--resolution 512", + "--train_batch_size 4", + "--max_train_steps 375", + "--learning_rate 5.0e-04", + "--scale_lr", + '--lr_scheduler "constant"', + "--lr_warmup_steps 0", + f"--output_dir {run_dir}", + "--save_as_full_pipeline", + "--gaudi_config_name Habana/stable-diffusion", + "--throughput_warmup_steps 3", + "--seed 27", + ] + pattern = re.compile(r"([\"\'].+?[\"\'])|\s") + cmd_line = [x for y in cmd_line for x in re.split(pattern, y) if x] + + # Run textual inversion + p = subprocess.Popen(cmd_line) + return_code = p.wait() + + # Ensure the run finished without any issue + self.assertEqual(return_code, 0) + + # Assess throughput + with open(Path(run_dir) / "speed_metrics.json") as fp: + results = json.load(fp) + self.assertGreaterEqual(results["train_samples_per_second"], 0.95 * TEXTUAL_INVERSION_THROUGHPUT) + self.assertLessEqual(results["train_runtime"], 1.05 * TEXTUAL_INVERSION_RUNTIME) + + # Assess generated image + pipe = GaudiStableDiffusionPipeline.from_pretrained( + run_dir, + torch_dtype=torch.bfloat16, + use_habana=True, + use_hpu_graphs=True, + gaudi_config=GaudiConfig(use_habana_mixed_precision=False), + ) + prompt = "A backpack" + set_seed(27) + image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5, output_type="np").images[0] + + # TODO: see how to generate images in a reproducible way + # expected_slice = np.array( + # [0.57421875, 0.5703125, 0.58203125, 0.58203125, 0.578125, 0.5859375, 0.578125, 0.57421875, 0.56640625] + # ) + self.assertEqual(image.shape, (512, 512, 3)) + # self.assertLess(np.abs(expected_slice - image[-3:, -3:, -1].flatten()).max(), 5e-3)