test for moe activation vram leak #3649
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis pull request adds a regression test for activation offloading state leakage and updates a training example configuration. The test validates that the offloading context state is properly cleared between training steps by instrumenting ChangesActivation Offloading State Leakage Fix
Estimated Code Review Effort🎯 2 (Simple) | ⏱️ ~12 minutes Suggested Reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tests/e2e/test_activation_offloading.py (2)
95-99: ⚡ Quick winRecorded
bwd_tensor_stash/bwd_ev_stashare never asserted.You instrument both backward stashes on lines 95–99 but only assert
tracker,storage_dedup, andfwd_stash(lines 119–131). Either drop the unused recordings, or add the matching assertions so a backward-side leak doesn't slip past this regression guard.Suggested addition
for rec in recorded_states: assert rec["tracker"] == 0, ( f"OffloadActivations.tracker not empty at start of step " f"{rec['step']}: {rec} — cross-step leak (`#3638`) regressed" ) assert rec["storage_dedup"] == 0, ( f"OffloadActivations.storage_to_tensor_id not empty at start " f"of step {rec['step']}: {rec}" ) assert rec["fwd_stash"] == 0, ( f"OffloadActivations.fwd_stash not empty at start of step " f"{rec['step']}: {rec}" ) + assert rec["bwd_tensor_stash"] == 0, ( + f"OffloadActivations.bwd_tensor_stash not empty at start of " + f"step {rec['step']}: {rec}" + ) + assert rec["bwd_ev_stash"] == 0, ( + f"OffloadActivations.bwd_ev_stash not empty at start of step " + f"{rec['step']}: {rec}" + )Also applies to: 119-131
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/e2e/test_activation_offloading.py` around lines 95 - 99, The test records bwd_tensor_stash and bwd_ev_stash (via getattr on ctx) but never asserts them, so add assertions mirroring the existing ones for tracker/storage_dedup/fwd_stash to prevent backward-side leaks: after the current assertions for tracker, storage_dedup, and fwd_stash in the test (the same scope that references ctx), assert that the lengths of bwd_tensor_stash and bwd_ev_stash equal the expected values (likely 0) using the same assertion style; alternatively remove the instrumentation lines if you prefer not to check backward stashes. Ensure you reference the exact attributes bwd_tensor_stash and bwd_ev_stash on ctx to keep consistency with the recorded values.
79-79: 💤 Low valueHoist the local import to module top.
from axolotl.core.trainers.mixins import activation_checkpointing as ac_modis only used to set up the monkeypatch; placing it at the top alongside the other axolotl imports keeps import-time errors visible early and follows the convention used elsewhere in this file.Proposed change
from axolotl.common.datasets import load_datasets +from axolotl.core.trainers.mixins import activation_checkpointing as ac_mod from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault @@ - from axolotl.core.trainers.mixins import activation_checkpointing as ac_mod - recorded_states: list[dict] = []🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/e2e/test_activation_offloading.py` at line 79, The local import "from axolotl.core.trainers.mixins import activation_checkpointing as ac_mod" should be hoisted to the module top with the other axolotl imports so import-time errors surface early and follow the file convention; move the import out of the test body into the file-level imports, keep the alias ac_mod, and ensure the test's monkeypatch setup still references ac_mod (used in the activation checkpointing monkeypatching).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/e2e/test_activation_offloading.py`:
- Around line 79-101: Run ruff format on tests/e2e/test_activation_offloading.py
and commit the reformat so the pre-commit hook passes; specifically reformat the
block that assigns original_training_step =
ac_mod.ActivationOffloadingMixin.training_step and the recording_training_step
function (the lines referencing recorded_states, recording_training_step,
activation_offload_context, ac_mod.OffloadActivations, and the
bwd_tensor_stash/bwd_ev_stash collects) to match ruff's wrapping/style (you can
run `ruff format tests/e2e/test_activation_offloading.py` or `pre-commit run
--all-files` to apply the exact changes).
- Around line 75-107: The monkeypatched recording_training_step incorrectly
reads self._offload_step_counter (which doesn't exist) causing AttributeError;
change it to read the counter from the offload context
(ctx._offload_step_counter) and/or safely use getattr(ctx,
"_offload_step_counter", None) when building the recorded_states entry; update
the recording_training_step used to monkeypatch
ac_mod.ActivationOffloadingMixin.training_step so it checks isinstance(ctx,
ac_mod.OffloadActivations) then uses ctx._offload_step_counter (or the getattr
fallback) when appending to recorded_states to ensure the test doesn't crash.
---
Nitpick comments:
In `@tests/e2e/test_activation_offloading.py`:
- Around line 95-99: The test records bwd_tensor_stash and bwd_ev_stash (via
getattr on ctx) but never asserts them, so add assertions mirroring the existing
ones for tracker/storage_dedup/fwd_stash to prevent backward-side leaks: after
the current assertions for tracker, storage_dedup, and fwd_stash in the test
(the same scope that references ctx), assert that the lengths of
bwd_tensor_stash and bwd_ev_stash equal the expected values (likely 0) using the
same assertion style; alternatively remove the instrumentation lines if you
prefer not to check backward stashes. Ensure you reference the exact attributes
bwd_tensor_stash and bwd_ev_stash on ctx to keep consistency with the recorded
values.
- Line 79: The local import "from axolotl.core.trainers.mixins import
activation_checkpointing as ac_mod" should be hoisted to the module top with the
other axolotl imports so import-time errors surface early and follow the file
convention; move the import out of the test body into the file-level imports,
keep the alias ac_mod, and ensure the test's monkeypatch setup still references
ac_mod (used in the activation checkpointing monkeypatching).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: fb962756-19ae-45de-b8bb-0c914c7dddd8
📒 Files selected for processing (2)
examples/mistral4/qlora-text.ymltests/e2e/test_activation_offloading.py
| # Record OffloadActivations state at the start of each training_step. | ||
| # Regression guard for #3638: tracker / dedup map / forward stash must | ||
| # be empty at the start of every step. With the leak (pre-fix), these | ||
| # grow monotonically and pin GPU memory until OOM. | ||
| from axolotl.core.trainers.mixins import activation_checkpointing as ac_mod | ||
|
|
||
| recorded_states: list[dict] = [] | ||
| original_training_step = ( | ||
| ac_mod.ActivationOffloadingMixin.training_step | ||
| ) | ||
|
|
||
| def recording_training_step(self, *args, **kwargs): | ||
| ctx = self.activation_offload_context | ||
| if isinstance(ctx, ac_mod.OffloadActivations): | ||
| recorded_states.append( | ||
| { | ||
| "step": self._offload_step_counter, | ||
| "tracker": len(ctx.tracker), | ||
| "storage_dedup": len(ctx.storage_to_tensor_id), | ||
| "fwd_stash": len(getattr(ctx, "fwd_stash", {})), | ||
| "bwd_tensor_stash": len( | ||
| getattr(ctx, "bwd_tensor_stash", {}) | ||
| ), | ||
| "bwd_ev_stash": len(getattr(ctx, "bwd_ev_stash", {})), | ||
| } | ||
| ) | ||
| return original_training_step(self, *args, **kwargs) | ||
|
|
||
| monkeypatch.setattr( | ||
| ac_mod.ActivationOffloadingMixin, | ||
| "training_step", | ||
| recording_training_step, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify the attributes/symbols this test depends on.
# 1) Locate the activation_checkpointing module and inspect its public symbols.
fd -t f activation_checkpointing.py
echo "--- OffloadActivations definitions/exports ---"
rg -nP -C2 '\b(class\s+OffloadActivations|OffloadActivations\s*=)\b'
echo "--- ActivationOffloadingMixin definition ---"
ast-grep --pattern 'class ActivationOffloadingMixin($$$):
$$$'
echo "--- _offload_step_counter usages/initializations ---"
rg -nP -C3 '_offload_step_counter'
echo "--- activation_offload_context assignments ---"
rg -nP -C2 '\bactivation_offload_context\b'
echo "--- training_step inside the mixin ---"
ast-grep --pattern 'class ActivationOffloadingMixin($$$):
$$$
def training_step($$$):
$$$
$$$'Repository: axolotl-ai-cloud/axolotl
Length of output: 5462
🏁 Script executed:
#!/bin/bash
# Check imports in activation_checkpointing.py
echo "=== Imports in activation_checkpointing.py ==="
head -30 src/axolotl/core/trainers/mixins/activation_checkpointing.py
# Check full test file to see all imports
echo ""
echo "=== Test file imports ==="
head -80 tests/e2e/test_activation_offloading.py
# Search for OffloadActivations in entire codebase
echo ""
echo "=== OffloadActivations anywhere in repo ==="
rg 'OffloadActivations' --type py
# Search for _offload_step_counter anywhere
echo ""
echo "=== _offload_step_counter anywhere in repo ==="
rg '_offload_step_counter' --type py
# Check the functions used to get context managers
echo ""
echo "=== get_act_offloading_ctx_manager definition ==="
rg -B2 -A10 'def get_act_offloading_ctx_manager|def get_lora_act_offloading_ctx_manager'Repository: axolotl-ai-cloud/axolotl
Length of output: 5993
🏁 Script executed:
#!/bin/bash
# Show the full activation_checkpointing.py file since it's implementing custom context managers
echo "=== Full activation_checkpointing.py ==="
cat -n src/axolotl/core/trainers/mixins/activation_checkpointing.py
# Check if there's any setup of _offload_step_counter in the config or elsewhere
echo ""
echo "=== Search for step_counter or offload initialization ==="
rg 'step_counter|offload.*init|OffloadActivations\(' -B2 -A2
# Check git history or PR comments about this test
echo ""
echo "=== Check test file for assertions about recorded_states ==="
tail -50 tests/e2e/test_activation_offloading.pyRepository: axolotl-ai-cloud/axolotl
Length of output: 14633
self._offload_step_counter does not exist on the trainer instance and will cause an AttributeError at runtime.
Line 91 dereferences self._offload_step_counter unconditionally. The attribute is never initialized on ActivationOffloadingMixin in this codebase, and searching the entire repository finds no initialization. This attribute must be provided by the upstream trl library (as part of OffloadActivations), but the test accesses it from the wrong object—it should access ctx._offload_step_counter (the context manager) rather than self._offload_step_counter (the trainer). Without this fix, recording_training_step will raise AttributeError on the first training step, and the assertions will never run. The test will appear to pass silently because the exception inside the monkeypatched function will not propagate if swallowed, making this a critical regression guard that doesn't actually guard.
🧰 Tools
🪛 GitHub Actions: lint / 0_pre-commit.txt
[error] 79-93: pre-commit failed at hook 'ruff-format' (ruff format). 1 file reformatted; files modified by this hook.
🪛 GitHub Actions: lint / pre-commit
[error] 79-92: pre-commit hook 'ruff-format' failed: file was reformatted (1 file reformatted, 668 unchanged). Run 'pre-commit run --all-files' or apply the formatting changes. Diff indicates line wrapping changes for 'original_training_step' and 'bwd_tensor_stash'.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/e2e/test_activation_offloading.py` around lines 75 - 107, The
monkeypatched recording_training_step incorrectly reads
self._offload_step_counter (which doesn't exist) causing AttributeError; change
it to read the counter from the offload context (ctx._offload_step_counter)
and/or safely use getattr(ctx, "_offload_step_counter", None) when building the
recorded_states entry; update the recording_training_step used to monkeypatch
ac_mod.ActivationOffloadingMixin.training_step so it checks isinstance(ctx,
ac_mod.OffloadActivations) then uses ctx._offload_step_counter (or the getattr
fallback) when appending to recorded_states to ensure the test doesn't crash.
| from axolotl.core.trainers.mixins import activation_checkpointing as ac_mod | ||
|
|
||
| recorded_states: list[dict] = [] | ||
| original_training_step = ( | ||
| ac_mod.ActivationOffloadingMixin.training_step | ||
| ) | ||
|
|
||
| def recording_training_step(self, *args, **kwargs): | ||
| ctx = self.activation_offload_context | ||
| if isinstance(ctx, ac_mod.OffloadActivations): | ||
| recorded_states.append( | ||
| { | ||
| "step": self._offload_step_counter, | ||
| "tracker": len(ctx.tracker), | ||
| "storage_dedup": len(ctx.storage_to_tensor_id), | ||
| "fwd_stash": len(getattr(ctx, "fwd_stash", {})), | ||
| "bwd_tensor_stash": len( | ||
| getattr(ctx, "bwd_tensor_stash", {}) | ||
| ), | ||
| "bwd_ev_stash": len(getattr(ctx, "bwd_ev_stash", {})), | ||
| } | ||
| ) | ||
| return original_training_step(self, *args, **kwargs) |
There was a problem hiding this comment.
Apply the ruff-format reformat to unblock CI.
The ruff-format pre-commit hook failed on this hunk (lines ~79–93) and reformatted the file. Run pre-commit run --all-files (or ruff format tests/e2e/test_activation_offloading.py) and commit the result. The reported diff specifically rewraps original_training_step = ac_mod.ActivationOffloadingMixin.training_step and the bwd_tensor_stash line — currently those are wrapped in a way ruff disagrees with.
Likely reformatted shape (verify with `ruff format`)
- recorded_states: list[dict] = []
- original_training_step = (
- ac_mod.ActivationOffloadingMixin.training_step
- )
+ recorded_states: list[dict] = []
+ original_training_step = ac_mod.ActivationOffloadingMixin.training_step
@@
- "bwd_tensor_stash": len(
- getattr(ctx, "bwd_tensor_stash", {})
- ),
+ "bwd_tensor_stash": len(getattr(ctx, "bwd_tensor_stash", {})),📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from axolotl.core.trainers.mixins import activation_checkpointing as ac_mod | |
| recorded_states: list[dict] = [] | |
| original_training_step = ( | |
| ac_mod.ActivationOffloadingMixin.training_step | |
| ) | |
| def recording_training_step(self, *args, **kwargs): | |
| ctx = self.activation_offload_context | |
| if isinstance(ctx, ac_mod.OffloadActivations): | |
| recorded_states.append( | |
| { | |
| "step": self._offload_step_counter, | |
| "tracker": len(ctx.tracker), | |
| "storage_dedup": len(ctx.storage_to_tensor_id), | |
| "fwd_stash": len(getattr(ctx, "fwd_stash", {})), | |
| "bwd_tensor_stash": len( | |
| getattr(ctx, "bwd_tensor_stash", {}) | |
| ), | |
| "bwd_ev_stash": len(getattr(ctx, "bwd_ev_stash", {})), | |
| } | |
| ) | |
| return original_training_step(self, *args, **kwargs) | |
| from axolotl.core.trainers.mixins import activation_checkpointing as ac_mod | |
| recorded_states: list[dict] = [] | |
| original_training_step = ac_mod.ActivationOffloadingMixin.training_step | |
| def recording_training_step(self, *args, **kwargs): | |
| ctx = self.activation_offload_context | |
| if isinstance(ctx, ac_mod.OffloadActivations): | |
| recorded_states.append( | |
| { | |
| "step": self._offload_step_counter, | |
| "tracker": len(ctx.tracker), | |
| "storage_dedup": len(ctx.storage_to_tensor_id), | |
| "fwd_stash": len(getattr(ctx, "fwd_stash", {})), | |
| "bwd_tensor_stash": len(getattr(ctx, "bwd_tensor_stash", {})), | |
| "bwd_ev_stash": len(getattr(ctx, "bwd_ev_stash", {})), | |
| } | |
| ) | |
| return original_training_step(self, *args, **kwargs) |
🧰 Tools
🪛 GitHub Actions: lint / 0_pre-commit.txt
[error] 79-93: pre-commit failed at hook 'ruff-format' (ruff format). 1 file reformatted; files modified by this hook.
🪛 GitHub Actions: lint / pre-commit
[error] 79-92: pre-commit hook 'ruff-format' failed: file was reformatted (1 file reformatted, 668 unchanged). Run 'pre-commit run --all-files' or apply the formatting changes. Diff indicates line wrapping changes for 'original_training_step' and 'bwd_tensor_stash'.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/e2e/test_activation_offloading.py` around lines 79 - 101, Run ruff
format on tests/e2e/test_activation_offloading.py and commit the reformat so the
pre-commit hook passes; specifically reformat the block that assigns
original_training_step = ac_mod.ActivationOffloadingMixin.training_step and the
recording_training_step function (the lines referencing recorded_states,
recording_training_step, activation_offload_context, ac_mod.OffloadActivations,
and the bwd_tensor_stash/bwd_ev_stash collects) to match ruff's wrapping/style
(you can run `ruff format tests/e2e/test_activation_offloading.py` or
`pre-commit run --all-files` to apply the exact changes).
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
| evals_per_epoch: 1 | ||
| saves_per_epoch: 1 | ||
| weight_decay: 0.0 | ||
| ddp_find_unused_parameters: true |
There was a problem hiding this comment.
Is this all that's needed?
There was a problem hiding this comment.
was surprised to see it fail with without ddp_find_unused_parameters: true for multi gpu tho
not needed here removing
37/848]
File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 472, in __call__
[rank0]: return super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [34/848]
File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl [33/848]
[rank0]: return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [31/848]
File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl [30/848]
[rank0]: return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [28/848]
File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1024, in compile_wrapper [27/848]
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/_dynamo/external_utils.py", line 69, in inner
[rank0]: return fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1779, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1695, in forward
[rank0]: inputs, kwargs = self._pre_forward(*inputs, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/root/miniconda3/envs/py3.11/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1584, in _pre_forward
[rank0]: if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in produc
ing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by
[rank0]: making sure all `forward` function outputs participate in calculating loss.
[rank0]: If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Pleas
e include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
[rank0]: Parameter indices which did not receive grad for rank 0: 14 15 32 33 50 51 68 69 86 87 102 103 120 121 138 139 156 157 174 175 192 193 208 209 226 227 244 245 262 263 280 281 298 29
9 314 315 332 333 350 351 368 369 386 387 404 405 420 421 438 439 456 457 474 475 492 493 510 511 526 527
[rank0]: In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradie
nt on this rank as part of this error
0%| | 0/321 [00:28<?, ?it/s]
W0508 15:03:29.992000 2524 site-packages/torch/distributed/elastic/multiprocessing/api.py:1012] Sending process 2788 closing signal SIGTERM
E0508 15:03:30.609000 2524 site-packages/torch/distributed/elastic/multiprocessing/api.py:986] failed (exitcode: 1) local_rank: 0 (pid: 2787) of binary: /root/miniconda3/envs/py3.11/bin/pyth
on3.11
Traceback (most recent call last):
File "/root/miniconda3/envs/py3.11/bin/accelerate", line 6, in <module>
sys.exit(main())
Show
|
can you address the coderabbit issues pls? |
fixes #3638
test here should go after
huggingface/trl#5738
context
Fixes #3638.
OffloadActivationsonly resets its cross-step state when the last saved tensor of a step is unpacked during backward. With MoE /torch.compile, some saved tensors never unpack (their backward node never executes),
so
trackerretains GPU activation references that accumulate every step linear VRAM leak. On Gemma-4 26B-A4B + ScatterMoE + sample_packing this OOMs by step 2Summary by CodeRabbit
Release Notes
New Features
Tests