Skip to content

Commit b04fa9a

Browse files
committed
Revert "Set default params if missing."
This reverts commit b491fa2
1 parent b491fa2 commit b04fa9a

File tree

1 file changed

+4
-41
lines changed

1 file changed

+4
-41
lines changed

torchvision/models/quantization/mobilenetv3.py

Lines changed: 4 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919

2020
class QuantizableSqueezeExcitation(SElayer):
21-
_version = 2
22-
2321
def __init__(self, *args: Any, **kwargs: Any) -> None:
2422
kwargs["scale_activation"] = nn.Hardswish
2523
super().__init__(*args, **kwargs)
@@ -31,42 +29,6 @@ def forward(self, input: Tensor) -> Tensor:
3129
def fuse_model(self) -> None:
3230
fuse_modules(self, ['fc1', 'activation'], inplace=True)
3331

34-
def _load_from_state_dict(
35-
self,
36-
state_dict,
37-
prefix,
38-
local_metadata,
39-
strict,
40-
missing_keys,
41-
unexpected_keys,
42-
error_msgs,
43-
):
44-
version = local_metadata.get("version", None)
45-
46-
if version is None or version < 2:
47-
default_state_dict = {
48-
"scale_activation.activation_post_process.scale": torch.tensor([1.]),
49-
"scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32),
50-
"scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]),
51-
"scale_activation.activation_post_process.observer_enabled": torch.tensor([1]),
52-
"scale_activation.activation_post_process.activation_post_process.min_val": torch.tensor(float('inf')),
53-
"scale_activation.activation_post_process.activation_post_process.max_val": torch.tensor(-float('inf')),
54-
}
55-
for k, v in default_state_dict.items():
56-
full_key = prefix + k
57-
if full_key not in state_dict:
58-
state_dict[full_key] = v
59-
60-
super()._load_from_state_dict(
61-
state_dict,
62-
prefix,
63-
local_metadata,
64-
strict,
65-
missing_keys,
66-
unexpected_keys,
67-
error_msgs,
68-
)
69-
7032

7133
class QuantizableInvertedResidual(InvertedResidual):
7234
# TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659
@@ -118,12 +80,13 @@ def _load_weights(
11880
arch: str,
11981
model: QuantizableMobileNetV3,
12082
model_url: Optional[str],
121-
progress: bool
83+
progress: bool,
84+
strict: bool = True
12285
) -> None:
12386
if model_url is None:
12487
raise ValueError("No checkpoint is available for {}".format(arch))
12588
state_dict = load_state_dict_from_url(model_url, progress=progress)
126-
model.load_state_dict(state_dict)
89+
model.load_state_dict(state_dict, strict=strict)
12790

12891

12992
def _mobilenet_v3_model(
@@ -147,7 +110,7 @@ def _mobilenet_v3_model(
147110
torch.quantization.prepare_qat(model, inplace=True)
148111

149112
if pretrained:
150-
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress)
113+
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress, strict=False)
151114

152115
torch.quantization.convert(model, inplace=True)
153116
model.eval()

0 commit comments

Comments
 (0)