Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 178: Add data type checking and validation while performing agent plan action #428

Merged
merged 12 commits into from
Jul 7, 2023
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ dependencies = [
"matplotlib >= 3.5.0",
"wget >= 3.2",
"ipywidgets == 8.0.6",
"jsonschema >= 4.17.3"
"jsonschema >= 4.17.3",
"tabulate >= 0.9.0"
]
requires-python = ">=3.6"

Expand Down
141 changes: 141 additions & 0 deletions src/autolabel/data_loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import logging
import pandas as pd
from tabulate import tabulate
from typing import Dict, Union
from datasets import Dataset
from autolabel.data_loaders.read_datasets import (
AutolabelDataset,
CSVReader,
JsonlReader,
HuggingFaceDatasetReader,
# SqlDatasetReader,
DataframeReader,
)
from autolabel.data_loaders.validation import TaskDataValidation
from autolabel.configs import AutolabelConfig

logger = logging.getLogger(__name__)


class DataValidationFailed(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)


class DatasetLoader:
# TODO: add support for reading from SQL databases
# TODO: add support for reading and loading datasets in chunks
MAX_ERROR_DISPLAYED = 100

def __init__(
self,
dataset: Union[str, pd.DataFrame],
config: AutolabelConfig,
max_items: int = 0,
start_index: int = 0,
validate: bool = True,
) -> None:
"""DatasetLoader class to read and load datasets.
Args:
dataset (Union[str, pd.DataFrame]): path to the dataset or the dataframe
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to 0.
start_index (int, optional): start index to read from. Defaults to 0.
"""
self.dataset = dataset
self.config = config
self.max_items = max_items
self.start_index = start_index

self.__al_dataset: AutolabelDataset = None
self.__malformed_records = None
self._read()
if validate:
self._validate()

@property
def dat(
self,
) -> Union[pd.DataFrame, Dataset]:
return self.__al_dataset.dataset

@property
def inputs(
self,
) -> Dict:
return self.__al_dataset.inputs

@property
def gt_labels(
self,
) -> str:
return self.__al_dataset.gt_labels

@property
def columns(
self,
) -> str:
return self.__al_dataset.columns

def _read(
self,
):
if isinstance(self.dataset, str):
if self.dataset.endswith(".csv"):
self.__al_dataset = CSVReader.read(
self.dataset,
self.config,
max_items=self.max_items,
start_index=self.start_index,
)

elif self.dataset.endswith(".jsonl"):
self.__al_dataset = JsonlReader.read(
self.dataset,
self.config,
max_items=self.max_items,
start_index=self.start_index,
)
else:
raise ValueError(f"Unsupported file format: {self.dataset}")
elif isinstance(self.dataset, Dataset):
self.__al_dataset: AutolabelDataset = HuggingFaceDatasetReader.read(
self.dataset, self.config, self.max_items, self.start_index
)
elif isinstance(self.dataset, pd.DataFrame):
self.__al_dataset: AutolabelDataset = DataframeReader.read(
self.dataset, self.config, self.start_index, self.max_items
)

def _validate(self):
"""Validate Data"""
data_validation = TaskDataValidation(config=self.config)

# Validate columns
data_validation.validate_dataset_columns(
dataset_columns=self.__al_dataset.columns
)

# Validate datatype and data format
self.__malformed_records = data_validation.validate(
data=self.__al_dataset.inputs
)

table = tabulate(
Sardhendu marked this conversation as resolved.
Show resolved Hide resolved
self.__malformed_records[0 : self.MAX_ERROR_DISPLAYED],
headers="keys",
tablefmt="fancy_grid",
numalign="center",
stralign="left",
)

logger.warning(
f"Data Validation failed for {len(self.__malformed_records)} records: \n Stats: \n {table}"
)

if len(self.__malformed_records) > 0:
raise DataValidationFailed(
f"Validation failed for {len(self.__malformed_records)} rows."
)
224 changes: 224 additions & 0 deletions src/autolabel/data_loaders/read_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from typing import Dict, List, Union

import logging
import pandas as pd

from pydantic import BaseModel, validator
from datasets import Dataset
from sqlalchemy.sql.selectable import Selectable
from autolabel.configs import AutolabelConfig

logger = logging.getLogger(__name__)
from typing import Union
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant import, this is already imported on line 1



class AutolabelDataset(BaseModel):
"""Data Attributes"""

columns: List
dataset: Union[pd.DataFrame, None]
inputs: List[Dict]
gt_labels: List

@validator("dataset", allow_reuse=True)
def validate_dataframe(cls, value):
if not isinstance(value, pd.DataFrame):
raise ValueError("Value must be a pandas DataFrame")
return value

class Config:
arbitrary_types_allowed = True


class CSVReader:
@staticmethod
def read(
csv_file: str,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> AutolabelDataset:
"""Read the csv file and sets dat, inputs and gt_labels
Args:
csv_file (str): path to the csv file
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
"""
logger.debug(f"reading the csv from: {start_index}")
delimiter = config.delimiter()
label_column = config.label_column()

dat = pd.read_csv(csv_file, sep=delimiter, dtype="str")[start_index:]

dat = dat.astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
)
return AutolabelDataset(
columns=list(dat.columns),
dataset=dat,
inputs=inputs,
gt_labels=gt_labels,
)


class JsonlReader:
@staticmethod
def read(
jsonl_file: str,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> AutolabelDataset:
"""Read the jsonl file and sets dat, inputs and gt_labels
Args:
jsonl_file (str): path to the jsonl file
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
"""
logger.debug(f"reading the jsonl from: {start_index}")
label_column = config.label_column()

dat = pd.read_json(jsonl_file, lines=True, dtype="str")[start_index:]
dat = dat.astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
)
return AutolabelDataset(
columns=list(dat.columns),
dataset=dat,
inputs=inputs,
gt_labels=gt_labels,
)


class HuggingFaceDatasetReader:
@staticmethod
def read(
dataset: Dataset,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> AutolabelDataset:
"""Read the huggingface dataset and sets dat, inputs and gt_labels
Args:
dataset (Dataset): dataset object to read from
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
"""
dataset.set_format("pandas")
dat = dataset[
start_index : max_items if max_items and max_items > 0 else len(dataset)
]

inputs = dat.to_dict(orient="records")
gt_labels = (
None
if not config.label_column()
or not len(inputs)
or config.label_column() not in inputs[0]
else dat[config.label_column()].tolist()
)
return AutolabelDataset(
columns=list(dat.columns),
dataset=dat,
inputs=inputs,
gt_labels=gt_labels,
)


class SqlDatasetReader:
@staticmethod
def read(
sql: Union[str, Selectable],
connection: str,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> AutolabelDataset:
"""Read the sql query and sets dat, inputs and gt_labels
Args:
connection (str): connection string
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
"""
logger.debug(f"reading the sql from: {start_index}")
label_column = config.label_column()

dat = pd.read_sql(sql, connection)[start_index:]
dat = dat.astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
)
return AutolabelDataset(
columns=list(dat.columns),
dataset=dat,
inputs=inputs,
gt_labels=gt_labels,
)


class DataframeReader:
@staticmethod
def read(
df: pd.DataFrame,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> AutolabelDataset:
"""Read the csv file and sets dat, inputs and gt_labels
Args:
df (pd.DataFrame): dataframe to read
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
"""
label_column = config.label_column()

dat = df[start_index:].astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
)
return AutolabelDataset(
columns=list(dat.columns),
dataset=dat,
inputs=inputs,
gt_labels=gt_labels,
)
Loading