Skip to content

Activation Offloading#1220

Merged
vermouth1992 merged 9 commits intoverl-project:mainfrom
imh966:activation_offload
May 23, 2025
Merged

Activation Offloading#1220
vermouth1992 merged 9 commits intoverl-project:mainfrom
imh966:activation_offload

Conversation

@imh966
Copy link
Copy Markdown
Contributor

@imh966 imh966 commented Apr 23, 2025

Checklist Before Starting

  • Search for similar PR(s).

What does this PR do?

This PR supports activation offloading, and currently it's only for FSDP backend.

High-Level Design

Our implementation is based on the one in TransformerEngine. For efficiency, it groups activations by TransformerLayer and offloads activation groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th activation group happen at the same time, and there are at most two activation groups in GPU memory.

Specific Changes

  1. Add activation offloading support.

API

Usage Example

export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=./data/gsm8k/train.parquet \
    data.val_files=./data/gsm8k/test.parquet \
    data.train_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=./huggingface.co/Qwen/Qwen2-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.model.enable_activation_offload=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','tensorboard'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2_7b_function_rm' \
    trainer.n_gpus_per_node=8 \
    trainer.val_before_train=False \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=5 \
    trainer.total_epochs=15

Test

We conducted experiments on the Qwen2 7B model based on the above script. The memory and throughput data are shown in the figures below, where the blue line represents activation offloading.
image image

Additional Info.

  • Issue Number: none
  • Training: This PR will affect FSDP backend
  • Inference: none

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks.
  • Add [BREAKING] to the PR title if it breaks any API.
  • Update the documentation about your changes in the docs.
  • Add CI test(s) if neccessary.

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Apr 23, 2025

CLA assistant check
All committers have signed the CLA.

len(layers) - 1, len(layers), tensor_filter
)
if enable_ckpt:
# The implementation of activation checkpointing in transformers is uncompatiable with activation offload
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What makes the activation checkpointing incompatible with activation offloading? In most cases, you would like to combine both for maximum memory saved.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just that the activation checkpointing implementation in the transforms library does not work well with activation offloading directly. On the one hand, when activation checkpointing is enabled, Verl assigns "use_reentrant" to false, which leads to an implementation based on saved_tensor_hook. My implementation is also based on saved_tensor_hook, while PyTorch can only apply one saved_tensor_hook, unlike forward_hook, which allows multiple hooks. On the other hand, if try to assign "use_reentrant" to true, the forward function of the FSDP module will be executed again in backward, resulting in an additional allgather for the parameters of the FSDP module.

In my implementation, you can still combine them to save more memory. When activation offloading is enabled, I disable activation checkpointing in transformers library and enable it again in another way.

self,
offload_handler: OffloadHandler,
handler_extra_kwargs: Optional[Dict[str, Any]] = None,
debug: bool = False,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the debug argument in the signature

# Data Structure to maintain reference to activation tensors
self.tensor_tag_to_buf = {}
# Data structure to hold the FP8/MXFP8 tensor objects
self.fp8_tensor_object_map = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not used?

@ZihengJiang
Copy link
Copy Markdown
Contributor

ZihengJiang commented Apr 29, 2025

Thanks for the PR. Took a glance and left some comments. Please add some tests to make sure that the feature works properly.

Also, it seems fsdp has some support for activation offloading: https://github.com/pytorch/pytorch/blob/main/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py#L275

Have you checked this?

@BearBiscuit05
Copy link
Copy Markdown
Collaborator

#1118, This PR can fix the issue where vLLM doesn't detect memory release.

)

