diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index 4444b5c3ab..eab47e0cf5 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -56,8 +56,8 @@ def _get_backbone_meta(backbone): backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. - stage_indices = ([0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + - [len(backbone) - 1]) + stage_indices = [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + stage_indices = [0] + stage_indices + [len(backbone) - 1] out_pos = stage_indices[-1] # use C5 which has output_stride = 16 out_layer = str(out_pos) out_inplanes = backbone[out_pos].out_channels