From 72db456d73551c3483b7f9d091f739aeb4be2dac Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 9 May 2026 08:20:15 +0530 Subject: [PATCH 1/2] cleanup vram --- trl/models/activation_offloading.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index b319d2ee8c3..732b4d02d19 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -536,6 +536,22 @@ def hook(outputs, inputs): unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream super().__init__(pack_tensor, unpack_tensor) + def __enter__(self): + # Drop stale state from any prior step where saved tensors didn't unpack + # (e.g. MoE expert paths under torch.compile). is_first_forward_call only + # resets when tracker empties during backward, so leaked entries pin GPU + # memory across iterations -> linear VRAM leak (axolotl #3638). Backward + # has completed by the time we re-enter, so anything still here is dead. + self.tracker.clear() + self.storage_to_tensor_id.clear() + if self.use_streams: + self.fwd_stash.clear() + self.bwd_tensor_stash.clear() + self.bwd_ev_stash.clear() + self.is_first_forward_call = True + self.is_first_backward_call = True + return super().__enter__() + def update_model_params(self, model: nn.Module): """ Update the set of parameter storage pointers from the model. This allows filtering out model parameters during From 6c77a2b4b72ca888d1817a5803ce6690401df265 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 9 May 2026 08:47:53 +0530 Subject: [PATCH 2/2] lint --- trl/models/activation_offloading.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index 732b4d02d19..4756cc4fb5d 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -540,8 +540,7 @@ def __enter__(self): # Drop stale state from any prior step where saved tensors didn't unpack # (e.g. MoE expert paths under torch.compile). is_first_forward_call only # resets when tracker empties during backward, so leaked entries pin GPU - # memory across iterations -> linear VRAM leak (axolotl #3638). Backward - # has completed by the time we re-enter, so anything still here is dead. + # memory across iterations -> linear VRAM leak self.tracker.clear() self.storage_to_tensor_id.clear() if self.use_streams: