Skip to content

Commit

Permalink
Add an attribute extraction task (#566)
Browse files Browse the repository at this point in the history
* started on the attribute extraction task

* added metrics for each attribute

* updated attribute extraction task with method

* removed attribute extraction f1 score

* added systematic approach to attribute metrics

* deleted debug print

* Add attribute extraction task

* Allow dataset exploration with attribute extraction

* Fix the case where there is no label column

---------

Co-authored-by: Tyler <[email protected]>
Co-authored-by: Rajas Bansal <[email protected]>
  • Loading branch information
3 people authored Sep 10, 2023
1 parent a32d3f7 commit 328b8ac
Show file tree
Hide file tree
Showing 14 changed files with 1,276 additions and 29 deletions.
42 changes: 42 additions & 0 deletions examples/ethos/config_ethos.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"task_name": "EthosAttributeExtraction",
"task_type": "attribute_extraction",
"dataset": {
"text_column": "text",
"delimiter": ","
},
"model": {
"provider": "openai",
"name": "gpt-3.5-turbo"
},
"prompt": {
"task_guidelines": "You are an expert at classifying hate speech and identifying the type of hate speech. Read the following tweets and extract the following attributes from the text.",
"attributes": [
{
"name": "violence",
"options": ["not_violent", "violent"],
"description": "If the tweet mentions violence towards a person or a group."
},
{
"name": "directed_vs_generalized",
"options": [
"generalized",
"directed"
],
"description": "If the hate speech is generalized towards a group or directed towards a specific person."
},
{
"name": "gender",
"options": [
"true",
"false"
],
"description": "If the hate speech uses gendered language and attacks a particular gender."
}
],
"few_shot_examples": "seed.csv",
"few_shot_selection": "fixed",
"few_shot_num": 5,
"example_template": "Text: {text}\nOutput: {output_dict}"
}
}
916 changes: 916 additions & 0 deletions examples/ethos/example_ethos.ipynb

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions src/autolabel/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class AutolabelConfig(BaseConfig):
OUTPUT_GUIDELINE_KEY = "output_guidelines"
OUTPUT_FORMAT_KEY = "output_format"
CHAIN_OF_THOUGHT_KEY = "chain_of_thought"
ATTRIBUTES_KEY = "attributes"

TRANSFORM_KEY = "transforms"

# Dataset generation config keys (config["dataset_generation"][<key>])
Expand Down Expand Up @@ -206,6 +208,10 @@ def chain_of_thought(self) -> bool:
"""Returns true if the model is able to perform chain of thought reasoning."""
return self._prompt_config.get(self.CHAIN_OF_THOUGHT_KEY, False)

def attributes(self) -> List[Dict]:
"""Returns a list of attributes to extract from the text."""
return self._prompt_config.get(self.ATTRIBUTES_KEY, [])

def transforms(self) -> List[Dict]:
"""Returns a list of transforms to apply to the data before sending to the model."""
return self.config.get(self.TRANSFORM_KEY, [])
Expand Down
6 changes: 6 additions & 0 deletions src/autolabel/configs/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ def populate_few_shot_selection() -> List[str]:
},
"few_shot_num": {"type": ["number", "null"]},
"chain_of_thought": {"type": ["boolean", "null"]},
"attributes": {
"anyOf": [
{"type": "array", "items": {"type": "object"}},
{"type": "null"},
]
},
},
"required": ["task_guidelines"],
"additionalProperties": False,
Expand Down
73 changes: 48 additions & 25 deletions src/autolabel/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
import pickle
from autolabel.tasks import TaskFactory
from autolabel.schema import TaskType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,11 +69,18 @@ def __init__(

inputs = df.to_dict(orient="records")
label_column = self.config.label_column()
gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else df[label_column].tolist()
)
if not self.config.task_type() == TaskType.ATTRIBUTE_EXTRACTION:
gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else df[label_column].tolist()
)
else:
attributes = [attr["name"] for attr in self.config.attributes()]
gt_labels = {
name: df[name].tolist() if name in df.keys() else None
for name in attributes
}

