-
Notifications
You must be signed in to change notification settings - Fork 207
feat: chunked logprob calculation with deferred fp32 cast to help with OOM #856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
a99900f
Port chunked logprob and deferred float32 logits (WIP).
pjin-nvidia 37a027a
Add copy of nemo.tron.model without logits float32 cast.
pjin-nvidia 31707db
Fix.
pjin-nvidia 6a445bc
Ruff + doc comment.
pjin-nvidia a020289
Configurable deferring float32 logits.
pjin-nvidia 956051c
Update docstrings.
pjin-nvidia ece8049
Ruff.
pjin-nvidia 670743f
Basic chunking support in logprobs computation with sequence packing.
pjin-nvidia 4dc2aca
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia f1a9d21
Unit test for chunked logprobs.
pjin-nvidia df70715
Ruff.
pjin-nvidia 637e131
Pyrefly.
pjin-nvidia abbf796
Fix test. Pyrefly.
pjin-nvidia 985ba77
Ruff.
pjin-nvidia 13265aa
Stale comment.
pjin-nvidia e0b8da0
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia 9758b14
Remove unused config.
pjin-nvidia ea51715
Remove unused config.
pjin-nvidia 1670f93
Also apply to the reference model.
pjin-nvidia 584d8e0
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia e958d40
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia 8aef5ed
Typed policy configs.
pjin-nvidia 830debe
Bump NeMo submodule.
pjin-nvidia 030ddc5
Remove duplicated nemo.tron.model code and use the new bumped submodule.
pjin-nvidia 9020390
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia 6f014ce
Remove leftover float32 cast. Ruff.
pjin-nvidia 9e5dcd7
Check for float32 logprobs.
pjin-nvidia b181eb0
Lint.
pjin-nvidia 3cb90c9
Add example config (TODO: functional test for this config).
pjin-nvidia dde98d9
Add 32K max context Qwen3 30B MoE test run.
pjin-nvidia 651fcdb
Fix deferred fp32 config.
pjin-nvidia 174d1ca
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia 65bbd9b
Rename deferred_fp32_logits => defer_fp32_logits.
pjin-nvidia edd00ef
chmod +x
pjin-nvidia a68dfa2
Missing config.
pjin-nvidia a0044ad
More missing config.
pjin-nvidia 8de985d
Using updated NeMo branch.
pjin-nvidia c5c83ba
More missing config.
pjin-nvidia b8a0d7a
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia 194eef6
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia 6865c1a
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia 7100474
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia da2f305
Lint and minor refactor.
pjin-nvidia cd5b02a
Fix.
pjin-nvidia 81fb8e1
Unnecessary clone.
pjin-nvidia ef9d3d5
Remove clone + exp_ with just exp.
pjin-nvidia b66497b
Merge remote-tracking branch 'origin/main' into pjin/logprob
pjin-nvidia 3d38161
Set HF_HUB_OFFLINE=1 for github CI.
pjin-nvidia 0f0de7d
Fix test.
pjin-nvidia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule NeMo
updated
from aaefed to 5c4264
172 changes: 172 additions & 0 deletions
172
examples/configs/recipes/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| checkpointing: | ||
| enabled: True | ||
| checkpoint_dir: results/grpo-math-qwen3-30ba3b-megatron-tp4-32k | ||
| save_period: 3 | ||
| keep_top_k: 1 | ||
| metric_name: val_reward | ||
| higher_is_better: True | ||
| checkpoint_must_save_by: null | ||
|
|
||
| grpo: | ||
| normalize_rewards: True | ||
| use_leave_one_out_baseline: True | ||
| max_num_steps: 3 | ||
| num_prompts_per_step: 64 | ||
| num_generations_per_prompt: 16 | ||
| max_rollout_turns: 1 | ||
| val_period: 3 | ||
| val_at_start: False | ||
| max_val_samples: 256 | ||
| val_batch_size: 256 | ||
| seed: 42 | ||
|
|
||
| loss_fn: | ||
| reference_policy_kl_penalty: 0.01 | ||
| ratio_clip_min: 0.2 | ||
| ratio_clip_max: 0.2 | ||
| # (default off) loss formulation improvements (docs/guides/grpo.md#loss) | ||
| use_on_policy_kl_approximation: False | ||
| use_importance_sampling_correction: False | ||
| token_level_loss: True | ||
| ratio_clip_c: null | ||
|
|
||
| policy: | ||
| model_name: "Qwen/Qwen3-30B-A3B" | ||
| tokenizer: | ||
| name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default | ||
| train_global_batch_size: 512 | ||
| train_micro_batch_size: 1 | ||
| generation_batch_size: 32 # Only used when generating using HF backend | ||
| logprob_batch_size: 1 | ||
| max_total_sequence_length: 32768 | ||
| precision: "bfloat16" | ||
| activation_checkpointing_enabled: True | ||
| logprob_chunk_size: 2048 | ||
|
|
||
| dtensor_cfg: | ||
| enabled: False | ||
|
|
||
| dynamic_batching: | ||
| enabled: False | ||
|
|
||
| sequence_packing: | ||
| enabled: False | ||
|
|
||
| max_grad_norm: 1.0 | ||
| make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size} | ||
|
|
||
| optimizer: null # remove default FSDP optimizer | ||
|
|
||
| scheduler: null # remove default FSDP scheduler | ||
|
|
||
| megatron_cfg: | ||
| enabled: True | ||
| empty_unused_memory_level: 1 | ||
| converter_type: "LlamaForCausalLM" | ||
| tensor_model_parallel_size: 4 | ||
| pipeline_model_parallel_size: 1 | ||
| context_parallel_size: 1 | ||
| expert_tensor_parallel_size: 1 | ||
| expert_model_parallel_size: 8 | ||
| sequence_parallel: True | ||
| pipeline_dtype: ${policy.precision} | ||
| num_layers_in_first_pipeline_stage: null | ||
| num_layers_in_last_pipeline_stage: null | ||
| freeze_moe_router: True | ||
| moe_router_dtype: "fp64" | ||
| moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo | ||
| moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo | ||
| apply_rope_fusion: True | ||
| activation_checkpointing: True | ||
| defer_fp32_logits: True | ||
|
|
||
| optimizer: | ||
| optimizer: "adam" | ||
| lr: 5.0e-7 | ||
| min_lr: 5.0e-8 | ||
| weight_decay: 0.0 | ||
| bf16: True | ||
| fp16: False | ||
| params_dtype: "float32" | ||
|
|
||
| adam_beta1: 0.9 | ||
| adam_beta2: 0.999 | ||
| adam_eps: 1e-8 | ||
|
|
||
| use_distributed_optimizer: True | ||
| use_precision_aware_optimizer: True | ||
|
|
||
| clip_grad: ${policy.max_grad_norm} | ||
|
|
||
| scheduler: | ||
| start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} | ||
| end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} | ||
| weight_decay_incr_style: "constant" | ||
| lr_decay_style: "constant" | ||
| lr_decay_iters: null | ||
| lr_warmup_iters: 2 | ||
| lr_warmup_init: 5.0e-8 | ||
|
|
||
| distributed_data_parallel_config: | ||
| grad_reduce_in_fp32: False | ||
| overlap_grad_reduce: True | ||
| overlap_param_gather: True | ||
| average_in_collective: True | ||
| use_custom_fsdp: False | ||
| data_parallel_sharding_strategy: "optim_grads_params" | ||
|
|
||
| env_vars: | ||
| PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False" | ||
|
|
||
| generation: | ||
| backend: "vllm" | ||
| max_new_tokens: ${policy.max_total_sequence_length} | ||
| temperature: 1.0 | ||
| top_p: 1.0 | ||
| top_k: null | ||
| stop_token_ids: null | ||
| stop_strings: null | ||
| vllm_cfg: | ||
| async_engine: False | ||
| precision: ${policy.precision} | ||
| tensor_parallel_size: 4 | ||
| pipeline_parallel_size: 1 | ||
| gpu_memory_utilization: 0.6 | ||
| max_model_len: ${policy.max_total_sequence_length} | ||
| # NB(pjin): https://github.com/NVIDIA-NeMo/RL/pull/857 | ||
| enforce_eager: True | ||
| colocated: | ||
| enabled: true | ||
| resources: | ||
| gpus_per_node: null | ||
| num_nodes: null | ||
|
|
||
| data: | ||
| dataset_name: "OpenMathInstruct-2" | ||
| shuffle: true | ||
| max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len | ||
| prompt_file: "examples/prompts/cot.txt" | ||
| system_prompt_file: null | ||
|
|
||
| env: | ||
| math: | ||
| num_workers: 8 | ||
|
|
||
| logger: | ||
| log_dir: logs/grpo-math-qwen3-30ba3b-megatron-tp4-32k | ||
| num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal | ||
| wandb_enabled: True | ||
| tensorboard_enabled: True | ||
| mlflow_enabled: False # Disable MLflow logging | ||
| monitor_gpus: False # If true, will monitor GPU usage and log to wandb and/or tensorboard | ||
| wandb: | ||
| project: nemo-rl | ||
| name: "grpo-math-qwen3-30ba3b-megatron-tp4-32k" | ||
| tensorboard: {} | ||
| gpu_monitoring: | ||
| collection_interval: 10 # How often to collect GPU usage metrics (in seconds) | ||
| flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) | ||
|
|
||
| cluster: | ||
| gpus_per_node: 8 | ||
| num_nodes: 4 | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what would be the reason to set this to False?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mostly for strict backward compat, but we could instead enable it by default (i.e. make it an opt-out config like
no_defer_fp32_logitsor similar)wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. How about the following:
wdyt? If the feature is broadly applicable we should probably switch it to true so no one else runs into the same issue (assuming no accuracy penalty)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup, (1) and then (2) SGTM!