diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index c3747acd618e83..80a7db487dcb7b 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -1,5 +1,5 @@ import os -from multiprocessing import Process, Queue +from multiprocessing import Queue import pytest from torch.utils.data import DataLoader @@ -21,9 +21,7 @@ def tpu_device_exists(q): q.put(device_type == 'TPU') queue = Queue() - p = Process(target=tpu_device_exists, args=(queue, )) - p.start() - p.join() + xmp.spawn(lambda rank, queue: tpu_device_exists(queue), (queue, ), nprocs=1, start_method='fork') TPU_AVAILABLE = queue.get() except ImportError: TPU_AVAILABLE = False