-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: Error(s) in loading state_dict when adding/updating metrics to a trained model. #4666
Comments
Reading further about state_dicts I noted that based on PyTorch docs, state_dicts should store model's parameters and hyper-parameters. I'm not sure why is there metrics-related data stored in this dict. |
https://github.com/PyTorchLightning/pytorch-lightning/blob/baa8558cc0e6d2a3e24f2669e6a59ffdb8138737/pytorch_lightning/metrics/metric.py#L88-L90 |
Well thanks @rohitgr7, ill try changing this parameter and i'll report back. |
maybe try after removing these keys from the checkpoint? |
After removing each metric data from the
At least the main issue seems fixed, now I'm dealing with this one. |
Surprisingly i had to change the line 99 of the File And now the test loop it's running at least without the new metrics. Thank you for your help. PD: I have other issues that prevent me to evaluate Recall, Precision, F1Score, and Confusion Table.
I'll keep debunging and inform any news. At the worst case, i can log the |
This is basically a duplicate (at least the initial problem) of #4361. metric = pl.metrics.Accuracy()
print(metric.state_dict()) # prints OrderedDict([('correct', tensor(0)), ('total', tensor(0))])
metric.persistent(False)
print(metric.state_dict()) # prints OrderedDict() |
Thanks for the help! Anyway here is the script I used to fix the """
Remove metric records on state_dict.
"""
import argparse
from pathlib import Path
import torch
state_dict_keys_to_remove = ['test_acc.total', 'test_acc.correct',
'val_acc.total', 'val_acc.correct',
'train_acc.total', 'train_acc.correct', ]
def main():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--src_path', help='', )
parser.add_argument('--dest_path', help='', )
args = parser.parse_args()
src_path = Path(args.src_path)
dest_path = Path(args.dest_path)
ckpt = torch.load(src_path)
print('info: state_dict keys: {}'.format(ckpt['state_dict'].keys()))
for k in state_dict_keys_to_remove:
del ckpt['state_dict'][k]
torch.save(ckpt, dest_path)
if __name__ == '__main__':
main() |
@Vichoko no problem. Personally I was in favor of having metric states part of the state dict, but when enough people raises the concern I am of course willing to change my opinion :] |
@SkafteNicki That's very laudable of you ^^ I still wonder in which case would be useful to store the metric states on the |
Try |
Not working |
❓ Questions and Help
For context, I trained a lot of models for many weeks, tracking the loss and accuracy for train, validation, and test steps.
Now, I wanted to evaluate more metrics for the test data set, more specifically, I added recall, confusion matrix, and precision metrics (from
ptl.metrics
module) to the test_step and test_epoch_end methods in the lighting module.Also, I replaced my custom accuracy with the class-based Accuracy implemented on the
ptl.metrics
package.When I try to test my model and get the metrics for the trained model on the test set, I get this error loading the checkpoint:
What is your question?
In my case, it's impossible to train again the models because it takes many weeks. So I just wonder if there is a way to load the already trained model anyway and obtain the updated test metrics by a test cycle.
Actually I just care about loading the parameters of the model to run the test cycle. I can't understand why it's so important to load other things up, those old metrics don't appear so vital to me.
What is you tired?
I read that this exception it's generated by torch model.load_state_dict method, and can be avoided with
strict=false
parameter.In my case I load the trained model with the
resume_from_checkpoint
parameter of the pytorch-lightning trainer class, so i have no clue to try to get closer to load this.What's your environment?
The text was updated successfully, but these errors were encountered: