-
Notifications
You must be signed in to change notification settings - Fork 147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better support for classification tasks with large number of label classes #561
Changes from 2 commits
a070422
e3c2126
3cfca54
f97a82d
16f769b
547e811
2a6ec29
e5ff12a
921c096
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,11 @@ | |
|
||
import json | ||
|
||
from langchain.prompts.example_selector import SemanticSimilarityExampleSelector | ||
from langchain.embeddings import OpenAIEmbeddings | ||
from langchain.prompts import FewShotPromptTemplate, PromptTemplate | ||
from autolabel.few_shot.vector_store import VectorStoreWrapper | ||
|
||
|
||
class ClassificationTask(BaseTask): | ||
DEFAULT_OUTPUT_GUIDELINES = ( | ||
|
@@ -55,6 +60,40 @@ def construct_prompt(self, input: Dict, examples: List) -> str: | |
# prepare task guideline | ||
labels_list = self.config.labels_list() | ||
num_labels = len(labels_list) | ||
|
||
# if large number of labels, filter labels_list by similarity of labels to input | ||
if num_labels >= 50: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we do this based on some config setting. just want to make sure that we have the ability to turn this on or off. we can do this on num_labels > 50 if the config setting corresponding to this is not set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with previous comments. we should enable this "label selection" from a config parameter, not a hardcoded num_labels threshold There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done in f97a82d can now turn on/off label selection in config (as well as the number of labels to select), like so:
|
||
example_prompt = PromptTemplate( | ||
input_variables=["input"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: call this label? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed in latest commit |
||
template="{input}", | ||
) | ||
label_examples = [{"input": label} for label in labels_list] | ||
|
||
example_selector = SemanticSimilarityExampleSelector.from_examples( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. instead of creating this each time, is it possible to construct this example selector once and then just call sample labels each time. This would make sure that we just embed the label list once. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. I now construct the LabelSelector once in agent.run() and agent.plan(). This way embeddings of labels are only computed once. |
||
# This is the list of labels available to select from. | ||
label_examples, | ||
# This is the embedding class used to produce embeddings which are used to measure semantic similarity. | ||
OpenAIEmbeddings(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we use the embedding model that is the same as the one used for the seed examples, this can be read from the config There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 to read this from the embedding model section in Autolabel config There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done in latest commit. It now chooses embedding function based on
|
||
# This is the VectorStore class that is used to store the embeddings and do a similarity search over. | ||
VectorStoreWrapper(cache=False), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. went through the code, ideally we don't need to set this cache as False and can use the cache setting from teh config, but if not, this would still be fine. |
||
# This is the number of examples to produce. | ||
k=10, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should have this value of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done in f97a82d can now specify value of k in autolabel config file like so:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets make this configurable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done in latest commit. |
||
) | ||
similar_prompt = FewShotPromptTemplate( | ||
example_selector=example_selector, | ||
example_prompt=example_prompt, | ||
prefix="Input: {example}", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use the example template from the config here? see how the seed examples are prepared There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was unnecessary and has been removed in latest commit. |
||
suffix="", | ||
input_variables=["example"], | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please revamp this implementation.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have revamped the implementation. The FewShotExampleTemplate and semantic similarity example selector has been removed entirely. Embeddings for labels in the label list is computed only once, in agent.plan() and agent.run() |
||
sampled_labels = similar_prompt.format(example=input["example"]) | ||
split_lines = sampled_labels.split("\n") | ||
labels_list = [] | ||
for i in range(1, len(split_lines)): | ||
if split_lines[i]: | ||
labels_list.append(split_lines[i]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might break for input examples that contain newline characters Maybe I do a check that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should not be needed once the implementation is revamped There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. correct, this has been removed. |
||
num_labels = len(labels_list) | ||
|
||
fmt_task_guidelines = self.task_guidelines.format( | ||
num_labels=num_labels, labels="\n".join(labels_list) | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably don't need these imports now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed these imports in e5ff12a