Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ policy:
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null
# LoRA (Low-Rank Adaptation) Configuration
lora_cfg:
enabled: False # Set to True to enable LoRA fine-tuning
target_modules: [] # List of module names to apply LoRA (empty list with match_all_linear=true applies to all linear layers)
exclude_modules: [] # List of module names to exclude from LoRA
match_all_linear: true # If True, applies LoRA to all linear layers (overrides target_modules)
dim: 8 # LoRA rank (r): lower rank = fewer parameters but less capacity. Typical values: 4, 8, 16, 32, 64
alpha: 32 # LoRA scaling factor: effective learning rate multiplier = alpha/dim. Typical values: 16, 32, 64
dropout: 0.0 # Dropout probability applied to LoRA layers (0.0 = no dropout)
dropout_position: "post" # Where to apply dropout: "pre" (before LoRA) or "post" (after LoRA)
lora_A_init: "xavier" # Initialization method for LoRA A matrix: "xavier" or "uniform"
use_triton: true # Use Triton-optimized kernels for LoRA (faster but requires flash-attn). Disable when tensor_parallel_size > 1

megatron_cfg:
enabled: false
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
defaults: ../../grpo_math_1B.yaml
grpo:
val_at_start: true
checkpointing:
checkpoint_dir: results/grpo-qwen3-8B-base-1n8g-fsdp2-lora
policy:
model_name: Qwen/Qwen3-8B-Base
max_total_sequence_length: 2048
dtensor_cfg:
activation_checkpointing: true
lora_cfg:
enabled: True
dim: 128
alpha: 128
sequence_packing:
enabled: false
logger:
log_dir: logs/grpo-qwen3-8B-base-1n8g-fsdp2-lora
wandb_enabled: true
tensorboard_enabled: true
wandb:
project: nemo-rl
name: grpo-qwen3-8B-base-1n8g-fsdp2-lora
cluster:
gpus_per_node: 8
72 changes: 59 additions & 13 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,23 @@ def setup(
)
policy_config["megatron_cfg"]["train_iters"] = total_train_iters

if "dtensor_cfg" in policy_config and policy_config["dtensor_cfg"]["enabled"]:
lora_cfg = (
policy_config["dtensor_cfg"]["lora_cfg"]
if "lora_cfg" in policy_config["dtensor_cfg"]
else None
)
if "enabled" in lora_cfg and lora_cfg["enabled"]:
# Override the vLLM lora config with the DTensor lora config
generation_config["vllm_cfg"]["lora_cfg"] = lora_cfg
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be okay for this PR, but do you have any ideas how to unify it when supporting grpo lora in mcore?
also cc @vadam5


assert colocated_inference, (
"LoRA in DTensor backend is only supported with colocated inference."
)
assert not _should_use_async_rollouts(master_config), (
"Async rollouts are not supported with LoRA in DTensor backend."
)

