diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index fbff01010c95..c98cb9633ea8 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -342,6 +342,13 @@ def INPUT_TYPES(s): ["bf16", "fp32"], {"default": "bf16", "tooltip": "The dtype to use for lora."}, ), + "gradient_checkpointing": ( + IO.BOOLEAN, + { + "default": True, + "tooltip": "Use gradient checkpointing to reduce memory usage at the cost of speed)", + }, + ), "existing_lora": ( folder_paths.get_filename_list("loras") + ["[None]"], { @@ -372,9 +379,11 @@ def train( seed, training_dtype, lora_dtype, + gradient_checkpointing, existing_lora, ): mp = model.clone() + device = comfy.model_management.get_torch_device() dtype = node_helpers.string_to_torch_dtype(training_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) @@ -384,8 +393,9 @@ def train( with torch.inference_mode(False): lora_sd = {} - generator = torch.Generator() - generator.manual_seed(seed) + old_cpu_rng_state = torch.get_rng_state() + old_device_rng_state = torch.cuda.get_rng_state(device) + torch.manual_seed(seed) # Load existing LoRA weights if provided existing_weights = {} @@ -472,8 +482,12 @@ def train( criterion = torch.nn.SmoothL1Loss() # setup models - for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): - patch(m) + if gradient_checkpointing: + modules_to_patch = find_all_highest_child_module_with_forward(mp.model.diffusion_model) + for m in modules_to_patch: + patch(m) + logging.info(f"Added gradient checkpoints to {len(modules_to_patch)} modules") + comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) # Setup sampler and guider like in test script @@ -493,6 +507,9 @@ def loss_callback(loss): # Training loop torch.cuda.empty_cache() try: + if comfy.utils.PROGRESS_BAR_ENABLED: + ui_pbar = comfy.utils.ProgressBar(steps) + for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): # Generate random sigma sigma = mp.model.model_sampling.percent_to_sigma( @@ -506,6 +523,7 @@ def loss_callback(loss): ss.sample( noise, guider, train_sampler, sigma, {"samples": latents[indices].clone()} ) + ui_pbar.update(1) finally: for m in mp.model.modules(): unpatch(m) @@ -518,6 +536,9 @@ def loss_callback(loss): for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) + torch.set_rng_state(old_cpu_rng_state) + torch.cuda.set_rng_state(old_device_rng_state, device) + return (mp, lora_sd, loss_map, steps + existing_steps)