Skip to content

Commit 25f5f3d

Browse files
shaneahmedmeasty
andauthored
♻️ Replace deprecated input pretrained with weights (#621)
- Replace `pretrained` with `weights` as this is replaced in torch by the new API. --------- Co-authored-by: Mark Eastwood <[email protected]>
1 parent 971ef74 commit 25f5f3d

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

tiatoolbox/models/architecture/vanilla.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tiatoolbox.utils.misc import select_device
1010

1111

12-
def _get_architecture(arch_name, pretrained=True, **kwargs):
12+
def _get_architecture(arch_name, weights="DEFAULT", **kwargs):
1313
"""Get a model.
1414
1515
Model architectures are either already defined within torchvision or
@@ -48,7 +48,7 @@ def _get_architecture(arch_name, pretrained=True, **kwargs):
4848
raise ValueError(f"Backbone `{arch_name}` is not supported.")
4949

5050
creator = backbone_dict[arch_name]
51-
model = creator(pretrained=pretrained, **kwargs)
51+
model = creator(weights=weights, **kwargs)
5252

5353
# Unroll all the definition and strip off the final GAP and FCN
5454
if "resnet" in arch_name or "resnext" in arch_name:

tiatoolbox/tools/registration/wsi_registration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from skimage import exposure, filters
1111
from skimage.registration import phase_cross_correlation
1212
from skimage.util import img_as_float
13+
from torchvision.models import VGG16_Weights
1314

1415
from tiatoolbox import logger
1516
from tiatoolbox.tools.patchextraction import PatchExtractor
@@ -306,7 +307,7 @@ def __init__(self):
306307
output_layers_key: list[str] = ["block3_pool", "block4_pool", "block5_pool"]
307308
self.features: dict = dict.fromkeys(output_layers_key, None)
308309
self.pretrained: torch.nn.Sequential = torchvision.models.vgg16(
309-
pretrained=True
310+
weights=VGG16_Weights.IMAGENET1K_V1
310311
).features
311312
self.f_hooks = []
312313

0 commit comments

Comments
 (0)