Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
d5731a8
Added gspo
ertkonuk Jul 28, 2025
e6ba491
Port GSPO ratio clip from dev branch.
pjin-nvidia Aug 1, 2025
ce337d8
Micro batch size.
pjin-nvidia Aug 1, 2025
ddc6137
Merge branch 'main' into pjin/gspo-algo
pjin-nvidia Aug 6, 2025
76c842d
Example GSPO config.
pjin-nvidia Aug 6, 2025
8b65703
Cleaner impl.
pjin-nvidia Aug 7, 2025
7938879
GSPO version of deepscaler example (WIP).
pjin-nvidia Aug 7, 2025
fef9120
Merge remote-tracking branch 'origin/main' into pjin/gspo-algo
pjin-nvidia Aug 7, 2025
9782097
Alternative 2-node TP=2 deepscaler GSPO config (WIP).
pjin-nvidia Aug 7, 2025
95867ed
Remove Qwen3 GSPO config (will take too long to test).
pjin-nvidia Aug 7, 2025
4d0568d
Merge remote-tracking branch 'origin/main' into pjin/gspo-algo
pjin-nvidia Aug 7, 2025
c536aff
Rename sequence_level_importance_sampling => sequence_level_importanc…
pjin-nvidia Aug 8, 2025
2cbc0ca
Merge remote-tracking branch 'origin/main' into pjin/gspo-algo
pjin-nvidia Aug 8, 2025
5e384dc
Using the new deepscaler config.
pjin-nvidia Aug 9, 2025
8882bfe
Restore comment.
pjin-nvidia Aug 9, 2025
b795e65
Restore comment.
pjin-nvidia Aug 9, 2025
808953e
Merge remote-tracking branch 'origin/main' into pjin/gspo-algo
pjin-nvidia Aug 9, 2025
446b44d
Merge remote-tracking branch 'origin/main' into pjin/gspo-algo
pjin-nvidia Aug 11, 2025
b9127f4
Merge remote-tracking branch 'origin/main' into pjin/gspo-algo
pjin-nvidia Aug 12, 2025
b6598df
Merge remote-tracking branch 'origin/main' into pjin/gspo-algo
pjin-nvidia Aug 12, 2025
931b744
Minor fix.
pjin-nvidia Aug 12, 2025
b90ea62
Remove inaccurate comment.
pjin-nvidia Aug 13, 2025
e87d144
Merge remote-tracking branch 'origin/main' into pjin/gspo-algo
pjin-nvidia Aug 14, 2025
dfbd89d
GSPO config test.
pjin-nvidia Aug 14, 2025
825f8a7
Merge remote-tracking branch 'origin/main' into pjin/gspo-algo
pjin-nvidia Aug 14, 2025
15b234c
Merge remote-tracking branch 'origin/main' into pjin/gspo-algo
pjin-nvidia Aug 19, 2025
3d80fb0
Merge remote-tracking branch 'origin/main' into HEAD
pjin-nvidia Aug 19, 2025
d95dea0
Batch size = 1 GSPO unit tests (should be equivalent to GRPO).
pjin-nvidia Aug 19, 2025
c957966
Clipped PG loss unit tests support for batch size > 1.
pjin-nvidia Aug 20, 2025
a28c974
GSPO unit tests.
pjin-nvidia Aug 21, 2025
969cc88
Merge remote-tracking branch 'origin/main' into pjin/gspo-algo
pjin-nvidia Aug 21, 2025
c69ae82
Loop var name. Format.
pjin-nvidia Aug 21, 2025
a752b5b
Merge remote-tracking branch 'origin/main' into pjin/gspo-algo-retry
pjin-nvidia Aug 25, 2025
5f2e53b
GSPO test needs GRPO base config.
pjin-nvidia Aug 25, 2025
f4970a5
Move Megatron Qwen2.5 7B run test from nightly to release.
pjin-nvidia Aug 26, 2025
d43fe09
Merge remote-tracking branch 'origin/main' into pjin/gspo-algo-retry
pjin-nvidia Aug 26, 2025
16258e3
Correct nightly GSPO run test.
pjin-nvidia Aug 26, 2025
abe4f91
Base config needs the added field.
pjin-nvidia Aug 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ loss_fn:
# (default off) loss formulation improvements (docs/guides/grpo.md#loss)
use_on_policy_kl_approximation: false
use_importance_sampling_correction: false
sequence_level_importance_ratios: false
token_level_loss: true

