Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed May 2, 2021
1 parent eab18c4 commit 941911a
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,10 @@ def __init__(
self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)

head = head(num_features, num_classes) if isinstance(head, FunctionType) else head
self.head = head or nn.Sequential(
nn.Flatten(),
nn.Linear(num_features, num_classes),
)
self.head = head or nn.Sequential(nn.Linear(num_features, num_classes), )

def forward(self, x) -> torch.Tensor:
x = self.backbone(x)
if x.dim() == 4:
x = x.mean(-1).mean(-1)
return self.head(x)

0 comments on commit 941911a

Please sign in to comment.