Skip to content
/ xla Public
forked from pytorch/xla

Enabling PyTorch on Google TPU

License

Notifications You must be signed in to change notification settings

davidel/xla

 
 

Repository files navigation

PyTorch/XLA

Current CI status: CircleCI

PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU with Google Colab, and use it in production and on Cloud TPU Pods with Google Cloud.

Take a look at one of our Colab notebooks to quickly try different PyTorch networks running on Cloud TPUs and learn how to use Cloud TPUs as PyTorch devices:

The rest of this README covers:

Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org.

Running PyTorch on Cloud TPUs with Google Cloud Platform

Google Cloud Platform lets you deploy PyTorch networks running on Cloud TPUs. This guide is split into two parts:

Running on a Single Cloud TPU

The following tutorials are available to help you train models on a single Cloud TPU:

To start, you create a Cloud TPU node with the corresponding release you wish to consume (TPU software version: ex. pytorch-1.5):

Once you've created a Cloud TPU node, you can train your PyTorch models by either:

Consume Prebuilt Docker Images

Follow these steps to train a PyTorch model with Docker on a Cloud TPU:

  1. Create a Compute VM and install docker (or use COS VM image)

    • Note: make sure the Compute VM is within the same zone as the TPU node you created or else performance will suffer, also ideally create a VM that has at least 16 cores (n1-standard-16) to not be VM compute/network bound.

    Docker images with torch and torch_xla preinstalled in the pytorch conda environment are distributed under: gcr.io/tpu-pytorch/xla.

  2. SSH into the VM and pull a version of the docker image into the VM. The currently available versions are:

    • gcr.io/tpu-pytorch/xla:r1.5: The current stable version.
    • gcr.io/tpu-pytorch/xla:nightly_3.6: Nightly version using Python 3.6.
    • gcr.io/tpu-pytorch/xla:nightly_3.7: Nightly version using Python 3.7.
    • gcr.io/tpu-pytorch/xla:nightly_3.6_YYYYMMDD (e.g.: gcr.io/tpu-pytorch/xla:nightly_3.6_20190531): The nightly version of the given day. You can replace 3.6 with 3.7 if desired.

    At this time is recommended to use nightly versions and eventually switch to the stable version in case there are issues with nightly. Remember to create a TPU with pytorch-nightly version when using nightly.

    To pull the dockers run one of the following commands:

    (vm)$ docker pull gcr.io/tpu-pytorch/xla:nightly_3.6
    (vm)$ docker pull gcr.io/tpu-pytorch/xla:nightly_3.6_YYYYMMDD
    (vm)$ docker pull gcr.io/tpu-pytorch/xla:r1.5
  3. Where $TPU_IP_ADDRESS (e.g.: 10.1.1.2) is your TPU Internal IP displayed in GCP UI, after pulling the docker image you can either:

    • Run the container with a single command:

      (vm)$ docker run --shm-size 16G -e XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470" gcr.io/tpu-pytorch/xla:r1.5 python /pytorch/xla/test/test_train_mp_mnist.py
    • Run the script in an interactive shell:

      (vm)$ docker run -it --shm-size 16G gcr.io/tpu-pytorch/xla:r1.5
      (pytorch) root@CONTAINERID:/$ export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
      (pytorch) root@CONTAINERID:/$ python pytorch/xla/test/test_train_mp_mnist.py

Consume Prebuilt Compute VM Images

Follow these steps to train a PyTorch model with a VM Image on a Cloud TPU:

  1. Create a Compute VM with PyTorch/XLA Image.

    • In the GCP Console, go to the VM Instances page.
    • Click Create Instance.
    • Make sure the compute VM is within the same zone as the TPU node you created or else performance will suffer, also ideally create a VM that has at least 16 cores (n1-standard-16) to not be VM compute/network bound.
    • In the Boot disk section, click Change to choose our PyTorch/XLA image.
    • At the bottom of the OS Images tab select the Debian GNU/Linux 9 Stretch + PyTorch/XLA image.
    • Chose an appropriate dist size based on your dataset and click Select.
    • Click Create to create the instance.
  2. SSH into VM and activate the conda environment you wish to use. Each release (e.g.: 0.1, 0.5, 1.5, nightly) is a separate conda environment.

    (vm)$ export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
    (vm)$ conda env list
    # conda environments:
    #
    base                  *  /anaconda3
    torch-xla-0.1              /anaconda3/envs/torch-xla-0.1
    torch-xla-0.5              /anaconda3/envs/torch-xla-0.5
    torch-xla-1.5              /anaconda3/envs/torch-xla-1.5
    torch-xla-nightly          /anaconda3/envs/torch-xla-nightly
    
    (vm)$ conda activate torch-xla-1.5
    (torch-xla-1.5)$ cd /usr/share/torch-xla-1.5/pytorch/xla
    (torch-xla-1.5)$ python test/test_train_mp_mnist.py

    To update the wheels torch and torch_xla to the latest nightly distribution (only updates your torch-xla-nightly conda env), run:

    (vm)$ cd /usr/share/torch-xla-nightly/pytorch/xla
    (vm)$ . ./scripts/update_nightly_torch_wheels.sh

