diff --git a/src/autolabel/tasks/classification.py b/src/autolabel/tasks/classification.py index f1d8f74b..8d13e9ad 100644 --- a/src/autolabel/tasks/classification.py +++ b/src/autolabel/tasks/classification.py @@ -9,6 +9,7 @@ from autolabel.schema import LLMAnnotation, Metric, MetricResult from autolabel.tasks import BaseTask from autolabel.utils import get_format_variables +from autolabel.tasks.utils import filter_unlabeled_examples import json @@ -167,9 +168,13 @@ def eval( ) eval_metrics_map[Metric.SUPPORT].append(len(curr_gt_labels)) - if len(curr_gt_labels) > 0: + ( + filtered_curr_gt_labels, + filtered_curr_llm_labels, + ) = filter_unlabeled_examples(curr_gt_labels, curr_llm_labels) + if len(filtered_curr_gt_labels) > 0: eval_metrics_map[Metric.ACCURACY].append( - accuracy_score(curr_gt_labels, curr_llm_labels) + accuracy_score(filtered_curr_gt_labels, filtered_curr_llm_labels) ) else: eval_metrics_map[Metric.ACCURACY].append(0.0) diff --git a/src/autolabel/tasks/entity_matching.py b/src/autolabel/tasks/entity_matching.py index 13e507f9..8c280bee 100644 --- a/src/autolabel/tasks/entity_matching.py +++ b/src/autolabel/tasks/entity_matching.py @@ -10,6 +10,7 @@ from autolabel.schema import LLMAnnotation, Metric, MetricResult from autolabel.tasks import BaseTask from autolabel.utils import get_format_variables +from autolabel.tasks.utils import filter_unlabeled_examples class EntityMatchingTask(BaseTask): @@ -165,9 +166,13 @@ def eval( len(curr_gt_labels) / float(len(gt_labels)) ) eval_metrics_map[Metric.SUPPORT].append(len(curr_gt_labels)) - if len(curr_gt_labels) > 0: + ( + filtered_curr_gt_labels, + filtered_curr_llm_labels, + ) = filter_unlabeled_examples(curr_gt_labels, curr_llm_labels) + if len(filtered_curr_gt_labels) > 0: eval_metrics_map[Metric.ACCURACY].append( - accuracy_score(curr_gt_labels, curr_llm_labels) + accuracy_score(filtered_curr_gt_labels, filtered_curr_llm_labels) ) else: eval_metrics_map[Metric.ACCURACY].append(0.0) diff --git a/src/autolabel/tasks/multilabel_classification.py b/src/autolabel/tasks/multilabel_classification.py index f7830dec..aadbcbd9 100644 --- a/src/autolabel/tasks/multilabel_classification.py +++ b/src/autolabel/tasks/multilabel_classification.py @@ -10,6 +10,7 @@ from autolabel.tasks import BaseTask from autolabel.tasks.utils import compute_f1 from autolabel.utils import get_format_variables +from autolabel.tasks.utils import filter_unlabeled_examples import json @@ -168,9 +169,13 @@ def eval( ) eval_metrics_map[Metric.SUPPORT].append(len(curr_gt_labels)) - if len(curr_gt_labels) > 0: + ( + filtered_curr_gt_labels, + filtered_curr_llm_labels, + ) = filter_unlabeled_examples(curr_gt_labels, curr_llm_labels) + if len(filtered_curr_gt_labels) > 0: eval_metrics_map[Metric.ACCURACY].append( - accuracy_score(curr_gt_labels, curr_llm_labels) + accuracy_score(filtered_curr_gt_labels, filtered_curr_llm_labels) ) else: eval_metrics_map[Metric.ACCURACY].append(0.0) @@ -180,8 +185,8 @@ def eval( eval_metrics_map[Metric.F1_MACRO].append( compute_f1( - curr_gt_labels, - curr_llm_labels, + filtered_curr_gt_labels, + filtered_curr_llm_labels, average="macro", labels=self.config.labels_list(), sep=self.config.label_separator(), @@ -190,8 +195,8 @@ def eval( eval_metrics_map[Metric.F1_WEIGHTED].append( compute_f1( - curr_gt_labels, - curr_llm_labels, + filtered_curr_gt_labels, + filtered_curr_llm_labels, average="weighted", labels=self.config.labels_list(), sep=self.config.label_separator(), diff --git a/src/autolabel/tasks/question_answering.py b/src/autolabel/tasks/question_answering.py index 2e99511d..70011557 100644 --- a/src/autolabel/tasks/question_answering.py +++ b/src/autolabel/tasks/question_answering.py @@ -11,6 +11,7 @@ from autolabel.tasks import BaseTask from autolabel.tasks.utils import normalize_text, compute_f1 from autolabel.utils import get_format_variables +from autolabel.tasks.utils import filter_unlabeled_examples class QuestionAnsweringTask(BaseTask): @@ -156,9 +157,15 @@ def eval( ) eval_metrics_map[Metric.SUPPORT].append(len(curr_gt_labels)) - if len(curr_gt_labels) > 0: + + ( + filtered_curr_gt_labels, + filtered_curr_llm_labels, + ) = filter_unlabeled_examples(curr_gt_labels, curr_llm_labels) + + if len(filtered_curr_gt_labels) > 0: eval_metrics_map[Metric.ACCURACY].append( - accuracy_score(curr_gt_labels, curr_llm_labels) + accuracy_score(filtered_curr_gt_labels, filtered_curr_llm_labels) ) else: eval_metrics_map[Metric.ACCURACY].append(0.0) @@ -167,7 +174,7 @@ def eval( eval_metrics_map[Metric.THRESHOLD].append(threshold) eval_metrics_map[Metric.F1].append( - compute_f1(curr_gt_labels, curr_llm_labels) + compute_f1(filtered_curr_gt_labels, filtered_curr_llm_labels) ) eval_metrics.extend( diff --git a/src/autolabel/tasks/utils.py b/src/autolabel/tasks/utils.py index c4842efe..4038b1e9 100644 --- a/src/autolabel/tasks/utils.py +++ b/src/autolabel/tasks/utils.py @@ -4,6 +4,7 @@ from sklearn.metrics import f1_score from sklearn.preprocessing import MultiLabelBinarizer +from autolabel.schema import LLMAnnotation def normalize_text(s: str) -> str: @@ -81,3 +82,24 @@ def compute_f1( f1_scores.append(2 * (prec * rec) / (prec + rec)) return sum(f1_scores) / len(f1_scores) + + +def filter_unlabeled_examples(gt_labels: List[str], llm_labels: List[LLMAnnotation]): + """Filter out unlabeled examples from the ground truth and LLM generated labels. + This is done by checking the ground truth labels which have nan values. + The corresponding ground truth and LLM labels are removed from the filtered labels lists. + + Args: + gt_labels (List[str]): ground truth labels + llm_labels (List[LLMAnnotation]): llm labels + + Returns: + Tuple[List[str], List[LLMAnnotation]]: filtered ground truth and LLM generated labels + """ + filtered_gt_labels = [] + filtered_llm_labels = [] + for gt_label, llm_label in zip(gt_labels, llm_labels): + if gt_label != "nan": + filtered_gt_labels.append(gt_label) + filtered_llm_labels.append(llm_label) + return filtered_gt_labels, filtered_llm_labels