Skip to content

Commit

Permalink
finished dist (Lightning-AI#911)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored and tullie committed Apr 3, 2020
1 parent 4bd73e8 commit c57fafa
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
14 changes: 13 additions & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@
except ImportError:
APEX_AVAILABLE = False

try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

XLA_AVAILABLE = True
except ImportError:
XLA_AVAILABLE = False


class TrainerDataLoadingMixin(ABC):

Expand Down Expand Up @@ -217,12 +226,15 @@ def get_dataloaders(self, model):

# on TPUs load each dataloader only on process 0
# this will trigger the data downloads
if self.use_tpu:
if self.use_tpu and XLA_AVAILABLE:
if self.tpu_local_core_rank == 0:
self.get_train_dataloader()
self.get_test_dataloaders()
self.get_val_dataloaders()

# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous()

# support IterableDataset for train data
self.is_iterable_train_dataloader = (
EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset))
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,11 @@ def run_pretrain_routine(self, model):
if self.use_ddp or self.use_ddp2:
dist.barrier()

# wait for all models to restore weights
if self.on_tpu and XLA_AVAILABLE:
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous()

# set up checkpoint callback
self.configure_checkpoint_callback()

Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,15 @@
LightningDataParallel,
)

try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

XLA_AVAILABLE = True
except ImportError:
XLA_AVAILABLE = False


class TrainerIOMixin(ABC):

Expand All @@ -125,6 +134,7 @@ def __init__(self):
self.early_stop_callback = None
self.lr_schedulers = None
self.optimizers = None
self.on_tpu = None
self.num_training_batches = None
self.accumulate_grad_batches = None

Expand Down Expand Up @@ -170,6 +180,11 @@ def restore_weights(self, model):
# wait for all processes to catch up
dist.barrier()

# wait for all models to restore weights
if self.on_tpu and XLA_AVAILABLE:
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous()

# clear cache after restore
if self.on_gpu:
torch.cuda.empty_cache()
Expand Down

0 comments on commit c57fafa

Please sign in to comment.