Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix Resnet pretrained weights with in_chans argument using timm adapt…
Browse files Browse the repository at this point in the history
…_input_conv (#743)
  • Loading branch information
ethanwharris authored Sep 7, 2021
1 parent 143112d commit 75afa79
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions flash/image/classification/backbones/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@
from torch.hub import load_state_dict_from_url

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TIMM_AVAILABLE
from flash.core.utilities.url_error import catch_url_error

if _TIMM_AVAILABLE:
from timm.models.helpers import adapt_input_conv


def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding."""
Expand Down Expand Up @@ -351,6 +355,8 @@ def _resnet(
)

if model_weights is not None:
in_chans = backbone.conv1.weight.shape[1]
model_weights["conv1.weight"] = adapt_input_conv(in_chans, model_weights["conv1.weight"])
backbone.load_state_dict(model_weights)

return backbone, num_features
Expand Down

0 comments on commit 75afa79

Please sign in to comment.