Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
grpo:
num_prompts_per_step: 64
num_generations_per_prompt: 32
max_rollout_turns: 1
max_num_steps: 500
normalize_rewards: true
use_leave_one_out_baseline: true
val_period: 10
val_at_start: false
max_val_samples: 256
val_batch_size: 256
seed: 42
loss_fn:
reference_policy_kl_penalty: 0.01
ratio_clip_min: 0.2
ratio_clip_max: 0.2
ratio_clip_c: null
use_on_policy_kl_approximation: false
use_importance_sampling_correction: false
token_level_loss: true
checkpointing:
enabled: true
checkpoint_dir: results/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated
metric_name: val_reward
higher_is_better: true
keep_top_k: 3
save_period: 10
checkpoint_must_save_by: null
policy:
model_name: meta-llama/Llama-3.1-8B-Instruct
tokenizer:
name: meta-llama/Llama-3.1-8B-Instruct
train_global_batch_size: 512
train_micro_batch_size: 1
generation_batch_size: 32
logprob_batch_size: 2
max_total_sequence_length: 4096
precision: bfloat16
dtensor_cfg:
enabled: true
cpu_offload: false
sequence_parallel: false
activation_checkpointing: false
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null
dynamic_batching:
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
sequence_length_round: 64
sequence_packing:
enabled: false
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64
make_sequence_length_divisible_by: 1
max_grad_norm: 1
optimizer:
name: torch.optim.AdamW
kwargs:
lr: 3e-07
weight_decay: 0.01
betas:
- 0.9
- 0.999
eps: 1e-08
foreach: false
fused: false
scheduler:
- name: torch.optim.lr_scheduler.LinearLR
kwargs:
start_factor: 0.1
end_factor: 1
total_iters: 13
- name: torch.optim.lr_scheduler.ConstantLR
kwargs:
factor: 1
total_iters: 10000000000
- milestones:
- 13
generation:
backend: vllm
max_new_tokens: 4096
temperature: 1
top_p: 1
top_k: null
stop_token_ids:
- 128009
stop_strings: null
vllm_cfg:
async_engine: true
precision: ${policy.precision}
tensor_parallel_size: 1
pipeline_parallel_size: 1
gpu_memory_utilization: 0.6
max_model_len: 4096
enforce_eager: False
colocated:
enabled: false
resources:
gpus_per_node: null
num_nodes: 1
data:
max_input_seq_length: 4096
prompt_file: examples/prompts/cot.txt
system_prompt_file: null
dataset_name: OpenMathInstruct-2
shuffle: true
env:
math:
num_workers: 8
logger:
log_dir: logs/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated
num_val_samples_to_print: 0
wandb_enabled: true
tensorboard_enabled: true
mlflow_enabled: false
monitor_gpus: true
wandb:
project: nemo-rl
name: grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated
tensorboard: {}
gpu_monitoring:
collection_interval: 10
flush_interval: 10
cluster:
gpus_per_node: 8
num_nodes: 2
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'data["train/preference_loss"]["1"] < 0.69316' \
'data["train/preference_loss"]["20"] < 0.6' \
'mean(data["timing/train/total_step_time"], -10, -1) < 7.8'
fi
fi
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'data["train/preference_loss"]["1"] < 0.69316' \
'data["train/preference_loss"]["150"] < 0.4' \
'mean(data["timing/train/total_step_time"], -11, -1) < 24'
fi
fi
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'data["train/preference_loss"]["1"] < 0.69316' \
'data["train/preference_loss"]["150"] < 0.4' \
'mean(data["timing/train/total_step_time"], -11, -1) < 11.5'
fi
fi
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'data["train/preference_loss"]["1"] < 0.69316' \
'data["train/preference_loss"]["20"] < 0.6' \
'mean(data["timing/train/total_step_time"], -10) < 6.7'
fi
fi
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,3 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'data["train/preference_loss"]["1"] < 0.6932' \
'data["train/preference_loss"]["150"] < 0.68'
fi

Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'data["train/loss"]["1"] < 0.69316' \
'data["train/loss"]["150"] < 0.55' \
'mean(data["timing/train/total_step_time"], -11, -1) < 1.3'
fi
fi
1 change: 0 additions & 1 deletion tests/test_suites/llm/grpo-deepscaler-1.5b-16K.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,3 @@ cat ${RUN_LOG}.aime-16k | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"sco
# 240 step checkpoint 0.3
uv run tests/check_metrics.py ${RUN_LOG}-16k-metric.json \
'data["score"] >= 0.2396'

1 change: 0 additions & 1 deletion tests/test_suites/llm/grpo-deepscaler-1.5b-24K.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,3 @@ cat ${RUN_LOG}.aime-24k | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"sco

