Skip to content

Commit

Permalink
Better support for classification tasks with large number of label cl…
Browse files Browse the repository at this point in the history
…asses (#561)

* for classification tasks with a large number of categories, filter the list of labels by similarity to the prompt

* replace Chroma DB with autolabels own VectorStoreWrapper. Remove debug prints

* move label selection logic into its own class

* allow for LabelSelector.k to be specified in config

* clear up comment

* remove default for embedding_func=OpenAIEmbeddings() , as this requires having OPENAI_API_KEY when importing autolabel

* if task_selection=true, check that task_type=classification

* remove unnused imports
  • Loading branch information
iomap authored Sep 19, 2023
1 parent 233d390 commit acb8755
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 7 deletions.
13 changes: 12 additions & 1 deletion src/autolabel/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ class AutolabelConfig(BaseConfig):
OUTPUT_GUIDELINE_KEY = "output_guidelines"
OUTPUT_FORMAT_KEY = "output_format"
CHAIN_OF_THOUGHT_KEY = "chain_of_thought"
LABEL_SELECTION_KEY = "label_selection"
LABEL_SELECTION_COUNT_KEY = "label_selection_count"
ATTRIBUTES_KEY = "attributes"

TRANSFORM_KEY = "transforms"

# Dataset generation config keys (config["dataset_generation"][<key>])
Expand Down Expand Up @@ -208,6 +209,16 @@ def chain_of_thought(self) -> bool:
"""Returns true if the model is able to perform chain of thought reasoning."""
return self._prompt_config.get(self.CHAIN_OF_THOUGHT_KEY, False)

def label_selection(self) -> bool:
"""Returns true if label selection is enabled. Label selection is the process of
narrowing down the list of possible labels by similarity to a given input. Useful for
classification tasks with a large number of possible classes."""
return self._prompt_config.get(self.LABEL_SELECTION_KEY, False)

def label_selection_count(self) -> int:
"""Returns the number of labels to select in LabelSelector"""
return self._prompt_config.get(self.LABEL_SELECTION_COUNT_KEY, 10)

def attributes(self) -> List[Dict]:
"""Returns a list of attributes to extract from the text."""
return self._prompt_config.get(self.ATTRIBUTES_KEY, [])
Expand Down
2 changes: 2 additions & 0 deletions src/autolabel/configs/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def populate_few_shot_selection() -> List[str]:
},
"few_shot_num": {"type": ["number", "null"]},
"chain_of_thought": {"type": ["boolean", "null"]},
"label_selection": {"type": ["boolean", "null"]},
"label_selection_count": {"type": ["number", "null"]},
"attributes": {
"anyOf": [
{"type": "array", "items": {"type": "object"}},
Expand Down
57 changes: 57 additions & 0 deletions src/autolabel/few_shot/label_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Dict, List
import bisect

from autolabel.few_shot.vector_store import cos_sim


class LabelSelector:
"""Returns the most similar labels to a given input. Used for
classification tasks with a large number of possible classes."""

labels: List[str]
"""A list of the possible labels to choose from."""

k: int = 10
"""Number of labels to select"""

embedding_func: Callable = None
"""Function used to generate embeddings of labels/input"""

labels_embeddings: Dict = {}
"""Dict used to store embeddings of each label"""

def __init__(
self, labels: List[str], embedding_func: Callable, k: int = 10
) -> None:
self.labels = labels
self.k = min(k, len(labels))
self.embedding_func = embedding_func
for l in self.labels:
self.labels_embeddings[l] = self.embedding_func.embed_query(l)

def select_labels(self, input: str) -> List[str]:
"""Select which labels to use based on the similarity to input"""
input_embedding = self.embedding_func.embed_query(input)

scores = []
for label, embedding in self.labels_embeddings.items():
similarity = cos_sim(embedding, input_embedding)
# insert into scores, while maintaining sorted order
bisect.insort(scores, (similarity, label))
return [label for (_, label) in scores[-self.k :]]

@classmethod
def from_examples(
cls,
labels: List[str],
embedding_func,
k: int = 10,
) -> LabelSelector:
"""Create pass-through label selector using given list of labels
Returns:
The LabelSelector instantiated
"""
return cls(labels=labels, k=k, embedding_func=embedding_func)
61 changes: 57 additions & 4 deletions src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
from autolabel.dataset import AutolabelDataset
from autolabel.data_models import AnnotationModel, TaskRunModel
from autolabel.database import StateManager
from autolabel.few_shot import ExampleSelectorFactory, BaseExampleSelector
from autolabel.few_shot import (
ExampleSelectorFactory,
BaseExampleSelector,
DEFAULT_EMBEDDING_PROVIDER,
PROVIDER_TO_MODEL,
)
from autolabel.few_shot.label_selector import LabelSelector
from autolabel.models import BaseModel, ModelFactory
from autolabel.metrics import BaseMetric
from autolabel.transforms import BaseTransform, TransformFactory
Expand All @@ -29,6 +35,7 @@
MetricResult,
TaskRun,
TaskStatus,
TaskType,
)
from autolabel.tasks import TaskFactory
from autolabel.utils import (
Expand Down Expand Up @@ -182,6 +189,20 @@ def run(
cache=self.generation_cache is not None,
)

if self.config.label_selection():
if self.config.task_type() != TaskType.CLASSIFICATION:
self.console.print(
"Warning: label_selection only supported for classification tasks!"
)
else:
self.label_selector = LabelSelector.from_examples(
labels=self.config.labels_list(),
embedding_func=PROVIDER_TO_MODEL.get(
self.config.embedding_provider(), DEFAULT_EMBEDDING_PROVIDER
)(),
k=self.config.label_selection_count(),
)

current_index = self.task_run.current_index if self.create_task else 0
cost = 0.0
postfix_dict = {}
Expand All @@ -202,8 +223,17 @@ def run(
)
else:
examples = []
# Construct Prompt to pass to LLM
final_prompt = self.task.construct_prompt(chunk, examples)
# Construct Prompt to pass to LLM
if (
self.config.label_selection()
and self.config.task_type() == TaskType.CLASSIFICATION
):
selected_labels = self.label_selector.select_labels(chunk["example"])
final_prompt = self.task.construct_prompt(
chunk, examples, selected_labels
)
else:
final_prompt = self.task.construct_prompt(chunk, examples)

response = self.llm.label([final_prompt])
for i, generations, error in zip(
Expand Down Expand Up @@ -356,6 +386,20 @@ def plan(
cache=self.generation_cache is not None,
)

if self.config.label_selection():
if self.config.task_type() != TaskType.CLASSIFICATION:
self.console.print(
"Warning: label_selection only supported for classification tasks!"
)
else:
self.label_selector = LabelSelector.from_examples(
labels=self.config.labels_list(),
embedding_func=PROVIDER_TO_MODEL.get(
self.config.embedding_provider(), DEFAULT_EMBEDDING_PROVIDER
)(),
k=self.config.label_selection_count(),
)

input_limit = min(len(dataset.inputs), 100)

for input_i in track(
Expand All @@ -370,7 +414,16 @@ def plan(
)
else:
examples = []
final_prompt = self.task.construct_prompt(input_i, examples)
if (
self.config.label_selection()
and self.config.task_type() == TaskType.CLASSIFICATION
):
selected_labels = self.label_selector.select_labels(input_i["example"])
final_prompt = self.task.construct_prompt(
input_i, examples, selected_labels
)
else:
final_prompt = self.task.construct_prompt(input_i, examples)
prompt_list.append(final_prompt)

# Calculate the number of tokens
Expand Down
9 changes: 7 additions & 2 deletions src/autolabel/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,18 @@ def __init__(self, config: AutolabelConfig) -> None:
if self.config.confidence():
self.metrics.append(AUROCMetric())

def construct_prompt(self, input: Dict, examples: List) -> str:
def construct_prompt(
self, input: Dict, examples: List, selected_labels: List[str] = None
) -> str:
# Copy over the input so that we can modify it
input = input.copy()

# prepare task guideline
labels_list = self.config.labels_list()
labels_list = (
self.config.labels_list() if not selected_labels else selected_labels
)
num_labels = len(labels_list)

fmt_task_guidelines = self.task_guidelines.format(
num_labels=num_labels, labels="\n".join(labels_list)
)
Expand Down

0 comments on commit acb8755

Please sign in to comment.