Skip to content

feat: refactor train utilities for dtensor policy v2#1757

Merged
terrykong merged 6 commits intomainfrom
hemil/automodel-train-refactor
Feb 6, 2026
Merged

feat: refactor train utilities for dtensor policy v2#1757
terrykong merged 6 commits intomainfrom
hemil/automodel-train-refactor

Conversation

@hemildesai
Copy link
Contributor

@hemildesai hemildesai commented Jan 10, 2026

https://app.devin.ai/review/NVIDIA-NeMo/RL/pull/1757 helpful for review

Nightlies:

dpo-llama3.1-8b-instruct-4n8g-fsdp2tp2-quick.v2: https://wandb.ai/nvidia/nemo-rl/runs/fjrjhcsc

grpo-moonlight-16b-automodel-1n8g-ep8: https://wandb.ai/nvidia/nemo-rl/runs/gq9wibah

grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3: https://wandb.ai/nvidia/nemo-rl/runs/r8qaeedc

grpo-qwen3-8B-base-1n8g-fsdp2-lora: https://wandb.ai/nvidia/nemo-rl/runs/xf11uk3t

sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel: https://wandb.ai/nvidia/nemo-rl/runs/qsoscd63

sft-llama3.1-8b-1n8g-fsdp2tp1-lora: https://wandb.ai/nvidia/nemo-rl/runs/n5anaa2b

Issues

