From 223a63cd827453490d5f082d16bce5ccc366fedb Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Thu, 21 Dec 2023 12:23:08 +0100 Subject: [PATCH 01/14] Add Textual Inversion fine-tuning script (#243) --- Makefile | 1 + docs/source/tutorials/stable_diffusion.mdx | 54 +- examples/stable-diffusion/README.md | 95 ++ .../stable-diffusion/textual_inversion.py | 1011 +++++++++++++++++ optimum/habana/accelerate/accelerator.py | 29 +- optimum/habana/accelerate/state.py | 8 +- .../diffusers/schedulers/scheduling_ddim.py | 47 +- tests/test_diffusers.py | 78 ++ 8 files changed, 1261 insertions(+), 62 deletions(-) create mode 100644 examples/stable-diffusion/textual_inversion.py diff --git a/Makefile b/Makefile index 5d04b03a84..60cb27abde 100644 --- a/Makefile +++ b/Makefile @@ -55,6 +55,7 @@ slow_tests_deepspeed: test_installs slow_tests_diffusers: test_installs python -m pip install git+https://github.com/huggingface/diffusers.git python -m pytest tests/test_diffusers.py -v -s -k "test_no_" + python -m pytest tests/test_diffusers.py -v -s -k "test_textual_inversion" # Run text-generation non-regression tests slow_tests_text_generation_example: test_installs diff --git a/docs/source/tutorials/stable_diffusion.mdx b/docs/source/tutorials/stable_diffusion.mdx index 724db5b7f5..c662005a5f 100644 --- a/docs/source/tutorials/stable_diffusion.mdx +++ b/docs/source/tutorials/stable_diffusion.mdx @@ -105,33 +105,13 @@ There are two different checkpoints for Stable Diffusion 2: -## Tips - -To accelerate your Stable Diffusion pipeline, you can run it in full *bfloat16* precision. -This will also save memory. -You just need to pass `torch_dtype=torch.bfloat16` to `from_pretrained` when instantiating your pipeline. -Here is how to do it: - -```py -import torch - -pipeline = GaudiStableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", - scheduler=scheduler, - use_habana=True, - use_hpu_graphs=True, - gaudi_config="Habana/stable-diffusion", - torch_dtype=torch.bfloat16 -) -``` - -# Super-resolution +## Super-resolution The Stable Diffusion upscaler diffusion model was created by the researchers and engineers from CompVis, Stability AI, and LAION. It is used to enhance the resolution of input images by a factor of 4. See [here](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/upscale) for more information. -## How to upscale low resolution images? +### How to upscale low resolution images? To generate RGB and depth images with Stable Diffusion Upscale on Gaudi, you need to instantiate two instances: - A pipeline with [`GaudiStableDiffusionUpscalePipeline`](../package_reference/stable_diffusion_pipeline#optimum.habana.diffusers.GaudiStableDiffusionUpscalePipeline). @@ -172,4 +152,32 @@ pipeline = GaudiStableDiffusionUpscalePipeline.from_pretrained( upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] upscaled_image.save("upsampled_cat.png") -``` \ No newline at end of file +``` + + +## Tips + +To accelerate your Stable Diffusion pipeline, you can run it in full *bfloat16* precision. +This will also save memory. +You just need to pass `torch_dtype=torch.bfloat16` to `from_pretrained` when instantiating your pipeline. +Here is how to do it: + +```py +import torch + +pipeline = GaudiStableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + scheduler=scheduler, + use_habana=True, + use_hpu_graphs=True, + gaudi_config="Habana/stable-diffusion", + torch_dtype=torch.bfloat16 +) +``` + + +## Textual Inversion Fine-Tuning + +[Textual Inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like Stable Diffusion on your own images using just 3-5 examples. + +You can find [here](https://github.com/huggingface/optimum-habana/blob/main/examples/stable-diffusion/textual_inversion.py) an example script that implements this training method. diff --git a/examples/stable-diffusion/README.md b/examples/stable-diffusion/README.md index 51b3897bca..21b407f3e8 100644 --- a/examples/stable-diffusion/README.md +++ b/examples/stable-diffusion/README.md @@ -115,3 +115,98 @@ python text_to_image_generation.py \ > - use [the latest checkpoint](https://huggingface.co/Intel/ldm3d-4c) for generating improved results > - use [the pano checkpoint](https://huggingface.co/Intel/ldm3d-pano) to generate panoramic view + +## Textual Inversion + +[Textual Inversion](https://arxiv.org/abs/2208.01618) is a method to personalize text2image models like Stable Diffusion on your own images using just 3-5 examples. +The `textual_inversion.py` script shows how to implement the training procedure on Habana Gaudi. + + +### Cat toy example + +Let's get our dataset. For this example, we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example . + +Let's first download it locally: + +```py +from huggingface_hub import snapshot_download + +local_dir = "./cat" +snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes") +``` + +This will be our training data. +Now we can launch the training using: + +```bash +python textual_inversion.py \ + --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \ + --train_data_dir ./cat \ + --learnable_property object \ + --placeholder_token "" \ + --initializer_token toy \ + --resolution 512 \ + --train_batch_size 4 \ + --max_train_steps 3000 \ + --learning_rate 5.0e-04 \ + --scale_lr \ + --lr_scheduler constant \ + --lr_warmup_steps 0 \ + --output_dir /tmp/textual_inversion_cat \ + --save_as_full_pipeline \ + --gaudi_config_name Habana/stable-diffusion \ + --throughput_warmup_steps 3 +``` + +> Change `--resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model. + +> As described in [the official paper](https://arxiv.org/abs/2208.01618), only one embedding vector is used for the placeholder token, *e.g.* `""`. However, one can also add multiple embedding vectors for the placeholder token to increase the number of fine-tuneable parameters. This can help the model to learn more complex details. To use multiple embedding vectors, you can define `--num_vectors` to a number larger than one, *e.g.*: `--num_vectors 5`. The saved textual inversion vectors will then be larger in size compared to the default case. + + +### Multi-card Run + +You can run this fine-tuning script in a distributed fashion as follows: +```bash +python ../gaudi_spawn.py --use_mpi --world_size 8 textual_inversion.py \ + --pretrained_model_name_or_path runwayml/stable-diffusion-v1-5 \ + --train_data_dir ./cat \ + --learnable_property object \ + --placeholder_token '""' \ + --initializer_token toy \ + --resolution 512 \ + --train_batch_size 4 \ + --max_train_steps 375 \ + --learning_rate 5.0e-04 \ + --scale_lr \ + --lr_scheduler constant \ + --lr_warmup_steps 0 \ + --output_dir /tmp/textual_inversion_cat \ + --save_as_full_pipeline \ + --gaudi_config_name Habana/stable-diffusion \ + --throughput_warmup_steps 3 +``` + + +### Inference + +Once you have trained a model as described right above, inference can be done simply using the `GaudiStableDiffusionPipeline`. Make sure to include the `placeholder_token` in your prompt. + +```python +import torch +from optimum.habana.diffusers import GaudiStableDiffusionPipeline + +model_id = "path-to-your-trained-model" +pipe = GaudiStableDiffusionPipeline.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + use_habana=True, + use_hpu_graphs=True, + gaudi_config="Habana/stable-diffusion", +) + +prompt = "A backpack" + +image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] + +image.save("cat-backpack.png") +``` diff --git a/examples/stable-diffusion/textual_inversion.py b/examples/stable-diffusion/textual_inversion.py new file mode 100644 index 0000000000..9f81d78885 --- /dev/null +++ b/examples/stable-diffusion/textual_inversion.py @@ -0,0 +1,1011 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import json +import logging +import math +import os +import random +import shutil +import time +import warnings +from pathlib import Path + +import diffusers +import numpy as np +import PIL +import safetensors +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from huggingface_hub import create_repo, upload_folder + +# TODO: remove and import from diffusers.utils when the new version of diffusers is released +from packaging import version +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from optimum.habana import GaudiConfig +from optimum.habana.accelerate import GaudiAccelerator +from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionPipeline +from optimum.habana.utils import set_seed + + +if is_wandb_available(): + import wandb + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } +# ------------------------------------------------------------------------------ + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.23.0") + +logger = get_logger(__name__) + + +def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- textual_inversion +inference: true +--- + """ + model_card = f""" +# Textual inversion text2image fine-tuning - {repo_id} +These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline (note: unet and vae are loaded again in float32) + pipeline = GaudiStableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + unet=unet, + vae=vae, + safety_checker=None, + revision=args.revision, + variant=args.variant, + use_habana=True, + use_hpu_graphs=True, + gaudi_config=args.gaudi_config_name, + ) + pipeline.scheduler = GaudiDDIMScheduler.from_config(pipeline.scheduler.config) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + return images + + +def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True): + logger.info("Saving embeddings") + learned_embeds = ( + accelerator.unwrap_model(text_encoder) + .get_input_embeddings() + .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] + ) + learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} + + if safe_serialization: + safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"}) + else: + torch.save(learned_embeds_dict, save_path) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Save learned_embeds.bin every X updates steps.", + ) + parser.add_argument( + "--save_as_full_pipeline", + action="store_true", + help="Save the complete stable diffusion pipeline.", + ) + parser.add_argument( + "--num_vectors", + type=int, + default=1, + help="How many textual inversion vectors shall be used to learn the concept.", + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data." + ) + parser.add_argument( + "--placeholder_token", + type=str, + default=None, + required=True, + help="A token to use as a placeholder for the concept.", + ) + parser.add_argument( + "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word." + ) + parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'") + parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution." + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--bf16", + action="store_true", + default=False, + help=("Whether to use bf16 mixed precision."), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=None, + help=( + "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--no_safe_serialization", + action="store_true", + help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.", + ) + parser.add_argument( + "--gaudi_config_name", + type=str, + default=None, + help="Local path to the Gaudi configuration file or its name on the Hugging Face Hub.", + ) + parser.add_argument( + "--throughput_warmup_steps", + type=int, + default=0, + help=( + "Number of steps to ignore for throughput calculation. For example, with throughput_warmup_steps=N, the" + " first N steps will not be considered in the calculation of the throughput. This is especially useful in" + " lazy mode." + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.train_data_dir is None: + raise ValueError("You must specify a train data directory.") + + return args + + +imagenet_templates_small = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +imagenet_style_templates_small = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + + +class TextualInversionDataset(Dataset): + def __init__( + self, + data_root, + tokenizer, + learnable_property="object", # [object, style] + size=512, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + center_crop=False, + ): + self.data_root = data_root + self.tokenizer = tokenizer + self.learnable_property = learnable_property + self.size = size + self.placeholder_token = placeholder_token + self.center_crop = center_crop + self.flip_p = flip_p + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + self.num_images = len(self.image_paths) + self._length = self.num_images + + if set == "train": + self._length = self.num_images * repeats + + self.interpolation = { + "linear": PIL_INTERPOLATION["linear"], + "bilinear": PIL_INTERPOLATION["bilinear"], + "bicubic": PIL_INTERPOLATION["bicubic"], + "lanczos": PIL_INTERPOLATION["lanczos"], + }[interpolation] + + self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small + self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + placeholder_string = self.placeholder_token + text = random.choice(self.templates).format(placeholder_string) + + example["input_ids"] = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids[0] + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + ( + h, + w, + ) = ( + img.shape[0], + img.shape[1], + ) + img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] + + image = Image.fromarray(img) + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip_transform(image) + image = np.array(image).astype(np.uint8) + image = (image / 127.5 - 1.0).astype(np.float32) + + example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) + return example + + +def main(): + args = parse_args() + logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name) + + accelerator = GaudiAccelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision="bf16" if gaudi_config.use_torch_autocast or args.bf16 else "no", + log_with=args.report_to, + project_config=accelerator_project_config, + force_autocast=gaudi_config.use_torch_autocast or args.bf16, + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + import habana_frameworks.torch.core as htcore + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load tokenizer + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + + # Load scheduler and models + noise_scheduler = GaudiDDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ).to(accelerator.device) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + # Add the placeholder token in tokenizer + placeholder_tokens = [args.placeholder_token] + + if args.num_vectors < 1: + raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}") + + # add dummy tokens for multi-vector + additional_tokens = [] + for i in range(1, args.num_vectors): + additional_tokens.append(f"{args.placeholder_token}_{i}") + placeholder_tokens += additional_tokens + + num_added_tokens = tokenizer.add_tokens(placeholder_tokens) + if num_added_tokens != args.num_vectors: + raise ValueError( + f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + + # Convert the initializer_token, placeholder_token to ids + token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) + # Check if initializer_token is a single token or a sequence of tokens + if len(token_ids) > 1: + raise ValueError("The initializer token must be a single token.") + + initializer_token_id = token_ids[0] + placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens) + + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Initialise the newly added placeholder token with the embeddings of the initializer token + token_embeds = text_encoder.get_input_embeddings().weight.data + with torch.no_grad(): + for token_id in placeholder_token_ids: + token_embeds[token_id] = token_embeds[initializer_token_id].clone() + + # Freeze vae and unet + vae.requires_grad_(False) + unet.requires_grad_(False) + # Freeze all parameters except for the token embeddings in text encoder + text_encoder.text_model.encoder.requires_grad_(False) + text_encoder.text_model.final_layer_norm.requires_grad_(False) + text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) + + if args.gradient_checkpointing: + # Keep unet in train mode if we are using gradient checkpointing to save memory. + # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode. + unet.train() + text_encoder.gradient_checkpointing_enable() + unet.enable_gradient_checkpointing() + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if gaudi_config.use_fused_adam: + from habana_frameworks.torch.hpex.optimizers import FusedAdamW + + optimizer_cls = FusedAdamW + else: + optimizer_cls = torch.optim.AdamW + optimizer = optimizer_cls( + text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Dataset and DataLoaders creation: + train_dataset = TextualInversionDataset( + data_root=args.train_data_dir, + tokenizer=tokenizer, + size=args.resolution, + placeholder_token=(" ".join(tokenizer.convert_ids_to_tokens(placeholder_token_ids))), + repeats=args.repeats, + learnable_property=args.learnable_property, + center_crop=args.center_crop, + set="train", + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers + ) + if args.validation_epochs is not None: + warnings.warn( + f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}." + " Deprecated validation_epochs in favor of `validation_steps`" + f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}", + FutureWarning, + stacklevel=2, + ) + args.validation_steps = args.validation_epochs * len(train_dataset) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + ) + + text_encoder.train() + # Prepare everything with our `accelerator`. + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if gaudi_config.use_torch_autocast or args.bf16: + weight_dtype = torch.bfloat16 + + # Move vae and unet to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("textual_inversion", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + # keep original embeddings as reference + orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone() + + t0 = None + + for epoch in range(first_epoch, args.num_train_epochs): + text_encoder.train() + for step, batch in enumerate(train_dataloader): + if t0 is None and global_step == args.throughput_warmup_steps: + t0 = time.perf_counter() + + with accelerator.accumulate(text_encoder): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() + latents = latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype) + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + htcore.mark_step() + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + htcore.mark_step() + + # Let's make sure we don't update any embedding weights besides the newly added token + index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) + index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False + + with torch.no_grad(): + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + images = [] + progress_bar.update(1) + global_step += 1 + if global_step % args.save_steps == 0: + weight_name = ( + f"learned_embeds-steps-{global_step}.bin" + if args.no_safe_serialization + else f"learned_embeds-steps-{global_step}.safetensors" + ) + save_path = os.path.join(args.output_dir, weight_name) + save_progress( + text_encoder, + placeholder_token_ids, + accelerator, + args, + save_path, + safe_serialization=not args.no_safe_serialization, + ) + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + images = log_validation( + text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + duration = time.perf_counter() - t0 + throughput = args.max_train_steps * total_batch_size / duration + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + logger.info(f"Throughput = {throughput} samples/s") + logger.info(f"Train runtime = {duration} seconds") + metrics = { + "train_samples_per_second": throughput, + "train_runtime": duration, + } + with open(f"{args.output_dir}/speed_metrics.json", mode="w") as file: + json.dump(metrics, file) + if args.push_to_hub and not args.save_as_full_pipeline: + logger.warning("Enabling full model saving because --push_to_hub=True was specified.") + save_full_model = True + else: + save_full_model = args.save_as_full_pipeline + if save_full_model: + pipeline = GaudiStableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder), + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=noise_scheduler, + ) + pipeline.save_pretrained(args.output_dir) + # Save the newly trained embeddings + weight_name = "learned_embeds.bin" if args.no_safe_serialization else "learned_embeds.safetensors" + save_path = os.path.join(args.output_dir, weight_name) + save_progress( + text_encoder, + placeholder_token_ids, + accelerator, + args, + save_path, + safe_serialization=not args.no_safe_serialization, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main() diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py index 3da8ab93b3..b47b75a3c4 100644 --- a/optimum/habana/accelerate/accelerator.py +++ b/optimum/habana/accelerate/accelerator.py @@ -23,6 +23,7 @@ from collections import OrderedDict from contextlib import contextmanager from dataclasses import make_dataclass +from types import MethodType import torch from accelerate import Accelerator @@ -47,6 +48,7 @@ ProjectConfiguration, RNGType, check_os_kernel, + convert_outputs_to_fp32, is_deepspeed_available, parse_choice_from_env, ) @@ -98,6 +100,7 @@ def __init__( kwargs_handlers: list[KwargsHandler] | None = None, dynamo_backend: GaudiDynamoBackend | str | None = None, distribution_strategy: str = None, + force_autocast: bool = False, ): self.trackers = [] if project_config is not None: @@ -248,6 +251,8 @@ def __init__( self._distribution_strategy = distribution_strategy + self.force_autocast = force_autocast + check_os_kernel() @property @@ -325,20 +330,18 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e elif device_placement and not self.verify_device_map(model): model = model.to(self.device) - # The following block is commented because forward+backward+loss is already wrapped with autocast in Trainer - # if self.native_amp: - # model._original_forward = model.forward - # model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward - # if self.mixed_precision == "bf16": - # new_forward = torch.autocast(device_type=self.state.device.type, dtype=torch.bfloat16)( - # model_forward_func - # ) + # The following block is executed only when force_autocast is True + # because forward+backward+loss is already wrapped with autocast in Trainer + if self.native_amp and self.force_autocast: + model._original_forward = model.forward + model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward + new_forward = torch.autocast(device_type=self.state.device.type, dtype=torch.bfloat16)(model_forward_func) - # if hasattr(model.forward, "__func__"): - # model.forward = MethodType(new_forward, model) - # model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) - # else: - # model.forward = convert_outputs_to_fp32(new_forward) + if hasattr(model.forward, "__func__"): + model.forward = MethodType(new_forward, model) + model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) + else: + model.forward = convert_outputs_to_fp32(new_forward) # FP8 is not supported on Gaudi2 yet # elif self.mixed_precision == "fp8": # if not has_transformer_engine_layers(model): diff --git a/optimum/habana/accelerate/state.py b/optimum/habana/accelerate/state.py index 1944bcfdfc..b8626d34f0 100644 --- a/optimum/habana/accelerate/state.py +++ b/optimum/habana/accelerate/state.py @@ -41,9 +41,13 @@ def __init__(self, cpu: bool = False, **kwargs): self.device = torch.device(env_device) if env_device is not None else None self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE") - if int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu: - from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu + # initialize_distributed_hpu is already called in the __init__ of + # habana_frameworks.torch.distributed.hccl + # It is necessary so that the env variable LOCAL_RANK is set before the + # conditional statement right below + from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu + if int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu: world_size, rank, local_rank = initialize_distributed_hpu() self.backend = kwargs.pop("backend", "hccl") diff --git a/optimum/habana/diffusers/schedulers/scheduling_ddim.py b/optimum/habana/diffusers/schedulers/scheduling_ddim.py index 2f3a6bc08a..440c15268a 100644 --- a/optimum/habana/diffusers/schedulers/scheduling_ddim.py +++ b/optimum/habana/diffusers/schedulers/scheduling_ddim.py @@ -184,8 +184,6 @@ def step( Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. - timestep (`float`): - The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. eta (`float`): @@ -293,25 +291,26 @@ def step( return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) - # def add_noise( - # self, - # original_samples: torch.FloatTensor, - # noise: torch.FloatTensor, - # timesteps: torch.IntTensor, - # ) -> torch.FloatTensor: - # # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - # self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - # timesteps = timesteps.to(original_samples.device) - - # sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - # sqrt_alpha_prod = sqrt_alpha_prod.flatten() - # while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - # sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - # sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - # sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - # while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - # sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - # noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - # return noisy_samples + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod has same device and dtype as original_samples + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index 21a72b6f71..aa57fdc9a6 100644 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -14,7 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import re +import subprocess import tempfile from io import BytesIO from pathlib import Path @@ -24,6 +27,7 @@ import requests import torch from diffusers import AutoencoderKL, UNet2DConditionModel +from huggingface_hub import snapshot_download from parameterized import parameterized from PIL import Image from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -47,6 +51,9 @@ THROUGHPUT_BASELINE_BF16 = 0.301 THROUGHPUT_BASELINE_AUTOCAST = 0.108 +TEXTUAL_INVERSION_THROUGHPUT = 58.16156989437878 +TEXTUAL_INVERSION_RUNTIME = 206.32180358597543 + class GaudiPipelineUtilsTester(TestCase): """ @@ -730,3 +737,74 @@ def test_no_generation_regression_upscale(self): ) self.assertEqual(upscaled_image.shape, (512, 512, 3)) self.assertLess(np.abs(expected_slice - upscaled_image[-3:, -3:, -1].flatten()).max(), 5e-3) + + @slow + def test_textual_inversion(self): + path_to_script = ( + Path(os.path.dirname(__file__)).parent / "examples" / "stable-diffusion" / "textual_inversion.py" + ) + + with tempfile.TemporaryDirectory() as data_dir: + snapshot_download( + "diffusers/cat_toy_example", local_dir=data_dir, repo_type="dataset", ignore_patterns=".gitattributes" + ) + with tempfile.TemporaryDirectory() as run_dir: + cmd_line = [ + "python3", + f"{path_to_script.parent.parent / 'gaudi_spawn.py'}", + "--use_mpi", + "--world_size", + "8", + f"{path_to_script}", + "--pretrained_model_name_or_path runwayml/stable-diffusion-v1-5", + f"--train_data_dir {data_dir}", + '--learnable_property "object"', + '--placeholder_token ""', + '--initializer_token "toy"', + "--resolution 512", + "--train_batch_size 4", + "--max_train_steps 375", + "--learning_rate 5.0e-04", + "--scale_lr", + '--lr_scheduler "constant"', + "--lr_warmup_steps 0", + f"--output_dir {run_dir}", + "--save_as_full_pipeline", + "--gaudi_config_name Habana/stable-diffusion", + "--throughput_warmup_steps 3", + "--seed 27", + ] + pattern = re.compile(r"([\"\'].+?[\"\'])|\s") + cmd_line = [x for y in cmd_line for x in re.split(pattern, y) if x] + + # Run textual inversion + p = subprocess.Popen(cmd_line) + return_code = p.wait() + + # Ensure the run finished without any issue + self.assertEqual(return_code, 0) + + # Assess throughput + with open(Path(run_dir) / "speed_metrics.json") as fp: + results = json.load(fp) + self.assertGreaterEqual(results["train_samples_per_second"], 0.95 * TEXTUAL_INVERSION_THROUGHPUT) + self.assertLessEqual(results["train_runtime"], 1.05 * TEXTUAL_INVERSION_RUNTIME) + + # Assess generated image + pipe = GaudiStableDiffusionPipeline.from_pretrained( + run_dir, + torch_dtype=torch.bfloat16, + use_habana=True, + use_hpu_graphs=True, + gaudi_config=GaudiConfig(use_habana_mixed_precision=False), + ) + prompt = "A backpack" + set_seed(27) + image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5, output_type="np").images[0] + + # TODO: see how to generate images in a reproducible way + # expected_slice = np.array( + # [0.57421875, 0.5703125, 0.58203125, 0.58203125, 0.578125, 0.5859375, 0.578125, 0.57421875, 0.56640625] + # ) + self.assertEqual(image.shape, (512, 512, 3)) + # self.assertLess(np.abs(expected_slice - image[-3:, -3:, -1].flatten()).max(), 5e-3) From cba7f00f21171a962cd9223698b31794ed20fc38 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Thu, 21 Dec 2023 14:34:39 -0800 Subject: [PATCH 02/14] Fix for Falcon error from PR #587 (#608) * Fix for Falcon error from PR #587 * Reformatted --- optimum/habana/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/habana/utils.py b/optimum/habana/utils.py index 702aea456b..59b27ac984 100644 --- a/optimum/habana/utils.py +++ b/optimum/habana/utils.py @@ -340,7 +340,9 @@ def check_habana_frameworks_version(req_version): """ Checks if the installed version of `habana_frameworks` is equal to `req_version`. """ - return get_habana_frameworks_version() == version.parse(req_version) + return (get_habana_frameworks_version().major == version.parse(req_version).major) and ( + get_habana_frameworks_version().minor == version.parse(req_version).minor + ) def get_device_name(): From 5f772acf44aeb586fc53887449963f0b985ab2a8 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Fri, 22 Dec 2023 23:58:34 +0100 Subject: [PATCH 03/14] Add inheritance in Diffusers pipelines (#611) --- .../diffusers/pipelines/pipeline_utils.py | 2 +- .../pipeline_stable_diffusion.py | 530 +----------------- .../pipeline_stable_diffusion_ldm3d.py | 436 +------------- .../pipeline_stable_diffusion_upscale.py | 502 +---------------- 4 files changed, 59 insertions(+), 1411 deletions(-) diff --git a/optimum/habana/diffusers/pipelines/pipeline_utils.py b/optimum/habana/diffusers/pipelines/pipeline_utils.py index 83a3a04116..a2d9139101 100644 --- a/optimum/habana/diffusers/pipelines/pipeline_utils.py +++ b/optimum/habana/diffusers/pipelines/pipeline_utils.py @@ -85,7 +85,7 @@ def __init__( gaudi_config: Union[str, GaudiConfig] = None, bf16_full_eval: bool = False, ): - super().__init__() + DiffusionPipeline.__init__(self) self.use_habana = use_habana if self.use_habana: diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 2440d5b69e..6c5e8fecbd 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import time from dataclasses import dataclass from math import ceil @@ -22,15 +21,11 @@ import numpy as np import PIL import torch -from diffusers.configuration_utils import FrozenDict -from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, scale_lora_layers, unscale_lora_layers -from packaging import version +from diffusers.utils import BaseOutput, deprecate from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from optimum.utils import logging @@ -50,25 +45,9 @@ class GaudiStableDiffusionPipelineOutput(BaseOutput): throughput: float -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): +class GaudiStableDiffusionPipeline(GaudiDiffusionPipeline, StableDiffusionPipeline): """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 - """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - return noise_cfg - - -class GaudiStableDiffusionPipeline( - GaudiDiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin -): - """ - Extends the [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline) class: + Adapted from: https://github.com/huggingface/diffusers/blob/v0.23.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L73 - Generation is performed by batches - Two `mark_step()` were added to add support for lazy mode - Added support for HPU graphs @@ -103,11 +82,6 @@ class GaudiStableDiffusionPipeline( This will be faster and save memory compared to fp32/mixed precision but can harm generated images. """ - model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] - _exclude_from_cpu_offload = ["safety_checker"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - def __init__( self, vae: AutoencoderKL, @@ -123,7 +97,8 @@ def __init__( gaudi_config: Union[str, GaudiConfig] = None, bf16_full_eval: bool = False, ): - super().__init__( + GaudiDiffusionPipeline.__init__( + self, use_habana, use_hpu_graphs, gaudi_config, @@ -134,410 +109,20 @@ def __init__( if bf16_full_eval: unet.conv_in.float() - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, + StableDiffusionPipeline.__init__( + self, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.register_to_config(requires_safety_checker=requires_safety_checker) self.to(self._device) - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - **kwargs, - ): - deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." - deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) - - prompt_embeds_tuple = self.encode_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=lora_scale, - **kwargs, - ) - - # concatenate for backwards comp - prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) - - return prompt_embeds - - def encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if prompt is not None and isinstance(prompt, str): - num_prompts = 1 - elif prompt is not None and isinstance(prompt, list): - num_prompts = len(prompt) - else: - num_prompts = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) - - if self.text_encoder is not None: - prompt_embeds_dtype = self.text_encoder.dtype - elif self.unet is not None: - prompt_embeds_dtype = self.unet.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype - - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * num_prompts - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif num_prompts != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has {len(negative_prompt)} elements, but `prompt`:" - f" {prompt} has {num_prompts}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(num_prompts * num_images_per_prompt, seq_len, -1) - - if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - return prompt_embeds, negative_prompt_embeds - - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - - def decode_latents(self, latents): - deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" - deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) - - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (num_images, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != num_images: @@ -567,87 +152,6 @@ def prepare_latents(self, num_images, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): - r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. - - The suffixes after the scaling factors represent the stages where they are being applied. - - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values - that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - if not hasattr(self, "unet"): - raise ValueError("The pipeline must have `unet` for using FreeU.") - self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - - def disable_freeu(self): - """Disables the FreeU mechanism if enabled.""" - self.unet.disable_freeu() - - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - timesteps (`torch.Tensor`): - generate embedding vectors at these timesteps - embedding_dim (`int`, *optional*, defaults to 512): - dimension of the embeddings to generate - dtype: - data type of the generated embeddings - - Returns: - `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def clip_skip(self): - return self._clip_skip - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None - - @property - def cross_attention_kwargs(self): - return self._cross_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - @classmethod def _split_inputs_into_batches(cls, batch_size, latents, prompt_embeds, negative_prompt_embeds): # Use torch.split to generate num_batches batches of size batch_size diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py index d45f0eaf91..3bcd156d1f 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import time -import warnings from dataclasses import dataclass from math import ceil from typing import Any, Callable, Dict, List, Optional, Union @@ -23,13 +21,11 @@ import numpy as np import PIL import torch -from diffusers.image_processor import VaeImageProcessorLDM3D -from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines import StableDiffusionLDM3DPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, scale_lora_layers, unscale_lora_layers +from diffusers.utils import BaseOutput from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from optimum.utils import logging @@ -37,6 +33,7 @@ from ....transformers.gaudi_configuration import GaudiConfig from ....utils import speed_metrics from ..pipeline_utils import GaudiDiffusionPipeline +from .pipeline_stable_diffusion import GaudiStableDiffusionPipeline logger = logging.get_logger(__name__) @@ -50,11 +47,9 @@ class GaudiStableDiffusionLDM3DPipelineOutput(BaseOutput): nsfw_content_detected: Optional[List[bool]] -class GaudiStableDiffusionLDM3DPipeline( - GaudiDiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin -): +class GaudiStableDiffusionLDM3DPipeline(GaudiDiffusionPipeline, StableDiffusionLDM3DPipeline): """ - Extends the [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline) class: + Adapted from: https://github.com/huggingface/diffusers/blob/v0.23.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py#L84 - Generation is performed by batches - Two `mark_step()` were added to add support for lazy mode - Added support for HPU graphs @@ -90,10 +85,6 @@ class GaudiStableDiffusionLDM3DPipeline( This will be faster and save memory compared to fp32/mixed precision but can harm generated images. """ - model_cpu_offload_seq = "text_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor"] - _exclude_from_cpu_offload = ["safety_checker"] - def __init__( self, vae: AutoencoderKL, @@ -109,7 +100,8 @@ def __init__( gaudi_config: Union[str, GaudiConfig] = None, bf16_full_eval: bool = False, ): - super().__init__( + GaudiDiffusionPipeline.__init__( + self, use_habana, use_hpu_graphs, gaudi_config, @@ -120,365 +112,20 @@ def __init__( if bf16_full_eval: unet.conv_in.float() - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, + StableDiffusionLDM3DPipeline.__init__( + self, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor) - self.register_to_config(requires_safety_checker=requires_safety_checker) self.to(self._device) - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - **kwargs, - ): - deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." - deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) - - prompt_embeds_tuple = self.encode_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=lora_scale, - **kwargs, - ) - - # concatenate for backwards comp - prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) - - return prompt_embeds - - def encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if prompt is not None and isinstance(prompt, str): - num_prompts = 1 - elif prompt is not None and isinstance(prompt, list): - num_prompts = len(prompt) - else: - num_prompts = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) - - if self.text_encoder is not None: - prompt_embeds_dtype = self.text_encoder.dtype - elif self.unet is not None: - prompt_embeds_dtype = self.unet.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype - - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * num_prompts - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif num_prompts != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has {len(negative_prompt)} elements, but `prompt`:" - f" {prompt} has {num_prompts}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(num_prompts * num_images_per_prompt, seq_len, -1) - - if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - return prompt_embeds, negative_prompt_embeds - - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - rgb_feature_extractor_input = feature_extractor_input[0] - safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - - def decode_latents(self, latents): - warnings.warn( - "The decode_latents method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor instead", - FutureWarning, - ) - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - def prepare_latents(self, num_images, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (num_images, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != num_images: @@ -508,49 +155,6 @@ def prepare_latents(self, num_images, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - @classmethod - def _split_inputs_into_batches(cls, batch_size, latents, prompt_embeds, negative_prompt_embeds): - # Use torch.split to generate num_batches batches of size batch_size - latents_batches = list(torch.split(latents, batch_size)) - prompt_embeds_batches = list(torch.split(prompt_embeds, batch_size)) - if negative_prompt_embeds is not None: - negative_prompt_embeds_batches = list(torch.split(negative_prompt_embeds, batch_size)) - - # If the last batch has less samples than batch_size, pad it with dummy samples - num_dummy_samples = 0 - if latents_batches[-1].shape[0] < batch_size: - num_dummy_samples = batch_size - latents_batches[-1].shape[0] - # Pad latents_batches - sequence_to_stack = (latents_batches[-1],) + tuple( - torch.zeros_like(latents_batches[-1][0][None, :]) for _ in range(num_dummy_samples) - ) - latents_batches[-1] = torch.vstack(sequence_to_stack) - # Pad prompt_embeds_batches - sequence_to_stack = (prompt_embeds_batches[-1],) + tuple( - torch.zeros_like(prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) - ) - prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) - # Pad negative_prompt_embeds_batches if necessary - if negative_prompt_embeds is not None: - sequence_to_stack = (negative_prompt_embeds_batches[-1],) + tuple( - torch.zeros_like(negative_prompt_embeds_batches[-1][0][None, :]) for _ in range(num_dummy_samples) - ) - negative_prompt_embeds_batches[-1] = torch.vstack(sequence_to_stack) - - # Stack batches in the same tensor - latents_batches = torch.stack(latents_batches) - if negative_prompt_embeds is not None: - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - for i, (negative_prompt_embeds_batch, prompt_embeds_batch) in enumerate( - zip(negative_prompt_embeds_batches, prompt_embeds_batches[:]) - ): - prompt_embeds_batches[i] = torch.cat([negative_prompt_embeds_batch, prompt_embeds_batch]) - prompt_embeds_batches = torch.stack(prompt_embeds_batches) - - return latents_batches, prompt_embeds_batches, num_dummy_samples - @torch.no_grad() def __call__( self, @@ -699,7 +303,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Split into batches (HPU-specific step) - latents_batches, text_embeddings_batches, num_dummy_samples = self._split_inputs_into_batches( + ( + latents_batches, + text_embeddings_batches, + num_dummy_samples, + ) = GaudiStableDiffusionPipeline._split_inputs_into_batches( batch_size, latents, prompt_embeds, diff --git a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index e96caa0c66..e04873a5d3 100644 --- a/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/optimum/habana/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import time -import warnings from dataclasses import dataclass from math import ceil from typing import Any, Callable, Dict, List, Optional, Union @@ -23,18 +21,10 @@ import numpy as np import PIL import torch -from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.models.attention_processor import ( - AttnProcessor2_0, - LoRAAttnProcessor2_0, - LoRAXFormersAttnProcessor, - XFormersAttnProcessor, -) -from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines import StableDiffusionUpscalePipeline from diffusers.schedulers import DDPMScheduler, KarrasDiffusionSchedulers -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, scale_lora_layers, unscale_lora_layers +from diffusers.utils import BaseOutput from diffusers.utils.torch_utils import randn_tensor from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer @@ -45,9 +35,6 @@ from ..pipeline_utils import GaudiDiffusionPipeline -# from diffusers.models.lora import adjust_lora_scale_text_encoder - - logger = logging.get_logger(__name__) PipelineImageInput = Union[ @@ -62,39 +49,11 @@ class GaudiStableDiffusionPipelineOutput(BaseOutput): throughput: float -def preprocess(image): - warnings.warn( - "The preprocess method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor.preprocess instead", - FutureWarning, - ) - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - w, h = image[0].size - w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 - - image = [np.array(i.resize((w, h)))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -class GaudiStableDiffusionUpscalePipeline( - GaudiDiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin -): +class GaudiStableDiffusionUpscalePipeline(GaudiDiffusionPipeline, StableDiffusionUpscalePipeline): """ Pipeline for text-guided image super-resolution using Stable Diffusion 2. - Extends the [`StableDiffusionUpscalePipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion#diffusers.StableDiffusionUpscalePipeline) class: + Adapted from: https://github.com/huggingface/diffusers/blob/v0.23.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py#L70 - Generation is performed by batches - Two `mark_step()` were added to add support for lazy mode - Added support for HPU graphs @@ -133,8 +92,6 @@ class GaudiStableDiffusionUpscalePipeline( This will be faster and save memory compared to fp32/mixed precision but can harm generated images. """ - _optional_components = ["safety_checker", "feature_extractor"] - def __init__( self, vae: AutoencoderKL, @@ -145,416 +102,35 @@ def __init__( scheduler: KarrasDiffusionSchedulers, safety_checker: Optional[Any] = None, feature_extractor: Optional[CLIPImageProcessor] = None, + watermarker: Optional[Any] = None, + max_noise_level: int = 350, use_habana: bool = False, use_hpu_graphs: bool = False, gaudi_config: Union[str, GaudiConfig] = None, bf16_full_eval: bool = False, - watermarker: Optional[Any] = None, - max_noise_level: int = 350, ): - super().__init__(use_habana, use_hpu_graphs, gaudi_config, bf16_full_eval) + GaudiDiffusionPipeline.__init__(self, use_habana, use_hpu_graphs, gaudi_config, bf16_full_eval) # Workaround for Synapse 1.11 for full bf16 if bf16_full_eval: unet.conv_in.float() - if hasattr( - vae, "config" - ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate - is_vae_scaling_factor_set_to_0_08333 = ( - hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333 - ) - if not is_vae_scaling_factor_set_to_0_08333: - deprecation_message = ( - "The configuration file of the vae does not contain `scaling_factor` or it is set to" - f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned" - " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to" - " 0.08333 Please make sure to update the config accordingly, as not doing so might lead to" - " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging" - " Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file" - ) - deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False) - vae.register_to_config(scaling_factor=0.08333) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - low_res_scheduler=low_res_scheduler, - scheduler=scheduler, - safety_checker=safety_checker, - watermarker=watermarker, - feature_extractor=feature_extractor, + StableDiffusionUpscalePipeline.__init__( + self, + vae, + text_encoder, + tokenizer, + unet, + low_res_scheduler, + scheduler, + safety_checker, + feature_extractor, + watermarker, + max_noise_level, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") - self.register_to_config(max_noise_level=max_noise_level) self.to(self._device) - @property - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - **kwargs, - ): - deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." - deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) - - prompt_embeds_tuple = self.encode_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=lora_scale, - **kwargs, - ) - - # concatenate for backwards comp - prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) - - return prompt_embeds - - def encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if prompt is not None and isinstance(prompt, str): - num_prompts = 1 - elif prompt is not None and isinstance(prompt, list): - num_prompts = len(prompt) - else: - num_prompts = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) - - if self.text_encoder is not None: - prompt_embeds_dtype = self.text_encoder.dtype - elif self.unet is not None: - prompt_embeds_dtype = self.unet.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype - - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * num_prompts - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif num_prompts != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has {len(negative_prompt)} elements, but `prompt`:" - f" {prompt} has {num_prompts}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(num_prompts * num_images_per_prompt, seq_len, -1) - - if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - return prompt_embeds, negative_prompt_embeds - - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, nsfw_detected, watermark_detected = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype=dtype) - ) - else: - nsfw_detected = None - watermark_detected = None - - if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None: - self.unet_offload_hook.offload() - - return image, nsfw_detected, watermark_detected - - def decode_latents(self, latents): - warnings.warn( - "The decode_latents method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor instead", - FutureWarning, - ) - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - image, - noise_level, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - ): - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - if ( - not isinstance(image, torch.Tensor) - and not isinstance(image, PIL.Image.Image) - and not isinstance(image, np.ndarray) - and not isinstance(image, list) - ): - raise ValueError( - f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}" - ) - - # verify batch size of prompt and image are same if image is a list or tensor or numpy array - if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray): - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if isinstance(image, list): - image_batch_size = len(image) - else: - image_batch_size = image.shape[0] - if batch_size != image_batch_size: - raise ValueError( - f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." - " Please make sure that passed `prompt` matches the batch size of `image`." - ) - - # check noise level - if noise_level > self.config.max_noise_level: - raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height, width) if latents is None: @@ -578,46 +154,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - def upcast_vae(self): - dtype = self.vae.dtype - self.vae.to(dtype=torch.float32) - use_torch_2_0_or_xformers = isinstance( - self.vae.decoder.mid_block.attentions[0].processor, - (AttnProcessor2_0, XFormersAttnProcessor, LoRAXFormersAttnProcessor, LoRAAttnProcessor2_0), - ) - # if xformers or torch_2_0 is used attention block does not need - # to be in float32 which can save lots of memory - if use_torch_2_0_or_xformers: - self.vae.post_quant_conv.to(dtype) - self.vae.decoder.conv_in.to(dtype) - self.vae.decoder.mid_block.to(dtype) - - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): - r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497. - - The suffixes after the scaling factors represent the stages where they are being applied. - - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values - that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - if not hasattr(self, "unet"): - raise ValueError("The pipeline must have `unet` for using FreeU.") - self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - - def disable_freeu(self): - """Disables the FreeU mechanism if enabled.""" - self.unet.disable_freeu() - @classmethod def _split_inputs_into_batches(cls, batch_size, latents, text_embeddings, uncond_embeddings, image, noise_level): # Use torch.split to generate num_batches batches of size batch_size From 79f6de3bca391104a14b9a0670924a86836bf5ae Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 26 Dec 2023 07:59:12 +0800 Subject: [PATCH 04/14] add DPO and SFT of TRL support in Gaudi and example (#601) * add DPO and SFT of TRL support in Gaudi and example Signed-off-by: Wang, Yi A * upgrade SFTTrainer/DPO trainer and stack_llama_2 example to v0.7.6 Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A --- examples/trl/stack_llama_2/README.md | 72 +++ examples/trl/stack_llama_2/dpo_llama2.py | 231 ++++++++++ .../trl/stack_llama_2/merge_peft_adapter.py | 50 ++ examples/trl/stack_llama_2/requirements.txt | 5 + examples/trl/stack_llama_2/sft_llama2.py | 168 +++++++ optimum/habana/trl/__init__.py | 2 + optimum/habana/trl/trainer/__init__.py | 21 + optimum/habana/trl/trainer/dpo_trainer.py | 426 ++++++++++++++++++ optimum/habana/trl/trainer/sft_trainer.py | 244 ++++++++++ 9 files changed, 1219 insertions(+) create mode 100644 examples/trl/stack_llama_2/README.md create mode 100644 examples/trl/stack_llama_2/dpo_llama2.py create mode 100644 examples/trl/stack_llama_2/merge_peft_adapter.py create mode 100644 examples/trl/stack_llama_2/requirements.txt create mode 100644 examples/trl/stack_llama_2/sft_llama2.py create mode 100644 optimum/habana/trl/__init__.py create mode 100644 optimum/habana/trl/trainer/__init__.py create mode 100644 optimum/habana/trl/trainer/dpo_trainer.py create mode 100644 optimum/habana/trl/trainer/sft_trainer.py diff --git a/examples/trl/stack_llama_2/README.md b/examples/trl/stack_llama_2/README.md new file mode 100644 index 0000000000..12b7e4da80 --- /dev/null +++ b/examples/trl/stack_llama_2/README.md @@ -0,0 +1,72 @@ +# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model + +## Prerequisites + +Install all the dependencies in the `requirements.txt`: + +``` +$ pip install -U -r requirements.txt +``` + + +## Training + +There were two main steps to the DPO training process: +1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se: + + ``` + python ../../gaudi_spawn.py --world_size 8 --use_mpi sft_llama2.py \ + --output_dir="./sft" \ + --max_steps=500 \ + --logging_steps=10 \ + --save_steps=10 \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=1 \ + --gradient_accumulation_steps=2 \ + --learning_rate=1e-4 \ + --lr_scheduler_type="cosine" \ + --warmup_steps=100 \ + --weight_decay=0.05 \ + --optim="paged_adamw_32bit" \ + --bf16 \ + --remove_unused_columns=False \ + --run_name="sft_llama2" \ + --report_to=none \ + --use_habana \ + --use_lazy_mode + ``` +2. Run the DPO trainer using the model saved by the previous step: + ``` + python ../../gaudi_spawn.py --world_size 8 --use_mpi dpo_llama2.py \ + --model_name_or_path="sft/final_merged_checkpoint" \ + --output_dir="dpo" \ + --report_to=none + ``` + + +## Merging the adaptors + +To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL: + +``` +python merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo" --output_name="stack-llama-2" +``` + +which will also push the model to your HuggingFace hub account. + +## Running the model + +We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via: + +```py +from peft import AutoPeftModelForCausalLM + + +model = AutoPeftModelForCausalLM.from_pretrained( + "dpo/final_checkpoint", + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, +) + +model.generate(...) +``` diff --git a/examples/trl/stack_llama_2/dpo_llama2.py b/examples/trl/stack_llama_2/dpo_llama2.py new file mode 100644 index 0000000000..2b102e1825 --- /dev/null +++ b/examples/trl/stack_llama_2/dpo_llama2.py @@ -0,0 +1,231 @@ +# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py, enable it for Gaudi2 +from dataclasses import dataclass, field +from typing import Dict, Optional + +import torch +from datasets import Dataset, load_dataset +from peft import LoraConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser + +from optimum.habana import GaudiConfig, GaudiTrainingArguments +from optimum.habana.trl import GaudiDPOTrainer + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + The arguments for the DPO training script. + """ + + # data parameters + beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) + + # training parameters + model_name_or_path: Optional[str] = field( + default="../sft/results/final_checkpoint", + metadata={"help": "the location of the SFT model name or path"}, + ) + tokenizer_name_or_path: Optional[str] = field( + default="meta-llama/Llama-2-7b-hf", + metadata={"help": "the location of the SFT model name or path"}, + ) + learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"}) + lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"}) + warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"}) + weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"}) + optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"}) + + per_device_train_batch_size: Optional[int] = field(default=1, metadata={"help": "train batch size per device"}) + per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"}) + gradient_accumulation_steps: Optional[int] = field( + default=4, metadata={"help": "the number of gradient accumulation steps"} + ) + gradient_checkpointing: Optional[bool] = field( + default=False, metadata={"help": "whether to use gradient checkpointing"} + ) + + lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) + lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) + + max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"}) + max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) + max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) + logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"}) + save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"}) + eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"}) + + output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"}) + log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"}) + + # instrumentation + sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"}) + report_to: Optional[str] = field( + default="wandb", + metadata={ + "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' + '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' + 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' + }, + ) + # debug argument for distributed training + ignore_bias_buffers: Optional[bool] = field( + default=False, + metadata={ + "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" + "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" + }, + ) + + +def get_stack_exchange_paired( + data_dir: str = "data/rl", + sanity_check: bool = False, + cache_dir: str = None, + num_proc=24, +) -> Dataset: + """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format. + + The dataset is converted to a dictionary with the following structure: + { + 'prompt': List[str], + 'chosen': List[str], + 'rejected': List[str], + } + + Prompts are structured as follows: + "Question: " + + "\n\nAnswer: " + """ + dataset = load_dataset( + "lvwerra/stack-exchange-paired", + split="train", + cache_dir=cache_dir, + data_dir=data_dir, + ) + original_columns = dataset.column_names + + if sanity_check: + dataset = dataset.select(range(min(len(dataset), 1000))) + + def return_prompt_and_responses(samples) -> Dict[str, str]: + return { + "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]], + "chosen": samples["response_j"], + "rejected": samples["response_k"], + } + + return dataset.map( + return_prompt_and_responses, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, + ) + + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + # 1. initialize training arguments: + training_args = GaudiTrainingArguments( + per_device_train_batch_size=script_args.per_device_train_batch_size, + per_device_eval_batch_size=script_args.per_device_eval_batch_size, + max_steps=script_args.max_steps, + logging_steps=script_args.logging_steps, + save_steps=script_args.save_steps, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + gradient_checkpointing=script_args.gradient_checkpointing, + learning_rate=script_args.learning_rate, + evaluation_strategy="steps", + eval_steps=script_args.eval_steps, + output_dir=script_args.output_dir, + report_to=script_args.report_to, + lr_scheduler_type=script_args.lr_scheduler_type, + warmup_steps=script_args.warmup_steps, + optim=script_args.optimizer_type, + bf16=True, + remove_unused_columns=False, + run_name="dpo_llama2", + use_habana=True, + use_lazy_mode=True, + use_hpu_graphs_for_training=True, + use_hpu_graphs_for_inference=True, + ) + # 2. load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + script_args.model_name_or_path, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, + ) + model.config.use_cache = False + + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + model_ref = AutoModelForCausalLM.from_pretrained( + script_args.model_name_or_path, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, + ) + model_ref.config.use_cache = False + tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name_or_path) + tokenizer.pad_token = tokenizer.eos_token + + # 3. Load the Stack-exchange paired dataset + train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check) + train_dataset = train_dataset.filter( + lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length + and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length + ) + + # 4. Load evaluation dataset + eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True) + eval_dataset = eval_dataset.filter( + lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length + and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length + ) + + peft_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + target_modules=[ + "q_proj", + "v_proj", + "k_proj", + "out_proj", + "fc_in", + "fc_out", + "wte", + ], + bias="none", + task_type="CAUSAL_LM", + ) + + gaudi_config = GaudiConfig() + gaudi_config.use_fused_adam = True + gaudi_config.use_fused_clip_norm = True + + # 5. initialize the DPO trainer + dpo_trainer = GaudiDPOTrainer( + model, + model_ref, + gaudi_config=gaudi_config, + args=training_args, + beta=script_args.beta, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + peft_config=peft_config, + max_prompt_length=script_args.max_prompt_length, + max_length=script_args.max_length, + ) + + # 6. train + dpo_trainer.train() + + # 7. save + dpo_trainer.save_model(script_args.output_dir) diff --git a/examples/trl/stack_llama_2/merge_peft_adapter.py b/examples/trl/stack_llama_2/merge_peft_adapter.py new file mode 100644 index 0000000000..8913fc62a4 --- /dev/null +++ b/examples/trl/stack_llama_2/merge_peft_adapter.py @@ -0,0 +1,50 @@ +# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py. +# only difference is removal of model.push_to_hub +from dataclasses import dataclass, field +from typing import Optional + +import torch +from peft import PeftConfig, PeftModel +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser + + +@dataclass +class ScriptArguments: + """ + The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the + merged model. + """ + + adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"}) + base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"}) + output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"}) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] +assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge" +assert script_args.base_model_name is not None, "please provide the name of the Base model" +assert script_args.output_name is not None, "please provide the output name of the merged model" + +peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name) +if peft_config.task_type == "SEQ_CLS": + # The sequence classification task is used for the reward model in PPO + model = AutoModelForSequenceClassification.from_pretrained( + script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16 + ) +else: + model = AutoModelForCausalLM.from_pretrained( + script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16 + ) + +tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name) + +# Load the PEFT model +model = PeftModel.from_pretrained(model, script_args.adapter_model_name) +model.eval() + +model = model.merge_and_unload() + +model.save_pretrained(f"{script_args.output_name}") +tokenizer.save_pretrained(f"{script_args.output_name}") +# model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False) diff --git a/examples/trl/stack_llama_2/requirements.txt b/examples/trl/stack_llama_2/requirements.txt new file mode 100644 index 0000000000..c980a4b30c --- /dev/null +++ b/examples/trl/stack_llama_2/requirements.txt @@ -0,0 +1,5 @@ +trl == 0.7.6 +peft == 0.6.2 +datasets +wandb +tyro diff --git a/examples/trl/stack_llama_2/sft_llama2.py b/examples/trl/stack_llama_2/sft_llama2.py new file mode 100644 index 0000000000..1ebff0df14 --- /dev/null +++ b/examples/trl/stack_llama_2/sft_llama2.py @@ -0,0 +1,168 @@ +# Fine-Tune Llama2-7b on SE paired dataset +# copy from https://github.com/huggingface/trl/blob/v0.7.6/examples/research_projects/stack_llama_2/scripts/sft_llama2.py, enable it for Gaudi2 +import logging +import os +from dataclasses import dataclass, field +from typing import Optional + +import torch +import transformers +from datasets import load_dataset +from peft import AutoPeftModelForCausalLM, LoraConfig +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser +from transformers.trainer_utils import is_main_process +from trl.trainer import ConstantLengthDataset + +from optimum.habana import GaudiConfig, GaudiTrainingArguments +from optimum.habana.trl import GaudiSFTTrainer + + +logger = logging.getLogger(__name__) + + +@dataclass +class ScriptArguments: + model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"}) + dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"}) + subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"}) + split: Optional[str] = field(default="train", metadata={"help": "the split to use"}) + size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"}) + streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"}) + shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"}) + seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"}) + num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"}) + packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"}) + + # LoraConfig + lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) + lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) + + +parser = HfArgumentParser((ScriptArguments, GaudiTrainingArguments)) +script_args, training_args = parser.parse_args_into_dataclasses() +peft_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + target_modules=["q_proj", "v_proj"], + bias="none", + task_type="CAUSAL_LM", +) + +if training_args.group_by_length and script_args.packing: + raise ValueError("Cannot use both packing and group by length") + + +def chars_token_ratio(dataset, tokenizer, nb_examples=400): + """ + Estimate the average number of characters per token in the dataset. + """ + total_characters, total_tokens = 0, 0 + for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): + text = prepare_sample_text(example) + total_characters += len(text) + if tokenizer.is_fast: + total_tokens += len(tokenizer(text).tokens()) + else: + total_tokens += len(tokenizer.tokenize(text)) + + return total_characters / total_tokens + + +def prepare_sample_text(example): + """Prepare the text from a sample of the dataset.""" + text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}" + return text + + +def create_datasets(tokenizer, args): + dataset = load_dataset( + args.dataset_name, + data_dir=args.subset, + split=args.split, + use_auth_token=True, + num_proc=args.num_workers if not args.streaming else None, + streaming=args.streaming, + ) + if args.streaming: + print("Loading the dataset in streaming mode") + valid_data = dataset.take(args.size_valid_set) + train_data = dataset.skip(args.size_valid_set) + train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None) + else: + dataset = dataset.train_test_split(test_size=0.005, seed=None) + train_data = dataset["train"] + valid_data = dataset["test"] + print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") + + chars_per_token = chars_token_ratio(train_data, tokenizer) + print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + formatting_func=prepare_sample_text, + infinite=True, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + valid_dataset = ConstantLengthDataset( + tokenizer, + valid_data, + formatting_func=prepare_sample_text, + infinite=False, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + return train_dataset, valid_dataset + + +base_model = AutoModelForCausalLM.from_pretrained( + script_args.model_name, + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, + use_auth_token=True, +) +base_model.config.use_cache = False + +tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training + +log_level = training_args.get_process_log_level() +logger.setLevel(log_level) +transformers.utils.logging.set_verbosity(log_level) +transformers.utils.logging.enable_default_handler() +transformers.utils.logging.enable_explicit_format() + +train_dataset, eval_dataset = create_datasets(tokenizer, script_args) + +gaudi_config = GaudiConfig() +gaudi_config.use_fused_adam = True +gaudi_config.use_fused_clip_norm = True + +trainer = GaudiSFTTrainer( + model=base_model, + gaudi_config=gaudi_config, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + packing=script_args.packing, + max_seq_length=None, + tokenizer=tokenizer, + args=training_args, +) +trainer.train() +trainer.save_model(training_args.output_dir) + +# Free memory for merging weights +del base_model +with training_args.main_process_first(desc="merge peft model"): + if is_main_process(training_args.local_rank): + model = AutoPeftModelForCausalLM.from_pretrained(training_args.output_dir, torch_dtype=torch.bfloat16) + model = model.merge_and_unload() + + output_merged_dir = os.path.join(training_args.output_dir, "final_merged_checkpoint") + model.save_pretrained(output_merged_dir, safe_serialization=True) diff --git a/optimum/habana/trl/__init__.py b/optimum/habana/trl/__init__.py new file mode 100644 index 0000000000..e80fac8b8a --- /dev/null +++ b/optimum/habana/trl/__init__.py @@ -0,0 +1,2 @@ +from .trainer.dpo_trainer import GaudiDPOTrainer +from .trainer.sft_trainer import GaudiSFTTrainer diff --git a/optimum/habana/trl/trainer/__init__.py b/optimum/habana/trl/trainer/__init__.py new file mode 100644 index 0000000000..13bf554fd7 --- /dev/null +++ b/optimum/habana/trl/trainer/__init__.py @@ -0,0 +1,21 @@ +# flake8: noqa + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# There is a circular import in the PPOTrainer if we let isort sort these +# isort: on + +from .sft_trainer import GaudiSFTTrainer +from .dpo_trainer import GaudiDPOTrainer diff --git a/optimum/habana/trl/trainer/dpo_trainer.py b/optimum/habana/trl/trainer/dpo_trainer.py new file mode 100644 index 0000000000..e5cfea0cd3 --- /dev/null +++ b/optimum/habana/trl/trainer/dpo_trainer.py @@ -0,0 +1,426 @@ +# DPO Authors: Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, and Chelsea Finn 2023 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import warnings +from collections import defaultdict +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +from accelerate.utils import is_deepspeed_available +from datasets import Dataset +from transformers import ( + AutoModelForCausalLM, + DataCollator, + PreTrainedModel, + PreTrainedTokenizerBase, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput +from trl import DPOTrainer, create_reference_model +from trl.import_utils import is_peft_available, is_wandb_available +from trl.trainer.utils import ( + DPODataCollatorWithPadding, + disable_dropout_in_model, + pad_to_length, + peft_module_casting_to_bf16, +) + +from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +if is_wandb_available(): + pass + +if is_deepspeed_available(): + pass + + +class GaudiDPOTrainer(DPOTrainer, GaudiTrainer): + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + beta: float = 0.1, + label_smoothing: float = 0, + loss_type: Literal["sigmoid", "hinge", "ipo", "kto"] = "sigmoid", + args: GaudiTrainingArguments = None, + gaudi_config: GaudiConfig = None, + data_collator: Optional[DataCollator] = None, + label_pad_token_id: int = -100, + padding_value: int = None, + truncation_mode: str = "keep_end", + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + max_length: Optional[int] = None, + max_prompt_length: Optional[int] = None, + max_target_length: Optional[int] = None, + peft_config: Optional[Dict] = None, + is_encoder_decoder: Optional[bool] = None, + disable_dropout: bool = True, + generate_during_eval: bool = False, + compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None, + precompute_ref_log_probs: bool = False, + model_init_kwargs: Optional[Dict] = None, + ref_model_init_kwargs: Optional[Dict] = None, + ): + """ + Copied from DPOTrainer.__init__: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/dpo_trainer.py#L127 + The only differences are: + - add new args gaudi_config + - use graph for ref_model + - use GaudiTrainer instead of Trainer + - cast peft model to bf16. + """ + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the DPOTrainer. But your model is already instantiated.") + + if ref_model_init_kwargs is None: + ref_model_init_kwargs = {} + elif not isinstance(ref_model, str): + raise ValueError( + "You passed ref_model_kwargs to the DPOTrainer. But your ref_model is already instantiated." + ) + + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the DPOTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if isinstance(ref_model, str): + warnings.warn( + "You passed a ref model_id to the DPOTrainer. This will automatically create an " + "`AutoModelForCausalLM`" + ) + ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_config, we merge and unload it first + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16: + peft_module_casting_to_bf16(model) + + # For models that use gradient_checkpoiting, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if generate_during_eval and not is_wandb_available(): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases to be installed." + " Please install `wandb` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = is_encoder_decoder + + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model or precompute_ref_log_probs: + # The `model` with adapters turned off will be used as the reference model + self.ref_model = None + else: + self.ref_model = create_reference_model(model) + + if data_collator is None: + if tokenizer is None: + raise ValueError( + "max_length or a tokenizer must be specified when using the default DPODataCollatorWithPadding" + ) + if max_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_length` in the DPOTrainer's init" + " it will be set to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if max_prompt_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the DPOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + + if max_target_length is None and self.is_encoder_decoder: + warnings.warn( + "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_target_length` in the DPOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_target_length = 128 + + data_collator = DPODataCollatorWithPadding( + pad_token_id=tokenizer.pad_token_id, + label_pad_token_id=label_pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + if disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.max_length = max_length + self.generate_during_eval = generate_during_eval + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value if padding_value is not None else tokenizer.pad_token_id + self.max_prompt_length = max_prompt_length + self.truncation_mode = truncation_mode + self.max_target_length = max_target_length + self.tokenizer = tokenizer + self.precompute_ref_log_probs = precompute_ref_log_probs + + # Since ref_logs are precomputed on the first call to get_train/eval_dataloader + # keep track of first called to avoid computation of future calls + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + + if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0: + warnings.warn( + "You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter." + ) + + self.beta = beta + self.label_smoothing = label_smoothing + self.loss_type = loss_type + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # tokenize the dataset + train_dataset = train_dataset.map(self.tokenize_row) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(self.tokenize_row) + + GaudiTrainer.__init__( + self, + model=model, + args=args, + gaudi_config=gaudi_config, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + # Deepspeed Zero-3 does not support precompute_ref_log_probs + if self.is_deepspeed_enabled: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." + ) + + if self.ref_model is None: + if not (self.is_peft_model or self.precompute_ref_log_probs): + raise ValueError( + "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" + ) + else: + if self.is_deepspeed_enabled: + self.ref_model = self._prepare_deepspeed(self.ref_model) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + from habana_frameworks.torch.hpu import wrap_in_hpu_graph # use graph for ref_model + + ref_model = self.accelerator.unwrap_model(self.ref_model) + ref_model = wrap_in_hpu_graph(ref_model) + + @staticmethod + def concatenated_inputs( + batch: Dict[str, Union[List, torch.LongTensor]], + is_encoder_decoder: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + device: Optional[torch.device] = None, + padded_max_length: int = 0, + ) -> Dict[str, torch.LongTensor]: + """ + Copied from DPOTrainer.concatenated_inputs: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/dpo_trainer.py#L701 + - pad to self.max_length in Gaudi2 + """ + concatenated_batch = {} + + if is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + if padded_max_length != 0: # pad to max_length in Gaudi + max_length = padded_max_length + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(device=device) + + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1).to(device=device) + ) + + return concatenated_batch + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """ + Copied from DPOTrainer.concatenated_forward: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/dpo_trainer.py#L866 + - pad to self.max_length in Gaudi2 + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + padded_max_length=self.max_length, + ) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "labels": concatenated_batch["concatenated_labels"], + "decoder_input_ids": concatenated_batch.pop("concatenated_decoder_input_ids", None), + } + if self.is_encoder_decoder + else {} + ) + all_logits = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + **model_kwargs, + ).logits + + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) diff --git a/optimum/habana/trl/trainer/sft_trainer.py b/optimum/habana/trl/trainer/sft_trainer.py new file mode 100644 index 0000000000..49b2525f4c --- /dev/null +++ b/optimum/habana/trl/trainer/sft_trainer.py @@ -0,0 +1,244 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import dataclasses +import inspect +import warnings +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from datasets import Dataset +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + DataCollator, + DataCollatorForLanguageModeling, + PreTrainedModel, + PreTrainedTokenizerBase, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction +from trl import SFTTrainer +from trl.import_utils import is_peft_available +from trl.trainer.utils import ( + DataCollatorForCompletionOnlyLM, + peft_module_casting_to_bf16, +) + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training + +from optimum.habana import GaudiConfig, GaudiTrainer, GaudiTrainingArguments + + +class GaudiSFTTrainer(SFTTrainer, GaudiTrainer): + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + args: GaudiTrainingArguments = None, + gaudi_config: GaudiConfig = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional["PeftConfig"] = None, + dataset_text_field: Optional[str] = None, + packing: Optional[bool] = False, + formatting_func: Optional[Callable] = None, + max_seq_length: Optional[int] = None, + infinite: Optional[bool] = None, + num_of_sequences: Optional[int] = 1024, + chars_per_token: Optional[float] = 3.6, + dataset_num_proc: Optional[int] = None, + dataset_batch_size: int = 1000, + neftune_noise_alpha: Optional[float] = None, + model_init_kwargs: Optional[Dict] = None, + dataset_kwargs: Optional[Dict] = None, + ): + """ + Copied from SFTTrainer.__init__: https://github.com/huggingface/trl/blob/v0.7.6/trl/trainer/sft_trainer.py#L120 + The only differences are: + - add new args gaudi_config + - use GaudiTrainer instead of Trainer + - cast peft model to bf16. + """ + if model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the SFTTrainer. But your model is already instantiated.") + + if infinite is not None: + warnings.warn( + "The `infinite` argument is deprecated and will be removed in a future version of TRL. Use `TrainingArguments.max_steps` or `TrainingArguments.num_train_epochs` instead to control training length." + ) + + if isinstance(model, str): + warnings.warn( + "You passed a model_id to the SFTTrainer. This will automatically create an " + "`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you." + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if packing and data_collator is not None and isinstance(data_collator, DataCollatorForCompletionOnlyLM): + raise ValueError( + "You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument." + ) + + if is_peft_available() and peft_config is not None: + if not isinstance(peft_config, PeftConfig): + raise ValueError( + "If you want to use the PeftModel, you need to pass a PeftConfig object to the SFTTrainer." + f" and you passed a {type(peft_config)}." + ) + + if not isinstance(model, PeftModel): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + preprare_model_kwargs = { + "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) + } + + if _support_gc_kwargs: + preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) + + if args is not None: + args = dataclasses.replace(args, gradient_checkpointing=False) + elif getattr(args, "gradient_checkpointing", False) and ( + "use_reentrant" not in gradient_checkpointing_kwargs + or gradient_checkpointing_kwargs["use_reentrant"] + ): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + model = get_peft_model(model, peft_config) + if args.bf16: + peft_module_casting_to_bf16(model) + + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + if max_seq_length is None: + # to overcome some issues with broken tokenizers + max_seq_length = min(tokenizer.model_max_length, 1024) + + warnings.warn( + f"You didn't pass a `max_seq_length` argument to the SFTTrainer, this will default to {max_seq_length}" + ) + + self.dataset_num_proc = dataset_num_proc + self.dataset_batch_size = dataset_batch_size + + self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha") + + if neftune_noise_alpha is not None and self._trainer_supports_neftune: + args.neftune_noise_alpha = neftune_noise_alpha + warnings.warn( + "You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `TrainingArguments`." + ) + # self.neftune_noise_alpha is done at Trainer level + elif not self._trainer_supports_neftune: + self.neftune_noise_alpha = neftune_noise_alpha + + if not packing: + if dataset_text_field is None and formatting_func is None: + raise ValueError( + "You passed `packing=False` to the SFTTrainer, but you didn't pass a `dataset_text_field` or `formatting_func` argument." + ) + + if data_collator is None: + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + if dataset_kwargs is None: + dataset_kwargs = {} + if train_dataset is not None: + train_dataset = self._prepare_dataset( + train_dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + **dataset_kwargs, + ) + if eval_dataset is not None: + _multiple = isinstance(eval_dataset, dict) + _eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset} + for _eval_dataset_name, _eval_dataset in _eval_datasets.items(): + _eval_datasets[_eval_dataset_name] = self._prepare_dataset( + _eval_dataset, + tokenizer, + packing, + dataset_text_field, + max_seq_length, + formatting_func, + num_of_sequences, + chars_per_token, + **dataset_kwargs, + ) + if not _multiple: + eval_dataset = _eval_datasets["singleton"] + + if tokenizer.padding_side is not None and tokenizer.padding_side != "right": + warnings.warn( + "You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to " + "overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code." + ) + + GaudiTrainer.__init__( + self, + model=model, + args=args, + gaudi_config=gaudi_config, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + if self.args.max_steps > 0 and packing: + warnings.warn( + "You passed `packing=True` to the SFTTrainer, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached." + ) + self.train_dataset.infinite = True + elif self.args.max_steps == -1 and packing: + self.train_dataset.infinite = False From cba23b36d2c99e5514349a81a9d0ce3b1146d32a Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Wed, 27 Dec 2023 21:31:20 -0800 Subject: [PATCH 05/14] Falcon graph compilation error fix for when bs>1 (#607) --- .../models/falcon/modeling_falcon.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index 2f11870450..b6b4887407 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -54,8 +54,25 @@ def gaudi_falcon_rotary_embedding_forward(self, query, key, seq_len, position_id """ cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype) + query_expansion_factor = int(query.shape[0] / cos.shape[0]) + if query_expansion_factor > 1: + query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0) + query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0) + else: + query_cos, query_sin = cos, sin + + key_expansion_factor = int(key.shape[0] / cos.shape[0]) + if key_expansion_factor > 1: + if key_expansion_factor != query_expansion_factor: + key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0) + key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0) + else: + key_cos, key_sin = query_cos, query_sin + else: + key_cos, key_sin = cos, sin + if FusedRoPE: - return FusedRoPE.apply(query, cos, sin, 0), FusedRoPE.apply(key, cos, sin, 0) + return FusedRoPE.apply(query, query_cos, query_sin, 0), FusedRoPE.apply(key, key_cos, key_sin, 0) else: return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin) From a5d7dec7bd0a54f5c2aaf9bfcc1d3c882b6187d5 Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Fri, 29 Dec 2023 16:07:29 +0100 Subject: [PATCH 06/14] Temporary fix for Diffusers CI (#618) --- Makefile | 1 - 1 file changed, 1 deletion(-) diff --git a/Makefile b/Makefile index 60cb27abde..2a45b8bfd4 100644 --- a/Makefile +++ b/Makefile @@ -53,7 +53,6 @@ slow_tests_deepspeed: test_installs python -m pytest tests/test_examples.py -v -s -k "deepspeed" slow_tests_diffusers: test_installs - python -m pip install git+https://github.com/huggingface/diffusers.git python -m pytest tests/test_diffusers.py -v -s -k "test_no_" python -m pytest tests/test_diffusers.py -v -s -k "test_textual_inversion" From 0152688830c6f410fd51d5b22c37652d915004db Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Fri, 29 Dec 2023 23:24:37 +0800 Subject: [PATCH 07/14] Fix crash if gaudi_config is not passed to GaudiTrainer (#613) --- optimum/habana/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 0ff64b560a..13fbcf1799 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -214,7 +214,7 @@ def __init__( ) if self.use_hpu_amp and "LOWER_LIST" not in os.environ: - gaudi_config.declare_autocast_bf16_fp32_ops() + self.gaudi_config.declare_autocast_bf16_fp32_ops() if self.args.use_lazy_mode: try: From 6a1521bc9c09d21aa346c42c20a2513f7ccf6b48 Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Fri, 29 Dec 2023 21:02:51 +0530 Subject: [PATCH 08/14] Update generation config to enable flash attention for inference (#609) --- examples/text-generation/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 225ae7c5ad..5c03de7dc6 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -328,6 +328,7 @@ def setup_generation_config(args, model, tokenizer): if generation_config.reduce_recompile: assert generation_config.bucket_size > 0 generation_config.kv_cache_fp8 = args.kv_cache_fp8 + generation_config.use_flash_attention = args.use_flash_attention return generation_config From b0421c1a227aa54e390d711e1a828d10394a92e4 Mon Sep 17 00:00:00 2001 From: Sun Choi Date: Wed, 3 Jan 2024 01:55:37 -0800 Subject: [PATCH 09/14] Avoid falcon perf drop from PR#607 when BS=1 (#620) --- optimum/habana/transformers/models/falcon/modeling_falcon.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index b6b4887407..ebb141fa5d 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -55,14 +55,14 @@ def gaudi_falcon_rotary_embedding_forward(self, query, key, seq_len, position_id cos, sin = self.cos_sin(seq_len, past_key_values_length, position_ids, query.device, query.dtype) query_expansion_factor = int(query.shape[0] / cos.shape[0]) - if query_expansion_factor > 1: + if query_expansion_factor > 1 and cos.shape[0] > 1: query_cos = torch.repeat_interleave(cos, query_expansion_factor, dim=0) query_sin = torch.repeat_interleave(sin, query_expansion_factor, dim=0) else: query_cos, query_sin = cos, sin key_expansion_factor = int(key.shape[0] / cos.shape[0]) - if key_expansion_factor > 1: + if key_expansion_factor > 1 and cos.shape[0] > 1: if key_expansion_factor != query_expansion_factor: key_cos = torch.repeat_interleave(cos, key_expansion_factor, dim=0) key_sin = torch.repeat_interleave(sin, key_expansion_factor, dim=0) From dd02a7b375c906adf2d2aa9ddc616e1a2bfe790c Mon Sep 17 00:00:00 2001 From: Bhargav Date: Wed, 3 Jan 2024 23:41:00 +0530 Subject: [PATCH 10/14] Adding support for bf16_full_eval (#610) --- examples/summarization/README.md | 3 ++- optimum/habana/transformers/trainer.py | 21 +++++++++++++++++++ optimum/habana/transformers/training_args.py | 3 --- ...test_encoder_decoder_text_summarization.py | 6 ++++-- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/examples/summarization/README.md b/examples/summarization/README.md index 8ebed989ed..3bdf382c72 100644 --- a/examples/summarization/README.md +++ b/examples/summarization/README.md @@ -227,7 +227,8 @@ python run_summarization.py \ --gaudi_config_name Habana/t5 \ --ignore_pad_token_for_loss False \ --pad_to_max_length \ - --bf16 + --bf16 \ + --bf16_full_eval ``` You can run inference with BART on the CNN-DailyMail dataset on 1 Gaudi card with the following command: diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 13fbcf1799..c393733f22 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -428,6 +428,11 @@ def train( self.is_in_train = True + # do_train is not a reliable argument, as it might not be set and .train() still called, so + # the following is a workaround: + if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: + self._move_model_to_device(self.model, args.device) + if "model_path" in kwargs: resume_from_checkpoint = kwargs.pop("model_path") warnings.warn( @@ -1510,6 +1515,14 @@ def evaluation_loop( ) self.already_wrapped_for_hpu_graphs = True + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + batch_size = self.args.eval_batch_size logger.info(f"***** Running {description} *****") @@ -1903,6 +1916,14 @@ def prediction_loop( ) self.already_wrapped_for_hpu_graphs = True + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + batch_size = dataloader.batch_size num_examples = self.num_examples(dataloader) logger.info(f"***** Running {description} *****") diff --git a/optimum/habana/transformers/training_args.py b/optimum/habana/transformers/training_args.py index 740e682f75..2cc6917c0d 100644 --- a/optimum/habana/transformers/training_args.py +++ b/optimum/habana/transformers/training_args.py @@ -54,7 +54,6 @@ # List of arguments that are not supported by optimum-habana UNSUPPORTED_ARGUMENTS = [ - "bf16_full_eval", "fp16", "fp16_backend", "fp16_full_eval", @@ -314,8 +313,6 @@ def __post_init__(self): raise ValueError("must be using hpu graphs to set max_hpu_graphs.") # Raise errors for arguments that are not supported by optimum-habana - if self.bf16_full_eval: - raise ValueError("--bf16_full_eval is not supported by optimum-habana.") if self.fp16 or self.fp16_full_eval: raise ValueError( "--fp16, --fp16_backend, --fp16_full_eval and --fp16_opt_level are not" diff --git a/tests/test_encoder_decoder_text_summarization.py b/tests/test_encoder_decoder_text_summarization.py index 506ae0e04e..197cc37e7a 100644 --- a/tests/test_encoder_decoder_text_summarization.py +++ b/tests/test_encoder_decoder_text_summarization.py @@ -15,7 +15,7 @@ MODELS_TO_TEST = { "bf16": [ ("facebook/bart-large-cnn", "Habana/bart", 4.691, 26.0688, 2, 1), - ("t5-3b", "Habana/t5", 2.28, 21.56, 2, 1), + ("t5-3b", "Habana/t5", 2.88, 21.56, 2, 1), ], } else: @@ -23,7 +23,7 @@ MODELS_TO_TEST = { "bf16": [ ("facebook/bart-large-cnn", "Habana/bart", 2.588, 26.0688, 2, 1), - ("t5-3b", "Habana/t5", 0.585, 21.72, 2, 1), + ("t5-3b", "Habana/t5", 0.98, 21.56, 2, 1), ], } @@ -76,6 +76,8 @@ def _test_text_summarization( if not deepspeed: command.append("--bf16") + if model_name == "t5-3b": + command.append("--bf16_full_eval") with TemporaryDirectory() as tmp_dir: command.append(f"--output_dir {tmp_dir}") From b72d8eaa91b3526c31be3bffc979fbb7eefd2a65 Mon Sep 17 00:00:00 2001 From: Puneesh Khanna Date: Thu, 4 Jan 2024 02:23:24 +0530 Subject: [PATCH 11/14] Enable fused rmsnorm in bf16 for llama (#621) --- optimum/habana/transformers/models/llama/modeling_llama.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index c24084e5b4..d0e217c38d 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -60,9 +60,8 @@ def gaudi_llama_rmsnorm_forward(self, hidden_states): - override RMSNorm with Habana fused RMSNorm """ if hidden_states.device.type == "hpu" and FusedRMSNorm: - orig_dtype = hidden_states.dtype - hidden_states = FusedRMSNorm.apply(hidden_states.float(), self.weight.float(), self.variance_epsilon) - return hidden_states.to(orig_dtype) + hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) + return hidden_states else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) From e419599b68f7269586eed0ee9ca00c4aa4d5dc28 Mon Sep 17 00:00:00 2001 From: Siddhant Jagtap <91691786+sjagtap1803@users.noreply.github.com> Date: Thu, 4 Jan 2024 11:55:24 -0600 Subject: [PATCH 12/14] Text-Generation Pipeline Example (#526) --- examples/text-generation/README.md | 4 + examples/text-generation/run_generation.py | 2 + .../text-generation-pipeline/README.md | 127 ++++++++++++++++++ .../text-generation-pipeline/pipeline.py | 47 +++++++ .../text-generation-pipeline/run_pipeline.py | 52 +++++++ 5 files changed, 232 insertions(+) create mode 100644 examples/text-generation/text-generation-pipeline/README.md create mode 100644 examples/text-generation/text-generation-pipeline/pipeline.py create mode 100644 examples/text-generation/text-generation-pipeline/run_pipeline.py diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 63ee4d6a33..cafe16d897 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -287,3 +287,7 @@ deepspeed --num_gpus 8 run_lm_eval.py \ --tasks winogrande \ -o eval.json ``` + +## Text-Generation Pipeline + +A Transformers-like pipeline is defined and provided [here](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation/text-generation-pipeline). It is optimized for Gaudi and can be called to generate text in your scripts. diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 445794048f..ab308e7023 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -231,6 +231,8 @@ def setup_parser(parser): action="store_true", help="Whether to enable Habana Flash Attention, provided that the model supports it.", ) + parser.add_argument("--temperature", default=1.0, type=float, help="Temperature value for text generation") + parser.add_argument("--top_p", default=1.0, type=float, help="Top_p value for generating text via sampling") args = parser.parse_args() diff --git a/examples/text-generation/text-generation-pipeline/README.md b/examples/text-generation/text-generation-pipeline/README.md new file mode 100644 index 0000000000..2fc93a6ca2 --- /dev/null +++ b/examples/text-generation/text-generation-pipeline/README.md @@ -0,0 +1,127 @@ + + +# Text-Generation Pipeline + +The text-generation pipeline can be used to perform text-generation by providing single or muliple prompts as input. + +## Requirements + +Update `PYTHONPATH` as follows. +```bash +export OPTIMUM_HABANA_PATH=/path/to/optimum-habana +export PYTHONPATH=${PYTHONPATH}:${OPTIMUM_HABANA_PATH}/examples/text-generation +``` + +If you plan to use [DeepSpeed-inference](https://docs.habana.ai/en/latest/PyTorch/DeepSpeed/Inference_Using_DeepSpeed.html), you should install DeepSpeed as follows: +```bash +pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.13.0 +``` + +## Usage + +To run generation with DeepSpeed-inference, you must launch the script as follows: + +```bash +python ../../gaudi_spawn.py --use_deepspeed --world_size number_of_devices run_pipeline.py ARGS +``` + +Without DeepSpeed-inference, you can run the script with: + +```bash +python run_pipeline.py ARGS +``` + +The list of all possible arguments can be obtained running: +```bash +python run_pipeline.py --help +``` + + +### Single and multiple prompts + +If you want to generate a sequence of text from a prompt of your choice, you should use the `--prompt` argument. +For example: +``` +python run_pipeline.py \ +--model_name_or_path meta-llama/Llama-2-7b-hf \ +--use_hpu_graphs \ +--use_kv_cache \ +--max_new_tokens 100 \ +--do_sample \ +--prompt "Here is my prompt" +``` + +If you want to provide several prompts as inputs, here is how to do it: +``` +python run_pipeline.py \ +--model_name_or_path meta-llama/Llama-2-7b-hf \ +--use_hpu_graphs \ +--use_kv_cache \ +--max_new_tokens 100 \ +--do_sample \ +--prompt "Hello world" "How are you?" +``` + +If you want to perform generation on default prompts, do not pass the `--prompt` argument. +``` +python run_pipeline.py \ +--model_name_or_path meta-llama/Llama-2-7b-hf \ +--use_hpu_graphs \ +--use_kv_cache \ +--max_new_tokens 100 \ +--do_sample +``` + +If you want to change the temperature and top_p values, make sure to include the `--do_sample` argument. Here is a sample command. +``` +python run_pipeline.py \ +--model_name_or_path meta-llama/Llama-2-7b-hf \ +--use_hpu_graphs \ +--use_kv_cache \ +--max_new_tokens 100 \ +--do_sample \ +--temperature 0.5 \ +--top_p 0.95 \ +--prompt "Hello world" "How are you?" +``` + +### Multi-card runs + +To run a large model such as Llama-2-70b via DeepSpeed, run the following command. +``` +python ../../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \ +--model_name_or_path meta-llama/Llama-2-70b-hf \ +--max_new_tokens 100 \ +--bf16 \ +--use_hpu_graphs \ +--use_kv_cache \ +--prompt "Hello world" "How are you?" "Here is my prompt" "Once upon a time" +``` + +To change the temperature and top_p values, run the following command. +``` +python ../../gaudi_spawn.py --use_deepspeed --world_size 8 run_pipeline.py \ +--model_name_or_path meta-llama/Llama-2-70b-hf \ +--max_new_tokens 100 \ +--bf16 \ +--use_hpu_graphs \ +--use_kv_cache \ +--do_sample \ +--temperature 0.5 \ +--top_p 0.95 \ +--prompt "Hello world" "How are you?" "Here is my prompt" "Once upon a time" +``` diff --git a/examples/text-generation/text-generation-pipeline/pipeline.py b/examples/text-generation/text-generation-pipeline/pipeline.py new file mode 100644 index 0000000000..0c2905a731 --- /dev/null +++ b/examples/text-generation/text-generation-pipeline/pipeline.py @@ -0,0 +1,47 @@ +import torch +from transformers import TextGenerationPipeline +from utils import initialize_model + + +class GaudiTextGenerationPipeline(TextGenerationPipeline): + def __init__(self, args, logger): + self.model, self.tokenizer, self.generation_config = initialize_model(args, logger) + + self.device = args.device + + if args.do_sample: + self.generation_config.temperature = args.temperature + self.generation_config.top_p = args.top_p + + self.max_padding_length = args.max_input_tokens if args.max_input_tokens > 0 else 100 + self.use_hpu_graphs = args.use_hpu_graphs + self.profiling_steps = args.profiling_steps + self.profiling_warmup_steps = args.profiling_warmup_steps + + import habana_frameworks.torch.hpu as torch_hpu + + logger.info("Graph compilation...") + for _ in range(3): + self("Here is my prompt") + torch_hpu.synchronize() + + def __call__(self, prompt: str): + model_inputs = self.tokenizer.encode_plus( + prompt, return_tensors="pt", max_length=self.max_padding_length, padding="max_length", truncation=True + ) + + for t in model_inputs: + if torch.is_tensor(model_inputs[t]): + model_inputs[t] = model_inputs[t].to(self.device) + + output = self.model.generate( + **model_inputs, + generation_config=self.generation_config, + lazy_mode=True, + hpu_graphs=self.use_hpu_graphs, + profiling_steps=self.profiling_steps, + profiling_warmup_steps=self.profiling_warmup_steps, + ).cpu() + + output_text = self.tokenizer.decode(output[0], skip_special_tokens=True) + return output_text diff --git a/examples/text-generation/text-generation-pipeline/run_pipeline.py b/examples/text-generation/text-generation-pipeline/run_pipeline.py new file mode 100644 index 0000000000..03bbaa6e91 --- /dev/null +++ b/examples/text-generation/text-generation-pipeline/run_pipeline.py @@ -0,0 +1,52 @@ +import argparse +import logging +import time + +from pipeline import GaudiTextGenerationPipeline +from run_generation import setup_parser + + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, +) +logger = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser() + args = setup_parser(parser) + args.num_return_sequences = 1 + + if args.prompt: + input_sentences = args.prompt + else: + input_sentences = [ + "DeepSpeed is a machine learning framework", + "He is working on", + "He has a", + "He got all", + "Everyone is happy and I can", + "The new movie that got Oscar this year", + "In the far far distance from our galaxy,", + "Peace is the only way", + ] + + logger.info("Initializing text-generation pipeline...") + pipe = GaudiTextGenerationPipeline(args, logger) + + logger.info("Running inference...") + for input_sentence in input_sentences: + print(f"Prompt: {input_sentence}") + t0 = time.perf_counter() + output = pipe(input_sentence) + duration = time.perf_counter() - t0 + throughput = args.max_new_tokens / duration + print(f"Generated Text: {repr(output)}") + print(f"Inference Duration: {duration} seconds") + print(f"Throughput: {throughput} tokens/second") + + +if __name__ == "__main__": + main() From 8fb43a454a60c71bf4b8285a2ea1c18b38571f7a Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Fri, 5 Jan 2024 10:26:27 +0100 Subject: [PATCH 13/14] Update CI diff file (#624) --- tests/example_diff/run_generation.txt | 31 +++++++++++++++------------ 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/tests/example_diff/run_generation.txt b/tests/example_diff/run_generation.txt index 094d6f4a44..a29c040b55 100644 --- a/tests/example_diff/run_generation.txt +++ b/tests/example_diff/run_generation.txt @@ -498,25 +498,28 @@ < help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", --- > help="Whether to enable Habana Flash Attention, provided that the model supports it.", -335,336d233 +335,336c234,235 < parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference") < args = parser.parse_args() -338,339c235 +--- +> parser.add_argument("--temperature", default=1.0, type=float, help="Temperature value for text generation") +> parser.add_argument("--top_p", default=1.0, type=float, help="Top_p value for generating text via sampling") +338,339c237 < # Initialize the distributed state. < distributed_state = PartialState(cpu=args.use_cpu) --- > args = parser.parse_args() -341c237,238 +341c239,240 < logger.warning(f"device: {distributed_state.device}, 16-bits inference: {args.fp16}") --- > if not args.use_hpu_graphs: > args.limit_hpu_graphs = False -343,344c240 +343,344c242 < if args.seed is not None: < set_seed(args.seed) --- > return args -346,373d241 +346,373d243 < # Initialize the model and tokenizer < try: < args.model_type = args.model_type.lower() @@ -545,7 +548,7 @@ < if requires_preprocessing: < prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) < preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) -375,378c243,246 +375,378c245,248 < if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: < tokenizer_kwargs = {"add_space_before_punct_symbol": True} < else: @@ -555,7 +558,7 @@ > parser = argparse.ArgumentParser() > args = setup_parser(parser) > model, tokenizer, generation_config = initialize_model(args, logger) -380,386c248 +380,386c250 < encoded_prompt = tokenizer.encode( < preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs < ) @@ -565,7 +568,7 @@ < encoded_prompt = encoded_prompt.to(distributed_state.device) --- > import habana_frameworks.torch.hpu as torch_hpu -388,389c250,393 +388,389c252,395 < if encoded_prompt.size()[-1] == 0: < input_ids = None --- @@ -713,7 +716,7 @@ > print(f"Graph compilation duration = {compilation_duration} seconds") > print(separator) > print() -391c395,412 +391c397,414 < input_ids = encoded_prompt --- > # Downloading and loading a dataset from the hub. @@ -734,7 +737,7 @@ > .shuffle() > .select(range(args.dataset_max_samples if args.dataset_max_samples > 0 else (raw_dataset[split]).num_rows)) > ) -393,399c414,421 +393,399c416,423 < if args.jit: < jit_input_texts = ["enable jit"] < jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer) @@ -751,7 +754,7 @@ > logger.info( > f"No column name was given so automatically choosing '{column_name}' for prompts. If you would like to use another column of the dataset, you can set the argument `--column_name`." > ) -401,439c423,443 +401,439c425,445 < sig = inspect.signature(model.__call__) < jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None) < traced_model = torch.jit.trace(model, jit_inputs, strict=False) @@ -813,7 +816,7 @@ > preprocess_function, > batched=True, > desc="Running tokenizer on dataset", -440a445,522 +440a447,524 > # After tokenization, we can remove the column of interest > raw_dataset = raw_dataset.remove_columns([column_name]) > raw_dataset.set_format(type="torch") @@ -892,13 +895,13 @@ > ) > print(separator) > t_end = time.time() -442,443c524,525 +442,443c526,527 < generated_sequences.append(total_sequence) < print(total_sequence) --- > throughput = total_new_tokens_generated / duration > # Print Stats -445c527,539 +445c529,541 < return generated_sequences --- > stats = f"Throughput (including tokenization) = {throughput} tokens/second" From bf9ab7ded9f23c87d718705d47bc14f5884ab266 Mon Sep 17 00:00:00 2001 From: Barak Goldberg Date: Thu, 7 Dec 2023 14:43:07 +0200 Subject: [PATCH 14/14] enable HQT Change-Id: I5f952e2f8d2f9db6d6be41d4069d8f5a4e21dfa9 --- ...xabs_hw_weights_pcs_maxabs_pow2_quant.json | 10 +++ .../quantization_config/maxabs_measure.json | 9 ++ .../maxabs_pcq_measure.json | 8 ++ .../quantization_config/maxabs_quant.json | 10 +++ .../quantization_config/shape_measure.json | 8 ++ .../quantization_config/unit_scale_quant.json | 10 +++ .../without_scale_quant.json | 10 +++ examples/text-generation/run_generation.py | 11 ++- examples/text-generation/run_lm_eval.py | 13 ++- examples/text-generation/utils.py | 19 ++-- .../models/llama/modeling_llama.py | 89 +++++++++++-------- 11 files changed, 154 insertions(+), 43 deletions(-) create mode 100644 examples/text-generation/quantization_config/act_maxabs_hw_weights_pcs_maxabs_pow2_quant.json create mode 100644 examples/text-generation/quantization_config/maxabs_measure.json create mode 100644 examples/text-generation/quantization_config/maxabs_pcq_measure.json create mode 100644 examples/text-generation/quantization_config/maxabs_quant.json create mode 100644 examples/text-generation/quantization_config/shape_measure.json create mode 100644 examples/text-generation/quantization_config/unit_scale_quant.json create mode 100644 examples/text-generation/quantization_config/without_scale_quant.json diff --git a/examples/text-generation/quantization_config/act_maxabs_hw_weights_pcs_maxabs_pow2_quant.json b/examples/text-generation/quantization_config/act_maxabs_hw_weights_pcs_maxabs_pow2_quant.json new file mode 100644 index 0000000000..bce5e3a102 --- /dev/null +++ b/examples/text-generation/quantization_config/act_maxabs_hw_weights_pcs_maxabs_pow2_quant.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs_per_channel", + "scale_method": "ACT_MAXABS_HW_WEIGHTS_PCS_MAXABS_POW2", + "whitelist": {"types": [], "names": []}, + "blacklist": {"types": [], "names": ["lm_head"]}, + "dump_stats_path": "./llama_output", + "dump_stats_xlsx_path": "./run_outputs/fp8stats.xlsx" +} diff --git a/examples/text-generation/quantization_config/maxabs_measure.json b/examples/text-generation/quantization_config/maxabs_measure.json new file mode 100644 index 0000000000..507cc003be --- /dev/null +++ b/examples/text-generation/quantization_config/maxabs_measure.json @@ -0,0 +1,9 @@ +{ + "method": "HOOKS", + "mode": "MEASURE", + "observer": "maxabs", + "whitelist": {"types": [], "names": []}, + "blacklist": {"types": [], "names": []}, + "dump_stats_path": "./llama_output/7b_measure", + "dump_stats_xlsx_path": "./llama_output/7b_measure/7b_fp8stats.xlsx" +} \ No newline at end of file diff --git a/examples/text-generation/quantization_config/maxabs_pcq_measure.json b/examples/text-generation/quantization_config/maxabs_pcq_measure.json new file mode 100644 index 0000000000..f4ebf6cbb3 --- /dev/null +++ b/examples/text-generation/quantization_config/maxabs_pcq_measure.json @@ -0,0 +1,8 @@ +{ + "method": "HOOKS", + "mode": "MEASURE", + "observer": "maxabs_per_channel", + "blacklist": {"types": [], "names": ["lm_head"]}, + "dump_stats_path": "./llama_output", + "dump_stats_xlsx_path": "./run_outputs/fp8stats.xlsx" +} diff --git a/examples/text-generation/quantization_config/maxabs_quant.json b/examples/text-generation/quantization_config/maxabs_quant.json new file mode 100644 index 0000000000..2f9683ba63 --- /dev/null +++ b/examples/text-generation/quantization_config/maxabs_quant.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "whitelist": {"types": [], "names": []}, + "blacklist": {"types": [], "names": ["lm_head"]}, + "dump_stats_path": "./llama_output/7b_measure", + "dump_stats_xlsx_path": "./llama_output/7b_measure/7b_fp8stats.xlsx" +} \ No newline at end of file diff --git a/examples/text-generation/quantization_config/shape_measure.json b/examples/text-generation/quantization_config/shape_measure.json new file mode 100644 index 0000000000..3204b12d12 --- /dev/null +++ b/examples/text-generation/quantization_config/shape_measure.json @@ -0,0 +1,8 @@ +{ + "method": "HOOKS", + "mode": "MEASURE", + "observer": "shape", + "blacklist": {"types": [], "names": ["lm_head"]}, + "dump_stats_path": "./llama_output", + "dump_stats_xlsx_path": "./run_outputs/fp8stats.xlsx" +} diff --git a/examples/text-generation/quantization_config/unit_scale_quant.json b/examples/text-generation/quantization_config/unit_scale_quant.json new file mode 100644 index 0000000000..f32daa1d5c --- /dev/null +++ b/examples/text-generation/quantization_config/unit_scale_quant.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "unit_scale", + "whitelist": {"types": [], "names": []}, + "blacklist": {"types": [], "names": []}, + "dump_stats_path": "./llama_output", + "dump_stats_xlsx_path": "./run_outputs/fp8stats.xlsx" +} diff --git a/examples/text-generation/quantization_config/without_scale_quant.json b/examples/text-generation/quantization_config/without_scale_quant.json new file mode 100644 index 0000000000..fec5d0aeb6 --- /dev/null +++ b/examples/text-generation/quantization_config/without_scale_quant.json @@ -0,0 +1,10 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "without_scale", + "whitelist": {"types": [], "names": []}, + "blacklist": {"types": [], "names": []}, + "dump_stats_path": "./llama_output", + "dump_stats_xlsx_path": "./run_outputs/fp8stats.xlsx" +} diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index ab308e7023..ecfbe15dca 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -22,6 +22,7 @@ import json import logging import math +import os import time from itertools import cycle from pathlib import Path @@ -223,7 +224,7 @@ def setup_parser(parser): parser.add_argument( "--kv_cache_fp8", action="store_true", - help="Store kv-cache in float8 when kv-cache is used", + help="Store kv-cache in float8 when kv-cache is used. Can't use this argument together with QUANT_CONFIG env var", ) parser.add_argument("--fp8", action="store_true", help="Enable Quantization to fp8") parser.add_argument( @@ -239,6 +240,11 @@ def setup_parser(parser): if not args.use_hpu_graphs: args.limit_hpu_graphs = False + args.quant_config = os.getenv("QUANT_CONFIG", "") + if args.quant_config and args.kv_cache_fp8: + # can't use both quant_config and kv_cache_fp8, since quant_config may trigger kv cache quantization + # with habana quantization toolkit + raise parser.error("Can't use QUANT_CONFIG env var with kv_cache_fp8 argument") return args @@ -539,6 +545,9 @@ def generate_dataset(batch): if prompt_length > 0: print(f"Graph compilation duration = {compilation_duration} seconds") print(separator) + if args.quant_config: + import habana_quantization_toolkit + habana_quantization_toolkit.finish_measurements(model) if __name__ == "__main__": diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index f16c126af2..e59393d71b 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -78,6 +78,7 @@ def __init__(self, tokenizer, model, args, options): if self.model.config.model_type == "llama": self.model_inputs.update( { + "reuse_cache" : self.options.reuse_cache, "attn_softmax_bf16": self.options.attn_softmax_bf16, } ) @@ -125,10 +126,17 @@ def find_bucket(self, length): return [b for b in self.buckets if b >= length][0] def _model_call(self, inps): - seq_length = inps.shape[-1] + bs, seq_length = inps.shape padding_length = 0 if self.options.static_shapes: bucket_length = self.find_bucket(seq_length) + if self.options.use_cache and self.options.reuse_cache: + self.model.allocate_kv_cache( + bs, + bucket_length + 1, + bucket_length, + False, + ) padding_length = bucket_length - seq_length inps = F.pad(inps, (0, padding_length), value=self.model.config.pad_token_id) logits = self.model(inps.to(self._device), **self.model_inputs)["logits"].cpu() @@ -164,6 +172,9 @@ def main(): print("{:35} = {} GB".format(k[:-5].replace("_", " ").capitalize(), v)) json.dump(results, open(args.output_file, "w"), indent=2) print(json.dumps(results, indent=2)) + if args.quant_config: + import habana_quantization_toolkit + habana_quantization_toolkit.finish_measurements(model) if __name__ == "__main__": diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 5c03de7dc6..f7e94d82e5 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -96,7 +96,7 @@ def setup_distributed(args): args.global_rank = int(os.getenv("RANK", "0")) -def setup_quantization(model): +def setup_quantization(args, model): import habana_frameworks.torch.core as htcore from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const from habana_frameworks.torch.hpu import hpu @@ -104,8 +104,8 @@ def setup_quantization(model): print("Initializing inference with quantization") _mark_params_as_const(model) _check_params_as_const(model) - - hpu.enable_quantization() + if not args.quant_config: + hpu.enable_quantization() htcore.hpu_initialize(model) return model @@ -114,6 +114,8 @@ def setup_env(args): # Will error if the minimal version of Transformers is not installed. Remove at your own risks. check_min_version("4.34.0") check_optimum_habana_min_version("1.9.0.dev0") + #TODO: SW-167588 - WA for memory issue in hqt prep_model + os.environ.setdefault('EXPERIMENTAL_WEIGHT_SHARING', 'FALSE') if args.global_rank == 0: os.environ.setdefault("GRAPH_VISUALIZATION", "true") @@ -158,6 +160,9 @@ def setup_model(args, model_dtype, model_kwargs, logger): model = peft_model(args, model_dtype, logger, **model_kwargs) else: model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs) + if args.quant_config: + import habana_quantization_toolkit + habana_quantization_toolkit.prep_model(model) model = model.eval().to(args.device) if args.use_hpu_graphs: @@ -178,7 +183,7 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): logger.info("DeepSpeed is enabled.") deepspeed.init_distributed(dist_backend="hccl") - config = AutoConfig.from_pretrained(args.model_name_or_path, **model_kwargs) + config = AutoConfig.from_pretrained(args.model_name_or_path, torch_dtype=model_dtype, **model_kwargs) load_to_meta = model_on_meta(config) if load_to_meta: @@ -227,6 +232,10 @@ def setup_distributed_model(args, model_dtype, model_kwargs, logger): model = model.module if model.config.model_type == "llama": patch_scoped_linear_all_reduce(model) + + if args.quant_config: + import habana_quantization_toolkit + habana_quantization_toolkit.prep_model(model) return model @@ -359,7 +368,7 @@ def initialize_model(args, logger): tokenizer, model = setup_tokenizer(args, model) generation_config = setup_generation_config(args, model, tokenizer) if args.fp8: - model = setup_quantization(model) + model = setup_quantization(args, model) init_end = time.perf_counter() logger.info(f"Args: {args}") logger.info(f"device: {args.device}, n_hpu: {args.world_size}, bf16: {model_dtype == torch.bfloat16}") diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index d0e217c38d..136da16e84 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -39,7 +39,9 @@ def update(prev, cur, dim, idx, inp_seq_len): orig_cur = cur - cur = cur.to(dtype=prev.dtype) + if prev.dtype == torch.float8_e4m3fn: + from habana_frameworks.torch.hpex.kernels.Fp8Ops import cast_to_fp8_v2 + cur = cast_to_fp8_v2(cur, None, False, False, prev.dtype)[0] if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: # Initialize prev[:, :, :inp_seq_len, :].copy_(cur) @@ -92,6 +94,32 @@ def __init__(self): def forward(self, x, y): return torch.matmul(x, y) +class KVCache(torch.nn.Module): + + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, kv_cache_fp8, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + if kv_cache_fp8: + dtype = torch.float8_e4m3fn + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return update(self.cache, cur, dim, idx, self.inp_seq_len) class GaudiLlamaAttention(LlamaAttention): def __init__(self, config: LlamaConfig): @@ -99,28 +127,19 @@ def __init__(self, config: LlamaConfig): self.matmul_qk = Matmul() self.matmul_av = Matmul() - self.past_key = None - self.past_value = None + self.k_cache = KVCache() + self.v_cache = KVCache() self.inp_seq_len = -1 + self.register_buffer("norm_factor", torch.tensor(1.0 / math.sqrt(self.head_dim)), persistent=False) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len, kv_cache_fp8): - key_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) - value_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) - if self.past_key is None or self.past_key.shape != key_shape: - self.inp_seq_len = inp_seq_len - device = self.k_proj.weight.device - dtype = self.k_proj.weight.dtype - if kv_cache_fp8: - dtype = torch.float8_e4m3fn - self.past_key = torch.zeros(key_shape, dtype=dtype, device=device) - self.past_value = torch.zeros(value_shape, dtype=dtype, device=device) - else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - self.past_key.fill_(0) - self.past_value.fill_(0) + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, kv_cache_fp8, dtype, device, cache_shape) def update_sincos_cache(self, seq_len): # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings @@ -135,14 +154,14 @@ def reorder(self, tensor, beam_idx, dim_a, dim_b): tensor.copy_(updated) def reorder_kv_cache(self, beam_idx: torch.LongTensor): - if self.past_key is None: + if self.k_cache.cache is None: return (None, None) - head_dim = self.past_key.size(-1) - seq_length = self.past_key.size(-2) - self.reorder(self.past_key, beam_idx, seq_length, head_dim) - self.reorder(self.past_value, beam_idx, seq_length, head_dim) - return (self.past_key.shape, self.past_value.shape) + head_dim = self.k_cache.cache.size(-1) + seq_length = self.k_cache.cache.size(-2) + self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) + self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) + return (self.k_cache.cache.shape, self.v_cache.cache.shape) def pre_attn_forward( self, @@ -212,17 +231,15 @@ def pre_attn_forward( if past_key_value is not None or reuse_cache: # reuse k, v, self_attention if reuse_cache: - past_key = self.past_key - past_value = self.past_value + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) else: - past_key = past_key_value[0] - past_value = past_key_value[1] - key_states = update(past_key, key_states, 2, token_idx, self.inp_seq_len) - value_states = update(past_value, value_states, 2, token_idx, self.inp_seq_len) + key_states = update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) if use_cache: if reuse_cache: - past_key_value = (self.past_key.shape, self.past_value.shape) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: past_key_value = (key_states.contiguous(), value_states.contiguous()) else: @@ -289,11 +306,11 @@ def pre_attn_forward( return attn_output, attn_weights, past_key_value def attention_all_reduce(self, attn_output): - if self.o_proj.__class__ is ScopedLinearAllReduce: + if hasattr(self.o_proj, "all_reduce"): self.o_proj.all_reduce(attn_output) def post_attn_forward(self, attn_output): - if self.o_proj.__class__ is ScopedLinearAllReduce: + if hasattr(self.o_proj, "post_all_reduce"): self.o_proj.post_all_reduce(attn_output) return attn_output @@ -322,13 +339,13 @@ def pre_mlp_forward(self, x): return output def mlp_all_reduce(self, x): - if self.down_proj.__class__ is ScopedLinearAllReduce: + if hasattr(self.down_proj, "all_reduce"): self.down_proj.all_reduce(x) def post_mlp_forward(self, x): if self.config.pretraining_tp > 1: return x - if self.down_proj.__class__ is ScopedLinearAllReduce: + if hasattr(self.down_proj, "post_all_reduce"): return self.down_proj.post_all_reduce(x) return x