Skip to content

Conversation

@amithrm
Copy link
Contributor

@amithrm amithrm commented May 25, 2022

This pull request enables the code needed to integrate torchrun launcher with xla backend.

@JackCaoG
Copy link
Collaborator

@amithrm Can you provide a test for this new feature?

@amithrm
Copy link
Contributor Author

amithrm commented May 25, 2022

@JackCaoG sure..will add tests

@amithrm
Copy link
Contributor Author

amithrm commented Jun 8, 2022

@JackCaoG I changed the initialization a bit to take into account how slurm configures the devices. Please take a look at it and also the test cases. All of these need would need more modifications after we discuss

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jun 9, 2022

I have to admit that I am not an expert of torchrun, let me read up some documentations first lol. Looping in @will-cromar to make sure this does not conflict with our future pjrt runtime.

@amithrm
Copy link
Contributor Author

amithrm commented Jun 22, 2022

we did some internal testing. It appears that at scale, we see issues with the set up of GRPC channels. We should understand if you see similar issues at your end too.

@amithrm
Copy link
Contributor Author

amithrm commented Jun 23, 2022

@JackCaoG
A simple test that you can run on GPU-XLA:

import sys
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

import os


def _mp_fn(index):
  print('XRT_LOCAL_WORKER:{}'.format(os.environ['XRT_LOCAL_WORKER']))
  print('XRT_DEVICE_MAP:{}'.format(os.environ['XRT_DEVICE_MAP']))
  print('XRT_WORKERS:{}'.format(os.environ['XRT_WORKERS']))
  print('XRT_HOST_WORLD_SIZE:{}'.format(os.environ['XRT_HOST_WORLD_SIZE']))
  device = xm.xla_device()
  world_size = xm.xrt_world_size()
  ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
  print('rank:{}, value:{}'.format(index, ordinal_tensor))
  result = xm.all_reduce('sum', ordinal_tensor)

  cpu_result = result.cpu()
  print('rank:{}, value:{}'.format(index, cpu_result))


if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=(), nprocs=2, join=True)

@amithrm
Copy link
Contributor Author

amithrm commented Jun 23, 2022

Run command:

GPU_NUM_DEVICES=2 python3 allreduce_xla.py

This will output:

XRT_LOCAL_WORKER:localservice:0
XRT_DEVICE_MAP:GPU:0;/job:localservice/replica:0/task:0/device:XLA_GPU:0|GPU:1;/job:localservice/replica:0/task:1/device:XLA_GPU:0
XRT_WORKERS:localservice:0;grpc://dfda805bbe4b:49887|localservice:1;grpc://dfda805bbe4b:33097
XRT_LOCAL_WORKER:localservice:1
XRT_DEVICE_MAP:GPU:0;/job:localservice/replica:0/task:0/device:XLA_GPU:0|GPU:1;/job:localservice/replica:0/task:1/device:XLA_GPU:0
XRT_WORKERS:localservice:0;grpc://dfda805bbe4b:49887|localservice:1;grpc://dfda805bbe4b:33097

If you look for XRT_WORKERS, this has the grpc string for each worker. This won't scale with number of workers.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Nov 2, 2022

We want to support torch run with new PJRT run time, in the mean time if this torch run utility can unblock aws folks we can also take it.

I am a bit hesitant whether claim official support for XRT:TPU + torch run. @will-cromar Let's invesgate what's the gap here, if it is free we might as well just take it.

Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't conflict with PJRT, but we will have to add a new code path for it. Runtime configuration is much simpler under PJRT, so I expect that we can simplify a lot of this.

@JackCaoG what do you think of moving this under torch_xla.experimental for now?

cmd = "torchrun --nproc_per_node=2 --master_addr=127.0.0.1 --master_port=2020 allreduce_torchrun.py "

new_env0 = os.environ.copy()
new_env0['NEURON_RT_VISIBLE_CORES'] = '0,1'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with Neuron at all. Do we need to install anything extra beyond PyTorch and PyTorch/XLA to run this test? If not, can you add a GPU CI test like this?

xla/test/run_tests.sh

Lines 85 to 90 in 590dee5

if [ -x "$(command -v nvidia-smi)" ]; then
PJRT_DEVICE=GPU run_test "$@"
else
# TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue.
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_test "$@"
fi

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these are neuron specific env variables. Will add the GPU specific ones. But, we have PJRT_DEVICE listed above, will this interfere with the torchrun?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, there shouldn't be an issue. PJRT won't work with this PR, but you can add a new function to run_tests.sh that sets up anything neuron related and skips setting PJRT_DEVICE (i.e. make a run_neuron function like run_pjrt)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @will-cromar Looks like in the CI tests, the file allreduce_torchrun.py is not picked up. What is a good way to fix this? Is there a global prefix that I can add before the file name?

@JackCaoG
Copy link
Collaborator

JackCaoG commented Nov 2, 2022

in terms of moving the code to experimental I guess it is related to how do we want to set the user expection. Is there any known cavet for this feature? If it works well we don't need to put it in experimental. However we should add a README similar to https://github.com/pytorch/xla/blob/master/docs/ddp.md

@JackCaoG
Copy link
Collaborator

JackCaoG commented Dec 5, 2022

@will-cromar can you take another pass of this pr when you have some time?

Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall LGTM. Please add a CI test (run_tests.sh) and then we can merge this

@amithrm
Copy link
Contributor Author

amithrm commented Dec 8, 2022

Looks like the file need to test (allreduce_torchrun.py) is not getting picked up. Checking with @will-cromar on how to fix this. And some yapf fixes are pending in one file.

@amithrm amithrm force-pushed the xrt_init branch 3 times, most recently from f263be6 to e71414a Compare December 10, 2022 16:55
@amithrm
Copy link
Contributor Author

amithrm commented Dec 13, 2022

@will-cromar I see build failure: NameError: name 'sympy' is not defined

@JackCaoG
Copy link
Collaborator

weird, head is green right now https://github.com/pytorch/xla/commits/master

@JackCaoG
Copy link
Collaborator

Ah ok https://github.com/pytorch/xla/pull/4313/files should fix it, can you rebase again?

@amithrm
Copy link
Contributor Author

amithrm commented Dec 13, 2022

@JackCaoG Alll the 4 pass @will-cromar is there anything else needed?

Copy link
Collaborator

@will-cromar will-cromar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@will-cromar will-cromar requested a review from JackCaoG December 13, 2022 19:10
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@JackCaoG JackCaoG merged commit 334d1d6 into pytorch:master Dec 13, 2022
@jeffhataws
Copy link
Collaborator

Thanks @JackCaoG and @amithrm !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants