From 1a700540075fa3d25a7403aec5eef2f8a81bce24 Mon Sep 17 00:00:00 2001 From: Leo Jiang <74156916+leisuzz@users.noreply.github.com> Date: Fri, 1 Nov 2024 00:49:32 -0600 Subject: [PATCH] Reduce Memory Cost in Flux Training (#9829) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Improve NPU performance * Improve NPU performance * Improve NPU performance * Improve NPU performance * [bugfix] bugfix for npu free memory * [bugfix] bugfix for npu free memory * [bugfix] bugfix for npu free memory * Reduce memory cost for flux training process --------- Co-authored-by: 蒋硕 Co-authored-by: Sayak Paul --- examples/dreambooth/train_dreambooth_flux.py | 6 ++++++ examples/dreambooth/train_dreambooth_lora_flux.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index bd1c29009976..9fd95fe823a5 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1740,6 +1740,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): torch_npu.npu.empty_cache() gc.collect() + images = None + del pipeline + # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: @@ -1798,6 +1801,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ignore_patterns=["step_*", "epoch_*"], ) + images = None + del pipeline + accelerator.end_training() diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index e21485952583..2c1126109a36 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1844,6 +1844,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): del text_encoder_one, text_encoder_two free_memory() + images = None + del pipeline + # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: @@ -1908,6 +1911,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ignore_patterns=["step_*", "epoch_*"], ) + images = None + del pipeline + accelerator.end_training()