Skip to content

Commit

Permalink
Fix metric calculation for unlabeled datapoints (#468)
Browse files Browse the repository at this point in the history
Calculate accuracy and F1 only on labeled datapoints

Co-authored-by: Rajas Bansal <[email protected]>
  • Loading branch information
rajasbansal and rajasbansal authored Jul 20, 2023
1 parent b3fa7b8 commit 740122c
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 13 deletions.
9 changes: 7 additions & 2 deletions src/autolabel/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from autolabel.schema import LLMAnnotation, Metric, MetricResult
from autolabel.tasks import BaseTask
from autolabel.utils import get_format_variables
from autolabel.tasks.utils import filter_unlabeled_examples

import json

Expand Down Expand Up @@ -167,9 +168,13 @@ def eval(
)

eval_metrics_map[Metric.SUPPORT].append(len(curr_gt_labels))
if len(curr_gt_labels) > 0:
(
filtered_curr_gt_labels,
filtered_curr_llm_labels,
) = filter_unlabeled_examples(curr_gt_labels, curr_llm_labels)
if len(filtered_curr_gt_labels) > 0:
eval_metrics_map[Metric.ACCURACY].append(
accuracy_score(curr_gt_labels, curr_llm_labels)
accuracy_score(filtered_curr_gt_labels, filtered_curr_llm_labels)
)
else:
eval_metrics_map[Metric.ACCURACY].append(0.0)
Expand Down
9 changes: 7 additions & 2 deletions src/autolabel/tasks/entity_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from autolabel.schema import LLMAnnotation, Metric, MetricResult
from autolabel.tasks import BaseTask
from autolabel.utils import get_format_variables
from autolabel.tasks.utils import filter_unlabeled_examples


class EntityMatchingTask(BaseTask):
Expand Down Expand Up @@ -165,9 +166,13 @@ def eval(
len(curr_gt_labels) / float(len(gt_labels))
)
eval_metrics_map[Metric.SUPPORT].append(len(curr_gt_labels))
if len(curr_gt_labels) > 0:
(
filtered_curr_gt_labels,
filtered_curr_llm_labels,
) = filter_unlabeled_examples(curr_gt_labels, curr_llm_labels)
if len(filtered_curr_gt_labels) > 0:
eval_metrics_map[Metric.ACCURACY].append(
accuracy_score(curr_gt_labels, curr_llm_labels)
accuracy_score(filtered_curr_gt_labels, filtered_curr_llm_labels)
)
else:
eval_metrics_map[Metric.ACCURACY].append(0.0)
Expand Down
17 changes: 11 additions & 6 deletions src/autolabel/tasks/multilabel_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from autolabel.tasks import BaseTask
from autolabel.tasks.utils import compute_f1
from autolabel.utils import get_format_variables
from autolabel.tasks.utils import filter_unlabeled_examples

import json

Expand Down Expand Up @@ -168,9 +169,13 @@ def eval(
)

eval_metrics_map[Metric.SUPPORT].append(len(curr_gt_labels))
if len(curr_gt_labels) > 0:
(
filtered_curr_gt_labels,
filtered_curr_llm_labels,
) = filter_unlabeled_examples(curr_gt_labels, curr_llm_labels)
if len(filtered_curr_gt_labels) > 0:
eval_metrics_map[Metric.ACCURACY].append(
accuracy_score(curr_gt_labels, curr_llm_labels)
accuracy_score(filtered_curr_gt_labels, filtered_curr_llm_labels)
)
else:
eval_metrics_map[Metric.ACCURACY].append(0.0)
Expand All @@ -180,8 +185,8 @@ def eval(

eval_metrics_map[Metric.F1_MACRO].append(
compute_f1(
curr_gt_labels,
curr_llm_labels,
filtered_curr_gt_labels,
filtered_curr_llm_labels,
average="macro",
labels=self.config.labels_list(),
sep=self.config.label_separator(),
Expand All @@ -190,8 +195,8 @@ def eval(

eval_metrics_map[Metric.F1_WEIGHTED].append(
compute_f1(
curr_gt_labels,
curr_llm_labels,
filtered_curr_gt_labels,
filtered_curr_llm_labels,
average="weighted",
labels=self.config.labels_list(),
sep=self.config.label_separator(),
Expand Down
13 changes: 10 additions & 3 deletions src/autolabel/tasks/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from autolabel.tasks import BaseTask
from autolabel.tasks.utils import normalize_text, compute_f1
from autolabel.utils import get_format_variables
from autolabel.tasks.utils import filter_unlabeled_examples


class QuestionAnsweringTask(BaseTask):
Expand Down Expand Up @@ -156,9 +157,15 @@ def eval(
)

eval_metrics_map[Metric.SUPPORT].append(len(curr_gt_labels))
if len(curr_gt_labels) > 0:

(
filtered_curr_gt_labels,
filtered_curr_llm_labels,
) = filter_unlabeled_examples(curr_gt_labels, curr_llm_labels)

if len(filtered_curr_gt_labels) > 0:
eval_metrics_map[Metric.ACCURACY].append(
accuracy_score(curr_gt_labels, curr_llm_labels)
accuracy_score(filtered_curr_gt_labels, filtered_curr_llm_labels)
)
else:
eval_metrics_map[Metric.ACCURACY].append(0.0)
Expand All @@ -167,7 +174,7 @@ def eval(
eval_metrics_map[Metric.THRESHOLD].append(threshold)

eval_metrics_map[Metric.F1].append(
compute_f1(curr_gt_labels, curr_llm_labels)
compute_f1(filtered_curr_gt_labels, filtered_curr_llm_labels)
)

eval_metrics.extend(
Expand Down
22 changes: 22 additions & 0 deletions src/autolabel/tasks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from sklearn.metrics import f1_score
from sklearn.preprocessing import MultiLabelBinarizer
from autolabel.schema import LLMAnnotation


def normalize_text(s: str) -> str:
Expand Down Expand Up @@ -81,3 +82,24 @@ def compute_f1(
f1_scores.append(2 * (prec * rec) / (prec + rec))

return sum(f1_scores) / len(f1_scores)


def filter_unlabeled_examples(gt_labels: List[str], llm_labels: List[LLMAnnotation]):
"""Filter out unlabeled examples from the ground truth and LLM generated labels.
This is done by checking the ground truth labels which have nan values.
The corresponding ground truth and LLM labels are removed from the filtered labels lists.
Args:
gt_labels (List[str]): ground truth labels
llm_labels (List[LLMAnnotation]): llm labels
Returns:
Tuple[List[str], List[LLMAnnotation]]: filtered ground truth and LLM generated labels
"""
filtered_gt_labels = []
filtered_llm_labels = []
for gt_label, llm_label in zip(gt_labels, llm_labels):
if gt_label != "nan":
filtered_gt_labels.append(gt_label)
filtered_llm_labels.append(llm_label)
return filtered_gt_labels, filtered_llm_labels

0 comments on commit 740122c

Please sign in to comment.