Skip to content

Commit

Permalink
add tests for single scalar return from training
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored and Borda committed Sep 7, 2020
1 parent 97b9d89 commit 7c7b7cc
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ def teardown(self):
# load last weights
if last_path and not self.trainer.testing:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)
self.trainer.model.load_state_dict(ckpt)

self.trainer.model = model
self.trainer.model = self.trainer.model

# when training completes, load the weights back in main process
self.__load_weights_on_main_process()
Expand All @@ -83,16 +83,18 @@ def train(self):
model = self.trainer.model

if self.trainer.can_prepare_data():
model.prepare_data()
self.trainer._is_data_prepared = True
self.trainer.model.prepare_data()
self._is_data_prepared = True

self.trainer.barrier()

# train
if self.trainer.tpu_id is not None:
self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue)
self.tpu_train_in_process(self.trainer.tpu_id, self.trainer.model, self.trainer, self.mp_queue)
else:
xmp.spawn(
self.tpu_train_in_process,
args=(model, self.trainer, self.mp_queue),
args=(self.trainer.model, self.trainer, self.mp_queue),
nprocs=self.trainer.tpu_cores,
start_method=self.start_method
)
Expand Down

0 comments on commit 7c7b7cc

Please sign in to comment.