66from torchvision .ops .feature_pyramid_network import ExtraFPNBlock , FeaturePyramidNetwork , LastLevelMaxPool
77
88from .. import mobilenet , resnet
9- from .._utils import IntermediateLayerGetter
9+ from .._api import WeightsEnum
10+ from .._utils import IntermediateLayerGetter , handle_legacy_interface
1011
1112
1213class BackboneWithFPN (nn .Module ):
@@ -55,9 +56,13 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
5556 return x
5657
5758
59+ @handle_legacy_interface (
60+ weights = ("pretrained" , True ), # type: ignore[arg-type]
61+ )
5862def resnet_fpn_backbone (
63+ * ,
5964 backbone_name : str ,
60- pretrained : bool ,
65+ weights : Optional [ WeightsEnum ] ,
6166 norm_layer : Callable [..., nn .Module ] = misc_nn_ops .FrozenBatchNorm2d ,
6267 trainable_layers : int = 3 ,
6368 returned_layers : Optional [List [int ]] = None ,
@@ -69,7 +74,7 @@ def resnet_fpn_backbone(
6974 Examples::
7075
7176 >>> from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
72- >>> backbone = resnet_fpn_backbone('resnet50', pretrained=True , trainable_layers=3)
77+ >>> backbone = resnet_fpn_backbone('resnet50', weights=ResNet50_Weights.DEFAULT , trainable_layers=3)
7378 >>> # get some dummy image
7479 >>> x = torch.rand(1,3,64,64)
7580 >>> # compute the output
@@ -85,7 +90,7 @@ def resnet_fpn_backbone(
8590 Args:
8691 backbone_name (string): resnet architecture. Possible values are 'resnet18', 'resnet34', 'resnet50',
8792 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'
88- pretrained (bool ): If True, returns a model with backbone pre-trained on Imagenet
93+ weights (WeightsEnum, optional ): The pretrained weights for the model
8994 norm_layer (callable): it is recommended to use the default value. For details visit:
9095 (https://github.com/facebookresearch/maskrcnn-benchmark/issues/267)
9196 trainable_layers (int): number of trainable (not frozen) layers starting from final block.
@@ -98,7 +103,7 @@ def resnet_fpn_backbone(
98103 a new list of feature maps and their corresponding names. By
99104 default a ``LastLevelMaxPool`` is used.
100105 """
101- backbone = resnet .__dict__ [backbone_name ](pretrained = pretrained , norm_layer = norm_layer )
106+ backbone = resnet .__dict__ [backbone_name ](weights = weights , norm_layer = norm_layer )
102107 return _resnet_fpn_extractor (backbone , trainable_layers , returned_layers , extra_blocks )
103108
104109
@@ -135,13 +140,13 @@ def _resnet_fpn_extractor(
135140
136141
137142def _validate_trainable_layers (
138- pretrained : bool ,
143+ is_trained : bool ,
139144 trainable_backbone_layers : Optional [int ],
140145 max_value : int ,
141146 default_value : int ,
142147) -> int :
143148 # don't freeze any layers if pretrained model or backbone is not used
144- if not pretrained :
149+ if not is_trained :
145150 if trainable_backbone_layers is not None :
146151 warnings .warn (
147152 "Changing trainable_backbone_layers has not effect if "
@@ -160,16 +165,20 @@ def _validate_trainable_layers(
160165 return trainable_backbone_layers
161166
162167
168+ @handle_legacy_interface (
169+ weights = ("pretrained" , True ), # type: ignore[arg-type]
170+ )
163171def mobilenet_backbone (
172+ * ,
164173 backbone_name : str ,
165- pretrained : bool ,
174+ weights : Optional [ WeightsEnum ] ,
166175 fpn : bool ,
167176 norm_layer : Callable [..., nn .Module ] = misc_nn_ops .FrozenBatchNorm2d ,
168177 trainable_layers : int = 2 ,
169178 returned_layers : Optional [List [int ]] = None ,
170179 extra_blocks : Optional [ExtraFPNBlock ] = None ,
171180) -> nn .Module :
172- backbone = mobilenet .__dict__ [backbone_name ](pretrained = pretrained , norm_layer = norm_layer )
181+ backbone = mobilenet .__dict__ [backbone_name ](weights = weights , norm_layer = norm_layer )
173182 return _mobilenet_extractor (backbone , fpn , trainable_layers , returned_layers , extra_blocks )
174183
175184
0 commit comments