Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@
"is_timm_available",
"is_tokenizers_available",
"is_torch_available",
"is_torch_neuroncore_available",
"is_torch_tpu_available",
"is_vision_available",
"logging",
Expand Down Expand Up @@ -3942,6 +3943,7 @@
is_timm_available,
is_tokenizers_available,
is_torch_available,
is_torch_neuroncore_available,
is_torch_tpu_available,
is_vision_available,
logging,
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
is_torch_available,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_neuroncore_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
is_torch_tpu_available,
Expand Down Expand Up @@ -500,6 +501,15 @@ def require_torch_tpu(test_case):
return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case)


def require_torch_neuroncore(test_case):
"""
Decorator marking a test that requires NeuronCore (in PyTorch).
"""
return unittest.skipUnless(is_torch_neuroncore_available(check_device=False), "test requires PyTorch NeuronCore")(
test_case
)


if is_torch_available():
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
import torch
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
is_torch_available,
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_neuroncore_available,
is_torch_tf32_available,
is_torch_tpu_available,
logging,
Expand All @@ -60,6 +61,17 @@
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm

if is_torch_neuroncore_available(check_device=False):
# torchrun support
# https://github.com/pytorch/xla/pull/3609
if os.environ.get("TORCHELASTIC_RUN_ID"):
import torch_xla.distributed.xla_backend as xbn

if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
torch.distributed.init_process_group(backend="xla")
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")


if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
is_torch_cuda_available,
is_torch_fx_available,
is_torch_fx_proxy,
is_torch_neuroncore_available,
is_torch_onnx_dict_inputs_support_available,
is_torch_tensorrt_fx_available,
is_torch_tf32_available,
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,13 @@ def is_torch_tpu_available(check_device=True):
return False


@lru_cache()
def is_torch_neuroncore_available(check_device=True):
if importlib.util.find_spec("torch_neuronx") is not None:
return is_torch_tpu_available(check_device)
return False


def is_torchdynamo_available():
if not is_torch_available():
return False
Expand Down
18 changes: 18 additions & 0 deletions tests/trainer/test_trainer_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_gpu,
require_torch_neuroncore,
)
from transformers.utils import logging

Expand Down Expand Up @@ -62,6 +63,23 @@ def forward(self, input_ids, labels=None):
return input_ids


class TestTrainerDistributedNeuronCore(TestCasePlus):
@require_torch_neuroncore
def test_trainer(self):

distributed_args = f"""
-m torch.distributed.launch
--nproc_per_node=2
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_trainer_distributed.py
""".split()
output_dir = self.get_auto_remove_tmp_dir()
args = f"--output_dir {output_dir}".split()
cmd = [sys.executable] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call


class TestTrainerDistributed(TestCasePlus):
@require_torch_multi_gpu
def test_trainer(self):
Expand Down