Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions trl/models/activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,23 @@ def update_model_params(self, model: nn.Module):

self.param_storages = param_storages

def __exit__(self, *args, **kwargs):
"""Sync streams and clear stashes before parent cleanup.

try/finally ensures the saved_tensors_hooks parent cleanup runs even if stream sync raises — otherwise hooks
stay permanently installed, creating a silent memory leak.
"""
try:
if self.use_streams:
self.s0.synchronize()
self.s1.synchronize()
self.bwd_tensor_stash.clear()
self.bwd_ev_stash.clear()
self.fwd_stash.clear()
finally:
result = super().__exit__(*args, **kwargs)
return result

Comment on lines +571 to +587

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __exit__(self, *args, **kwargs):
"""Sync streams and clear stashes before parent cleanup.
Prevents leaked CUDA blocks with garbage stream IDs when the
context manager exits with in-flight async copies or stashed
tensors still referencing the offload stream.
NOTE: tracker is NOT clearedthe backward pass unpacks tensors
via tracker AFTER __exit__ (see class docstring example).
"""
import torch
# 1. Sync both streams so async copies finish before cleanup
if self.use_streams and self.s0 is not None:
self.s0.synchronize()
if self.use_streams and self.s1 is not None:
self.s1.synchronize()
# 2. Clear stashed tensors (only exist when use_streams=True)
if self.use_streams:
self.bwd_tensor_stash.clear()
self.bwd_ev_stash.clear()
self.fwd_stash.clear()
return super().__exit__(*args, **kwargs)
def __exit__(self, *args, **kwargs):
"""Sync streams and clear stashes before parent cleanup."""
try:
if self.use_streams:
self.s0.synchronize()
self.s1.synchronize()
self.bwd_tensor_stash.clear()
self.bwd_ev_stash.clear()
self.fwd_stash.clear()
finally:
result = super().__exit__(*args, **kwargs)
return result

@kashif kashif May 6, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets make the parent __exit__ run in a finally so the saved tensor hooks are always popped. I’d also drop the is not None checks unless there is a real code path where self.s0 or self.s1 can be None. In this class, self.s0 is always initialized, and self.s1 exists whenever use_streams=True.


class NoOpManager(saved_tensors_hooks):
"""
Expand Down
Loading