Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added MP test decorator. #2343

Merged
merged 1 commit into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,15 @@ utils
.. automodule:: torch_xla.utils.cached_dataset
.. autoclass:: CachedDataset


test
----------------------------------

.. automodule:: torch_xla.utils.test_utils
.. autofunction:: mp_test
.. autofunction:: write_to_summary
.. autofunction:: close_summary_writer
.. autofunction:: get_summary_writer
.. autofunction:: print_training_update
.. autofunction:: print_test_update

9 changes: 9 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import torch_xla.debug.metrics as met
import torch_xla.debug.model_comparator as mc
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as xtu
import torch_xla.utils.utils as xu
import torch_xla.utils.serialization as xser
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -1799,6 +1800,14 @@ def aten_fn(a, b, limit=None):
kwargs={'limit': 10})


class MpDecoratorTest(XlaTestCase):

@xtu.mp_test
def test_mp_decorator(self):
xla_device = xm.xla_device()
self.assertTrue(xla_device.type == 'xla')


class TestGeneric(XlaTestCase):

def test_zeros_like_patch(self):
Expand Down
33 changes: 32 additions & 1 deletion torch_xla/test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,45 @@
from datetime import datetime
import multiprocessing
import os
import sys
import time
from datetime import datetime
import unittest

import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.debug.metrics_compare_utils as mcu
import torch_xla.utils.utils as xu


def mp_test(func):
"""Wraps a `unittest.TestCase` function running it within an isolated process.

Example::

import torch_xla.test.test_utils as xtu
import unittest

class MyTest(unittest.TestCase):

@xtu.mp_test
def test_basic(self):
...

Args:
func (callable): The `unittest.TestCase` function to be wrapped.
"""

def wrapper(*args, **kwargs):
proc = multiprocessing.Process(target=func, args=args, kwargs=kwargs)
proc.start()
proc.join()
if isinstance(args[0], unittest.TestCase):
args[0].assertEqual(proc.exitcode, 0)
return proc.exitcode

return wrapper


def _get_device_spec(device):
ordinal = xm.get_ordinal(defval=-1)
return str(device) if ordinal < 0 else '{}/{}'.format(device, ordinal)
Expand Down