forked from fepegar/highresnet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
hubconf.py
37 lines (32 loc) · 1.07 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
dependencies = ['torch']
def highres2dnet(*args, **kwargs):
"""
HighRes2DNet in the style of
HighRes3DNet by Li et al. 2017 for T1-MRI brain parcellation
"""
from highresnet import HighRes2DNet
model = HighRes2DNet(*args, **kwargs)
return model
def highres3dnet(*args, pretrained=False, **kwargs):
"""
HighRes3DNet by Li et al. 2017 for T1-MRI brain parcellation
pretrained (bool): load parameters from pretrained model
"""
from highresnet import HighRes3DNet
if pretrained:
model = HighRes3DNet(
*args,
in_channels=1,
out_channels=160,
add_dropout_layer=True,
**kwargs,
)
url_dir = 'https://github.com/fepegar/highresnet-models/raw/master'
url = '{}/highres3dnet_li_parameters-7d297872.pth'.format(url_dir)
state_dict = torch.hub.load_state_dict_from_url(
url, progress=False, map_location='cpu')
model.load_state_dict(state_dict)
else:
model = HighRes3DNet(*args, **kwargs)
return model