Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/e2e_ppo_trainer_megatron_vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ jobs:
- name: clean up and install Megatron-Bridge
run: |
rm -rf checkpoints
pip3 install git+https://github.com/NVIDIA-NeMo/Megatron-Bridge.git@550924c --no-deps --no-build-isolation
pip3 install git+https://github.com/NVIDIA-NeMo/Megatron-Bridge.git@83a7c11 --no-deps --no-build-isolation
pip3 install git+https://github.com/NVIDIA/Megatron-LM.git@5455f0a --no-deps --no-build-isolation
pip3 install "nvidia-modelopt[torch]>=0.37.0" transformers==4.57.1
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron, use Megatron-Bridge LoRA e2e to pre-load and save (Deepseek)
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/e2e_ppo_trainer_megatron_vllm_2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ jobs:
run: |
pip3 install -r requirements-test.txt
pip3 install --no-deps --force-reinstall .
pip3 install git+https://github.com/NVIDIA-NeMo/Megatron-Bridge.git@550924c --no-deps --no-build-isolation
pip3 install git+https://github.com/NVIDIA-NeMo/Megatron-Bridge.git@83a7c11 --no-deps --no-build-isolation
pip3 install git+https://github.com/NVIDIA/Megatron-LM.git@5455f0a --no-deps --no-build-isolation
pip3 install "nvidia-modelopt[torch]>=0.37.0" transformers==4.57.1
- name: Prepare GSM8K dataset
Expand Down Expand Up @@ -159,7 +159,7 @@ jobs:
MAX_PROMPT_LENGTH=512 MAX_RESPONSE_LENGTH=512 LORA_RANK=8 CRITIC_LORA_RANK=8 \
MODEL_ID=Qwen/Qwen3-30B-A3B-Instruct-2507 USE_MBRIDGE=True VANILLA_MBRIDGE=False VALUE_VANILLA_MBRIDGE=False \
COMMON_PP=2 COMMON_VPP=null COMMON_CP=1 COMMON_TP=4 COMMON_EP=2 COMMON_ETP=1 INFER_TP=8 \
USE_DIST_CKPT=False ALL_OFFLOAD=True SKIP_SAVE_HF_MODEL=1 bash tests/special_e2e/run_ppo_trainer_megatron.sh
USE_DIST_CKPT=False LORA_MERGE=True ALL_OFFLOAD=True SKIP_SAVE_HF_MODEL=1 bash tests/special_e2e/run_ppo_trainer_megatron.sh
- name: clean up
run: |
rm -rf checkpoints
Expand Down
21 changes: 14 additions & 7 deletions docs/advance/ppo_lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Megatron Backend Usage Guide

You need to install and enable Megatron-Bridge for Megatron LoRA support.

Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `this commit <https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/550924c04368a175ef261a72230204410f455260>`_ or later for proper support, and use the following settings to enable Megatron-Bridge:
Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `this commit <https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/83a7c1134c562d8c6decd10a1f0a6e6a7a8a3a44>`_ or later for proper support, and use the following settings to enable Megatron-Bridge:

- ``actor_rollout_ref.actor.megatron.use_mbridge=True``
- ``actor_rollout_ref.actor.megatron.vanilla_mbridge=False``
Expand All @@ -71,7 +71,7 @@ Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `th

1. **LoRA Implementation**: Verl Megatron backend uses Megatron-Bridge's native LoRA implementation, which differs from HuggingFace PEFT.

2. **Weight Sync Mechanism**: Currently, Megatron-Bridge syncs weights by merging LoRA adapters into the base model weights before transferring to vLLM rather than loading separate adapters. This is necessary because Megatron-Bridge's LoRA format is not directly integratable with vLLM's LoRA loading mechanism (HF PEFT format), and LoRA bridge is not yet supported.
2. **Weight Sync / Refit Mechanism**: Currently, Megatron-Bridge can support syncing weights by either merging LoRA adapters into the base model weights before transferring to vLLM (for better inference speed but more refit time and potential precision loss), as well as loading separate adapters.

**Configuration for Megatron LoRA:**

Expand All @@ -83,6 +83,9 @@ Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `th
# LoRA type: "lora", "vlm_lora", "canonical_lora", or "dora"
type: lora

# whether to sync weights / refit by either merging LoRA adapters into the base model weights before transferring to vLLM (for better inference speed but more refit time and potential precision loss). If this is False, it will load separate adapters.
merge: False

# LoRA rank (Dimension of the low-rank projection space.). Set to 0 to disable LoRA
rank: 0

