Skip to content

fix: prevent 5 GB+ CUDA memory leak in activation offloading by syncing streams and clear stashes in OffloadActivations.__exit__#5700

Merged
kashif merged 2 commits into
huggingface:mainfrom
butterwecksolutions:stream-cleanup
May 7, 2026
Merged

fix: prevent 5 GB+ CUDA memory leak in activation offloading by syncing streams and clear stashes in OffloadActivations.__exit__#5700
kashif merged 2 commits into
huggingface:mainfrom
butterwecksolutions:stream-cleanup

Conversation

@butterwecksolutions

@butterwecksolutions butterwecksolutions commented May 4, 2026

Copy link
Copy Markdown
Contributor

Problem

OffloadActivations.__exit__ defers entirely to
torch.autograd.graph.saved_tensors_hooks.__exit__, which does not
synchronize the offload CUDA stream (s1) or the compute stream (s0).
Async copies still in-flight when the context manager exits can leave
stashed tensors referencing a stream whose lifetime may already be
over, leaking CUDA blocks with garbage stream IDs.

Observed via torch.cuda.memory._dump_snapshot during QLoRA vision
training on a 9B model (RTX 3090 24 GB):

Stream 93898984661968 (garbage ID): 5.70 GB ← likely orphaned offload tensors
Stream 0 (compute): 14.69 GB ← normal model
Fragmented/unreleased: 5.83 GB

The garbage stream blocks (~8.4 MB each) match the size of activation
offload buffers. After adding stream sync + stash clear, garbage-stream
blocks no longer appear in post-training snapshots.

Fix

