From cdcfa6059a7727bb6559ddfcd533aeffb86fc80d Mon Sep 17 00:00:00 2001 From: Christian Butterweck Date: Sat, 9 May 2026 09:29:18 +0200 Subject: [PATCH 1/3] Release BNB dequantization buffers after activation offloading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two independent VRAM leak paths in OffloadActivations are fixed by cleaning up stale state and releasing allocator cache blocks in __enter__, where the previous backward has already completed: 1. MoE + sample_packing + torch.compile — saved tensors on subgraphs whose backward nodes never execute leak ~60 tensors/micro-step because the unpack-then-delete logic never fires for them. 2. QLoRA BNB 4-bit dequantization buffers — tracker references keep allocator blocks alive across steps, and empty_cache() is never called (~0.6 GiB/step, OOM after 30-40 steps on 24 GB GPUs). __enter__ clears tracker, storage_to_tensor_id, tensor_id, stashes, and calls accelerator-aware empty_cache() (conditional on bitsandbytes in sys.modules to avoid penalizing non-BNB workloads). __exit__ handles stream sync and stash cleanup as before (#5700). All cleanup uses explicit if/elif dispatch matching the file's established accelerator pattern. --- trl/models/activation_offloading.py | 44 +++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index b319d2ee8c3..8f0557d474d 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -19,6 +19,8 @@ # LICENSE file in the root directory of https://github.com/pytorch/torchtune. +import sys + import psutil import torch from accelerate import logging @@ -568,11 +570,49 @@ def update_model_params(self, model: nn.Module): self.param_storages = param_storages + def __enter__(self): + """Clear stale state and release BNB buffers before entering. + + By the time __enter__ is called, the previous forward/backward has + already completed, so anything still in tracker, + storage_to_tensor_id, or the stashes is leaked and safe to drop. + + Two leak paths are handled: + 1. MoE + sample_packing + torch.compile: dynamic expert routing may + leave saved tensors on subgraphs whose backward nodes never execute, + so the unpack-then-delete logic never fires. tracker/stashes from + the previous step survive into the next. + 2. QLoRA BNB dequantization buffers: tracker retains references to + tensors sharing allocator blocks with BNB buffers, and the allocator + cache is never flushed between steps (~0.6 GiB/step, OOM after 30-40). + + Returns super().__enter__() to register pack/unpack hooks via + saved_tensors_hooks (PyTorch autograd engine). + """ + self.tracker.clear() + self.storage_to_tensor_id.clear() + self.tensor_id = 0 + self.is_first_forward_call = True + self.is_first_backward_call = True + if self.use_streams: + self.bwd_tensor_stash.clear() + self.bwd_ev_stash.clear() + self.fwd_stash.clear() + if "bitsandbytes" in sys.modules: + if self.accelerator_type == "xpu": + torch.xpu.empty_cache() + elif is_torch_npu_available() and self.accelerator_type == "npu": + torch.npu.empty_cache() + else: + torch.cuda.empty_cache() + return super().__enter__() + def __exit__(self, *args, **kwargs): """Sync streams and clear stashes before parent cleanup. - try/finally ensures the saved_tensors_hooks parent cleanup runs even if stream sync raises — otherwise hooks - stay permanently installed, creating a silent memory leak. + try/finally ensures the saved_tensors_hooks parent cleanup runs even if + stream sync raises — otherwise hooks stay permanently installed, creating + a silent memory leak. """ try: if self.use_streams: From 909416ac53461bad258dfc51ffa1e37ca51827cc Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 9 May 2026 12:10:40 +0200 Subject: [PATCH 2/3] test activation offloading stale state --- tests/test_activation_offloading.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py index 35d8432d81d..5f2f6a03b4b 100644 --- a/tests/test_activation_offloading.py +++ b/tests/test_activation_offloading.py @@ -196,6 +196,34 @@ def forward(self, x): loss.backward() + @require_torch_accelerator + def test_stale_tracker_state_is_cleared_between_forwards(self): + """Test that tensors from unused graph branches don't accumulate across steps.""" + + class ModelWithUnusedBranch(nn.Module): + def __init__(self): + super().__init__() + self.used = nn.Linear(8, 8) + self.unused = nn.Linear(8, 8) + + def forward(self, x): + return self.used(x).sum(), self.unused(x).sum() + + model = ModelWithUnusedBranch().to(torch_device) + offload_ctx = OffloadActivations(use_pin_memory=False, use_streams=False, min_offload_size=1) + offload_ctx.update_model_params(model) + inp = torch.randn(4, 8, device=torch_device) + + tracker_counts = [] + for _ in range(3): + model.zero_grad(set_to_none=True) + with offload_ctx: + loss, _ = model(inp) + loss.backward() + tracker_counts.append(len(offload_ctx.tracker)) + + assert tracker_counts == [tracker_counts[0]] * len(tracker_counts) + @require_torch_accelerator def test_parameter_filtering(self): """Test that model parameters are filtered during offloading""" From 22b5210a4d397c474ed3c86ebdad92915caf79b2 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 9 May 2026 12:13:10 +0200 Subject: [PATCH 3/3] style activation offloading docstrings --- trl/models/activation_offloading.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index 8f0557d474d..b6cff43f7a3 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -573,21 +573,17 @@ def update_model_params(self, model: nn.Module): def __enter__(self): """Clear stale state and release BNB buffers before entering. - By the time __enter__ is called, the previous forward/backward has - already completed, so anything still in tracker, - storage_to_tensor_id, or the stashes is leaked and safe to drop. + By the time __enter__ is called, the previous forward/backward has already completed, so anything still in + tracker, storage_to_tensor_id, or the stashes is leaked and safe to drop. Two leak paths are handled: - 1. MoE + sample_packing + torch.compile: dynamic expert routing may - leave saved tensors on subgraphs whose backward nodes never execute, - so the unpack-then-delete logic never fires. tracker/stashes from - the previous step survive into the next. - 2. QLoRA BNB dequantization buffers: tracker retains references to - tensors sharing allocator blocks with BNB buffers, and the allocator - cache is never flushed between steps (~0.6 GiB/step, OOM after 30-40). - - Returns super().__enter__() to register pack/unpack hooks via - saved_tensors_hooks (PyTorch autograd engine). + 1. MoE + sample_packing + torch.compile: dynamic expert routing may leave saved tensors on subgraphs whose + backward nodes never execute, so the unpack-then-delete logic never fires. tracker/stashes from the previous + step survive into the next. + 2. QLoRA BNB dequantization buffers: tracker retains references to tensors sharing allocator blocks with BNB + buffers, and the allocator cache is never flushed between steps (~0.6 GiB/step, OOM after 30-40). + + Returns super().__enter__() to register pack/unpack hooks via saved_tensors_hooks (PyTorch autograd engine). """ self.tracker.clear() self.storage_to_tensor_id.clear() @@ -610,9 +606,8 @@ def __enter__(self): def __exit__(self, *args, **kwargs): """Sync streams and clear stashes before parent cleanup. - try/finally ensures the saved_tensors_hooks parent cleanup runs even if - stream sync raises — otherwise hooks stay permanently installed, creating - a silent memory leak. + try/finally ensures the saved_tensors_hooks parent cleanup runs even if stream sync raises — otherwise hooks + stay permanently installed, creating a silent memory leak. """ try: if self.use_streams: