Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Populate default_preprocess in embedding model
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Apr 14, 2021
1 parent 5e0d0ca commit 3ed7b82
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions flash/vision/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,6 @@ class ImageEmbedder(Task):

backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES

@property
def preprocess(self):
return ImageClassificationPreprocess(predict_transform=ImageClassificationData.default_val_transforms())

def __init__(
self,
embedding_dim: Optional[int] = None,
Expand All @@ -70,6 +66,9 @@ def __init__(
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
default_preprocess=ImageClassificationPreprocess(
predict_transform=ImageClassificationData.default_val_transforms(),
)
)

self.save_hyperparameters()
Expand Down

0 comments on commit 3ed7b82

Please sign in to comment.