diff --git a/test/allreduce_torchrun.py b/test/allreduce_torchrun.py new file mode 100644 index 000000000000..173999e61b34 --- /dev/null +++ b/test/allreduce_torchrun.py @@ -0,0 +1,56 @@ +import argparse +import os +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch.distributed as dist +import torch_xla.distributed.xla_multiprocessing as xmp +from torch_xla.distributed.xrt_init import init_xrt_context +import torch_xla.distributed.xla_backend + + +def _mp_fn_xrt_init(): + rank = int(os.environ['RANK']) + size = int(os.environ['WORLD_SIZE']) + + init_xrt_context() + + device = xm.xla_device() + ones = torch.ones((2, 3)) + xones = ones.to(device) + result = xm.all_reduce('sum', xones) + + result_cpu = result.cpu() + expected = torch.ones((2, 3)) * size + assert torch.all(result_cpu == expected), f'{result_cpu} != {expected}' + + +def _mp_fn_xla_backend(): + rank = int(os.environ['RANK']) + size = int(os.environ['WORLD_SIZE']) + + dist.init_process_group('xla') + device = xm.xla_device() + + ones = torch.ones((2, 3)) + xones = ones.to(device) + dist.all_reduce(xones, op=torch.distributed.ReduceOp.SUM) + + result_cpu = xones.cpu() + expected = torch.ones((2, 3)) * size + assert torch.all(xones.cpu() == expected), f'{xones} != {expected}' + + +if __name__ == '__main__': + print( + 'master_port:{}, master_addr:{}, rank:{}, local_rank:{}, size:{}'.format( + os.environ['MASTER_PORT'], os.environ['MASTER_ADDR'], + os.environ['RANK'], os.environ['LOCAL_RANK'], + os.environ['WORLD_SIZE'])) + parser = argparse.ArgumentParser() + parser.add_argument('--use_xla_backend', action="store_true") + args = parser.parse_args() + if args.use_xla_backend: + _mp_fn_xla_backend() + else: + _mp_fn_xrt_init() diff --git a/test/run_tests.sh b/test/run_tests.sh index 0338bd0401ae..8ab19d244f7b 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -91,6 +91,15 @@ function run_async_scalar { XLA_TRANSFER_SCALAR_ASYNC=1 run_test "$@" } +function run_torchrun { + echo "Running tests spawned by torchrun" + if [ -x "$(command -v nvidia-smi)" ]; then + run_test "$@" + else + echo "the tests need atleast two XLA workers to validate" + fi +} + function run_op_tests { run_dynamic python3 "$CDIR/../../test/test_view_ops.py" "$@" -v TestViewOpsXLA run_test python3 "$CDIR/../../test/test_torch.py" "$@" -v TestTorchDeviceTypeXLA @@ -144,6 +153,7 @@ function run_mp_op_tests { run_test python3 "$CDIR/test_mp_save.py" run_test python3 "$CDIR/test_mp_mesh_reduce.py" run_test python3 "$CDIR/test_mp_sync_batch_norm.py" + run_torchrun python3 "$CDIR/test_allreduce_torchrun.py" run_xla_backend_mp python3 "$CDIR/test_torch_distributed_all_gather_xla_backend.py" run_xla_backend_mp python3 "$CDIR/test_torch_distributed_all_reduce_xla_backend.py" run_xla_backend_mp python3 "$CDIR/test_torch_distributed_multi_all_reduce_xla_backend.py" diff --git a/test/test_allreduce_torchrun.py b/test/test_allreduce_torchrun.py new file mode 100644 index 000000000000..c81bb3abf445 --- /dev/null +++ b/test/test_allreduce_torchrun.py @@ -0,0 +1,26 @@ +import os +import subprocess +import pathlib + + +def test_local_torchrun_xrt_init(): + # This test launches a allreduce using torchrun launcher, uses native xla_model CCop + ci_dir = pathlib.Path(__file__).parent.resolve() + cmd = f'torchrun --nproc_per_node=2 --master_addr=127.0.0.1 --master_port=2020 {ci_dir}/allreduce_torchrun.py' + proc = subprocess.Popen(cmd, shell=True) + return_code = proc.wait() + assert return_code == 0 + + +def test_local_torchrun_xla_backend(): + # This test launches a allreduce using torchrun launcher, uses xla backend + ci_dir = pathlib.Path(__file__).parent.resolve() + cmd = f'torchrun --nproc_per_node=2 --master_addr=127.0.0.1 --master_port=2020 {ci_dir}/allreduce_torchrun.py --use_xla_backend' + proc = subprocess.Popen(cmd, shell=True) + return_code = proc.wait() + assert return_code == 0 + + +if __name__ == '__main__': + test_local_torchrun_xrt_init() + test_local_torchrun_xla_backend() diff --git a/torch_xla/distributed/_xrt_run_server.py b/torch_xla/distributed/_xrt_run_server.py new file mode 100644 index 000000000000..e8ed8feb339e --- /dev/null +++ b/torch_xla/distributed/_xrt_run_server.py @@ -0,0 +1,45 @@ +""" +This script is for starting the xrt_server. It also polls the PID and +checks if it exist. It would kill the server, when the process whose +PID it was tracking dies. +NOTE: This script should be used only by xrt_init.py and not anyone else. +""" +import os +import argparse +import psutil +import time +import signal +import multiprocessing +import torch_xla + + +def _polling(pid_to_track): + + def is_pid_alive(pid): + # The idea behind this is: if the process doesn't exist, + # getting a process status should throw an error. + # If the process exist, then we check if it hasn't gone + # into zombie state. This can happen when we run torchrun + # from neuron_parallel_compile. + try: + return psutil.Process(pid).status() != psutil.STATUS_ZOMBIE + except: + return False + + while is_pid_alive(pid_to_track): + time.sleep(10) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--port", required=True) + parser.add_argument("--pid_to_track", default=None) + args = parser.parse_args() + polling_process = multiprocessing.Process( + target=_polling, args=(int(args.pid_to_track),)) + server_process = multiprocessing.Process( + target=torch_xla._XLAC._run_xrt_local_service, args=(int(args.port),)) + polling_process.start() + server_process.start() + polling_process.join() + os.kill(server_process.pid, signal.SIGKILL) diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index f580c9c78712..b89c92ed295e 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -9,6 +9,7 @@ ProcessGroup, Work, ) +from .xrt_init import init_xrt_context def _create_xla_process_group(prefix_store, rank, size, timeout): @@ -41,6 +42,13 @@ def __init__(self, prefix_store, rank, size, timeout): self.prefix_store = prefix_store # reserved for future use. self.timeout = timeout self._mesh = [] + # Initialize xrt neuron environment + # Passes in the store created by torch.distributed to avoid + # creating two TCP stores. We only want to call this + # when the user is using torchrun and not xmp.spawn() + # or some other flow. + if os.getenv('TORCHELASTIC_RUN_ID') != None: + init_xrt_context(store=prefix_store) def getBackendName(self): return 'xla' diff --git a/torch_xla/distributed/xrt_init.py b/torch_xla/distributed/xrt_init.py new file mode 100644 index 000000000000..9bd1324e71fa --- /dev/null +++ b/torch_xla/distributed/xrt_init.py @@ -0,0 +1,250 @@ +import os +import re +import socket +import subprocess +import torch.distributed as dist +import torch_xla.core.xla_model as xm +import torch_xla.core.xla_env_vars as xenv +from torch_xla.utils.utils import get_free_tcp_ports +from torch_xla.distributed.xla_multiprocessing import _get_devices_per_worker + +XRT_SERVER_REGEX = 'torch_xla.distributed._xrt_run_server' +_TCP_STORE = None +_INIT_XRT_ALREADY_CALLED = False + + +def _create_devices(dev_kind, world_size): + # Create global XLA devices. Adapted from xmp.spawn() to function across nodes + devices = [] + dev_type = 'GPU' + + for gindex in range(0, world_size): + tfdevice = f'{dev_type}:{gindex};/job:localservice/replica:0/task:{gindex}/device:XLA_{dev_type}:0' + devices.append(tfdevice) + os.environ[xenv.DEVICE_MAP] = '|'.join(devices) + + +def _setup_workers(world_size, rank, local_world_size, local_rank): + # Set up workers across nodes. xmp.spawn() does this locally by figuring out free ports on the node + # We do this globally by doing an allgather of locally obtained free socket addresses + # Note that this follows the original scheme, in the new scheme only one address per node needs exchange + host = socket.gethostname() + if local_rank == 0: + ports = [str(i) for i in get_free_tcp_ports(local_world_size)] + _TCP_STORE.set(host, ' '.join(ports)) + else: + ports_str = _TCP_STORE.get(host).decode('UTF-8') + ports = list(ports_str.split(' ')) + + my_worker = '{}:{};grpc://{}:{}'.format('localservice', rank, host, + ports[local_rank]) + all_workers = [] + for i in range(0, world_size): + if rank == i: + _TCP_STORE.set(f'worker:{i}', my_worker) + all_workers.append(my_worker) + else: + worker = _TCP_STORE.get(f'worker:{i}').decode('UTF-8') + all_workers.append(worker) + os.environ['XRT_WORKERS'] = '|'.join(all_workers) + + +def _get_address_from_store(key, rank): + if rank == 0: + port = get_free_tcp_ports()[0] + host = socket.getfqdn() + service_addr = '{}:{}'.format(host, port) + _TCP_STORE.set(key, service_addr) + else: + service_addr = _TCP_STORE.get(key).decode('UTF-8') + + return service_addr + + +def _set_mesh_config(rank): + address = _get_address_from_store('xrt_mesh_config', rank) + if not os.environ.get(xenv.SERVICE_ADDRESS, None): + os.environ[xenv.SERVICE_ADDRESS] = address + if not os.environ.get("TPU_MESH_CONTROLLER_ADDRESS", None): + address = _get_address_from_store('tpu_mesh_config', rank) + _, port = address.split(":") + os.environ["TPU_MESH_CONTROLLER_ADDRESS"] = address + os.environ["TPU_MESH_CONTROLLER_PORT"] = port + + +def _set_tpu_xrt_envs(local_rank, rank, group_rank, local_world_size, + world_size): + total_nodes = world_size // local_world_size + + xrt_tpu_config = [] + tpu_config_port = None + for i in range(total_nodes): + key = f'worker_{i}_address' + if group_rank == i and local_rank == 0: + tpu_config_port = get_free_tcp_ports()[0] + host = socket.getfqdn() + address = '{}:{}'.format(host, tpu_config_port) + _TCP_STORE.set(key, address) + else: + address = _TCP_STORE.get(key).decode('UTF-8') + if total_nodes == 1: + xrt_tpu_config.append(f'localservice;{i};{address}') + else: + xrt_tpu_config.append(f'c_localservice;{i};{address}') + + if rank == 0: + os.environ[xenv.TPU_CONFIG] = '|'.join(xrt_tpu_config) + os.environ[xenv.TPU_NUM_DEVICES] = str(local_world_size) + + os.environ[ + xenv. + LOCAL_WORKER] = f'localservice:{group_rank}' if total_nodes == 1 else f'c_localservice:{group_rank}' + os.environ[xenv.WORLD_SIZE] = str(world_size) + os.environ[xenv.HOST_WORLD_SIZE] = str(total_nodes) + os.environ[xenv.ORDINAL] = str(rank) + os.environ[xenv.LOCAL_ORDINAL] = str(local_rank) + os.environ[xenv.MP_DEVICE] = f'TPU:{rank}' + if not os.environ.get('TF_GRPC_DEFAULT_OPTIONS', None): + os.environ['TF_GRPC_DEFAULT_OPTIONS'] = ( + 'grpc.keepalive_time_ms=60000,grpc.keepalive_timeout_ms=14400000,' + 'grpc.http2.max_pings_without_data=0,grpc.http2.min_ping_interval_without_data_ms=300000' + ) + # We don't want torch_xla to start the local server internally. + # We are starting the xrt server by ourselves + os.environ['XRT_START_LOCAL_SERVER'] = '0' + + return tpu_config_port + + +def _set_neuron_envs(rank, world_size, local_world_size): + os.environ["NEURON_USE_LOAD_COLLECTIVES"] = '1' + os.environ['NEURON_GLOBAL_DEVICE_ID'] = str(rank) + os.environ['NEURON_GLOBAL_DEVICE_COUNT'] = str(world_size) + if not os.environ.get('NEURON_RT_VISIBLE_CORES', None): + os.environ['NEURON_RT_VISIBLE_CORES'] = ','.join( + [str(i) for i in range(local_world_size)]) + + +def _setup_nccl_service(dev_kind, rank): + # Set up NCCL COMM ID required for NCCL communicator IDs + address = _get_address_from_store('nccl_info', rank) + if dev_kind == 'NEURON': + os.environ['NEURON_RT_ROOT_COMM_ID'] = address + elif dev_kind == 'GPU': + os.environ['NEURON_RT_ROOT_COMM_ID'] = address + os.environ['XRT_MESH_SERVICE_ADDRESS'] = address + else: + raise RuntimeError('NCCL service setup failed!') + + +def set_xrt_envs(world_size, rank, local_rank): + # Set up all the XRT specific env variables, adapted from xmp.spawn() + os.environ[xenv.WORLD_SIZE] = str(world_size) + os.environ[xenv.ORDINAL] = str(rank) + os.environ[xenv.LOCAL_ORDINAL] = str(local_rank) + os.environ[xenv.LOCAL_WORKER] = 'localservice:' + str(rank) + + os.environ[xenv.MP_DEVICE] = f'GPU:{rank}' + gpus_to_use = os.environ.get('CUDA_VISIBLE_DEVICES') + if gpus_to_use is not None: + # If gpu devices are set by a scheduling entity (eg. SLURM) we index into + # comma separated string containing numbered gpu devies + gpus_to_use_list = gpus_to_use.split(',') + os.environ['CUDA_VISIBLE_DEVICES'] = gpus_to_use_list[local_rank] + else: + # If no explicit visible devices are provided, local_rank is used to identify + # the gpu used by this process + os.environ['CUDA_VISIBLE_DEVICES'] = str(local_rank) + + +def init_xrt_context(master_addr=None, master_port=None, store=None): + """Initializes the XLA device depending on the kind of the device. Or is a no-op if init_xrt_context + has already been called. + + Args: + master_addr (string): This is used to set up the TCPStore. If none is provided, it is obtained + from the environment variable. Also not required/used if store argument is passed in. + + master_port (int): This is used to set up the TCPStore. If none is provided, it is obtained from + environment variable. Also not required/used if store argument is passed in. + + store (TCPstore): A TCPstore object to use instead of creating a new one. If None a TCPStore object + will be setup for you. + Default: None + """ + global _INIT_XRT_ALREADY_CALLED + + if _INIT_XRT_ALREADY_CALLED: + return + + # Call this in the actual test case, to work with torch/xla workers + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) + world_size = int(os.environ['WORLD_SIZE']) + group_rank = int(os.environ['GROUP_RANK']) + local_world_size = int(os.environ['LOCAL_WORLD_SIZE']) + + if master_addr is None: + master_addr = os.environ['MASTER_ADDR'] + + if master_port is None: + master_port = os.environ['MASTER_PORT'] + + dev_list = os.listdir('/dev/') + #checking the dev kind, need similar filter for TPU + neuron_devs = list(filter(lambda v: re.match('neuron', v), dev_list)) + if neuron_devs: + dev_kind = 'NEURON' + else: + dev_kind = 'GPU' + + os.environ.pop(xenv.TPU_CONFIG, None) + os.environ.pop(xenv.TPU_NUM_DEVICES, None) + os.environ.pop(xenv.GPU_NUM_DEVICES, None) + + # This is required if we want to dynamically grab free ports. + # Useful in shared settings when we cannot predetermine what ports are taken. + is_server = True if rank is '0' else False + global _TCP_STORE + if store is None: + assert master_addr is not None + assert master_port is not None + _TCP_STORE = dist.TCPStore(master_addr, int(master_port), world_size, + is_server) + else: + _TCP_STORE = store + + node_list = None + + if dev_kind == 'NEURON': #similar check for TPU.. + tpu_config_port = _set_tpu_xrt_envs(local_rank, rank, group_rank, + local_world_size, world_size) + elif dev_kind == 'GPU': + _setup_nccl_service(dev_kind, rank) + set_xrt_envs(world_size, rank, local_rank) + _create_devices(dev_kind, world_size) + _setup_workers(world_size, rank, local_world_size, local_rank) + + _set_mesh_config(rank) + + if dev_kind == 'NEURON': #similar check for TPU.. + _setup_nccl_service(dev_kind, rank) + _set_neuron_envs(rank, world_size, local_world_size) + + total_nodes = world_size // local_world_size + if local_rank == 0: + local_env = os.environ.copy() + subprocess.Popen([ + 'python3', '-m', XRT_SERVER_REGEX, '--port', + str(tpu_config_port), '--pid_to_track', + str(os.getppid()) + ], + env=local_env, + start_new_session=True) + + dev = xm.xla_device() + xm.set_replication(dev, [dev]) + + # if we get to this point, we know the function completed successfully + # and we can switch the flag to True + _INIT_XRT_ALREADY_CALLED = True