Skip to content

Commit

Permalink
Added better functionality for label selection (#713)
Browse files Browse the repository at this point in the history
* fixed label selection

* fixed label selection

* fixed few shot for label selection

* addressed comments

* addressed comments

* addressed comments
  • Loading branch information
Vaibhav2001 authored Feb 16, 2024
1 parent 42f6cc2 commit 0cbc022
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 138 deletions.
16 changes: 13 additions & 3 deletions src/autolabel/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class AutolabelConfig(BaseConfig):
CHAIN_OF_THOUGHT_KEY = "chain_of_thought"
LABEL_SELECTION_KEY = "label_selection"
LABEL_SELECTION_COUNT_KEY = "label_selection_count"
LABEL_SELECTION_THRESHOLD = "label_selection_threshold"
ATTRIBUTES_KEY = "attributes"
TRANSFORM_KEY = "transforms"

Expand Down Expand Up @@ -188,7 +189,7 @@ def labels_list(self) -> List[str]:
if isinstance(self._prompt_config.get(self.VALID_LABELS_KEY, []), List):
return self._prompt_config.get(self.VALID_LABELS_KEY, [])
else:
return self._prompt_config.get(self.VALID_LABELS_KEY, {}).keys()
return list(self._prompt_config.get(self.VALID_LABELS_KEY, {}).keys())

def label_descriptions(self) -> Dict[str, str]:
"""Returns a dict of label descriptions"""
Expand Down Expand Up @@ -236,9 +237,18 @@ def label_selection(self) -> bool:
classification tasks with a large number of possible classes."""
return self._prompt_config.get(self.LABEL_SELECTION_KEY, False)

def label_selection_count(self) -> int:
def max_selected_labels(self) -> int:
"""Returns the number of labels to select in LabelSelector"""
return self._prompt_config.get(self.LABEL_SELECTION_COUNT_KEY, 10)
k = self._prompt_config.get(self.LABEL_SELECTION_COUNT_KEY, 10)
if k < 1:
return len(self.labels_list())
return k

def label_selection_threshold(self) -> float:
"""Returns the threshold for label selection in LabelSelector
If the similarity score ratio with the top Score is above this threshold,
the label is selected."""
return self._prompt_config.get(self.LABEL_SELECTION_THRESHOLD, 0.0)

def attributes(self) -> List[Dict]:
"""Returns a list of attributes to extract from the text."""
Expand Down
5 changes: 3 additions & 2 deletions src/autolabel/few_shot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
from typing import Dict, List

from autolabel.configs import AutolabelConfig
from autolabel.schema import FewShotAlgorithm, ModelProvider
from langchain.embeddings import (
CohereEmbeddings,
HuggingFaceEmbeddings,
Expand All @@ -16,6 +14,9 @@
)
from langchain.prompts.example_selector.base import BaseExampleSelector

from autolabel.configs import AutolabelConfig
from autolabel.schema import FewShotAlgorithm, ModelProvider

from .fixed_example_selector import FixedExampleSelector
from .label_diversity_example_selector import (
LabelDiversityRandomExampleSelector,
Expand Down
28 changes: 24 additions & 4 deletions src/autolabel/few_shot/fixed_example_selector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from pydantic import BaseModel, Extra
from typing import Dict, List, Optional

from typing import Dict, List
from langchain.prompts.example_selector.base import BaseExampleSelector
from pydantic import BaseModel, Extra


class FixedExampleSelector(BaseExampleSelector, BaseModel):
Expand All @@ -26,9 +26,29 @@ class Config:
def add_example(self, example: Dict[str, str]) -> None:
self.examples.append(example)

def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
def select_examples(
self,
input_variables: Dict[str, str],
**kwargs,
) -> List[dict]:
"""Select which examples to use based on the input lengths."""
return self.examples[: self.k]
label_column = kwargs.get("label_column")
selected_labels = kwargs.get("selected_labels")

if not selected_labels:
return self.examples[: self.k]

if not label_column:
print("No label column provided, returning all examples")
return self.examples[: self.k]

# get the examples where label matches the selected labels
valid_examples = [
example
for example in self.examples
if example.get(label_column) in selected_labels
]
return valid_examples[: min(self.k, len(valid_examples))]

@classmethod
def from_examples(
Expand Down
86 changes: 50 additions & 36 deletions src/autolabel/few_shot/label_selector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Dict, List

import bisect
from collections.abc import Callable
from typing import Dict, List, Optional, Tuple, Union

import torch
from sqlalchemy.sql import text as sql_text

from autolabel.few_shot.vector_store import cos_sim
from autolabel.configs import AutolabelConfig
from autolabel.few_shot.vector_store import VectorStoreWrapper, cos_sim


class LabelSelector:
Expand All @@ -13,45 +18,54 @@ class LabelSelector:
labels: List[str]
"""A list of the possible labels to choose from."""

k: int = 10
"""Number of labels to select"""

embedding_func: Callable = None
"""Function used to generate embeddings of labels/input"""
label_descriptions: Optional[Dict[str, str]]
"""A dictionary of label descriptions. If provided, the selector will
use these descriptions to find the most similar labels to the input."""

labels_embeddings: Dict = {}
"""Dict used to store embeddings of each label"""

cache: bool = True
"""Whether to cache the embeddings of labels"""

def __init__(
self, labels: List[str], embedding_func: Callable, k: int = 10
self,
config: Union[AutolabelConfig, str, dict],
embedding_func: Callable,
cache: bool = True,
) -> None:
self.labels = labels
self.k = min(k, len(labels))
self.embedding_func = embedding_func
for l in self.labels:
self.labels_embeddings[l] = self.embedding_func.embed_query(l)
self.config = config
self.labels = self.config.labels_list()
self.label_descriptions = self.config.label_descriptions()
self.k = min(self.config.max_selected_labels(), len(self.labels))
self.threshold = self.config.label_selection_threshold()
self.cache = cache
self.vectorStore = VectorStoreWrapper(
embedding_function=embedding_func, cache=self.cache
)

# Get the embeddings of the labels
if self.label_descriptions is not None:
(labels, descriptions) = zip(*self.label_descriptions.items())
self.labels = list(labels)
self.labels_embeddings = torch.Tensor(
self.vectorStore._get_embeddings(descriptions)
)
else:
self.labels_embeddings = torch.Tensor(
self.vectorStore._get_embeddings(self.labels)
)

def select_labels(self, input: str) -> List[str]:
"""Select which labels to use based on the similarity to input"""
input_embedding = self.embedding_func.embed_query(input)

scores = []
for label, embedding in self.labels_embeddings.items():
similarity = cos_sim(embedding, input_embedding)
# insert into scores, while maintaining sorted order
bisect.insort(scores, (similarity, label))
return [label for (_, label) in scores[-self.k :]]

@classmethod
def from_examples(
cls,
labels: List[str],
embedding_func,
k: int = 10,
) -> LabelSelector:
"""Create pass-through label selector using given list of labels
Returns:
The LabelSelector instantiated
"""
return cls(labels=labels, k=k, embedding_func=embedding_func)
input_embedding = torch.Tensor(self.vectorStore._get_embeddings([input]))
scores = cos_sim(input_embedding, self.labels_embeddings).view(-1)
scores = list(zip(scores, self.labels))
scores.sort(key=lambda x: x[0])

# remove labels with similarity score less than self.threshold*topScore
return [
label
for (score, label) in scores[-self.k :]
if score > self.threshold * scores[-1][0]
]
Loading

0 comments on commit 0cbc022

Please sign in to comment.