Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better support for classification tasks with large number of label classes #561

Merged
merged 9 commits into from
Sep 19, 2023
12 changes: 12 additions & 0 deletions src/autolabel/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class AutolabelConfig(BaseConfig):
OUTPUT_GUIDELINE_KEY = "output_guidelines"
OUTPUT_FORMAT_KEY = "output_format"
CHAIN_OF_THOUGHT_KEY = "chain_of_thought"
LABEL_SELECTION_KEY = "label_selection"
LABEL_SELECTION_COUNT_KEY = "label_selection_count"
TRANSFORM_KEY = "transforms"

# Dataset generation config keys (config["dataset_generation"][<key>])
Expand Down Expand Up @@ -201,6 +203,16 @@ def chain_of_thought(self) -> bool:
"""Returns true if the model is able to perform chain of thought reasoning."""
return self._prompt_config.get(self.CHAIN_OF_THOUGHT_KEY, False)

def label_selection(self) -> bool:
"""Returns true if label selection is enabled. Label selection is the process of
narrowing down the list of possible labels by similarity to a given input. Useful for
classification tasks with a large number of possible classes."""
return self._prompt_config.get(self.LABEL_SELECTION_KEY, False)

def label_selection_count(self) -> int:
"""Returns the number of labels to select in LabelSelector"""
return self._prompt_config.get(self.LABEL_SELECTION_COUNT_KEY, 10)

def transforms(self) -> List[Dict]:
"""Returns a list of transforms to apply to the data before sending to the model."""
return self.config.get(self.TRANSFORM_KEY, [])
Expand Down
2 changes: 2 additions & 0 deletions src/autolabel/configs/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def populate_few_shot_selection() -> List[str]:
},
"few_shot_num": {"type": ["number", "null"]},
"chain_of_thought": {"type": ["boolean", "null"]},
"label_selection": {"type": ["boolean", "null"]},
"label_selection_count": {"type": ["number", "null"]},
},
"required": ["task_guidelines"],
"additionalProperties": False,
Expand Down
59 changes: 59 additions & 0 deletions src/autolabel/few_shot/label_selector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from typing import Dict, List
import bisect

from autolabel.few_shot.vector_store import cos_sim

from langchain.embeddings.openai import OpenAIEmbeddings


class LabelSelector:
"""Returns the most similar labels to a given input. Used for
classification tasks with a large number of possible classes."""

labels: List[str]
"""A list of the possible labels to choose from."""

k: int = 10
"""Number of labels to select"""

embedding_func = OpenAIEmbeddings()
"""Function used to generate embeddings of labels/input"""

labels_embeddings: Dict = {}
"""Dict used to store embeddings of each label"""

def __init__(
self, labels: List[str], k: int = 10, embedding_func=OpenAIEmbeddings()
) -> None:
self.labels = labels
self.k = min(k, len(labels))
self.embedding_func = embedding_func
for l in self.labels:
self.labels_embeddings[l] = self.embedding_func.embed_query(l)

def select_labels(self, input: str) -> List[str]:
"""Select which labels to use based on the similarity to input"""
input_embedding = self.embedding_func.embed_query(input)

scores = []
for label, embedding in self.labels_embeddings.items():
similarity = cos_sim(embedding, input_embedding)
# insert into scores, while maintaining sorted order
bisect.insort(scores, (similarity, label))
return [label for (_, label) in scores[-self.k :]]