Expand All @@ -101,6 +104,11 @@ Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `th
# - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP
# Target modules can also contain wildcards. For example, you can specify
# target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv on the first two layers
#
# Note:
# For MLA (e.g., DeepSeek), you should use ["linear_kv_down_proj","linear_kv_up_proj","linear_q_down_proj","linear_q_up_proj","linear_q_proj"]
# Instead of "linear_qkv" or ["linear_q","linear_k","linear_v"]
# By default, MoE routers are excluded from LoRA adaptation, and you will need to specify "router" in target_modules to include them.
target_modules:
- linear_qkv
- linear_proj
Expand Down Expand Up @@ -136,12 +144,11 @@ Make sure you use Megatron-Bridge later than 0.2.0, and we recommended using `th
freeze_vision_projection: True
freeze_language_model: True

LoRA training experiment with Qwen3-8B on 8 * H200 single node comparing FSDP and Megatron backend (script adapted from examples/grpo_trainer/run_qwen2-7b_math_megatron_lora.sh):

**Current Limitations:**

1. **No HuggingFace PEFT Export**: Currently there is no built-in way to export Megatron LoRA adapters to HuggingFace PEFT format for inference with standard HF/vLLM pipelines, such support is coming soon with Megatron-Bridge `LoRA bridge <https://github.com/NVIDIA-NeMo/Megatron-Bridge/issues/1536>`_.

2. **LoRA Merge Overhead**: As we don't have LoRA bridge for now, each weight sync (refit) requires merging LoRA weights, which adds some overhead compared to direct dynamic adapter loading.
.. image:: https://github.com/user-attachments/assets/0482f423-01a3-4e52-a7ee-8b9cd79b7b1a
.. image:: https://github.com/user-attachments/assets/6ce10400-8164-47d8-90a6-c1bf002fb9e8
.. image:: https://github.com/user-attachments/assets/092d3a43-4eba-425e-a584-8d83c1f02de4


Best Practices and Notes
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_trainer/run_qwen2-7b_math_megatron_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -xeuo pipefail

# Need to install Megatron-Bridge
# NOTE: Make sure you use Megatron-Bridge later than 0.2.0
# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/550924c04368a175ef261a72230204410f455260 or later)
# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/83a7c1134c562d8c6decd10a1f0a6e6a7a8a3a44 or later)
# for proper MoE LoRA support.

# For Megatron communication/computation overlapping
Expand Down
2 changes: 1 addition & 1 deletion examples/grpo_trainer/run_qwen3moe-30b_megatron_lora.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -xeuo pipefail

# Need to install Megatron-Bridge
# NOTE: Make sure you use Megatron-Bridge later than 0.2.0
# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/550924c04368a175ef261a72230204410f455260 or later)
# (Recommend https://github.com/NVIDIA-NeMo/Megatron-Bridge/commit/83a7c1134c562d8c6decd10a1f0a6e6a7a8a3a44 or later)
# for proper MoE LoRA support.

# For Megatron communication/computation overlapping
Expand Down
2 changes: 2 additions & 0 deletions tests/special_e2e/run_ppo_trainer_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ LORA_RANK=${LORA_RANK:-0}
CRITIC_LORA_RANK=${CRITIC_LORA_RANK:-$LORA_RANK}
LORA_ALPHA=${LORA_ALPHA:-${LORA_RANK}}
LORA_TARGET_MODULES=${LORA_TARGET_MODULES:-"['linear_qkv','linear_proj','linear_fc1','linear_fc2']"}
LORA_MERGE=${LORA_MERGE:-False}

MAX_PROMPT_LENGTH=${MAX_PROMPT_LENGTH:-512}
MAX_RESPONSE_LENGTH=${MAX_RESPONSE_LENGTH:-512}
Expand Down Expand Up @@ -163,6 +164,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.model.lora.rank=${LORA_RANK} \
actor_rollout_ref.model.lora.alpha=${LORA_ALPHA} \
actor_rollout_ref.model.lora.target_modules=${LORA_TARGET_MODULES} \
actor_rollout_ref.model.lora.merge=${LORA_MERGE} \
actor_rollout_ref.actor.optim.lr_warmup_steps=$LR_WARMUP_STEPS \
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=$OPTIM_MEMORY_EFFICIENT \
+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=$OPTIM_MEMORY_EFFICIENT \
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ actor_rollout_ref:
num_speculative_tokens: 1
lora:
type: lora
merge: false
rank: 0
alpha: 32
dropout: 0.0
Expand Down
5 changes: 5 additions & 0 deletions verl/trainer/config/critic/megatron_critic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ model:
# - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP
# Target modules can also contain wildcards. For example, you can specify
# target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv on the first two layers
#
# Note:
# For MLA (e.g., DeepSeek), you should use ["linear_kv_down_proj","linear_kv_up_proj","linear_q_down_proj","linear_q_up_proj","linear_q_proj"]
# Instead of "linear_qkv" or ["linear_q","linear_k","linear_v"]
# By default, MoE routers are excluded from LoRA adaptation, and you will need to specify "router" in target_modules to include them.
target_modules:
- linear_qkv
- linear_proj
Expand Down
8 changes: 8 additions & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ actor_rollout_ref:
# LoRA type: "lora", "vlm_lora", "canonical_lora", or "dora"
type: lora

