1818
1919
2020class QuantizableSqueezeExcitation (SElayer ):
21- _version = 2
22-
2321 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
2422 kwargs ["scale_activation" ] = nn .Hardswish
2523 super ().__init__ (* args , ** kwargs )
@@ -31,42 +29,6 @@ def forward(self, input: Tensor) -> Tensor:
3129 def fuse_model (self ) -> None :
3230 fuse_modules (self , ['fc1' , 'activation' ], inplace = True )
3331
34- def _load_from_state_dict (
35- self ,
36- state_dict ,
37- prefix ,
38- local_metadata ,
39- strict ,
40- missing_keys ,
41- unexpected_keys ,
42- error_msgs ,
43- ):
44- version = local_metadata .get ("version" , None )
45-
46- if version is None or version < 2 :
47- default_state_dict = {
48- "scale_activation.activation_post_process.scale" : torch .tensor ([1. ]),
49- "scale_activation.activation_post_process.zero_point" : torch .tensor ([0 ], dtype = torch .int32 ),
50- "scale_activation.activation_post_process.fake_quant_enabled" : torch .tensor ([1 ]),
51- "scale_activation.activation_post_process.observer_enabled" : torch .tensor ([1 ]),
52- "scale_activation.activation_post_process.activation_post_process.min_val" : torch .tensor (float ('inf' )),
53- "scale_activation.activation_post_process.activation_post_process.max_val" : torch .tensor (- float ('inf' )),
54- }
55- for k , v in default_state_dict .items ():
56- full_key = prefix + k
57- if full_key not in state_dict :
58- state_dict [full_key ] = v
59-
60- super ()._load_from_state_dict (
61- state_dict ,
62- prefix ,
63- local_metadata ,
64- strict ,
65- missing_keys ,
66- unexpected_keys ,
67- error_msgs ,
68- )
69-
7032
7133class QuantizableInvertedResidual (InvertedResidual ):
7234 # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
@@ -118,12 +80,13 @@ def _load_weights(
11880 arch : str ,
11981 model : QuantizableMobileNetV3 ,
12082 model_url : Optional [str ],
121- progress : bool
83+ progress : bool ,
84+ strict : bool = True
12285) -> None :
12386 if model_url is None :
12487 raise ValueError ("No checkpoint is available for {}" .format (arch ))
12588 state_dict = load_state_dict_from_url (model_url , progress = progress )
126- model .load_state_dict (state_dict )
89+ model .load_state_dict (state_dict , strict = strict )
12790
12891
12992def _mobilenet_v3_model (
@@ -147,7 +110,7 @@ def _mobilenet_v3_model(
147110 torch .quantization .prepare_qat (model , inplace = True )
148111
149112 if pretrained :
150- _load_weights (arch , model , quant_model_urls .get (arch + '_' + backend , None ), progress )
113+ _load_weights (arch , model , quant_model_urls .get (arch + '_' + backend , None ), progress , strict = False )
151114
152115 torch .quantization .convert (model , inplace = True )
153116 model .eval ()
0 commit comments