Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better support for classification tasks with large number of label classes #561

Merged
merged 9 commits into from
Sep 19, 2023
12 changes: 12 additions & 0 deletions src/autolabel/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class AutolabelConfig(BaseConfig):
OUTPUT_GUIDELINE_KEY = "output_guidelines"
OUTPUT_FORMAT_KEY = "output_format"
CHAIN_OF_THOUGHT_KEY = "chain_of_thought"
LABEL_SELECTION_KEY = "label_selection"
LABEL_SELECTION_COUNT_KEY = "label_selection_count"
TRANSFORM_KEY = "transforms"

# Dataset generation config keys (config["dataset_generation"][<key>])
Expand Down Expand Up @@ -201,6 +203,16 @@ def chain_of_thought(self) -> bool:
"""Returns true if the model is able to perform chain of thought reasoning."""
return self._prompt_config.get(self.CHAIN_OF_THOUGHT_KEY, False)

def label_selection(self) -> bool:
"""Returns true if label selection is enabled. Label selection is the process of
narrowing down the list of possible labels by similarity to a given input. Useful for
classification tasks with a large number of possible classes."""
return self._prompt_config.get(self.LABEL_SELECTION_KEY, False)

def label_selection_count(self) -> int:
"""Returns the number of labels to select in LabelSelector"""
return self._prompt_config.get(self.LABEL_SELECTION_COUNT_KEY, 10)

def transforms(self) -> List[Dict]:
"""Returns a list of transforms to apply to the data before sending to the model."""
return self.config.get(self.TRANSFORM_KEY, [])
Expand Down
2 changes: 2 additions & 0 deletions src/autolabel/configs/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def populate_few_shot_selection() -> List[str]:
},
"few_shot_num": {"type": ["number", "null"]},
"chain_of_thought": {"type": ["boolean", "null"]},
"label_selection": {"type": ["boolean", "null"]},
"label_selection_count": {"type": ["number", "null"]},
},
"required": ["task_guidelines"],
"additionalProperties": False,
Expand Down
57 changes: 57 additions & 0 deletions src/autolabel/few_shot/label_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Dict, List
import bisect

from autolabel.few_shot.vector_store import cos_sim


class LabelSelector:
"""Returns the most similar labels to a given input. Used for
classification tasks with a large number of possible classes."""

labels: List[str]
"""A list of the possible labels to choose from."""

k: int = 10
"""Number of labels to select"""

embedding_func: Callable = None
"""Function used to generate embeddings of labels/input"""

labels_embeddings: Dict = {}
"""Dict used to store embeddings of each label"""

def __init__(
self, labels: List[str], embedding_func: Callable, k: int = 10
) -> None:
self.labels = labels
self.k = min(k, len(labels))
self.embedding_func = embedding_func
for l in self.labels:
self.labels_embeddings[l] = self.embedding_func.embed_query(l)

def select_labels(self, input: str) -> List[str]:
"""Select which labels to use based on the similarity to input"""
input_embedding = self.embedding_func.embed_query(input)

scores = []
for label, embedding in self.labels_embeddings.items():
similarity = cos_sim(embedding, input_embedding)
# insert into scores, while maintaining sorted order
bisect.insort(scores, (similarity, label))
return [label for (_, label) in scores[-self.k :]]

@classmethod
def from_examples(
cls,
labels: List[str],
embedding_func,
k: int = 10,
) -> LabelSelector:
"""Create pass-through label selector using given list of labels

Returns:
The LabelSelector instantiated
"""
return cls(labels=labels, k=k, embedding_func=embedding_func)
61 changes: 57 additions & 4 deletions src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
from autolabel.dataset import AutolabelDataset
from autolabel.data_models import AnnotationModel, TaskRunModel
from autolabel.database import StateManager
from autolabel.few_shot import ExampleSelectorFactory, BaseExampleSelector
from autolabel.few_shot import (
ExampleSelectorFactory,
BaseExampleSelector,
DEFAULT_EMBEDDING_PROVIDER,
PROVIDER_TO_MODEL,
)
from autolabel.few_shot.label_selector import LabelSelector
from autolabel.models import BaseModel, ModelFactory
from autolabel.metrics import BaseMetric
from autolabel.transforms import BaseTransform, TransformFactory
Expand All @@ -25,6 +31,7 @@
MetricResult,
TaskRun,
TaskStatus,
TaskType,
)
from autolabel.tasks import TaskFactory
from autolabel.utils import (
Expand Down Expand Up @@ -165,6 +172,20 @@ def run(
cache=self.generation_cache is not None,
)

if self.config.label_selection():
if self.config.task_type() != TaskType.CLASSIFICATION:
self.console.print(
"Warning: label_selection only supported for classification tasks!"
)
else:
self.label_selector = LabelSelector.from_examples(
labels=self.config.labels_list(),
embedding_func=PROVIDER_TO_MODEL.get(
self.config.embedding_provider(), DEFAULT_EMBEDDING_PROVIDER
)(),
k=self.config.label_selection_count(),
)

current_index = self.task_run.current_index if self.create_task else 0
cost = 0.0
postfix_dict = {}
Expand All @@ -185,8 +206,17 @@ def run(
)
else:
examples = []
# Construct Prompt to pass to LLM
final_prompt = self.task.construct_prompt(chunk, examples)
# Construct Prompt to pass to LLM
if (
self.config.label_selection()
and self.config.task_type() == TaskType.CLASSIFICATION
):
selected_labels = self.label_selector.select_labels(chunk["example"])
final_prompt = self.task.construct_prompt(
chunk, examples, selected_labels
)
else:
final_prompt = self.task.construct_prompt(chunk, examples)

response = self.llm.label([final_prompt])
for i, generations, error in zip(
Expand Down Expand Up @@ -332,6 +362,20 @@ def plan(
cache=self.generation_cache is not None,
)

if self.config.label_selection():
if self.config.task_type() != TaskType.CLASSIFICATION:
self.console.print(
"Warning: label_selection only supported for classification tasks!"
)
else:
self.label_selector = LabelSelector.from_examples(
labels=self.config.labels_list(),
embedding_func=PROVIDER_TO_MODEL.get(
self.config.embedding_provider(), DEFAULT_EMBEDDING_PROVIDER
)(),
k=self.config.label_selection_count(),
)

input_limit = min(len(dataset.inputs), 100)

for input_i in track(
Expand All @@ -346,7 +390,16 @@ def plan(
)
else:
examples = []
final_prompt = self.task.construct_prompt(input_i, examples)
if (
self.config.label_selection()
and self.config.task_type() == TaskType.CLASSIFICATION
):
selected_labels = self.label_selector.select_labels(input_i["example"])
final_prompt = self.task.construct_prompt(
input_i, examples, selected_labels
)
else:
final_prompt = self.task.construct_prompt(input_i, examples)
prompt_list.append(final_prompt)

# Calculate the number of tokens
Expand Down
9 changes: 7 additions & 2 deletions src/autolabel/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,18 @@ def __init__(self, config: AutolabelConfig) -> None:
if self.config.confidence():
self.metrics.append(AUROCMetric())

def construct_prompt(self, input: Dict, examples: List) -> str:
def construct_prompt(
self, input: Dict, examples: List, selected_labels: List[str] = None
) -> str:
# Copy over the input so that we can modify it
input = input.copy()

# prepare task guideline
labels_list = self.config.labels_list()
labels_list = (
self.config.labels_list() if not selected_labels else selected_labels
)
num_labels = len(labels_list)

fmt_task_guidelines = self.task_guidelines.format(
num_labels=num_labels, labels="\n".join(labels_list)
)
Expand Down
Loading