uv run tests/check_metrics.py ${RUN_LOG}-24k-metric.json \
'data["score"] >= 0.2396'

Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,3 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'mean(data["train/token_mult_prob_error"]) < 1.1' \
'data["train/token_mult_prob_error"]["20"] < 1.1'
fi

Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,3 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'mean(data["train/token_mult_prob_error"]) < 1.1' \
'data["train/token_mult_prob_error"]["100"] < 1.1'
fi

Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
source $SCRIPT_DIR/common.env

# ===== BEGIN CONFIG =====
NUM_NODES=2
STEPS_PER_RUN=30
MAX_STEPS=30
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
NUM_MINUTES=120
# ===== END CONFIG =====

exit_if_max_steps_reached

# Run the experiment
cd $PROJECT_ROOT
uv run examples/run_grpo_math.py \
--config $CONFIG_PATH \
grpo.max_num_steps=$MAX_STEPS \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=True \
logger.wandb.project=nemo-rl \
logger.wandb.name=$EXP_NAME \
logger.monitor_gpus=True \
logger.tensorboard_enabled=True \
checkpointing.enabled=True \
checkpointing.checkpoint_dir=$CKPT_DIR \
$@ \
2>&1 | tee $RUN_LOG

# Convert tensorboard logs to json
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

# Only run metrics if the target step is reached
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
uv run tests/check_metrics.py $JSON_METRICS \
'mean(data["train/token_mult_prob_error"]) < 1.1' \
'data["train/token_mult_prob_error"]["30"] < 1.1'
fi
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,3 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'mean(data["train/token_mult_prob_error"]) < 1.1' \
'data["train/token_mult_prob_error"]["100"] < 1.1'
fi

Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,3 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'data["train/token_mult_prob_error"]["500"] < 1.1' \
'mean(data["timing/train/total_step_time"], -6, -1) < 10'
fi

Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'data["train/token_mult_prob_error"]["500"] < 1.1' \
'data["train/reward"]["500"] > 0.1' \
'mean(data["timing/train/total_step_time"], -6, -1) < 10.5'

fi
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,3 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'mean(data["train/token_mult_prob_error"]) < 1.1' \
'data["train/token_mult_prob_error"]["20"] < 1.1'
fi

Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,3 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'mean(data["train/token_mult_prob_error"]) < 1.1' \
'data["train/token_mult_prob_error"]["2"] < 1.1'
fi

Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,3 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'mean(data["train/token_mult_prob_error"]) < 1.1' \
'data["train/token_mult_prob_error"]["30"] < 1.1'
fi

Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,3 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'data["train/token_mult_prob_error"]["450"] < 1.1' \
'mean(data["timing/train/total_step_time"], 2) < 25'
fi

Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
# Only run metrics if the target step is reached
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
uv run tests/check_metrics.py $JSON_METRICS \
'data["train/loss"]["1"] < 0.6' \
'data["train/loss"]["1"] < 0.6' \
'data["train/loss"]["250"] < 0.36' \
'max(data["ray/node.0.gpu.0.mem_gb"]) < 70' \
'mean(data["timing/train/total_step_time"], 2) < 10'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
uv run tests/check_metrics.py $JSON_METRICS \
'data["train/loss"]["1"] < 0.6' \
'data["train/loss"]["250"] < 0.36' \
'max(data["ray/node.0.gpu.0.mem_gb"]) < 80' \
'max(data["ray/node.0.gpu.0.mem_gb"]) < 80' \
'mean(data["timing/train/total_step_time"], 2) < 22'
fi
2 changes: 1 addition & 1 deletion tests/test_suites/llm/sft-llama3.1-8b-1n8g-fsdp2tp2sp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'data["train/loss"]["50"] < 0.38' \
'max(data["ray/node.0.gpu.0.mem_gb"]) < 70' \
'mean(data["timing/train/total_step_time"], 2) < 32'
fi
fi
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'data["train/loss"]["1"] < 0.6' \
'data["train/loss"]["250"] < 0.36' \
'mean(data["timing/train/total_step_time"], 2) < 6'
fi
fi
2 changes: 1 addition & 1 deletion tests/test_suites/llm/sft-llama3.1-8b-1n8g-megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'data["train/loss"]["1"] < 0.6' \
'data["train/loss"]["250"] < 0.36' \
'mean(data["timing/train/total_step_time"], 2) < 20'
fi
fi
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,3 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'max(data["ray/node.0.gpu.0.mem_gb"]) < 25' \
'mean(data["timing/train/total_step_time"], -6, -1) < 0.6'
fi

Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | ma
'data["train/loss"]["1"] < 0.37' \
'data["train/loss"]["20"] < 0.3' \
'max(data["ray/node.0.gpu.0.mem_gb"]) < 35'
fi
fi
3 changes: 3 additions & 0 deletions tests/test_suites/nightly.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh
# FP8
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8.sh

# Non-colocated
tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.sh

#######
# SFT #
#######
Expand Down