Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ policy:
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null

# LoRA (Low-Rank Adaptation) Configuration
lora_cfg:
enabled: False # Set to True to enable LoRA fine-tuning
target_modules: [] # List of module names to apply LoRA (empty list with match_all_linear=true applies to all linear layers)
exclude_modules: [] # List of module names to exclude from LoRA
match_all_linear: true # If True, applies LoRA to all linear layers (overrides target_modules)
dim: 8 # LoRA rank (r): lower rank = fewer parameters but less capacity. Typical values: 4, 8, 16, 32, 64
alpha: 32 # LoRA scaling factor: effective learning rate multiplier = alpha/dim. Typical values: 16, 32, 64
dropout: 0.0 # Dropout probability applied to LoRA layers (0.0 = no dropout)
dropout_position: "post" # Where to apply dropout: "pre" (before LoRA) or "post" (after LoRA)
lora_A_init: "xavier" # Initialization method for LoRA A matrix: "xavier" or "uniform"
use_triton: true # Use Triton-optimized kernels for LoRA (faster but requires flash-attn). Disable when tensor_parallel_size > 1

megatron_cfg:
enabled: false
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
defaults: ../../grpo_math_1B.yaml
grpo:
val_at_start: true
checkpointing:
checkpoint_dir: results/grpo-qwen3-8B-base-1n8g-fsdp2-lora
policy:
model_name: Qwen/Qwen3-8B-Base
max_total_sequence_length: 2048
dtensor_cfg:
activation_checkpointing: true
lora_cfg:
enabled: True
dim: 128
alpha: 128
sequence_packing:
enabled: false
logger:
log_dir: logs/grpo-qwen3-8B-base-1n8g-fsdp2-lora
wandb_enabled: true
tensorboard_enabled: true
wandb:
project: nemo-rl
name: grpo-qwen3-8B-base-1n8g-fsdp2-lora
cluster:
gpus_per_node: 8
52 changes: 51 additions & 1 deletion nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import ray
import torch
from nemo_automodel.components._peft.lora import LinearLoRA
from nemo_automodel.components.distributed.cp_utils import (
create_context_parallel_ctx,
)
Expand Down Expand Up @@ -85,19 +86,66 @@ def dtensor_params_generator(
Args:
model: The model whose parameters to generate.
target_dtype: The dtype to convert tensors to.
peft_config: Optional LoRA config for filtering which layers to merge.

Yields:
Tuples of (fully_qualified_name, tensor) where tensors are converted to target dtype and made contiguous.
"""
module_map = dict(model.named_modules())
for name, tensor in model.state_dict().items():
if name.endswith(".lora_A.weight") or name.endswith(".lora_B.weight"):
continue
full_tensor = tensor.full_tensor() if isinstance(tensor, DTensor) else tensor
adapted_fqn_tensors = _maybe_adapt_tensor_to_hf(model, name, full_tensor)
merged_tensor = _maybe_merge_lora_weight(module_map, name, full_tensor)

adapted_fqn_tensors = _maybe_adapt_tensor_to_hf(model, name, merged_tensor)
for adapted_fqn, adapted_tensor in adapted_fqn_tensors:
# Convert to target dtype
yield (
adapted_fqn,
adapted_tensor.to(target_dtype, non_blocking=True).contiguous(),
)
del adapted_tensor
del adapted_fqn_tensors
del merged_tensor
del full_tensor


@torch.no_grad()
def _maybe_merge_lora_weight(
module_map: dict[str, nn.Module],
fqn: str,
tensor: torch.Tensor,
) -> torch.Tensor:
if not fqn.endswith(".weight"):
return tensor
module_name = fqn[: -len(".weight")]
module = module_map.get(module_name)
if not isinstance(module, LinearLoRA):
return tensor
if not (hasattr(module, "lora_A") and hasattr(module, "lora_B")):
return tensor

lora_a = (
module.lora_A.weight.full_tensor()
if isinstance(module.lora_A.weight, DTensor)
else module.lora_A.weight
)
lora_b = (
module.lora_B.weight.full_tensor()
if isinstance(module.lora_B.weight, DTensor)
else module.lora_B.weight
)
lora_a = lora_a.to(device=tensor.device, dtype=tensor.dtype)
lora_b = lora_b.to(device=tensor.device, dtype=tensor.dtype)
scale = getattr(module, "scale", None)

if scale is None and hasattr(module, "alpha") and hasattr(module, "dim"):
scale = module.alpha / module.dim
if scale is None:
scale = 1.0

return tensor + torch.matmul(lora_b, lora_a) * scale


def _maybe_adapt_tensor_to_hf(
Expand Down Expand Up @@ -1208,6 +1256,8 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]:
"""Prepare state dict metadata for weight refitting and IPC streaming."""
state_dict_info = {}
for name, tensor in self.model.state_dict().items():
if name.endswith(".lora_A.weight") or name.endswith(".lora_B.weight"):
continue
full_tensor = (
tensor.full_tensor() if isinstance(tensor, DTensor) else tensor
)
Expand Down
3 changes: 3 additions & 0 deletions tests/functional/L1_Functional_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ time uv run --no-sync bash ./tests/functional/sft.sh
time uv run --no-sync bash ./tests/functional/sft_resume_diamond.sh
time uv run --no-sync bash ./tests/functional/grpo.sh
time uv run --no-sync bash ./tests/functional/grpo_async.sh
time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async.sh
time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh
time uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh
time uv run --no-sync bash ./tests/functional/grpo_megatron.sh
time uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh
time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh
Expand Down
46 changes: 46 additions & 0 deletions tests/functional/grpo_automodel_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/bin/bash

# clean up checkpoint directory on exit
trap "rm -rf /tmp/lora_sft_checkpoints" EXIT

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
# Mark the current repo as safe, since wandb fetches metadata about the repo
git config --global --add safe.directory $PROJECT_ROOT

set -eou pipefail

EXP_NAME=$(basename $0 .sh)
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
LOG_DIR=$EXP_DIR/logs
JSON_METRICS=$EXP_DIR/metrics.json
RUN_LOG=$EXP_DIR/run.log
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}

rm -rf $EXP_DIR $LOG_DIR
mkdir -p $EXP_DIR $LOG_DIR

cd $PROJECT_ROOT
uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
$PROJECT_ROOT/examples/run_grpo_math.py\
grpo.max_num_steps=3 \
grpo.num_prompts_per_step=8 \
grpo.num_generations_per_prompt=4 \
data.shuffle=false \
policy.dtensor_cfg.lora_cfg.enabled=True \
policy.dtensor_cfg.lora_cfg.dim=32 \
policy.train_global_batch_size=32 \
policy.train_micro_batch_size=1 \
cluster.gpus_per_node=2 \
logger.tensorboard_enabled=true \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=false \
logger.monitor_gpus=true \
checkpointing.enabled=false \
"$@" \
2>&1 | tee $RUN_LOG

uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

uv run tests/check_metrics.py $JSON_METRICS \
'max(data["train/reward"]) > 0.03'
52 changes: 52 additions & 0 deletions tests/functional/grpo_automodel_lora_async.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/bin/bash

# clean up checkpoint directory on exit
trap "rm -rf /tmp/lora_sft_checkpoints" EXIT

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
# Mark the current repo as safe, since wandb fetches metadata about the repo
git config --global --add safe.directory $PROJECT_ROOT

set -eou pipefail

EXP_NAME=$(basename $0 .sh)
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
LOG_DIR=$EXP_DIR/logs
JSON_METRICS=$EXP_DIR/metrics.json
RUN_LOG=$EXP_DIR/run.log
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}

rm -rf $EXP_DIR $LOG_DIR
mkdir -p $EXP_DIR $LOG_DIR

cd $PROJECT_ROOT
NRL_FORCE_REBUILD_VENVS=true uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
$PROJECT_ROOT/examples/run_grpo_math.py\
grpo.max_num_steps=3 \
grpo.num_prompts_per_step=8 \
grpo.num_generations_per_prompt=4 \
data.shuffle=false \
policy.dtensor_cfg.lora_cfg.enabled=True \
policy.dtensor_cfg.lora_cfg.dim=32 \
policy.train_global_batch_size=32 \
policy.train_micro_batch_size=1 \
policy.generation.colocated.enabled=false \
policy.generation.colocated.resources.gpus_per_node=1 \
policy.generation.colocated.resources.num_nodes=1 \
policy.generation.vllm_cfg.async_engine=true \
grpo.async_grpo.enabled=true \
loss_fn.use_importance_sampling_correction=true \
cluster.gpus_per_node=2 \
logger.tensorboard_enabled=true \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=false \
logger.monitor_gpus=true \
checkpointing.enabled=false \
"$@" \
2>&1 | tee $RUN_LOG

uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

uv run tests/check_metrics.py $JSON_METRICS \
'max(data["train/reward"]) > 0.03'
49 changes: 49 additions & 0 deletions tests/functional/grpo_automodel_lora_non_colocated.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/bin/bash

# clean up checkpoint directory on exit
trap "rm -rf /tmp/lora_sft_checkpoints" EXIT

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
# Mark the current repo as safe, since wandb fetches metadata about the repo
git config --global --add safe.directory $PROJECT_ROOT

set -eou pipefail

EXP_NAME=$(basename $0 .sh)
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
LOG_DIR=$EXP_DIR/logs
JSON_METRICS=$EXP_DIR/metrics.json
RUN_LOG=$EXP_DIR/run.log
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}

rm -rf $EXP_DIR $LOG_DIR
mkdir -p $EXP_DIR $LOG_DIR

cd $PROJECT_ROOT
NRL_FORCE_REBUILD_VENVS=true uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
$PROJECT_ROOT/examples/run_grpo_math.py\
grpo.max_num_steps=3 \
grpo.num_prompts_per_step=8 \
grpo.num_generations_per_prompt=4 \
data.shuffle=false \
policy.dtensor_cfg.lora_cfg.enabled=True \
policy.dtensor_cfg.lora_cfg.dim=32 \
policy.train_global_batch_size=32 \
policy.train_micro_batch_size=1 \
policy.generation.colocated.enabled=false \
policy.generation.colocated.resources.gpus_per_node=1 \
policy.generation.colocated.resources.num_nodes=1 \
cluster.gpus_per_node=2 \
logger.tensorboard_enabled=true \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=false \
logger.monitor_gpus=true \
checkpointing.enabled=false \
"$@" \
2>&1 | tee $RUN_LOG

uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

uv run tests/check_metrics.py $JSON_METRICS \
'max(data["train/reward"]) > 0.03'
44 changes: 44 additions & 0 deletions tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
source $SCRIPT_DIR/common.env

# ===== BEGIN CONFIG =====
NUM_NODES=1
STEPS_PER_RUN=20
MAX_STEPS=20
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
NUM_MINUTES=30
# ===== END CONFIG =====

exit_if_max_steps_reached

# Run the experiment
cd $PROJECT_ROOT
uv run examples/run_grpo_math.py \
--config $CONFIG_PATH \
grpo.max_num_steps=$MAX_STEPS \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=True \
logger.wandb.project=nemo-rl \
logger.wandb.name=$EXP_NAME \
logger.monitor_gpus=True \
logger.tensorboard_enabled=True \
checkpointing.enabled=True \
checkpointing.checkpoint_dir=$CKPT_DIR \
$@ \
2>&1 | tee $RUN_LOG

# Convert tensorboard logs to json
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

# Only run metrics if the target step is reached
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
uv run tests/check_metrics.py $JSON_METRICS \
'mean(data["train/gen_kl_error"], 20) < 0.002' \
'data["train/gen_kl_error"]["20"] < 0.002' \
'max(data["train/reward"]) > 0.35' \
'mean(data["timing/train/total_step_time"], 2) < 80'

# Clean up checkpoint directory after successful run to save space.
rm -rf "$CKPT_DIR"
fi
3 changes: 3 additions & 0 deletions tests/test_suites/nightly.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.sh
tests/test_suites/llm/grpo-nano-v2-12b-1n8g-megatron.sh
tests/test_suites/llm/grpo-nano-v2-12b-2n8g-fsdp2tp1.sh

# lora
tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh

#######
# SFT #
#######
Expand Down
Loading
Loading