diff --git a/src/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py b/src/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py index 67e78a7d37c..2bbb324c8cd 100644 --- a/src/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py +++ b/src/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py @@ -14,7 +14,7 @@ import torch.distributed -from accelerate.test_utils import require_huggingface_suite +from accelerate.test_utils import require_huggingface_suite, torch_device from accelerate.utils import is_transformers_available @@ -27,7 +27,8 @@ @require_huggingface_suite def init_torch_dist_then_launch_deepspeed(): - torch.distributed.init_process_group(backend="nccl") + backend = "ccl" if torch_device == "xpu" else "nccl" + torch.distributed.init_process_group(backend=backend) deepspeed_config = { "zero_optimization": { "stage": 3,