From 941911a0d08154107e057749231cca0b3d968016 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sun, 2 May 2021 17:22:01 +0100 Subject: [PATCH] update --- flash/vision/classification/model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 2d68e07341..916c7c2d90 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -104,11 +104,10 @@ def __init__( self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) head = head(num_features, num_classes) if isinstance(head, FunctionType) else head - self.head = head or nn.Sequential( - nn.Flatten(), - nn.Linear(num_features, num_classes), - ) + self.head = head or nn.Sequential(nn.Linear(num_features, num_classes), ) def forward(self, x) -> torch.Tensor: x = self.backbone(x) + if x.dim() == 4: + x = x.mean(-1).mean(-1) return self.head(x)