From 838436a04f5bd087238d7941464eb6ca75883ede Mon Sep 17 00:00:00 2001 From: Jirka Date: Sat, 25 Jul 2020 22:08:11 +0200 Subject: [PATCH] typo --- tests/models/test_tpu.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index d0e752c9018ba..92c336e0acf39 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -59,20 +59,20 @@ def test_model_tpu_index_1(tmpdir): @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test -def test_model_tpu_index_7(tmpdir): +def test_model_tpu_index_5(tmpdir): """Make sure model trains on TPU.""" trainer_options = dict( default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=1, - tpu_cores=[7], + tpu_cores=[5], limit_train_batches=0.4, limit_val_batches=0.4 ) model = EvalModelTemplate() tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) - assert torch_xla._XLAC._xla_get_default_device() == 'xla:8' + assert torch_xla._XLAC._xla_get_default_device() == 'xla:5' @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @@ -168,18 +168,19 @@ def long_train_loader(): @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test -def test_model_16bit_tpu_index_7(tmpdir): +def test_model_16bit_tpu_index_5(tmpdir): """Test if distributed TPU core training works""" model = EvalModelTemplate() trainer = Trainer( default_root_dir=tmpdir, + precision=16, max_epochs=1, train_percent_check=0.4, val_percent_check=0.2, - tpu_cores=[7], + tpu_cores=[5], ) trainer.fit(model) - assert torch_xla._XLAC._xla_get_default_device() == 'xla:7' + assert torch_xla._XLAC._xla_get_default_device() == 'xla:5' assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables" @@ -213,10 +214,10 @@ def test_early_stop_checkpoints_on_tpu(tmpdir): max_epochs=50, limit_train_batches=10, limit_val_batches=10, - tpu_cores=[8], + tpu_cores=[5], ) trainer.fit(model) - assert torch_xla._XLAC._xla_get_default_device() == 'xla:7' + assert torch_xla._XLAC._xla_get_default_device() == 'xla:5' @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")