Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better support for classification tasks with large number of label classes #561

Merged
merged 9 commits into from
Sep 19, 2023

Conversation

iomap
Copy link
Contributor

@iomap iomap commented Sep 1, 2023

New option to filter labels_list by similarity to input example.

Two new optional fields are now present in prompt_config schema:

        "label_selection": true,
        "label_selection_count": 10

@iomap iomap changed the title Better support for classification tasks with many label classes Better support for classification tasks with large number of label classes Sep 1, 2023
# This is the VectorStore class that is used to store the embeddings and do a similarity search over.
VectorStoreWrapper(cache=False),
# This is the number of examples to produce.
k=10,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have this value of k be configurable. Maybe 10 is a reasonable default, but we might want workflows in the future that automatically test for the right value of k.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in f97a82d

can now specify value of k in autolabel config file like so:

        "label_selection": true,
        "label_selection_count": 10

Copy link
Contributor

@rajasbansal rajasbansal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we benchmark this on banking/ledgar. I hope there isn't a big drop in performance using this approach

# This is the list of labels available to select from.
label_examples,
# This is the embedding class used to produce embeddings which are used to measure semantic similarity.
OpenAIEmbeddings(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use the embedding model that is the same as the one used for the seed examples, this can be read from the config

Copy link
Contributor

@nihit nihit Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to read this from the embedding model section in Autolabel config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in latest commit.

It now chooses embedding function based on config.embedding_provider()

            self.label_selector = LabelSelector.from_examples(
                labels=self.config.labels_list(),
                k=self.config.label_selection_count(),
                embedding_func=PROVIDER_TO_MODEL.get(
                    self.config.embedding_provider(), DEFAULT_EMBEDDING_PROVIDER
                )(),
            )

similar_prompt = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt,
prefix="Input: {example}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the example template from the config here? see how the seed examples are prepared

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was unnecessary and has been removed in latest commit.

)
label_examples = [{"input": label} for label in labels_list]

example_selector = SemanticSimilarityExampleSelector.from_examples(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of creating this each time, is it possible to construct this example selector once and then just call sample labels each time. This would make sure that we just embed the label list once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea.

I now construct the LabelSelector once in agent.run() and agent.plan(). This way embeddings of labels are only computed once.

# if large number of labels, filter labels_list by similarity of labels to input
if num_labels >= 50:
example_prompt = PromptTemplate(
input_variables=["input"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: call this label?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in latest commit

# This is the embedding class used to produce embeddings which are used to measure semantic similarity.
OpenAIEmbeddings(),
# This is the VectorStore class that is used to store the embeddings and do a similarity search over.
VectorStoreWrapper(cache=False),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

went through the code, ideally we don't need to set this cache as False and can use the cache setting from teh config, but if not, this would still be fine.

@@ -55,6 +60,40 @@ def construct_prompt(self, input: Dict, examples: List) -> str:
# prepare task guideline
labels_list = self.config.labels_list()
num_labels = len(labels_list)

# if large number of labels, filter labels_list by similarity of labels to input
if num_labels >= 50:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do this based on some config setting. just want to make sure that we have the ability to turn this on or off. we can do this on num_labels > 50 if the config setting corresponding to this is not set

Copy link
Contributor

@nihit nihit Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with previous comments. we should enable this "label selection" from a config parameter, not a hardcoded num_labels threshold

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in f97a82d

can now turn on/off label selection in config (as well as the number of labels to select), like so:

        "label_selection": true,
        "label_selection_count": 10

# This is the VectorStore class that is used to store the embeddings and do a similarity search over.
VectorStoreWrapper(cache=False),
# This is the number of examples to produce.
k=10,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets make this configurable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in latest commit.

@iomap
Copy link
Contributor Author

iomap commented Sep 1, 2023

did we benchmark this on banking/ledgar. I hope there isn't a big drop in performance using this approach

My initial test on Ledgar:
Previous accuracy: 71% (from Jupyter notebook in autolabel/examples/ledgar)
Accuracy with change (on 100 samples): 68%

Roughly the same, but need to continue testing. Will try out a full run on banking and ledgar datasets.

Comment on lines 90 to 94
split_lines = sampled_labels.split("\n")
labels_list = []
for i in range(1, len(split_lines)):
if split_lines[i]:
labels_list.append(split_lines[i])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might break for input examples that contain newline characters \n.

Maybe I do a check that split_lines[i] in labels_list before appending

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be needed once the implementation is revamped

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct, this has been removed.

Copy link
Contributor

@nihit nihit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest abstracting all this logic away in a "Label Selector" class, similar to what the Example Selectors do.

The Label Selector would be initialized once in the run/plan agent method by reading appropriate fields from the config (again, very similar to example selector)

The agent can then call this object's "select labels" function when labeling each example to get a list of K most likely labels, like https://github.com/refuel-ai/autolabel/blob/main/src/autolabel/labeler.py#L183

And pass it to the task object (like https://github.com/refuel-ai/autolabel/blob/main/src/autolabel/labeler.py#L189)

Comment on lines 70 to 88
label_examples = [{"input": label} for label in labels_list]

example_selector = SemanticSimilarityExampleSelector.from_examples(
# This is the list of labels available to select from.
label_examples,
# This is the embedding class used to produce embeddings which are used to measure semantic similarity.
OpenAIEmbeddings(),
# This is the VectorStore class that is used to store the embeddings and do a similarity search over.
VectorStoreWrapper(cache=False),
# This is the number of examples to produce.
k=10,
)
similar_prompt = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt,
prefix="Input: {example}",
suffix="",
input_variables=["example"],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please revamp this implementation.

  1. No need to go via FewShotExampleTemplate and semantic similarity example selector.
  2. The label selection should conceptually consist of 3 steps: (i) input row --> formatted example (ii) compute embedding of the formatted example (iii) find nearest neighbors from among the label list
  3. The embeddings for labels in the label list should be computed just once, not once per row

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have revamped the implementation.

The FewShotExampleTemplate and semantic similarity example selector has been removed entirely.

Embeddings for labels in the label list is computed only once, in agent.plan() and agent.run()

# This is the list of labels available to select from.
label_examples,
# This is the embedding class used to produce embeddings which are used to measure semantic similarity.
OpenAIEmbeddings(),
Copy link
Contributor

@nihit nihit Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to read this from the embedding model section in Autolabel config

@@ -55,6 +60,40 @@ def construct_prompt(self, input: Dict, examples: List) -> str:
# prepare task guideline
labels_list = self.config.labels_list()
num_labels = len(labels_list)

# if large number of labels, filter labels_list by similarity of labels to input
if num_labels >= 50:
Copy link
Contributor

@nihit nihit Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with previous comments. we should enable this "label selection" from a config parameter, not a hardcoded num_labels threshold

Comment on lines 90 to 94
split_lines = sampled_labels.split("\n")
labels_list = []
for i in range(1, len(split_lines)):
if split_lines[i]:
labels_list.append(split_lines[i])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be needed once the implementation is revamped

@iomap
Copy link
Contributor Author

iomap commented Sep 7, 2023

Worth noting that after refactoring this PR, I am noticing a slight (~5%) impact on labeling accuracy on Ledgar dataset

Results with label_selection = true, k = 10

┏━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ accuracy ┃ support ┃ completion_rate ┃
┡━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
. │ 0.6465 │ 100 │ 0.99 │
└──────────┴─────────┴─────────────────┘

Results with label_selection = false

┏━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ accuracy ┃ support ┃ completion_rate ┃
┡━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ 0.7071 │ 100 │ 0.99 │
└──────────┴─────────┴─────────────────┘

Prior to the revamp, accuracy was about equal in both cases.

Perhaps our cos_sim() function isn't quite as good as the langchain similarity selector I was using previously? That or the embedding function is configured differently. I have noticed that embedding generation time is longer now than it was prior.

@iomap
Copy link
Contributor Author

iomap commented Sep 7, 2023

How I tested:

import os
# provide your own OpenAI API key here
os.environ['OPENAI_API_KEY'] = '...'

from autolabel import get_data

get_data('ledgar')

import json

from autolabel import LabelingAgent

# load the config
with open('config_ledgar.json', 'r') as f:
     config = json.load(f)

# create an agent for labeling
agent = LabelingAgent(config=config)

from autolabel import AutolabelDataset
ds = AutolabelDataset("test.csv", config=config)
agent.plan(ds)

# now, do the actual labeling
ds = agent.run(ds, output_name="output_test.csv", max_items=100)

config_ledgar.json

{
    "task_name": "LegalProvisionsClassification",
    "task_type": "classification",
    "dataset": {
        "label_column": "label",
        "delimiter": ","
    },
    "model": {
        "provider": "openai",
        "name": "gpt-3.5-turbo"
    },
    "prompt": {
        "task_guidelines": "You are an expert at understanding legal contracts. Your job is to correctly classify legal provisions in contracts into one of the following categories.\nCategories:{labels}\n",
        "labels": [
            "Agreements",
            "Amendments",
            "Adjustments",
            "Anti-Corruption Laws",
            "Applicable Laws",
            "Approvals",
            "Arbitration",
            "Assignments",
            "Assigns",
            "Authority",
            "Authorizations",
            "Base Salary",
            "Benefits",
            "Binding Effects",
            "Books",
            "Brokers",
            "Capitalization",
            "Change In Control",
            "Closings",
            "Compliance With Laws",
            "Confidentiality",
            "Consent To Jurisdiction",
            "Consents",
            "Construction",
            "Cooperation",
            "Costs",
            "Counterparts",
            "Death",
            "Defined Terms",
            "Definitions",
            "Disability",
            "Disclosures",
            "Duties",
            "Effective Dates",
            "Effectiveness",
            "Employment",
            "Enforceability",
            "Enforcements",
            "Entire Agreements",
            "Erisa",
            "Existence",
            "Expenses",
            "Fees",
            "Financial Statements",
            "Forfeitures",
            "Further Assurances",
            "General",
            "Governing Laws",
            "Headings",
            "Indemnifications",
            "Indemnity",
            "Insurances",
            "Integration",
            "Intellectual Property",
            "Interests",
            "Interpretations",
            "Jurisdictions",
            "Liens",
            "Litigations",
            "Miscellaneous",
            "Modifications",
            "No Conflicts",
            "No Defaults",
            "No Waivers",
            "Non-Disparagement",
            "Notices",
            "Organizations",
            "Participations",
            "Payments",
            "Positions",
            "Powers",
            "Publicity",
            "Qualifications",
            "Records",
            "Releases",
            "Remedies",
            "Representations",
            "Sales",
            "Sanctions",
            "Severability",
            "Solvency",
            "Specific Performance",
            "Submission To Jurisdiction",
            "Subsidiaries",
            "Successors",
            "Survival",
            "Tax Withholdings",
            "Taxes",
            "Terminations",
            "Terms",
            "Titles",
            "Transactions With Affiliates",
            "Use Of Proceeds",
            "Vacations",
            "Venues",
            "Vesting",
            "Waiver Of Jury Trials",
            "Waivers",
            "Warranties",
            "Withholdings"
        ],
        "example_template": "Example: {example}\nOutput: {label}",
        "few_shot_examples": "seed.csv",
        "few_shot_selection": "semantic_similarity",
        "few_shot_num": 4,
        "label_selection": true,
        "label_selection_count": 10
    }
}

Copy link
Contributor

@rajasbansal rajasbansal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@@ -346,7 +376,13 @@ def plan(
)
else:
examples = []
final_prompt = self.task.construct_prompt(input_i, examples)
if self.config.label_selection():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this would give an error if label selection was set to true for any task other than classification. This is because the construct_prompt has been changed just for the classification task. Any way to catch this i.e label selection not supported for this task

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. Will check that it is a classification task (if label_selection = true)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 2a6ec29

@@ -21,6 +21,11 @@

import json

from langchain.prompts.example_selector import SemanticSimilarityExampleSelector
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably don't need these imports now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed these imports in e5ff12a

@nihit
Copy link
Contributor

nihit commented Sep 8, 2023

@iomap to followup with any updates to documentation

@iomap iomap merged commit acb8755 into main Sep 19, 2023
2 checks passed
@iomap iomap deleted the many_classes_support branch September 19, 2023 20:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants