diff --git a/pytorch_lightning/accelerator_backends/tpu_backend.py b/pytorch_lightning/accelerator_backends/tpu_backend.py index 95203550eb37e1..8d1d1b271b7dc8 100644 --- a/pytorch_lightning/accelerator_backends/tpu_backend.py +++ b/pytorch_lightning/accelerator_backends/tpu_backend.py @@ -126,7 +126,7 @@ def __save_end_of_training_weights(self, model: LightningModule, trainer): def __setup_tpu_training(self, model: LightningModule, trainer): # use the default device from the process - tpu_device = xm.xla_device() + # tpu_device = xm.xla_device() # if given an ordinal device, use this as the device if trainer.tpu_id is not None: