Skip to content

Conversation

@ybgao-nvidia
Copy link
Contributor

@ybgao-nvidia ybgao-nvidia commented Aug 14, 2025

What does this PR do ?

Memory optimizations

This PR applies memory optimizations that allows for single-node (8xH100) training of the Nemotron 12B model with sequence length 12288.

We need the following optimizations to make 12K context work:

  1. the RMSNorm layers are not being checkpointed and they contribute 1/3 of the memory footprint for forward/backward (excl. weights, optimizer and gradients)
  2. there's a lot of fragmentation due to the lack of support for FlashAttention2 for the Mistral model Nemotron 12B is based on (attention allocates ~6GB of memory at a time) and leaves a 8-9GB difference between active and resident memory
  3. adding an occasional torch gc got us the rest of the way there to 12K

The additional checkpointed layers provides a significant decrease in peak memory usage with minimal performance impact. However, enabling a smaller max_split_size in the allocator does increase the step time slightly. The collated performance results are below:

Seqlen Config Peak Allocated (GB) Peak Reserved (GB) Step Time (s)
5500 baseline 61.66 67.69 13.88
+checkpoint norm 42.71 48.75 10.04
+allocator frag 43.05 51.20 14.04
8192 baseline OOM
+checkpoint norm 52.17 60.38 12.94
+allocator frag 52.42 61.08 16.36
12288 baseline OOM
+checkpoint norm OOM
(66.56)
OOM
(73.24)
18.19
+allocator frag 66.81 73.26 23.43

Removal of configure_expandable_segments

Furthermore, the current implementation of configure_expandable_segments does not actually perform its intended function.

  • It calls torch.cuda.get_device_properties(0).major which initializes torch, including the memory allocator. The subsequent assignment to the environment variable will therefore not affect the allocator. Instead, the torch.cuda.memory._set_allocator_settings function should be used.

However, setting expandable segments results in minimal affect to peak memory usage while causing a large performance overhead (from 20s to 80s per training iteration).

We have deleted the function and the related invocations and tests to keep the runtime behaviour consistent. Should the need arise to set expandable segments, the user shall do so instead in the env_vars in the recipe configuration.

Minor fixes for config schema

Some tweaks are done to make config validation pass.

  • Made tensorboard field of logger optional

Issues

This PR resolves #848.

Usage

It is recommended to run DPO training with PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:64 to reduce allocator fragmentation.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

@ybgao-nvidia ybgao-nvidia marked this pull request as ready for review August 14, 2025 19:50
Copy link
Contributor

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the improving performance @ybgao-nvidia !

parthchadha
parthchadha previously approved these changes Aug 18, 2025
Signed-off-by: Yubo Gao <[email protected]>
Signed-off-by: Yubo Gao <[email protected]>
@wangshangsam
Copy link
Contributor

However, setting expandable segments results in minimal affect to peak memory usage while causing a large performance overhead (from 20s to 80s per training iteration). We have disabled it by default.

Wait ... where is it disabled by default?

wangshangsam
wangshangsam previously approved these changes Aug 20, 2025
Copy link
Contributor

@wangshangsam wangshangsam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small nits, but otherwise LGTM

Co-authored-by: Shang Wang <[email protected]>
Signed-off-by: Yubo Gao <[email protected]>
Signed-off-by: Yubo Gao <[email protected]>
@ybgao-nvidia ybgao-nvidia removed the CI:L1 Run doctests, unit tests, and functional tests label Aug 25, 2025
@github-actions
Copy link

⚠️ File Synchronization Check

Check based on commit: 189868b (PR #926 from ybgao/aug13-dpo-12k-memory)

⚠️ Parallel Plans Synchronization Warning

The file nemo_rl/models/dtensor/parallelize.py was modified in this PR, but 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py was not updated.

Why this matters:
These files contain similar parallel plan implementations that should be kept synchronized to ensure consistency across the codebase.

Action required:

  • Please review if the changes in nemo_rl/models/dtensor/parallelize.py should also be applied to 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py
  • Update 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py if necessary to maintain synchronization
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/dtensor/parallelize.py
  • Not modified: 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@ybgao-nvidia ybgao-nvidia added the CI:L1 Run doctests, unit tests, and functional tests label Aug 25, 2025
@github-actions
Copy link

⚠️ File Synchronization Check

Check based on commit: 189868b (PR #926 from ybgao/aug13-dpo-12k-memory)

⚠️ Parallel Plans Synchronization Warning

The file nemo_rl/models/dtensor/parallelize.py was modified in this PR, but 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py was not updated.

Why this matters:
These files contain similar parallel plan implementations that should be kept synchronized to ensure consistency across the codebase.

Action required:

  • Please review if the changes in nemo_rl/models/dtensor/parallelize.py should also be applied to 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py
  • Update 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py if necessary to maintain synchronization
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/dtensor/parallelize.py
  • Not modified: 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link

⚠️ File Synchronization Check

Check based on commit: 7573f6d (PR #926 from ybgao/aug13-dpo-12k-memory)

⚠️ Parallel Plans Synchronization Warning

The file nemo_rl/models/dtensor/parallelize.py was modified in this PR, but 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py was not updated.

Why this matters:
These files contain similar parallel plan implementations that should be kept synchronized to ensure consistency across the codebase.

Action required:

  • Please review if the changes in nemo_rl/models/dtensor/parallelize.py should also be applied to 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py
  • Update 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py if necessary to maintain synchronization
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/dtensor/parallelize.py
  • Not modified: 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: Yubo Gao <[email protected]>
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 3271a08 (PR #926 from ybgao/aug13-dpo-12k-memory)

⚠️ Parallel Plans Synchronization Warning

The file nemo_rl/models/dtensor/parallelize.py was modified in this PR, but neither 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py nor 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/parallelizer.py was updated.

Why this matters:
These files contain similar parallel plan implementations that should be kept synchronized to ensure consistency across the codebase.

Action required:

  • Please review if the changes in nemo_rl/models/dtensor/parallelize.py should also be applied to 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py or 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/parallelizer.py
  • Update the appropriate related file(s) if necessary to maintain functional consistency
  • Request access to the NVIDIA-NeMo/Automodel repository, create a PR against the nemo-rl-submodule branch, and update the Automodel submodule in the nemo-rl index
  • Add @ffrujeri as a reviewer of this PR if you have any questions about the consistency requirements
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/dtensor/parallelize.py
  • Not modified: 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py
  • Not modified: 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/parallelizer.py

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: b97abd2 (PR #926 from ybgao/aug13-dpo-12k-memory)

⚠️ Parallel Plans Synchronization Warning

The file nemo_rl/models/dtensor/parallelize.py was modified in this PR, but neither 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py nor 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/parallelizer.py was updated.

Why this matters:
These files contain similar parallel plan implementations that should be kept synchronized to ensure consistency across the codebase.

Action required:

  • Please review if the changes in nemo_rl/models/dtensor/parallelize.py should also be applied to 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py or 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/parallelizer.py
  • Update the appropriate related file(s) if necessary to maintain functional consistency
  • Request access to the NVIDIA-NeMo/Automodel repository, create a PR against the nemo-rl-submodule branch, and update the Automodel submodule in the nemo-rl index
  • Add @ffrujeri as a reviewer of this PR if you have any questions about the consistency requirements
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/dtensor/parallelize.py
  • Not modified: 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py
  • Not modified: 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/parallelizer.py

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Copy link
Contributor

@wangshangsam wangshangsam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangshangsam
Copy link
Contributor

⚠️ File Consistency Check

Check based on commit: b97abd2 (PR #926 from ybgao/aug13-dpo-12k-memory)

⚠️ Parallel Plans Synchronization Warning

The file nemo_rl/models/dtensor/parallelize.py was modified in this PR, but neither 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py nor 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/parallelizer.py was updated.

Why this matters: These files contain similar parallel plan implementations that should be kept synchronized to ensure consistency across the codebase.

Action required:

  • Please review if the changes in nemo_rl/models/dtensor/parallelize.py should also be applied to 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py or 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/parallelizer.py
  • Update the appropriate related file(s) if necessary to maintain functional consistency
  • Request access to the NVIDIA-NeMo/Automodel repository, create a PR against the nemo-rl-submodule branch, and update the Automodel submodule in the nemo-rl index
  • Add @ffrujeri as a reviewer of this PR if you have any questions about the consistency requirements
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/dtensor/parallelize.py
  • Not modified: 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/optimized_tp_plans.py
  • Not modified: 3rdparty/Automodel-workspace/Automodel/nemo_automodel/components/distributed/parallelizer.py

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Corresponding fix in Automodel: NVIDIA-NeMo/Automodel#391

@terrykong terrykong added this pull request to the merge queue Aug 26, 2025
Merged via the queue into main with commit 989f177 Aug 26, 2025
21 checks passed
@terrykong terrykong deleted the ybgao/aug13-dpo-12k-memory branch August 26, 2025 22:34
soodoshll pushed a commit to soodoshll/RL that referenced this pull request Aug 28, 2025
…IDIA-NeMo#926)

Signed-off-by: Yubo Gao <[email protected]>
Co-authored-by: Shang Wang <[email protected]>
Co-authored-by: Terry Kong <[email protected]>
Signed-off-by: Qidong Su <[email protected]>
skirdey-inflection pushed a commit to skirdey-inflection/RL that referenced this pull request Aug 30, 2025
…IDIA-NeMo#926)

Signed-off-by: Yubo Gao <[email protected]>
Co-authored-by: Shang Wang <[email protected]>
Co-authored-by: Terry Kong <[email protected]>
Signed-off-by: Stanislav Kirdey <[email protected]>
soodoshll pushed a commit to soodoshll/RL that referenced this pull request Sep 4, 2025
…IDIA-NeMo#926)

Signed-off-by: Yubo Gao <[email protected]>
Co-authored-by: Shang Wang <[email protected]>
Co-authored-by: Terry Kong <[email protected]>
Signed-off-by: Qidong Su <[email protected]>
PrinsYin pushed a commit to PrinsYin/RL that referenced this pull request Nov 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Nemotron 12B][DPO] GPU memory footprint higher than expectation

5 participants