-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove gradient checkpointing inside gradient checkpointing #4474
Conversation
The documentation is not available anymore as the PR was closed or merged. |
thanks. your fix, in combination with the earlier fix from today, makes it possible to also fine-tune the whole u-net on 48G of VRAM at 768x768. tested using this script without this fix here, the forward pass 'runs away' after several steps and OOMs. with it, the use is quite stable. |
@ethansmith2000 thanks for your PR! Could you please run |
Hey @ethansmith2000, Thanks for the nice issue! I'm actually not sure we have to remove gradient checkpointing from the unet block here - I don't think it's problematic to have torch checkpointing wrapped into torch checkpointing. I couldn't reproduce any memory savings with this PR. I'm running the following script: #!/usr/bin/env bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export INSTANCE_DIR="ddogog"
export OUTPUT_DIR="lora-trained-xl"
accelerate launch train_dreambooth_lora_sdxl.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="fp16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--enable_xformers_memory_efficient_attention \
--mixed_precision="fp16" \
--use_8bit_adam \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub where as I'm generating the instance dir from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
) For both your PR and current main, I'm getting a max memory usage of 14.7 GB VRAM. |
oh, it absolutely does. it allowed me to actually rent 48G GPUs instead of 80G. i tested both, as it would have just been more convenient for me if it had worked with main. haven't been using that trainer script yet, i have a local implementation i've been mashing together. settings: export LEARNING_RATE=4e-7 #@param {type:"number"}
export NUM_EPOCHS=25
export LR_SCHEDULE="constant"
export LR_WARMUP_STEPS=$((MAX_NUM_STEPS / 10))
export TRAIN_BATCH_SIZE=10
export RESOLUTION=1024
export GRADIENT_ACCUMULATION_STEPS=15
export MIXED_PRECISION="bf16"
export TRAINING_DYNAMO_BACKEND='no'
export TRAINING_NUM_PROCESSES=1
export TRAINING_NUM_MACHINES=1
export TRAINER_EXTRA_ARGS="--allow_tf32 --use_8bit_adam --use_ema"
export TRAINER_EXTRA_ARGS="${TRAINER_EXTRA_ARGS} --enable_xformers_memory_efficient_attention --use_original_images=true" |
Also asked on the PyTorch forum here: https://discuss.pytorch.org/t/gradient-checkpointing-inside-gradient-checkpointing/185861 |
@bghira could you maybe post a reproducible of that form: #4474 (comment) for your script - would love to reproduce the difference in memory here. |
might take me time :) all my GPUs are training |
Ok after some internal discussion it seems like this "gradient checkpointing" inside "gradient checkpointing" works when using The advantage of allowing gradient checkpointing inside gradient checkpointing is that it allows us to build a more modular library where each component can have gradient checkpointing enabled. E.g. we ideally want both the For now however the solution is this PR seems like the better, more robust way going forward - so let's go with it I'd say. I also don't see any real downsides to it :-) |
* grad checkpointing * fix make fix-copies * fix --------- Co-authored-by: Patrick von Platen <[email protected]>
* grad checkpointing * fix make fix-copies * fix --------- Co-authored-by: Patrick von Platen <[email protected]>
What does this PR do?
fixes to gradient checkpointing based on this diagram
![Untitled-19](https://private-user-images.githubusercontent.com/98723285/258501474-0569e92a-a428-4952-8d36-3f95aba3cf63.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk2MTI5MjAsIm5iZiI6MTczOTYxMjYyMCwicGF0aCI6Ii85ODcyMzI4NS8yNTg1MDE0NzQtMDU2OWU5MmEtYTQyOC00OTUyLThkMzYtM2Y5NWFiYTNjZjYzLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTUlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE1VDA5NDM0MFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWRhMDVmMTM1YmU1ODQ2MzE3OTBjMWQ1NjNhMDlmYzcwZWE5NjA5YjYzODRjY2U3NjZjZTY5MjAwZGE3YjMwZjImWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.Az92m4EnHcCjgrCK6LCp1onr4C3TsLiLhDYmyOPDAoE)
I originally made these changes but then saw Transformer2DModel already had in the necessary changes in main! But the downblocks then need to have their checkpoint functions remove I believe which i have done in this commit
here are the differences in VRAM consumption on the LoRA training script you have provided for SDXL running at default parameters
![Screen Shot 2023-08-04 at 2 31 31 PM](https://private-user-images.githubusercontent.com/98723285/258502213-383ab8d6-7210-4aba-ae84-efad32de9f19.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk2MTI5MjAsIm5iZiI6MTczOTYxMjYyMCwicGF0aCI6Ii85ODcyMzI4NS8yNTg1MDIyMTMtMzgzYWI4ZDYtNzIxMC00YWJhLWFlODQtZWZhZDMyZGU5ZjE5LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTUlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE1VDA5NDM0MFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTQ1Y2ExMWRiODZkNTUyMWI1ZjBiZjhlNDlmM2IwMmE1YTQ5MzYxN2JjYWJiOTJhMTQ1YzkxYTY1MzEwOWQxMjEmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.hDAqtgngnRdc6VbY3BLNSWYnAVRWy4E6nlQig5EvWPE)
Before:
After:
Fixes # (issue)
improper gradient checkpointing that becomes problematic when Transformer2Dmodels consist of more than 1 BasicTransformerBlock
@patrickvonplaten @sayakpaul