From 4cf3fe130627371041d9be838b6a9d94c215b4ae Mon Sep 17 00:00:00 2001 From: Leo Jiang <74156916+leisuzz@users.noreply.github.com> Date: Fri, 1 Nov 2024 00:49:32 -0600 Subject: [PATCH] Reduce Memory Cost in Flux Training (#9829) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Improve NPU performance * Improve NPU performance * Improve NPU performance * Improve NPU performance * [bugfix] bugfix for npu free memory * [bugfix] bugfix for npu free memory * [bugfix] bugfix for npu free memory * Reduce memory cost for flux training process --------- Co-authored-by: 蒋硕 Co-authored-by: Sayak Paul --- examples/dreambooth/train_dreambooth_flux.py | 6 ++++++ examples/dreambooth/train_dreambooth_lora_flux.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index bd1c29009976b..9fd95fe823a5f 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1740,6 +1740,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): torch_npu.npu.empty_cache() gc.collect() + images = None + del pipeline + # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: @@ -1798,6 +1801,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ignore_patterns=["step_*", "epoch_*"], ) + images = None + del pipeline + accelerator.end_training() diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index e214859525830..2c1126109a36d 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1844,6 +1844,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): del text_encoder_one, text_encoder_two free_memory() + images = None + del pipeline + # Save the lora layers accelerator.wait_for_everyone() if accelerator.is_main_process: @@ -1908,6 +1911,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ignore_patterns=["step_*", "epoch_*"], ) + images = None + del pipeline + accelerator.end_training()