From c08eb85329a61142c9c270cb296681f56f06f485 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Sat, 11 Jul 2020 09:26:31 -0700 Subject: [PATCH] Added MP test decorator. --- docs/source/index.rst | 12 ++++++++++++ test/test_operations.py | 9 +++++++++ torch_xla/test/test_utils.py | 33 ++++++++++++++++++++++++++++++++- 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index dff725f1e37..09091237e31 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -85,3 +85,15 @@ utils .. automodule:: torch_xla.utils.cached_dataset .. autoclass:: CachedDataset + +test +---------------------------------- + +.. automodule:: torch_xla.utils.test_utils +.. autofunction:: mp_test +.. autofunction:: write_to_summary +.. autofunction:: close_summary_writer +.. autofunction:: get_summary_writer +.. autofunction:: print_training_update +.. autofunction:: print_test_update + diff --git a/test/test_operations.py b/test/test_operations.py index 706f134799f..d18b6ef5766 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -35,6 +35,7 @@ import torch_xla.debug.metrics as met import torch_xla.debug.model_comparator as mc import torch_xla.distributed.parallel_loader as pl +import torch_xla.test.test_utils as xtu import torch_xla.utils.utils as xu import torch_xla.utils.serialization as xser import torch_xla.core.xla_model as xm @@ -1799,6 +1800,14 @@ def aten_fn(a, b, limit=None): kwargs={'limit': 10}) +class MpDecoratorTest(XlaTestCase): + + @xtu.mp_test + def test_mp_decorator(self): + xla_device = xm.xla_device() + self.assertTrue(xla_device.type == 'xla') + + class TestGeneric(XlaTestCase): def test_zeros_like_patch(self): diff --git a/torch_xla/test/test_utils.py b/torch_xla/test/test_utils.py index a9c3eb738d5..1a1d071d664 100644 --- a/torch_xla/test/test_utils.py +++ b/torch_xla/test/test_utils.py @@ -1,7 +1,9 @@ +from datetime import datetime +import multiprocessing import os import sys import time -from datetime import datetime +import unittest import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met @@ -9,6 +11,35 @@ import torch_xla.utils.utils as xu +def mp_test(func): + """Wraps a `unittest.TestCase` function running it within an isolated process. + + Example:: + + import torch_xla.test.test_utils as xtu + import unittest + + class MyTest(unittest.TestCase): + + @xtu.mp_test + def test_basic(self): + ... + + Args: + func (callable): The `unittest.TestCase` function to be wrapped. + """ + + def wrapper(*args, **kwargs): + proc = multiprocessing.Process(target=func, args=args, kwargs=kwargs) + proc.start() + proc.join() + if isinstance(args[0], unittest.TestCase): + args[0].assertEqual(proc.exitcode, 0) + return proc.exitcode + + return wrapper + + def _get_device_spec(device): ordinal = xm.get_ordinal(defval=-1) return str(device) if ordinal < 0 else '{}/{}'.format(device, ordinal)