diff --git a/CHANGELOG.md b/CHANGELOG.md index 21ed0c97b2a4f..2f2d92166c2f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added autogenerated helptext to `Trainer.add_argparse_args` ([#4344](https://github.com/PyTorchLightning/pytorch-lightning/pull/4344)) - Added support for string values in `Trainer`'s `profiler` parameter ([#3656](https://github.com/PyTorchLightning/pytorch-lightning/pull/3656)) +- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340)) + ### Changed - Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587)) diff --git a/pytorch_lightning/utilities/xla_device_utils.py b/pytorch_lightning/utilities/xla_device_utils.py index 5687992981ae6..14a59fd105c5a 100644 --- a/pytorch_lightning/utilities/xla_device_utils.py +++ b/pytorch_lightning/utilities/xla_device_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import functools import importlib +import queue as q from multiprocessing import Process, Queue import torch @@ -24,10 +25,10 @@ xm = None -def inner_f(queue, func, **kwargs): # pragma: no cover +def inner_f(queue, func, *args, **kwargs): # pragma: no cover try: - queue.put(func(**kwargs)) - except Exception as _e: + queue.put(func(*args, **kwargs)) + except Exception: import traceback traceback.print_exc() @@ -38,10 +39,13 @@ def pl_multi_process(func): @functools.wraps(func) def wrapper(*args, **kwargs): queue = Queue() - proc = Process(target=inner_f, args=(queue, func,), kwargs=kwargs) + proc = Process(target=inner_f, args=(queue, func, *args), kwargs=kwargs) proc.start() - proc.join() - return queue.get() + proc.join(10) + try: + return queue.get_nowait() + except q.Empty: + return False return wrapper diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index b0a3497a0f3be..10de63db049e7 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -11,13 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import time + import pytest -from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils as xdu +import pytorch_lightning.utilities.xla_device_utils as xla_utils from tests.base.develop_utils import pl_multi_process_test try: import torch_xla.core.xla_model as xm + XLA_AVAILABLE = True except ImportError as e: XLA_AVAILABLE = False @@ -26,13 +29,13 @@ @pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): """Check tpu_device_exists returns None when torch_xla is not available""" - assert xdu.tpu_device_exists() is None + assert xla_utils.XLADeviceUtils.tpu_device_exists() is None @pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed") def test_tpu_device_presence(): """Check tpu_device_exists returns True when TPU is available""" - assert xdu.tpu_device_exists() is True + assert xla_utils.XLADeviceUtils.tpu_device_exists() is True @pytest.mark.skipif(not XLA_AVAILABLE, reason="test requires torch_xla to be installed") @@ -42,3 +45,14 @@ def test_xla_device_is_a_tpu(): device = xm.xla_device() device_type = xm.xla_device_hw(device) return device_type == "TPU" + + +def test_result_returns_within_10_seconds(): + """Check that pl_multi_process returns within 10 seconds""" + + start = time.time() + result = xla_utils.pl_multi_process(time.sleep)(25) + end = time.time() + elapsed_time = int(end - start) + assert elapsed_time <= 10 + assert result is False