Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jun 16, 2021
1 parent c867c92 commit f8a1432
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
7 changes: 2 additions & 5 deletions flash/image/segmentation/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,8 @@ def _get_backbone_meta(backbone):
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = sum([
[0],
[i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)],
[len(backbone) - 1],
])
stage_indices = ([0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] +
[len(backbone) - 1])
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
out_layer = str(out_pos)
out_inplanes = backbone[out_pos].out_channels
Expand Down
2 changes: 1 addition & 1 deletion tests/image/segmentation/test_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
import pytest
import torch
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE

from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS

Expand Down

0 comments on commit f8a1432

Please sign in to comment.