From 18fe891f78ff997a8d299a5e31e98dd343074b5c Mon Sep 17 00:00:00 2001 From: bghira Date: Wed, 18 Dec 2024 13:42:36 -0600 Subject: [PATCH] sd3: allow setting grad checkpointing interval --- helpers/configuration/cmd_args.py | 2 +- helpers/models/sd3/transformer.py | 19 ++++++++++++++++++- .../training/default_settings/safety_check.py | 6 +----- helpers/training/diffusion_model.py | 5 ++++- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/helpers/configuration/cmd_args.py b/helpers/configuration/cmd_args.py index 89ff22f0..5fc965c5 100644 --- a/helpers/configuration/cmd_args.py +++ b/helpers/configuration/cmd_args.py @@ -1068,7 +1068,7 @@ def get_argument_parser(): default=None, type=int, help=( - "Some models (Flux, SDXL, SD1.x/2.x) can have their gradient checkpointing limited to every nth block." + "Some models (Flux, SDXL, SD1.x/2.x, SD3) can have their gradient checkpointing limited to every nth block." " This can speed up training but will use more memory with larger intervals." ), ) diff --git a/helpers/models/sd3/transformer.py b/helpers/models/sd3/transformer.py index 943bdb7b..fbb20d52 100644 --- a/helpers/models/sd3/transformer.py +++ b/helpers/models/sd3/transformer.py @@ -140,6 +140,16 @@ def __init__( ) self.gradient_checkpointing = False + self.gradient_checkpointing_interval = None + + def set_gradient_checkpointing_interval(self, interval: int): + """ + Sets the interval for gradient checkpointing. + + Parameters: + interval (`int`): The interval for gradient checkpointing. + """ + self.gradient_checkpointing_interval = interval # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking def enable_forward_chunking( @@ -384,7 +394,14 @@ def forward( ) continue - if self.training and self.gradient_checkpointing: + if ( + self.training + and self.gradient_checkpointing + and ( + self.gradient_checkpointing_interval is None + or index_block % self.gradient_checkpointing_interval == 0 + ) + ): def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): diff --git a/helpers/training/default_settings/safety_check.py b/helpers/training/default_settings/safety_check.py index 0318440c..73d5d824 100644 --- a/helpers/training/default_settings/safety_check.py +++ b/helpers/training/default_settings/safety_check.py @@ -139,11 +139,7 @@ def safety_check(args, accelerator): ) args.attention_mechanism = "diffusers" - gradient_checkpointing_interval_supported_models = [ - "flux", - "sana", - "sdxl", - ] + gradient_checkpointing_interval_supported_models = ["flux", "sana", "sdxl", "sd3"] if args.gradient_checkpointing_interval is not None: if ( args.model_family.lower() diff --git a/helpers/training/diffusion_model.py b/helpers/training/diffusion_model.py index 950b0241..8a61a94f 100644 --- a/helpers/training/diffusion_model.py +++ b/helpers/training/diffusion_model.py @@ -179,7 +179,10 @@ def load_diffusion_model(args, weight_dtype): set_checkpoint_interval(int(args.gradient_checkpointing_interval)) - if args.gradient_checkpointing_interval is not None: + if ( + args.gradient_checkpointing_interval is not None + and args.gradient_checkpointing_interval > 1 + ): if transformer is not None and hasattr( transformer, "set_gradient_checkpointing_interval" ):