-
Notifications
You must be signed in to change notification settings - Fork 211
feat: GRPO example for Qwen3 32b context length=128k #957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
03c03fd
d9c70fc
98ec95a
774f87a
07ef653
cd9435b
5d36fcb
9ddf53a
52e7c15
cc65592
78ac745
96ee39f
2bc6bcc
a744da3
77222af
7d77ed3
cff33e1
f2fc000
02087ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,180 @@ | ||
| checkpointing: | ||
| enabled: True | ||
| checkpoint_dir: results/grpo-math-qwen3-32b-megatron-128k-4n8g | ||
| save_period: 3 | ||
| keep_top_k: 1 | ||
| metric_name: val_reward | ||
| higher_is_better: True | ||
| checkpoint_must_save_by: null | ||
|
|
||
| grpo: | ||
| normalize_rewards: True | ||
| use_leave_one_out_baseline: True | ||
| max_num_steps: 500 | ||
| num_prompts_per_step: 8 | ||
| num_generations_per_prompt: 8 | ||
| max_rollout_turns: 1 | ||
| val_period: 10 | ||
| val_at_start: False | ||
| max_val_samples: 256 | ||
| val_batch_size: 256 | ||
| seed: 42 | ||
| overlong_filtering: false | ||
|
|
||
| loss_fn: | ||
| reference_policy_kl_penalty: 0.01 | ||
| ratio_clip_min: 0.2 | ||
| ratio_clip_max: 0.2 | ||
| # (default off) loss formulation improvements (docs/guides/grpo.md#loss) | ||
| use_on_policy_kl_approximation: False | ||
| use_importance_sampling_correction: False | ||
| token_level_loss: True | ||
| ratio_clip_c: null | ||
|
|
||
| policy: | ||
| model_name: "Qwen/Qwen3-32B" | ||
| model_kwargs: | ||
| rope_scaling: | ||
| type: "yarn" | ||
| factor: 4.0 | ||
| original_max_position_embeddings: 32768 | ||
| tokenizer: | ||
| name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default | ||
| train_global_batch_size: 16 | ||
| train_micro_batch_size: 1 | ||
| generation_batch_size: 32 # Only used when generating using HF backend | ||
| logprob_batch_size: 1 | ||
| max_total_sequence_length: 131072 | ||
| precision: "bfloat16" | ||
| logprob_chunk_size: 512 | ||
|
|
||
| dtensor_cfg: | ||
| enabled: False | ||
|
|
||
| dynamic_batching: | ||
| enabled: False | ||
|
|
||
| sequence_packing: | ||
| enabled: True | ||
| train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}} | ||
| logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}} | ||
| algorithm: "modified_first_fit_decreasing" | ||
| sequence_length_round: 64 | ||
|
|
||
| max_grad_norm: 1.0 | ||
| make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size} | ||
|
|
||
| optimizer: null # remove default FSDP optimizer | ||
|
|
||
| scheduler: null # remove default FSDP scheduler | ||
|
|
||
| megatron_cfg: | ||
| enabled: True | ||
| empty_unused_memory_level: 1 | ||
| converter_type: "LlamaForCausalLM" | ||
| tensor_model_parallel_size: 8 | ||
| pipeline_model_parallel_size: 1 | ||
| context_parallel_size: 4 | ||
| expert_tensor_parallel_size: 1 | ||
| expert_model_parallel_size: 1 | ||
| sequence_parallel: True | ||
| pipeline_dtype: ${policy.precision} | ||
| num_layers_in_first_pipeline_stage: null | ||
| num_layers_in_last_pipeline_stage: null | ||
| freeze_moe_router: True | ||
| moe_router_dtype: "fp64" | ||
| moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo | ||
| moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo | ||
| apply_rope_fusion: True | ||
| activation_checkpointing: True | ||
| defer_fp32_logits: True | ||
|
|
||
| optimizer: | ||
| optimizer: "adam" | ||
| lr: 5.0e-7 | ||
| min_lr: 5.0e-8 | ||
| weight_decay: 0.0 | ||
| bf16: True | ||
| fp16: False | ||
| params_dtype: "float32" | ||
|
|
||
| adam_beta1: 0.9 | ||
| adam_beta2: 0.999 | ||
| adam_eps: 1e-8 | ||
|
|
||
| use_distributed_optimizer: True | ||
| use_precision_aware_optimizer: True | ||
|
|
||
| clip_grad: ${policy.max_grad_norm} | ||
|
|
||
| scheduler: | ||
| start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} | ||
| end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} | ||
| weight_decay_incr_style: "constant" | ||
| lr_decay_style: "constant" | ||
| lr_decay_iters: null | ||
| lr_warmup_iters: 2 | ||
| lr_warmup_init: 5.0e-8 | ||
|
|
||
| distributed_data_parallel_config: | ||
| grad_reduce_in_fp32: False | ||
| overlap_grad_reduce: True | ||
| overlap_param_gather: True | ||
| average_in_collective: True | ||
| use_custom_fsdp: False | ||
| data_parallel_sharding_strategy: "optim_grads_params" | ||
|
|
||
| generation: | ||
| backend: "vllm" | ||
| max_new_tokens: ${policy.max_total_sequence_length} | ||
| temperature: 1.0 | ||
| top_p: 1.0 | ||
| top_k: null | ||
| stop_token_ids: null | ||
| stop_strings: null | ||
| vllm_cfg: | ||
| async_engine: False | ||
| precision: ${policy.precision} | ||
| tensor_parallel_size: 8 | ||
| pipeline_parallel_size: 1 | ||
| enable_expert_parallel: false | ||
| gpu_memory_utilization: 0.6 | ||
| max_model_len: ${policy.max_total_sequence_length} | ||
| # NB(pjin): https://github.com/NVIDIA-NeMo/RL/pull/857 | ||
| enforce_eager: True | ||
| vllm_kwargs: ${policy.model_kwargs} | ||
| colocated: | ||
| enabled: true | ||
| resources: | ||
| gpus_per_node: null | ||
| num_nodes: null | ||
|
|
||
| data: | ||
| dataset_name: "OpenMathInstruct-2" | ||
| shuffle: true | ||
| max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len | ||
| prompt_file: "examples/prompts/cot.txt" | ||
| system_prompt_file: null | ||
|
|
||
| env: | ||
| math: | ||
| num_workers: 8 | ||
|
|
||
| logger: | ||
| log_dir: logs/grpo-math-qwen3-32b-128k-4n8g-megatrontp8cp4 | ||
| num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal | ||
| wandb_enabled: True | ||
| tensorboard_enabled: True | ||
| mlflow_enabled: False # Disable MLflow logging | ||
| monitor_gpus: False # If true, will monitor GPU usage and log to wandb and/or tensorboard | ||
| wandb: | ||
| project: nemo-rl | ||
| name: "grpo-math-qwen3-32b-128k-4n8g-megatrontp8cp4" | ||
| tensorboard: {} | ||
| gpu_monitoring: | ||
| collection_interval: 10 # How often to collect GPU usage metrics (in seconds) | ||
| flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) | ||
|
|
||
| cluster: | ||
| gpus_per_node: 8 | ||
| num_nodes: 4 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -759,7 +759,7 @@ def __init__( | |
| self.final_padded_vocab_size = tokenizer_config.padded_vocab_size | ||
| self.dp_size = worker_sharding_annotations.get_axis_size("data_parallel") | ||
| self.megatron_bridge = AutoBridge.from_hf_pretrained( | ||
| hf_model_name, trust_remote_code=True | ||
| hf_model_name, trust_remote_code=True, **self.cfg.get("model_kwargs", {}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think mcore path cannot parse/handle the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we should error out in mcore path if seeing @yaoyu-33 to confirm if the qwen and llama bridge in mbridge can parse and handle this field. I believe this is only supported in deepseek model type. |
||
| ) | ||
|
|
||
| self.should_disable_forward_pre_hook = ( | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,39 @@ | ||||||
| #!/bin/bash | ||||||
| SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) | ||||||
| source $SCRIPT_DIR/common.env | ||||||
|
|
||||||
| # ===== BEGIN CONFIG ===== | ||||||
| NUM_NODES=4 | ||||||
| STEPS_PER_RUN=2 | ||||||
| MAX_STEPS=2 | ||||||
| NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up | ||||||
| NUM_MINUTES=90 | ||||||
| # ===== END CONFIG ===== | ||||||
|
|
||||||
| exit_if_max_steps_reached | ||||||
|
|
||||||
| # Run the experiment | ||||||
| cd $PROJECT_ROOT | ||||||
| uv run examples/run_grpo_math.py \ | ||||||
| --config $CONFIG_PATH \ | ||||||
| grpo.max_num_steps=$MAX_STEPS \ | ||||||
| logger.log_dir=$LOG_DIR \ | ||||||
| logger.wandb_enabled=True \ | ||||||
| logger.wandb.project=nemo-rl \ | ||||||
| logger.wandb.name=$EXP_NAME \ | ||||||
| logger.monitor_gpus=True \ | ||||||
| logger.tensorboard_enabled=True \ | ||||||
| checkpointing.enabled=True \ | ||||||
| checkpointing.checkpoint_dir=$CKPT_DIR \ | ||||||
| $@ \ | ||||||
| 2>&1 | tee $RUN_LOG | ||||||
|
|
||||||
| # Convert tensorboard logs to json | ||||||
| uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS | ||||||
|
|
||||||
| # Only run metrics if the target step is reached | ||||||
| if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then | ||||||
| uv run tests/check_metrics.py $JSON_METRICS \ | ||||||
| 'mean(data["train/token_mult_prob_error"]) < 1.1' \ | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wouldn't this fail according to the wandb link you shared?
I think with longer generations, we'll probably run into outliers that skew the mean. This one is run for so few steps, it's probably hard to write something robust. Maybe:
Suggested change
and then add a comment above why you use I made an issue to track this: #1039 |
||||||
| 'data["train/token_mult_prob_error"]["$MAX_STEPS"] < 1.1' | ||||||
| fi | ||||||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which models need this? is it possible to handle this in code?
In the past when we had stuff like this, the consensus was to handle it in code since we knew which model types needed it, ex: fdb565c
regardless of if this is handled in code or yaml, it should probably have an entry in the
model-quirks.mdso we have documentationThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is used by qwen3 + long context length: https://huggingface.co/Qwen/Qwen3-32B#processing-long-texts.
Since it's a optional configuration that can be changed by the user, I tend to explicitly put it in yaml.
Will update
model-quirks.mdto reflect this.