diff --git a/flash/image/classification/backbones/resnet.py b/flash/image/classification/backbones/resnet.py index 0f136e9df5..b68c9d15d8 100644 --- a/flash/image/classification/backbones/resnet.py +++ b/flash/image/classification/backbones/resnet.py @@ -25,8 +25,12 @@ from torch.hub import load_state_dict_from_url from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _TIMM_AVAILABLE from flash.core.utilities.url_error import catch_url_error +if _TIMM_AVAILABLE: + from timm.models.helpers import adapt_input_conv + def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: """3x3 convolution with padding.""" @@ -351,6 +355,8 @@ def _resnet( ) if model_weights is not None: + in_chans = backbone.conv1.weight.shape[1] + model_weights["conv1.weight"] = adapt_input_conv(in_chans, model_weights["conv1.weight"]) backbone.load_state_dict(model_weights) return backbone, num_features