diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py index eef379a9324..b319d2ee8c3 100644 --- a/trl/models/activation_offloading.py +++ b/trl/models/activation_offloading.py @@ -568,6 +568,23 @@ def update_model_params(self, model: nn.Module): self.param_storages = param_storages + def __exit__(self, *args, **kwargs): + """Sync streams and clear stashes before parent cleanup. + + try/finally ensures the saved_tensors_hooks parent cleanup runs even if stream sync raises — otherwise hooks + stay permanently installed, creating a silent memory leak. + """ + try: + if self.use_streams: + self.s0.synchronize() + self.s1.synchronize() + self.bwd_tensor_stash.clear() + self.bwd_ev_stash.clear() + self.fwd_stash.clear() + finally: + result = super().__exit__(*args, **kwargs) + return result + class NoOpManager(saved_tensors_hooks): """