Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove gradient checkpointing inside gradient checkpointing #4474

Merged
merged 3 commits into from
Aug 7, 2023
Merged

Remove gradient checkpointing inside gradient checkpointing #4474

merged 3 commits into from
Aug 7, 2023

Conversation

ethansmith2000
Copy link
Contributor

What does this PR do?

fixes to gradient checkpointing based on this diagram
Untitled-19

I originally made these changes but then saw Transformer2DModel already had in the necessary changes in main! But the downblocks then need to have their checkpoint functions remove I believe which i have done in this commit

here are the differences in VRAM consumption on the LoRA training script you have provided for SDXL running at default parameters
Before:
Screen Shot 2023-08-04 at 2 31 31 PM

After:

Screen Shot 2023-08-04 at 2 31 41 PM

Fixes # (issue)

improper gradient checkpointing that becomes problematic when Transformer2Dmodels consist of more than 1 BasicTransformerBlock

@patrickvonplaten @sayakpaul

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 4, 2023

The documentation is not available anymore as the PR was closed or merged.

@bghira
Copy link
Contributor

bghira commented Aug 4, 2023

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.125.06   Driver Version: 525.125.06   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA RTX A6000    On   | 00000000:56:00.0 Off |                  Off |
| 30%   49C    P2   112W / 300W |  43984MiB / 49140MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    On   | 00000000:D6:00.0 Off |                  Off |
| 30%   49C    P2   116W / 300W |  43508MiB / 49140MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

thanks. your fix, in combination with the earlier fix from today, makes it possible to also fine-tune the whole u-net on 48G of VRAM at 768x768.

tested using this script

without this fix here, the forward pass 'runs away' after several steps and OOMs.

with it, the use is quite stable.

image

@sayakpaul
Copy link
Member

@ethansmith2000 thanks for your PR!

Could you please run make fix-copies from your diffusers fork and push the changes?

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Aug 7, 2023

Hey @ethansmith2000,

Thanks for the nice issue!

I'm actually not sure we have to remove gradient checkpointing from the unet block here - I don't think it's problematic to have torch checkpointing wrapped into torch checkpointing.

I couldn't reproduce any memory savings with this PR. I'm running the following script:

#!/usr/bin/env bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export INSTANCE_DIR="ddogog"
export OUTPUT_DIR="lora-trained-xl"

accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --enable_xformers_memory_efficient_attention \
  --mixed_precision="fp16" \
  --use_8bit_adam \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0" \
  --push_to_hub

where as I'm generating the instance dir "dog" with:

from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
    "diffusers/dog-example",
    local_dir=local_dir, repo_type="dataset",
    ignore_patterns=".gitattributes",
)

For both your PR and current main, I'm getting a max memory usage of 14.7 GB VRAM.

@patrickvonplaten
Copy link
Contributor

@bghira I've just updated #4505 to "main" so that the initial gradient checkpointing fixes are included - could you double check if the PR with current main really leads to worse memory results?

@bghira
Copy link
Contributor

bghira commented Aug 7, 2023

oh, it absolutely does. it allowed me to actually rent 48G GPUs instead of 80G. i tested both, as it would have just been more convenient for me if it had worked with main.

haven't been using that trainer script yet, i have a local implementation i've been mashing together.

settings:

export LEARNING_RATE=4e-7 #@param {type:"number"}
export NUM_EPOCHS=25
export LR_SCHEDULE="constant"
export LR_WARMUP_STEPS=$((MAX_NUM_STEPS / 10))
export TRAIN_BATCH_SIZE=10
export RESOLUTION=1024
export GRADIENT_ACCUMULATION_STEPS=15
export MIXED_PRECISION="bf16"
export TRAINING_DYNAMO_BACKEND='no'
export TRAINING_NUM_PROCESSES=1
export TRAINING_NUM_MACHINES=1
export TRAINER_EXTRA_ARGS="--allow_tf32 --use_8bit_adam --use_ema"
export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --enable_xformers_memory_efficient_attention --use_original_images=true"

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten
Copy link
Contributor

@bghira could you maybe post a reproducible of that form: #4474 (comment) for your script - would love to reproduce the difference in memory here.

@bghira
Copy link
Contributor

bghira commented Aug 7, 2023

might take me time :) all my GPUs are training

@patrickvonplaten
Copy link
Contributor

Ok after some internal discussion it seems like this "gradient checkpointing" inside "gradient checkpointing" works when using use_reentrant=True is currently being worked on, but quite a novel feature (so probably doesn't work with PyTorch < 2). So it seems like it's not very robust right now.

The advantage of allowing gradient checkpointing inside gradient checkpointing is that it allows us to build a more modular library where each component can have gradient checkpointing enabled. E.g. we ideally want both the Transformer2DModel and the unet blocks independently from each other be able to use gradient checkpointing.

For now however the solution is this PR seems like the better, more robust way going forward - so let's go with it I'd say. I also don't see any real downsides to it :-)

@patrickvonplaten patrickvonplaten merged commit f4f8541 into huggingface:main Aug 7, 2023
@patrickvonplaten patrickvonplaten changed the title grad checkpointing Remove gradient checkpointing inside gradient checkpointing Aug 7, 2023
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* grad checkpointing

* fix make fix-copies

* fix

---------

Co-authored-by: Patrick von Platen <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* grad checkpointing

* fix make fix-copies

* fix

---------

Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants