diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py index 35d8432d81d..5f2f6a03b4b 100644 --- a/tests/test_activation_offloading.py +++ b/tests/test_activation_offloading.py @@ -196,6 +196,34 @@ def forward(self, x): loss.backward() + @require_torch_accelerator + def test_stale_tracker_state_is_cleared_between_forwards(self): + """Test that tensors from unused graph branches don't accumulate across steps.""" + + class ModelWithUnusedBranch(nn.Module): + def __init__(self): + super().__init__() + self.used = nn.Linear(8, 8) + self.unused = nn.Linear(8, 8) + + def forward(self, x): + return self.used(x).sum(), self.unused(x).sum() + + model = ModelWithUnusedBranch().to(torch_device) + offload_ctx = OffloadActivations(use_pin_memory=False, use_streams=False, min_offload_size=1) + offload_ctx.update_model_params(model) + inp = torch.randn(4, 8, device=torch_device) + + tracker_counts = [] + for _ in range(3): + model.zero_grad(set_to_none=True) + with offload_ctx: + loss, _ = model(inp) + loss.backward() + tracker_counts.append(len(offload_ctx.tracker)) + + assert tracker_counts == [tracker_counts[0]] * len(tracker_counts) + @require_torch_accelerator def test_parameter_filtering(self): """Test that model parameters are filtered during offloading""" diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index b319d2ee8c3..b6cff43f7a3 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -19,6 +19,8 @@ # LICENSE file in the root directory of https://github.com/pytorch/torchtune. +import sys + import psutil import torch from accelerate import logging @@ -568,6 +570,39 @@ def update_model_params(self, model: nn.Module): self.param_storages = param_storages + def __enter__(self): + """Clear stale state and release BNB buffers before entering. + + By the time __enter__ is called, the previous forward/backward has already completed, so anything still in + tracker, storage_to_tensor_id, or the stashes is leaked and safe to drop. + + Two leak paths are handled: + 1. MoE + sample_packing + torch.compile: dynamic expert routing may leave saved tensors on subgraphs whose + backward nodes never execute, so the unpack-then-delete logic never fires. tracker/stashes from the previous + step survive into the next. + 2. QLoRA BNB dequantization buffers: tracker retains references to tensors sharing allocator blocks with BNB + buffers, and the allocator cache is never flushed between steps (~0.6 GiB/step, OOM after 30-40). + + Returns super().__enter__() to register pack/unpack hooks via saved_tensors_hooks (PyTorch autograd engine). + """ + self.tracker.clear() + self.storage_to_tensor_id.clear() + self.tensor_id = 0 + self.is_first_forward_call = True + self.is_first_backward_call = True + if self.use_streams: + self.bwd_tensor_stash.clear() + self.bwd_ev_stash.clear() + self.fwd_stash.clear() + if "bitsandbytes" in sys.modules: + if self.accelerator_type == "xpu": + torch.xpu.empty_cache() + elif is_torch_npu_available() and self.accelerator_type == "npu": + torch.npu.empty_cache() + else: + torch.cuda.empty_cache() + return super().__enter__() + def __exit__(self, *args, **kwargs): """Sync streams and clear stashes before parent cleanup.