Activation Offloading#1220
Conversation
verl/utils/activation_offload.py
Outdated
| len(layers) - 1, len(layers), tensor_filter | ||
| ) | ||
| if enable_ckpt: | ||
| # The implementation of activation checkpointing in transformers is uncompatiable with activation offload |
There was a problem hiding this comment.
What makes the activation checkpointing incompatible with activation offloading? In most cases, you would like to combine both for maximum memory saved.
There was a problem hiding this comment.
It's just that the activation checkpointing implementation in the transforms library does not work well with activation offloading directly. On the one hand, when activation checkpointing is enabled, Verl assigns "use_reentrant" to false, which leads to an implementation based on saved_tensor_hook. My implementation is also based on saved_tensor_hook, while PyTorch can only apply one saved_tensor_hook, unlike forward_hook, which allows multiple hooks. On the other hand, if try to assign "use_reentrant" to true, the forward function of the FSDP module will be executed again in backward, resulting in an additional allgather for the parameters of the FSDP module.
In my implementation, you can still combine them to save more memory. When activation offloading is enabled, I disable activation checkpointing in transformers library and enable it again in another way.
verl/utils/activation_offload.py
Outdated
| self, | ||
| offload_handler: OffloadHandler, | ||
| handler_extra_kwargs: Optional[Dict[str, Any]] = None, | ||
| debug: bool = False, |
There was a problem hiding this comment.
Remove the debug argument in the signature
verl/utils/activation_offload.py
Outdated
| # Data Structure to maintain reference to activation tensors | ||
| self.tensor_tag_to_buf = {} | ||
| # Data structure to hold the FP8/MXFP8 tensor objects | ||
| self.fp8_tensor_object_map = {} |
|
Thanks for the PR. Took a glance and left some comments. Please add some tests to make sure that the feature works properly. Also, it seems fsdp has some support for activation offloading: https://github.com/pytorch/pytorch/blob/main/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py#L275 Have you checked this? |
|
#1118, This PR can fix the issue where vLLM doesn't detect memory release. |
verl/workers/fsdp_workers.py
Outdated
| ) | ||
|
|
||
| if config.model.get('enable_activation_offload', False): | ||
| enable_gradient_checkpointing = config.model.get('enable_gradient_checkpointing', False) |
There was a problem hiding this comment.
please add a unit test to CI
|
pls update PR description according to https://github.com/volcengine/verl/blob/main/.github/PULL_REQUEST_TEMPLATE.md and complete the checklist |
Yes, I checked this support before the PR. It does not support to overlap offloading and computation, and it offloads some parameters of module. |
|
@imh966 thanks for confirming. Please add a test in the CI |
verl/workers/fsdp_workers.py
Outdated
| else: | ||
| raise NotImplementedError(f"Unknown strategy {config.strategy}") | ||
|
|
||
| if config.model.get('enable_activation_offload', False): |
There was a problem hiding this comment.
It will disable gradient checkpointing silently while activation offload is enabled in the same time, right? This sounds strange for user.
Let's just verify the configuration and raise an explicit error if enable_activation_offload and enable_gradient_checkpointing is enabled in the same time.
There was a problem hiding this comment.
Also, given this feature only support FSDP for now, adding assertion in configuration verification to make sure `strategy == "fsdp".
There was a problem hiding this comment.
Thank you for your suggestions. Actually, it just disables gradient checkpointing in transformers library and enables it in another way because gradient checkpointing in transformers library is incompatible with my implementation of activation offloading. I will try to add some comments to make it not so strange for user.
verl/utils/activation_offload.py
Outdated
| module.forward = wrapped_method.__get__(module, type(module)) | ||
|
|
||
|
|
||
| def enable_activation_offload_for_fsdp_model(model, enable_ckpt=False): |
There was a problem hiding this comment.
I suggest short name for this utilities function, just offload_activation(model) and add assert for FSDP only.
There was a problem hiding this comment.
I've changed the name of this function to enable_activation_offloading. Do you think it's fine?
|
@eric-haibin-lin @vermouth1992 shall we merge this? |
dccc0c2 to
a77e537
Compare
|
@imh966 I've seen the offload implementation between yours and TEs. |
@wplf Actually, it's almost a copy from TE. I just deleted some unused code to eliminate dependence on TE and added a few modifications to make it work well with FSDP. |
Thank you for your reply. I am impressed very much! |
0dbc26e to
90397c9
Compare
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? This PR supports activation offloading, and currently it's only for FSDP backend. ### High-Level Design Our implementation is based on the [one](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/cpu_offload.py) in TransformerEngine. For efficiency, it groups activations by TransformerLayer and offloads activation groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th activation group happen at the same time, and there are at most two activation groups in GPU memory. ### Specific Changes 1. Add activation offloading support. ### API ### Usage Example ``` export VLLM_ATTENTION_BACKEND=XFORMERS python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=./data/gsm8k/train.parquet \ data.val_files=./data/gsm8k/test.parquet \ data.train_batch_size=512 \ data.max_prompt_length=512 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=./huggingface.co/Qwen/Qwen2-7B-Instruct \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=64 \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.model.enable_activation_offload=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=64 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger=['console','tensorboard'] \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.val_before_train=False \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ trainer.total_epochs=15 ``` ### Test We conducted experiments on the Qwen2 7B model based on the above script. The memory and throughput data are shown in the figures below, where the blue line represents activation offloading. <img width="351" alt="image" src="https://github.com/user-attachments/assets/207576a1-3f47-4b40-bf19-60cf8105d609" /> <img width="361" alt="image" src="https://github.com/user-attachments/assets/d58f0f8b-eb5f-4e19-a892-4d778ff26135" /> ### Additional Info. - **Issue Number**: none - **Training**: This PR will affect FSDP backend - **Inference**: none ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if neccessary.
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? This PR supports activation offloading, and currently it's only for FSDP backend. ### High-Level Design Our implementation is based on the [one](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/cpu_offload.py) in TransformerEngine. For efficiency, it groups activations by TransformerLayer and offloads activation groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th activation group happen at the same time, and there are at most two activation groups in GPU memory. ### Specific Changes 1. Add activation offloading support. ### API ### Usage Example ``` export VLLM_ATTENTION_BACKEND=XFORMERS python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=./data/gsm8k/train.parquet \ data.val_files=./data/gsm8k/test.parquet \ data.train_batch_size=512 \ data.max_prompt_length=512 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=./huggingface.co/Qwen/Qwen2-7B-Instruct \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=64 \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.model.enable_activation_offload=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=64 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger=['console','tensorboard'] \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.val_before_train=False \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ trainer.total_epochs=15 ``` ### Test We conducted experiments on the Qwen2 7B model based on the above script. The memory and throughput data are shown in the figures below, where the blue line represents activation offloading. <img width="351" alt="image" src="https://github.com/user-attachments/assets/207576a1-3f47-4b40-bf19-60cf8105d609" /> <img width="361" alt="image" src="https://github.com/user-attachments/assets/d58f0f8b-eb5f-4e19-a892-4d778ff26135" /> ### Additional Info. - **Issue Number**: none - **Training**: This PR will affect FSDP backend - **Inference**: none ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if neccessary.
|
Thank you for great work. I've test this pull request on my own FSDP repo. And it works very well. |
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? This PR supports activation offloading, and currently it's only for FSDP backend. ### High-Level Design Our implementation is based on the [one](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/cpu_offload.py) in TransformerEngine. For efficiency, it groups activations by TransformerLayer and offloads activation groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th activation group happen at the same time, and there are at most two activation groups in GPU memory. ### Specific Changes 1. Add activation offloading support. ### API ### Usage Example ``` export VLLM_ATTENTION_BACKEND=XFORMERS python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=./data/gsm8k/train.parquet \ data.val_files=./data/gsm8k/test.parquet \ data.train_batch_size=512 \ data.max_prompt_length=512 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=./huggingface.co/Qwen/Qwen2-7B-Instruct \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=64 \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.model.enable_activation_offload=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=64 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger=['console','tensorboard'] \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.val_before_train=False \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ trainer.total_epochs=15 ``` ### Test We conducted experiments on the Qwen2 7B model based on the above script. The memory and throughput data are shown in the figures below, where the blue line represents activation offloading. <img width="351" alt="image" src="https://github.com/user-attachments/assets/207576a1-3f47-4b40-bf19-60cf8105d609" /> <img width="361" alt="image" src="https://github.com/user-attachments/assets/d58f0f8b-eb5f-4e19-a892-4d778ff26135" /> ### Additional Info. - **Issue Number**: none - **Training**: This PR will affect FSDP backend - **Inference**: none ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if neccessary.
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? This PR supports activation offloading, and currently it's only for FSDP backend. ### High-Level Design Our implementation is based on the [one](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/cpu_offload.py) in TransformerEngine. For efficiency, it groups activations by TransformerLayer and offloads activation groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th activation group happen at the same time, and there are at most two activation groups in GPU memory. ### Specific Changes 1. Add activation offloading support. ### API ### Usage Example ``` export VLLM_ATTENTION_BACKEND=XFORMERS python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=./data/gsm8k/train.parquet \ data.val_files=./data/gsm8k/test.parquet \ data.train_batch_size=512 \ data.max_prompt_length=512 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=./huggingface.co/Qwen/Qwen2-7B-Instruct \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=64 \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.model.enable_activation_offload=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=64 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger=['console','tensorboard'] \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.val_before_train=False \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ trainer.total_epochs=15 ``` ### Test We conducted experiments on the Qwen2 7B model based on the above script. The memory and throughput data are shown in the figures below, where the blue line represents activation offloading. <img width="351" alt="image" src="https://github.com/user-attachments/assets/207576a1-3f47-4b40-bf19-60cf8105d609" /> <img width="361" alt="image" src="https://github.com/user-attachments/assets/d58f0f8b-eb5f-4e19-a892-4d778ff26135" /> ### Additional Info. - **Issue Number**: none - **Training**: This PR will affect FSDP backend - **Inference**: none ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if neccessary.
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? This PR supports activation offloading, and currently it's only for FSDP backend. ### High-Level Design Our implementation is based on the [one](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/cpu_offload.py) in TransformerEngine. For efficiency, it groups activations by TransformerLayer and offloads activation groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th activation group happen at the same time, and there are at most two activation groups in GPU memory. ### Specific Changes 1. Add activation offloading support. ### API ### Usage Example ``` export VLLM_ATTENTION_BACKEND=XFORMERS python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=./data/gsm8k/train.parquet \ data.val_files=./data/gsm8k/test.parquet \ data.train_batch_size=512 \ data.max_prompt_length=512 \ data.max_response_length=1024 \ data.filter_overlong_prompts=True \ data.truncation='error' \ actor_rollout_ref.model.path=./huggingface.co/Qwen/Qwen2-7B-Instruct \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=64 \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ actor_rollout_ref.actor.entropy_coeff=0 \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.model.enable_activation_offload=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ actor_rollout_ref.rollout.n=5 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=64 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ algorithm.use_kl_in_reward=False \ trainer.critic_warmup=0 \ trainer.logger=['console','tensorboard'] \ trainer.project_name='verl_grpo_example_gsm8k' \ trainer.experiment_name='qwen2_7b_function_rm' \ trainer.n_gpus_per_node=8 \ trainer.val_before_train=False \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ trainer.total_epochs=15 ``` ### Test We conducted experiments on the Qwen2 7B model based on the above script. The memory and throughput data are shown in the figures below, where the blue line represents activation offloading. <img width="351" alt="image" src="https://github.com/user-attachments/assets/207576a1-3f47-4b40-bf19-60cf8105d609" /> <img width="361" alt="image" src="https://github.com/user-attachments/assets/d58f0f8b-eb5f-4e19-a892-4d778ff26135" /> ### Additional Info. - **Issue Number**: none - **Training**: This PR will affect FSDP backend - **Inference**: none ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title if it breaks any API. - [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if neccessary.
Checklist Before Starting
What does this PR do?
This PR supports activation offloading, and currently it's only for FSDP backend.
High-Level Design
Our implementation is based on the one in TransformerEngine. For efficiency, it groups activations by TransformerLayer and offloads activation groups asynchronously. This means that the offloading of the i-th activation group and the computation of the i+1-th activation group happen at the same time, and there are at most two activation groups in GPU memory.
Specific Changes
API
Usage Example
Test
We conducted experiments on the Qwen2 7B model based on the above script. The memory and throughput data are shown in the figures below, where the blue line represents activation offloading.

Additional Info.
Checklist Before Submitting
[BREAKING]to the PR title if it breaks any API.