Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added better functionality for label selection #713

Merged
merged 6 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/autolabel/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class AutolabelConfig(BaseConfig):
CHAIN_OF_THOUGHT_KEY = "chain_of_thought"
LABEL_SELECTION_KEY = "label_selection"
LABEL_SELECTION_COUNT_KEY = "label_selection_count"
LABEL_SELECTION_THRESHOLD = "label_selection_threshold"
ATTRIBUTES_KEY = "attributes"
TRANSFORM_KEY = "transforms"

Expand Down Expand Up @@ -188,7 +189,7 @@ def labels_list(self) -> List[str]:
if isinstance(self._prompt_config.get(self.VALID_LABELS_KEY, []), List):
return self._prompt_config.get(self.VALID_LABELS_KEY, [])
else:
return self._prompt_config.get(self.VALID_LABELS_KEY, {}).keys()
return list(self._prompt_config.get(self.VALID_LABELS_KEY, {}).keys())

def label_descriptions(self) -> Dict[str, str]:
"""Returns a dict of label descriptions"""
Expand Down Expand Up @@ -238,7 +239,16 @@ def label_selection(self) -> bool:

def label_selection_count(self) -> int:
Vaibhav2001 marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the number of labels to select in LabelSelector"""
return self._prompt_config.get(self.LABEL_SELECTION_COUNT_KEY, 10)
k = self._prompt_config.get(self.LABEL_SELECTION_COUNT_KEY, 10)
if k < 1:
return len(self.labels_list())
return k

def label_selection_threshold(self) -> float:
"""Returns the threshold for label selection in LabelSelector
If the similarity score ratio with the top Score is above this threshold,
the label is selected."""
return self._prompt_config.get(self.LABEL_SELECTION_THRESHOLD, 0.95)
Vaibhav2001 marked this conversation as resolved.
Show resolved Hide resolved

def attributes(self) -> List[Dict]:
"""Returns a list of attributes to extract from the text."""
Expand Down
5 changes: 3 additions & 2 deletions src/autolabel/few_shot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import logging
from typing import Dict, List

from autolabel.configs import AutolabelConfig
from autolabel.schema import FewShotAlgorithm, ModelProvider
from langchain.embeddings import (
CohereEmbeddings,
HuggingFaceEmbeddings,
Expand All @@ -16,6 +14,9 @@
)
from langchain.prompts.example_selector.base import BaseExampleSelector

from autolabel.configs import AutolabelConfig
from autolabel.schema import FewShotAlgorithm, ModelProvider

from .fixed_example_selector import FixedExampleSelector
from .label_diversity_example_selector import (
LabelDiversityRandomExampleSelector,
Expand Down
28 changes: 24 additions & 4 deletions src/autolabel/few_shot/fixed_example_selector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from pydantic import BaseModel, Extra
from typing import Dict, List, Optional

from typing import Dict, List
from langchain.prompts.example_selector.base import BaseExampleSelector
from pydantic import BaseModel, Extra


class FixedExampleSelector(BaseExampleSelector, BaseModel):
Expand All @@ -26,9 +26,29 @@ class Config:
def add_example(self, example: Dict[str, str]) -> None:
self.examples.append(example)

def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
def select_examples(
self,
input_variables: Dict[str, str],
**kwargs,
) -> List[dict]:
"""Select which examples to use based on the input lengths."""
return self.examples[: self.k]
label_column = kwargs.get("label_column")
selected_labels = kwargs.get("selected_labels")

if not selected_labels:
return self.examples[: self.k]

if not label_column:
print("No label column provided, returning all examples")
return self.examples[: self.k]

# get the examples where label matches the selected labels
valid_examples = [
example
for example in self.examples
if example.get(label_column) in selected_labels
]
return valid_examples[: min(self.k, len(valid_examples))]

@classmethod
def from_examples(
Expand Down
75 changes: 45 additions & 30 deletions src/autolabel/few_shot/label_selector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Dict, List

import bisect
from collections.abc import Callable
from typing import Dict, List, Optional, Tuple, Union

from sqlalchemy.sql import text as sql_text

from autolabel.few_shot.vector_store import cos_sim
from autolabel.configs import AutolabelConfig
from autolabel.few_shot.vector_store import VectorStoreWrapper, cos_sim


class LabelSelector:
Expand All @@ -13,45 +17,56 @@ class LabelSelector:
labels: List[str]
"""A list of the possible labels to choose from."""

k: int = 10
"""Number of labels to select"""

embedding_func: Callable = None
"""Function used to generate embeddings of labels/input"""
label_descriptions: Optional[Dict[str, str]]
"""A dictionary of label descriptions. If provided, the selector will
use these descriptions to find the most similar labels to the input."""

labels_embeddings: Dict = {}
"""Dict used to store embeddings of each label"""

cache: bool = True
"""Whether to cache the embeddings of labels"""

def __init__(
self, labels: List[str], embedding_func: Callable, k: int = 10
self,
config: Union[AutolabelConfig, str, dict],
embedding_func: Callable,
cache: bool = True,
) -> None:
self.labels = labels
self.k = min(k, len(labels))
self.embedding_func = embedding_func
for l in self.labels:
self.labels_embeddings[l] = self.embedding_func.embed_query(l)
self.config = config
self.labels = self.config.labels_list()
self.label_descriptions = self.config.label_descriptions()
self.k = min(self.config.label_selection_count(), len(self.labels))
self.threshold = self.config.label_selection_threshold()
self.cache = cache
self.vectorStore = VectorStoreWrapper(
embedding_function=embedding_func, cache=self.cache
)

# Get the embeddings of the labels
if self.label_descriptions is not None:
(labels, descriptions) = zip(*self.label_descriptions.items())
embeddings = self.vectorStore._get_embeddings(descriptions)
for i, label in enumerate(labels):
self.labels_embeddings[label] = embeddings[i]
else:
embeddings = self.vectorStore._get_embeddings(self.labels)
for i, label in enumerate(labels):
self.labels_embeddings[label] = embeddings[i]

def select_labels(self, input: str) -> List[str]:
"""Select which labels to use based on the similarity to input"""
input_embedding = self.embedding_func.embed_query(input)
input_embedding = self.vectorStore._get_embeddings([input])

scores = []
for label, embedding in self.labels_embeddings.items():
similarity = cos_sim(embedding, input_embedding)
# insert into scores, while maintaining sorted order
bisect.insort(scores, (similarity, label))
Vaibhav2001 marked this conversation as resolved.
Show resolved Hide resolved
return [label for (_, label) in scores[-self.k :]]

@classmethod
def from_examples(
cls,
labels: List[str],
embedding_func,
k: int = 10,
) -> LabelSelector:
"""Create pass-through label selector using given list of labels

Returns:
The LabelSelector instantiated
"""
return cls(labels=labels, k=k, embedding_func=embedding_func)

# remove labels with similarity score less than self.threshold*topScore
return [
Vaibhav2001 marked this conversation as resolved.
Show resolved Hide resolved
label
for (score, label) in scores[-self.k :]
if score > self.threshold * scores[-1][0]
]
Loading
Loading