-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[fsdp] feat: Merge lora in fsdp training to speed up rollout #5115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
fd8a7de
add merging lora for fsdp.
amzfang 0660408
fix the lint and license.
amzfang ca93c41
minor update in examples/grpo_trainer/run_qwen3-4b_gsm8k_grpo_lora_me…
amzfang 2b701fe
address review comments.
amzfang 479c28e
Apply suggestion from @Copilot
amzfang e48aebb
fix workflow tests
amzfang f5914e7
fix fully_sharded_loras config.
amzfang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
examples/grpo_trainer/run_qwen3-4b_gsm8k_grpo_lora_merge.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| set -x | ||
|
|
||
| # initial "val-core/openai/gsm8k/acc/mean@1":0.378316906747536 | ||
| # after training: "val-core/openai/gsm8k/acc/mean@1":0.9264594389689158 | ||
|
|
||
| TIMESTAMP=$(date +%Y%m%d.%H%M%S) | ||
| project_name=verl_grpo_example_gsm8k | ||
| experiment_name=qwen3_4b_grpo-lora-merged-${TIMESTAMP} | ||
| train_dir=outputs/$project_name/$experiment_name/ | ||
| mkdir -p $train_dir | ||
| export TENSORBOARD_DIR=$train_dir/tensorboard_log/ | ||
| export VERL_FILE_LOGGER_PATH=$train_dir/metrics.jsonl | ||
|
|
||
| max_token_len_per_gpu=24576 | ||
|
|
||
| python3 -m verl.trainer.main_ppo \ | ||
| algorithm.adv_estimator=grpo \ | ||
| trainer.val_before_train=True \ | ||
| data.train_files=$HOME/data/gsm8k/train.parquet \ | ||
| data.val_files=$HOME/data/gsm8k/test.parquet \ | ||
| data.train_batch_size=128 \ | ||
| data.max_prompt_length=1024 \ | ||
| data.max_response_length=1024 \ | ||
| data.filter_overlong_prompts=True \ | ||
| data.truncation='error' \ | ||
| data.shuffle=False \ | ||
| actor_rollout_ref.model.path=Qwen/Qwen3-4B \ | ||
| actor_rollout_ref.model.use_remove_padding=True \ | ||
| actor_rollout_ref.model.enable_gradient_checkpointing=True \ | ||
| +actor_rollout_ref.model.lora.merge=True \ | ||
| actor_rollout_ref.model.lora_rank=32 \ | ||
| actor_rollout_ref.model.lora_alpha=64 \ | ||
| actor_rollout_ref.actor.optim.lr=1.0e-05 \ | ||
| actor_rollout_ref.actor.use_dynamic_bsz=True \ | ||
| actor_rollout_ref.actor.ppo_mini_batch_size=64 \ | ||
| actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token_len_per_gpu} \ | ||
| actor_rollout_ref.actor.use_kl_loss=True \ | ||
| actor_rollout_ref.actor.kl_loss_coef=0.001 \ | ||
| actor_rollout_ref.actor.kl_loss_type=low_var_kl \ | ||
| actor_rollout_ref.actor.entropy_coeff=0 \ | ||
| actor_rollout_ref.actor.strategy=fsdp2 \ | ||
| actor_rollout_ref.actor.fsdp_config.model_dtype=bf16 \ | ||
| actor_rollout_ref.actor.fsdp_config.param_offload=False \ | ||
| actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ | ||
| actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \ | ||
| actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token_len_per_gpu} \ | ||
| actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ | ||
| actor_rollout_ref.rollout.name=vllm \ | ||
| actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ | ||
| actor_rollout_ref.rollout.n=8 \ | ||
| actor_rollout_ref.rollout.load_format=safetensors \ | ||
| actor_rollout_ref.rollout.layered_summon=True \ | ||
| actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ | ||
| actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token_len_per_gpu} \ | ||
| actor_rollout_ref.ref.fsdp_config.param_offload=True \ | ||
| actor_rollout_ref.ref.strategy=fsdp2 \ | ||
| actor_rollout_ref.ref.fsdp_config.model_dtype=bf16 \ | ||
| algorithm.use_kl_in_reward=False \ | ||
| trainer.use_legacy_worker_impl=disable \ | ||
| trainer.critic_warmup=0 \ | ||
| trainer.logger='["console","tensorboard","file"]' \ | ||
| trainer.project_name=$project_name \ | ||
| trainer.experiment_name=$experiment_name \ | ||
| trainer.default_local_dir=$train_dir \ | ||
| trainer.n_gpus_per_node=8 \ | ||
| trainer.nnodes=1 \ | ||
| trainer.save_freq=20 \ | ||
| trainer.test_freq=5 \ | ||
| trainer.total_epochs=1 \ | ||
| 2>&1 | tee $train_dir/train_log.txt | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,221 @@ | ||
| # Copyright 2026 Amazon.com Inc and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import os | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.distributed | ||
| import torch.multiprocessing as mp | ||
| from peft import LoraConfig, get_peft_model | ||
| from torch.distributed import init_device_mesh | ||
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
| from torch.distributed.fsdp import MixedPrecision, ShardingStrategy | ||
| from transformers import AutoModelForCausalLM, GptOssConfig, Qwen2Config | ||
|
|
||
| from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device | ||
| from verl.utils.fsdp_utils import ( | ||
| MixedPrecisionPolicy, | ||
| apply_fsdp2, | ||
| get_fsdp_wrap_policy, | ||
| merged_lora_context, | ||
| ) | ||
|
|
||
|
|
||
| def _test_merged_lora_context_worker( | ||
| rank, world_size, rendezvous_file, strategy, model_config, lora_config_dict, backup_adapters | ||
| ): | ||
| """Worker function for testing merged_lora_context with FSDP. | ||
|
|
||
| Args: | ||
| rank: Process rank | ||
| world_size: Total number of processes | ||
| rendezvous_file: Path to rendezvous file for distributed init | ||
| strategy: FSDP strategy ("fsdp" or "fsdp2") | ||
| model_config: Model configuration object (Qwen2Config, GptOssConfig, etc.) | ||
| lora_config_dict: Dictionary of LoRA configuration parameters | ||
| backup_adapters: Whether to backup adapter weights before merging | ||
| """ | ||
| get_torch_device().set_device(rank) | ||
| torch.distributed.init_process_group( | ||
| backend=get_nccl_backend(), | ||
| init_method=f"file://{rendezvous_file}", | ||
| rank=rank, | ||
| world_size=world_size, | ||
| ) | ||
| device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=("dp",)) | ||
|
|
||
| # Create model from provided config | ||
| with torch.device(get_device_name()): | ||
| model = AutoModelForCausalLM.from_config( | ||
| config=model_config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" | ||
| ) | ||
| model = model.to(device=get_device_name()) | ||
|
|
||
| # Add LoRA with provided config | ||
| lora_config = LoraConfig(**lora_config_dict) | ||
| model = get_peft_model(model, lora_config) | ||
|
|
||
| # Initialize LoRA adapter weights to non-zero values for testing | ||
| from peft.tuners.lora import LoraLayer | ||
|
|
||
| with torch.no_grad(): | ||
| for name, module in model.named_modules(): | ||
| if isinstance(module, LoraLayer): | ||
| for adapter_name in module.lora_A.keys(): | ||
| if adapter_name in module.lora_A: | ||
| # Initialize lora_A with values around 1.0 | ||
| module.lora_A[adapter_name].weight.data.uniform_(0.5, 1.5) | ||
| if adapter_name in module.lora_B: | ||
| # Initialize lora_B with values around 2.0 | ||
| module.lora_B[adapter_name].weight.data.uniform_(1.5, 2.5) | ||
|
|
||
| # Wrap model with FSDP | ||
| if strategy == "fsdp": | ||
| mixed_precision = MixedPrecision( | ||
| param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32 | ||
| ) | ||
| model = FSDP( | ||
| model, | ||
| use_orig_params=True, | ||
| device_id=get_torch_device().current_device(), | ||
| sharding_strategy=ShardingStrategy.FULL_SHARD, | ||
| mixed_precision=mixed_precision, | ||
| device_mesh=device_mesh, | ||
| auto_wrap_policy=get_fsdp_wrap_policy(module=model, is_lora=True), | ||
| ) | ||
| else: | ||
| mp_policy = MixedPrecisionPolicy( | ||
| param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True | ||
| ) | ||
| fsdp_kwargs = { | ||
| "mesh": device_mesh, | ||
| "mp_policy": mp_policy, | ||
| } | ||
| apply_fsdp2(model, fsdp_kwargs, {}) | ||
|
|
||
| # Test: backup adapter weights, merge, restore | ||
| from peft.tuners.lora import LoraLayer | ||
|
|
||
| lora_layers = [m for m in model.modules() if isinstance(m, LoraLayer)] | ||
|
|
||
| # Verify LoRA layers exist | ||
| assert len(lora_layers) > 0, "Model should have LoRA layers" | ||
|
|
||
| # Initially not merged | ||
| for layer in lora_layers: | ||
| assert not getattr(layer, "merged", False), "LoRA should not be merged initially" | ||
|
|
||
| # Backup adapter weights before merge | ||
| from peft.utils.save_and_load import get_peft_model_state_dict | ||
|
|
||
| original_adapter_weights = get_peft_model_state_dict(model) | ||
|
|
||
| # Use merged_lora_context with the specified backup_adapters flag | ||
| for _ in range(3): | ||
| with merged_lora_context(model, backup_adapters=backup_adapters): | ||
| # Inside context, LoRA should be merged | ||
| for layer in lora_layers: | ||
| assert getattr(layer, "merged", False), "LoRA should be merged inside context" | ||
|
|
||
| # After context, check the state based on backup_adapters flag | ||
| for layer in lora_layers: | ||
| assert not getattr(layer, "merged", False), "LoRA should be unmerged after context" | ||
|
|
||
| restored_adapter_weights = get_peft_model_state_dict(model) | ||
|
|
||
| # Verify adapter weights are restored exactly | ||
| for key in original_adapter_weights.keys(): | ||
| assert key in restored_adapter_weights, f"Key {key} should be in restored weights" | ||
| torch.testing.assert_close( | ||
| original_adapter_weights[key].cpu(), | ||
| restored_adapter_weights[key].cpu(), | ||
| rtol=1e-5, | ||
| atol=1e-6, | ||
| msg=f"Adapter weight {key} should be restored to original value", | ||
| ) | ||
|
|
||
| if rank == 0: | ||
| model_name = model_config.__class__.__name__ | ||
| backup_mode = "with backup" if backup_adapters else "without backup" | ||
| print(f"merged_lora_context test with {model_name} {strategy} {backup_mode} passed on {world_size} GPUs!") | ||
|
|
||
| torch.distributed.barrier() | ||
| torch.distributed.destroy_process_group() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("world_size", (2,)) | ||
| @pytest.mark.parametrize("strategy", ("fsdp", "fsdp2")) | ||
| @pytest.mark.parametrize("backup_adapters", (True, False)) | ||
| def test_merged_lora_context_qwen2(world_size, strategy, backup_adapters, tmp_path): | ||
| """Test merged_lora_context with FSDP on Qwen2 model.""" | ||
| rendezvous_file = str(tmp_path / f"rdzv_file_qwen2_{backup_adapters}") | ||
| os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True) | ||
|
|
||
| # Create Qwen2 model config | ||
| model_config = Qwen2Config(num_hidden_layers=2, num_attention_heads=2, hidden_size=128) | ||
|
|
||
| # Create LoRA config for Qwen2 | ||
| lora_config_dict = { | ||
| "r": 8, | ||
| "lora_alpha": 16, | ||
| "target_modules": ["q_proj", "v_proj"], | ||
| "lora_dropout": 0.0, | ||
| "bias": "none", | ||
| "task_type": "CAUSAL_LM", | ||
| } | ||
|
|
||
| mp.spawn( | ||
| fn=_test_merged_lora_context_worker, | ||
| args=(world_size, rendezvous_file, strategy, model_config, lora_config_dict, backup_adapters), | ||
| nprocs=world_size, | ||
| join=True, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("world_size", (2,)) | ||
| @pytest.mark.parametrize("strategy", ("fsdp", "fsdp2")) | ||
| @pytest.mark.parametrize("backup_adapters", (True, False)) | ||
| def test_merged_lora_context_gptoss(world_size, strategy, backup_adapters, tmp_path): | ||
| """Test merged_lora_context with FSDP on GPT-OSS model.""" | ||
| rendezvous_file = str(tmp_path / f"rdzv_file_gptoss_{backup_adapters}") | ||
| os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True) | ||
|
|
||
| # Create GPT-OSS model config | ||
| model_config = GptOssConfig( | ||
| num_hidden_layers=2, | ||
| num_attention_heads=2, | ||
| num_key_value_heads=2, | ||
| hidden_size=128, | ||
| intermediate_size=256, | ||
| ) | ||
|
|
||
| # Create LoRA config for GPT-OSS | ||
| lora_config_dict = { | ||
| "r": 8, | ||
| "lora_alpha": 16, | ||
| "target_modules": "all-linear", | ||
| "target_parameters": ["mlp.experts.gate_up_proj", "mlp.experts.down_proj"], | ||
| "exclude_modules": ["mlp.router"], | ||
| "lora_dropout": 0.0, | ||
| "bias": "none", | ||
| "task_type": "CAUSAL_LM", | ||
| } | ||
|
|
||
| mp.spawn( | ||
| fn=_test_merged_lora_context_worker, | ||
| args=(world_size, rendezvous_file, strategy, model_config, lora_config_dict, backup_adapters), | ||
| nprocs=world_size, | ||
| join=True, | ||
| ) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need
+here asmergeis already a existing field, right?