Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/advance/ppo_lora.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
RL(HF) algorithms with LoRA Support
===========================================

Last updated: 12/17/2025.
Last updated: 02/03/2026.

We support LoRA (Low-Rank Adaptation) for reinforcement learning algorithms such as PPO, GRPO, and others.

Expand Down Expand Up @@ -42,6 +42,8 @@ FSDP Backend Usage Guide
- `actor_rollout_ref.model.lora_adapter_path`: string, path to a pretrained LoRA adapter directory.
If provided, loads existing adapter instead of creating new one. Enables multi-stage training from previously saved adapters.
Directory need contain `adapter_model.safetensors` and `adapter_config.json`.
- `actor_rollout_ref.model.lora.merge`: bool, whether to merge LoRA adapters into the base model weights before transferring to vLLM.
If True, it will merge LoRA adapters into the base model weights before transferring to vLLM. If False, it will transfer only adapters to vLLM. This option is currently supported **only for engine-based rollout workers** (i.e. vLLM engine workers using the new worker implementation with ``trainer.use_legacy_worker_impl`` disabled) and is not available when using the legacy worker implementation.

5. Recommend options:

Expand Down Expand Up @@ -137,6 +139,10 @@ Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `th
# Path to pre-trained LoRA adapter weights (null to train from scratch)
adapter_path: null

# Whether to fully shard LoRA adapters. Defaults to False
# https://docs.vllm.ai/en/latest/api/vllm/config/lora/#vllm.config.lora.LoRAConfig.fully_sharded_loras
fully_sharded_loras: bool

# VLMLoRA additionally allows the user to specify whether the language or vision models should be frozen.
# For example, a common finetuning workload for multimodal models is to apply adapters to language model and fully
# finetune the vision model.
Expand Down
71 changes: 71 additions & 0 deletions examples/grpo_trainer/run_qwen3-4b_gsm8k_grpo_lora_merge.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
set -x

# initial "val-core/openai/gsm8k/acc/mean@1":0.378316906747536
# after training: "val-core/openai/gsm8k/acc/mean@1":0.9264594389689158

TIMESTAMP=$(date +%Y%m%d.%H%M%S)
project_name=verl_grpo_example_gsm8k
experiment_name=qwen3_4b_grpo-lora-merged-${TIMESTAMP}
train_dir=outputs/$project_name/$experiment_name/
mkdir -p $train_dir
export TENSORBOARD_DIR=$train_dir/tensorboard_log/
export VERL_FILE_LOGGER_PATH=$train_dir/metrics.jsonl

max_token_len_per_gpu=24576

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
trainer.val_before_train=True \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=128 \
data.max_prompt_length=1024 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.shuffle=False \
actor_rollout_ref.model.path=Qwen/Qwen3-4B \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
+actor_rollout_ref.model.lora.merge=True \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need + here as merge is already a existing field, right?

actor_rollout_ref.model.lora_rank=32 \
actor_rollout_ref.model.lora_alpha=64 \
actor_rollout_ref.actor.optim.lr=1.0e-05 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${max_token_len_per_gpu} \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.strategy=fsdp2 \
actor_rollout_ref.actor.fsdp_config.model_dtype=bf16 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${max_token_len_per_gpu} \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.rollout.layered_summon=True \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${max_token_len_per_gpu} \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
actor_rollout_ref.ref.strategy=fsdp2 \
actor_rollout_ref.ref.fsdp_config.model_dtype=bf16 \
algorithm.use_kl_in_reward=False \
trainer.use_legacy_worker_impl=disable \
trainer.critic_warmup=0 \
trainer.logger='["console","tensorboard","file"]' \
trainer.project_name=$project_name \
trainer.experiment_name=$experiment_name \
trainer.default_local_dir=$train_dir \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.total_epochs=1 \
2>&1 | tee $train_dir/train_log.txt

1 change: 1 addition & 0 deletions tests/special_e2e/run_ppo_trainer_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.model.lora.alpha=${LORA_ALPHA} \
actor_rollout_ref.model.lora.target_modules=${LORA_TARGET_MODULES} \
actor_rollout_ref.model.lora.merge=${LORA_MERGE} \
+actor_rollout_ref.model.lora.fully_sharded_loras=True \
actor_rollout_ref.actor.optim.lr_warmup_steps=$LR_WARMUP_STEPS \
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=$OPTIM_MEMORY_EFFICIENT \
+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=$OPTIM_MEMORY_EFFICIENT \
Expand Down
2 changes: 2 additions & 0 deletions tests/special_sanity/check_license.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
license_head_sglang = "Copyright 2023-2024 SGLang Team"
license_head_modelbest = "Copyright 2025 ModelBest Inc. and/or its affiliates"
license_head_amazon = "Copyright 2025 Amazon.com Inc and/or its affiliates"
license_head_amazon_26 = "Copyright 2026 Amazon.com Inc and/or its affiliates"
license_head_facebook = "Copyright (c) 2016- Facebook, Inc"
license_head_meituan = "Copyright 2025 Meituan Ltd. and/or its affiliates"
license_head_huawei = "Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved."
Expand All @@ -36,6 +37,7 @@
license_head_sglang,
license_head_modelbest,
license_head_amazon,
license_head_amazon_26,
license_head_facebook,
license_head_meituan,
license_head_huawei,
Expand Down
221 changes: 221 additions & 0 deletions tests/utils/test_fsdp_lora_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright 2026 Amazon.com Inc and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import pytest
import torch
import torch.distributed
import torch.multiprocessing as mp
from peft import LoraConfig, get_peft_model
from torch.distributed import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from transformers import AutoModelForCausalLM, GptOssConfig, Qwen2Config

from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device
from verl.utils.fsdp_utils import (
MixedPrecisionPolicy,
apply_fsdp2,
get_fsdp_wrap_policy,
merged_lora_context,
)


def _test_merged_lora_context_worker(
rank, world_size, rendezvous_file, strategy, model_config, lora_config_dict, backup_adapters
):
"""Worker function for testing merged_lora_context with FSDP.

Args:
rank: Process rank
world_size: Total number of processes
rendezvous_file: Path to rendezvous file for distributed init
strategy: FSDP strategy ("fsdp" or "fsdp2")
model_config: Model configuration object (Qwen2Config, GptOssConfig, etc.)
lora_config_dict: Dictionary of LoRA configuration parameters
backup_adapters: Whether to backup adapter weights before merging
"""
get_torch_device().set_device(rank)
torch.distributed.init_process_group(
backend=get_nccl_backend(),
init_method=f"file://{rendezvous_file}",
rank=rank,
world_size=world_size,
)
device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=("dp",))

# Create model from provided config
with torch.device(get_device_name()):
model = AutoModelForCausalLM.from_config(
config=model_config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
)
model = model.to(device=get_device_name())

# Add LoRA with provided config
lora_config = LoraConfig(**lora_config_dict)
model = get_peft_model(model, lora_config)

# Initialize LoRA adapter weights to non-zero values for testing
from peft.tuners.lora import LoraLayer

with torch.no_grad():
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
for adapter_name in module.lora_A.keys():
if adapter_name in module.lora_A:
# Initialize lora_A with values around 1.0
module.lora_A[adapter_name].weight.data.uniform_(0.5, 1.5)
if adapter_name in module.lora_B:
# Initialize lora_B with values around 2.0
module.lora_B[adapter_name].weight.data.uniform_(1.5, 2.5)

# Wrap model with FSDP
if strategy == "fsdp":
mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32, buffer_dtype=torch.float32
)
model = FSDP(
model,
use_orig_params=True,
device_id=get_torch_device().current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
device_mesh=device_mesh,
auto_wrap_policy=get_fsdp_wrap_policy(module=model, is_lora=True),
)
else:
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True
)
fsdp_kwargs = {
"mesh": device_mesh,
"mp_policy": mp_policy,
}
apply_fsdp2(model, fsdp_kwargs, {})

# Test: backup adapter weights, merge, restore
from peft.tuners.lora import LoraLayer

lora_layers = [m for m in model.modules() if isinstance(m, LoraLayer)]

# Verify LoRA layers exist
assert len(lora_layers) > 0, "Model should have LoRA layers"

# Initially not merged
for layer in lora_layers:
assert not getattr(layer, "merged", False), "LoRA should not be merged initially"

# Backup adapter weights before merge
from peft.utils.save_and_load import get_peft_model_state_dict

original_adapter_weights = get_peft_model_state_dict(model)

# Use merged_lora_context with the specified backup_adapters flag
for _ in range(3):
with merged_lora_context(model, backup_adapters=backup_adapters):
# Inside context, LoRA should be merged
for layer in lora_layers:
assert getattr(layer, "merged", False), "LoRA should be merged inside context"

# After context, check the state based on backup_adapters flag
for layer in lora_layers:
assert not getattr(layer, "merged", False), "LoRA should be unmerged after context"

restored_adapter_weights = get_peft_model_state_dict(model)

# Verify adapter weights are restored exactly
for key in original_adapter_weights.keys():
assert key in restored_adapter_weights, f"Key {key} should be in restored weights"
torch.testing.assert_close(
original_adapter_weights[key].cpu(),
restored_adapter_weights[key].cpu(),
rtol=1e-5,
atol=1e-6,
msg=f"Adapter weight {key} should be restored to original value",
)

if rank == 0:
model_name = model_config.__class__.__name__
backup_mode = "with backup" if backup_adapters else "without backup"
print(f"merged_lora_context test with {model_name} {strategy} {backup_mode} passed on {world_size} GPUs!")

torch.distributed.barrier()
torch.distributed.destroy_process_group()


@pytest.mark.parametrize("world_size", (2,))
@pytest.mark.parametrize("strategy", ("fsdp", "fsdp2"))
@pytest.mark.parametrize("backup_adapters", (True, False))
def test_merged_lora_context_qwen2(world_size, strategy, backup_adapters, tmp_path):
"""Test merged_lora_context with FSDP on Qwen2 model."""
rendezvous_file = str(tmp_path / f"rdzv_file_qwen2_{backup_adapters}")
os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)

# Create Qwen2 model config
model_config = Qwen2Config(num_hidden_layers=2, num_attention_heads=2, hidden_size=128)

# Create LoRA config for Qwen2
lora_config_dict = {
"r": 8,
"lora_alpha": 16,
"target_modules": ["q_proj", "v_proj"],
"lora_dropout": 0.0,
"bias": "none",
"task_type": "CAUSAL_LM",
}

mp.spawn(
fn=_test_merged_lora_context_worker,
args=(world_size, rendezvous_file, strategy, model_config, lora_config_dict, backup_adapters),
nprocs=world_size,
join=True,
)


@pytest.mark.parametrize("world_size", (2,))
@pytest.mark.parametrize("strategy", ("fsdp", "fsdp2"))
@pytest.mark.parametrize("backup_adapters", (True, False))
def test_merged_lora_context_gptoss(world_size, strategy, backup_adapters, tmp_path):
"""Test merged_lora_context with FSDP on GPT-OSS model."""
rendezvous_file = str(tmp_path / f"rdzv_file_gptoss_{backup_adapters}")
os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)

# Create GPT-OSS model config
model_config = GptOssConfig(
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
hidden_size=128,
intermediate_size=256,
)

# Create LoRA config for GPT-OSS
lora_config_dict = {
"r": 8,
"lora_alpha": 16,
"target_modules": "all-linear",
"target_parameters": ["mlp.experts.gate_up_proj", "mlp.experts.down_proj"],
"exclude_modules": ["mlp.router"],
"lora_dropout": 0.0,
"bias": "none",
"task_type": "CAUSAL_LM",
}

mp.spawn(
fn=_test_merged_lora_context_worker,
args=(world_size, rendezvous_file, strategy, model_config, lora_config_dict, backup_adapters),
nprocs=world_size,
join=True,
)
Loading
Loading