checkpointing:
Expand Down
146 changes: 146 additions & 0 deletions examples/configs/recipes/llm/grpo-gspo-deepscaler-1.5b-8K.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# GRPO Algorithm Configuration
grpo:
num_prompts_per_step: 128
num_generations_per_prompt: 8
max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question)
max_num_steps: 1000000
normalize_rewards: true
use_leave_one_out_baseline: true
val_period: 10
val_at_start: false
max_val_samples: 480
val_batch_size: 32
seed: 42

loss_fn:
reference_policy_kl_penalty: 0.0
ratio_clip_min: 0.2
ratio_clip_max: 0.2
ratio_clip_c: null
# (default off) loss formulation improvements (docs/guides/grpo.md#loss)
use_on_policy_kl_approximation: false
use_importance_sampling_correction: false
sequence_level_importance_ratios: true
token_level_loss: false

checkpointing:
enabled: true
checkpoint_dir: "results/grpo"
metric_name: "val_reward"
higher_is_better: true
keep_top_k: 10
save_period: 10
checkpoint_must_save_by: null

policy:
model_name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 64
train_micro_batch_size: 1
generation_batch_size: 32 # Only used when generating using HF backend
logprob_batch_size: 4
max_total_sequence_length: 8192
precision: "bfloat16"

dtensor_cfg:
enabled: true
cpu_offload: true
sequence_parallel: true
activation_checkpointing: true
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null

dynamic_batching:
enabled: False

sequence_packing:
enabled: False

# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
max_grad_norm: 1.0

optimizer:
name: "torch.optim.AdamW"
kwargs:
lr: 2.0e-6
weight_decay: 0.01
betas: [0.9, 0.999]
eps: 1e-8
# when using Dtensor, we need to set foreach
# and fused to False
foreach: False
fused: False

scheduler:
- name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
start_factor: 0.1
end_factor: 1.0
total_iters: 50
- name: "torch.optim.lr_scheduler.ConstantLR"
kwargs:
factor: 1.0
total_iters: 10000000000
- milestones: [50]

generation:
backend: "vllm"
max_new_tokens: ${policy.max_total_sequence_length}
temperature: 1.0
top_p: 1.0
top_k: null
stop_token_ids: null
stop_strings: null
vllm_cfg:
async_engine: false
precision: ${policy.precision}
tensor_parallel_size: 1
pipeline_parallel_size: 1
gpu_memory_utilization: 0.6
max_model_len: ${policy.max_total_sequence_length}
enforce_eager: True
colocated:
# true: generation shares training GPUs
# false: uses dedicated generation resources
enabled: true
# only relevant when enabled is false
resources:
gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1
num_nodes: null # Decides number of nodes to be dedicated to generation

data:
max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len
prompt_file: "examples/prompts/cot.txt"
system_prompt_file: null
dataset_name: "DeepScaler"
shuffle: true

env:
math:
num_workers: 16

logger:
log_dir: "logs" # Base directory for all logs
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
wandb_enabled: false
tensorboard_enabled: false
mlflow_enabled: false
monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard
wandb:
project: "grpo-dev"
name: "grpo-dev-logger"
tensorboard: {}
mlflow:
experiment_name: "grpo-dev"
run_name: "grpo-dev-logger"
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)

cluster:
gpus_per_node: 8
num_nodes: 1
74 changes: 62 additions & 12 deletions nemo_rl/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, TypedDict, TypeVar
from typing import Any, NotRequired, Optional, TypedDict, TypeVar

import torch

Expand All @@ -37,6 +37,11 @@ class ClippedPGLossConfig(TypedDict):
use_on_policy_kl_approximation: bool
use_importance_sampling_correction: bool
token_level_loss: bool
# If True, apply the off-policy importance-sampling correction at the
# sequence level (one weight per generated sample), as in GSPO.
# If False (default), correction is applied at the token level as in the
# original GRPO paper.
sequence_level_importance_ratios: NotRequired[bool]