self.df = df
self.inputs = inputs
Expand Down Expand Up @@ -106,6 +114,12 @@ def process_labels(
# Add the LLM labels to the dataframe
self.df[self.generate_label_name("label")] = [x.label for x in llm_labels]

if self.config.task_type() == TaskType.ATTRIBUTE_EXTRACTION:
for attr in self.config.attributes():
self.df[self.generate_label_name("label", attr["name"])] = [
x.label[attr["name"]] for x in llm_labels
]

# Add the LLM errors to the dataframe
self.df[self.generate_label_name("error")] = [x.error for x in llm_labels]

Expand Down Expand Up @@ -165,7 +179,11 @@ def save(self, output_file_name: str):
raise ValueError(f"Unsupported output file format: {output_file_name}")

def filter(
self, label: str = None, ground_truth: str = None, filter_func: Callable = None
self,
label: str = None,
ground_truth: str = None,
filter_func: Callable = None,
label_column: str = None,
):
"""
Filter the dataset based on the label, ground truth or a custom filter function.
Expand All @@ -175,17 +193,19 @@ def filter(
label: The llm label to filter on.
ground_truth: The ground truth label to filter on.
filter_func: A custom filter function to filter on.
label_column: The column to filter on. This is only used for attribute extraction tasks.
"""
filtered_df = self.df

if label:
filtered_df = filtered_df[
filtered_df[self.generate_label_name("label")] == label
filtered_df[self.generate_label_name("label", label_column)] == label
]

if ground_truth:
filtered_df = filtered_df[
filtered_df[self.config.label_column()] == ground_truth
filtered_df[(label_column or self.config.label_column())]
== ground_truth
]

if filter_func:
Expand Down Expand Up @@ -213,47 +233,54 @@ def completed(self):
filtered_df = self.df[self.df[self.generate_label_name("error")].isnull()]
return AutolabelDataset(filtered_df, self.config)

def incorrect(self, label: str = None, ground_truth: str = None):
def incorrect(
self, label: str = None, ground_truth: str = None, label_column: str = None
):
"""
Filter the dataset to only include incorrect items. This means the labels
where the llm label was incorrect.
Args:
label: The llm label to filter on.
ground_truth: The ground truth label to filter on.
label_column: The column to filter on. This is only used for attribute extraction tasks.
"""
gt_label_column = self.config.label_column()
gt_label_column = label_column or self.config.label_column()

if gt_label_column is None:
raise ValueError(
"Cannot compute mistakes without ground truth label column"
)

filtered_df = self.df[
self.df[self.generate_label_name("label")] != self.df[gt_label_column]
self.df[self.generate_label_name("label", label_column)]
!= self.df[gt_label_column]
]

if label:
filtered_df = filtered_df[
filtered_df[self.generate_label_name("label")] == label
filtered_df[self.generate_label_name("label", label_column)] == label
]

if ground_truth:
filtered_df = filtered_df[filtered_df[gt_label_column] == ground_truth]

return AutolabelDataset(filtered_df, self.config)

def correct(self):
def correct(self, label_column: str = None):
"""
Filter the dataset to only include correct items. This means the labels
where the llm label was correct.
Args:
label_column: The column to filter on. This is only used for attribute extraction tasks.
"""
gt_label_column = self.config.label_column()
gt_label_column = label_column or self.config.label_column()

if gt_label_column is None:
raise ValueError("Cannot compute correct without ground truth label column")

filtered_df = self.df[
self.df[self.generate_label_name("label")] == self.df[gt_label_column]
self.df[self.generate_label_name("label", label_column)]
== self.df[gt_label_column]
]
return AutolabelDataset(filtered_df, self.config)

Expand All @@ -278,18 +305,11 @@ def eval(self):
Evaluate the dataset based on the task. We run the metrics that were
specified by the task being run.
"""
gt_label_column = self.config.label_column()

if gt_label_column is None:
raise ValueError("Cannot compute eval without ground truth label column")

gt_labels = self.df[gt_label_column]

llm_labels = self.df[self.generate_label_name("annotation")].tolist()

task = TaskFactory.from_config(self.config)

metrics = task.eval(llm_labels, gt_labels)
metrics = task.eval(llm_labels, self.gt_labels)

table = {}
for metric in metrics:
Expand Down Expand Up @@ -335,8 +355,11 @@ def _validate(self):
f"Validation failed for {len(self.__malformed_records)} rows."
)

def generate_label_name(self, col_name: str):
return f"{self.config.task_name()}_{col_name}"
def generate_label_name(self, col_name: str, label_column: str = None):
label_column = (
label_column or self.config.label_column() or self.config.task_name()
)
return f"{label_column}_{col_name}"


class DataValidationFailed(Exception):
Expand Down
19 changes: 15 additions & 4 deletions src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,12 @@ def run(
if dataset.gt_labels:
eval_result = self.task.eval(
llm_labels,
dataset.gt_labels[: len(llm_labels)],
dataset.gt_labels[: len(llm_labels)]
if isinstance(dataset.gt_labels, list)
else {
k: v[: len(llm_labels)]
for k, v in dataset.gt_labels.items()
},
additional_metrics=additional_metrics,
)

Expand Down Expand Up @@ -259,7 +264,9 @@ def run(
if not skip_eval and dataset.gt_labels:
eval_result = self.task.eval(
llm_labels,
dataset.gt_labels[: len(llm_labels)],
dataset.gt_labels[: len(llm_labels)]
if isinstance(dataset.gt_labels, list)
else {k: v[: len(llm_labels)] for k, v in dataset.gt_labels.items()},
additional_metrics=additional_metrics,
)
# TODO: serialize and write to file
Expand Down Expand Up @@ -415,8 +422,12 @@ def handle_existing_task_run(
)
llm_labels = self.get_all_annotations()
if gt_labels and len(llm_labels) > 0:
self.console.print("Evaluating the existing task...")
gt_labels = gt_labels[: len(llm_labels)]
pprint("Evaluating the existing task...")
gt_labels = (
gt_labels[: len(llm_labels)]
if isinstance(gt_labels, list)
else {k: v[: len(llm_labels)] for k, v in gt_labels.items()}
)
eval_result = self.task.eval(
llm_labels, gt_labels, additional_metrics=additional_metrics
)
Expand Down
11 changes: 11 additions & 0 deletions src/autolabel/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
from typing import List
import logging

from sklearn.metrics import accuracy_score

from autolabel.metrics import BaseMetric
from autolabel.schema import LLMAnnotation, MetricResult, MetricType


logger = logging.getLogger(__name__)


class AccuracyMetric(BaseMetric):
def __init__(self) -> None:
super().__init__()

def compute(
self, llm_labels: List[LLMAnnotation], gt_labels: List[str]
) -> List[MetricResult]:
# If there are not ground truth labels, return an empty list
if not gt_labels:
logger.warning(
"No ground truth labels were provided. Skipping accuracy metric."
)
return []

filtered_llm_labels = []
filtered_gt_labels = []
for llm_label, gt_label in zip(llm_labels, gt_labels):
Expand Down
9 changes: 9 additions & 0 deletions src/autolabel/metrics/auroc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import List
import logging

from sklearn.metrics import roc_auc_score
import numpy as np

from autolabel.metrics import BaseMetric
from autolabel.schema import LLMAnnotation, MetricResult, MetricType

logger = logging.getLogger(__name__)


class AUROCMetric(BaseMetric):
def __init__(self) -> None:
Expand All @@ -14,6 +17,12 @@ def __init__(self) -> None:
def compute(
self, llm_labels: List[LLMAnnotation], gt_labels: List[str]
) -> List[MetricResult]:
if not gt_labels:
logger.warning(
"No ground truth labels were provided. Skipping AUROC metric."
)
return []

filtered_llm_labels = []
filtered_gt_labels = []
for llm_label, gt_label in zip(llm_labels, gt_labels):
Expand Down
10 changes: 10 additions & 0 deletions src/autolabel/metrics/classification_report.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import List
import logging

from sklearn.metrics import classification_report

from autolabel.metrics import BaseMetric
from autolabel.schema import LLMAnnotation, MetricResult, MetricType

logger = logging.getLogger(__name__)


class ClassificationReportMetric(BaseMetric):
def __init__(self) -> None:
Expand All @@ -13,6 +16,13 @@ def __init__(self) -> None:
def compute(
self, llm_labels: List[LLMAnnotation], gt_labels: List[str]
) -> List[MetricResult]:
# If there are not ground truth labels, return an empty list
if not gt_labels:
logger.warning(
"No ground truth labels were provided. Skipping classification report metric."
)
return []

filtered_llm_labels = []
filtered_gt_labels = []
for llm_label, gt_label in zip(llm_labels, gt_labels):
Expand Down
Loading

0 comments on commit 328b8ac

Please sign in to comment.