From 5a10fa336057f3b95785bb845fcc661ebf693a84 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 17 Sep 2025 11:48:58 +0000 Subject: [PATCH 1/2] revert --- optimum/habana/transformers/trainer.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index f1c6774dd7..1157f3a05f 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -1677,6 +1677,29 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, return data.to(**kwargs) return data + # TODO: investigate why the accelerator's autocast wrapper is not enough to trigger autocast in some edge cases + # see PR: + def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): + """ + A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired + arguments, depending on the situation. + Modified by Habana to enable using `autocast` on Gaudi devices. + """ + if self.use_cpu_amp: + ctx_manager = torch.autocast(device_type="cpu", dtype=torch.bfloat16, cache_enabled=cache_enabled) + elif self.use_hpu_amp: + ctx_manager = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True) + else: + ctx_manager = contextlib.nullcontext() + + # Merge autocast context and `fp8_autocast` context if FP8 is enabled. + # Currently FP8 is enabled only for training. + if self.accelerator.fp8_enabled and self.model.training: + ctx_manager = FP8ContextWrapper(ctx_manager, fp8_recipe=self.accelerator.fp8_recipe) + + return ctx_manager + + def training_step( self, model: torch.nn.Module, From a99eef4adc4c9336b96c8e467c7cca9fb1f8112b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Wed, 17 Sep 2025 11:50:42 +0000 Subject: [PATCH 2/2] style --- optimum/habana/transformers/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 1157f3a05f..4dd77b06fd 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -1699,7 +1699,6 @@ def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): return ctx_manager - def training_step( self, model: torch.nn.Module,