#1589

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced centralized training utilities with modular post-processing framework supporting loss computation, log-probability extraction, top-k logits, and score calculation
    • Enhanced training infrastructure with improved support for multimodal inputs and advanced parallel configurations
    • Added validation for model architecture consistency
  • Tests

    • Added comprehensive test coverage for training components and distributed execution scenarios

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 8dbdbef (PR #1757 from hemil/automodel-train-refactor)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@hemildesai hemildesai force-pushed the hemil/automodel-data-refactor branch from 2959be7 to be28e23 Compare January 20, 2026 20:10
@hemildesai hemildesai mentioned this pull request Jan 22, 2026
3 tasks
@hemildesai hemildesai force-pushed the hemil/automodel-data-refactor branch from 1ff4e30 to aba0a9d Compare January 27, 2026 20:23
Base automatically changed from hemil/automodel-data-refactor to main January 28, 2026 06:42
@hemildesai hemildesai force-pushed the hemil/automodel-train-refactor branch from 8dbdbef to 2c65235 Compare January 28, 2026 18:52
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 2c65235 (PR #1757 from hemil/automodel-train-refactor)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@hemildesai hemildesai marked this pull request as ready for review January 28, 2026 18:53
@hemildesai hemildesai requested review from a team as code owners January 28, 2026 18:53
@hemildesai hemildesai added the CI:L1 Run doctests, unit tests, and functional tests label Jan 28, 2026
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 28, 2026

📝 Walkthrough

Walkthrough

Introduces a modular post-processing framework for automodel training with DTensors. Adds a sequence dimensionality validator in data.py, creates a comprehensive training utilities module (train.py) providing forward/backward orchestration and post-processors (Loss, Logprobs, TopK-Logits, Score), and refactors dtensor_policy_worker_v2.py to use these centralized components instead of ad-hoc logic.

Changes

Cohort / File(s) Summary
Core automodel training infrastructure
nemo_rl/models/automodel/data.py
Added check_sequence_dim() validation function to assert consistent sequence dimensionality across batched tensors; returns sequence dimension and size.
Core automodel training infrastructure
nemo_rl/models/automodel/train.py
New training utilities module introducing model forward/logits extraction, temperature scaling, DTensor preparation for context/tensor parallelism, and modular post-processors (LossPostProcessor, LogprobsPostProcessor, TopkLogitsPostProcessor, ScorePostProcessor). Provides automodel_forward_backward() and forward_with_post_processing_fn() for unified training loop orchestration with support for sequence packing, distributed resharding, and microbatch metrics.
Policy worker integration
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Refactored training, get_logprobs, score, and get_topk_logits methods to use new post-processor classes and centralized training functions; removed inline temperature scaling, DTensor handling, and loss computation logic; added sequence dimension validation via check_sequence_dim(), train-context factory, and optional CUDA cache-clearing hook via on_microbatch_start.
Unit tests
tests/unit/models/automodel/test_automodel_train.py
Comprehensive test suite covering model_forward, extract_logits, temperature scaling, all post-processors, sequence packing integration, DTensor redistribution, microbatch processing, and end-to-end orchestration with mock models and data iterators.

Sequence Diagram

sequenceDiagram
    participant DataIter as Data Iterator
    participant FwdBwd as automodel_forward_backward()
    participant MB as ProcessedMicrobatch
    participant Model as Model.forward()
    participant PostProc as PostProcessor.__call__()
    participant Loss as Loss Computation
    participant Backward as .backward()

    DataIter->>FwdBwd: iterate ProcessedMicrobatch
    loop For each microbatch
        FwdBwd->>MB: get processed_mb
        FwdBwd->>Model: model_forward(processed_mb)
        Model-->>FwdBwd: logits
        FwdBwd->>PostProc: __call__(logits, mb, ...)
        PostProc->>Loss: compute loss
        Loss-->>PostProc: loss_tensor
        PostProc-->>FwdBwd: (result, metrics)
        alt If not forward_only
            FwdBwd->>Backward: loss.backward()
            Backward-->>FwdBwd: gradients computed
        end
        FwdBwd->>FwdBwd: collect metrics & results
    end
    FwdBwd-->>DataIter: list[(result, metrics), ...]
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

CI:L1

Suggested reviewers

  • terrykong
  • yuki-97
🚥 Pre-merge checks | ✅ 2 | ❌ 2
❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 31.88% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Test Results For Major Changes ⚠️ Warning PR implements major refactoring with significant new functionality but lacks concrete description, test results, performance benchmarks, and evidence of no regressions. Update PR description with unit test results, comprehensive test suite outcomes, evidence of preserved numeric/convergence behavior, and performance impact analysis.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main change: refactoring train utilities for dtensor policy v2, which aligns with the additions of new train utilities module, post-processing classes, and integration into the policy worker.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch hemil/automodel-train-refactor

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
nemo_rl/models/automodel/data.py (1)

1-13: Update the copyright year to 2026.

This is a non‑test Python file modified in this PR; the header still shows 2025.

🛠️ Proposed fix
-# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# Copyright (c) 2026, NVIDIA CORPORATION.  All rights reserved.

As per coding guidelines "The NVIDIA copyright header should include the current year".

nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)

1-13: Update the copyright year to 2026.

This non‑test file was modified but still uses the 2025 header.

🛠️ Proposed fix
-# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+# Copyright (c) 2026, NVIDIA CORPORATION.  All rights reserved.

As per coding guidelines "The NVIDIA copyright header should include the current year".

🤖 Fix all issues with AI agents
In `@nemo_rl/models/automodel/data.py`:
- Around line 362-366: The loop in check_sequence_dim uses an unused variable
named k which triggers lint warnings; rename k to _ or _key in the for loop
header (for _key, v in data.items() or for _, v in data.items()) so the unused
key is clear to linters while keeping the logic that checks torch tensors and
sequence_dim size in the existing block (references: function
check_sequence_dim, loop variables k and v, variable sequence_dim and
seq_dim_size).

In `@nemo_rl/models/automodel/train.py`:
- Around line 1-13: Update the copyright header in
nemo_rl/models/automodel/train.py by changing the year from 2025 to 2026 at the
top of the file; locate the existing Apache License block in train.py and
replace the "2025" year token with "2026" so the file header reflects the
current year.
- Around line 294-306: forward_with_post_processing_fn currently applies
apply_temperature_scaling(logits, cfg) before calling post_processing_fn, which
causes LossPostProcessor and ScorePostProcessor to see temperature‑scaled
logits; change the logic so temperature scaling is only applied for
sampling-oriented post processors (e.g., those used for logprob/top‑k sampling
or DistributedLogprobWithSampling) or gated by an explicit flag (e.g.,
cfg.apply_temperature_for_sampling). Locate the scaling call around
forward_with_post_processing_fn and replace it with a conditional that checks
the post_processing_fn type (exclude LossPostProcessor and ScorePostProcessor)
or checks the new cfg flag, ensuring LossPostProcessor and ScorePostProcessor
receive unscaled logits while sampling post processors still get
temperature‑scaled logits.
- Around line 140-176: The cp_mesh parameter in redistribute_logits_for_cp is
unused and triggers lint ARG001; keep the signature for API stability but
explicitly mark it unused by adding a no-op reference such as "_ = cp_mesh" or
"del cp_mesh" at the start of the function (with a short comment like "#
intentionally unused" or "# silence ARG001") so linters stop complaining while
behavior remains unchanged.

In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py`:
- Around line 301-303: The unused variable warnings come from unpacking
check_sequence_dim(data) into sequence_dim, seq_dim_size and from unused metrics
in post-processing forward calls; to silence Ruff, rename unused bindings by
prefixing with an underscore (e.g., use sequence_dim, _seq_dim_size =
check_sequence_dim(data)) and rename unused metrics to _metrics in the
post-processing forward invocations (or assign to _metrics where returned), and
apply the same pattern wherever seq_dim_size or metrics are currently unused in
train(), score(), and the post-processing forward calls in
DTensorPolicyWorkerV2.
- Around line 339-341: The warning emitted in dtensor_policy_worker_v2 (inside
the code that calls warnings.warn) has a spelling mistake ("unnnecessarily") and
is missing a stacklevel; update the warnings.warn call in
dtensor_policy_worker_v2 (the cache-clearing warning) to correct the message to
"unnecessarily" and pass stacklevel=2 so the warning points to the caller (i.e.,
change the warnings.warn invocation used for emptying cache every
{empty_cache_steps} microbatches to include the corrected string and
stacklevel=2).

In `@tests/unit/models/automodel/test_automodel_train.py`:
- Line 147: Tests assign unused variables (e.g., result, loss, metrics) from
calls like model_forward(...) which Ruff flags; remove the unused assignments by
either calling the function without assignment (model_forward(...)) or renaming
the targets to throwaway names (_ , _loss, _metrics) wherever you see result,
loss, or metrics being assigned (including uses of model_forward and any
train/metric-returning helpers referenced in the diff) so the values are still
produced but the linter no longer reports unused variables.
🧹 Nitpick comments (1)
nemo_rl/models/automodel/train.py (1)

216-218: Prefer isinstance over type(...) == torch.Tensor.

type(current_tensor) == torch.Tensor rejects subclasses (e.g., torch.nn.Parameter). Use isinstance for idiomatic and safer checks.

🛠️ Proposed fix
-                assert type(current_tensor) == torch.Tensor, (
+                assert isinstance(current_tensor, torch.Tensor), (
                     f"tensor {tensor_name} is not a tensor"
                 )

@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 3618531 (PR #1757 from hemil/automodel-train-refactor)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@hemildesai hemildesai added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Jan 29, 2026
@hemildesai hemildesai force-pushed the hemil/automodel-train-refactor branch from 3618531 to 670a2cf Compare January 29, 2026 06:18
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 670a2cf (PR #1757 from hemil/automodel-train-refactor)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@hemildesai hemildesai added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Jan 29, 2026
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 7cbeac7 (PR #1757 from hemil/automodel-train-refactor)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@hemildesai hemildesai removed the CI:L1 Run doctests, unit tests, and functional tests label Jan 31, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai hemildesai force-pushed the hemil/automodel-train-refactor branch from 7cbeac7 to 5359dbf Compare February 4, 2026 22:03
@github-actions
Copy link

github-actions bot commented Feb 4, 2026

⚠️ File Consistency Check

Check based on commit: 5359dbf (PR #1757 from hemil/automodel-train-refactor)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: Hemil Desai <hemild@nvidia.com>
@github-actions
Copy link

github-actions bot commented Feb 5, 2026

⚠️ File Consistency Check

Check based on commit: bca57d5 (PR #1757 from hemil/automodel-train-refactor)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@hemildesai hemildesai added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 5, 2026
@terrykong terrykong enabled auto-merge (squash) February 6, 2026 04:18
@terrykong terrykong added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Feb 6, 2026
@terrykong terrykong merged commit 3bdb852 into main Feb 6, 2026
62 of 63 checks passed
@terrykong terrykong deleted the hemil/automodel-train-refactor branch February 6, 2026 12:51
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 12, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 9, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants