diff --git a/pyproject.toml b/pyproject.toml index ebdd886b..8b4077da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,8 @@ dependencies = [ "matplotlib >= 3.5.0", "wget >= 3.2", "ipywidgets == 8.0.6", - "jsonschema >= 4.17.3" + "jsonschema >= 4.17.3", + "tabulate >= 0.9.0" ] requires-python = ">=3.6" diff --git a/src/autolabel/data_loaders/__init__.py b/src/autolabel/data_loaders/__init__.py new file mode 100644 index 00000000..96982ac5 --- /dev/null +++ b/src/autolabel/data_loaders/__init__.py @@ -0,0 +1,141 @@ +import logging +import pandas as pd +from tabulate import tabulate +from typing import Dict, Union +from datasets import Dataset +from autolabel.data_loaders.read_datasets import ( + AutolabelDataset, + CSVReader, + JsonlReader, + HuggingFaceDatasetReader, + # SqlDatasetReader, + DataframeReader, +) +from autolabel.data_loaders.validation import TaskDataValidation +from autolabel.configs import AutolabelConfig + +logger = logging.getLogger(__name__) + + +class DataValidationFailed(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class DatasetLoader: + # TODO: add support for reading from SQL databases + # TODO: add support for reading and loading datasets in chunks + MAX_ERROR_DISPLAYED = 100 + + def __init__( + self, + dataset: Union[str, pd.DataFrame], + config: AutolabelConfig, + max_items: int = 0, + start_index: int = 0, + validate: bool = True, + ) -> None: + """DatasetLoader class to read and load datasets. + + Args: + dataset (Union[str, pd.DataFrame]): path to the dataset or the dataframe + config (AutolabelConfig): config object + max_items (int, optional): max number of items to read. Defaults to 0. + start_index (int, optional): start index to read from. Defaults to 0. + """ + self.dataset = dataset + self.config = config + self.max_items = max_items + self.start_index = start_index + + self.__al_dataset: AutolabelDataset = None + self.__malformed_records = None + self._read() + if validate: + self._validate() + + @property + def dat( + self, + ) -> Union[pd.DataFrame, Dataset]: + return self.__al_dataset.dataset + + @property + def inputs( + self, + ) -> Dict: + return self.__al_dataset.inputs + + @property + def gt_labels( + self, + ) -> str: + return self.__al_dataset.gt_labels + + @property + def columns( + self, + ) -> str: + return self.__al_dataset.columns + + def _read( + self, + ): + if isinstance(self.dataset, str): + if self.dataset.endswith(".csv"): + self.__al_dataset = CSVReader.read( + self.dataset, + self.config, + max_items=self.max_items, + start_index=self.start_index, + ) + + elif self.dataset.endswith(".jsonl"): + self.__al_dataset = JsonlReader.read( + self.dataset, + self.config, + max_items=self.max_items, + start_index=self.start_index, + ) + else: + raise ValueError(f"Unsupported file format: {self.dataset}") + elif isinstance(self.dataset, Dataset): + self.__al_dataset: AutolabelDataset = HuggingFaceDatasetReader.read( + self.dataset, self.config, self.max_items, self.start_index + ) + elif isinstance(self.dataset, pd.DataFrame): + self.__al_dataset: AutolabelDataset = DataframeReader.read( + self.dataset, self.config, self.start_index, self.max_items + ) + + def _validate(self): + """Validate Data""" + data_validation = TaskDataValidation(config=self.config) + + # Validate columns + data_validation.validate_dataset_columns( + dataset_columns=self.__al_dataset.columns + ) + + # Validate datatype and data format + self.__malformed_records = data_validation.validate( + data=self.__al_dataset.inputs + ) + + table = tabulate( + self.__malformed_records[0 : self.MAX_ERROR_DISPLAYED], + headers="keys", + tablefmt="fancy_grid", + numalign="center", + stralign="left", + ) + + logger.warning( + f"Data Validation failed for {len(self.__malformed_records)} records: \n Stats: \n {table}" + ) + + if len(self.__malformed_records) > 0: + raise DataValidationFailed( + f"Validation failed for {len(self.__malformed_records)} rows." + ) diff --git a/src/autolabel/data_loaders/read_datasets.py b/src/autolabel/data_loaders/read_datasets.py new file mode 100644 index 00000000..a68f1b6f --- /dev/null +++ b/src/autolabel/data_loaders/read_datasets.py @@ -0,0 +1,224 @@ +from typing import Dict, List, Union + +import logging +import pandas as pd + +from pydantic import BaseModel, validator +from datasets import Dataset +from sqlalchemy.sql.selectable import Selectable +from autolabel.configs import AutolabelConfig + +logger = logging.getLogger(__name__) +from typing import Union + + +class AutolabelDataset(BaseModel): + """Data Attributes""" + + columns: List + dataset: Union[pd.DataFrame, None] + inputs: List[Dict] + gt_labels: List + + @validator("dataset", allow_reuse=True) + def validate_dataframe(cls, value): + if not isinstance(value, pd.DataFrame): + raise ValueError("Value must be a pandas DataFrame") + return value + + class Config: + arbitrary_types_allowed = True + + +class CSVReader: + @staticmethod + def read( + csv_file: str, + config: AutolabelConfig, + max_items: int = None, + start_index: int = 0, + ) -> AutolabelDataset: + """Read the csv file and sets dat, inputs and gt_labels + + Args: + csv_file (str): path to the csv file + config (AutolabelConfig): config object + max_items (int, optional): max number of items to read. Defaults to None. + start_index (int, optional): start index to read from. Defaults to 0. + """ + logger.debug(f"reading the csv from: {start_index}") + delimiter = config.delimiter() + label_column = config.label_column() + + dat = pd.read_csv(csv_file, sep=delimiter, dtype="str")[start_index:] + + dat = dat.astype(str) + if max_items and max_items > 0: + max_items = min(max_items, len(dat)) + dat = dat[:max_items] + + inputs = dat.to_dict(orient="records") + gt_labels = ( + None + if not label_column or not len(inputs) or label_column not in inputs[0] + else dat[label_column].tolist() + ) + return AutolabelDataset( + columns=list(dat.columns), + dataset=dat, + inputs=inputs, + gt_labels=gt_labels, + ) + + +class JsonlReader: + @staticmethod + def read( + jsonl_file: str, + config: AutolabelConfig, + max_items: int = None, + start_index: int = 0, + ) -> AutolabelDataset: + """Read the jsonl file and sets dat, inputs and gt_labels + + Args: + jsonl_file (str): path to the jsonl file + config (AutolabelConfig): config object + max_items (int, optional): max number of items to read. Defaults to None. + start_index (int, optional): start index to read from. Defaults to 0. + """ + logger.debug(f"reading the jsonl from: {start_index}") + label_column = config.label_column() + + dat = pd.read_json(jsonl_file, lines=True, dtype="str")[start_index:] + dat = dat.astype(str) + if max_items and max_items > 0: + max_items = min(max_items, len(dat)) + dat = dat[:max_items] + + inputs = dat.to_dict(orient="records") + gt_labels = ( + None + if not label_column or not len(inputs) or label_column not in inputs[0] + else dat[label_column].tolist() + ) + return AutolabelDataset( + columns=list(dat.columns), + dataset=dat, + inputs=inputs, + gt_labels=gt_labels, + ) + + +class HuggingFaceDatasetReader: + @staticmethod + def read( + dataset: Dataset, + config: AutolabelConfig, + max_items: int = None, + start_index: int = 0, + ) -> AutolabelDataset: + """Read the huggingface dataset and sets dat, inputs and gt_labels + + Args: + dataset (Dataset): dataset object to read from + config (AutolabelConfig): config object + max_items (int, optional): max number of items to read. Defaults to None. + start_index (int, optional): start index to read from. Defaults to 0. + """ + dataset.set_format("pandas") + dat = dataset[ + start_index : max_items if max_items and max_items > 0 else len(dataset) + ] + + inputs = dat.to_dict(orient="records") + gt_labels = ( + None + if not config.label_column() + or not len(inputs) + or config.label_column() not in inputs[0] + else dat[config.label_column()].tolist() + ) + return AutolabelDataset( + columns=list(dat.columns), + dataset=dat, + inputs=inputs, + gt_labels=gt_labels, + ) + + +class SqlDatasetReader: + @staticmethod + def read( + sql: Union[str, Selectable], + connection: str, + config: AutolabelConfig, + max_items: int = None, + start_index: int = 0, + ) -> AutolabelDataset: + """Read the sql query and sets dat, inputs and gt_labels + + Args: + connection (str): connection string + config (AutolabelConfig): config object + max_items (int, optional): max number of items to read. Defaults to None. + start_index (int, optional): start index to read from. Defaults to 0. + """ + logger.debug(f"reading the sql from: {start_index}") + label_column = config.label_column() + + dat = pd.read_sql(sql, connection)[start_index:] + dat = dat.astype(str) + if max_items and max_items > 0: + max_items = min(max_items, len(dat)) + dat = dat[:max_items] + + inputs = dat.to_dict(orient="records") + gt_labels = ( + None + if not label_column or not len(inputs) or label_column not in inputs[0] + else dat[label_column].tolist() + ) + return AutolabelDataset( + columns=list(dat.columns), + dataset=dat, + inputs=inputs, + gt_labels=gt_labels, + ) + + +class DataframeReader: + @staticmethod + def read( + df: pd.DataFrame, + config: AutolabelConfig, + max_items: int = None, + start_index: int = 0, + ) -> AutolabelDataset: + """Read the csv file and sets dat, inputs and gt_labels + + Args: + df (pd.DataFrame): dataframe to read + config (AutolabelConfig): config object + max_items (int, optional): max number of items to read. Defaults to None. + start_index (int, optional): start index to read from. Defaults to 0. + """ + label_column = config.label_column() + + dat = df[start_index:].astype(str) + if max_items and max_items > 0: + max_items = min(max_items, len(dat)) + dat = dat[:max_items] + + inputs = dat.to_dict(orient="records") + gt_labels = ( + None + if not label_column or not len(inputs) or label_column not in inputs[0] + else dat[label_column].tolist() + ) + return AutolabelDataset( + columns=list(dat.columns), + dataset=dat, + inputs=inputs, + gt_labels=gt_labels, + ) diff --git a/src/autolabel/data_loaders/validation.py b/src/autolabel/data_loaders/validation.py new file mode 100644 index 00000000..de6bc82b --- /dev/null +++ b/src/autolabel/data_loaders/validation.py @@ -0,0 +1,246 @@ +"""Data and Schema Validation""" + +import re +import json +from functools import cached_property +from typing import Dict, List, Union, Optional +from json.decoder import JSONDecodeError +from pydantic import BaseModel, create_model, ValidationError, root_validator +from pydantic.types import StrictStr +from autolabel.configs import AutolabelConfig + + +# Regex pattern to extract expected column from onfig.example_template() +EXPECTED_COLUMN_PATTERN = r"\{([^}]*)\}" + + +class NERTaskValidate(BaseModel): + """Validate NER Task + + The label column can either be a string or a json + """ + + label_column: str + labels_set: set # A NER Task should have a unique set of labels in config + + def validate(self, value: str): + """Validate NER + + A NER label can only be a dictionary + """ + # TODO: This can be made better + if value.startswith("{") and value.endswith("}"): + try: + seed_labels = json.loads(value) + unmatched_label = set(seed_labels.keys()) - self.labels_set + if len(unmatched_label) != 0: + raise ValueError( + f"labels: '{unmatched_label}' not in promt/labels provided in config " + ) + except JSONDecodeError: + raise + else: + raise + + +class ClassificationTaskValidate(BaseModel): + """Validate Classification Task + + The label column can either be a string or a string of list + """ + + label_column: str + labels_set: set # A classification Task should have a unique set of labels in config + + def validate(self, value: str): + """Validate classification + + A classification label(ground_truth) could either be a list or string + """ + # TODO: This can be made better + if value.startswith("[") and value.endswith("]"): + try: + seed_labels = eval(value) + if not isinstance(seed_labels, list): + raise + unmatched_label = set(seed_labels) - self.labels_set + if len(unmatched_label) != 0: + raise ValueError( + f"labels: '{unmatched_label}' not in promt/labels provided in config " + ) + except SyntaxError: + raise + else: + if value not in self.labels_set: + raise ValueError( + f"labels: '{value}' not in promt/labels provided in config " + ) + + +class EMTaskValidate(BaseModel): + """Validate Entity Matching Task + + As of now we assume that the input label_column is a string + """ + + label_column: str + labels_set: set # An EntityMatching Task should have a unique set of labels in config + + def validate(self, value: str): + if value not in self.labels_set: + raise ValueError( + f"labels: '{value}' not in promt/labels provided in config " + ) + + +class QATaskValidate(BaseModel): + """Validate Question Answering Task + + As of now we assume that the input label_column is a string + """ + + label_column: str + labels_set: Optional[ + set + ] # A QA task may or may not have a unique set of label list + + def validate(self, value: str): + """Since question answering is arbitarary task we have no validation""" + pass + + +TaskTypeValidate = Union[ + NERTaskValidate, + ClassificationTaskValidate, + EMTaskValidate, + QATaskValidate, +] + + +class DataValidationTasks(BaseModel): + classification: TaskTypeValidate = ClassificationTaskValidate + named_entity_recognition: TaskTypeValidate = NERTaskValidate + entity_matching: TaskTypeValidate = EMTaskValidate + question_answering: TaskTypeValidate = QATaskValidate + + +class TaskDataValidation: + """Task Validation""" + + def __init__(self, config: AutolabelConfig): + """Task Validation + + Args: + config: AutolabelConfig = User passed parsed configuration + """ + # the type of task, classification, named_entity_recognition, etc.., "config/task_type" + task_type: str = config.task_type() + # the label column as specified in config, "config/dataset/label_column" + label_column: str = config.label_column() + # list of valid labels provided in config "config/prompt/labels" + labels_list: Optional[List] = config.labels_list() + # example template from config "config/prompt/example_template" + + self.example_template: str = config.example_template() + + self.__schema = {col: (StrictStr, ...) for col in self.expected_columns} + + self.__validation_task = DataValidationTasks.__dict__[task_type]( + label_column=label_column, labels_set=set(labels_list) + ) + self.__data_validation = self.data_validation_and_schema_check( + self.__validation_task + ) + + @cached_property + def expected_columns(self) -> List: + """Fetch expected columns""" + column_name_lists = [] + for text in self.example_template.split("\n"): + matches = re.findall(EXPECTED_COLUMN_PATTERN, text) + column_name_lists += matches + return column_name_lists + + @property + def schema(self) -> Dict: + """Fecth Schema""" + return self.__schema + + @property + def validation_task( + self, + ) -> TaskTypeValidate: + """Fetch validation task""" + return self.__validation_task + + def data_validation_and_schema_check(self, validation_task: BaseModel): + """Validate data format and datatype + + Args: + validation_task (TaskTypeValidate): validation task + + Raises: + e: Validation error if the inputs are not string + e: Validation error if validation_task fails + + Returns: + DataValidation: Pydantic Model for validation + """ + Model = create_model("Model", **self.__schema) + + class DataValidation(BaseModel): + """Data Validation""" + + # We define validate as a classmethod such that a dynamic `data` can be passed + # iteratively to the validate method using `DataValidation.validate` + @classmethod + def validate(cls, data): + """Valdiate data types""" + model = Model(**data) + try: + # We perform the normal pydantic validation here + # This checks both the Schema and also calls check_fields + cls(**model.dict()) + except ValidationError as e: + raise e + + @root_validator(pre=True, allow_reuse=True) + def check_fields(cls, values): + """Validate data format""" + try: + label_column_value = values[validation_task.label_column] + validation_task.validate(label_column_value) + except ValidationError as e: + raise e + + return DataValidation + + def validate(self, data: List[dict]) -> List[Dict]: + """Validate Data""" + error_messages = [] + for index, item in enumerate(data): + try: + self.__data_validation.validate(item) + except ValidationError as e: + for err in e.errors(): + field = ".".join(err["loc"]) + error_messages += [ + { + "row_num": index, + "loc": field, + "msg": err["msg"], + "type": err["type"], + } + ] + return error_messages + + def validate_dataset_columns(self, dataset_columns: List): + """Validate columns + + Valiate if the columns mentioned in example_template dataset are correct + and are contined within the columns of the dataset(seed.csv) + """ + missing_columns = set(self.expected_columns) - set(dataset_columns) + assert ( + len(missing_columns) == 0 + ), f"columns={missing_columns} missing in seed.csv file" diff --git a/src/autolabel/dataset_loader.py b/src/autolabel/dataset_loader.py deleted file mode 100644 index 8dec09af..00000000 --- a/src/autolabel/dataset_loader.py +++ /dev/null @@ -1,233 +0,0 @@ -from typing import Dict, List, Tuple, Union - -import logging -import pandas as pd -from sqlalchemy.sql.selectable import Selectable -from datasets import Dataset - - -from autolabel.configs import AutolabelConfig - -logger = logging.getLogger(__name__) - - -class DatasetLoader: - # TODO: add support for reading from SQL databases - # TODO: add support for reading and loading datasets in chunks - - def __init__( - self, - dataset: Union[str, pd.DataFrame], - config: AutolabelConfig, - max_items: int = 0, - start_index: int = 0, - ) -> None: - """DatasetLoader class to read and load datasets. - - Args: - dataset (Union[str, pd.DataFrame]): path to the dataset or the dataframe - config (AutolabelConfig): config object - max_items (int, optional): max number of items to read. Defaults to 0. - start_index (int, optional): start index to read from. Defaults to 0. - """ - self.dataset = dataset - self.config = config - self.max_items = max_items - self.start_index = start_index - - if isinstance(dataset, str): - self._read_file(dataset, config, max_items, start_index) - elif isinstance(dataset, Dataset): - self._read_hf_dataset(dataset, config, max_items, start_index) - elif isinstance(dataset, pd.DataFrame): - self._read_dataframe(dataset, config, start_index, max_items) - - def _read_csv( - self, - csv_file: str, - config: AutolabelConfig, - max_items: int = None, - start_index: int = 0, - ) -> None: - """Read the csv file and sets dat, inputs and gt_labels - - Args: - csv_file (str): path to the csv file - config (AutolabelConfig): config object - max_items (int, optional): max number of items to read. Defaults to None. - start_index (int, optional): start index to read from. Defaults to 0. - """ - logger.debug(f"reading the csv from: {start_index}") - delimiter = config.delimiter() - label_column = config.label_column() - - self.dat = pd.read_csv(csv_file, sep=delimiter, dtype="str")[start_index:] - self.dat = self.dat.astype(str) - if max_items and max_items > 0: - max_items = min(max_items, len(self.dat)) - self.dat = self.dat[:max_items] - - self.inputs = self.dat.to_dict(orient="records") - self.gt_labels = ( - None - if not label_column - or not len(self.inputs) - or label_column not in self.inputs[0] - else self.dat[label_column].tolist() - ) - - def _read_dataframe( - self, - df: pd.DataFrame, - config: AutolabelConfig, - max_items: int = None, - start_index: int = 0, - ) -> None: - """Read the csv file and sets dat, inputs and gt_labels - - Args: - df (pd.DataFrame): dataframe to read - config (AutolabelConfig): config object - max_items (int, optional): max number of items to read. Defaults to None. - start_index (int, optional): start index to read from. Defaults to 0. - """ - label_column = config.label_column() - - self.dat = df[start_index:].astype(str) - if max_items and max_items > 0: - max_items = min(max_items, len(self.dat)) - self.dat = self.dat[:max_items] - - self.inputs = self.dat.to_dict(orient="records") - self.gt_labels = ( - None - if not label_column - or not len(self.inputs) - or label_column not in self.inputs[0] - else self.dat[label_column].tolist() - ) - - def _read_jsonl( - self, - jsonl_file: str, - config: AutolabelConfig, - max_items: int = None, - start_index: int = 0, - ) -> None: - """Read the jsonl file and sets dat, inputs and gt_labels - - Args: - jsonl_file (str): path to the jsonl file - config (AutolabelConfig): config object - max_items (int, optional): max number of items to read. Defaults to None. - start_index (int, optional): start index to read from. Defaults to 0. - """ - logger.debug(f"reading the jsonl from: {start_index}") - label_column = config.label_column() - - self.dat = pd.read_json(jsonl_file, lines=True, dtype="str")[start_index:] - self.dat = self.dat.astype(str) - if max_items and max_items > 0: - max_items = min(max_items, len(self.dat)) - self.dat = self.dat[:max_items] - - self.inputs = self.dat.to_dict(orient="records") - self.gt_labels = ( - None - if not label_column - or not len(self.inputs) - or label_column not in self.inputs[0] - else self.dat[label_column].tolist() - ) - - def _read_sql( - self, - sql: Union[str, Selectable], - connection: str, - config: AutolabelConfig, - max_items: int = None, - start_index: int = 0, - ) -> None: - """Read the sql query and sets dat, inputs and gt_labels - - Args: - connection (str): connection string - config (AutolabelConfig): config object - max_items (int, optional): max number of items to read. Defaults to None. - start_index (int, optional): start index to read from. Defaults to 0. - """ - logger.debug(f"reading the sql from: {start_index}") - label_column = config.label_column() - - self.dat = pd.read_sql(sql, connection)[start_index:] - self.dat = self.dat.astype(str) - if max_items and max_items > 0: - max_items = min(max_items, len(self.dat)) - self.dat = self.dat[:max_items] - - self.inputs = self.dat.to_dict(orient="records") - self.gt_labels = ( - None - if not label_column - or not len(self.inputs) - or label_column not in self.inputs[0] - else self.dat[label_column].tolist() - ) - - def _read_file( - self, - file: str, - config: AutolabelConfig, - max_items: int = None, - start_index: int = 0, - ) -> None: - """Read the file and sets dat, inputs and gt_labels - - Args: - file (str): path to the file - config (AutolabelConfig): config object - max_items (int, optional): max number of items to read. Defaults to None. - start_index (int, optional): start index to read from. Defaults to 0. - - Raises: - ValueError: if the file format is not supported - """ - if file.endswith(".csv"): - return self._read_csv( - file, config, max_items=max_items, start_index=start_index - ) - elif file.endswith(".jsonl"): - return self._read_jsonl( - file, config, max_items=max_items, start_index=start_index - ) - else: - raise ValueError(f"Unsupported file format: {file}") - - def _read_hf_dataset( - self, - dataset: Dataset, - config: AutolabelConfig, - max_items: int = None, - start_index: int = 0, - ) -> None: - """Read the huggingface dataset and sets dat, inputs and gt_labels - - Args: - dataset (Dataset): dataset object to read from - config (AutolabelConfig): config object - max_items (int, optional): max number of items to read. Defaults to None. - start_index (int, optional): start index to read from. Defaults to 0. - """ - dataset.set_format("pandas") - self.dat = dataset[ - start_index : max_items if max_items and max_items > 0 else len(dataset) - ] - - self.inputs = self.dat.to_dict(orient="records") - self.gt_labels = ( - None - if not config.label_column() - or not len(self.inputs) - or config.label_column() not in self.inputs[0] - else self.dat[config.label_column()].tolist() - ) diff --git a/src/autolabel/labeler.py b/src/autolabel/labeler.py index 2b286a05..54a91a8b 100644 --- a/src/autolabel/labeler.py +++ b/src/autolabel/labeler.py @@ -15,7 +15,7 @@ from autolabel.configs import AutolabelConfig from autolabel.data_models import AnnotationModel, TaskRunModel from autolabel.database import StateManager -from autolabel.dataset_loader import DatasetLoader +from autolabel.data_loaders import DatasetLoader from autolabel.few_shot import ExampleSelectorFactory from autolabel.models import BaseModel, ModelFactory from autolabel.schema import LLMAnnotation, MetricResult, TaskRun, TaskStatus @@ -306,7 +306,9 @@ def plan( "REFUEL_API_KEY environment variable must be set to compute confidence scores. You can request an API key at https://refuel-ai.typeform.com/llm-access." ) - dataset_loader = DatasetLoader(dataset, self.config, max_items, start_index) + dataset_loader = DatasetLoader( + dataset, self.config, max_items, start_index, validate=True + ) prompt_list = [] total_cost = 0 @@ -316,7 +318,7 @@ def plan( # If this dataset config is a string, read the corrresponding csv file if isinstance(seed_examples, str): - seed_loader = DatasetLoader(seed_examples, self.config) + seed_loader = DatasetLoader(seed_examples, self.config, validate=True) seed_examples = seed_loader.inputs # Check explanations are present in data if explanation_column is passed in diff --git a/tests/unit/data_loaders/test_validation.py b/tests/unit/data_loaders/test_validation.py new file mode 100644 index 00000000..7e174a38 --- /dev/null +++ b/tests/unit/data_loaders/test_validation.py @@ -0,0 +1,209 @@ +"""Test Validation""" +import pytest +from autolabel.configs import AutolabelConfig +from autolabel.data_loaders.validation import TaskDataValidation + + +CLASSIFICATION_CONFIG_SAMPLE_DICT = { + "task_name": "LegalProvisionsClassification", + "task_type": "classification", + "dataset": {"label_column": "label", "delimiter": ","}, + "model": {"provider": "openai", "name": "gpt-3.5-turbo"}, + "prompt": { + "task_guidelines": "You are an expert at understanding legal contracts. Your job is to correctly classify legal provisions in contracts into one of the following categories.\nCategories:{labels}\n", + "labels": [ + "Agreements", + "Argument", + ], + "example_template": "Example: {example}\nOutput: {label}", + "few_shot_examples": "seed.csv", + "few_shot_selection": "semantic_similarity", + "few_shot_num": 4, + }, +} + +NER_CONFIG_SAMPLE_DICT = { + "task_name": "PersonLocationOrgMiscNER", + "task_type": "named_entity_recognition", + "dataset": { + "label_column": "CategorizedLabels", + "text_column": "example", + "delimiter": ",", + }, + "model": {"provider": "openai", "name": "gpt-3.5-turbo"}, + "prompt": { + "task_guidelines": "You are an expert at extracting Person, Organization, Location, and Miscellaneous entities from text. Your job is to extract named entities mentioned in text, and classify them into one of the following categories.\nCategories:\n{labels}\n ", + "labels": [ + "Location", + ], + "example_template": "Example: {example}\nOutput:\n{CategorizedLabels}", + "few_shot_examples": "seed.csv", + "few_shot_selection": "semantic_similarity", + "few_shot_num": 5, + }, +} + +EM_CONFIG_SAMPLE_DICT = { + "task_name": "ProductCatalogEntityMatch", + "task_type": "entity_matching", + "dataset": {"label_column": "label", "delimiter": ","}, + "model": {"provider": "openai", "name": "gpt-3.5-turbo"}, + "prompt": { + "task_guidelines": "You are an expert at identifying duplicate products from online product catalogs.\nYou will be given information about two product entities, and your job is to tell if they are the same (duplicate) or different (not duplicate). Your answer must be from one of the following options:\n{labels}", + "labels": ["duplicate", "not duplicate"], + "example_template": "Title of entity1: {Title_entity1}; \nDuplicate or not: {label}", + "few_shot_selection": "fixed", + "few_shot_num": 2, + }, +} + + +def test_validate_classification_task(): + """Test Validate classification""" + data = [ + {"question": "s", "answer": "ee"}, # Wrong column names + {"example": "s", "label": 1}, # int value not accepted + {"example": "s", "label": "['w']"}, # w, d are incorrect labels + { + "example": "s", + "label": "['Agreements', 'Random']", + }, # Random not valid label + {"example": "s", "label": "['Agreements', 'Argument']"}, # Correct + {"example": "s", "label": "['Agreements']"}, # Correct + ] + expected_output = [ + { + "loc": "example", + "msg": "field required", + "row_num": 0, + "type": "value_error.missing", + }, + { + "loc": "label", + "msg": "field required", + "row_num": 0, + "type": "value_error.missing", + }, + { + "loc": "label", + "msg": "str type expected", + "row_num": 1, + "type": "type_error.str", + }, + { + "loc": "__root__", + "msg": "labels: '{'w'}' not in promt/labels provided in config ", + "row_num": 2, + "type": "value_error", + }, + { + "loc": "__root__", + "msg": "labels: '{'Random'}' not in promt/labels provided in config ", + "row_num": 3, + "type": "value_error", + }, + ] + data_validation = TaskDataValidation( + config=AutolabelConfig(CLASSIFICATION_CONFIG_SAMPLE_DICT) + ) + error_table = data_validation.validate(data=data) + + for exp_out, err_out in zip(expected_output, error_table): + assert exp_out == err_out + + +def test_validate_ner_task(): + """Test Validate NamedEntityRecognition""" + data = [ + { + # Miscellaneous is not a valid label mentioned in NER_CONFIG_SAMPLE_DICT + "example": "example1", + "CategorizedLabels": '{"Location": ["Okla"], "Miscellaneous": []}', + }, + { + # Not a valid Json + "example": "example2", + "CategorizedLabels": '{"Location":["Texas"], ""Miscellaneous"": []}', + }, + { + # label is not the correct column name + "example": "example3", + "label": '{"Location":["Texas"], "Miscellaneous": ["USDA", "PPAS"]}', + }, + { + # Correct + "example": "example2", + "CategorizedLabels": '{"Location":["Texas"]}', + }, + ] + expected_output = [ + { + "loc": "__root__", + "msg": "labels: '{'Miscellaneous'}' not in promt/labels provided in config ", + "row_num": 0, + "type": "value_error", + }, + { + "loc": "__root__", + "msg": "Expecting ':' delimiter: line 1 column 26 (char 25)", + "row_num": 1, + "type": "value_error.jsondecode", + }, + { + "loc": "CategorizedLabels", + "msg": "field required", + "row_num": 2, + "type": "value_error.missing", + }, + ] + + data_validation = TaskDataValidation(config=AutolabelConfig(NER_CONFIG_SAMPLE_DICT)) + + error_table = data_validation.validate(data=data) + + for exp_out, err_out in zip(expected_output, error_table): + assert exp_out == err_out + + +def test_validate_EM_task(): + """Test Validate NamedEntityRecognition""" + data = [ + {"Title_entity1": "example1", "label": "duplicate"}, + {"Title_entity1": "example2", "label": "not duplicate"}, + {"ErrorColumn": "example2", "label": '{"Location":["Texas"]}'}, + {"Title_entity1": "example2", "label": "duplicate not duplicate"}, + ] + + expected_output = [ + { + "row_num": 2, + "loc": "Title_entity1", + "msg": "field required", + "type": "value_error.missing", + }, + { + "row_num": 3, + "loc": "__root__", + "msg": "labels: 'duplicate not duplicate' not in promt/labels provided in config ", + "type": "value_error", + }, + ] + + data_validation = TaskDataValidation(config=AutolabelConfig(EM_CONFIG_SAMPLE_DICT)) + + error_table = data_validation.validate(data=data) + + for exp_out, err_out in zip(expected_output, error_table): + assert exp_out == err_out + + +def test_columns(): + """Test Validate NamedEntityRecognition""" + data_validation = TaskDataValidation(config=AutolabelConfig(NER_CONFIG_SAMPLE_DICT)) + + with pytest.raises( + AssertionError, match=r"columns={'example'} missing in seed.csv file" + ): + data_validation.validate_dataset_columns( + dataset_columns=["input", "CategorizedLabels"] + ) diff --git a/tests/unit/test_data_loading.py b/tests/unit/test_data_loading.py index 482797a6..3fc86a76 100644 --- a/tests/unit/test_data_loading.py +++ b/tests/unit/test_data_loading.py @@ -1,5 +1,5 @@ from autolabel import LabelingAgent -from autolabel.dataset_loader import DatasetLoader +from autolabel.data_loaders import DatasetLoader from pandas import DataFrame csv_path = "tests/assets/banking/test.csv" diff --git a/tests/unit/test_few_shot.py b/tests/unit/test_few_shot.py index 20e45e93..9ff67ffd 100644 --- a/tests/unit/test_few_shot.py +++ b/tests/unit/test_few_shot.py @@ -1,7 +1,7 @@ import json from autolabel.configs import AutolabelConfig -from autolabel.dataset_loader import DatasetLoader +from autolabel.data_loaders import DatasetLoader from autolabel.few_shot import ExampleSelectorFactory from langchain.embeddings import HuggingFaceEmbeddings from pytest import approx