Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,28 @@ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor,
return data.to(**kwargs)
return data

# TODO: investigate why the accelerator's autocast wrapper is not enough to trigger autocast in some edge cases
# see PR:
def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
"""
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
arguments, depending on the situation.
Modified by Habana to enable using `autocast` on Gaudi devices.
"""
if self.use_cpu_amp:
ctx_manager = torch.autocast(device_type="cpu", dtype=torch.bfloat16, cache_enabled=cache_enabled)
elif self.use_hpu_amp:
ctx_manager = torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True)
else:
ctx_manager = contextlib.nullcontext()

# Merge autocast context and `fp8_autocast` context if FP8 is enabled.
# Currently FP8 is enabled only for training.
if self.accelerator.fp8_enabled and self.model.training:
ctx_manager = FP8ContextWrapper(ctx_manager, fp8_recipe=self.accelerator.fp8_recipe)

return ctx_manager

def training_step(
self,
model: torch.nn.Module,
Expand Down
Loading