-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Description
Is your feature request related to a problem? Please describe.
Currently, saving both the feature embedding vector and the prediction vector requires two forward passes:
predictions = model(inputs)
embeddings = model.forward_features(inputs) # inefficient
This is technically not necessary, since the embeddings are computed when you compute predictions.
In addition, once you have your embeddings, there is no standardized method to convert the embeddings into predictions. For example, with ResNet Models I could do something like:
predictions = model.fc(embeddings)
...but that does not generalize to other models since not every model has a fc layer.
Describe the solution you'd like
For every model to have a forward method (I suggest the name forward_predictions) which takes an embedding as input and outputs a prediction.
For example, ResNet would go from:
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x
to:
def forward_predictions(self, x): # x is embedding vector
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.forward_predictions(x)
return x
Our inference code, then, would become:
embeddings = model.forward_features(input)
predictions = model.forward_predictions(embeddings) # no redundant compute
return embeddings, predictions
... and enables us to convert feature embedding vectors from a feature store into a prediction:
embedding = feature_store.query(<interesting image>)
nearest_neighbor_embedding = feature_store.get_nearest_vector(embedding)
prediction = model.forward_predictions(nearest_neighbor_embedding)
Describe alternatives you've considered
Option 1: Two separate forward passes (for embeddings and predictions) for each image. Create "forward_predictions" helper functions for each different architecture we use to convert embeddings into predictions.
Option 2: register a custom forward hook for every architecture which intercepts the embedding vector during the forward pass.