Add an __exit__ method that runs BEFORE the parent's cleanup:

  1. Synchronize s0 (compute stream) and s1 (offload stream)
  2. Clear stashed tensors (bwd_tensor_stash, bwd_ev_stash, fwd_stash)
    — only when use_streams=True (stashes don't exist otherwise)

Tracker is intentionally NOT cleared — the backward pass unpacks
tensors via tracker AFTER __exit__ (see class docstring example).

Zero performance impact — called once per context manager, sync+clear
takes ~0.1ms.

Verification

Applied as a monkey-patch since May 2026. Post-training snapshots show
zero garbage-stream blocks. Without the fix, garbage-stream blocks
accumulate monotonically per step at ~0.2 GiB/step.


  • This PR fixes a typo or improves the docs

  • Read the contributor guideline

  • Discussed via GitHub issue (N/A — found during QLoRA debugging)

  • Documentation update needed? No — internal behavior change only

  • New tests? Not easily testable — requires CUDA memory snapshot assertions

  • AI-assisted: PR text and structure refined with AI; root cause found
    through human debugging


Note

Medium Risk
Touches CUDA stream/autograd hook lifecycle in OffloadActivations, so mistakes could cause hangs or perf regressions, but the change is small and gated to use_streams=True.

Overview
Prevents activation-offloading CUDA memory from being leaked after leaving the OffloadActivations context by adding a custom __exit__.

When use_streams=True, __exit__ now synchronizes the compute/offload streams (s0, s1) and clears the forward/backward stash dictionaries before delegating to saved_tensors_hooks.__exit__, using a try/finally to ensure the parent hook cleanup always runs.

Reviewed by Cursor Bugbot for commit f1377a4. Bugbot is set up for automated code reviews on this repo. Configure here.

Comment thread trl/models/activation_offloading.py Outdated
Comment thread trl/models/activation_offloading.py Outdated
Comment thread trl/models/activation_offloading.py Outdated
Comment thread trl/models/activation_offloading.py Outdated
@butterwecksolutions butterwecksolutions force-pushed the stream-cleanup branch 2 times, most recently from 26ab3ee to 0fc7d4d Compare May 4, 2026 17:57
@butterwecksolutions butterwecksolutions changed the title fix: sync streams and clear stashes in OffloadActivations.__exit__ fix: prevent 5 GB+ CUDA memory leak in activation offloading by syncing streams and clear stashes in OffloadActivations.__exit__ May 4, 2026
@qgallouedec qgallouedec assigned kashif and unassigned kashif May 4, 2026
Comment thread trl/models/activation_offloading.py Outdated

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit e259a88. Configure here.

Comment thread trl/models/activation_offloading.py Outdated
Comment on lines +571 to +595
def __exit__(self, *args, **kwargs):
"""Sync streams and clear stashes before parent cleanup.

Prevents leaked CUDA blocks with garbage stream IDs when the
context manager exits with in-flight async copies or stashed
tensors still referencing the offload stream.

NOTE: tracker is NOT cleared — the backward pass unpacks tensors
via tracker AFTER __exit__ (see class docstring example).
"""
import torch
# 1. Sync both streams so async copies finish before cleanup
if self.use_streams and self.s0 is not None:
self.s0.synchronize()
if self.use_streams and self.s1 is not None:
self.s1.synchronize()

# 2. Clear stashed tensors (only exist when use_streams=True)
if self.use_streams:
self.bwd_tensor_stash.clear()
self.bwd_ev_stash.clear()
self.fwd_stash.clear()

return super().__exit__(*args, **kwargs)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __exit__(self, *args, **kwargs):
"""Sync streams and clear stashes before parent cleanup.
Prevents leaked CUDA blocks with garbage stream IDs when the
context manager exits with in-flight async copies or stashed
tensors still referencing the offload stream.
NOTE: tracker is NOT clearedthe backward pass unpacks tensors
via tracker AFTER __exit__ (see class docstring example).
"""
import torch
# 1. Sync both streams so async copies finish before cleanup
if self.use_streams and self.s0 is not None:
self.s0.synchronize()
if self.use_streams and self.s1 is not None:
self.s1.synchronize()
# 2. Clear stashed tensors (only exist when use_streams=True)
if self.use_streams:
self.bwd_tensor_stash.clear()
self.bwd_ev_stash.clear()
self.fwd_stash.clear()
return super().__exit__(*args, **kwargs)
def __exit__(self, *args, **kwargs):
"""Sync streams and clear stashes before parent cleanup."""
try:
if self.use_streams:
self.s0.synchronize()
self.s1.synchronize()
self.bwd_tensor_stash.clear()
self.bwd_ev_stash.clear()
self.fwd_stash.clear()
finally:
result = super().__exit__(*args, **kwargs)
return result

@kashif kashif May 6, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets make the parent __exit__ run in a finally so the saved tensor hooks are always popped. I’d also drop the is not None checks unless there is a real code path where self.s0 or self.s1 can be None. In this class, self.s0 is always initialized, and self.s1 exists whenever use_streams=True.

@kashif

kashif commented May 6, 2026

Copy link
Copy Markdown
Collaborator

thanks @butterwecksolutions if you can try my suggestion then we should be good!

Async CUDA streams (s0 for CPU→GPU, s1 for GPU→CPU) accumulate garbage
when tensors in bwd_tensor_stash, bwd_ev_stash, and fwd_stash are not
freed on context manager exit. This grows VRAM ~0.2 GB/step until OOM
at step 19 on RTX 3090 (24 GB).

Fix: synchronize both streams + clear all stashes before cleanup.
try/finally ensures saved_tensors_hooks parent always runs, preventing
permanently installed hooks creating a silent memory leak.

Verified: 7.0 GB VRAM saved in VL training (20.3→13.3 GB) via systematic
14-test reproducer with fresh GPU isolation per test.

Co-Authored-By: Claude Opus 4.6 <noreply@openclaude.dev>
@butterwecksolutions

butterwecksolutions commented May 7, 2026

Copy link
Copy Markdown
Contributor Author

Fix Verified — 7.0 GB Saved via Systematic Isolation

Updated to @kashif's suggested pattern (try/finally, no is-not-None guards,
no tracker.clear).

Then verified against a 14-test reproducer (7 patches × 2 modes, fresh
GPU isolation per test, Qwopus3.5-9B, RTX 3090, PyTorch 2.10, TRL 0.29.0):

