Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions tests/test_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,34 @@ def forward(self, x):

loss.backward()

@require_torch_accelerator
def test_stale_tracker_state_is_cleared_between_forwards(self):
"""Test that tensors from unused graph branches don't accumulate across steps."""

class ModelWithUnusedBranch(nn.Module):
def __init__(self):
super().__init__()
self.used = nn.Linear(8, 8)
self.unused = nn.Linear(8, 8)

def forward(self, x):
return self.used(x).sum(), self.unused(x).sum()

model = ModelWithUnusedBranch().to(torch_device)
offload_ctx = OffloadActivations(use_pin_memory=False, use_streams=False, min_offload_size=1)
offload_ctx.update_model_params(model)
inp = torch.randn(4, 8, device=torch_device)

tracker_counts = []
for _ in range(3):
model.zero_grad(set_to_none=True)
with offload_ctx:
loss, _ = model(inp)
loss.backward()
tracker_counts.append(len(offload_ctx.tracker))

assert tracker_counts == [tracker_counts[0]] * len(tracker_counts)

@require_torch_accelerator
def test_parameter_filtering(self):
"""Test that model parameters are filtered during offloading"""
Expand Down
35 changes: 35 additions & 0 deletions trl/models/activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
# LICENSE file in the root directory of https://github.com/pytorch/torchtune.


import sys

import psutil
import torch
from accelerate import logging
Expand Down Expand Up @@ -568,6 +570,39 @@ def update_model_params(self, model: nn.Module):

self.param_storages = param_storages

def __enter__(self):
"""Clear stale state and release BNB buffers before entering.

By the time __enter__ is called, the previous forward/backward has already completed, so anything still in
tracker, storage_to_tensor_id, or the stashes is leaked and safe to drop.

Two leak paths are handled:
1. MoE + sample_packing + torch.compile: dynamic expert routing may leave saved tensors on subgraphs whose
backward nodes never execute, so the unpack-then-delete logic never fires. tracker/stashes from the previous
step survive into the next.
2. QLoRA BNB dequantization buffers: tracker retains references to tensors sharing allocator blocks with BNB
buffers, and the allocator cache is never flushed between steps (~0.6 GiB/step, OOM after 30-40).

Returns super().__enter__() to register pack/unpack hooks via saved_tensors_hooks (PyTorch autograd engine).
"""
self.tracker.clear()
self.storage_to_tensor_id.clear()
self.tensor_id = 0
self.is_first_forward_call = True
self.is_first_backward_call = True
if self.use_streams:
self.bwd_tensor_stash.clear()
self.bwd_ev_stash.clear()
self.fwd_stash.clear()
Comment thread
cursor[bot] marked this conversation as resolved.
if "bitsandbytes" in sys.modules:
if self.accelerator_type == "xpu":
torch.xpu.empty_cache()
elif is_torch_npu_available() and self.accelerator_type == "npu":
torch.npu.empty_cache()
else:
torch.cuda.empty_cache()
return super().__enter__()

def __exit__(self, *args, **kwargs):
"""Sync streams and clear stashes before parent cleanup.

Expand Down
Loading