-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix 178: Add data type checking and validation while performing agent…
… plan action (#428)
- Loading branch information
Showing
9 changed files
with
829 additions
and
239 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import logging | ||
import pandas as pd | ||
from tabulate import tabulate | ||
from typing import Dict, Union | ||
from datasets import Dataset | ||
from autolabel.data_loaders.read_datasets import ( | ||
AutolabelDataset, | ||
CSVReader, | ||
JsonlReader, | ||
HuggingFaceDatasetReader, | ||
# SqlDatasetReader, | ||
DataframeReader, | ||
) | ||
from autolabel.data_loaders.validation import TaskDataValidation | ||
from autolabel.configs import AutolabelConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class DataValidationFailed(Exception): | ||
def __init__(self, message): | ||
self.message = message | ||
super().__init__(self.message) | ||
|
||
|
||
class DatasetLoader: | ||
# TODO: add support for reading from SQL databases | ||
# TODO: add support for reading and loading datasets in chunks | ||
MAX_ERROR_DISPLAYED = 100 | ||
|
||
def __init__( | ||
self, | ||
dataset: Union[str, pd.DataFrame], | ||
config: AutolabelConfig, | ||
max_items: int = 0, | ||
start_index: int = 0, | ||
validate: bool = True, | ||
) -> None: | ||
"""DatasetLoader class to read and load datasets. | ||
Args: | ||
dataset (Union[str, pd.DataFrame]): path to the dataset or the dataframe | ||
config (AutolabelConfig): config object | ||
max_items (int, optional): max number of items to read. Defaults to 0. | ||
start_index (int, optional): start index to read from. Defaults to 0. | ||
""" | ||
self.dataset = dataset | ||
self.config = config | ||
self.max_items = max_items | ||
self.start_index = start_index | ||
|
||
self.__al_dataset: AutolabelDataset = None | ||
self.__malformed_records = None | ||
self._read() | ||
if validate: | ||
self._validate() | ||
|
||
@property | ||
def dat( | ||
self, | ||
) -> Union[pd.DataFrame, Dataset]: | ||
return self.__al_dataset.dataset | ||
|
||
@property | ||
def inputs( | ||
self, | ||
) -> Dict: | ||
return self.__al_dataset.inputs | ||
|
||
@property | ||
def gt_labels( | ||
self, | ||
) -> str: | ||
return self.__al_dataset.gt_labels | ||
|
||
@property | ||
def columns( | ||
self, | ||
) -> str: | ||
return self.__al_dataset.columns | ||
|
||
def _read( | ||
self, | ||
): | ||
if isinstance(self.dataset, str): | ||
if self.dataset.endswith(".csv"): | ||
self.__al_dataset = CSVReader.read( | ||
self.dataset, | ||
self.config, | ||
max_items=self.max_items, | ||
start_index=self.start_index, | ||
) | ||
|
||
elif self.dataset.endswith(".jsonl"): | ||
self.__al_dataset = JsonlReader.read( | ||
self.dataset, | ||
self.config, | ||
max_items=self.max_items, | ||
start_index=self.start_index, | ||
) | ||
else: | ||
raise ValueError(f"Unsupported file format: {self.dataset}") | ||
elif isinstance(self.dataset, Dataset): | ||
self.__al_dataset: AutolabelDataset = HuggingFaceDatasetReader.read( | ||
self.dataset, self.config, self.max_items, self.start_index | ||
) | ||
elif isinstance(self.dataset, pd.DataFrame): | ||
self.__al_dataset: AutolabelDataset = DataframeReader.read( | ||
self.dataset, self.config, self.start_index, self.max_items | ||
) | ||
|
||
def _validate(self): | ||
"""Validate Data""" | ||
data_validation = TaskDataValidation(config=self.config) | ||
|
||
# Validate columns | ||
data_validation.validate_dataset_columns( | ||
dataset_columns=self.__al_dataset.columns | ||
) | ||
|
||
# Validate datatype and data format | ||
self.__malformed_records = data_validation.validate( | ||
data=self.__al_dataset.inputs | ||
) | ||
|
||
table = tabulate( | ||
self.__malformed_records[0 : self.MAX_ERROR_DISPLAYED], | ||
headers="keys", | ||
tablefmt="fancy_grid", | ||
numalign="center", | ||
stralign="left", | ||
) | ||
|
||
logger.warning( | ||
f"Data Validation failed for {len(self.__malformed_records)} records: \n Stats: \n {table}" | ||
) | ||
|
||
if len(self.__malformed_records) > 0: | ||
raise DataValidationFailed( | ||
f"Validation failed for {len(self.__malformed_records)} rows." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
from typing import Dict, List, Union | ||
|
||
import logging | ||
import pandas as pd | ||
|
||
from pydantic import BaseModel, validator | ||
from datasets import Dataset | ||
from sqlalchemy.sql.selectable import Selectable | ||
from autolabel.configs import AutolabelConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
from typing import Union | ||
|
||
|
||
class AutolabelDataset(BaseModel): | ||
"""Data Attributes""" | ||
|
||
columns: List | ||
dataset: Union[pd.DataFrame, None] | ||
inputs: List[Dict] | ||
gt_labels: List | ||
|
||
@validator("dataset", allow_reuse=True) | ||
def validate_dataframe(cls, value): | ||
if not isinstance(value, pd.DataFrame): | ||
raise ValueError("Value must be a pandas DataFrame") | ||
return value | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
|
||
class CSVReader: | ||
@staticmethod | ||
def read( | ||
csv_file: str, | ||
config: AutolabelConfig, | ||
max_items: int = None, | ||
start_index: int = 0, | ||
) -> AutolabelDataset: | ||
"""Read the csv file and sets dat, inputs and gt_labels | ||
Args: | ||
csv_file (str): path to the csv file | ||
config (AutolabelConfig): config object | ||
max_items (int, optional): max number of items to read. Defaults to None. | ||
start_index (int, optional): start index to read from. Defaults to 0. | ||
""" | ||
logger.debug(f"reading the csv from: {start_index}") | ||
delimiter = config.delimiter() | ||
label_column = config.label_column() | ||
|
||
dat = pd.read_csv(csv_file, sep=delimiter, dtype="str")[start_index:] | ||
|
||
dat = dat.astype(str) | ||
if max_items and max_items > 0: | ||
max_items = min(max_items, len(dat)) | ||
dat = dat[:max_items] | ||
|
||
inputs = dat.to_dict(orient="records") | ||
gt_labels = ( | ||
None | ||
if not label_column or not len(inputs) or label_column not in inputs[0] | ||
else dat[label_column].tolist() | ||
) | ||
return AutolabelDataset( | ||
columns=list(dat.columns), | ||
dataset=dat, | ||
inputs=inputs, | ||
gt_labels=gt_labels, | ||
) | ||
|
||
|
||
class JsonlReader: | ||
@staticmethod | ||
def read( | ||
jsonl_file: str, | ||
config: AutolabelConfig, | ||
max_items: int = None, | ||
start_index: int = 0, | ||
) -> AutolabelDataset: | ||
"""Read the jsonl file and sets dat, inputs and gt_labels | ||
Args: | ||
jsonl_file (str): path to the jsonl file | ||
config (AutolabelConfig): config object | ||
max_items (int, optional): max number of items to read. Defaults to None. | ||
start_index (int, optional): start index to read from. Defaults to 0. | ||
""" | ||
logger.debug(f"reading the jsonl from: {start_index}") | ||
label_column = config.label_column() | ||
|
||
dat = pd.read_json(jsonl_file, lines=True, dtype="str")[start_index:] | ||
dat = dat.astype(str) | ||
if max_items and max_items > 0: | ||
max_items = min(max_items, len(dat)) | ||
dat = dat[:max_items] | ||
|
||
inputs = dat.to_dict(orient="records") | ||
gt_labels = ( | ||
None | ||
if not label_column or not len(inputs) or label_column not in inputs[0] | ||
else dat[label_column].tolist() | ||
) | ||
return AutolabelDataset( | ||
columns=list(dat.columns), | ||
dataset=dat, | ||
inputs=inputs, | ||
gt_labels=gt_labels, | ||
) | ||
|
||
|
||
class HuggingFaceDatasetReader: | ||
@staticmethod | ||
def read( | ||
dataset: Dataset, | ||
config: AutolabelConfig, | ||
max_items: int = None, | ||
start_index: int = 0, | ||
) -> AutolabelDataset: | ||
"""Read the huggingface dataset and sets dat, inputs and gt_labels | ||
Args: | ||
dataset (Dataset): dataset object to read from | ||
config (AutolabelConfig): config object | ||
max_items (int, optional): max number of items to read. Defaults to None. | ||
start_index (int, optional): start index to read from. Defaults to 0. | ||
""" | ||
dataset.set_format("pandas") | ||
dat = dataset[ | ||
start_index : max_items if max_items and max_items > 0 else len(dataset) | ||
] | ||
|
||
inputs = dat.to_dict(orient="records") | ||
gt_labels = ( | ||
None | ||
if not config.label_column() | ||
or not len(inputs) | ||
or config.label_column() not in inputs[0] | ||
else dat[config.label_column()].tolist() | ||
) | ||
return AutolabelDataset( | ||
columns=list(dat.columns), | ||
dataset=dat, | ||
inputs=inputs, | ||
gt_labels=gt_labels, | ||
) | ||
|
||
|
||
class SqlDatasetReader: | ||
@staticmethod | ||
def read( | ||
sql: Union[str, Selectable], | ||
connection: str, | ||
config: AutolabelConfig, | ||
max_items: int = None, | ||
start_index: int = 0, | ||
) -> AutolabelDataset: | ||
"""Read the sql query and sets dat, inputs and gt_labels | ||
Args: | ||
connection (str): connection string | ||
config (AutolabelConfig): config object | ||
max_items (int, optional): max number of items to read. Defaults to None. | ||
start_index (int, optional): start index to read from. Defaults to 0. | ||
""" | ||
logger.debug(f"reading the sql from: {start_index}") | ||
label_column = config.label_column() | ||
|
||
dat = pd.read_sql(sql, connection)[start_index:] | ||
dat = dat.astype(str) | ||
if max_items and max_items > 0: | ||
max_items = min(max_items, len(dat)) | ||
dat = dat[:max_items] | ||
|
||
inputs = dat.to_dict(orient="records") | ||
gt_labels = ( | ||
None | ||
if not label_column or not len(inputs) or label_column not in inputs[0] | ||
else dat[label_column].tolist() | ||
) | ||
return AutolabelDataset( | ||
columns=list(dat.columns), | ||
dataset=dat, | ||
inputs=inputs, | ||
gt_labels=gt_labels, | ||
) | ||
|
||
|
||
class DataframeReader: | ||
@staticmethod | ||
def read( | ||
df: pd.DataFrame, | ||
config: AutolabelConfig, | ||
max_items: int = None, | ||
start_index: int = 0, | ||
) -> AutolabelDataset: | ||
"""Read the csv file and sets dat, inputs and gt_labels | ||
Args: | ||
df (pd.DataFrame): dataframe to read | ||
config (AutolabelConfig): config object | ||
max_items (int, optional): max number of items to read. Defaults to None. | ||
start_index (int, optional): start index to read from. Defaults to 0. | ||
""" | ||
label_column = config.label_column() | ||
|
||
dat = df[start_index:].astype(str) | ||
if max_items and max_items > 0: | ||
max_items = min(max_items, len(dat)) | ||
dat = dat[:max_items] | ||
|
||
inputs = dat.to_dict(orient="records") | ||
gt_labels = ( | ||
None | ||
if not label_column or not len(inputs) or label_column not in inputs[0] | ||
else dat[label_column].tolist() | ||
) | ||
return AutolabelDataset( | ||
columns=list(dat.columns), | ||
dataset=dat, | ||
inputs=inputs, | ||
gt_labels=gt_labels, | ||
) |
Oops, something went wrong.