diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 2d68e07341..916c7c2d90 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -104,11 +104,10 @@ def __init__( self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) head = head(num_features, num_classes) if isinstance(head, FunctionType) else head - self.head = head or nn.Sequential( - nn.Flatten(), - nn.Linear(num_features, num_classes), - ) + self.head = head or nn.Sequential(nn.Linear(num_features, num_classes), ) def forward(self, x) -> torch.Tensor: x = self.backbone(x) + if x.dim() == 4: + x = x.mean(-1).mean(-1) return self.head(x)