if config.model.get('enable_activation_offload', False):
enable_gradient_checkpointing = config.model.get('enable_gradient_checkpointing', False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a unit test to CI

@eric-haibin-lin
Copy link
Copy Markdown
Collaborator

eric-haibin-lin commented Apr 30, 2025

pls update PR description according to https://github.com/volcengine/verl/blob/main/.github/PULL_REQUEST_TEMPLATE.md and complete the checklist

@imh966
Copy link
Copy Markdown
Contributor Author

imh966 commented May 7, 2025

Thanks for the PR. Took a glance and left some comments. Please add some tests to make sure that the feature works properly.

Also, it seems fsdp has some support for activation offloading: https://github.com/pytorch/pytorch/blob/main/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py#L275

Have you checked this?

Thanks for the PR. Took a glance and left some comments. Please add some tests to make sure that the feature works properly.

Also, it seems fsdp has some support for activation offloading: https://github.com/pytorch/pytorch/blob/main/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py#L275

Have you checked this?

Yes, I checked this support before the PR. It does not support to overlap offloading and computation, and it offloads some parameters of module.

@imh966 imh966 force-pushed the activation_offload branch from 6a9aa7c to d203ab8 Compare May 7, 2025 11:53
@ZihengJiang
Copy link
Copy Markdown
Contributor

@imh966 thanks for confirming. Please add a test in the CI

else:
raise NotImplementedError(f"Unknown strategy {config.strategy}")

if config.model.get('enable_activation_offload', False):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will disable gradient checkpointing silently while activation offload is enabled in the same time, right? This sounds strange for user.

Let's just verify the configuration and raise an explicit error if enable_activation_offload and enable_gradient_checkpointing is enabled in the same time.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, given this feature only support FSDP for now, adding assertion in configuration verification to make sure `strategy == "fsdp".

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestions. Actually, it just disables gradient checkpointing in transformers library and enables it in another way because gradient checkpointing in transformers library is incompatible with my implementation of activation offloading. I will try to add some comments to make it not so strange for user.

module.forward = wrapped_method.__get__(module, type(module))


def enable_activation_offload_for_fsdp_model(model, enable_ckpt=False):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest short name for this utilities function, just offload_activation(model) and add assert for FSDP only.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed the name of this function to enable_activation_offloading. Do you think it's fine?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me

@ZihengJiang
Copy link
Copy Markdown
Contributor

@eric-haibin-lin @vermouth1992 shall we merge this?

@imh966 imh966 force-pushed the activation_offload branch from dccc0c2 to a77e537 Compare May 13, 2025 09:19
@imh966 imh966 requested a review from eric-haibin-lin May 15, 2025 03:16
@wplf
Copy link
Copy Markdown

wplf commented May 21, 2025

@imh966 I've seen the offload implementation between yours and TEs.
Do you modify TE's offload to support FSDP?
Is there other features besides it?

@imh966
Copy link
Copy Markdown
Contributor Author

imh966 commented May 21, 2025

@imh966 I've seen the offload implementation between yours and TEs. Do you modify TE's offload to support FSDP? Is there other features besides it?

@wplf Actually, it's almost a copy from TE. I just deleted some unused code to eliminate dependence on TE and added a few modifications to make it work well with FSDP.

@wplf
Copy link
Copy Markdown

wplf commented May 21, 2025

@imh966 I've seen the offload implementation between yours and TEs. Do you modify TE's offload to support FSDP? Is there other features besides it?

@wplf Actually, it's almost a copy from TE. I just deleted some unused code to eliminate dependence on TE and added a few modifications to make it work well with FSDP.

Thank you for your reply. I am impressed very much!

vermouth1992
vermouth1992 previously approved these changes May 22, 2025
@imh966 imh966 dismissed stale reviews from vermouth1992 and eric-haibin-lin via c475aa9 May 22, 2025 09:59
@imh966 imh966 force-pushed the activation_offload branch from 0dbc26e to 90397c9 Compare May 23, 2025 02:27
@vermouth1992 vermouth1992 merged commit aaaaaab into verl-project:main May 23, 2025
37 checks passed
ETOgaosion pushed a commit to Jianbing-D/verl that referenced this pull request Jun 8, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

This PR supports activation offloading, and currently it's only for FSDP
backend.

### High-Level Design

Our implementation is based on the
[one](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/cpu_offload.py)
in TransformerEngine. For efficiency, it groups activations by
TransformerLayer and offloads activation groups asynchronously. This
means that the offloading of the i-th activation group and the
computation of the i+1-th activation group happen at the same time, and
there are at most two activation groups in GPU memory.

### Specific Changes

1. Add activation offloading support.

### API

### Usage Example

``` 
export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=./data/gsm8k/train.parquet \
    data.val_files=./data/gsm8k/test.parquet \
    data.train_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=./huggingface.co/Qwen/Qwen2-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.model.enable_activation_offload=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','tensorboard'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2_7b_function_rm' \
    trainer.n_gpus_per_node=8 \
    trainer.val_before_train=False \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=5 \
    trainer.total_epochs=15

 ```


### Test

We conducted experiments on the Qwen2 7B model based on the above script. The memory and throughput data are shown in the figures below, where the blue line represents activation offloading.
<img width="351" alt="image" src="https://github.com/user-attachments/assets/207576a1-3f47-4b40-bf19-60cf8105d609" /> <img width="361" alt="image" src="https://github.com/user-attachments/assets/d58f0f8b-eb5f-4e19-a892-4d778ff26135" />

### Additional Info.

- **Issue Number**: none
- **Training**: This PR will affect FSDP backend
- **Inference**: none

### Checklist Before Submitting

- [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [x] Add `[BREAKING]` to the PR title if it breaks any API.
- [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add CI test(s) if neccessary.
wwwjn pushed a commit to wwwjn/verl that referenced this pull request Jun 10, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

This PR supports activation offloading, and currently it's only for FSDP
backend.

### High-Level Design

Our implementation is based on the
[one](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/cpu_offload.py)
in TransformerEngine. For efficiency, it groups activations by
TransformerLayer and offloads activation groups asynchronously. This
means that the offloading of the i-th activation group and the
computation of the i+1-th activation group happen at the same time, and
there are at most two activation groups in GPU memory.

### Specific Changes

1. Add activation offloading support.

### API

### Usage Example

``` 
export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=./data/gsm8k/train.parquet \
    data.val_files=./data/gsm8k/test.parquet \
    data.train_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=./huggingface.co/Qwen/Qwen2-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.model.enable_activation_offload=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','tensorboard'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2_7b_function_rm' \
    trainer.n_gpus_per_node=8 \
    trainer.val_before_train=False \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=5 \
    trainer.total_epochs=15

 ```


### Test

We conducted experiments on the Qwen2 7B model based on the above script. The memory and throughput data are shown in the figures below, where the blue line represents activation offloading.
<img width="351" alt="image" src="https://github.com/user-attachments/assets/207576a1-3f47-4b40-bf19-60cf8105d609" /> <img width="361" alt="image" src="https://github.com/user-attachments/assets/d58f0f8b-eb5f-4e19-a892-4d778ff26135" />

### Additional Info.

- **Issue Number**: none
- **Training**: This PR will affect FSDP backend
- **Inference**: none

### Checklist Before Submitting

- [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [x] Add `[BREAKING]` to the PR title if it breaks any API.
- [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add CI test(s) if neccessary.
@wplf
Copy link
Copy Markdown

wplf commented Sep 12, 2025

Thank you for great work. I've test this pull request on my own FSDP repo. And it works very well.

chenjiaoAngel added a commit to chenjiaoAngel/verl that referenced this pull request Nov 14, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

This PR supports activation offloading, and currently it's only for FSDP
backend.

### High-Level Design

Our implementation is based on the
[one](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/cpu_offload.py)
in TransformerEngine. For efficiency, it groups activations by
TransformerLayer and offloads activation groups asynchronously. This
means that the offloading of the i-th activation group and the
computation of the i+1-th activation group happen at the same time, and
there are at most two activation groups in GPU memory.

### Specific Changes

1. Add activation offloading support.

### API

### Usage Example

``` 
export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=./data/gsm8k/train.parquet \
    data.val_files=./data/gsm8k/test.parquet \
    data.train_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=./huggingface.co/Qwen/Qwen2-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.model.enable_activation_offload=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','tensorboard'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2_7b_function_rm' \
    trainer.n_gpus_per_node=8 \
    trainer.val_before_train=False \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=5 \
    trainer.total_epochs=15

 ```


### Test

We conducted experiments on the Qwen2 7B model based on the above script. The memory and throughput data are shown in the figures below, where the blue line represents activation offloading.
<img width="351" alt="image" src="https://github.com/user-attachments/assets/207576a1-3f47-4b40-bf19-60cf8105d609" /> <img width="361" alt="image" src="https://github.com/user-attachments/assets/d58f0f8b-eb5f-4e19-a892-4d778ff26135" />

### Additional Info.

- **Issue Number**: none
- **Training**: This PR will affect FSDP backend
- **Inference**: none

### Checklist Before Submitting

- [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [x] Add `[BREAKING]` to the PR title if it breaks any API.
- [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add CI test(s) if neccessary.
TimurTaepov pushed a commit to giorgossideris/verl that referenced this pull request Dec 20, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

This PR supports activation offloading, and currently it's only for FSDP
backend.

### High-Level Design

Our implementation is based on the
[one](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/cpu_offload.py)
in TransformerEngine. For efficiency, it groups activations by
TransformerLayer and offloads activation groups asynchronously. This
means that the offloading of the i-th activation group and the
computation of the i+1-th activation group happen at the same time, and
there are at most two activation groups in GPU memory.

### Specific Changes

1. Add activation offloading support.

### API

### Usage Example

``` 
export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=./data/gsm8k/train.parquet \
    data.val_files=./data/gsm8k/test.parquet \
    data.train_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=./huggingface.co/Qwen/Qwen2-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.model.enable_activation_offload=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','tensorboard'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2_7b_function_rm' \
    trainer.n_gpus_per_node=8 \
    trainer.val_before_train=False \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=5 \
    trainer.total_epochs=15

 ```


### Test

We conducted experiments on the Qwen2 7B model based on the above script. The memory and throughput data are shown in the figures below, where the blue line represents activation offloading.
<img width="351" alt="image" src="https://github.com/user-attachments/assets/207576a1-3f47-4b40-bf19-60cf8105d609" /> <img width="361" alt="image" src="https://github.com/user-attachments/assets/d58f0f8b-eb5f-4e19-a892-4d778ff26135" />

### Additional Info.

- **Issue Number**: none
- **Training**: This PR will affect FSDP backend
- **Inference**: none

### Checklist Before Submitting

- [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [x] Add `[BREAKING]` to the PR title if it breaks any API.
- [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add CI test(s) if neccessary.
vyomakesh0728 added a commit to vyomakesh0728/verl that referenced this pull request Jan 22, 2026
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

This PR supports activation offloading, and currently it's only for FSDP
backend.

### High-Level Design

Our implementation is based on the
[one](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/cpu_offload.py)
in TransformerEngine. For efficiency, it groups activations by
TransformerLayer and offloads activation groups asynchronously. This
means that the offloading of the i-th activation group and the
computation of the i+1-th activation group happen at the same time, and
there are at most two activation groups in GPU memory.

### Specific Changes

1. Add activation offloading support.

### API

### Usage Example

``` 
export VLLM_ATTENTION_BACKEND=XFORMERS

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    data.train_files=./data/gsm8k/train.parquet \
    data.val_files=./data/gsm8k/test.parquet \
    data.train_batch_size=512 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=./huggingface.co/Qwen/Qwen2-7B-Instruct \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.model.enable_activation_offload=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.n=5 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=64 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','tensorboard'] \
    trainer.project_name='verl_grpo_example_gsm8k' \
    trainer.experiment_name='qwen2_7b_function_rm' \
    trainer.n_gpus_per_node=8 \
    trainer.val_before_train=False \
    trainer.nnodes=1 \
    trainer.save_freq=-1 \
    trainer.test_freq=5 \
    trainer.total_epochs=15

 ```


### Test

We conducted experiments on the Qwen2 7B model based on the above script. The memory and throughput data are shown in the figures below, where the blue line represents activation offloading.
<img width="351" alt="image" src="https://github.com/user-attachments/assets/207576a1-3f47-4b40-bf19-60cf8105d609" /> <img width="361" alt="image" src="https://github.com/user-attachments/assets/d58f0f8b-eb5f-4e19-a892-4d778ff26135" />

### Additional Info.

- **Issue Number**: none
- **Training**: This PR will affect FSDP backend
- **Inference**: none

### Checklist Before Submitting

- [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [x] Add `[BREAKING]` to the PR title if it breaks any API.
- [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add CI test(s) if neccessary.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants