Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add in_chans arg to resnet (#673)
Browse files Browse the repository at this point in the history
* Add in_chans arg to resnet

* Update CHANGELOG.md
  • Loading branch information
ethanwharris authored Aug 17, 2021
1 parent e86bfd9 commit 18c5322
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for flash zero with the `InstanceSegmentation` and `KeypointDetector` tasks ([#672](https://github.com/PyTorchLightning/lightning-flash/pull/672))

- Added support for `in_chans` argument to the flash ResNet to control the expected number of input channels ([#673](https://github.com/PyTorchLightning/lightning-flash/pull/673))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
5 changes: 3 additions & 2 deletions flash/image/classification/backbones/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
norm_layer: Optional[Callable[..., nn.Module]] = None,
first_conv3x3: bool = False,
remove_first_maxpool: bool = False,
in_chans: int = 3,
) -> None:

super().__init__()
Expand All @@ -194,9 +195,9 @@ def __init__(
num_out_filters = width_per_group * widen

if first_conv3x3:
self.conv1 = nn.Conv2d(3, num_out_filters, kernel_size=3, stride=1, padding=1, bias=False)
self.conv1 = nn.Conv2d(in_chans, num_out_filters, kernel_size=3, stride=1, padding=1, bias=False)
else:
self.conv1 = nn.Conv2d(3, num_out_filters, kernel_size=7, stride=2, padding=3, bias=False)
self.conv1 = nn.Conv2d(in_chans, num_out_filters, kernel_size=7, stride=2, padding=3, bias=False)

self.bn1 = norm_layer(num_out_filters)
self.relu = nn.ReLU(inplace=True)
Expand Down

0 comments on commit 18c5322

Please sign in to comment.