Skip to content

Commit

Permalink
add tests for single scalar return from training
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored and Borda committed Aug 11, 2020
1 parent f1a33cb commit 2ba7d73
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
23 changes: 12 additions & 11 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(self, trainer):
self.start_method = None
self.mp_queue = None

def setup(self):
def setup(self, model):
self.trainer.model = model
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')

if not XLA_AVAILABLE:
Expand All @@ -52,7 +53,7 @@ def setup(self):
smp = mp.get_context(self.start_method)
self.mp_queue = smp.SimpleQueue()

def teardown(self, model):
def teardown(self):
# restore main state with best weights
best_path = self.mp_queue.get()
results = self.mp_queue.get()
Expand All @@ -65,28 +66,28 @@ def teardown(self, model):
# load last weights
if last_path and not self.trainer.testing:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)
self.trainer.model.load_state_dict(ckpt)

self.trainer.model = model
self.trainer.model = self.trainer.model

# when training completes, load the weights back in main process
self.__load_weights_on_main_process()
return results

def train(self, model: LightningModule):
self.trainer.model = model

def train(self):
if self.trainer.can_prepare_data():
model.prepare_data()
self.trainer._is_data_prepared = True
self.trainer.model.prepare_data()
self._is_data_prepared = True

self.trainer.barrier()

# train
if self.trainer.tpu_id is not None:
self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue)
self.tpu_train_in_process(self.trainer.tpu_id, self.trainer.model, self.trainer, self.mp_queue)
else:
xmp.spawn(
self.tpu_train_in_process,
args=(model, self.trainer, self.mp_queue),
args=(self.trainer.model, self.trainer, self.mp_queue),
nprocs=self.trainer.tpu_cores,
start_method=self.start_method
)
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,9 +1056,9 @@ def fit(

elif self.use_tpu:
self.accelerator_backend = TPUBackend(self)
self.accelerator_backend.setup()
self.accelerator_backend.train(model)
self.accelerator_backend.teardown(model)
self.accelerator_backend.setup(model)
self.accelerator_backend.train()
self.accelerator_backend.teardown()

else:
self.accelerator_backend = CPUBackend(self)
Expand Down

0 comments on commit 2ba7d73

Please sign in to comment.