Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better support for classification tasks with large number of label classes #561

Merged
merged 9 commits into from
Sep 19, 2023
39 changes: 39 additions & 0 deletions src/autolabel/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@

import json

from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably don't need these imports now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed these imports in e5ff12a

from langchain.embeddings import OpenAIEmbeddings
from langchain.prompts import FewShotPromptTemplate, PromptTemplate
from autolabel.few_shot.vector_store import VectorStoreWrapper


class ClassificationTask(BaseTask):
DEFAULT_OUTPUT_GUIDELINES = (
Expand Down Expand Up @@ -55,6 +60,40 @@ def construct_prompt(self, input: Dict, examples: List) -> str:
# prepare task guideline
labels_list = self.config.labels_list()
num_labels = len(labels_list)

# if large number of labels, filter labels_list by similarity of labels to input
if num_labels >= 50:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do this based on some config setting. just want to make sure that we have the ability to turn this on or off. we can do this on num_labels > 50 if the config setting corresponding to this is not set

Copy link
Contributor

@nihit nihit Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with previous comments. we should enable this "label selection" from a config parameter, not a hardcoded num_labels threshold

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in f97a82d

can now turn on/off label selection in config (as well as the number of labels to select), like so:

        "label_selection": true,
        "label_selection_count": 10

example_prompt = PromptTemplate(
input_variables=["input"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: call this label?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in latest commit

template="{input}",
)
label_examples = [{"input": label} for label in labels_list]

example_selector = SemanticSimilarityExampleSelector.from_examples(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of creating this each time, is it possible to construct this example selector once and then just call sample labels each time. This would make sure that we just embed the label list once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

I now construct the LabelSelector once in agent.run() and agent.plan(). This way embeddings of labels are only computed once.

# This is the list of labels available to select from.
label_examples,
# This is the embedding class used to produce embeddings which are used to measure semantic similarity.
OpenAIEmbeddings(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use the embedding model that is the same as the one used for the seed examples, this can be read from the config

Copy link
Contributor

@nihit nihit Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to read this from the embedding model section in Autolabel config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in latest commit.

It now chooses embedding function based on config.embedding_provider()

            self.label_selector = LabelSelector.from_examples(
                labels=self.config.labels_list(),
                k=self.config.label_selection_count(),
                embedding_func=PROVIDER_TO_MODEL.get(
                    self.config.embedding_provider(), DEFAULT_EMBEDDING_PROVIDER
                )(),
            )

# This is the VectorStore class that is used to store the embeddings and do a similarity search over.
VectorStoreWrapper(cache=False),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

went through the code, ideally we don't need to set this cache as False and can use the cache setting from teh config, but if not, this would still be fine.

# This is the number of examples to produce.
k=10,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have this value of k be configurable. Maybe 10 is a reasonable default, but we might want workflows in the future that automatically test for the right value of k.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in f97a82d

can now specify value of k in autolabel config file like so:

        "label_selection": true,
        "label_selection_count": 10

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets make this configurable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in latest commit.

)
similar_prompt = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt,
prefix="Input: {example}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the example template from the config here? see how the seed examples are prepared

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was unnecessary and has been removed in latest commit.

suffix="",
input_variables=["example"],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please revamp this implementation.

  1. No need to go via FewShotExampleTemplate and semantic similarity example selector.
  2. The label selection should conceptually consist of 3 steps: (i) input row --> formatted example (ii) compute embedding of the formatted example (iii) find nearest neighbors from among the label list
  3. The embeddings for labels in the label list should be computed just once, not once per row

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have revamped the implementation.

The FewShotExampleTemplate and semantic similarity example selector has been removed entirely.

Embeddings for labels in the label list is computed only once, in agent.plan() and agent.run()

sampled_labels = similar_prompt.format(example=input["example"])
split_lines = sampled_labels.split("\n")
labels_list = []
for i in range(1, len(split_lines)):
if split_lines[i]:
labels_list.append(split_lines[i])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might break for input examples that contain newline characters \n.

Maybe I do a check that split_lines[i] in labels_list before appending

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be needed once the implementation is revamped

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct, this has been removed.

num_labels = len(labels_list)

fmt_task_guidelines = self.task_guidelines.format(
num_labels=num_labels, labels="\n".join(labels_list)
)
Expand Down
Loading