Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use xm.save to save model on TPU #3044

Closed
wants to merge 44 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
7bdc8ae
support xm save
lezwon Aug 15, 2020
131ec3c
or to and
lezwon Aug 15, 2020
df49e5c
removed rank_zero_only from on_train_start
lezwon Aug 15, 2020
43d55b6
delete file if exists
lezwon Aug 15, 2020
1313940
added rank_zero_only for _del_model
lezwon Aug 15, 2020
6d2eb42
load state_dict
lezwon Aug 15, 2020
1ee1b05
access state_dict from checkpoint
lezwon Aug 15, 2020
f62c9e5
bytesbuffer to filepath
lezwon Aug 15, 2020
5c696b9
added checkpoint_callback check
lezwon Aug 15, 2020
82f32b7
print function name
lezwon Aug 15, 2020
3449f93
return after saving xla tensor
lezwon Aug 15, 2020
9568af3
log global rank
lezwon Aug 15, 2020
52e4b7c
removed trainer from check_monitor_top_k
lezwon Aug 15, 2020
702f0b9
print current device
lezwon Aug 15, 2020
8746488
rank_zero_only for _do_check_save
lezwon Aug 15, 2020
a123088
log in atomic save
lezwon Aug 15, 2020
7a0b6e5
log in _do_check_save
lezwon Aug 15, 2020
13dcfa1
log in check_monitor_top_k
lezwon Aug 15, 2020
189cc3b
more logs inside check_monitor_top_k
lezwon Aug 15, 2020
ca1da3a
remove rank zero only form _do_check_save
lezwon Aug 15, 2020
aaecd95
added is_global_zero in _do_check_save
lezwon Aug 15, 2020
f302d24
fix del_list issue
lezwon Aug 15, 2020
04ce20e
added global condition check
lezwon Aug 17, 2020
2032128
added pdb
lezwon Aug 17, 2020
0d1611e
removed pdb
lezwon Aug 17, 2020
683a1b7
log best_k_models
lezwon Aug 17, 2020
d81e61a
more changes
lezwon Aug 17, 2020
99ee490
log after save
lezwon Aug 17, 2020
dfdcd76
fix error
lezwon Aug 17, 2020
ca303bc
fix error again
lezwon Aug 17, 2020
de4a117
fix error again and again
lezwon Aug 17, 2020
31e867d
remove device call
lezwon Aug 17, 2020
15eb2b8
ayee add this too. done saving part
lezwon Aug 17, 2020
8784d59
remove state dict
lezwon Aug 17, 2020
79882f0
remove inspect statements
lezwon Aug 17, 2020
b20855b
print stack trace if not is_xla_tensor
lezwon Aug 17, 2020
a4f2c6b
remove inspect reference
lezwon Aug 17, 2020
8f41335
Revert "print stack trace if not is_xla_tensor"
lezwon Aug 19, 2020
9dbfdbf
removed state dict
lezwon Aug 21, 2020
85c5b55
added is_xla_tensor into docstring
lezwon Aug 21, 2020
d488119
save only state dict
lezwon Aug 21, 2020
12d675d
save entire checkpoint
lezwon Aug 22, 2020
c1d8f74
load state dict
lezwon Aug 22, 2020
c883572
xla check inside atomic save
lezwon Sep 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,14 @@ def teardown(self):
last_path = self.mp_queue.get()

# transfer back the best path to the trainer
self.trainer.checkpoint_callback.best_model_path = best_path
if self.trainer.checkpoint_callback is not None:
self.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also bets score

# load last weights
if last_path and not self.trainer.testing:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)
model.load_state_dict(ckpt['state_dict'])

self.trainer.model = model

Expand Down
42 changes: 21 additions & 21 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def kth_best_model(self):
" and will be removed in v0.10.0", DeprecationWarning)
return self.kth_best_model_path

@rank_zero_only
def _del_model(self, filepath):
if self._fs.exists(filepath):
self._fs.rm(filepath)
Expand Down Expand Up @@ -261,7 +262,6 @@ def format_checkpoint_name(self, epoch, metrics, ver=None):
ckpt_name = f'{filename}.ckpt'
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name

@rank_zero_only
def on_pretrain_routine_start(self, trainer, pl_module):
"""
Determines model checkpoint save directory at runtime. References attributes from the
Expand Down Expand Up @@ -315,11 +315,8 @@ def __warn_deprecated_monitor_key(self):
f" Remove `ModelCheckpoint(monitor='{self.monitor}')` to fix."
)

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
# only run on main process
if trainer.global_rank != 0:
return

if trainer.running_sanity_check:
return
Expand Down Expand Up @@ -379,7 +376,6 @@ def on_validation_end(self, trainer, pl_module):
if self.verbose > 0:
log.info(f'Epoch {epoch:d}: saving model to {filepath}')

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
self._save_model(filepath, trainer, pl_module)

if self.save_last:
Expand All @@ -395,22 +391,26 @@ def _do_check_save(self, filepath, current, epoch, trainer, pl_module):
# remove kth

del_list = []
if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0:
delpath = self.kth_best_model_path
self.best_k_models.pop(self.kth_best_model_path)
del_list.append(delpath)

self.best_k_models[filepath] = current
if len(self.best_k_models) == self.save_top_k:
# monitor dict has reached k elements
_op = max if self.mode == 'min' else min
self.kth_best_model_path = _op(self.best_k_models,
key=self.best_k_models.get)
self.kth_value = self.best_k_models[self.kth_best_model_path]

_op = min if self.mode == 'min' else max
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
self.best_model_score = self.best_k_models[self.best_model_path]

if trainer.is_global_zero:
if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0:
delpath = self.kth_best_model_path
self.best_k_models.pop(self.kth_best_model_path)
del_list.append(delpath)

self.best_k_models[filepath] = current
if len(self.best_k_models) == self.save_top_k:
# monitor dict has reached k elements
_op = max if self.mode == 'min' else min
self.kth_best_model_path = _op(self.best_k_models,
key=self.best_k_models.get)
self.kth_value = self.best_k_models[self.kth_best_model_path]

_op = min if self.mode == 'min' else max
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
self.best_model_score = self.best_k_models[self.best_model_path]

print(inspect.currentframe().f_code.co_name + f' Line 401 rank: {trainer.global_rank}')

if self.verbose > 0:
log.info(
Expand Down
23 changes: 19 additions & 4 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
import torch
import fsspec

# we want this for tf.io.gfile, which if tf is installed gives full tf,
# otherwise gives a pruned down version which works for some file backends but
# not all

try:
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True

pathlike = Union[Path, str]

Expand Down Expand Up @@ -49,12 +59,17 @@ def atomic_save(checkpoint, filepath: str):
accepts.
filepath: The path to which the checkpoint will be saved.
This points to the file that the checkpoint will be stored in.
is_xla_tensor: If the tensor to be saved in an XLA Tensor
Is true if the model is being trained on a TPU
"""
bytesbuffer = io.BytesIO()
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
# More details can be found here: https://github.com/pytorch/pytorch/issues/42239
if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:

if checkpoint.device.type == "xla" and XLA_AVAILABLE:
return xm.save(checkpoint, filepath, master_only=True, global_master=True)
elif LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
# More details can be found here: https://github.com/pytorch/pytorch/issues/42239
torch.save(checkpoint, bytesbuffer, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, bytesbuffer)
Expand Down