-
Notifications
You must be signed in to change notification settings - Fork 11
/
checkpoints.py
56 lines (47 loc) · 1.82 KB
/
checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import urllib
import torch
from torch.utils import model_zoo
class CheckpointIO(object):
def __init__(self, checkpoint_dir='./chkpts', **kwargs):
self.module_dict = kwargs
self.checkpoint_dir = checkpoint_dir
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
def register_modules(self, **kwargs):
self.module_dict.update(kwargs)
def save(self, filename, **kwargs):
if not os.path.isabs(filename):
filename = os.path.join(self.checkpoint_dir, filename)
outdict = kwargs
for k, v in self.module_dict.items():
outdict[k] = v.state_dict()
torch.save(outdict, filename)
def load(self, filename):
return self.load_file(filename)
def load_file(self, filename):
if not os.path.isabs(filename):
filename = os.path.join(self.checkpoint_dir, filename)
if os.path.exists(filename):
print(filename)
print('=> Loading checkpoint from local file...')
state_dict = torch.load(filename)
scalars = self.parse_state_dict(state_dict)
return scalars
else:
raise FileExistsError
def parse_state_dict(self, state_dict):
for k, v in self.module_dict.items():
if k in state_dict:
pop_list = []
for kk in state_dict[k].keys():
if "mano_layer" in kk:
pop_list.append(kk)
for kk in pop_list:
state_dict[k].pop(kk)
v.load_state_dict(state_dict[k])
else:
print('Warning: Could not find %s in checkpoint!' % k)
scalars = {k: v for k, v in state_dict.items()
if k not in self.module_dict}
return scalars