diff --git a/efficientnet_pytorch/model.py b/efficientnet_pytorch/model.py index 3a42217..ed14b19 100644 --- a/efficientnet_pytorch/model.py +++ b/efficientnet_pytorch/model.py @@ -139,7 +139,7 @@ def __init__(self, blocks_args=None, global_params=None): # The first block needs to take care of stride and filter size increase. self._blocks.append(MBConvBlock(block_args, self._global_params)) if block_args.num_repeat > 1: - block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) + block_args = block_args._replace(input_filters=block_args.output_filters, stride=[1]) for _ in range(block_args.num_repeat - 1): self._blocks.append(MBConvBlock(block_args, self._global_params))