Skip to content
56 changes: 56 additions & 0 deletions test/allreduce_torchrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import argparse
import os
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch.distributed as dist
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.distributed.xrt_init import init_xrt_context
import torch_xla.distributed.xla_backend


def _mp_fn_xrt_init():
rank = int(os.environ['RANK'])
size = int(os.environ['WORLD_SIZE'])

init_xrt_context()

device = xm.xla_device()
ones = torch.ones((2, 3))
xones = ones.to(device)
result = xm.all_reduce('sum', xones)

result_cpu = result.cpu()
expected = torch.ones((2, 3)) * size
assert torch.all(result_cpu == expected), f'{result_cpu} != {expected}'


def _mp_fn_xla_backend():
rank = int(os.environ['RANK'])
size = int(os.environ['WORLD_SIZE'])

dist.init_process_group('xla')
device = xm.xla_device()

ones = torch.ones((2, 3))
xones = ones.to(device)
dist.all_reduce(xones, op=torch.distributed.ReduceOp.SUM)

result_cpu = xones.cpu()
expected = torch.ones((2, 3)) * size
assert torch.all(xones.cpu() == expected), f'{xones} != {expected}'


if __name__ == '__main__':
print(
'master_port:{}, master_addr:{}, rank:{}, local_rank:{}, size:{}'.format(
os.environ['MASTER_PORT'], os.environ['MASTER_ADDR'],
os.environ['RANK'], os.environ['LOCAL_RANK'],
os.environ['WORLD_SIZE']))
parser = argparse.ArgumentParser()
parser.add_argument('--use_xla_backend', action="store_true")
args = parser.parse_args()
if args.use_xla_backend:
_mp_fn_xla_backend()
else:
_mp_fn_xrt_init()
10 changes: 10 additions & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ function run_async_scalar {
XLA_TRANSFER_SCALAR_ASYNC=1 run_test "$@"
}

function run_torchrun {
echo "Running tests spawned by torchrun"
if [ -x "$(command -v nvidia-smi)" ]; then
run_test "$@"
else
echo "the tests need atleast two XLA workers to validate"
fi
}

function run_op_tests {
run_dynamic python3 "$CDIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA
run_test python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA
Expand Down Expand Up @@ -144,6 +153,7 @@ function run_mp_op_tests {
run_test python3 "$CDIR/test_mp_save.py"
run_test python3 "$CDIR/test_mp_mesh_reduce.py"
run_test python3 "$CDIR/test_mp_sync_batch_norm.py"
run_torchrun python3 "$CDIR/test_allreduce_torchrun.py"
run_xla_backend_mp python3 "$CDIR/test_torch_distributed_all_gather_xla_backend.py"
run_xla_backend_mp python3 "$CDIR/test_torch_distributed_all_reduce_xla_backend.py"
run_xla_backend_mp python3 "$CDIR/test_torch_distributed_multi_all_reduce_xla_backend.py"
Expand Down
26 changes: 26 additions & 0 deletions test/test_allreduce_torchrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import subprocess
import pathlib


def test_local_torchrun_xrt_init():
# This test launches a allreduce using torchrun launcher, uses native xla_model CCop
ci_dir = pathlib.Path(__file__).parent.resolve()
cmd = f'torchrun --nproc_per_node=2 --master_addr=127.0.0.1 --master_port=2020 {ci_dir}/allreduce_torchrun.py'
proc = subprocess.Popen(cmd, shell=True)
return_code = proc.wait()
assert return_code == 0


def test_local_torchrun_xla_backend():
# This test launches a allreduce using torchrun launcher, uses xla backend
ci_dir = pathlib.Path(__file__).parent.resolve()
cmd = f'torchrun --nproc_per_node=2 --master_addr=127.0.0.1 --master_port=2020 {ci_dir}/allreduce_torchrun.py --use_xla_backend'
proc = subprocess.Popen(cmd, shell=True)
return_code = proc.wait()
assert return_code == 0


if __name__ == '__main__':
test_local_torchrun_xrt_init()
test_local_torchrun_xla_backend()
45 changes: 45 additions & 0 deletions torch_xla/distributed/_xrt_run_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
This script is for starting the xrt_server. It also polls the PID and
checks if it exist. It would kill the server, when the process whose
PID it was tracking dies.
NOTE: This script should be used only by xrt_init.py and not anyone else.
"""
import os
import argparse
import psutil
import time
import signal
import multiprocessing
import torch_xla


def _polling(pid_to_track):

def is_pid_alive(pid):
# The idea behind this is: if the process doesn't exist,
# getting a process status should throw an error.
# If the process exist, then we check if it hasn't gone
# into zombie state. This can happen when we run torchrun
# from neuron_parallel_compile.
try:
return psutil.Process(pid).status() != psutil.STATUS_ZOMBIE
except:
return False

while is_pid_alive(pid_to_track):
time.sleep(10)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--port", required=True)
parser.add_argument("--pid_to_track", default=None)
args = parser.parse_args()
polling_process = multiprocessing.Process(
target=_polling, args=(int(args.pid_to_track),))
server_process = multiprocessing.Process(
target=torch_xla._XLAC._run_xrt_local_service, args=(int(args.port),))
polling_process.start()
server_process.start()
polling_process.join()
os.kill(server_process.pid, signal.SIGKILL)
8 changes: 8 additions & 0 deletions torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ProcessGroup,
Work,
)
from .xrt_init import init_xrt_context


def _create_xla_process_group(prefix_store, rank, size, timeout):
Expand Down Expand Up @@ -41,6 +42,13 @@ def __init__(self, prefix_store, rank, size, timeout):
self.prefix_store = prefix_store # reserved for future use.
self.timeout = timeout
self._mesh = []
# Initialize xrt neuron environment
# Passes in the store created by torch.distributed to avoid
# creating two TCP stores. We only want to call this
# when the user is using torchrun and not xmp.spawn()
# or some other flow.
if os.getenv('TORCHELASTIC_RUN_ID') != None:
init_xrt_context(store=prefix_store)

def getBackendName(self):
return 'xla'
Expand Down
Loading