-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Better support for classification tasks with large number of label cl…
…asses (#561) * for classification tasks with a large number of categories, filter the list of labels by similarity to the prompt * replace Chroma DB with autolabels own VectorStoreWrapper. Remove debug prints * move label selection logic into its own class * allow for LabelSelector.k to be specified in config * clear up comment * remove default for embedding_func=OpenAIEmbeddings() , as this requires having OPENAI_API_KEY when importing autolabel * if task_selection=true, check that task_type=classification * remove unnused imports
- Loading branch information
Showing
5 changed files
with
135 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from __future__ import annotations | ||
from collections.abc import Callable | ||
from typing import Dict, List | ||
import bisect | ||
|
||
from autolabel.few_shot.vector_store import cos_sim | ||
|
||
|
||
class LabelSelector: | ||
"""Returns the most similar labels to a given input. Used for | ||
classification tasks with a large number of possible classes.""" | ||
|
||
labels: List[str] | ||
"""A list of the possible labels to choose from.""" | ||
|
||
k: int = 10 | ||
"""Number of labels to select""" | ||
|
||
embedding_func: Callable = None | ||
"""Function used to generate embeddings of labels/input""" | ||
|
||
labels_embeddings: Dict = {} | ||
"""Dict used to store embeddings of each label""" | ||
|
||
def __init__( | ||
self, labels: List[str], embedding_func: Callable, k: int = 10 | ||
) -> None: | ||
self.labels = labels | ||
self.k = min(k, len(labels)) | ||
self.embedding_func = embedding_func | ||
for l in self.labels: | ||
self.labels_embeddings[l] = self.embedding_func.embed_query(l) | ||
|
||
def select_labels(self, input: str) -> List[str]: | ||
"""Select which labels to use based on the similarity to input""" | ||
input_embedding = self.embedding_func.embed_query(input) | ||
|
||
scores = [] | ||
for label, embedding in self.labels_embeddings.items(): | ||
similarity = cos_sim(embedding, input_embedding) | ||
# insert into scores, while maintaining sorted order | ||
bisect.insort(scores, (similarity, label)) | ||
return [label for (_, label) in scores[-self.k :]] | ||
|
||
@classmethod | ||
def from_examples( | ||
cls, | ||
labels: List[str], | ||
embedding_func, | ||
k: int = 10, | ||
) -> LabelSelector: | ||
"""Create pass-through label selector using given list of labels | ||
Returns: | ||
The LabelSelector instantiated | ||
""" | ||
return cls(labels=labels, k=k, embedding_func=embedding_func) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters