Skip to content

[feat] support activation cpu offload in fsdp and fsdp2#7201

Merged
Jintao-Huang merged 7 commits intomodelscope:mainfrom
tpx818:activation_cpu_offlod
Feb 11, 2026
Merged

[feat] support activation cpu offload in fsdp and fsdp2#7201
Jintao-Huang merged 7 commits intomodelscope:mainfrom
tpx818:activation_cpu_offlod

Conversation

@meichangsu1
Copy link
Copy Markdown
Contributor

@meichangsu1 meichangsu1 commented Dec 24, 2025

[feat] support activation cpu offload in fsdp and fsdp2

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

Introduce activation CPU offloading built on autograd saved_tensors hooks and grouped async offload/reload to reduce GPU activation memory.
Add synchronous and async double-buffer offload handlers with stream-based D2H/H2D overlap and group window scheduling.
Wrap FSDP/FSDP2 layer forward to insert group commit boundaries and manage offload context; skip Embedding layers.
Provide activation checkpointing compatibility by replacing transformers’ checkpointing with internal checkpoint wrapper when enabled.
Add training callback to read fsdp_config flags and enable activation offload for FSDP v1/v2 at train start.

Experiment results

a simple test

batch_size seq_len peak_no_offload_MB peak_offload_MB delta_MB
8 256 304.20 336.23 -32.03
8 512 496.10 448.13 47.97
8 1024 800.49 660.22 140.27
16 256 624.23 576.26 47.97
16 512 928.61 788.34 140.27
16 1024 1504.74 1172.41 332.33

for model-specific results,please look the examples in the code

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @meichangsu1, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a critical memory optimization feature: activation CPU offloading for FSDP and FSDP2 training strategies. By intelligently moving activation tensors to CPU memory, the system can train larger models or use more extensive batch sizes on GPUs with limited memory. The implementation includes a sophisticated asynchronous offloading mechanism and is seamlessly integrated into the existing training framework via a new callback, making it easy to enable and configure.

Highlights

  • Activation CPU Offload: Implemented support for offloading activation tensors to CPU memory for FSDP and FSDP2 models, significantly reducing GPU memory consumption during training.
  • Asynchronous Offloading Mechanism: Introduced an asynchronous double-buffer group offload handler that intelligently manages CPU-GPU data transfers, allowing for overlapping communication and computation to improve efficiency.
  • Integrated Training Callback: A new ActivationCpuOffloadCallBack is added to the training pipeline, which automatically enables and configures activation CPU offloading based on the FSDP settings provided in the training arguments.
  • FSDP Configuration and Examples: Provided a new example FSDP2 configuration file (fsdp2.json) with detailed documentation and activation_cpu_offload enabled, along with a corresponding training script (train.sh).
  • Gradient Requirement for Checkpointing: Ensured proper gradient computation by calling model.enable_input_require_grads() when activation checkpointing is enabled in conjunction with CPU offloading.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for activation CPU offload in FSDP and FSDP2, which is a great feature for memory optimization. The implementation adds a new ActivationCpuOffloadCallBack and the associated logic to hook into the model's forward pass.

My review has identified a few issues:

  • There is a critical inconsistency in the FSDP configuration key used for the FSDP version (version vs. fsdp_version), which will prevent the new feature from working.
  • The example configuration file has a confusing note.
  • The example training script has a typo.
  • The new activation_cpu_offload.py file has some minor issues with logging configuration and type hinting.

I have provided detailed comments and suggestions to address these points. Once these are resolved, the PR should be in good shape.

CUDA_VISIBLE_DEVICES=0,1 \
swift sft \
--model 'Qwen/Qwen3-0.6B' \
--dataset 'swift/self-cognition#1000' \ \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is an extra backslash here (\\). A single backslash is sufficient for line continuation in shell scripts. This could cause unexpected behavior in some shells.

Suggested change
--dataset 'swift/self-cognition#1000' \ \
--dataset 'swift/self-cognition#1000' \

self.model_parameters_storage = new_storage


def get_torch_device() -> any:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint any is used for the return value. any is a built-in function in Python. The correct type hint for an arbitrary type is Any from the typing module, which is already imported in this file.

Suggested change
def get_torch_device() -> any:
def get_torch_device() -> Any:

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

方便问一下,这个RP解决的是什么问题啊

@@ -0,0 +1,27 @@
#!/bin/bash
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一个 前后的显存占用对比吧

然后用8B的模型

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from swift.utils import get_logger

logger = get_logger()
logger.setLevel(logging.WARNING)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个请去掉,影响面比较大

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

有参考代码的链接嘛

@meichangsu1
Copy link
Copy Markdown
Contributor Author

有参考代码的链接嘛

InternLM/InternEvo#391
verl-project/verl#1220


# activation_cpu_offload=false
# OOM
# {'loss': 1.13790035, 'grad_norm': 1.41472316, 'learning_rate': 5e-05, 'token_acc': 0.83174487, 'epoch': 0.04, 'global_step/max_steps': '1/27', 'percentage': '3.70%', 'elapsed_time': '46s', 'remaining_time': '20m 1s', 'memory(GiB)': 61.79, 'train_speed(iter/s)': 0.021641}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lora 61GiB吗,这个例子跑的啊

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是这个例子跑的,用的内部一个数据集跑的,那个数据机的token 长度比较大

--lora_rank 8 \
--lora_alpha 32 \
--target_modules all-linear \
--freeze_vit true \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么有vit的参数

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for that,the script displayed here is copied from the other demo;while the log is not produced from the train script here ,i will update it later

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经更新了示例和相应的显存占用

- Add ActivationCpuOffloadCallBack import and registration in callbacks mapping
- Automatically append activation_cpu_offload callback when FSDP config has activation_cpu_offload enabled
- Enables memory-efficient training by offloading activations to CPU during FSDP forward pass
…offload

- Add `fsdp2.json` configuration file for PyTorch native FSDP v2 with activation CPU offloading
- Include detailed parameter documentation and usage notes for FSDP2
- Provide example training script (`train.sh`) demonstrating multi-GPU training with LoRA
- Disable gradient checkpointing in favor of FSDP's native activation checkpointing
- Enable CPU RAM efficient loading and sharded state dicts for memory optimization
@meichangsu1 meichangsu1 deleted the activation_cpu_offlod branch January 22, 2026 03:51
@meichangsu1 meichangsu1 restored the activation_cpu_offlod branch January 22, 2026 04:30
@meichangsu1 meichangsu1 reopened this Jan 22, 2026
@meichangsu1 meichangsu1 force-pushed the activation_cpu_offlod branch from 86979cb to 20680af Compare January 22, 2026 04:33
- Add __init__ method to ActivationCpuOffloadCallBack to properly initialize parent class
- Update import to use local base TrainerCallback instead of transformers version
- Ensure callback follows consistent initialization pattern with other callbacks
- Remove activation_cpu_offload parameter from fsdp2.json
- Set activation_checkpointing to true for improved memory efficiency
- Maintain existing auto_wrap_policy and state_dict_type settings
@@ -0,0 +1,26 @@
{
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放在examples/ascend文件夹下吧

@@ -0,0 +1,54 @@
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1 \
NPROC_PER_NODE=2 \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这能给个体现这个技术很好的例子嘛

因为看这里例子,显存降低的不多,但是速度降了很多

self.model_parameters_storage = new_storage


def get_torch_device() -> any:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any -> Any

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable feature for activation CPU offloading with FSDP, which can significantly reduce GPU memory usage. The implementation is thorough, including asynchronous offloading with double buffering and compatibility with activation checkpointing. I've identified a critical bug that prevents the feature from being enabled, along with a couple of other suggestions for improvement to enhance robustness and clarity. Overall, this is a great addition.

Comment on lines +70 to +80
def get_torch_device() -> any:
"""Return the corresponding torch attribute based on the device type string.
Returns:
module: The corresponding torch device namespace, or torch.cuda if not found.
"""
device_name = get_device_name()
try:
return getattr(torch, device_name)
except AttributeError:
logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
return torch.cuda
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The get_torch_device function is not robust for CPU-only environments. If get_device_name() returns 'cpu', getattr(torch, 'cpu') raises an AttributeError, and the code falls back to torch.cuda. If CUDA is not available, the subsequent call to torch.cuda.Stream() will crash. Since this feature is designed for offloading from a GPU to save memory, it's better to explicitly raise an error in CPU-only mode to avoid unexpected crashes and provide a clear message to the user.

Suggested change
def get_torch_device() -> any:
"""Return the corresponding torch attribute based on the device type string.
Returns:
module: The corresponding torch device namespace, or torch.cuda if not found.
"""
device_name = get_device_name()
try:
return getattr(torch, device_name)
except AttributeError:
logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
return torch.cuda
def get_torch_device() -> any:
"""Return the corresponding torch attribute based on the device type string.
Returns:
module: The corresponding torch device namespace, or torch.cuda if not found.
"""
device_name = get_device_name()
if device_name == 'cpu':
raise RuntimeError('Activation CPU offload requires a device with streams (e.g., CUDA, NPU) and is not supported in CPU-only mode.')
try:
return getattr(torch, device_name)
except AttributeError:
logger.warning(f"Device namespace '{device_name}' not found in torch, try to load torch.cuda.")
return torch.cuda

- Add new training script `train.sh` for Ascend platform with activation CPU offload configuration
- Include comprehensive training parameters for Qwen3-8B model with LoRA fine-tuning
- Provide example training outputs for both activation_cpu_offload=true and false scenarios
- Move existing fsdp2.json configuration file to Ascend examples directory
@Jintao-Huang
Copy link
Copy Markdown
Collaborator

哈咯,lint 过一下

…c copy

- Change training dataset from `AI-ModelScope/LongAlpaca-12k` to `AI-ModelScope/alpaca-gpt4-data-zh` in the example script
- Modify `SynchronizedGroupOffloadHandler.offload` to use synchronous, non-pinned memory copy when NPU is available, as NPU does not fully support async H2D/D2H with pinned memory
@meichangsu1
Copy link
Copy Markdown
Contributor Author

哈咯,lint 过一下

@meichangsu1 meichangsu1 closed this Feb 5, 2026
@meichangsu1
Copy link
Copy Markdown
Contributor Author

哈咯,lint 过一下

done

@meichangsu1 meichangsu1 reopened this Feb 5, 2026
@Jintao-Huang Jintao-Huang merged commit dc6ab89 into modelscope:main Feb 11, 2026
3 of 6 checks passed
zhichenggeng pushed a commit to zhichenggeng/ms-swift that referenced this pull request Feb 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants