diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8e91bacb26f6..59829f4f7fd7 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -577,6 +577,7 @@ "is_timm_available", "is_tokenizers_available", "is_torch_available", + "is_torch_neuroncore_available", "is_torch_tpu_available", "is_vision_available", "logging", @@ -3942,6 +3943,7 @@ is_timm_available, is_tokenizers_available, is_torch_available, + is_torch_neuroncore_available, is_torch_tpu_available, is_vision_available, logging, diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 31760557aa9c..149b31758485 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -83,6 +83,7 @@ is_torch_available, is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, + is_torch_neuroncore_available, is_torch_tensorrt_fx_available, is_torch_tf32_available, is_torch_tpu_available, @@ -500,6 +501,15 @@ def require_torch_tpu(test_case): return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case) +def require_torch_neuroncore(test_case): + """ + Decorator marking a test that requires NeuronCore (in PyTorch). + """ + return unittest.skipUnless(is_torch_neuroncore_available(check_device=False), "test requires PyTorch NeuronCore")( + test_case + ) + + if is_torch_available(): # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode import torch diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b5c6025176bf..1a907107f865 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -46,6 +46,7 @@ is_torch_available, is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, + is_torch_neuroncore_available, is_torch_tf32_available, is_torch_tpu_available, logging, @@ -60,6 +61,17 @@ if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm +if is_torch_neuroncore_available(check_device=False): + # torchrun support + # https://github.com/pytorch/xla/pull/3609 + if os.environ.get("TORCHELASTIC_RUN_ID"): + import torch_xla.distributed.xla_backend as xbn + + if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): + torch.distributed.init_process_group(backend="xla") + if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla): + raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.") + if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 353fe45e8e41..6e98c5716647 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -153,6 +153,7 @@ is_torch_cuda_available, is_torch_fx_available, is_torch_fx_proxy, + is_torch_neuroncore_available, is_torch_onnx_dict_inputs_support_available, is_torch_tensorrt_fx_available, is_torch_tf32_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 80ffd38c10ec..d1457b67095d 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -451,6 +451,13 @@ def is_torch_tpu_available(check_device=True): return False +@lru_cache() +def is_torch_neuroncore_available(check_device=True): + if importlib.util.find_spec("torch_neuronx") is not None: + return is_torch_tpu_available(check_device) + return False + + def is_torchdynamo_available(): if not is_torch_available(): return False diff --git a/tests/trainer/test_trainer_distributed.py b/tests/trainer/test_trainer_distributed.py index 6ed74efe510c..68d07e0f60f3 100644 --- a/tests/trainer/test_trainer_distributed.py +++ b/tests/trainer/test_trainer_distributed.py @@ -21,6 +21,7 @@ execute_subprocess_async, get_torch_dist_unique_port, require_torch_multi_gpu, + require_torch_neuroncore, ) from transformers.utils import logging @@ -62,6 +63,23 @@ def forward(self, input_ids, labels=None): return input_ids +class TestTrainerDistributedNeuronCore(TestCasePlus): + @require_torch_neuroncore + def test_trainer(self): + + distributed_args = f""" + -m torch.distributed.launch + --nproc_per_node=2 + --master_port={get_torch_dist_unique_port()} + {self.test_file_dir}/test_trainer_distributed.py + """.split() + output_dir = self.get_auto_remove_tmp_dir() + args = f"--output_dir {output_dir}".split() + cmd = [sys.executable] + distributed_args + args + execute_subprocess_async(cmd, env=self.get_env()) + # successful return here == success - any errors would have caused an error in the sub-call + + class TestTrainerDistributed(TestCasePlus): @require_torch_multi_gpu def test_trainer(self):