Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into bugfix/audio_numpy_loading
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Sep 7, 2021
2 parents 9fbfab5 + 75afa79 commit f3e8aab
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions flash/image/classification/backbones/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@
from torch.hub import load_state_dict_from_url

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TIMM_AVAILABLE
from flash.core.utilities.url_error import catch_url_error

if _TIMM_AVAILABLE:
from timm.models.helpers import adapt_input_conv


def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding."""
Expand Down Expand Up @@ -351,6 +355,8 @@ def _resnet(
)

if model_weights is not None:
in_chans = backbone.conv1.weight.shape[1]
model_weights["conv1.weight"] = adapt_input_conv(in_chans, model_weights["conv1.weight"])
backbone.load_state_dict(model_weights)

return backbone, num_features
Expand Down

0 comments on commit f3e8aab

Please sign in to comment.