diff --git a/optimum/habana/accelerate/state.py b/optimum/habana/accelerate/state.py index b8626d34f0..7ccbdbf593 100644 --- a/optimum/habana/accelerate/state.py +++ b/optimum/habana/accelerate/state.py @@ -76,7 +76,8 @@ def __init__(self, cpu: bool = False, **kwargs): self.process_index = rank self.local_process_index = local_rank if self.device is None: - self.device = torch.device("hpu", self.local_process_index) + # TODO: replace by `torch.device("hpu", self.local_process_index)` when hpu:x is supported + self.device = torch.device("hpu") else: self.distributed_type = GaudiDistributedType.NO self.num_processes = 1