Skip to content

[FEATURE] Method to convert feature embeddings into predictions #1141

@crypdick

Description

@crypdick

Is your feature request related to a problem? Please describe.
Currently, saving both the feature embedding vector and the prediction vector requires two forward passes:

predictions = model(inputs)
embeddings = model.forward_features(inputs)  # inefficient

This is technically not necessary, since the embeddings are computed when you compute predictions.

In addition, once you have your embeddings, there is no standardized method to convert the embeddings into predictions. For example, with ResNet Models I could do something like:

predictions = model.fc(embeddings)

...but that does not generalize to other models since not every model has a fc layer.

Describe the solution you'd like

For every model to have a forward method (I suggest the name forward_predictions) which takes an embedding as input and outputs a prediction.

For example, ResNet would go from:

def forward(self, x):
    x = self.forward_features(x)
    x = self.global_pool(x)
    if self.drop_rate:
        x = F.dropout(x, p=float(self.drop_rate), training=self.training)
    x = self.fc(x)
    return x

to:

def forward_predictions(self, x):  # x is embedding vector
    x = self.global_pool(x)
    if self.drop_rate:
        x = F.dropout(x, p=float(self.drop_rate), training=self.training)
    x = self.fc(x)
    return x

def forward(self, x):
    x = self.forward_features(x)
    x = self.forward_predictions(x)
    return x

Our inference code, then, would become:

embeddings = model.forward_features(input)
predictions = model.forward_predictions(embeddings)  # no redundant compute
return embeddings, predictions

... and enables us to convert feature embedding vectors from a feature store into a prediction:

embedding = feature_store.query(<interesting image>)
nearest_neighbor_embedding = feature_store.get_nearest_vector(embedding)
prediction = model.forward_predictions(nearest_neighbor_embedding)

Describe alternatives you've considered
Option 1: Two separate forward passes (for embeddings and predictions) for each image. Create "forward_predictions" helper functions for each different architecture we use to convert embeddings into predictions.

Option 2: register a custom forward hook for every architecture which intercepts the embedding vector during the forward pass.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions