fix: [gemma4] fix VRAM leak in hybrid FA2+SDPA (hybrid attentiuon) path under activation check…#3611
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughIntroduces thread-local storage to manage Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/axolotl/monkeypatch/models/gemma4/fused_attn.py`:
- Around line 178-182: patched_call currently only calls _set_shared_kv_states
when shared_kv is non-None, leaving TLS stale when callers omit the kwarg;
change patched_call to always call _set_shared_kv_states(shared_kv) after
popping shared_kv from kwargs (i.e., keep shared_kv =
kwargs.pop("shared_kv_states", None) but remove the conditional and call
_set_shared_kv_states(shared_kv) unconditionally) so TLS is cleared/overwritten
for every invocation and fused_forward's tls_store fallback works reliably.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: d606f7e3-97be-4596-a897-30e530fa3f67
📒 Files selected for processing (1)
src/axolotl/monkeypatch/models/gemma4/fused_attn.py
…r patched __call__
Overwrite the thread-local shared_kv_states store on every invocation
(including with None) instead of only when the kwarg is present.
The previous conditional write left stale dicts in TLS on any path that
reaches Gemma4TextDecoderLayer.__call__ without a shared_kv_states
kwarg — e.g. generation, eval hooks, or future HF refactors that make
the kwarg optional. fused_forward would then silently consume a prior
step's K/V dict instead of falling back to its own kwarg path.
Unconditional write makes the invariant in the surrounding comment
("TLS is overwritten on each new step's first decoder-layer call, so
the previous step's dict is released promptly") actually hold.
No behavior change for the training happy path, which always passes
the kwarg. Addresses CodeRabbit review on PR axolotl-ai-cloud#3611
There was a problem hiding this comment.
Ok, I think this makes sense. I'm a bit unclear whether having the store on module scope is the best overall workaround (possible clash with any async runs tho likely rare).
I think we can also make that big chunk of comments more concise. If necessary, just point it to this PR with the explanation here.
Edit: Could we also add unit test (no need e2e) where appropriate? Maybe one to check that the shared kv is popped, the moe issue you met, ?
Good point on the scope, let me think of if there needs to be a gating check to prevent any regressions across other functionality - I'm leaning towards this as other options would likely be too invasive elsewhere and add latency. I'll address the comments and unit tests for sure. |
|
thanks! lmk when this is ready to review again |
Update PR axolotl-ai-cloud#3611 with gate for checkpointed training to avoid regressions across async flows. Added unit tests for kwargs pop, store-clear regression, and flag gating. Condensed verbose comments
I just pushed the changes with a gate (for checkpointed training), unit tests and shorter docstrings Please let me know if you need me to update the PR summary itself as some of it is now OBE Edit: I missed a test for the MoE 'NoneType' error I observed and corrected in e3669b2 , commit 1bdc84c adds a 5th test per @NanoCode012 's comments, should be good for review now. Thanks! |
| @@ -16,6 +17,18 @@ | |||
|
|
|||
| logger = logging.getLogger(__name__) | |||
There was a problem hiding this comment.
Not due to your change, but we should use the axolotl logger as shown in other files.
There was a problem hiding this comment.
Do you need me to update this?
…pointing
Route shared_kv_states through a thread-local side channel instead of the
decoder-layer kwargs so the checkpoint partial never references the dict.
HF's Gemma4TextModel.forward passes shared_kv_states (a mutable dict used
for cross-layer K/V sharing) as a kwarg to every decoder_layer call.
GradientCheckpointingLayer.__call__ then forms
partial(super().__call__, **kwargs), and whichever checkpoint runs
(axolotl's CPU_Offloaded_Gradient_Checkpointer or torch's stock
checkpoint) captures that partial. The partial holds a reference to the
dict, which holds the K/V tensors produced by store_full_length_kv
layers. Those tensors stay pinned for the full duration of backward, and
delayed ref-cycle cleanup in torch's caching allocator under FSDP2 +
activation checkpointing bleeds the residual across steps.
Observed symptom: VRAM climbs ~0.47 GiB/step from a 42 GiB baseline,
OOMs around step 73 (~94 GiB peak) on Gemma-4 31B multimodal with
gemma4_hybrid_attn_impl: true. Independent of seq len / image size.
All-flex-attention path is flat but ~22x slower.
Violated invariant: anything crossing an activation-checkpoint boundary
must be a tensor (refcounted by autograd) or plain Python data -- never
a mutable container holding tensor references.
Fix (all in src/axolotl/monkeypatch/models/gemma4/fused_attn.py):
* threading.local() store with _get/_set_shared_kv_states helpers
* _patch_decoder_layer_call(): monkeypatches
Gemma4TextDecoderLayer.__call__ to pop shared_kv_states from kwargs
and stash it in TLS before delegating to GradientCheckpointingLayer.
The partial formed downstream no longer references the dict.
* fused_forward reads TLS first, falls back to kwarg for callers that
bypass the patched __call__ (e.g. direct attention invocation).
* wired into patch_gemma4_fused_attn; idempotent via a sentinel.
TLS is overwritten on each new step's first decoder-layer call, so the
previous step's dict is released promptly. No changes to hybrid dispatch,
FSDP wrap policy, or any config behaviour. Works for hybrid, flex, and
eager paths.
Introduced by PR axolotl-ai-cloud#3598 (commit b8358aa).
…r patched __call__
Overwrite the thread-local shared_kv_states store on every invocation
(including with None) instead of only when the kwarg is present.
The previous conditional write left stale dicts in TLS on any path that
reaches Gemma4TextDecoderLayer.__call__ without a shared_kv_states
kwarg — e.g. generation, eval hooks, or future HF refactors that make
the kwarg optional. fused_forward would then silently consume a prior
step's K/V dict instead of falling back to its own kwarg path.
Unconditional write makes the invariant in the surrounding comment
("TLS is overwritten on each new step's first decoder-layer call, so
the previous step's dict is released promptly") actually hold.
No behavior change for the training happy path, which always passes
the kwarg. Addresses CodeRabbit review on PR axolotl-ai-cloud#3611
… threads see shared_kv_states during backward recompute Previous commits fixed memory leak on 31B but caused type error with MOE Gemma4 variants - this fixes that: PR 3611's TLS variant only works when recompute runs on the same thread that set TLS during forward. PyTorch's C++ autograd engine (_engine_run_backward) spawns per-device worker threads to dispatch backward, and HF-Trainer gradient_checkpointing (stock torch.utils.checkpoint, non-reentrant / saved-tensor-hooks) fires unpack_hook -> recompute_fn on those worker threads. TLS set on the main thread during forward is invisible there, so _get_shared_kv_states() returns None and the consumer-layer lookup crashes with "'NoneType' object is not subscriptable" at fused_attn.py:97 (shared_kv_states[self.kv_shared_layer_index]). A plain module-level dict is visible to all threads in the process. Lifecycle is identical: the slot is overwritten each forward, releasing the previous step's dict and allowing its K/V tensors to be GC'd, so the original VRAM-leak fix still holds under FSDP2 AC too.
Update PR axolotl-ai-cloud#3611 with gate for checkpointed training to avoid regressions across async flows. Added unit tests for kwargs pop, store-clear regression, and flag gating. Condensed verbose comments
Additional regression test for MoE gemma4 variants - asserts the module-level store is readable from threads other than the one that set it in response to previously observed 'NoneType' error
1bdc84c to
e91d711
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
[gemma4] fix VRAM leak in hybrid FA2+SDPA path under activation checkpointing
Route
shared_kv_statesthrough a thread-local side channel instead of the decoder-layer kwargs so the checkpoint partial never references the dict.Description
HF's
Gemma4TextModel.forwardpassesshared_kv_states(a mutable dict used for cross-layer K/V sharing) as a kwarg to everydecoder_layer(...)call.GradientCheckpointingLayer.__call__then formspartial(super().__call__, **kwargs), and whichever checkpoint implementation runs — axolotl'sCPU_Offloaded_Gradient_Checkpointer(viactx.forward_function = forward_functionatsrc/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py:51) or torch's stocktorch.utils.checkpoint.checkpoint— captures that partial. The partial holds a reference to the dict, which holds the K/V tensors produced bystore_full_length_kvlayers. Those tensors stay pinned for the full duration of backward, and delayed ref-cycle cleanup in torch's caching allocator under FSDP2 + activation checkpointing bleeds the residual across training steps.Violated invariant: anything crossing an activation-checkpoint boundary must be a tensor (refcounted by autograd) or plain Python data — never a mutable container holding tensor references.
All changes live in
src/axolotl/monkeypatch/models/gemma4/fused_attn.py(+58 / −2):threading.local()store with_get_shared_kv_states()/_set_shared_kv_states()helpers._patch_decoder_layer_call()— monkeypatchesGemma4TextDecoderLayer.__call__to popshared_kv_statesfrom kwargs and stash it in TLS before delegating toGradientCheckpointingLayer.__call__. The partial formed downstream no longer references the dict.fused_forwardreads TLS first and falls back to the kwarg for callers that bypass the patched__call__(e.g. direct attention invocation during eval hooks).patch_gemma4_fused_attn(); idempotent via an_axolotl_shared_kv_tls_patchedsentinel.TLS is overwritten on each new step's first decoder-layer call, so the previous step's dict is released promptly. No changes to the hybrid dispatch, FSDP wrap policy, config schema, or any public behavior. The patch is unconditional — it works correctly for hybrid, flex, and eager paths, because the TLS rerouting is semantically equivalent to the kwarg path (it is purely a capture-avoidance change).
Motivation and Context
Fixes #3610.
Introduced by PR #3598 (commit
b8358aa5, "[gemma4] use mixed Flash Attention and SDPA and add fused RMSNorm+RoPE Triton kernels"), which added the hybrid flag, the unconditionalpatch_gemma4_fused_attn()monkeypatch, and the fused RMSNorm+RoPE kernels together.Observed symptom (pre-patch):
gemma4_hybrid_attn_impl: trueflex_attention: true(all 60 layers)The climb is independent of actual sequence length or image size — confirmed by the reporter, which rules out shape variability,
torch.compilerecompiles, and head_dim=512 SDPA peak allocations (which would scale with T²). The flex path is flat but ~2× slower and not a practical workaround for a multi-day training run.Hypotheses considered and rejected during diagnosis:
cfg.flash_attention = Trueactivates downstream per-step allocations — audited allcfg.flash_attentionconsumers; none allocate per-step state under the reported config (sample_packing: false, SM120 skips the FA4 auto-patch).The violated-invariant analysis also cleanly explains why the all-flex path doesn't leak: flex compiles all 60 layers identically into a single graph context where saved tensors are managed by the compiled graph rather than by Python refcounting, so the dict's refs don't pin anything.
How has this been tested?
Static / structural validation (performed):
ast.parseon the modified module — clean.Patch installation smoke test — both
Gemma4TextDecoderLayer.__call__andGemma4TextAttention.forwardare replaced afterpatch_gemma4_fused_attn().Sentinel (
_axolotl_shared_kv_tls_patched) prevents double-wrapping on repeated invocations.TLS
_get_shared_kv_states()/_set_shared_kv_states()round-trip correctly.Byte-identical to the editable install being used on the reporter's hardware.
Ran
pre-commit run --files src/axolotl/monkeypatch/models/gemma4/fused_attn.pyon the modified file:End-to-end VRAM-flatness validation (confirmed on reporter's hardware):
Reran the failing config — Gemma-4 31B multimodal, LoRA r=256 on language layers, vision tower +
embed_visiontrained full,sample_packing: false,pad_to_sequence_len: false,sequence_len: 5248, FSDP2 withreshard_after_forward: true,activation_checkpointing: trueunderfsdp_config(HF Trainer-levelgradient_checkpointing: false), on 2× RTX PRO 6000 Blackwell 96 GB. Ran 200+ steps past the previous OOM point (step 73) including a full eval + checkpoint-save cycle at step 200. VRAM is flat; throughput matches pre-patch hybrid.Post-patch VRAM oscillates in a ~48–58 GiB band with one excursion to 66.2 GiB at step 189; no upward trend. Step time is ~13–14 s/it (steady-state), identical to pre-patch hybrid. Eval loss dropped 5.43 → 0.69 over the first 200 steps; the first
save_strategy: stepscheckpoint wrote cleanly underSHARDED_STATE_DICTFSDP2.Testing environment of record:
transformers5.5.4torch2.11.0+cu130What reviewers may want to sanity-check:
Gemma4TextDecoderLayer.__call__intercept sits aboveGradientCheckpointingLayer.__call__, so the partial formed downstream does not re-capture the popped kwarg.fused_forwardpreserves behavior for any caller that invokes attention without going through the patched decoder-layer__call__(e.g. custom eval hooks that reach intoself_attndirectly).AI Usage Disclaimer
Claude Code (Opus 4.7) was used in the following scopes:
CPU_Offloaded_Gradient_Checkpointer,GradientCheckpointingLayer, and theshared_kv_statesdataflow through the decoder loop; validating four hypotheses against the code; producing the "violated invariant" analysis._patch_decoder_layer_call()intercept, and thefused_forwardfallback; wiring intopatch_gemma4_fused_attn(); adding the idempotency sentinel.Supplied by the human author: the training config, observed VRAM numbers (0.47 GiB/step, step-73 OOM, 42 GiB baseline, 94 GiB peak, 13–14 vs 30 s/step steady-state throughput), hardware details, confirmation that the climb is independent of sequence length / image size, and the end-to-end 200+ step VRAM-flatness rerun on the reporter's 2× RTX PRO 6000 Blackwell hardware including a full eval + checkpoint-save cycle.
Screenshots (if appropriate)
N/A — the change is a structural fix with no user-visible surface.
Types of changes
Social Handles (Optional)
(fill in if desired)
Summary by CodeRabbit