diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 69a3fd8c85..5528cfc5d6 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -38,10 +38,9 @@ class ImageClassifier(ClassificationTask): def __init__( self, - num_classes, - backbone="resnet18", - num_features: int = None, - pretrained=True, + num_classes: int, + backbone: str = "resnet18", + pretrained: bool = True, loss_fn: Callable = F.cross_entropy, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, metrics: Union[Callable, Mapping, Sequence, None] = (Accuracy()),