# whether to sync weights / refit by either merging LoRA adapters into the base model weights before transferring to vLLM (for better inference speed but more refit time and potential precision loss). If this is False, it will load separate adapters.
merge: False

# LoRA rank (Dimension of the low-rank projection space.). Set to 0 to disable LoRA
rank: 0 # typical values: 8, 16, 32, 64

Expand All @@ -64,6 +67,11 @@ actor_rollout_ref:
# - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP
# Target modules can also contain wildcards. For example, you can specify
# target_modules=['*.layers.0.*.linear_qkv', '*.layers.1.*.linear_qkv'] to add LoRA to only linear_qkv on the first two layers
#
# Note:
# For MLA (e.g., DeepSeek), you should use ["linear_kv_down_proj","linear_kv_up_proj","linear_q_down_proj","linear_q_up_proj","linear_q_proj"]
# Instead of "linear_qkv" or ["linear_q","linear_k","linear_v"]
# By default, MoE routers are excluded from LoRA adaptation, and you will need to specify "router" in target_modules to include them.
target_modules:
- linear_qkv
- linear_proj
Expand Down
1 change: 1 addition & 0 deletions verl/utils/checkpoint/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ def save_checkpoint(self, local_path: str, hdfs_path: str = None, global_step: i
"grad_sync_func",
"param_sync_func",
"generation_config",
"_pg_collection",
]
backup = {}
for k in bypass_keys:
Expand Down
12 changes: 10 additions & 2 deletions verl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,15 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
)

# check LoRA rank in vLLM
if config.actor_rollout_ref.model.get("lora_rank", 0) > 0 and config.actor_rollout_ref.rollout.name == "vllm":
assert config.actor_rollout_ref.model.lora_rank <= 512, "LoRA rank in vLLM must be less than or equal to 512"
lora_config = config.actor_rollout_ref.model.get("lora", {})
lora_rank = lora_config.get("rank", 0)
if lora_config.get("merge", False):
lora_rank = 0
if lora_rank <= 0:
lora_rank = config.actor_rollout_ref.model.get("lora_rank", 0)
if lora_rank > 0 and config.actor_rollout_ref.rollout.name == "vllm":
from verl.workers.rollout.vllm_rollout.utils import get_vllm_max_lora_rank

get_vllm_max_lora_rank(lora_rank)

print("[validate_config] All configuration checks passed successfully!")
120 changes: 120 additions & 0 deletions verl/utils/megatron_peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,53 @@

import os
from pathlib import Path
from typing import Iterator

import torch

# Map megatron lora target modules to HF-style module names for vLLM
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only an advice, not necessary, can we move the lora mapping to megatron-bridge instead of keeping them in verl?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this is something in discussion, and I guess it would need some API designs.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add @yaoyu-33 for vis. We need to extend bridge APIs with capability of exporting lora weights.

MEGATRON_TO_HF_MODULES = {
"linear_qkv": ["q_proj", "k_proj", "v_proj"],
"linear_proj": ["o_proj"],
"linear_fc1": ["gate_proj", "up_proj"],
"linear_fc2": ["down_proj"],
"router": ["gate"],
# Canonical LoRA mappings
"linear_q": ["q_proj"],
"linear_k": ["k_proj"],
"linear_v": ["v_proj"],
"linear_fc1_up": ["up_proj"],
"linear_fc1_gate": ["gate_proj"],
# MLA mappings
"linear_kv_down_proj": ["kv_a_proj_with_mqa"],
"linear_kv_up_proj": ["kv_b_proj"],
"linear_q_down_proj": ["q_a_proj"],
"linear_q_up_proj": ["q_b_proj"],
"linear_q_proj": ["q_proj"],
}

# Modules with stacked parameters that need .base_layer suffix in vLLM
STACKED_PARAMS = [
".q_proj.weight",
".q_proj.bias",
".k_proj.weight",
".k_proj.bias",
".v_proj.weight",
".v_proj.bias",
".o_proj.weight",
".o_proj.bias",
".gate_proj.weight",
".up_proj.weight",
".down_proj.weight",
".mlp.gate.weight",
".mlp.gate.bias",
".mlp.gate.e_score_correction_bias",
".kv_a_proj_with_mqa.weight",
".kv_b_proj.weight",
".q_a_proj.weight",
".q_b_proj.weight",
]


