-
Notifications
You must be signed in to change notification settings - Fork 306
feat: Implement ProRLv2 recipe #1809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
74 commits
Select commit
Hold shift + click to select a range
80c2bb2
init prorl
hijkzzz 7fc1869
fix paper link
hijkzzz 4112ef3
fix code
hijkzzz 69db96a
fix
hijkzzz 88c9231
fix async
hijkzzz ba604c2
fix log
hijkzzz 49152a6
fix
hijkzzz 8579b5a
fix
hijkzzz cc3ceca
refactor
hijkzzz 3e0bf44
refactor
hijkzzz cbbc279
fix
hijkzzz de32998
fix
hijkzzz 877778b
fix length penalty
hijkzzz 8cc3924
fix
hijkzzz 7912118
refactor
hijkzzz 74aed6b
fix
hijkzzz 177abd5
fix
hijkzzz ab95476
fix
hijkzzz d0e3c21
Fix comments for tis
hijkzzz 4de5015
update
hijkzzz 0d93255
update
hijkzzz d4b09ba
update
hijkzzz 34e670e
update
hijkzzz ce6c814
update
hijkzzz 8bdd906
update
hijkzzz 00c0fee
update
hijkzzz eb834d5
Update nemo_rl/algorithms/advantage_estimator.py
yfw 2dca4a7
fix: address yfw's code review comments
hijkzzz 073bf0c
fix: address reviewer comments for stop_properly_penalty and adv_esti…
hijkzzz 348d2f2
fix
hijkzzz 3615d8b
fix
hijkzzz 45693a0
fix
hijkzzz bcb6025
update
hijkzzz eb4aafc
update
hijkzzz e1008ed
pre-commit
hijkzzz 22cb5b9
update
hijkzzz 8a2559a
fix CI bugs
hijkzzz 8e52658
add docs
hijkzzz dbc2131
update
hijkzzz 4736d94
update docs
hijkzzz 58e7935
Merge branch 'main' into jianh/prorl
hijkzzz 6329fc4
fix test case
hijkzzz 1ba2f98
Add ProRLv2 functional and nightly tests
hijkzzz 3f5ffd6
Merge branch 'main' into jianh/prorl
hijkzzz 403fbce
update
hijkzzz 858d112
fix
hijkzzz 4c331ed
Merge branch 'main' into jianh/prorl
hijkzzz dfc3624
fix
hijkzzz 26d83b6
fix
hijkzzz 371b9c3
fix
hijkzzz 4fc1e12
Update nemo_rl/models/generation/interfaces.py
hijkzzz 7d463cb
Merge origin/main into jianh/prorl
hijkzzz 87d7b08
fix
hijkzzz 95e29c1
fix
hijkzzz 693d635
Update docs/guides/prorlv2.md
hijkzzz c0fd6e7
Update docs/guides/prorlv2.md
hijkzzz 1f1a2c6
Update docs/guides/prorlv2.md
hijkzzz bb6a0b3
Update docs/guides/prorlv2.md
hijkzzz 6138bd7
Update docs/guides/prorlv2.md
hijkzzz 6ee66bc
Update docs/guides/prorlv2.md
hijkzzz 790240a
Update docs/guides/prorlv2.md
hijkzzz ffbc1a4
Update docs/guides/prorlv2.md
hijkzzz dd6e831
Update docs/guides/prorlv2.md
hijkzzz 2b9afb3
Update docs/guides/prorlv2.md
hijkzzz 0ed9b8b
fix: disable dynamic sampling in prorlv2 L1 test and fix TIS defaults
hijkzzz 8321cb6
Merge branch 'main' into jianh/prorl
hijkzzz f66ecd7
Apply suggestion from @jgerh
hijkzzz 0152ffa
Merge branch 'main' into jianh/prorl
hijkzzz 7b7957d
fix test
hijkzzz abc73a2
Merge branch 'main' into jianh/prorl
hijkzzz 3593dcf
fix configs
hijkzzz 1a514fa
Merge remote-tracking branch 'origin/main' into jianh/prorl
hijkzzz 1abf129
Merge remote-tracking branch 'origin/main' into jianh/prorl
hijkzzz 3bb1d0f
fix: bump nightly GPU hours threshold to 1300
hijkzzz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,205 @@ | ||
| # An In-Depth Walkthrough of ProRLv2 in NeMo RL | ||
|
|
||
| This guide covers the ProRLv2 configuration pattern in NeMo RL, based on the example config [`examples/configs/prorlv2.yaml`](../../examples/configs/prorlv2.yaml). | ||
|
|
||
| ProRLv2 (as used in this repo) is best thought of as **GRPO and a bundle of stability/efficiency techniques** commonly used for long-horizon RL fine-tuning | ||
|
|
||
| - **DAPO dynamic sampling**: skip prompt-groups with zero reward variance | ||
| - **Decoupled (asymmetric) clipping**: `ratio_clip_max > ratio_clip_min` | ||
| - **Token-level policy gradient loss** | ||
| - **Importance sampling correction and TIS/CE-POP** (especially helpful for MoE/backend-mismatch scenarios) | ||
| - **Reinforce++: Decoupled local/global advantage normalization** (`reinforce_plus_plus`) | ||
| - **“Stop properly” penalty** for truncated responses | ||
|
|
||
| This document focuses on ProRLv2-specific knobs and gotchas. For foundational concepts on GRPO (data, environments, generation backends, loss/metrics), see the [NeMo RL GRPO Guide](grpo.md). For the original DAPO motivation behind dynamic sampling/overlong shaping, see the [NeMo RL DAPO Guide](dapo.md). | ||
|
|
||
| ## Quickstart: Launch a ProRLv2 Run | ||
|
|
||
| Use the example configuration [`examples/configs/prorlv2.yaml`](../../examples/configs/prorlv2.yaml): | ||
|
|
||
| ```bash | ||
| uv run examples/run_grpo_math.py --config examples/configs/prorlv2.yaml {overrides} | ||
| ``` | ||
|
|
||
| `prorlv2.yaml` inherits from [`examples/configs/grpo_math_1B.yaml`](../../examples/configs/grpo_math_1B.yaml) and only overrides a small set of fields under `grpo` and `loss_fn`, plus output directories. | ||
|
|
||
| **Reminder**: Don’t forget to set your `HF_HOME`, `WANDB_API_KEY`, and `HF_DATASETS_CACHE` (if needed). You’ll need to do a `huggingface-cli login` as well for gated models. | ||
|
|
||
| ## DAPO: Dynamic Sampling | ||
|
|
||
| Standard GRPO will train on all generated responses, even when a prompt’s `num_generations_per_prompt` responses all receive the same reward (no per-prompt learning signal). **Dynamic sampling** filters to keep only prompt-groups with diverse rewards (`std > 0`), and can accumulate across multiple generation batches until it reaches the target rollout batch size. | ||
|
|
||
| - **Config**: enable with `grpo.use_dynamic_sampling: true` and tune: | ||
| - `grpo.batch_multiplier`: how many extra prompts to generate to compensate filtering | ||
| - `grpo.dynamic_sampling_max_gen_batches`: upper bound before raising an error | ||
| - **Implementation**: see `dynamic_sampling()` in [`nemo_rl/algorithms/grpo.py`](../../nemo_rl/algorithms/grpo.py). | ||
|
|
||
| ## Advantage Estimator: Reinforce++ | ||
|
|
||
| The ProRLv2 recipe uses **Reinforce++** advantage estimation instead of the standard GRPO-style group baseline. | ||
|
|
||
| Quick intuition: | ||
|
|
||
| - Reinforce++ uses **decoupled local + global normalization**. | ||
| - Compared to GRPO-style **local-only normalization**, this decoupling can be **more stable** in longer runs (less sensitivity to per-batch scale/variance shifts). | ||
|
|
||
| Computation (as implemented in this repo, with the ProRLv2 example defaults): | ||
|
|
||
| ```text | ||
| Defaults in examples/configs/prorlv2.yaml: | ||
| grpo.adv_estimator.minus_baseline = true | ||
| loss_fn.use_kl_in_reward = false | ||
|
|
||
| Steps: | ||
| 1) Per prompt-group, compute mean reward, then subtract it: | ||
| a_i = r_i - mean_{j in same prompt} r_j | ||
|
|
||
| 2) Global normalize across *all valid response tokens* in the batch: | ||
| A <- (A - mean(A)) / sqrt(max(var(A), 1e-8)) | ||
| ``` | ||
|
|
||
| ```yaml | ||
| grpo: | ||
| adv_estimator: | ||
| name: "reinforce_plus_plus" | ||
| normalize_rewards: true | ||
| use_leave_one_out_baseline: false | ||
| minus_baseline: true | ||
| ``` | ||
|
|
||
| - **Config**: `grpo.adv_estimator.name: "reinforce_plus_plus"` | ||
| - **Implementation**: the training loop wires this via `ReinforcePlusPlusAdvantageEstimator` in [`nemo_rl/algorithms/grpo.py`](../../nemo_rl/algorithms/grpo.py). | ||
| - **Reference**: [REINFORCE++ paper](https://arxiv.org/abs/2501.03262) | ||
|
|
||
| ## Reward Shaping: “Stop properly” Penalty (Truncation Penalty) | ||
|
|
||
| When a generation hits the max length without emitting EOS, many pipelines mark it as **truncated**. The “stop properly” penalty scales the reward for truncated samples: | ||
|
|
||
| - `stop_properly_penalty_coef = 0.0`: truncated samples get **zero reward** | ||
| - `stop_properly_penalty_coef = 1.0`: **no penalty** (keep original rewards) | ||
| - Any value in \([0, 1]\) interpolates between the two. | ||
|
|
||
| In the example config: | ||
|
|
||
| ```yaml | ||
| grpo: | ||
| reward_shaping: | ||
| enabled: true | ||
| stop_properly_penalty_coef: 0.0 | ||
| ``` | ||
|
|
||
| - **Implementation**: `apply_reward_shaping()` in [`nemo_rl/algorithms/reward_functions.py`](../../nemo_rl/algorithms/reward_functions.py). | ||
|
|
||
| :::{important} | ||
| In the current implementation, if `stop_properly_penalty_coef` is set (not `null`), `apply_reward_shaping()` **returns early** after applying truncation scaling. That means you **cannot** apply DAPO "overlong reward shaping" in the same run unless you set `stop_properly_penalty_coef: null` and provide the DAPO overlong parameters (`overlong_buffer_length`, `overlong_buffer_penalty`, `max_response_length`). | ||
| ::: | ||
|
|
||
| ## Loss: Decoupled (Asymmetric) Clipping | ||
|
|
||
| ProRLv2 uses DAPO’s “decoupled clipping” idea by setting different lower/upper clip bounds: | ||
|
|
||
| ```yaml | ||
| loss_fn: | ||
| ratio_clip_min: 0.2 | ||
| ratio_clip_max: 0.27 | ||
| ``` | ||
|
|
||
| This keeps PPO/GRPO-style clipping behavior but allows a larger expansion region than the contraction region, which can help exploration and reduce early collapse. | ||
|
|
||
| - **Implementation**: `ClippedPGLossFn` documents decoupled clipping in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py). | ||
|
|
||
| ## Loss: Token-level Policy Gradient | ||
|
|
||
| ProRLv2 enables token-level loss: | ||
|
|
||
| ```yaml | ||
| loss_fn: | ||
| token_level_loss: true | ||
| ``` | ||
|
|
||
| This computes the policy gradient loss per token (under masking) instead of aggregating per sequence, which is often helpful for long CoT/variable-length rollouts. | ||
|
|
||
| ## Truncated Importance Sampling | ||
|
|
||
| When training and generation backends differ (e.g., numerics, precision, MoE routing, or vLLM vs training framework), you may see a mismatch between: | ||
|
|
||
| - `generation_logprobs` (logprobs under the generation backend that produced samples) | ||
| - `prev_logprobs` (logprobs under the training framework policy) | ||
|
|
||
| NeMo RL supports **importance sampling correction**, and ProRLv2’s example config turns it on together with **truncated importance sampling**. | ||
|
|
||
| Quick intuition: | ||
|
|
||
| - This is mainly useful for **MoE/backend mismatch** cases, where the generation backend and the training policy can disagree on logprobs. | ||
| - We compute an importance weight from `prev_logprobs` (training policy) vs `generation_logprobs` (generator). **ICE-POP** drops outliers by zeroing weights outside \([min, max]\). | ||
| - In the common setup of **one policy update per rollout batch** (i.e., minibatch equals the per-step rollout batch; no PPO multi-epoch reuse), the PPO/GRPO likelihood ratio term is effectively **1.0** at update time, so the main stability issue is the MoE/backend-mismatch importance weights. | ||
| - “Online ICE-POP” here just means applying that ICE-POP filtering **during loss computation** on the current training batch. | ||
|
|
||
| - **Reference**: [The Online IcePop Solution for MoE models](https://hijkzzz.notion.site/online-ice-pop) | ||
|
|
||
| ```yaml | ||
| loss_fn: | ||
| use_importance_sampling_correction: true | ||
| truncated_importance_sampling_ratio: 5.0 | ||
| truncated_importance_sampling_ratio_min: 0.5 | ||
| truncated_importance_sampling_type: "icepop" | ||
| ``` | ||
|
|
||
| - **`use_importance_sampling_correction`**: enable token-level importance weights (must be `true` for truncated IS) | ||
| - **`truncated_importance_sampling_ratio`**: upper bound (or upper threshold) | ||
| - **`truncated_importance_sampling_ratio_min`**: lower bound used by ICE-POP filtering | ||
| - **`truncated_importance_sampling_type`**: | ||
| - `"tis"`: clamp weights to `<= truncated_importance_sampling_ratio` | ||
| - `"icepop"`: set weights outside \([min, max]\) to zero (filter outliers) | ||
|
|
||
| - **Implementation**: see `ClippedPGLossFn` init-time checks and logic in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py). | ||
|
|
||
| ## Full Example Config (Annotated) | ||
|
|
||
| The ProRLv2 example config is intentionally small and relies on defaults from `grpo_math_1B.yaml`. | ||
|
|
||
| - **Example config**: [`examples/configs/prorlv2.yaml`](../../examples/configs/prorlv2.yaml) | ||
| - **Base defaults**: [`examples/configs/grpo_math_1B.yaml`](../../examples/configs/grpo_math_1B.yaml) | ||
|
|
||
| ## Practical Overrides | ||
|
|
||
| A few common overrides when launching: | ||
|
|
||
| ```bash | ||
| uv run examples/run_grpo_math.py \ | ||
| --config examples/configs/prorlv2.yaml \ | ||
| policy.model_name="Qwen/Qwen2.5-1.5B" \ | ||
| logger.wandb_enabled=true \ | ||
| logger.wandb.project="prorlv2-dev" \ | ||
| checkpointing.checkpoint_dir="results/prorlv2" \ | ||
| logger.log_dir="logs/prorlv2" | ||
| ``` | ||
|
|
||
| If you want to enable DAPO overlong reward shaping instead of stop-properly: | ||
|
|
||
| ```bash | ||
| uv run examples/run_grpo_math.py \ | ||
| --config examples/configs/prorlv2.yaml \ | ||
| grpo.reward_shaping.stop_properly_penalty_coef=null \ | ||
| grpo.reward_shaping.overlong_buffer_length=4096 \ | ||
| grpo.reward_shaping.overlong_buffer_penalty=1.0 \ | ||
| grpo.reward_shaping.max_response_length=20480 | ||
| ``` | ||
|
|
||
| ## What to Monitor | ||
|
|
||
| In addition to task rewards/accuracy, a few stability signals are particularly useful with ProRLv2-style runs: | ||
|
|
||
| - **Dynamic sampling efficiency**: if enabled, watch how often batches need multiple generation rounds (see `dapo.md` for detailed guidance). | ||
| - **Training–generation mismatch**: `token_mult_prob_error`, `gen_kl_error`, `policy_kl_error`, `js_divergence_error` are computed in `ClippedPGLossFn` (see the [GRPO metrics section](grpo.md#metrics)). | ||
| - **Truncation rate**: if high, either increase `policy.max_total_sequence_length`/`policy.generation.max_model_len` or relax truncation penalty (`stop_properly_penalty_coef`). | ||
|
|
||
| ## References | ||
|
|
||
| - **ProRLv2 blog**: [Scaling LLM Reinforcement Learning with Prolonged Training using ProRL v2](https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/) | ||
| - **DAPO**: [Decoupled Clip and Dynamic Sampling Policy Optimization](https://arxiv.org/pdf/2503.14476) | ||
| - **GRPO**: [Group Relative Policy Optimization](https://arxiv.org/abs/2402.03300) | ||
| - **REINFORCE++**: [REINFORCE++](https://arxiv.org/abs/2501.03262) | ||
| - **DLER (stop properly penalty explanation)**: [DLER](https://arxiv.org/pdf/2510.15110) | ||
| - **[NeMo RL GRPO Guide](grpo.md)** | ||
| - **[NeMo RL DAPO Guide](dapo.md)** |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| # ProRLv2 Algorithm Configuration | ||
| # | ||
| # This configuration implements ProRLv2 with TIS techniques: | ||
| # - Dynamic Sampling: Filter prompts with zero reward variance | ||
| # - Decoupled Clipping: Asymmetric ratio clipping (clip_max > clip_min) | ||
| # - Token-level Loss: Fine-grained policy gradient | ||
| # - Truncated Importance Sampling (TIS) / IcePop for MoE models | ||
| # - REINFORCE++: Decoupled local and global advantage normalization estimator | ||
| # - Stop properly penalty: Reward scale coefficient for truncated responses | ||
| # | ||
| # Inherits from grpo_math_1B.yaml | ||
| # | ||
| # Usage: | ||
| # python examples/run_grpo_math.py --config examples/configs/prorlv2.yaml | ||
| # | ||
| # Reference papers and blogs: | ||
| # ProRLv2: https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/ | ||
| # REINFORCE++: https://arxiv.org/abs/2501.03262 | ||
| # The Online IcePop Solution for MoE models: https://hijkzzz.notion.site/online-ice-pop | ||
| # DLER (for Stop properly penalty): https://arxiv.org/pdf/2510.15110 | ||
|
|
||
| defaults: "grpo_math_1B.yaml" | ||
|
|
||
| grpo: | ||
| # ============================================================================ | ||
| # DAPO: Dynamic Sampling | ||
| # Filter out prompts where all generations have the same reward (std=0) | ||
| # This focuses training on "learnable" examples with mixed outcomes | ||
| # ============================================================================ | ||
| use_dynamic_sampling: true | ||
| dynamic_sampling_max_gen_batches: 10 # Max batches before error | ||
| batch_multiplier: 1.5 # Generate more prompts to account for filtering | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # Advantage Estimator | ||
| # Options: "grpo" (default) or "reinforce_plus_plus" | ||
| # ============================================================================ | ||
| adv_estimator: | ||
| name: "reinforce_plus_plus" # Use "grpo" for standard GRPO | ||
| # Global normalization of rewards | ||
| normalize_rewards: true | ||
| use_leave_one_out_baseline: false | ||
| # Reinforce++-Baseline specific | ||
| minus_baseline: true | ||
|
|
||
| # ============================================================================ | ||
| # Reward Shaping | ||
| # Applied to rewards before advantage calculation | ||
| # Includes DAPO overlong penalty and stop properly penalty | ||
| # ============================================================================ | ||
| reward_shaping: | ||
| enabled: true | ||
| # Stop properly penalty: scale factor for truncated responses (0-1) | ||
| # 0 = zero reward for truncated (default), 1 = no penalty | ||
| stop_properly_penalty_coef: 0.0 # Set to e.g., 0.1 to halve truncated rewards | ||
|
|
||
| # ============================================================================ | ||
| # Loss Function Configuration | ||
| # ============================================================================ | ||
| loss_fn: | ||
| # KL regularization | ||
| reference_policy_kl_penalty: 0.0001 | ||
| reference_policy_kl_type: "k2" | ||
| kl_input_clamp_value: 20.0 | ||
| kl_output_clamp_value: 10.0 | ||
|
|
||
| # ============================================================================ | ||
| # DAPO: Decoupled (Asymmetric) Clipping | ||
| # ratio_clip_max > ratio_clip_min allows more exploration | ||
| # Standard PPO uses symmetric clipping (both = 0.2) | ||
| # ============================================================================ | ||
| ratio_clip_min: 0.2 | ||
| ratio_clip_max: 0.27 # Slightly larger for exploration | ||
|
|
||
| # Dual-clipping (set to e.g., 3.0 to enable, null to disable) | ||
| ratio_clip_c: null | ||
|
|
||
| # ============================================================================ | ||
| # DAPO: Token-level Loss | ||
| # Compute loss per-token instead of per-sequence | ||
| # ============================================================================ | ||
| token_level_loss: true | ||
|
|
||
| # ============================================================================ | ||
| # Truncated Importance Sampling (TIS / ICE-POP) | ||
| # Requires use_importance_sampling_correction: true | ||
| # ============================================================================ | ||
| use_importance_sampling_correction: true | ||
| truncated_importance_sampling_ratio: 5.0 # Upper bound | ||
| truncated_importance_sampling_ratio_min: 0.5 # Lower bound (ICE-POP only) | ||
| # Type: "tis" (clamp to max) or "icepop" (filter outside [min, max]) | ||
| truncated_importance_sampling_type: "icepop" | ||
|
|
||
| # Reinforce++: add KL penalty to reward instead of loss | ||
| # Set to false to use external KL loss (reference_policy_kl_penalty) for better stability | ||
| use_kl_in_reward: false | ||
hijkzzz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # ============================================================================ | ||
| # Output directories | ||
| # ============================================================================ | ||
| checkpointing: | ||
| checkpoint_dir: "results/prorl" | ||
|
|
||
| logger: | ||
| log_dir: "logs/prorl" | ||
29 changes: 29 additions & 0 deletions
29
examples/configs/recipes/llm/prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| defaults: ../../prorlv2.yaml | ||
| grpo: | ||
| max_num_steps: 450 | ||
| checkpointing: | ||
| checkpoint_dir: results/prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1 | ||
| policy: | ||
| model_name: Qwen/Qwen2.5-Math-1.5B-Instruct | ||
| tokenizer: | ||
| name: Qwen/Qwen2.5-Math-1.5B-Instruct | ||
| dynamic_batching: | ||
| enabled: true | ||
| sequence_packing: | ||
| enabled: false | ||
| make_sequence_length_divisible_by: 1 | ||
| generation: | ||
| max_new_tokens: 512 | ||
| vllm_cfg: | ||
| max_model_len: 512 | ||
| data: | ||
| max_input_seq_length: 512 | ||
| logger: | ||
| log_dir: logs/prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1 | ||
| wandb_enabled: true | ||
| tensorboard_enabled: true | ||
| wandb: | ||
| project: nemo-rl | ||
| name: prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1 | ||
| cluster: | ||
| gpus_per_node: 8 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.