Skip to content

fix: [gemma4] fix VRAM leak in hybrid FA2+SDPA (hybrid attentiuon) path under activation check…#3611

Merged
winglian merged 6 commits into
axolotl-ai-cloud:mainfrom
thad0ctor:fix/gemma4-hybrid-vram-leak
Apr 21, 2026
Merged

fix: [gemma4] fix VRAM leak in hybrid FA2+SDPA (hybrid attentiuon) path under activation check…#3611
winglian merged 6 commits into
axolotl-ai-cloud:mainfrom
thad0ctor:fix/gemma4-hybrid-vram-leak

Conversation

@thad0ctor

@thad0ctor thad0ctor commented Apr 19, 2026

Copy link
Copy Markdown
Contributor

[gemma4] fix VRAM leak in hybrid FA2+SDPA path under activation checkpointing

Route shared_kv_states through a thread-local side channel instead of the decoder-layer kwargs so the checkpoint partial never references the dict.

Description

HF's Gemma4TextModel.forward passes shared_kv_states (a mutable dict used for cross-layer K/V sharing) as a kwarg to every decoder_layer(...) call. GradientCheckpointingLayer.__call__ then forms partial(super().__call__, **kwargs), and whichever checkpoint implementation runs — axolotl's CPU_Offloaded_Gradient_Checkpointer (via ctx.forward_function = forward_function at src/axolotl/monkeypatch/gradient_checkpointing/offload_cpu.py:51) or torch's stock torch.utils.checkpoint.checkpoint — captures that partial. The partial holds a reference to the dict, which holds the K/V tensors produced by store_full_length_kv layers. Those tensors stay pinned for the full duration of backward, and delayed ref-cycle cleanup in torch's caching allocator under FSDP2 + activation checkpointing bleeds the residual across training steps.

Violated invariant: anything crossing an activation-checkpoint boundary must be a tensor (refcounted by autograd) or plain Python data — never a mutable container holding tensor references.

All changes live in src/axolotl/monkeypatch/models/gemma4/fused_attn.py (+58 / −2):

  • threading.local() store with _get_shared_kv_states() / _set_shared_kv_states() helpers.
  • _patch_decoder_layer_call() — monkeypatches Gemma4TextDecoderLayer.__call__ to pop shared_kv_states from kwargs and stash it in TLS before delegating to GradientCheckpointingLayer.__call__. The partial formed downstream no longer references the dict.
  • fused_forward reads TLS first and falls back to the kwarg for callers that bypass the patched __call__ (e.g. direct attention invocation during eval hooks).
  • Wired into patch_gemma4_fused_attn(); idempotent via an _axolotl_shared_kv_tls_patched sentinel.

TLS is overwritten on each new step's first decoder-layer call, so the previous step's dict is released promptly. No changes to the hybrid dispatch, FSDP wrap policy, config schema, or any public behavior. The patch is unconditional — it works correctly for hybrid, flex, and eager paths, because the TLS rerouting is semantically equivalent to the kwarg path (it is purely a capture-avoidance change).

Motivation and Context

Fixes #3610.

Introduced by PR #3598 (commit b8358aa5, "[gemma4] use mixed Flash Attention and SDPA and add fused RMSNorm+RoPE Triton kernels"), which added the hybrid flag, the unconditional patch_gemma4_fused_attn() monkeypatch, and the fused RMSNorm+RoPE kernels together.

Observed symptom (pre-patch):

Path VRAM behavior Throughput (steady-state)
gemma4_hybrid_attn_impl: true Climbs ~0.47 GiB/step from 42 GiB baseline; OOMs at step 73 (~94 GiB peak) ~13–14 s/step
flex_attention: true (all 60 layers) Flat at 64 GiB ~30 s/step (~2× slower)

The climb is independent of actual sequence length or image size — confirmed by the reporter, which rules out shape variability, torch.compile recompiles, and head_dim=512 SDPA peak allocations (which would scale with T²). The flex path is flat but ~2× slower and not a practical workaround for a multi-day training run.

Hypotheses considered and rejected during diagnosis:

  • Per-layer shallow-copied config breaks FSDP/dynamo equivalence — contributes to path divergence but is not the leak source.
  • Silent cfg.flash_attention = True activates downstream per-step allocations — audited all cfg.flash_attention consumers; none allocate per-step state under the reported config (sample_packing: false, SM120 skips the FA4 auto-patch).
  • FA2 workspace retention under mixed dispatch — FA2 workspace is stateless per-call; no accumulation.