class ClippedPGLossDataDict(TypedDict):
Expand All @@ -60,6 +65,7 @@ class ClippedPGLossFn(LossFunction):
- PPO (Clipped) - https://arxiv.org/abs/1707.06347
- GRPO - https://arxiv.org/abs/2402.03300
- REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740
- GSPO (set sequence_level_importance_ratios = True and token_level_loss = False) - https://arxiv.org/abs/2507.18071

Formula:
L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref)
Expand Down Expand Up @@ -101,10 +107,18 @@ def __init__(self, cfg: ClippedPGLossConfig):
self.use_importance_sampling_correction = cfg[
"use_importance_sampling_correction"
]

# Whether to compute importance weights per-sequence instead of per-token.
self.sequence_level_importance_ratios = cfg.get(
"sequence_level_importance_ratios",
False,
)
self.loss_type = (
LossType.TOKEN_LEVEL if cfg["token_level_loss"] else LossType.SEQUENCE_LEVEL
)
if self.sequence_level_importance_ratios:
assert self.loss_type == LossType.SEQUENCE_LEVEL, (
"sequence-level importance sampling (e.g. GSPO) is mutually exclusive with token-level loss"
)

def __call__(
self,
Expand Down Expand Up @@ -204,7 +218,17 @@ def __call__(

# Calculate clipped loss function if ppo ratio is enabled.
if not self.disable_ppo_ratio:
ratios = (curr_logprobs - prev_logprobs).exp()
log_ratios = curr_logprobs - prev_logprobs
if self.sequence_level_importance_ratios:
seq_log_ratio_mean = masked_mean(
log_ratios,
token_mask,
dim=-1,
).unsqueeze(-1)
seq_ratio = seq_log_ratio_mean.exp()
ratios = seq_ratio.repeat(1, advantages.shape[1])
else:
ratios = log_ratios.exp()
ratios_clamped = ratios.clamp(
1.0 - self.ratio_clip_min, 1.0 + self.ratio_clip_max
)
Expand All @@ -228,11 +252,29 @@ def __call__(
advantages < 0, torch.min(clip_loss, loss3), clip_loss
)

# -------------------------------------------------------------
# Off-policy (actor) importance-sampling correction
# -------------------------------------------------------------
# See: docs/guides/grpo.md#importance-sampling-correction
actor_importance_weights = torch.exp(prev_logprobs - generation_logprobs)
actor_importance_weights = torch.nan_to_num(
actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0
)
if self.sequence_level_importance_ratios:
# importance weight w_i = exp(Σ_t (log π_actor − log π_behaviour))
seq_lp_diff = ((prev_logprobs - generation_logprobs) * mask).sum(dim=-1)
actor_importance_weights = torch.exp(seq_lp_diff).detach()
actor_importance_weights = torch.nan_to_num(
actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0
)
# Broadcast to token dimension so we can reuse existing reduction
actor_importance_weights_expanded = actor_importance_weights.unsqueeze(-1)
else:
# Token-level correction
actor_importance_weights_expanded = torch.exp(
prev_logprobs - generation_logprobs
)
actor_importance_weights_expanded = torch.nan_to_num(
actor_importance_weights_expanded, nan=0.0, posinf=0.0, neginf=0.0
)
actor_importance_weights = actor_importance_weights_expanded
del actor_importance_weights_expanded
if self.use_importance_sampling_correction:
importance_weights_to_use = actor_importance_weights
else:
Expand All @@ -255,12 +297,20 @@ def __call__(
global_normalization_factor=global_valid_seqs,
)

# Metric: sampling importance ratio (mean over samples)
# See: docs/guides/grpo.md#sampling-importance-ratio
sample_importance_ratio = masked_mean(
actor_importance_weights,
mask,
global_normalization_factor=global_valid_toks,
)
if self.sequence_level_importance_ratios:
sample_importance_ratio = masked_mean(
actor_importance_weights,
sample_mask,
global_normalization_factor=global_valid_seqs,
)
else:
sample_importance_ratio = masked_mean(
actor_importance_weights,
mask,
global_normalization_factor=global_valid_toks,
)

# Approximating entropy as E_{s ~ \pi_{gen}(s)}[-(\pi_{curr}/\pi_{gen})log(\pi_{curr}(s))]
# See more details and other metrics in docs/guides/grpo.md#metrics
Expand Down
67 changes: 67 additions & 0 deletions tests/test_suites/llm/grpo-gspo-deepscaler-1.5b-8K.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
source $SCRIPT_DIR/common.env

# ===== BEGIN CONFIG =====
NUM_NODES=1
STEPS_PER_RUN=40
MAX_STEPS=40
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
NUM_MINUTES=240
# ===== END CONFIG =====

exit_if_max_steps_reached

# Run the experiment
cd $PROJECT_ROOT
uv run examples/run_grpo_math.py \
--config $CONFIG_PATH \
grpo.max_num_steps=$MAX_STEPS \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=True \
logger.wandb.project=nemo-rl \
logger.wandb.name=$EXP_NAME \
logger.monitor_gpus=True \
logger.tensorboard_enabled=True \
checkpointing.enabled=True \
checkpointing.checkpoint_dir=$CKPT_DIR \
$@ \
2>&1 | tee $RUN_LOG

# Convert tensorboard logs to json
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

# Only run metrics if the target step is reached
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
uv run tests/check_metrics.py $JSON_METRICS \
'mean(data["train/token_mult_prob_error"]) < 1.1' \
"data['train/token_mult_prob_error']['$MAX_STEPS'] < 1.1"
fi

# TODO: enable in subsequent PR to do a quick accuracy check
## Convert 8k checkpoint
#uv run examples/converters/convert_dcp_to_hf.py \
# --config=$CKPT_DIR/step_${MAX_STEPS}/config.yaml \
# --dcp-ckpt-path=$CKPT_DIR/step_${MAX_STEPS}/policy/weights \
# --hf-ckpt-path=$CKPT_DIR/gspo-deepscaler-8k-${MAX_STEPS}-hf
#
## Run eval
#uv run examples/run_eval.py \
# generation.model_name=$CKPT_DIR/gspo-deepscaler-8k-${MAX_STEPS}-hf \
# data.prompt_file=examples/prompts/cot.txt \
# generation.vllm_cfg.max_model_len=32768 2>&1 | tee ${RUN_LOG}.aime-8k
#
#cat ${RUN_LOG}.aime-8k | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > ${RUN_LOG}-8k-metric.json
#
#uv run tests/check_metrics.py ${RUN_LOG}-8k-metric.json \
# 'data["score"] >= 0.25' \
#
##uv run examples/run_eval.py \
## generation.model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \
## data.prompt_file=examples/prompts/cot.txt \
## generation.vllm_cfg.max_model_len=32768 2>&1 | tee ${RUN_LOG}.aime-baseline
#
##cat ${RUN_LOG}.aime-baseline | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > ${RUN_LOG}-baseline-metric.json
#
##uv run tests/check_metrics.py ${RUN_LOG}-baseline-metric.json \
## 'data["score"] == 0.2' \
4 changes: 3 additions & 1 deletion tests/test_suites/nightly.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.v3.sh

# Megatron
tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.sh
tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-megatron.sh

# Functional 32b run
tests/test_suites/llm/grpo-qwen2.5-32b-32n8g-fsdp2tp8sp-actckpt.v3.sh
Expand All @@ -29,6 +28,9 @@ tests/test_suites/llm/grpo-deepscaler-1.5b-16K.sh
tests/test_suites/llm/grpo-deepscaler-1.5b-24K.sh
tests/test_suites/llm/grpo-deepscaler-1.5b-8K.sh

# Deepscaler (GSPO)
tests/test_suites/llm/grpo-gspo-deepscaler-1.5b-8K.sh

# GRPO math test run (32K context mcore)
tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh

Expand Down
3 changes: 3 additions & 0 deletions tests/test_suites/release.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# GRPO #
########

# Megatron (Qwen/Qwen2.5-7B-Instruct)
tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-megatron.sh

# Long 8b run
tests/test_suites/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.v3.sh

Expand Down
Loading