How to Run on TPU Pods (distributed training)

Whereas the previous section focused on training on a single TPU node, this section discusses distributed training in TPU Pods. The tutorial, Training PyTorch models on Cloud TPU Pods, is a great place to start.

The recommended setup for running distributed training on TPU Pods uses the pairing of Compute VM Instance Groups and TPU Pods. Each of the Compute VM in the instance group drives 8 cores on the TPU Pod and so using an instance group ensures each of the Compute VMs use the identical base image.

Training on pods can be broken down to largely 3 different steps:

  1. Create your instance group (recommended) or Use a list of VM instances
  2. Create your TPU Pod
  3. Start distributed training

Create your instance group

  1. Create an instance template.
  • During creation, make sure to go to section "Identity and API access" → "Access Scopes" and select "Allow full access to all Cloud APIs".
  • If you have already have a VM instance running that you used to train PyTorch/TPU workloads and want to use that exact setup for distributed training: instructions.
  • Or, you can create an instance template using the PyTorch/XLA VM image we provide: instructions.
  1. Create an instance group to drive the TPU pod.
  • This instance group is where all the input pipeline happens and where we feed all the tensors into the TPUs for training.
  • Use the instance template created in step (1) to create your instance group.
  • Make sure to (a) create the instance group in a single zone (same zone as the TPU Pod you'll create), (b) no autoscaling or health-checks, (c) number of instances (size of instance group) should be number of cores / 8 (ex. for a v3-32 you'd create an instance group of size 32/8 = 4).
  • Here are the instructions for creating an instance group: instructions.

Create your TPU Pod

  1. Create a TPU pod (same as creating regular TPUs, just select more cores when selecting TPU type).
  • Make sure that the TPU is in the same zone as the instance group.
  • Make sure that the size of your instance group follows: # instances in group = number of TPU cores / 8.

Start distributed training

  1. SSH into any of the VMs in the instance group and get in an environment where you have torch and torch_xla installed (whether that's a conda environment or docker container).
  2. Let's say the command you ran to run a v3-8 was: XLA_USE_BF16=1 python test/test_train_mp_imagenet.py --fake_data.
  • To distribute training as a conda environment process:
(torch-xla-nightly)$ python -m torch_xla.distributed.xla_dist --tpu=$TPU_POD_NAME --conda-env=torch-xla-nightly --env=XLA_USE_BF16=1 -- python /usr/share/torch-xla-1.5/pytorch/xla/test/test_train_mp_imagenet.py --fake_data
  • Or, to distribute training as a docker container:
(torch-xla-nightly)$ python -m torch_xla.distributed.xla_dist --tpu=$TPU_POD_NAME --docker-image=gcr.io/tpu-pytorch/xla:nightly_3.6 --docker-run-flag=--rm=true --docker-run-flag=--shm-size=50GB --env=XLA_USE_BF16=1 -- python /pytorch/xla/test/test_train_mp_imagenet.py --fake_data

List of VMs

If you up to not use an instance group, you can decide to use a list of VM instances that you may have already created (or can create individually). Make sure that you create all the VM instances in the same zone as the TPU node, and also make sure that the VMs have the same configuration (datasets, VM size, disk size, etc.). Then you can start distributed training after creating your TPU pod. The difference is in the python -m torch_xla.distributed.xla_dist command. For example, to use a list of VMs run the following command (ex. conda with v3-32):

(torch-xla-nightly)$ cd /usr/share/torch-xla-nightly/pytorch/xla
(torch-xla-nightly)$ python -m torch_xla.distributed.xla_dist --tpu=$TPU_POD_NAME --vm $VM1 --vm $VM2 --vm $VM3 --vm $VM4 --conda-env=torch-xla-nightly --env=XLA_USE_BF16=1 -- python test/test_train_imagenet.py --fake_data

To learn more about TPU Pods check out this blog post. For more information regarding system architecture, please refer to the Cloud TPU System Architecture page.

API & Best Practices

In general PyTorch/XLA follows PyTorch APIs, some additional torch_xla specific APIs are available at:

Documentation for the latest release

Documentation for master branch

See the API Guide for best practices when writing networks that run on Cloud TPUs and Cloud TPU Pods.

Troubleshooting

If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).

Providing Feedback

The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!

Contributing

See the contribution guide.

About

Enabling PyTorch on Google TPU

Resources

License

Code of conduct

Stars

Watchers

Forks

Packages

No packages published

Languages

  • C++ 71.1%
  • Python 18.4%
  • Jupyter Notebook 9.2%
  • Other 1.3%