Skip to content

Commit

Permalink
Fix 178: Add data type checking and validation while performing agent…
Browse files Browse the repository at this point in the history
… plan action (#428)
  • Loading branch information
Sardhendu authored Jul 7, 2023
1 parent d568318 commit 6ed14a5
Show file tree
Hide file tree
Showing 9 changed files with 829 additions and 239 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ dependencies = [
"matplotlib >= 3.5.0",
"wget >= 3.2",
"ipywidgets == 8.0.6",
"jsonschema >= 4.17.3"
"jsonschema >= 4.17.3",
"tabulate >= 0.9.0"
]
requires-python = ">=3.6"

Expand Down
141 changes: 141 additions & 0 deletions src/autolabel/data_loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import logging
import pandas as pd
from tabulate import tabulate
from typing import Dict, Union
from datasets import Dataset
from autolabel.data_loaders.read_datasets import (
AutolabelDataset,
CSVReader,
JsonlReader,
HuggingFaceDatasetReader,
# SqlDatasetReader,
DataframeReader,
)
from autolabel.data_loaders.validation import TaskDataValidation
from autolabel.configs import AutolabelConfig

logger = logging.getLogger(__name__)


class DataValidationFailed(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)


class DatasetLoader:
# TODO: add support for reading from SQL databases
# TODO: add support for reading and loading datasets in chunks
MAX_ERROR_DISPLAYED = 100

def __init__(
self,
dataset: Union[str, pd.DataFrame],
config: AutolabelConfig,
max_items: int = 0,
start_index: int = 0,
validate: bool = True,
) -> None:
"""DatasetLoader class to read and load datasets.
Args:
dataset (Union[str, pd.DataFrame]): path to the dataset or the dataframe
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to 0.
start_index (int, optional): start index to read from. Defaults to 0.
"""
self.dataset = dataset
self.config = config
self.max_items = max_items
self.start_index = start_index

self.__al_dataset: AutolabelDataset = None
self.__malformed_records = None
self._read()
if validate:
self._validate()

@property
def dat(
self,
) -> Union[pd.DataFrame, Dataset]:
return self.__al_dataset.dataset

@property
def inputs(
self,
) -> Dict:
return self.__al_dataset.inputs

@property
def gt_labels(
self,
) -> str:
return self.__al_dataset.gt_labels

@property
def columns(
self,
) -> str:
return self.__al_dataset.columns

def _read(
self,
):
if isinstance(self.dataset, str):
if self.dataset.endswith(".csv"):
self.__al_dataset = CSVReader.read(
self.dataset,
self.config,
max_items=self.max_items,
start_index=self.start_index,
)

elif self.dataset.endswith(".jsonl"):
self.__al_dataset = JsonlReader.read(
self.dataset,
self.config,
max_items=self.max_items,
start_index=self.start_index,
)
else:
raise ValueError(f"Unsupported file format: {self.dataset}")
elif isinstance(self.dataset, Dataset):
self.__al_dataset: AutolabelDataset = HuggingFaceDatasetReader.read(
self.dataset, self.config, self.max_items, self.start_index
)
elif isinstance(self.dataset, pd.DataFrame):
self.__al_dataset: AutolabelDataset = DataframeReader.read(
self.dataset, self.config, self.start_index, self.max_items
)

def _validate(self):
"""Validate Data"""
data_validation = TaskDataValidation(config=self.config)

# Validate columns
data_validation.validate_dataset_columns(
dataset_columns=self.__al_dataset.columns
)

# Validate datatype and data format
self.__malformed_records = data_validation.validate(
data=self.__al_dataset.inputs
)

table = tabulate(
self.__malformed_records[0 : self.MAX_ERROR_DISPLAYED],
headers="keys",
tablefmt="fancy_grid",
numalign="center",
stralign="left",
)

logger.warning(
f"Data Validation failed for {len(self.__malformed_records)} records: \n Stats: \n {table}"
)

if len(self.__malformed_records) > 0:
raise DataValidationFailed(
f"Validation failed for {len(self.__malformed_records)} rows."
)
224 changes: 224 additions & 0 deletions src/autolabel/data_loaders/read_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from typing import Dict, List, Union

import logging
import pandas as pd

from pydantic import BaseModel, validator
from datasets import Dataset
from sqlalchemy.sql.selectable import Selectable
from autolabel.configs import AutolabelConfig

logger = logging.getLogger(__name__)
from typing import Union


class AutolabelDataset(BaseModel):
"""Data Attributes"""

columns: List
dataset: Union[pd.DataFrame, None]
inputs: List[Dict]
gt_labels: List

@validator("dataset", allow_reuse=True)
def validate_dataframe(cls, value):
if not isinstance(value, pd.DataFrame):
raise ValueError("Value must be a pandas DataFrame")
return value

class Config:
arbitrary_types_allowed = True


class CSVReader:
@staticmethod
def read(
csv_file: str,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> AutolabelDataset:
"""Read the csv file and sets dat, inputs and gt_labels
Args:
csv_file (str): path to the csv file
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
"""
logger.debug(f"reading the csv from: {start_index}")
delimiter = config.delimiter()
label_column = config.label_column()

dat = pd.read_csv(csv_file, sep=delimiter, dtype="str")[start_index:]

dat = dat.astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
)
return AutolabelDataset(
columns=list(dat.columns),
dataset=dat,
inputs=inputs,
gt_labels=gt_labels,
)


class JsonlReader:
@staticmethod
def read(
jsonl_file: str,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> AutolabelDataset:
"""Read the jsonl file and sets dat, inputs and gt_labels
Args:
jsonl_file (str): path to the jsonl file
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
"""
logger.debug(f"reading the jsonl from: {start_index}")
label_column = config.label_column()

dat = pd.read_json(jsonl_file, lines=True, dtype="str")[start_index:]
dat = dat.astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
)
return AutolabelDataset(
columns=list(dat.columns),
dataset=dat,
inputs=inputs,
gt_labels=gt_labels,
)


class HuggingFaceDatasetReader:
@staticmethod
def read(
dataset: Dataset,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> AutolabelDataset:
"""Read the huggingface dataset and sets dat, inputs and gt_labels
Args:
dataset (Dataset): dataset object to read from
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
"""
dataset.set_format("pandas")
dat = dataset[
start_index : max_items if max_items and max_items > 0 else len(dataset)
]

inputs = dat.to_dict(orient="records")
gt_labels = (
None
if not config.label_column()
or not len(inputs)
or config.label_column() not in inputs[0]
else dat[config.label_column()].tolist()
)
return AutolabelDataset(
columns=list(dat.columns),
dataset=dat,
inputs=inputs,
gt_labels=gt_labels,
)


class SqlDatasetReader:
@staticmethod
def read(
sql: Union[str, Selectable],
connection: str,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> AutolabelDataset:
"""Read the sql query and sets dat, inputs and gt_labels
Args:
connection (str): connection string
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
"""
logger.debug(f"reading the sql from: {start_index}")
label_column = config.label_column()

dat = pd.read_sql(sql, connection)[start_index:]
dat = dat.astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
)
return AutolabelDataset(
columns=list(dat.columns),
dataset=dat,
inputs=inputs,
gt_labels=gt_labels,
)


class DataframeReader:
@staticmethod
def read(
df: pd.DataFrame,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> AutolabelDataset:
"""Read the csv file and sets dat, inputs and gt_labels
Args:
df (pd.DataFrame): dataframe to read
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
"""
label_column = config.label_column()

dat = df[start_index:].astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
)
return AutolabelDataset(
columns=list(dat.columns),
dataset=dat,
inputs=inputs,
gt_labels=gt_labels,
)
Loading

0 comments on commit 6ed14a5

Please sign in to comment.