Skip to content

Commit

Permalink
Moving returned items to dataset loader fields (#317)
Browse files Browse the repository at this point in the history
* Moving returned items to dataset loader fields

* updated comments, added underscores, updated tests

---------

Co-authored-by: Rajas Bansal <[email protected]>
Co-authored-by: Tyler <[email protected]>
  • Loading branch information
3 people authored Jun 21, 2023
1 parent 192e1e5 commit cf272a3
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 144 deletions.
157 changes: 85 additions & 72 deletions src/autolabel/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,155 +14,171 @@ class DatasetLoader:
# TODO: add support for reading from SQL databases
# TODO: add support for reading and loading datasets in chunks

@staticmethod
def read_csv(
def __init__(
self,
dataset: Union[str, pd.DataFrame],
config: AutolabelConfig,
max_items: int = 0,
start_index: int = 0,
) -> None:
"""DatasetLoader class to read and load datasets.
Args:
dataset (Union[str, pd.DataFrame]): path to the dataset or the dataframe
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to 0.
start_index (int, optional): start index to read from. Defaults to 0.
"""
self.dataset = dataset
self.config = config
self.max_items = max_items
self.start_index = start_index

if isinstance(dataset, str):
self._read_file(dataset, config, max_items, start_index)
elif isinstance(dataset, pd.DataFrame):
self._read_dataframe(dataset, config, start_index, max_items)

def _read_csv(
self,
csv_file: str,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> Tuple[pd.DataFrame, List[Dict], List]:
"""Read the csv file and return the dataframe, inputs and gt_labels
) -> None:
"""Read the csv file and sets dat, inputs and gt_labels
Args:
csv_file (str): path to the csv file
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
Returns:
Tuple[pd.DataFrame, List[Dict], List]: dataframe, inputs and gt_labels
"""
logger.debug(f"reading the csv from: {start_index}")
delimiter = config.delimiter()
label_column = config.label_column()

dat = pd.read_csv(csv_file, sep=delimiter, dtype="str")[start_index:]
dat = dat.astype(str)
self.dat = pd.read_csv(csv_file, sep=delimiter, dtype="str")[start_index:]
self.dat = self.dat.astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]
max_items = min(max_items, len(self.dat))
self.dat = self.dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
self.inputs = self.dat.to_dict(orient="records")
self.gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
if not label_column
or not len(self.inputs)
or label_column not in self.inputs[0]
else self.dat[label_column].tolist()
)
return (dat, inputs, gt_labels)

@staticmethod
def read_dataframe(
def _read_dataframe(
self,
df: pd.DataFrame,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> Tuple[pd.DataFrame, List[Dict], List]:
"""Read the csv file and return the dataframe, inputs and gt_labels
) -> None:
"""Read the csv file and sets dat, inputs and gt_labels
Args:
df (pd.DataFrame): dataframe to read
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
Returns:
Tuple[pd.DataFrame, List[Dict], List]: dataframe, inputs and gt_labels
"""
label_column = config.label_column()

dat = df[start_index:].astype(str)
self.dat = df[start_index:].astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]
max_items = min(max_items, len(self.dat))
self.dat = self.dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
self.inputs = self.dat.to_dict(orient="records")
self.gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
if not label_column
or not len(self.inputs)
or label_column not in self.inputs[0]
else self.dat[label_column].tolist()
)
return (dat, inputs, gt_labels)

@staticmethod
def read_jsonl(
def _read_jsonl(
self,
jsonl_file: str,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> Tuple[pd.DataFrame, List[Dict], List]:
"""Read the jsonl file and return the dataframe, inputs and gt_labels
) -> None:
"""Read the jsonl file and sets dat, inputs and gt_labels
Args:
jsonl_file (str): path to the jsonl file
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
Returns:
Tuple[pd.DataFrame, List[Dict], List]: dataframe, inputs and gt_labels
"""
logger.debug(f"reading the jsonl from: {start_index}")
label_column = config.label_column()

dat = pd.read_json(jsonl_file, lines=True, dtype="str")[start_index:]
dat = dat.astype(str)
self.dat = pd.read_json(jsonl_file, lines=True, dtype="str")[start_index:]
self.dat = self.dat.astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]
max_items = min(max_items, len(self.dat))
self.dat = self.dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
self.inputs = self.dat.to_dict(orient="records")
self.gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
if not label_column
or not len(self.inputs)
or label_column not in self.inputs[0]
else self.dat[label_column].tolist()
)
return (dat, inputs, gt_labels)

@staticmethod
def read_sql(
def _read_sql(
self,
sql: Union[str, Selectable],
connection: str,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> Tuple[pd.DataFrame, List[Dict], List]:
"""Read the sql query and return the dataframe, inputs and gt_labels
) -> None:
"""Read the sql query and sets dat, inputs and gt_labels
Args:
connection (str): connection string
config (AutolabelConfig): config object
max_items (int, optional): max number of items to read. Defaults to None.
start_index (int, optional): start index to read from. Defaults to 0.
Returns:
Tuple[pd.DataFrame, List[Dict], List]: dataframe, inputs and gt_labels
"""
logger.debug(f"reading the sql from: {start_index}")
label_column = config.label_column()

dat = pd.read_sql(sql, connection)[start_index:]
dat = dat.astype(str)
self.dat = pd.read_sql(sql, connection)[start_index:]
self.dat = self.dat.astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]
max_items = min(max_items, len(self.dat))
self.dat = self.dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
self.inputs = self.dat.to_dict(orient="records")
self.gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
if not label_column
or not len(self.inputs)
or label_column not in self.inputs[0]
else self.dat[label_column].tolist()
)
return (dat, inputs, gt_labels)

@staticmethod
def read_file(
def _read_file(
self,
file: str,
config: AutolabelConfig,
max_items: int = None,
start_index: int = 0,
) -> Tuple[pd.DataFrame, List[Dict], List]:
"""Read the file and return the dataframe, inputs and gt_labels
) -> None:
"""Read the file and sets dat, inputs and gt_labels
Args:
file (str): path to the file
Expand All @@ -172,16 +188,13 @@ def read_file(
Raises:
ValueError: if the file format is not supported
Returns:
Tuple[pd.DataFrame, List[Dict], List]: dataframe, inputs and gt_labels
"""
if file.endswith(".csv"):
return DatasetLoader.read_csv(
return self._read_csv(
file, config, max_items=max_items, start_index=start_index
)
elif file.endswith(".jsonl"):
return DatasetLoader.read_jsonl(
return self._read_jsonl(
file, config, max_items=max_items, start_index=start_index
)
else:
Expand Down
Loading

0 comments on commit cf272a3

Please sign in to comment.