diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index b3aebda22cebe..d0e752c9018ba 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -179,7 +179,7 @@ def test_model_16bit_tpu_index_7(tmpdir): tpu_cores=[7], ) trainer.fit(model) - assert torch_xla._XLAC._xla_get_default_device() == 'xla:8' + assert torch_xla._XLAC._xla_get_default_device() == 'xla:7' assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables" @@ -216,7 +216,7 @@ def test_early_stop_checkpoints_on_tpu(tmpdir): tpu_cores=[8], ) trainer.fit(model) - assert torch_xla._XLAC._xla_get_default_device() == 'xla:8' + assert torch_xla._XLAC._xla_get_default_device() == 'xla:7' @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")