Mode Baseline P3 (#5700) Saved
VL (seq_len=512) 20.3 GB 13.3 GB 7.0 GB
TEXT (seq_len=4096) 8.5 GB 8.5 GB — (leak VL-triggered)

Sole root cause among 5 candidate fixes:

  • P1 (bitsandbytes-foundation/bitsandbytes 1935) = 0.0 GB → closed
  • P2 (transformers #45769) = 0.0 GB → closed
  • P5 (naive backward hook) = OOM at 24.2 GB

Thanks @kashif for the review and the suggestion — the try/finally is the
right call. Thanks @qgallouedec for merging #5694 and maintaining TRL.
Happy I could contribute again!

Full reproducer, rendered HTML report, raw training logs:

Viele Grüße aus Arnsberg! 🇩🇪

@kashif kashif self-requested a review May 7, 2026 07:02
Comment thread trl/models/activation_offloading.py Outdated
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif kashif merged commit 8a6cc03 into huggingface:main May 7, 2026
12 checks passed
butterwecksolutions added a commit to butterwecksolutions/trl that referenced this pull request May 8, 2026
After huggingface#5700 fixed the CUDA stream leak, a residual VRAM leak from
bitsandbytes 4-bit dequantization buffers remains: the tracker dict
retains references to tensors whose storage shares allocator blocks
with BNB buffers, and torch.cuda.empty_cache() is never called.
This causes ~0.6 GiB monotonic VRAM growth per step, triggering
OOM after 30-40 steps on 24 GB GPUs.

Fix: clear tracker + storage_to_tensor_id after parent __exit__
and call accelerator.empty_cache() to release cached BNB buffer
blocks. Uses explicit if/elif dispatcher matching the file's
established pattern (no getattr).
butterwecksolutions added a commit to butterwecksolutions/trl that referenced this pull request May 8, 2026
After huggingface#5700 fixed the CUDA stream leak, a residual VRAM leak from
bitsandbytes 4-bit dequantization buffers remains: the tracker dict
retains references to tensors whose storage shares allocator blocks
with BNB buffers, and torch.cuda.empty_cache() is never called.
This causes ~0.6 GiB monotonic VRAM growth per step, triggering
OOM after 30-40 steps on 24 GB GPUs.

Fix: clear tracker + storage_to_tensor_id after parent __exit__
and call accelerator.empty_cache() to release cached BNB buffer
blocks. Uses explicit if/elif dispatcher matching the file's
established pattern (no getattr).
butterwecksolutions added a commit to butterwecksolutions/trl that referenced this pull request May 8, 2026
After huggingface#5700 fixed the CUDA stream leak, a residual VRAM leak from
bitsandbytes 4-bit dequantization buffers remains: the tracker dict
retains references to tensors whose storage shares allocator blocks
with BNB buffers, and torch.cuda.empty_cache() is never called.
This causes ~0.6 GiB monotonic VRAM growth per step, triggering
OOM after 30-40 steps on 24 GB GPUs.

Fix: clear tracker + storage_to_tensor_id after parent __exit__
and call accelerator.empty_cache() to release cached BNB buffer
blocks. Uses explicit if/elif dispatcher matching the file's
established pattern (no getattr).
butterwecksolutions added a commit to butterwecksolutions/trl that referenced this pull request May 8, 2026
After huggingface#5700 fixed the CUDA stream leak, a residual VRAM leak from
bitsandbytes 4-bit dequantization buffers remains: the tracker dict
retains references to tensors whose storage shares allocator blocks
with BNB buffers, and torch.cuda.empty_cache() is never called.
This causes ~0.6 GiB monotonic VRAM growth per step, triggering
OOM after 30-40 steps on 24 GB GPUs.

Fix: clear tracker + storage_to_tensor_id after parent __exit__
and call accelerator.empty_cache() to release cached BNB buffer
blocks. Uses explicit if/elif dispatcher matching the file's
established pattern (no getattr).
butterwecksolutions added a commit to butterwecksolutions/trl that referenced this pull request May 9, 2026
After huggingface#5700 fixed the CUDA stream leak, a residual VRAM leak from
bitsandbytes 4-bit dequantization buffers remains: the tracker dict
retains references to tensors whose storage shares allocator blocks
with BNB buffers, and torch.cuda.empty_cache() is never called.
This causes ~0.6 GiB monotonic VRAM growth per step, triggering
OOM after 30-40 steps on 24 GB GPUs.

Fix: clear tracker + storage_to_tensor_id after parent __exit__
and call accelerator.empty_cache() to release cached BNB buffer
blocks. Uses explicit if/elif dispatcher matching the file's
established pattern (no getattr).
butterwecksolutions added a commit to butterwecksolutions/trl that referenced this pull request May 9, 2026
After huggingface#5700 fixed the CUDA stream leak, a residual VRAM leak remains
from stale state and BNB 4-bit dequantization buffers.

Two independent leak paths are fixed:

1. __enter__ (stale state from previous step):
   - MoE + sample_packing + torch.compile: dynamic expert routing leaves
     saved tensors on subgraphs that never contribute to loss. Their
     backward nodes never execute, so tracker/state from previous step
     survives into the next.
   - Fix: clear tracker, storage_to_tensor_id, stashes, tensor_id, and
     forward/backward flags in __enter__ before re-registering hooks
     via super().__enter__().

2. __exit__ (BNB dequantization buffers after parent cleanup):
   - QLoRA training: BNB 4-bit dequantization buffers accumulate because
     tracker retains references to tensors sharing allocator blocks with
     BNB buffers, and empty_cache() is never called.
   - ~0.6 GiB monotonic VRAM growth per step, OOM after 30-40 steps.
   - Fix: tracker.clear(), storage_to_tensor_id.clear(), tensor_id=0,
     state reset, and accelerator-aware empty_cache() in finally block.
     Conditional on bitsandbytes in sys.modules to avoid penalizing
     non-BNB workloads. Empty-cache uses explicit if/elif dispatch
     (xpu/npu/cuda) matching the file's established pattern.

A complete audit confirms no further leak sources exist.
butterwecksolutions added a commit to butterwecksolutions/trl that referenced this pull request May 9, 2026
Two independent VRAM leak paths in OffloadActivations are fixed
by cleaning up stale state and releasing allocator cache blocks
in __enter__, where the previous backward is guaranteed complete:

1. MoE + sample_packing + torch.compile — saved tensors on subgraphs
   whose backward nodes never execute leak ~60 tensors/micro-step
   because the unpack-then-delete logic never fires for them.

2. QLoRA BNB 4-bit dequantization buffers — tracker references keep
   allocator blocks alive across steps, and empty_cache() is never
   called (~0.6 GiB/step, OOM after 30-40 steps on 24 GB GPUs).

__enter__ clears tracker, storage_to_tensor_id, tensor_id, stashes,
and calls accelerator-aware empty_cache() (conditional on bitsandbytes
in sys.modules to avoid penalizing non-BNB workloads).

__exit__ handles stream sync and stash cleanup as before (huggingface#5700).
All cleanup uses explicit if/elif dispatch matching the file's
established accelerator pattern.
butterwecksolutions added a commit to butterwecksolutions/trl that referenced this pull request May 9, 2026
Two independent VRAM leak paths in OffloadActivations are fixed
by cleaning up stale state and releasing allocator cache blocks
in __enter__, where the previous backward has already completed:

1. MoE + sample_packing + torch.compile — saved tensors on subgraphs
   whose backward nodes never execute leak ~60 tensors/micro-step
   because the unpack-then-delete logic never fires for them.

2. QLoRA BNB 4-bit dequantization buffers — tracker references keep
   allocator blocks alive across steps, and empty_cache() is never
   called (~0.6 GiB/step, OOM after 30-40 steps on 24 GB GPUs).

__enter__ clears tracker, storage_to_tensor_id, tensor_id, stashes,
and calls accelerator-aware empty_cache() (conditional on bitsandbytes
in sys.modules to avoid penalizing non-BNB workloads).

__exit__ handles stream sync and stash cleanup as before (huggingface#5700).
All cleanup uses explicit if/elif dispatch matching the file's
established accelerator pattern.
@butterwecksolutions butterwecksolutions mentioned this pull request May 9, 2026
8 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants