feat: refactor train utilities for dtensor policy v2#1757
Conversation
|
2959be7 to
be28e23
Compare
1ff4e30 to
aba0a9d
Compare
8dbdbef to
2c65235
Compare
|
📝 WalkthroughWalkthroughIntroduces a modular post-processing framework for automodel training with DTensors. Adds a sequence dimensionality validator in data.py, creates a comprehensive training utilities module (train.py) providing forward/backward orchestration and post-processors (Loss, Logprobs, TopK-Logits, Score), and refactors dtensor_policy_worker_v2.py to use these centralized components instead of ad-hoc logic. Changes
Sequence DiagramsequenceDiagram
participant DataIter as Data Iterator
participant FwdBwd as automodel_forward_backward()
participant MB as ProcessedMicrobatch
participant Model as Model.forward()
participant PostProc as PostProcessor.__call__()
participant Loss as Loss Computation
participant Backward as .backward()
DataIter->>FwdBwd: iterate ProcessedMicrobatch
loop For each microbatch
FwdBwd->>MB: get processed_mb
FwdBwd->>Model: model_forward(processed_mb)
Model-->>FwdBwd: logits
FwdBwd->>PostProc: __call__(logits, mb, ...)
PostProc->>Loss: compute loss
Loss-->>PostProc: loss_tensor
PostProc-->>FwdBwd: (result, metrics)
alt If not forward_only
FwdBwd->>Backward: loss.backward()
Backward-->>FwdBwd: gradients computed
end
FwdBwd->>FwdBwd: collect metrics & results
end
FwdBwd-->>DataIter: list[(result, metrics), ...]
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
nemo_rl/models/automodel/data.py (1)
1-13: Update the copyright year to 2026.This is a non‑test Python file modified in this PR; the header still shows 2025.
🛠️ Proposed fix
-# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.As per coding guidelines "The NVIDIA copyright header should include the current year".
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
1-13: Update the copyright year to 2026.This non‑test file was modified but still uses the 2025 header.
🛠️ Proposed fix
-# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.As per coding guidelines "The NVIDIA copyright header should include the current year".
🤖 Fix all issues with AI agents
In `@nemo_rl/models/automodel/data.py`:
- Around line 362-366: The loop in check_sequence_dim uses an unused variable
named k which triggers lint warnings; rename k to _ or _key in the for loop
header (for _key, v in data.items() or for _, v in data.items()) so the unused
key is clear to linters while keeping the logic that checks torch tensors and
sequence_dim size in the existing block (references: function
check_sequence_dim, loop variables k and v, variable sequence_dim and
seq_dim_size).
In `@nemo_rl/models/automodel/train.py`:
- Around line 1-13: Update the copyright header in
nemo_rl/models/automodel/train.py by changing the year from 2025 to 2026 at the
top of the file; locate the existing Apache License block in train.py and
replace the "2025" year token with "2026" so the file header reflects the
current year.
- Around line 294-306: forward_with_post_processing_fn currently applies
apply_temperature_scaling(logits, cfg) before calling post_processing_fn, which
causes LossPostProcessor and ScorePostProcessor to see temperature‑scaled
logits; change the logic so temperature scaling is only applied for
sampling-oriented post processors (e.g., those used for logprob/top‑k sampling
or DistributedLogprobWithSampling) or gated by an explicit flag (e.g.,
cfg.apply_temperature_for_sampling). Locate the scaling call around
forward_with_post_processing_fn and replace it with a conditional that checks
the post_processing_fn type (exclude LossPostProcessor and ScorePostProcessor)
or checks the new cfg flag, ensuring LossPostProcessor and ScorePostProcessor
receive unscaled logits while sampling post processors still get
temperature‑scaled logits.
- Around line 140-176: The cp_mesh parameter in redistribute_logits_for_cp is
unused and triggers lint ARG001; keep the signature for API stability but
explicitly mark it unused by adding a no-op reference such as "_ = cp_mesh" or
"del cp_mesh" at the start of the function (with a short comment like "#
intentionally unused" or "# silence ARG001") so linters stop complaining while
behavior remains unchanged.
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py`:
- Around line 301-303: The unused variable warnings come from unpacking
check_sequence_dim(data) into sequence_dim, seq_dim_size and from unused metrics
in post-processing forward calls; to silence Ruff, rename unused bindings by
prefixing with an underscore (e.g., use sequence_dim, _seq_dim_size =
check_sequence_dim(data)) and rename unused metrics to _metrics in the
post-processing forward invocations (or assign to _metrics where returned), and
apply the same pattern wherever seq_dim_size or metrics are currently unused in
train(), score(), and the post-processing forward calls in
DTensorPolicyWorkerV2.
- Around line 339-341: The warning emitted in dtensor_policy_worker_v2 (inside
the code that calls warnings.warn) has a spelling mistake ("unnnecessarily") and
is missing a stacklevel; update the warnings.warn call in
dtensor_policy_worker_v2 (the cache-clearing warning) to correct the message to
"unnecessarily" and pass stacklevel=2 so the warning points to the caller (i.e.,
change the warnings.warn invocation used for emptying cache every
{empty_cache_steps} microbatches to include the corrected string and
stacklevel=2).
In `@tests/unit/models/automodel/test_automodel_train.py`:
- Line 147: Tests assign unused variables (e.g., result, loss, metrics) from
calls like model_forward(...) which Ruff flags; remove the unused assignments by
either calling the function without assignment (model_forward(...)) or renaming
the targets to throwaway names (_ , _loss, _metrics) wherever you see result,
loss, or metrics being assigned (including uses of model_forward and any
train/metric-returning helpers referenced in the diff) so the values are still
produced but the linter no longer reports unused variables.
🧹 Nitpick comments (1)
nemo_rl/models/automodel/train.py (1)
216-218: Preferisinstanceovertype(...) == torch.Tensor.
type(current_tensor) == torch.Tensorrejects subclasses (e.g.,torch.nn.Parameter). Useisinstancefor idiomatic and safer checks.🛠️ Proposed fix
- assert type(current_tensor) == torch.Tensor, ( + assert isinstance(current_tensor, torch.Tensor), ( f"tensor {tensor_name} is not a tensor" )
|
3618531 to
670a2cf
Compare
|
|
7cbeac7 to
5359dbf
Compare
|
|
Signed-off-by: Hemil Desai <hemild@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
https://app.devin.ai/review/NVIDIA-NeMo/RL/pull/1757 helpful for review
Nightlies:
dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2: https://wandb.ai/nvidia/nemo-rl/runs/fjrjhcsc
grpo-moonlight-16b-automodel-1n8g-ep8: https://wandb.ai/nvidia/nemo-rl/runs/gq9wibah
grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3: https://wandb.ai/nvidia/nemo-rl/runs/r8qaeedc
grpo-qwen3-8B-base-1n8g-fsdp2-lora: https://wandb.ai/nvidia/nemo-rl/runs/xf11uk3t
sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel: https://wandb.ai/nvidia/nemo-rl/runs/qsoscd63
sft-llama3.1-8b-1n8g-fsdp2tp1-lora: https://wandb.ai/nvidia/nemo-rl/runs/n5anaa2b
Issues
#1589
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.