diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index b6cff43f7a3..a70655a5828 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -538,6 +538,21 @@ def hook(outputs, inputs): unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream super().__init__(pack_tensor, unpack_tensor) + def __enter__(self): + # Drop stale state from any prior step where saved tensors didn't unpack + # (e.g. MoE expert paths under torch.compile). is_first_forward_call only + # resets when tracker empties during backward, so leaked entries pin GPU + # memory across iterations -> linear VRAM leak + self.tracker.clear() + self.storage_to_tensor_id.clear() + if self.use_streams: + self.fwd_stash.clear() + self.bwd_tensor_stash.clear() + self.bwd_ev_stash.clear() + self.is_first_forward_call = True + self.is_first_backward_call = True + return super().__enter__() + def update_model_params(self, model: nn.Module): """ Update the set of parameter storage pointers from the model. This allows filtering out model parameters during