diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index e69a418f305..37ca392e4e1 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -37,6 +37,8 @@
title: CPO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
+ - local: alignprop_trainer
+ title: AlignProp Trainer
- local: orpo_trainer
title: ORPO Trainer
- local: iterative_sft_trainer
diff --git a/docs/source/alignprop_trainer.mdx b/docs/source/alignprop_trainer.mdx
new file mode 100644
index 00000000000..f1c508f529e
--- /dev/null
+++ b/docs/source/alignprop_trainer.mdx
@@ -0,0 +1,91 @@
+# Aligning Text-to-Image Diffusion Models with Reward Backpropagation
+
+## The why
+
+If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
+AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
+
+

+
+
+## Getting started with `examples/scripts/alignprop.py`
+
+The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`).
+
+**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.
+
+Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
+
+```batch
+python alignprop.py --hf_user_access_token
+```
+
+To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
+
+The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
+
+- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
+- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False
+
+## Setting up the image logging hook function
+
+Expect the function to be given a dictionary with keys
+```python
+['image', 'prompt', 'prompt_metadata', 'rewards']
+
+```
+and `image`, `prompt`, `prompt_metadata`, `rewards`are batched.
+You are free to log however you want the use of `wandb` or `tensorboard` is recommended.
+
+### Key terms
+
+- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
+- `prompt` : The prompt is the text that is used to generate the image
+- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
+- `image` : The image generated by the Stable Diffusion model
+
+Example code for logging sampled images with `wandb` is given below.
+
+```python
+# for logging these images to wandb
+
+def image_outputs_hook(image_data, global_step, accelerate_logger):
+ # For the sake of this example, we only care about the last batch
+ # hence we extract the last element of the list
+ result = {}
+ images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
+ for i, image in enumerate(images):
+ pil = Image.fromarray(
+ (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
+ )
+ pil = pil.resize((256, 256))
+ result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
+ accelerate_logger.log_images(
+ result,
+ step=global_step,
+ )
+
+```
+
+### Using the finetuned model
+
+Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
+
+```python
+from diffusers import StableDiffusionPipeline
+pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+pipeline.to("cuda")
+
+pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')
+
+prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
+results = pipeline(prompts)
+
+for prompt, image in zip(prompts,results.images):
+ image.save(f"dump/{prompt}.png")
+```
+
+## Credits
+
+This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation
+ by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://arxiv.org/abs/2310.03739).
diff --git a/examples/scripts/alignprop.py b/examples/scripts/alignprop.py
new file mode 100644
index 00000000000..f482c49da83
--- /dev/null
+++ b/examples/scripts/alignprop.py
@@ -0,0 +1,129 @@
+# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Total Batch size = 128 = 4 (num_gpus) * 8 (per_device_batch) * 4 (accumulation steps)
+Feel free to reduce batch size or increasing truncated_rand_backprop_min to a higher value to reduce memory usage.
+
+CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/scripts/alignprop.py \
+ --num_epochs=20 \
+ --train_gradient_accumulation_steps=4 \
+ --sample_num_steps=50 \
+ --train_batch_size=8 \
+ --tracker_project_name="stable_diffusion_training" \
+ --log_with="wandb"
+
+"""
+from dataclasses import dataclass, field
+
+import numpy as np
+from transformers import HfArgumentParser
+
+from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline
+from trl.models.auxiliary_modules import aesthetic_scorer
+
+
+@dataclass
+class ScriptArguments:
+ pretrained_model: str = field(
+ default="runwayml/stable-diffusion-v1-5", metadata={"help": "the pretrained model to use"}
+ )
+ pretrained_revision: str = field(default="main", metadata={"help": "the pretrained model revision to use"})
+ hf_hub_model_id: str = field(
+ default="alignprop-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to"}
+ )
+ hf_hub_aesthetic_model_id: str = field(
+ default="trl-lib/ddpo-aesthetic-predictor",
+ metadata={"help": "HuggingFace model ID for aesthetic scorer model weights"},
+ )
+ hf_hub_aesthetic_model_filename: str = field(
+ default="aesthetic-model.pth",
+ metadata={"help": "HuggingFace model filename for aesthetic scorer model weights"},
+ )
+ use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."})
+
+
+# list of example prompts to feed stable diffusion
+animals = [
+ "cat",
+ "dog",
+ "horse",
+ "monkey",
+ "rabbit",
+ "zebra",
+ "spider",
+ "bird",
+ "sheep",
+ "deer",
+ "cow",
+ "goat",
+ "lion",
+ "frog",
+ "chicken",
+ "duck",
+ "goose",
+ "bee",
+ "pig",
+ "turkey",
+ "fly",
+ "llama",
+ "camel",
+ "bat",
+ "gorilla",
+ "hedgehog",
+ "kangaroo",
+]
+
+
+def prompt_fn():
+ return np.random.choice(animals), {}
+
+
+def image_outputs_logger(image_pair_data, global_step, accelerate_logger):
+ # For the sake of this example, we will only log the last batch of images
+ # and associated data
+ result = {}
+ images, prompts, _ = [image_pair_data["images"], image_pair_data["prompts"], image_pair_data["rewards"]]
+ for i, image in enumerate(images[:4]):
+ prompt = prompts[i]
+ result[f"{prompt}"] = image.unsqueeze(0).float()
+ accelerate_logger.log_images(
+ result,
+ step=global_step,
+ )
+
+
+if __name__ == "__main__":
+ parser = HfArgumentParser((ScriptArguments, AlignPropConfig))
+ args, alignprop_config = parser.parse_args_into_dataclasses()
+ alignprop_config.project_kwargs = {
+ "logging_dir": "./logs",
+ "automatic_checkpoint_naming": True,
+ "total_limit": 5,
+ "project_dir": "./save",
+ }
+
+ pipeline = DefaultDDPOStableDiffusionPipeline(
+ args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=args.use_lora
+ )
+ trainer = AlignPropTrainer(
+ alignprop_config,
+ aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename),
+ prompt_fn,
+ pipeline,
+ image_samples_hook=image_outputs_logger,
+ )
+
+ trainer.train()
+
+ trainer.push_to_hub(args.hf_hub_model_id)
diff --git a/tests/test_alignprop_trainer.py b/tests/test_alignprop_trainer.py
new file mode 100644
index 00000000000..7faff69da3c
--- /dev/null
+++ b/tests/test_alignprop_trainer.py
@@ -0,0 +1,109 @@
+# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import gc
+import unittest
+
+import torch
+
+from trl import is_diffusers_available, is_peft_available
+
+from .testing_utils import require_diffusers
+
+
+if is_diffusers_available() and is_peft_available():
+ from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline
+
+
+def scorer_function(images, prompts, metadata):
+ return torch.randn(1) * 3.0, {}
+
+
+def prompt_function():
+ return ("cabbages", {})
+
+
+@require_diffusers
+class AlignPropTrainerTester(unittest.TestCase):
+ """
+ Test the AlignPropTrainer class.
+ """
+
+ def setUp(self):
+ self.alignprop_config = AlignPropConfig(
+ num_epochs=2,
+ train_gradient_accumulation_steps=1,
+ train_batch_size=2,
+ truncated_backprop_rand=False,
+ mixed_precision=None,
+ save_freq=1000000,
+ )
+ pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch"
+ pretrained_revision = "main"
+
+ pipeline = DefaultDDPOStableDiffusionPipeline(
+ pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=False
+ )
+
+ self.trainer = AlignPropTrainer(self.alignprop_config, scorer_function, prompt_function, pipeline)
+
+ return super().setUp()
+
+ def tearDown(self) -> None:
+ gc.collect()
+
+ def test_generate_samples(self):
+ output_pairs = self.trainer._generate_samples(2, with_grad=True)
+ assert len(output_pairs.keys()) == 3
+ assert len(output_pairs["images"]) == 2
+
+ def test_calculate_loss(self):
+ sample = self.trainer._generate_samples(2)
+
+ images = sample["images"]
+ prompts = sample["prompts"]
+
+ assert images.shape == (2, 3, 128, 128)
+ assert len(prompts) == 2
+
+ rewards = self.trainer.compute_rewards(sample)
+ loss = self.trainer.calculate_loss(rewards)
+
+ assert torch.isfinite(loss.cpu())
+
+
+@require_diffusers
+class AlignPropTrainerWithLoRATester(AlignPropTrainerTester):
+ """
+ Test the AlignPropTrainer class.
+ """
+
+ def setUp(self):
+ self.alignprop_config = AlignPropConfig(
+ num_epochs=2,
+ train_gradient_accumulation_steps=1,
+ mixed_precision=None,
+ truncated_backprop_rand=False,
+ save_freq=1000000,
+ )
+
+ pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch"
+ pretrained_revision = "main"
+
+ pipeline = DefaultDDPOStableDiffusionPipeline(
+ pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=True
+ )
+
+ self.trainer = AlignPropTrainer(self.alignprop_config, scorer_function, prompt_function, pipeline)
+
+ return super().setUp()
diff --git a/trl/__init__.py b/trl/__init__.py
index 6b33eca27fc..fcd91595a45 100644
--- a/trl/__init__.py
+++ b/trl/__init__.py
@@ -39,6 +39,8 @@
"DPOTrainer",
"CPOConfig",
"CPOTrainer",
+ "AlignPropConfig",
+ "AlignPropTrainer",
"IterativeSFTTrainer",
"KTOConfig",
"KTOTrainer",
@@ -105,6 +107,8 @@
DPOTrainer,
CPOConfig,
CPOTrainer,
+ AlignPropConfig,
+ AlignPropTrainer,
IterativeSFTTrainer,
KTOConfig,
KTOTrainer,
diff --git a/trl/models/auxiliary_modules.py b/trl/models/auxiliary_modules.py
new file mode 100644
index 00000000000..ed1f9b75074
--- /dev/null
+++ b/trl/models/auxiliary_modules.py
@@ -0,0 +1,97 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import torch
+import torch.nn as nn
+import torchvision
+from huggingface_hub import hf_hub_download
+from huggingface_hub.utils import EntryNotFoundError
+from transformers import CLIPModel
+
+from trl.import_utils import is_npu_available, is_xpu_available
+
+
+class MLP(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.layers = nn.Sequential(
+ nn.Linear(768, 1024),
+ nn.Dropout(0.2),
+ nn.Linear(1024, 128),
+ nn.Dropout(0.2),
+ nn.Linear(128, 64),
+ nn.Dropout(0.1),
+ nn.Linear(64, 16),
+ nn.Linear(16, 1),
+ )
+
+ def forward(self, embed):
+ return self.layers(embed)
+
+
+class AestheticScorer(torch.nn.Module):
+ """
+ This model attempts to predict the aesthetic score of an image. The aesthetic score
+ is a numerical approximation of how much a specific image is liked by humans on average.
+ This is from https://github.com/christophschuhmann/improved-aesthetic-predictor
+ """
+
+ def __init__(self, *, dtype, model_id, model_filename):
+ super().__init__()
+ self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
+ self.normalize = torchvision.transforms.Normalize(
+ mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
+ )
+ self.target_size = 224
+ self.mlp = MLP()
+ try:
+ cached_path = hf_hub_download(model_id, model_filename)
+ except EntryNotFoundError:
+ cached_path = os.path.join(model_id, model_filename)
+ state_dict = torch.load(cached_path, map_location=torch.device("cpu"))
+ self.mlp.load_state_dict(state_dict)
+ self.dtype = dtype
+ self.eval()
+
+ def __call__(self, images):
+ device = next(self.parameters()).device
+ images = torchvision.transforms.Resize(self.target_size)(images)
+ images = self.normalize(images).to(self.dtype).to(device)
+ embed = self.clip.get_image_features(pixel_values=images)
+ # normalize embedding
+ embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
+ reward = self.mlp(embed).squeeze(1)
+ return reward
+
+
+def aesthetic_scorer(hub_model_id, model_filename):
+ scorer = AestheticScorer(
+ model_id=hub_model_id,
+ model_filename=model_filename,
+ dtype=torch.float32,
+ )
+ if is_npu_available():
+ scorer = scorer.npu()
+ elif is_xpu_available():
+ scorer = scorer.xpu()
+ else:
+ scorer = scorer.cuda()
+
+ def _fn(images, prompts, metadata):
+ images = (images).clamp(0, 1)
+ scores = scorer(images)
+ return scores, {}
+
+ return _fn
diff --git a/trl/models/modeling_sd_base.py b/trl/models/modeling_sd_base.py
index bbca699fff9..44d3e700347 100644
--- a/trl/models/modeling_sd_base.py
+++ b/trl/models/modeling_sd_base.py
@@ -14,12 +14,14 @@
import contextlib
import os
+import random
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
+import torch.utils.checkpoint as checkpoint
from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
@@ -233,7 +235,6 @@ def scheduler_step(
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
-
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# to prevent OOB on gather
@@ -528,6 +529,259 @@ def pipeline_step(
return DDPOPipelineOutput(image, all_latents, all_log_probs)
+def pipeline_step_with_grad(
+ pipeline,
+ prompt: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ truncated_backprop: bool = True,
+ truncated_backprop_rand: bool = True,
+ gradient_checkpoint: bool = True,
+ truncated_backprop_timestep: int = 49,
+ truncated_rand_backprop_minmax: tuple = (0, 50),
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+):
+ r"""
+ Function to get RGB image with gradients attached to the model weights. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to pipeline.unet.config.sample_size * pipeline.vae_scale_factor): The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to pipeline.unet.config.sample_size * pipeline.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ truncated_backprop (`bool`, *optional*, defaults to True):
+ Truncated Backpropation to fixed timesteps, helps prevent collapse during diffusion reward training as shown in AlignProp (https://arxiv.org/abs/2310.03739).
+ truncated_backprop_rand (`bool`, *optional*, defaults to True):
+ Truncated Randomized Backpropation randomizes truncation to different diffusion timesteps, this helps prevent collapse during diffusion reward training as shown in AlignProp (https://arxiv.org/abs/2310.03739).
+ Enabling truncated_backprop_rand allows adapting earlier timesteps in diffusion while not resulting in a collapse.
+ gradient_checkpoint (`bool`, *optional*, defaults to True):
+ Adds gradient checkpointing to Unet forward pass. Reduces GPU memory consumption while slightly increasing the training time.
+ truncated_backprop_timestep (`int`, *optional*, defaults to 49):
+ Absolute timestep to which the gradients are being backpropagated. Higher number reduces the memory usage and reduces the chances of collapse.
+ While a lower value, allows more semantic changes in the diffusion generations, as the earlier diffusion timesteps are getting updated.
+ However it also increases the chances of collapse.
+ truncated_rand_backprop_minmax (`Tuple`, *optional*, defaults to (0,50)):
+ Range for randomized backprop. Here the value at 0 index indicates the earlier diffusion timestep to update (closer to noise), while the value
+ at index 1 indicates the later diffusion timestep to update.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `pipeline.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+
+ Examples:
+
+ Returns:
+ `DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated log probabilities
+ """
+ # 0. Default height and width to unet
+ height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
+ width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
+
+ with torch.no_grad():
+ # 1. Check inputs. Raise error if not correct
+ pipeline.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = pipeline._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
+ )
+ prompt_embeds = pipeline._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = pipeline.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = pipeline.unet.config.in_channels
+ latents = pipeline.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
+ all_latents = [latents]
+ all_log_probs = []
+ with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ if gradient_checkpoint:
+ noise_pred = checkpoint.checkpoint(
+ pipeline.unet,
+ latent_model_input,
+ t,
+ prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ use_reentrant=False,
+ )[0]
+ else:
+ noise_pred = pipeline.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # truncating backpropagation is critical for preventing overoptimization (https://arxiv.org/abs/2304.05977).
+ if truncated_backprop:
+ # Randomized truncation randomizes the truncation process (https://arxiv.org/abs/2310.03739)
+ # the range of truncation is defined by truncated_rand_backprop_minmax
+ # Setting truncated_rand_backprop_minmax[0] to be low will allow the model to update earlier timesteps in the diffusion chain, while setitng it high will reduce the memory usage.
+ if truncated_backprop_rand:
+ rand_timestep = random.randint(
+ truncated_rand_backprop_minmax[0], truncated_rand_backprop_minmax[1]
+ )
+ if i < rand_timestep:
+ noise_pred = noise_pred.detach()
+ else:
+ # fixed truncation process
+ if i < truncated_backprop_timestep:
+ noise_pred = noise_pred.detach()
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ scheduler_output = scheduler_step(pipeline.scheduler, noise_pred, t, latents, eta)
+ latents = scheduler_output.latents
+ log_prob = scheduler_output.log_probs
+
+ all_latents.append(latents)
+ all_log_probs.append(log_prob)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ image = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor, return_dict=False)[0]
+ image, has_nsfw_concept = pipeline.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = pipeline.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload last model to CPU
+ if hasattr(pipeline, "final_offload_hook") and pipeline.final_offload_hook is not None:
+ pipeline.final_offload_hook.offload()
+
+ return DDPOPipelineOutput(image, all_latents, all_log_probs)
+
+
class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline):
def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str = "main", use_lora: bool = True):
self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
@@ -563,6 +817,9 @@ def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str
def __call__(self, *args, **kwargs) -> DDPOPipelineOutput:
return pipeline_step(self.sd_pipeline, *args, **kwargs)
+ def rgb_with_grad(self, *args, **kwargs) -> DDPOPipelineOutput:
+ return pipeline_step_with_grad(self.sd_pipeline, *args, **kwargs)
+
def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput:
return scheduler_step(self.sd_pipeline.scheduler, *args, **kwargs)
diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py
index c291ff31ec4..8d43e7761e0 100644
--- a/trl/trainer/__init__.py
+++ b/trl/trainer/__init__.py
@@ -33,6 +33,8 @@
"dpo_trainer": ["DPOTrainer"],
"cpo_config": ["CPOConfig"],
"cpo_trainer": ["CPOTrainer"],
+ "alignprop_config": ["AlignPropConfig"],
+ "alignprop_trainer": ["AlignPropTrainer"],
"iterative_sft_trainer": ["IterativeSFTTrainer"],
"kto_config": ["KTOConfig"],
"kto_trainer": ["KTOTrainer"],
@@ -81,6 +83,7 @@
from .iterative_sft_trainer import IterativeSFTTrainer
from .cpo_config import CPOConfig
from .cpo_trainer import CPOTrainer
+ from .alignprop_config import AlignPropConfig
from .kto_config import KTOConfig
from .kto_trainer import KTOTrainer
from .model_config import ModelConfig
diff --git a/trl/trainer/alignprop_config.py b/trl/trainer/alignprop_config.py
new file mode 100644
index 00000000000..7bd4cd32bde
--- /dev/null
+++ b/trl/trainer/alignprop_config.py
@@ -0,0 +1,104 @@
+import os
+import sys
+import warnings
+from dataclasses import dataclass, field
+from typing import Literal, Optional
+
+from ..core import flatten_dict
+from ..import_utils import is_bitsandbytes_available, is_torchvision_available
+
+
+@dataclass
+class AlignPropConfig:
+ """
+ Configuration class for AlignPropTrainer
+ """
+
+ # common parameters
+ exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
+ """the name of this experiment (by default is the file name without the extension name)"""
+ run_name: Optional[str] = ""
+ """Run name for wandb logging and checkpoint saving."""
+ seed: int = 0
+ """Seed value for random generations"""
+ log_with: Optional[Literal["wandb", "tensorboard"]] = None
+ """Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"""
+ log_image_freq = 1
+ """Logging Frequency for images"""
+ tracker_kwargs: dict = field(default_factory=dict)
+ """Keyword arguments for the tracker (e.g. wandb_project)"""
+ accelerator_kwargs: dict = field(default_factory=dict)
+ """Keyword arguments for the accelerator"""
+ project_kwargs: dict = field(default_factory=dict)
+ """Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
+ tracker_project_name: str = "trl"
+ """Name of project to use for tracking"""
+ logdir: str = "logs"
+ """Top-level logging directory for checkpoint saving."""
+
+ # hyperparameters
+ num_epochs: int = 100
+ """Number of epochs to train."""
+ save_freq: int = 1
+ """Number of epochs between saving model checkpoints."""
+ num_checkpoint_limit: int = 5
+ """Number of checkpoints to keep before overwriting old ones."""
+ mixed_precision: str = "fp16"
+ """Mixed precision training."""
+ allow_tf32: bool = True
+ """Allow tf32 on Ampere GPUs."""
+ resume_from: Optional[str] = ""
+ """Resume training from a checkpoint."""
+ sample_num_steps: int = 50
+ """Number of sampler inference steps."""
+ sample_eta: float = 1.0
+ """Eta parameter for the DDIM sampler."""
+ sample_guidance_scale: float = 5.0
+ """Classifier-free guidance weight."""
+ train_batch_size: int = 1
+ """Batch size (per GPU!) to use for training."""
+ train_use_8bit_adam: bool = False
+ """Whether to use the 8bit Adam optimizer from bitsandbytes."""
+ train_learning_rate: float = 1e-3
+ """Learning rate."""
+ train_adam_beta1: float = 0.9
+ """Adam beta1."""
+ train_adam_beta2: float = 0.999
+ """Adam beta2."""
+ train_adam_weight_decay: float = 1e-4
+ """Adam weight decay."""
+ train_adam_epsilon: float = 1e-8
+ """Adam epsilon."""
+ train_gradient_accumulation_steps: int = 1
+ """Number of gradient accumulation steps."""
+ train_max_grad_norm: float = 1.0
+ """Maximum gradient norm for gradient clipping."""
+ negative_prompts: Optional[str] = ""
+ """Comma-separated list of prompts to use as negative examples."""
+ truncated_backprop_rand: bool = True
+ """Truncated Randomized Backpropation randomizes truncation to different diffusion timesteps"""
+ truncated_backprop_timestep: int = 49
+ """Absolute timestep to which the gradients are being backpropagated. If truncated_backprop_rand is False"""
+ truncated_rand_backprop_minmax: tuple = (0, 50)
+ """Range of diffusion timesteps for randomized truncated backprop."""
+
+ def to_dict(self):
+ output_dict = {}
+ for key, value in self.__dict__.items():
+ output_dict[key] = value
+ return flatten_dict(output_dict)
+
+ def __post_init__(self):
+ if self.log_with not in ["wandb", "tensorboard"]:
+ warnings.warn(
+ "Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'."
+ )
+
+ if self.log_with == "wandb" and not is_torchvision_available():
+ warnings.warn("Wandb image logging requires torchvision to be installed")
+
+ if self.train_use_8bit_adam and not is_bitsandbytes_available():
+ raise ImportError(
+ "You need to install bitsandbytes to use 8bit Adam. "
+ "You can install it with `pip install bitsandbytes`."
+ )
diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py
new file mode 100644
index 00000000000..9024a410d86
--- /dev/null
+++ b/trl/trainer/alignprop_trainer.py
@@ -0,0 +1,422 @@
+# Copyright 2023 AlignProp-pytorch authors (Mihir Prabhudesai), metric-space, The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import warnings
+from collections import defaultdict
+from typing import Any, Callable, Optional, Tuple
+from warnings import warn
+
+import torch
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import whoami
+
+from ..models import DDPOStableDiffusionPipeline
+from . import AlignPropConfig, BaseTrainer
+
+
+logger = get_logger(__name__)
+
+
+MODEL_CARD_TEMPLATE = """---
+license: apache-2.0
+tags:
+- trl
+- alignprop
+- diffusers
+- reinforcement-learning
+- text-to-image
+- stable-diffusion
+---
+
+# {model_name}
+
+This is a pipeline that finetunes a diffusion model with reward backpropagation while using randomized truncation (https://arxiv.org/abs/2310.03739). The model can be used for image generation conditioned with text.
+
+"""
+
+
+class AlignPropTrainer(BaseTrainer):
+ """
+ The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
+ Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
+ As of now only Stable Diffusion based pipelines are supported
+
+ Attributes:
+ **config** (`AlignPropConfig`) -- Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more
+ details.
+ **reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) -- Reward function to be used
+ **prompt_function** (Callable[[], Tuple[str, Any]]) -- Function to generate prompts to guide model
+ **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
+ **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
+ """
+
+ _tag_names = ["trl", "alignprop"]
+
+ def __init__(
+ self,
+ config: AlignPropConfig,
+ reward_function: Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor],
+ prompt_function: Callable[[], Tuple[str, Any]],
+ sd_pipeline: DDPOStableDiffusionPipeline,
+ image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
+ ):
+ if image_samples_hook is None:
+ warn("No image_samples_hook provided; no images will be logged")
+
+ self.prompt_fn = prompt_function
+ self.reward_fn = reward_function
+ self.config = config
+ self.image_samples_callback = image_samples_hook
+
+ accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
+
+ if self.config.resume_from:
+ self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
+ if "checkpoint_" not in os.path.basename(self.config.resume_from):
+ # get the most recent checkpoint in this directory
+ checkpoints = list(
+ filter(
+ lambda x: "checkpoint_" in x,
+ os.listdir(self.config.resume_from),
+ )
+ )
+ if len(checkpoints) == 0:
+ raise ValueError(f"No checkpoints found in {self.config.resume_from}")
+ checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
+ self.config.resume_from = os.path.join(
+ self.config.resume_from,
+ f"checkpoint_{checkpoint_numbers[-1]}",
+ )
+
+ accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
+
+ self.accelerator = Accelerator(
+ log_with=self.config.log_with,
+ mixed_precision=self.config.mixed_precision,
+ project_config=accelerator_project_config,
+ # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
+ # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
+ # the total number of optimizer steps to accumulate across.
+ gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
+ **self.config.accelerator_kwargs,
+ )
+
+ is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
+
+ if self.accelerator.is_main_process:
+ self.accelerator.init_trackers(
+ self.config.tracker_project_name,
+ config=dict(alignprop_trainer_config=config.to_dict())
+ if not is_using_tensorboard
+ else config.to_dict(),
+ init_kwargs=self.config.tracker_kwargs,
+ )
+
+ logger.info(f"\n{config}")
+
+ set_seed(self.config.seed, device_specific=True)
+
+ self.sd_pipeline = sd_pipeline
+
+ self.sd_pipeline.set_progress_bar_config(
+ position=1,
+ disable=not self.accelerator.is_local_main_process,
+ leave=False,
+ desc="Timestep",
+ dynamic_ncols=True,
+ )
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ if self.accelerator.mixed_precision == "fp16":
+ inference_dtype = torch.float16
+ elif self.accelerator.mixed_precision == "bf16":
+ inference_dtype = torch.bfloat16
+ else:
+ inference_dtype = torch.float32
+
+ self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
+ self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
+ self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
+
+ trainable_layers = self.sd_pipeline.get_trainable_layers()
+
+ self.accelerator.register_save_state_pre_hook(self._save_model_hook)
+ self.accelerator.register_load_state_pre_hook(self._load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if self.config.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ self.optimizer = self._setup_optimizer(
+ trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
+ )
+
+ self.neg_prompt_embed = self.sd_pipeline.text_encoder(
+ self.sd_pipeline.tokenizer(
+ [""] if self.config.negative_prompts is None else self.config.negative_prompts,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
+ ).input_ids.to(self.accelerator.device)
+ )[0]
+
+ # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
+ # more memory
+ self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
+
+ if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
+ unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
+ self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
+ else:
+ self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
+
+ if config.resume_from:
+ logger.info(f"Resuming from {config.resume_from}")
+ self.accelerator.load_state(config.resume_from)
+ self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
+ else:
+ self.first_epoch = 0
+
+ def compute_rewards(self, prompt_image_pairs):
+ reward, reward_metadata = self.reward_fn(
+ prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
+ )
+ return reward
+
+ def step(self, epoch: int, global_step: int):
+ """
+ Perform a single step of training.
+
+ Args:
+ epoch (int): The current epoch.
+ global_step (int): The current global step.
+
+ Side Effects:
+ - Model weights are updated
+ - Logs the statistics to the accelerator trackers.
+ - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
+
+ Returns:
+ global_step (int): The updated global step.
+
+ """
+ info = defaultdict(list)
+
+ self.sd_pipeline.unet.train()
+
+ for _ in range(self.config.train_gradient_accumulation_steps):
+ with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
+ prompt_image_pairs = self._generate_samples(
+ batch_size=self.config.train_batch_size,
+ )
+
+ rewards = self.compute_rewards(prompt_image_pairs)
+
+ prompt_image_pairs["rewards"] = rewards
+
+ rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
+
+ loss = self.calculate_loss(rewards)
+
+ self.accelerator.backward(loss)
+
+ if self.accelerator.sync_gradients:
+ self.accelerator.clip_grad_norm_(
+ self.trainable_layers.parameters()
+ if not isinstance(self.trainable_layers, list)
+ else self.trainable_layers,
+ self.config.train_max_grad_norm,
+ )
+
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+
+ info["reward_mean"].append(rewards_vis.mean())
+ info["reward_std"].append(rewards_vis.std())
+ info["loss"].append(loss.item())
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if self.accelerator.sync_gradients:
+ # log training-related stuff
+ info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
+ info = self.accelerator.reduce(info, reduction="mean")
+ info.update({"epoch": epoch})
+ self.accelerator.log(info, step=global_step)
+ global_step += 1
+ info = defaultdict(list)
+ else:
+ raise ValueError(
+ "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
+ )
+ # Logs generated images
+ if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
+ self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
+
+ if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
+ self.accelerator.save_state()
+
+ return global_step
+
+ def calculate_loss(self, rewards):
+ """
+ Calculate the loss for a batch of an unpacked sample
+
+ Args:
+ rewards (torch.Tensor):
+ Differentiable reward scalars for each generated image, shape: [batch_size]
+ Returns:
+ loss (torch.Tensor)
+ (all of these are of shape (1,))
+ """
+ # Loss is specific to Aesthetic Reward function used in AlignProp (https://arxiv.org/pdf/2310.03739.pdf)
+ loss = 10.0 - (rewards).mean()
+ return loss
+
+ def loss(
+ self,
+ advantages: torch.Tensor,
+ clip_range: float,
+ ratio: torch.Tensor,
+ ):
+ unclipped_loss = -advantages * ratio
+ clipped_loss = -advantages * torch.clamp(
+ ratio,
+ 1.0 - clip_range,
+ 1.0 + clip_range,
+ )
+ return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
+
+ def _setup_optimizer(self, trainable_layers_parameters):
+ if self.config.train_use_8bit_adam:
+ import bitsandbytes
+
+ optimizer_cls = bitsandbytes.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ return optimizer_cls(
+ trainable_layers_parameters,
+ lr=self.config.train_learning_rate,
+ betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
+ weight_decay=self.config.train_adam_weight_decay,
+ eps=self.config.train_adam_epsilon,
+ )
+
+ def _save_model_hook(self, models, weights, output_dir):
+ self.sd_pipeline.save_checkpoint(models, weights, output_dir)
+ weights.pop() # ensures that accelerate doesn't try to handle saving of the model
+
+ def _load_model_hook(self, models, input_dir):
+ self.sd_pipeline.load_checkpoint(models, input_dir)
+ models.pop() # ensures that accelerate doesn't try to handle loading of the model
+
+ def _generate_samples(self, batch_size, with_grad=True, prompts=None):
+ """
+ Generate samples from the model
+
+ Args:
+ batch_size (int): Batch size to use for sampling
+ with_grad (bool): Whether the generated RGBs should have gradients attached to it.
+
+ Returns:
+ prompt_image_pairs (Dict[Any])
+ """
+ prompt_image_pairs = {}
+
+ sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
+
+ if prompts is None:
+ prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
+ else:
+ prompt_metadata = [{} for _ in range(batch_size)]
+
+ prompt_ids = self.sd_pipeline.tokenizer(
+ prompts,
+ return_tensors="pt",
+ padding="max_length",
+ truncation=True,
+ max_length=self.sd_pipeline.tokenizer.model_max_length,
+ ).input_ids.to(self.accelerator.device)
+
+ prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
+
+ if with_grad:
+ sd_output = self.sd_pipeline.rgb_with_grad(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=sample_neg_prompt_embeds,
+ num_inference_steps=self.config.sample_num_steps,
+ guidance_scale=self.config.sample_guidance_scale,
+ eta=self.config.sample_eta,
+ truncated_backprop_rand=self.config.truncated_backprop_rand,
+ truncated_backprop_timestep=self.config.truncated_backprop_timestep,
+ truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
+ output_type="pt",
+ )
+ else:
+ sd_output = self.sd_pipeline(
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=sample_neg_prompt_embeds,
+ num_inference_steps=self.config.sample_num_steps,
+ guidance_scale=self.config.sample_guidance_scale,
+ eta=self.config.sample_eta,
+ output_type="pt",
+ )
+
+ images = sd_output.images
+
+ prompt_image_pairs["images"] = images
+ prompt_image_pairs["prompts"] = prompts
+ prompt_image_pairs["prompt_metadata"] = prompt_metadata
+
+ return prompt_image_pairs
+
+ def train(self, epochs: Optional[int] = None):
+ """
+ Train the model for a given number of epochs
+ """
+ global_step = 0
+ if epochs is None:
+ epochs = self.config.num_epochs
+ for epoch in range(self.first_epoch, epochs):
+ global_step = self.step(epoch, global_step)
+
+ def create_model_card(self, path: str, model_name: Optional[str] = "TRL AlignProp Model") -> None:
+ """Creates and saves a model card for a TRL model.
+
+ Args:
+ path (`str`): The path to save the model card to.
+ model_name (`str`, *optional*): The name of the model, defaults to `TRL AlignProp Model`.
+ """
+ try:
+ user = whoami()["name"]
+ # handle the offline case
+ except Exception:
+ warnings.warn("Cannot retrieve user information assuming you are running in offline mode.")
+ return
+
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}")
+ with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
+ f.write(model_card_content)
+
+ def _save_pretrained(self, save_directory):
+ self.sd_pipeline.save_pretrained(save_directory)
+ self.create_model_card(save_directory)