-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added check to verify xla device is TPU (#3274)
* tpu device check * replaced with xmp spawn * Revert "replaced with xmp spawn" This reverts commit 6835380 * replaced all instances of XLA_AVAILABLE * moved inner_f to global scope * made refactors * added changelog * added TPU_AVAILABLE variable * fix codefactor issues * removed form trainer and early stopping * add TORCHXLA_AVAILABLE check * added tests * refactoring * Update pytorch_lightning/utilities/xla_device_utils.py Co-authored-by: Adrian Wälchli <[email protected]> * updated function names * fixed bug * updated CHANGELOG.md * added todo * added type hints * isort and black Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: William Falcon <[email protected]>
- Loading branch information
1 parent
2cf17a3
commit 69833da
Showing
8 changed files
with
138 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import functools | ||
import importlib | ||
from multiprocessing import Process, Queue | ||
|
||
import torch | ||
|
||
TORCHXLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None | ||
if TORCHXLA_AVAILABLE: | ||
import torch_xla.core.xla_model as xm | ||
else: | ||
xm = None | ||
|
||
|
||
def inner_f(queue, func, **kwargs): # pragma: no cover | ||
try: | ||
queue.put(func(**kwargs)) | ||
except Exception as _e: | ||
import traceback | ||
|
||
traceback.print_exc() | ||
queue.put(None) | ||
|
||
|
||
def pl_multi_process(func): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
queue = Queue() | ||
proc = Process(target=inner_f, args=(queue, func,), kwargs=kwargs) | ||
proc.start() | ||
proc.join() | ||
return queue.get() | ||
|
||
return wrapper | ||
|
||
|
||
class XLADeviceUtils: | ||
"""Used to detect the type of XLA device""" | ||
|
||
TPU_AVAILABLE = None | ||
|
||
@staticmethod | ||
def _fetch_xla_device_type(device: torch.device) -> str: | ||
""" | ||
Returns XLA device type | ||
Args: | ||
device: (:class:`~torch.device`): Accepts a torch.device type with a XLA device format i.e xla:0 | ||
Return: | ||
Returns a str of the device hardware type. i.e TPU | ||
""" | ||
if xm is not None: | ||
return xm.xla_device_hw(device) | ||
|
||
@staticmethod | ||
def _is_device_tpu() -> bool: | ||
""" | ||
Check if device is TPU | ||
Return: | ||
A boolean value indicating if the xla device is a TPU device or not | ||
""" | ||
if xm is not None: | ||
device = xm.xla_device() | ||
device_type = XLADeviceUtils._fetch_xla_device_type(device) | ||
return device_type == "TPU" | ||
|
||
@staticmethod | ||
def tpu_device_exists() -> bool: | ||
""" | ||
Public method to check if TPU is available | ||
Return: | ||
A boolean value indicating if a TPU device exists on the system | ||
""" | ||
if XLADeviceUtils.TPU_AVAILABLE is None and TORCHXLA_AVAILABLE: | ||
XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)() | ||
return XLADeviceUtils.TPU_AVAILABLE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import pytest | ||
|
||
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils as xdu | ||
from tests.base.develop_utils import pl_multi_process_test | ||
|
||
try: | ||
import torch_xla.core.xla_model as xm | ||
XLA_AVAILABLE = True | ||
except ImportError as e: | ||
XLA_AVAILABLE = False | ||
|
||
|
||
@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent") | ||
def test_tpu_device_absence(): | ||
"""Check tpu_device_exists returns None when torch_xla is not available""" | ||
assert xdu.tpu_device_exists() is None | ||
|
||
|
||
@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed") | ||
def test_tpu_device_presence(): | ||
"""Check tpu_device_exists returns True when TPU is available""" | ||
assert xdu.tpu_device_exists() is True | ||
|
||
|
||
@pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed") | ||
@pl_multi_process_test | ||
def test_xla_device_is_a_tpu(): | ||
"""Check that the XLA device is a TPU""" | ||
device = xm.xla_device() | ||
device_type = xm.xla_device_hw(device) | ||
return device_type == "TPU" |