# Define initialization functions that will be used in all paths
def init_policy():
"""Initialize policy training workers."""
Expand Down Expand Up @@ -505,6 +522,9 @@ def init_vllm():
assert loss_config["use_importance_sampling_correction"] is True, (
"Importance sampling must be enabled for vLLM FP8 generation for good convergence!"
)
assert not policy_config["dtensor_cfg"]["lora_cfg"]["enabled"], (
"LoRA is not supported with vLLM FP8 generation."
)
if generation_config["vllm_cfg"]["kv_cache_dtype"].startswith("fp8"):
# FP8 KV cache requires FP8 model precision
assert generation_config["vllm_cfg"]["precision"] == "fp8", (
Expand Down Expand Up @@ -933,18 +953,11 @@ def refit_policy_generation(
timer: Optional Timer used to time the prepare/transfer/update phase
kv_scales: Optional dictionary of KV cache scales for FP8 quantization.
"""
if colocated_inference:
policy.offload_before_refit()
policy_generation.prepare_for_generation(tags=["weights"])

# Create a context manager that does nothing when timer is None
timer_context = (
timer.time("prepare_for_generation/transfer_and_update_weights")
if timer is not None
else nullcontext()
)
with timer_context:
# update weights
def _perform_refit_weights(refit_mode: str):
assert refit_mode in ("base_model", "lora"), (
"refit_mode must be either 'base_model' or 'lora'"
)
update_success = False
if colocated_inference:
# get model param keys, which is grouped by size
Expand All @@ -959,9 +972,13 @@ def refit_policy_generation(
)

futures_train = policy.stream_weights_via_ipc_zmq(
buffer_size_bytes=buffer_size_bytes, kv_scales=kv_scales
buffer_size_bytes=buffer_size_bytes,
kv_scales=kv_scales,
refit_mode=refit_mode,
)
futures_inference = policy_generation.update_weights_via_ipc_zmq(
refit_mode=refit_mode,
)
futures_inference = policy_generation.update_weights_via_ipc_zmq()
# wait for all futures to complete
ray.get(futures_train)
results = ray.get(futures_inference)
Expand All @@ -985,6 +1002,35 @@ def refit_policy_generation(
)
raise RuntimeError(error_message)

lora_enabled, lora_base_refit_done = policy.check_lora_base_refit_done()
refit_lora_weights = lora_enabled
refit_base_model_weights = not lora_enabled or not lora_base_refit_done

if colocated_inference:
policy.offload_before_refit()
policy_generation.prepare_for_generation(tags=["weights"])

# Create a context manager that does nothing when timer is None
timer_context = (
timer.time("prepare_for_generation/transfer_and_update_weights")
if timer is not None
else nullcontext()
)
with timer_context:
if refit_base_model_weights:
_perform_refit_weights(refit_mode="base_model")
print(
" ▶ Refitting base model weights...",
flush=True,
)

if refit_lora_weights:
_perform_refit_weights(refit_mode="lora")
print(
" ▶ Refitting LoRA weights...",
flush=True,
)

if colocated_inference:
policy.offload_after_refit()
policy_generation.prepare_for_generation(tags=["kv_cache"])
Expand Down
6 changes: 4 additions & 2 deletions nemo_rl/models/generation/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, NotRequired, TypedDict, Union
from typing import Any, NotRequired, Optional, TypedDict, Union

import ray
import torch
Expand Down Expand Up @@ -245,7 +245,9 @@ def prepare_refit_info(self, state_dict_info: dict[str, Any]) -> None:
"""Prepare the info for refit."""
raise NotImplementedError

def update_weights_via_ipc_zmq(self) -> list[ray.ObjectRef]:
def update_weights_via_ipc_zmq(
self, refit_mode: Optional[str] = "base_model"
) -> list[ray.ObjectRef]:
"""Update the model weights from the given IPC handles."""
raise NotImplementedError

Expand Down
211 changes: 211 additions & 0 deletions nemo_rl/models/generation/vllm/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Comment on lines +1 to +13
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add file to pyrefly.toml whitelist.

The pipeline failure indicates this new file needs to be added to the pyrefly.toml project-includes list.

🧰 Tools
🪛 GitHub Actions: CICD NeMo RL

[error] 1-1: File nemo_rl/models/generation/lora.py has zero errors but is not in pyrefly.toml in the 'project-includes' list. Please add it to this whitelist.

🤖 Prompt for AI Agents
In @nemo_rl/models/generation/lora.py around lines 1 - 13, The new module
nemo_rl/models/generation/lora.py is missing from the pyrefly.toml whitelist and
must be added to the project-includes so CI stops failing; open pyrefly.toml and
add the relative path "nemo_rl/models/generation/lora.py" (or the appropriate
glob such as "nemo_rl/models/generation/*.py") to the project-includes array,
save the file, and re-run the pipeline to verify the whitelist change fixes the
failure.



from typing import Any, Optional

import torch
import vllm
from torch import nn
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase


class LoRARequestWithCfgAndWeights(LoRARequest):
lora_cfg: Optional[dict] = None
lora_weights: Optional[dict[str, Any]] = None


def patched_load_adapter(self, lora_request: LoRARequestWithCfgAndWeights):
try:
supported_lora_modules = self._adapter_manager.supported_lora_modules
packed_modules_mapping = self._adapter_manager.packed_modules_mapping
expected_lora_lst: list[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_lst.extend(packed_modules_mapping[module])
else:
expected_lora_lst.append(module)
if module == "experts":
expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
lora_weights = None

if isinstance(lora_request, LoRARequestWithCfgAndWeights):
lora_cfg = lora_request.lora_cfg
lora_weights = lora_request.lora_weights
peft_helper = PEFTHelper.from_dict(lora_cfg)
else:
lora_path = get_adapter_absolute_path(lora_request.lora_path)

peft_helper = PEFTHelper.from_local_dir(
lora_path,
self.max_position_embeddings,
lora_request.tensorizer_config_dict,
)

# Validates the LoRA configuration against requirements before
# loading weights, throwing an exception if validation fails.
peft_helper.validate_legal(self.lora_config)

# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
model = self._adapter_manager.model
hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
if isinstance(lora_request, LoRARequestWithCfgAndWeights):
lora = self._lora_model_cls.from_lora_tensors(
lora_model_id=lora_request.lora_int_id,
tensors=lora_weights,
peft_helper=peft_helper,
device="cpu",
dtype=self.lora_config.lora_dtype,
embeddings=None,
target_embedding_padding=self.vocab_size
+ self.lora_config.lora_extra_vocab_size,
embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules,
weights_mapper=hf_to_vllm_mapper,
)
else:
lora = self._lora_model_cls.from_local_checkpoint(
lora_path,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
target_embedding_padding=self.vocab_size
+ self.lora_config.lora_extra_vocab_size,
embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules,
weights_mapper=hf_to_vllm_mapper,
)

except FileNotFoundError as e:
# FileNotFoundError should be raised if both
# - No adapter found to download from huggingface (or in
# offline mode)
# - No local adapter files found at `lora_request.lora_path`
# For NotFoundError
raise ValueError(
f"Loading lora {lora_request.lora_name} failed: No adapter "
f"found for {lora_request.lora_path}"
) from e
except Exception as e:
# For BadRequestError
raise e

if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
raise ValueError(
f"LoRA added vocab size {lora.extra_vocab_size} is greater than lora_extra_vocab_size "
f"{self.lora_config.lora_extra_vocab_size}."
)
return lora


def patched_get_supported_lora_modules(model: nn.Module) -> list[str]:
"""Skip lm_head modules in the supported_lora_modules.

In vLLM, all linear layers support LoRA. But in Automodel, lm_head not support LoRA.
Refer to https://github.com/NVIDIA-NeMo/Automodel/blob/50253d14c2aefa2206036022b4ccce9f3476ba4d/nemo_automodel/components/_peft/module_matcher.py#L99 for more details.
"""
supported_lora_modules: set[str] = set()
for name, module in model.named_modules():
# get the embedding modules if the module's embedding_modules
# is not empty.
embedding_modules = getattr(module, "embedding_modules", None)
if embedding_modules is not None:
for name in embedding_modules:
if "lm_head" in name:
continue
supported_lora_modules.add(name)

# get all the linear subfixes.
if isinstance(module, (LinearBase,)):
supported_lora_modules.add(name.split(".")[-1])

if isinstance(module, (FusedMoE,)):
supported_lora_modules.add(name.split(".")[-1])

return list(supported_lora_modules)


def apply_lora_patches():
# patch the get_supported_lora_modules function
import vllm.lora.utils as lora_utils

setattr(
lora_utils, "get_supported_lora_modules", patched_get_supported_lora_modules
)

# patch the get_supported_lora_modules function in lora_models
import vllm.lora.models as lora_models

setattr(
lora_models, "get_supported_lora_modules", patched_get_supported_lora_modules
)

assert vllm.__version__.startswith("0.11."), (
"vLLM version must be == 0.11.x to apply the patches. "
"If this assertion fails, please check the vLLM version and remove the patching on condition. "
"You can:\n"
"1. Check whether vllm support load lora from memory.\n"
"2. If yes, remove the patching call\n"
"3. Delete this assertion"
"4. Delete this patch: patched_load_adapter"
)
# patch the load_adapter function in LRUCacheWorkerLoRAManager
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager

setattr(LRUCacheWorkerLoRAManager, "_load_adapter", patched_load_adapter)


def apply_weight_name_mapping(
weights: list[tuple[str, torch.Tensor]],
supported_modules: list[str],
packed_modules_mapping: dict[str, list[str]],
) -> list[tuple[str, torch.Tensor]]:
"""Apply weight name mapping if LoRA is enabled."""

def map_param_name(param_name: str) -> str:
# Vllm add logits_processor to lm_head weight(https://github.com/vllm-project/vllm/blob/b8b302cde434df8c9289a2b465406b47ebab1c2d/vllm/lora/models.py#L506), we skip mapping for lm_head weight
if "lm_head" in param_name:
return param_name
parts = param_name.split(".")
if len(parts) < 2:
return param_name
base_name = ".".join(parts[:-2]) # prefix
module_name = parts[-2] # e.g. q_proj/k_proj/v_proj/gate_proj/up_proj/...
field_name = parts[-1] # weight/bias
resolved_module_name = module_name
for packed_name, member_names in packed_modules_mapping.items():
if module_name in member_names:
resolved_module_name = packed_name
break
# use resolved_module_name for checking, but return the original module_name
if resolved_module_name in supported_modules:
if base_name != "":
return f"{base_name}.{module_name}.base_layer.{field_name}"
else:
return f"{module_name}.base_layer.{field_name}"
return param_name

new_weights = []
for name, w in weights:
new_name = map_param_name(name)
new_weights.append((new_name, w))
return new_weights
Loading
Loading