Skip to content

Commit

Permalink
Added check to verify xla device is TPU (#3274)
Browse files Browse the repository at this point in the history
* tpu device check

* replaced with xmp spawn

* Revert "replaced with xmp spawn"

This reverts commit 6835380

* replaced all instances of XLA_AVAILABLE

* moved inner_f to global scope

* made refactors

* added changelog

* added TPU_AVAILABLE variable

* fix codefactor issues

* removed form trainer and early stopping

* add TORCHXLA_AVAILABLE check

* added tests

* refactoring

* Update pytorch_lightning/utilities/xla_device_utils.py

Co-authored-by: Adrian Wälchli <[email protected]>

* updated function names

* fixed bug

* updated CHANGELOG.md

* added todo

* added type hints

* isort and black

Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: William Falcon <[email protected]>
  • Loading branch information
3 people authored Oct 6, 2020
1 parent 2cf17a3 commit 69833da
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 36 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `broadcast` to `TPUBackend` ([#3814](https://github.com/PyTorchLightning/pytorch-lightning/pull/3814))

- Added `XLADeviceUtils` class to check XLA device type ([#3274](https://github.com/PyTorchLightning/pytorch-lightning/pull/3274))

### Changed

- Changed `LearningRateLogger` to `LearningRateMonitor` ([#3251](https://github.com/PyTorchLightning/pytorch-lightning/pull/3251))
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,19 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.core import LightningModule
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.utilities import AMPType, rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils

try:
TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

if TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True
import torch_xla.distributed.parallel_loader as xla_pl


class TPUBackend(Accelerator):
Expand All @@ -47,7 +46,8 @@ def __init__(self, trainer, cluster_environment=None):
def setup(self, model):
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')

if not XLA_AVAILABLE:
# TODO: Move this check to Trainer __init__ or device parser
if not TPU_AVAILABLE:
raise MisconfigurationException('PyTorch XLA not installed.')

# see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2
Expand Down Expand Up @@ -171,7 +171,7 @@ def to_device(self, batch):
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
if not XLA_AVAILABLE:
if not TPU_AVAILABLE:
raise MisconfigurationException(
'Requested to transfer batch to TPU but XLA is not available.'
' Are you sure this machine has TPUs?'
Expand Down
17 changes: 8 additions & 9 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,22 @@
Monitor a validation metric and stop training when it stops improving.
"""
import os

import numpy as np
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
import os
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils

TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()


torch_inf = torch.tensor(np.Inf)

try:
import torch_xla
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True



class EarlyStopping(Callback):
Expand Down Expand Up @@ -186,7 +185,7 @@ def _run_early_stopping_check(self, trainer, pl_module):
if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device)

if trainer.use_tpu and XLA_AVAILABLE:
if trainer.use_tpu and TPU_AVAILABLE:
current = current.cpu()

if self.monitor_op(current - self.min_delta, self.best_score):
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.core.step_result import TrainResult, EvalResult
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.parsing import (
Expand All @@ -43,12 +45,10 @@
from torch.optim.optimizer import Optimizer


try:
TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

if TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True


class LightningModule(
Expand Down
10 changes: 3 additions & 7 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,18 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from copy import deepcopy


TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
try:
from apex import amp
except ImportError:
amp = None

try:
if TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True

try:
import horovod.torch as hvd
Expand Down
74 changes: 74 additions & 0 deletions pytorch_lightning/utilities/xla_device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import functools
import importlib
from multiprocessing import Process, Queue

import torch

TORCHXLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None
if TORCHXLA_AVAILABLE:
import torch_xla.core.xla_model as xm
else:
xm = None


def inner_f(queue, func, **kwargs): # pragma: no cover
try:
queue.put(func(**kwargs))
except Exception as _e:
import traceback

traceback.print_exc()
queue.put(None)


def pl_multi_process(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
queue = Queue()
proc = Process(target=inner_f, args=(queue, func,), kwargs=kwargs)
proc.start()
proc.join()
return queue.get()

return wrapper


class XLADeviceUtils:
"""Used to detect the type of XLA device"""

TPU_AVAILABLE = None

@staticmethod
def _fetch_xla_device_type(device: torch.device) -> str:
"""
Returns XLA device type
Args:
device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0
Return:
Returns a str of the device hardware type. i.e TPU
"""
if xm is not None:
return xm.xla_device_hw(device)

@staticmethod
def _is_device_tpu() -> bool:
"""
Check if device is TPU
Return:
A boolean value indicating if the xla device is a TPU device or not
"""
if xm is not None:
device = xm.xla_device()
device_type = XLADeviceUtils._fetch_xla_device_type(device)
return device_type == "TPU"

@staticmethod
def tpu_device_exists() -> bool:
"""
Public method to check if TPU is available
Return:
A boolean value indicating if a TPU device exists on the system
"""
if XLADeviceUtils.TPU_AVAILABLE is None and TORCHXLA_AVAILABLE:
XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)()
return XLADeviceUtils.TPU_AVAILABLE
12 changes: 6 additions & 6 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import os
from multiprocessing import Process, Queue

import pytest
from torch.utils.data import DataLoader

import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.accelerators.base_backend import BackendType
from pytorch_lightning.accelerators import TPUBackend
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from tests.base import EvalModelTemplate
from tests.base.datasets import TrialMNIST
from tests.base.develop_utils import pl_multi_process_test

try:
TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

if TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
SERIAL_EXEC = xmp.MpSerialExecutor()
except ImportError:
TPU_AVAILABLE = False
else:
TPU_AVAILABLE = True


_LARGER_DATASET = TrialMNIST(download=True, num_samples=2000, digits=(0, 1, 2, 5, 8))
Expand Down Expand Up @@ -216,7 +217,6 @@ def test_tpu_misconfiguration():
Trainer(tpu_cores=[1, 8])


# @patch('pytorch_lightning.trainer.trainer.XLA_AVAILABLE', False)
@pytest.mark.skipif(TPU_AVAILABLE, reason="test requires missing TPU")
def test_exception_when_no_tpu_found(tmpdir):
"""Test if exception is thrown when xla devices are not available"""
Expand Down
31 changes: 31 additions & 0 deletions tests/utilities/test_xla_device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils as xdu
from tests.base.develop_utils import pl_multi_process_test

try:
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
except ImportError as e:
XLA_AVAILABLE = False


@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent")
def test_tpu_device_absence():
"""Check tpu_device_exists returns None when torch_xla is not available"""
assert xdu.tpu_device_exists() is None


@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed")
def test_tpu_device_presence():
"""Check tpu_device_exists returns True when TPU is available"""
assert xdu.tpu_device_exists() is True


@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed")
@pl_multi_process_test
def test_xla_device_is_a_tpu():
"""Check that the XLA device is a TPU"""
device = xm.xla_device()
device_type = xm.xla_device_hw(device)
return device_type == "TPU"

0 comments on commit 69833da

Please sign in to comment.