Skip to content

Commit

Permalink
typo
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jul 25, 2020
1 parent 613f30c commit 838436a
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,20 @@ def test_model_tpu_index_1(tmpdir):

@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_model_tpu_index_7(tmpdir):
def test_model_tpu_index_5(tmpdir):
"""Make sure model trains on TPU."""
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=[7],
tpu_cores=[5],
limit_train_batches=0.4,
limit_val_batches=0.4
)

model = EvalModelTemplate()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
assert torch_xla._XLAC._xla_get_default_device() == 'xla:8'
assert torch_xla._XLAC._xla_get_default_device() == 'xla:5'


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
Expand Down Expand Up @@ -168,18 +168,19 @@ def long_train_loader():

@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_model_16bit_tpu_index_7(tmpdir):
def test_model_16bit_tpu_index_5(tmpdir):
"""Test if distributed TPU core training works"""
model = EvalModelTemplate()
trainer = Trainer(
default_root_dir=tmpdir,
precision=16,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
tpu_cores=[7],
tpu_cores=[5],
)
trainer.fit(model)
assert torch_xla._XLAC._xla_get_default_device() == 'xla:7'
assert torch_xla._XLAC._xla_get_default_device() == 'xla:5'
assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables"


Expand Down Expand Up @@ -213,10 +214,10 @@ def test_early_stop_checkpoints_on_tpu(tmpdir):
max_epochs=50,
limit_train_batches=10,
limit_val_batches=10,
tpu_cores=[8],
tpu_cores=[5],
)
trainer.fit(model)
assert torch_xla._XLAC._xla_get_default_device() == 'xla:7'
assert torch_xla._XLAC._xla_get_default_device() == 'xla:5'


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
Expand Down

0 comments on commit 838436a

Please sign in to comment.