def _get_rank_checkpoint_path(base_path: str) -> str:
"""Get rank-specific checkpoint path following Megatron's convention.
Expand Down Expand Up @@ -224,10 +268,86 @@ def print_adapter_info(model):
print(f"{'=' * 60}\n")


def convert_megatron_to_hf_target_modules(megatron_modules: list[str]) -> list[str]:
"""Convert megatron lora target modules to HF-style module names.

Args:
megatron_modules: List of megatron-style module names.

Returns:
List of HF-style module names with duplicates removed.
"""
hf_target_modules = []
for module in megatron_modules:
if module in MEGATRON_TO_HF_MODULES:
hf_target_modules.extend(MEGATRON_TO_HF_MODULES[module])
else:
hf_target_modules.append(module)
# Remove duplicates while preserving order
return list(dict.fromkeys(hf_target_modules))


def build_peft_config_for_vllm(lora_config: dict) -> dict:
"""Build a peft_config dict compatible with vLLM's PEFTHelper from megatron lora config.

Args:
lora_config: Megatron lora configuration dictionary.

Returns:
A dictionary compatible with vLLM's PEFTHelper.from_dict().
"""
from peft import TaskType

target_modules = lora_config.get("target_modules", ["linear_qkv", "linear_proj", "linear_fc1", "linear_fc2"])
exclude_modules = lora_config.get("exclude_modules", [])
hf_target_modules = convert_megatron_to_hf_target_modules(target_modules)
hf_exclude_modules = convert_megatron_to_hf_target_modules(exclude_modules)

return {
"task_type": TaskType.CAUSAL_LM,
"r": lora_config.get("rank", 0),
"lora_alpha": lora_config.get("alpha", 32),
"target_modules": hf_target_modules,
"exclude_modules": hf_exclude_modules,
"bias": "none",
"lora_dropout": lora_config.get("dropout", 0.0),
}


# vLLM needs to target all-linear no matter about specific LoRA config
def add_base_layer_suffix(
params: Iterator[tuple[str, torch.Tensor]],
model_type: str,
) -> Iterator[tuple[str, torch.Tensor]]:
"""Yield param pairs with a base-layer suffix added to the param name.

Args:
params: Iterator of (param_name, tensor)
model_type: The type of the model (e.g., "llama").
"""
stacked_params = STACKED_PARAMS
# TODO: other models may have more special treatment, or integrate this into Megatron-Bridge
if model_type == "llama":
stacked_params = [".embed_tokens.weight", *STACKED_PARAMS]
for name, param in params:
ending_suffix = ""
for suffix in stacked_params:
if name.endswith(suffix):
ending_suffix = suffix
break
if ending_suffix:
suffix = ending_suffix.rsplit(".", 1)[-1]
name = f"{name[: -len(suffix)]}base_layer.{suffix}"
yield name, param


__all__ = [
"get_adapter_state_dict",
"save_adapter_checkpoint",
"load_adapter_checkpoint",
"count_adapter_parameters",
"print_adapter_info",
"convert_megatron_to_hf_target_modules",
"build_peft_config_for_vllm",
"add_base_layer_suffix",
]
16 changes: 13 additions & 3 deletions verl/workers/engine/megatron/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from verl.utils.device import get_device_id, get_device_name
from verl.utils.megatron.pipeline_parallel import make_batch_generator
from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy, vocab_parallel_log_probs_from_logits
from verl.utils.megatron_peft_utils import add_base_layer_suffix, build_peft_config_for_vllm
from verl.utils.megatron_utils import (
get_megatron_module_device,
load_megatron_model_to_gpu,
Expand Down Expand Up @@ -524,14 +525,23 @@ def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forw
else:
return {}

def get_per_tensor_param(self, **kwargs):
def get_per_tensor_param(self, base_sync_done=False, **kwargs):
load_megatron_model_to_gpu(self.module, load_grad=False)
peft_config = None
non_merge_lora_sync = self.peft_cls is not None and not self.model_config.lora.get("merge", False)
if self.vanilla_bridge:
per_tensor_param = self.bridge.export_weights(self.module)
elif base_sync_done and non_merge_lora_sync:
# Only export adapter weights
peft_config = build_peft_config_for_vllm(self.model_config.lora)
per_tensor_param = self.bridge.export_adapter_weights(self.module)
else:
per_tensor_param = self.bridge.export_hf_weights(self.module)
# TODO: support megatron LoRA
return per_tensor_param, None
if non_merge_lora_sync:
per_tensor_param = add_base_layer_suffix(
per_tensor_param, model_type=self.model_config.hf_config.model_type
)
return per_tensor_param, peft_config

def disable_adapter(self) -> ContextManager:
return self.peft_cls.disable_adapter(self.module)
Expand Down
Loading
Loading