From 4323d786ac31cc2d007eb9b13174ed129ab4a082 Mon Sep 17 00:00:00 2001 From: Harish Subramony Date: Fri, 12 Jan 2024 11:50:15 -0800 Subject: [PATCH 1/2] remove hpu:X notation untill fully supported by bridge --- optimum/habana/accelerate/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/accelerate/state.py b/optimum/habana/accelerate/state.py index b8626d34f0..c8cfa66363 100644 --- a/optimum/habana/accelerate/state.py +++ b/optimum/habana/accelerate/state.py @@ -76,7 +76,7 @@ def __init__(self, cpu: bool = False, **kwargs): self.process_index = rank self.local_process_index = local_rank if self.device is None: - self.device = torch.device("hpu", self.local_process_index) + self.device = torch.device("hpu") else: self.distributed_type = GaudiDistributedType.NO self.num_processes = 1 From 17f244c40651682411176846ac30faac6d4a7c65 Mon Sep 17 00:00:00 2001 From: Harish Subramony <81822986+hsubramony@users.noreply.github.com> Date: Tue, 16 Jan 2024 09:46:09 -0800 Subject: [PATCH 2/2] Update optimum/habana/accelerate/state.py Co-authored-by: regisss <15324346+regisss@users.noreply.github.com> --- optimum/habana/accelerate/state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/habana/accelerate/state.py b/optimum/habana/accelerate/state.py index c8cfa66363..7ccbdbf593 100644 --- a/optimum/habana/accelerate/state.py +++ b/optimum/habana/accelerate/state.py @@ -76,6 +76,7 @@ def __init__(self, cpu: bool = False, **kwargs): self.process_index = rank self.local_process_index = local_rank if self.device is None: + # TODO: replace by `torch.device("hpu", self.local_process_index)` when hpu:x is supported self.device = torch.device("hpu") else: self.distributed_type = GaudiDistributedType.NO