Skip to content

Commit

Permalink
tpu device check
Browse files Browse the repository at this point in the history
  • Loading branch information
lezwon committed Aug 30, 2020
1 parent f46318e commit 4a1701c
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from multiprocessing import Process, Queue

import pytest
from torch.utils.data import DataLoader
Expand All @@ -14,15 +15,18 @@
import torch_xla
import torch_xla.distributed.xla_multiprocessing as xmp

SERIAL_EXEC = xmp.MpSerialExecutor()
# TODO: The tests are aborted if the following lines are uncommented. Must be resolved with XLA team
# device = torch_xla.core.xla_model.xla_device()
# device_type = torch_xla.core.xla_model.xla_device_hw(device)
# TPU_AVAILABLE = device_type == 'TPU'
def tpu_device_exists(q):
device = torch_xla.core.xla_model.xla_device()
device_type = torch_xla.core.xla_model.xla_device_hw(device)
q.put(device_type == 'TPU')

queue = Queue()
p = Process(target=tpu_device_exists, args=(queue, ))
p.start()
p.join()
TPU_AVAILABLE = queue.get()
except ImportError:
TPU_AVAILABLE = False
else:
TPU_AVAILABLE = True


_LARGER_DATASET = TrialMNIST(download=True, num_samples=2000, digits=(0, 1, 2, 5, 8))
Expand Down

0 comments on commit 4a1701c

Please sign in to comment.