@classmethod
def from_examples(
cls,
labels: List[str],
k: int = 10,
embedding_func=OpenAIEmbeddings(),
) -> LabelSelector:
"""Create pass-through label selector using given list of labels

Returns:
The LabelSelector instantiated
"""
return cls(labels=labels, k=k, embedding_func=embedding_func)
44 changes: 40 additions & 4 deletions src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
from autolabel.dataset import AutolabelDataset
from autolabel.data_models import AnnotationModel, TaskRunModel
from autolabel.database import StateManager
from autolabel.few_shot import ExampleSelectorFactory, BaseExampleSelector
from autolabel.few_shot import (
ExampleSelectorFactory,
BaseExampleSelector,
DEFAULT_EMBEDDING_PROVIDER,
PROVIDER_TO_MODEL,
)
from autolabel.few_shot.label_selector import LabelSelector
from autolabel.models import BaseModel, ModelFactory
from autolabel.metrics import BaseMetric
from autolabel.transforms import BaseTransform, TransformFactory
Expand Down Expand Up @@ -165,6 +171,15 @@ def run(
cache=self.generation_cache is not None,
)

if self.config.label_selection():
self.label_selector = LabelSelector.from_examples(
labels=self.config.labels_list(),
k=self.config.label_selection_count(),
embedding_func=PROVIDER_TO_MODEL.get(
self.config.embedding_provider(), DEFAULT_EMBEDDING_PROVIDER
)(),
)

current_index = self.task_run.current_index if self.create_task else 0
cost = 0.0
postfix_dict = {}
Expand All @@ -185,8 +200,14 @@ def run(
)
else:
examples = []
# Construct Prompt to pass to LLM
final_prompt = self.task.construct_prompt(chunk, examples)
# Construct Prompt to pass to LLM
if self.config.label_selection():
selected_labels = self.label_selector.select_labels(chunk["example"])
final_prompt = self.task.construct_prompt(
chunk, examples, selected_labels
)
else:
final_prompt = self.task.construct_prompt(chunk, examples)

response = self.llm.label([final_prompt])
for i, generations, error in zip(
Expand Down Expand Up @@ -332,6 +353,15 @@ def plan(
cache=self.generation_cache is not None,
)

if self.config.label_selection():
self.label_selector = LabelSelector.from_examples(
labels=self.config.labels_list(),
k=self.config.label_selection_count(),
embedding_func=PROVIDER_TO_MODEL.get(
self.config.embedding_provider(), DEFAULT_EMBEDDING_PROVIDER
)(),
)

input_limit = min(len(dataset.inputs), 100)

for input_i in track(
Expand All @@ -346,7 +376,13 @@ def plan(
)
else:
examples = []
final_prompt = self.task.construct_prompt(input_i, examples)
if self.config.label_selection():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this would give an error if label selection was set to true for any task other than classification. This is because the construct_prompt has been changed just for the classification task. Any way to catch this i.e label selection not supported for this task

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. Will check that it is a classification task (if label_selection = true)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 2a6ec29

selected_labels = self.label_selector.select_labels(input_i["example"])
final_prompt = self.task.construct_prompt(
input_i, examples, selected_labels
)
else:
final_prompt = self.task.construct_prompt(input_i, examples)
prompt_list.append(final_prompt)

# Calculate the number of tokens
Expand Down
14 changes: 12 additions & 2 deletions src/autolabel/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@

import json

from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably don't need these imports now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed these imports in e5ff12a

from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import FewShotPromptTemplate, PromptTemplate
from autolabel.few_shot.vector_store import VectorStoreWrapper


class ClassificationTask(BaseTask):
DEFAULT_OUTPUT_GUIDELINES = (
Expand Down Expand Up @@ -48,13 +53,18 @@ def __init__(self, config: AutolabelConfig) -> None:
if self.config.confidence():
self.metrics.append(AUROCMetric())

def construct_prompt(self, input: Dict, examples: List) -> str:
def construct_prompt(
self, input: Dict, examples: List, selected_labels: List[str] = None
) -> str:
# Copy over the input so that we can modify it
input = input.copy()

# prepare task guideline
labels_list = self.config.labels_list()
labels_list = (
self.config.labels_list() if not selected_labels else selected_labels
)
num_labels = len(labels_list)

fmt_task_guidelines = self.task_guidelines.format(
num_labels=num_labels, labels="\n".join(labels_list)
)
Expand Down
Loading