Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
a99900f
Port chunked logprob and deferred float32 logits (WIP).
pjin-nvidia Aug 6, 2025
37a027a
Add copy of nemo.tron.model without logits float32 cast.
pjin-nvidia Aug 6, 2025
31707db
Fix.
pjin-nvidia Aug 6, 2025
6a445bc
Ruff + doc comment.
pjin-nvidia Aug 6, 2025
a020289
Configurable deferring float32 logits.
pjin-nvidia Aug 6, 2025
956051c
Update docstrings.
pjin-nvidia Aug 6, 2025
ece8049
Ruff.
pjin-nvidia Aug 6, 2025
670743f
Basic chunking support in logprobs computation with sequence packing.
pjin-nvidia Aug 6, 2025
4dc2aca
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia Aug 6, 2025
f1a9d21
Unit test for chunked logprobs.
pjin-nvidia Aug 6, 2025
df70715
Ruff.
pjin-nvidia Aug 6, 2025
637e131
Pyrefly.
pjin-nvidia Aug 6, 2025
abbf796
Fix test. Pyrefly.
pjin-nvidia Aug 6, 2025
985ba77
Ruff.
pjin-nvidia Aug 6, 2025
13265aa
Stale comment.
pjin-nvidia Aug 6, 2025
e0b8da0
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia Aug 7, 2025
9758b14
Remove unused config.
pjin-nvidia Aug 7, 2025
ea51715
Remove unused config.
pjin-nvidia Aug 7, 2025
1670f93
Also apply to the reference model.
pjin-nvidia Aug 7, 2025
584d8e0
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia Aug 7, 2025
e958d40
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia Aug 7, 2025
8aef5ed
Typed policy configs.
pjin-nvidia Aug 7, 2025
830debe
Bump NeMo submodule.
pjin-nvidia Aug 7, 2025
030ddc5
Remove duplicated nemo.tron.model code and use the new bumped submodule.
pjin-nvidia Aug 7, 2025
9020390
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia Aug 7, 2025
6f014ce
Remove leftover float32 cast. Ruff.
pjin-nvidia Aug 7, 2025
9e5dcd7
Check for float32 logprobs.
pjin-nvidia Aug 7, 2025
b181eb0
Lint.
pjin-nvidia Aug 7, 2025
3cb90c9
Add example config (TODO: functional test for this config).
pjin-nvidia Aug 7, 2025
dde98d9
Add 32K max context Qwen3 30B MoE test run.
pjin-nvidia Aug 7, 2025
651fcdb
Fix deferred fp32 config.
pjin-nvidia Aug 8, 2025
174d1ca
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia Aug 8, 2025
65bbd9b
Rename deferred_fp32_logits => defer_fp32_logits.
pjin-nvidia Aug 8, 2025
edd00ef
chmod +x
pjin-nvidia Aug 8, 2025
a68dfa2
Missing config.
pjin-nvidia Aug 9, 2025
a0044ad
More missing config.
pjin-nvidia Aug 9, 2025
8de985d
Using updated NeMo branch.
pjin-nvidia Aug 9, 2025
c5c83ba
More missing config.
pjin-nvidia Aug 9, 2025
b8a0d7a
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia Aug 9, 2025
194eef6
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia Aug 11, 2025
6865c1a
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia Aug 12, 2025
7100474
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia Aug 12, 2025
da2f305
Lint and minor refactor.
pjin-nvidia Aug 12, 2025
cd5b02a
Fix.
pjin-nvidia Aug 12, 2025
81fb8e1
Unnecessary clone.
pjin-nvidia Aug 12, 2025
ef9d3d5
Remove clone + exp_ with just exp.
pjin-nvidia Aug 12, 2025
b66497b
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia Aug 13, 2025
3d38161
Set HF_HUB_OFFLINE=1 for github CI.
pjin-nvidia Aug 13, 2025
0f0de7d
Fix test.
pjin-nvidia Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/actions/test-template/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ runs:
--shm-size=64g \
--env TRANSFORMERS_OFFLINE=0 \
--env HYDRA_FULL_ERROR=1 \
--env HF_HUB_OFFLINE=1 \
--env HF_HOME=/home/TestData/nemo-rl/hf_home \
--env HF_DATASETS_CACHE=/home/TestData/nemo-rl/hf_datasets_cache \
--env NEMO_RL_REPO_DIR=/opt/nemo-rl \
Expand Down
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[submodule "3rdparty/NeMo"]
path = 3rdparty/NeMo-workspace/NeMo
url = https://github.com/NVIDIA/NeMo.git
branch = https://github.com/NVIDIA/NeMo/tree/ashors/rl-qwen3-export
branch = pjin/ashors/rl-qwen3-export
shallow = true
[submodule "3rdparty/Megatron-LM"]
path = 3rdparty/Megatron-LM-workspace/Megatron-LM
Expand Down
2 changes: 1 addition & 1 deletion 3rdparty/NeMo-workspace/NeMo
Submodule NeMo updated from aaefed to 5c4264
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
checkpointing:
enabled: True
checkpoint_dir: results/grpo-math-qwen3-30ba3b-megatron-tp4-32k
save_period: 3
keep_top_k: 1
metric_name: val_reward
higher_is_better: True
checkpoint_must_save_by: null

grpo:
normalize_rewards: True
use_leave_one_out_baseline: True
max_num_steps: 3
num_prompts_per_step: 64
num_generations_per_prompt: 16
max_rollout_turns: 1
val_period: 3
val_at_start: False
max_val_samples: 256
val_batch_size: 256
seed: 42

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
ratio_clip_max: 0.2
# (default off) loss formulation improvements (docs/guides/grpo.md#loss)
use_on_policy_kl_approximation: False
use_importance_sampling_correction: False
token_level_loss: True
ratio_clip_c: null

