Skip to content

Commit

Permalink
Fix (weights only) checkpoints loading without pl (#3287)
Browse files Browse the repository at this point in the history
* cast pl AttributeDict to dict

* fix for omegaconf
  • Loading branch information
s-rog authored Sep 2, 2020
1 parent f747cb6 commit 65e6687
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,12 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
if hasattr(model, '_hparams_name'):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
# add arguments to the checkpoint
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
if OMEGACONF_AVAILABLE:
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
if isinstance(model.hparams, Container):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
else:
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)

# give the model a chance to add a few things
model.on_save_checkpoint(checkpoint)
Expand Down

0 comments on commit 65e6687

Please sign in to comment.