The violated-invariant analysis also cleanly explains why the all-flex path doesn't leak: flex compiles all 60 layers identically into a single graph context where saved tensors are managed by the compiled graph rather than by Python refcounting, so the dict's refs don't pin anything.

How has this been tested?

Static / structural validation (performed):

  • ast.parse on the modified module — clean.

  • Patch installation smoke test — both Gemma4TextDecoderLayer.__call__ and Gemma4TextAttention.forward are replaced after patch_gemma4_fused_attn().

  • Sentinel (_axolotl_shared_kv_tls_patched) prevents double-wrapping on repeated invocations.

  • TLS _get_shared_kv_states() / _set_shared_kv_states() round-trip correctly.

  • Byte-identical to the editable install being used on the reporter's hardware.

  • Ran pre-commit run --files src/axolotl/monkeypatch/models/gemma4/fused_attn.py on the modified file:

    Hook Result
    ruff (legacy alias) ✅ Passed
    ruff format ✅ Passed
    mypy ✅ Passed
    bandit ✅ Passed
    fix end of files ✅ Passed
    trim trailing whitespace ✅ Passed

End-to-end VRAM-flatness validation (confirmed on reporter's hardware):

Reran the failing config — Gemma-4 31B multimodal, LoRA r=256 on language layers, vision tower + embed_vision trained full, sample_packing: false, pad_to_sequence_len: false, sequence_len: 5248, FSDP2 with reshard_after_forward: true, activation_checkpointing: true under fsdp_config (HF Trainer-level gradient_checkpointing: false), on 2× RTX PRO 6000 Blackwell 96 GB. Ran 200+ steps past the previous OOM point (step 73) including a full eval + checkpoint-save cycle at step 200. VRAM is flat; throughput matches pre-patch hybrid.

step was (pre-patch) is (post-patch)
eval (pre-train) 42.5 GiB 42.6 GiB
9 57.2 GiB 55.3 GiB
19 64.8 GiB 56.3 GiB
29 62.0 GiB 48.7 GiB
39 72.5 GiB 55.0 GiB
49 74.7 GiB 52.1 GiB
59 79.2 GiB 49.0 GiB
69 85.6 GiB 52.9 GiB
73 OOM (~94 GiB)
79 57.1 GiB
89 52.7 GiB
99 57.5 GiB
109 48.5 GiB
119 53.4 GiB
129 52.3 GiB
139 58.5 GiB
149 49.9 GiB
159 57.9 GiB
169 53.7 GiB
179 54.3 GiB
189 66.2 GiB
199 49.3 GiB
eval @ 200 44.6 GiB
save @ 200 ✅ checkpoint saved

Post-patch VRAM oscillates in a ~48–58 GiB band with one excursion to 66.2 GiB at step 189; no upward trend. Step time is ~13–14 s/it (steady-state), identical to pre-patch hybrid. Eval loss dropped 5.43 → 0.69 over the first 200 steps; the first save_strategy: steps checkpoint wrote cleanly under SHARDED_STATE_DICT FSDP2.

Testing environment of record:

  • transformers 5.5.4
  • Hardware: 2× RTX PRO 6000 Blackwell 96 GB (SM120); FSDP2; activation checkpointing enabled.
  • Model: Gemma-4 31B (shimmed to enforce max image tokens in training)
  • torch 2.11.0+cu130
  • main/323da791 served as basis for fork with these changes

What reviewers may want to sanity-check:

  • The Gemma4TextDecoderLayer.__call__ intercept sits above GradientCheckpointingLayer.__call__, so the partial formed downstream does not re-capture the popped kwarg.
  • The kwarg fallback in fused_forward preserves behavior for any caller that invokes attention without going through the patched decoder-layer __call__ (e.g. custom eval hooks that reach into self_attn directly).
  • The unconditional application to the flex and eager paths is intentional and benign: the TLS rerouting is a capture-avoidance change, not a behavior change.

AI Usage Disclaimer

Claude Code (Opus 4.7) was used in the following scopes:

  • Root-cause investigation — reading the axolotl + HF transformers source for the Gemma-4 hybrid attention path, CPU_Offloaded_Gradient_Checkpointer, GradientCheckpointingLayer, and the shared_kv_states dataflow through the decoder loop; validating four hypotheses against the code; producing the "violated invariant" analysis.
  • Patch design — proposing the thread-local side channel as the minimal diff shape that preserves the hybrid dispatch and FSDP wrap policy.
  • Patch implementation — writing the TLS helpers, the _patch_decoder_layer_call() intercept, and the fused_forward fallback; wiring into patch_gemma4_fused_attn(); adding the idempotency sentinel.
  • Static validation — the AST parse, patch-installation smoke test, TLS round-trip, and idempotency verification.

Supplied by the human author: the training config, observed VRAM numbers (0.47 GiB/step, step-73 OOM, 42 GiB baseline, 94 GiB peak, 13–14 vs 30 s/step steady-state throughput), hardware details, confirmation that the climb is independent of sequence length / image size, and the end-to-end 200+ step VRAM-flatness rerun on the reporter's 2× RTX PRO 6000 Blackwell hardware including a full eval + checkpoint-save cycle.

Screenshots (if appropriate)

N/A — the change is a structural fix with no user-visible surface.

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Performance / memory optimization (eliminates per-step VRAM accumulation on the hybrid path)
  • Documentation update

Social Handles (Optional)

(fill in if desired)

Summary by CodeRabbit

  • Refactor
    • Improved internal memory management and state handling for Gemma4 model attention processing, enhancing overall efficiency and system reliability.

@coderabbitai

coderabbitai Bot commented Apr 19, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 696cf381-f308-4de2-9476-cca5fef21104

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Introduces thread-local storage to manage shared_kv_states in Gemma4 fused attention monkeypatch, preventing mutable container capture by gradient checkpointing. Adds decoder-layer call interception to route shared state through thread-local storage instead of kwargs, mitigating VRAM leaks under FSDP2 with activation checkpointing.

Changes

Cohort / File(s) Summary
Gemma4 Fused Attention TLS Routing
src/axolotl/monkeypatch/models/gemma4/fused_attn.py
Added thread-local storage (_GEMMA4_SHARED_KV_TLS) with getter/setter helpers; made shared_kv_states kwarg optional in fused_forward with TLS-preference logic; introduced _patch_decoder_layer_call() to monkeypatch Gemma4TextDecoderLayer.__call__ for intercepting and routing shared_kv_states into TLS; integrated decoder-layer patch into patch_gemma4_fused_attn().

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.86% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Linked Issues check ✅ Passed The pull request directly implements the proposed minimal monkeypatch for issue #3610: routing shared_kv_states via thread-local storage, patching Gemma4TextDecoderLayer.call, updating fused_forward logic, and wiring into patch_gemma4_fused_attn() idempotently.
Out of Scope Changes check ✅ Passed All changes are scoped to the identified issue: internal helpers for TLS management, decoder-layer call patching, fused_forward signature updates, and integration into the existing patch function. No unrelated modifications to hybrid dispatch, FSDP policy, configs, or public APIs.
Title check ✅ Passed The PR title accurately describes the main change: fixing a VRAM leak in the Gemma4 hybrid FA2+SDPA attention path that occurs under activation checkpointing. The title directly reflects the core problem and solution being addressed.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/monkeypatch/models/gemma4/fused_attn.py`:
- Around line 178-182: patched_call currently only calls _set_shared_kv_states
when shared_kv is non-None, leaving TLS stale when callers omit the kwarg;
change patched_call to always call _set_shared_kv_states(shared_kv) after
popping shared_kv from kwargs (i.e., keep shared_kv =
kwargs.pop("shared_kv_states", None) but remove the conditional and call
_set_shared_kv_states(shared_kv) unconditionally) so TLS is cleared/overwritten
for every invocation and fused_forward's tls_store fallback works reliably.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: d606f7e3-97be-4596-a897-30e530fa3f67

📥 Commits

Reviewing files that changed from the base of the PR and between 323da79 and a24fca2.

📒 Files selected for processing (1)
  • src/axolotl/monkeypatch/models/gemma4/fused_attn.py

Comment thread src/axolotl/monkeypatch/models/gemma4/fused_attn.py
@thad0ctor thad0ctor changed the title [gemma4] fix VRAM leak in hybrid FA2+SDPA path under activation check… fix: [gemma4] fix VRAM leak in hybrid FA2+SDPA (hybrid attentiuon) path under activation check… Apr 19, 2026
thad0ctor added a commit to thad0ctor/axolotl that referenced this pull request Apr 19, 2026
…r patched __call__

  Overwrite the thread-local shared_kv_states store on every invocation
  (including with None) instead of only when the kwarg is present.

  The previous conditional write left stale dicts in TLS on any path that
  reaches Gemma4TextDecoderLayer.__call__ without a shared_kv_states
  kwarg — e.g. generation, eval hooks, or future HF refactors that make
  the kwarg optional. fused_forward would then silently consume a prior
  step's K/V dict instead of falling back to its own kwarg path.

  Unconditional write makes the invariant in the surrounding comment
  ("TLS is overwritten on each new step's first decoder-layer call, so
  the previous step's dict is released promptly") actually hold.

  No behavior change for the training happy path, which always passes
  the kwarg. Addresses CodeRabbit review on PR axolotl-ai-cloud#3611

@NanoCode012 NanoCode012 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think this makes sense. I'm a bit unclear whether having the store on module scope is the best overall workaround (possible clash with any async runs tho likely rare).

I think we can also make that big chunk of comments more concise. If necessary, just point it to this PR with the explanation here.

Edit: Could we also add unit test (no need e2e) where appropriate? Maybe one to check that the shared kv is popped, the moe issue you met, ?

@thad0ctor

Copy link
Copy Markdown
Contributor Author

Ok, I think this makes sense. I'm a bit unclear whether having the store on module scope is the best overall workaround (possible clash with any async runs tho likely rare).

I think we can also make that big chunk of comments more concise. If necessary, just point it to this PR with the explanation here.

Edit: Could we also add unit test (no need e2e) where appropriate? Maybe one to check that the shared kv is popped, the moe issue you met, ?

Good point on the scope, let me think of if there needs to be a gating check to prevent any regressions across other functionality - I'm leaning towards this as other options would likely be too invasive elsewhere and add latency.

I'll address the comments and unit tests for sure.

@winglian

Copy link
Copy Markdown
Collaborator

thanks! lmk when this is ready to review again

thad0ctor added a commit to thad0ctor/axolotl that referenced this pull request Apr 20, 2026
Update PR axolotl-ai-cloud#3611 with gate for checkpointed training to avoid regressions across async flows.

Added unit tests for kwargs pop, store-clear regression, and flag gating. Condensed verbose comments
@thad0ctor

thad0ctor commented Apr 20, 2026

Copy link
Copy Markdown
Contributor Author

thanks! lmk when this is ready to review again

I just pushed the changes with a gate (for checkpointed training), unit tests and shorter docstrings

Please let me know if you need me to update the PR summary itself as some of it is now OBE

Edit: I missed a test for the MoE 'NoneType' error I observed and corrected in e3669b2 , commit 1bdc84c adds a 5th test per @NanoCode012 's comments, should be good for review now. Thanks!

@thad0ctor thad0ctor requested a review from winglian April 20, 2026 16:41
@@ -16,6 +17,18 @@

logger = logging.getLogger(__name__)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not due to your change, but we should use the axolotl logger as shown in other files.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need me to update this?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix this

thad0ctor and others added 6 commits April 21, 2026 11:54
…pointing

Route shared_kv_states through a thread-local side channel instead of the
decoder-layer kwargs so the checkpoint partial never references the dict.

HF's Gemma4TextModel.forward passes shared_kv_states (a mutable dict used
for cross-layer K/V sharing) as a kwarg to every decoder_layer call.
GradientCheckpointingLayer.__call__ then forms
partial(super().__call__, **kwargs), and whichever checkpoint runs
(axolotl's CPU_Offloaded_Gradient_Checkpointer or torch's stock
checkpoint) captures that partial. The partial holds a reference to the
dict, which holds the K/V tensors produced by store_full_length_kv
layers. Those tensors stay pinned for the full duration of backward, and
delayed ref-cycle cleanup in torch's caching allocator under FSDP2 +
activation checkpointing bleeds the residual across steps.

Observed symptom: VRAM climbs ~0.47 GiB/step from a 42 GiB baseline,
OOMs around step 73 (~94 GiB peak) on Gemma-4 31B multimodal with
gemma4_hybrid_attn_impl: true. Independent of seq len / image size.
All-flex-attention path is flat but ~22x slower.

Violated invariant: anything crossing an activation-checkpoint boundary
must be a tensor (refcounted by autograd) or plain Python data -- never
a mutable container holding tensor references.

Fix (all in src/axolotl/monkeypatch/models/gemma4/fused_attn.py):
  * threading.local() store with _get/_set_shared_kv_states helpers
  * _patch_decoder_layer_call(): monkeypatches
    Gemma4TextDecoderLayer.__call__ to pop shared_kv_states from kwargs
    and stash it in TLS before delegating to GradientCheckpointingLayer.
    The partial formed downstream no longer references the dict.
  * fused_forward reads TLS first, falls back to kwarg for callers that
    bypass the patched __call__ (e.g. direct attention invocation).
  * wired into patch_gemma4_fused_attn; idempotent via a sentinel.

TLS is overwritten on each new step's first decoder-layer call, so the
previous step's dict is released promptly. No changes to hybrid dispatch,
FSDP wrap policy, or any config behaviour. Works for hybrid, flex, and
eager paths.

Introduced by PR axolotl-ai-cloud#3598 (commit b8358aa).
…r patched __call__

  Overwrite the thread-local shared_kv_states store on every invocation
  (including with None) instead of only when the kwarg is present.

  The previous conditional write left stale dicts in TLS on any path that
  reaches Gemma4TextDecoderLayer.__call__ without a shared_kv_states
  kwarg — e.g. generation, eval hooks, or future HF refactors that make
  the kwarg optional. fused_forward would then silently consume a prior
  step's K/V dict instead of falling back to its own kwarg path.

  Unconditional write makes the invariant in the surrounding comment
  ("TLS is overwritten on each new step's first decoder-layer call, so
  the previous step's dict is released promptly") actually hold.

  No behavior change for the training happy path, which always passes
  the kwarg. Addresses CodeRabbit review on PR axolotl-ai-cloud#3611
… threads see shared_kv_states during backward recompute

Previous commits fixed memory leak on 31B but caused type error with MOE Gemma4 variants - this fixes that:

PR 3611's TLS variant only works when recompute runs on the same thread
  that set TLS during forward. PyTorch's C++ autograd engine
  (_engine_run_backward) spawns per-device worker threads to dispatch
  backward, and HF-Trainer gradient_checkpointing (stock
  torch.utils.checkpoint, non-reentrant / saved-tensor-hooks) fires
  unpack_hook -> recompute_fn on those worker threads. TLS set on the main
  thread during forward is invisible there, so _get_shared_kv_states()
  returns None and the consumer-layer lookup crashes with
  "'NoneType' object is not subscriptable" at
  fused_attn.py:97 (shared_kv_states[self.kv_shared_layer_index]).

  A plain module-level dict is visible to all threads in the process.
  Lifecycle is identical: the slot is overwritten each forward, releasing
  the previous step's dict and allowing its K/V tensors to be GC'd, so
  the original VRAM-leak fix still holds under FSDP2 AC too.
Update PR axolotl-ai-cloud#3611 with gate for checkpointed training to avoid regressions across async flows.

Added unit tests for kwargs pop, store-clear regression, and flag gating. Condensed verbose comments
Additional regression test for MoE gemma4 variants - asserts the module-level store is readable from threads other than the one that set it in response to previously observed 'NoneType' error
@winglian winglian force-pushed the fix/gemma4-hybrid-vram-leak branch from 1bdc84c to e91d711 Compare April 21, 2026 15:56
@codecov

codecov Bot commented Apr 21, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 75.86207% with 7 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...rc/axolotl/monkeypatch/models/gemma4/fused_attn.py 84.61% 4 Missing ⚠️
src/axolotl/loaders/patch_manager.py 0.00% 3 Missing ⚠️

📢 Thoughts on this report? Let us know!

@winglian winglian merged commit e562e14 into axolotl-ai-cloud:main Apr 21, 2026
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

VRAM leak in Gemma-4 hybrid attention path under activation checkpointing

3 participants