-
Notifications
You must be signed in to change notification settings - Fork 70
/
Copy pathbackbones.py
62 lines (57 loc) · 3.51 KB
/
backbones.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import timm # noqa
import torch
import torchvision.models as models # noqa
def load_ref_wrn50():
import resnet
return resnet.wide_resnet50_2(True)
_BACKBONES = {
"cait_s24_224" : "cait.cait_S24_224(True)",
"cait_xs24": "cait.cait_XS24(True)",
"alexnet": "models.alexnet(pretrained=True)",
"bninception": 'pretrainedmodels.__dict__["bninception"]'
'(pretrained="imagenet", num_classes=1000)',
"resnet18": "models.resnet18(pretrained=True)",
"resnet50": "models.resnet50(pretrained=True)",
"mc3_resnet50": "load_mc3_rn50()",
"resnet101": "models.resnet101(pretrained=True)",
"resnext101": "models.resnext101_32x8d(pretrained=True)",
"resnet200": 'timm.create_model("resnet200", pretrained=True)',
"resnest50": 'timm.create_model("resnest50d_4s2x40d", pretrained=True)',
"resnetv2_50_bit": 'timm.create_model("resnetv2_50x3_bitm", pretrained=True)',
"resnetv2_50_21k": 'timm.create_model("resnetv2_50x3_bitm_in21k", pretrained=True)',
"resnetv2_101_bit": 'timm.create_model("resnetv2_101x3_bitm", pretrained=True)',
"resnetv2_101_21k": 'timm.create_model("resnetv2_101x3_bitm_in21k", pretrained=True)',
"resnetv2_152_bit": 'timm.create_model("resnetv2_152x4_bitm", pretrained=True)',
"resnetv2_152_21k": 'timm.create_model("resnetv2_152x4_bitm_in21k", pretrained=True)',
"resnetv2_152_384": 'timm.create_model("resnetv2_152x2_bit_teacher_384", pretrained=True)',
"resnetv2_101": 'timm.create_model("resnetv2_101", pretrained=True)',
"vgg11": "models.vgg11(pretrained=True)",
"vgg19": "models.vgg19(pretrained=True)",
"vgg19_bn": "models.vgg19_bn(pretrained=True)",
"wideresnet50": "models.wide_resnet50_2(pretrained=True)",
"ref_wideresnet50": "load_ref_wrn50()",
"wideresnet101": "models.wide_resnet101_2(pretrained=True)",
"mnasnet_100": 'timm.create_model("mnasnet_100", pretrained=True)',
"mnasnet_a1": 'timm.create_model("mnasnet_a1", pretrained=True)',
"mnasnet_b1": 'timm.create_model("mnasnet_b1", pretrained=True)',
"densenet121": 'timm.create_model("densenet121", pretrained=True)',
"densenet201": 'timm.create_model("densenet201", pretrained=True)',
"inception_v4": 'timm.create_model("inception_v4", pretrained=True)',
"vit_small": 'timm.create_model("vit_small_patch16_224", pretrained=True)',
"vit_base": 'timm.create_model("vit_base_patch16_224", pretrained=True)',
"vit_large": 'timm.create_model("vit_large_patch16_224", pretrained=True)',
"vit_r50": 'timm.create_model("vit_large_r50_s32_224", pretrained=True)',
"vit_deit_base": 'timm.create_model("deit_base_patch16_224", pretrained=True)',
"vit_deit_distilled": 'timm.create_model("deit_base_distilled_patch16_224", pretrained=True)',
"vit_swin_base": 'timm.create_model("swin_base_patch4_window7_224", pretrained=True)',
"vit_swin_large": 'timm.create_model("swin_large_patch4_window7_224", pretrained=True)',
"efficientnet_b7": 'timm.create_model("tf_efficientnet_b7", pretrained=True)',
"efficientnet_b5": 'timm.create_model("tf_efficientnet_b5", pretrained=True)',
"efficientnet_b3": 'timm.create_model("tf_efficientnet_b3", pretrained=True)',
"efficientnet_b1": 'timm.create_model("tf_efficientnet_b1", pretrained=True)',
"efficientnetv2_m": 'timm.create_model("tf_efficientnetv2_m", pretrained=True)',
"efficientnetv2_l": 'timm.create_model("tf_efficientnetv2_l", pretrained=True)',
"efficientnet_b3a": 'timm.create_model("efficientnet_b3a", pretrained=True)',
}
def load(name):
return eval(_BACKBONES[name])