-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[feat] support activation cpu offload in fsdp and fsdp2 #7201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
17395ec
feat: add activation CPU offload callback support for FSDP training
meichangsu1 20680af
feat: add FSDP2 configuration and training script for activation CPU …
meichangsu1 80e1915
feat(callbacks): add constructor to ActivationCpuOffloadCallBack
meichangsu1 0cd6ffd
feat: enable activation checkpointing in FSDP2 config
meichangsu1 18ab86e
Updated fsdp2.json configuration for activation offload example
meichangsu1 723314f
feat(ascend): add training script for activation CPU offload
meichangsu1 de4d8dc
feat(activation_cpu_offload): update dataset and add NPU-specific syn…
meichangsu1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| { | ||
| "_description": "FSDP2 configuration for distributed training (PyTorch native FSDP v2)", | ||
| "_requires": "torch>=2.4.0", | ||
| "_note": "This is the recommended configuration for multi-GPU training without CPU offloading. NOTE: When using FSDP2, do NOT use --gradient_checkpointing, use activation_checkpointing in fsdp_config instead.", | ||
|
|
||
| "_param_docs": { | ||
| "fsdp": "FSDP strategy string. Options: 'full_shard' (ZeRO-3 style, shards params+grads+optimizer), 'shard_grad_op' (ZeRO-2 style, shards grads+optimizer only). Add 'auto_wrap' to enable automatic layer wrapping. Add 'offload' to enable CPU offloading.", | ||
| "fsdp_version": "FSDP version. Use 2 for PyTorch native FSDP2 (recommended). FSDP2 uses DTensor for per-parameter sharding, supports LoRA/QLoRA natively.", | ||
| "auto_wrap_policy": "How to wrap model layers. 'TRANSFORMER_BASED_WRAP' wraps transformer decoder layers (from model._no_split_modules). 'SIZE_BASED_WRAP' wraps modules exceeding min_num_params.", | ||
| "cpu_ram_efficient_loading": "If true, only rank 0 loads full model weights, then broadcasts to other ranks. Reduces CPU RAM usage during initialization.", | ||
| "state_dict_type": "'SHARDED_STATE_DICT' (recommended): each rank saves its own shard without extra communication. 'FULL_STATE_DICT': gathers full model on rank 0 (higher memory, slower).", | ||
| "reshard_after_forward": "true = FULL_SHARD (ZeRO-3), reshards params after forward pass. false = SHARD_GRAD_OP (ZeRO-2), keeps params gathered during forward/backward.", | ||
| "activation_checkpointing": "Use FSDP's native activation checkpointing instead of gradient_checkpointing. This is the correct way to save memory with FSDP.", | ||
| "activation_cpu_offload": "true = offload activations to CPU. false = keep activations on GPU,can enable when using activation_checkpointing." | ||
| }, | ||
| "fsdp": "full_shard auto_wrap", | ||
| "fsdp_config": { | ||
| "fsdp_version": 2, | ||
| "reshard_after_forward": true, | ||
| "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", | ||
| "cpu_ram_efficient_loading": true, | ||
| "state_dict_type": "SHARDED_STATE_DICT", | ||
| "activation_checkpointing": false, | ||
| "activation_cpu_offload": true | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| #!/bin/bash | ||
| CUDA_VISIBLE_DEVICES=0,1 \ | ||
| NPROC_PER_NODE=2 \ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这能给个体现这个技术很好的例子嘛 因为看这里例子,显存降低的不多,但是速度降了很多 |
||
| swift sft \ | ||
| --model 'Qwen/Qwen3-8B' \ | ||
| --train_type lora \ | ||
| --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \ | ||
| --torch_dtype bfloat16 \ | ||
| --num_train_epochs 1 \ | ||
| --per_device_train_batch_size 1 \ | ||
| --per_device_eval_batch_size 1 \ | ||
| --learning_rate 1e-4 \ | ||
| --lora_rank 8 \ | ||
| --lora_alpha 32 \ | ||
| --gradient_checkpointing false \ | ||
| --weight_decay 0.1 \ | ||
| --target_modules all-linear \ | ||
| --gradient_accumulation_steps 16 \ | ||
| --eval_steps 100 \ | ||
| --save_steps 5 \ | ||
| --save_total_limit 2 \ | ||
| --logging_steps 5 \ | ||
| --max_length 2048 \ | ||
| --output_dir output \ | ||
| --system You\ are\ a\ helpful\ assistant. \ | ||
| --warmup_ratio 0.05 \ | ||
| --dataloader_num_workers 4 \ | ||
| --fsdp './examples/train/activation_cpu_offload/fsdp2.json' | ||
|
|
||
|
|
||
| # --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' | ||
| # activation_cpu_offload=true | ||
|
|
||
| # {'loss': 2.1327579, 'grad_norm': 1.72890568, 'learning_rate': 8.346e-05, 'token_acc': 0.58396158, 'epoch': 0.32, 'global_step/max_steps': '5/16', 'percentage': '31.25%', 'elapsed_time': '5m 28s', 'remaining_time': '12m 2s', 'memory(GiB)': 24.8, 'train_speed(iter/s)': 0.015218} | ||
| # Train: 31%|██████████████████████████████████████▍ | 5/16 [05:28<11:41, 63.77s/it][INFO:swift] Saving model checkpoint to /model/ljl/output/v45-20251231-160511/checkpoint-5 | ||
| # {'loss': 1.51323957, 'grad_norm': 0.39210615, 'learning_rate': 3.455e-05, 'token_acc': 0.62368014, 'epoch': 0.64, 'global_step/max_steps': '10/16', 'percentage': '62.50%', 'elapsed_time': '10m 22s', 'remaining_time': '6m 13s', 'memory(GiB)': 24.87, 'train_speed(iter/s)': 0.016054} | ||
| # Train: 62%|████████████████████████████████████████████████████████████████████████████▎ | 10/16 [10:22<05:37, 56.26s/it][INFO:swift] Saving model checkpoint to /model/ljl/output/v45-20251231-160511/checkpoint-10 | ||
| # {'loss': 1.36127844, 'grad_norm': 0.30676287, 'learning_rate': 1.09e-06, 'token_acc': 0.64411869, 'epoch': 0.96, 'global_step/max_steps': '15/16', 'percentage': '93.75%', 'elapsed_time': '15m 6s', 'remaining_time': '1m 0s', 'memory(GiB)': 24.87, 'train_speed(iter/s)': 0.016547} | ||
| # ... | ||
| # {'train_runtime': 962.7184, 'train_samples_per_second': 0.519, 'train_steps_per_second': 0.017, 'train_loss': 1.61728384, 'token_acc': 0.62789828, 'epoch': 1.0, 'global_step/max_steps': '16/16', 'percentage': '100.00%', 'elapsed_time': '16m 2s', 'remaining_time': '0s', 'memory(GiB)': 24.87, 'train_speed(iter/s)': 0.016624} | ||
| # Train: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [16:02<00:00, 60.16s/it] | ||
|
|
||
|
|
||
| # activation_cpu_offload=false | ||
|
|
||
| # {'loss': 2.15452981, 'grad_norm': 1.7536869, 'learning_rate': 0.0001, 'token_acc': 0.61792799, 'epoch': 0.06, 'global_step/max_steps': '1/16', 'percentage': '6.25%', 'elapsed_time': '46s', 'remaining_time': '11m 39s', 'memory(GiB)': 26.14, 'train_speed(iter/s)': 0.021458} | ||
| # {'loss': 2.13306689, 'grad_norm': 1.7279824, 'learning_rate': 8.346e-05, 'token_acc': 0.58295639, 'epoch': 0.32, 'global_step/max_steps': '5/16', 'percentage': '31.25%', 'elapsed_time': '2m 55s', 'remaining_time': '6m 26s', 'memory(GiB)': 26.59, 'train_speed(iter/s)': 0.028456} | ||
| # Train: 31%|██████████████████████████████████████▍ | 5/16 [02:55<05:59, 32.65s/it][INFO:swift] Saving model checkpoint to /model/ljl/output/v44-20251231-155036/checkpoint-5 | ||
| # {'loss': 1.51308346, 'grad_norm': 0.39151499, 'learning_rate': 3.455e-05, 'token_acc': 0.62377399, 'epoch': 0.64, 'global_step/max_steps': '10/16', 'percentage': '62.50%', 'elapsed_time': '5m 18s', 'remaining_time': '3m 10s', 'memory(GiB)': 27.73, 'train_speed(iter/s)': 0.031432} | ||
| # Train: 62%|████████████████████████████████████████████████████████████████████████████▎ | 10/16 [05:18<02:51, 28.58s/it][INFO:swift] Saving model checkpoint to /model/ljl/output/v44-20251231-155036/checkpoint-10 | ||
| # {'loss': 1.36132231, 'grad_norm': 0.30557585, 'learning_rate': 1.09e-06, 'token_acc': 0.64442776, 'epoch': 0.96, 'global_step/max_steps': '15/16', 'percentage': '93.75%', 'elapsed_time': '7m 57s', 'remaining_time': '31s', 'memory(GiB)': 27.96, 'train_speed(iter/s)': 0.031437} | ||
| # ... | ||
| # {'train_runtime': 507.5282, 'train_samples_per_second': 0.985, 'train_steps_per_second': 0.032, 'train_loss': 1.61732693, 'token_acc': 0.63051608, 'epoch': 1.0, 'global_step/max_steps': '16/16', 'percentage': '100.00%', 'elapsed_time': '8m 27s', 'remaining_time': '0s', 'memory(GiB)': 27.96, 'train_speed(iter/s)': 0.031543} | ||
| # Train: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [08:27<00:00, 31.70s/it] | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加一个 前后的显存占用对比吧
然后用8B的模型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done