policy:
model_name: "Qwen/Qwen3-30B-A3B"
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 512
train_micro_batch_size: 1
generation_batch_size: 32 # Only used when generating using HF backend
logprob_batch_size: 1
max_total_sequence_length: 32768
precision: "bfloat16"
activation_checkpointing_enabled: True
logprob_chunk_size: 2048

dtensor_cfg:
enabled: False

dynamic_batching:
enabled: False

sequence_packing:
enabled: False

max_grad_norm: 1.0
make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size}

optimizer: null # remove default FSDP optimizer

scheduler: null # remove default FSDP scheduler

megatron_cfg:
enabled: True
empty_unused_memory_level: 1
converter_type: "LlamaForCausalLM"
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 1
context_parallel_size: 1
expert_tensor_parallel_size: 1
expert_model_parallel_size: 8
sequence_parallel: True
pipeline_dtype: ${policy.precision}
num_layers_in_first_pipeline_stage: null
num_layers_in_last_pipeline_stage: null
freeze_moe_router: True
moe_router_dtype: "fp64"
moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo
moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo
apply_rope_fusion: True
activation_checkpointing: True
defer_fp32_logits: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would be the reason to set this to False?

Copy link
Contributor Author

@pjin-nvidia pjin-nvidia Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly for strict backward compat, but we could instead enable it by default (i.e. make it an opt-out config like no_defer_fp32_logits or similar)

wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. How about the following:

  1. this PR introduces it, default off
  2. follow up PR where we run all our nightly tests to see if defaulting to true is ok, if so, remove the arg
    wdyt? If the feature is broadly applicable we should probably switch it to true so no one else runs into the same issue (assuming no accuracy penalty)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, (1) and then (2) SGTM!


optimizer:
optimizer: "adam"
lr: 5.0e-7
min_lr: 5.0e-8
weight_decay: 0.0
bf16: True
fp16: False
params_dtype: "float32"

adam_beta1: 0.9
adam_beta2: 0.999
adam_eps: 1e-8

use_distributed_optimizer: True
use_precision_aware_optimizer: True

clip_grad: ${policy.max_grad_norm}

scheduler:
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 2
lr_warmup_init: 5.0e-8

distributed_data_parallel_config:
grad_reduce_in_fp32: False
overlap_grad_reduce: True
overlap_param_gather: True
average_in_collective: True
use_custom_fsdp: False
data_parallel_sharding_strategy: "optim_grads_params"

env_vars:
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False"

generation:
backend: "vllm"
max_new_tokens: ${policy.max_total_sequence_length}
temperature: 1.0
top_p: 1.0
top_k: null
stop_token_ids: null
stop_strings: null
vllm_cfg:
async_engine: False
precision: ${policy.precision}
tensor_parallel_size: 4
pipeline_parallel_size: 1
gpu_memory_utilization: 0.6
max_model_len: ${policy.max_total_sequence_length}
# NB(pjin): https://github.com/NVIDIA-NeMo/RL/pull/857
enforce_eager: True
colocated:
enabled: true
resources:
gpus_per_node: null
num_nodes: null

data:
dataset_name: "OpenMathInstruct-2"
shuffle: true
max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len
prompt_file: "examples/prompts/cot.txt"
system_prompt_file: null

env:
math:
num_workers: 8

logger:
log_dir: logs/grpo-math-qwen3-30ba3b-megatron-tp4-32k
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
wandb_enabled: True
tensorboard_enabled: True
mlflow_enabled: False # Disable MLflow logging
monitor_gpus: False # If true, will monitor GPU usage and log to wandb and/or tensorboard
wandb:
project: nemo-rl
name: "grpo-math-qwen3-30ba3b-megatron-tp4-32k"
tensorboard: {}
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)

cluster:
gpus_per_node: 8
num_nodes: 4
8 changes: 3 additions & 5 deletions nemo_rl/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ def __call__(
global_normalization_factor=global_valid_toks,
).item()

next_token_logits = next_token_logits.to(torch.float32)

if vocab_parallel_group is not None:
assert vocab_parallel_rank is not None, (
"vocab_parallel_rank must be provided when vocab_parallel_group is provided"
Expand All @@ -159,6 +157,7 @@ def __call__(
next_token_logits, data["input_ids"], seq_index=seq_index
)
else:
next_token_logits = next_token_logits.to(torch.float32)
next_token_logits_wo_last = next_token_logits[
:, :-1
] # Remove last position's logits
Expand Down Expand Up @@ -327,8 +326,6 @@ def __call__(
mask = token_mask * sample_mask.unsqueeze(-1)
seq_index = data.get("seq_index", None)

next_token_logits = next_token_logits.to(torch.float32)

# Gather the logprobs for the actual next tokens
if vocab_parallel_group is not None:
assert vocab_parallel_rank is not None, (
Expand All @@ -351,6 +348,7 @@ def __call__(
)
else:
next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token
next_token_logits = next_token_logits.to(torch.float32)
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
)
Expand Down Expand Up @@ -583,7 +581,6 @@ def _dpo_loss(
sample_mask = data["sample_mask"]
seq_index = data.get("seq_index", None)

next_token_logits = next_token_logits.to(torch.float32)
if vocab_parallel_group is not None:
assert vocab_parallel_rank is not None, (
"vocab_parallel_rank must be provided when vocab_parallel_group is provided"
Expand All @@ -605,6 +602,7 @@ def _dpo_loss(
)
else:
next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token
next_token_logits = next_token_logits.to(torch.float32)
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
)
Expand Down
Loading
Loading