1919
2020class QuantizableSqueezeExcitation (SElayer ):
2121 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
22+ kwargs ["scale_activation" ]= nn .Hardswish
2223 super ().__init__ (* args , ** kwargs )
2324 self .skip_mul = nn .quantized .FloatFunctional ()
2425
@@ -80,11 +81,12 @@ def _load_weights(
8081 model : QuantizableMobileNetV3 ,
8182 model_url : Optional [str ],
8283 progress : bool ,
84+ strict : bool
8385) -> None :
8486 if model_url is None :
8587 raise ValueError ("No checkpoint is available for {}" .format (arch ))
8688 state_dict = load_state_dict_from_url (model_url , progress = progress )
87- model .load_state_dict (state_dict )
89+ model .load_state_dict (state_dict , strict = strict )
8890
8991
9092def _mobilenet_v3_model (
@@ -108,13 +110,13 @@ def _mobilenet_v3_model(
108110 torch .quantization .prepare_qat (model , inplace = True )
109111
110112 if pretrained :
111- _load_weights (arch , model , quant_model_urls .get (arch + '_' + backend , None ), progress )
113+ _load_weights (arch , model , quant_model_urls .get (arch + '_' + backend , None ), progress , False )
112114
113115 torch .quantization .convert (model , inplace = True )
114116 model .eval ()
115117 else :
116118 if pretrained :
117- _load_weights (arch , model , model_urls .get (arch , None ), progress )
119+ _load_weights (arch , model , model_urls .get (arch , None ), progress , True )
118120
119121 return model
120122
0 commit comments