Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving returned items to dataset loader fields #317

Merged
merged 3 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 62 additions & 42 deletions src/autolabel/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,25 @@ class DatasetLoader:
# TODO: add support for reading from SQL databases
# TODO: add support for reading and loading datasets in chunks

@staticmethod
def __init__(
self,
dataset: Union[str, pd.DataFrame],
config: AutolabelConfig,
max_items: int = 0,
start_index: int = 0,
) -> None:
self.dataset = dataset
self.config = config
self.max_items = max_items
self.start_index = start_index

if isinstance(dataset, str):
self.read_file(dataset, config, max_items, start_index)
elif isinstance(dataset, pd.DataFrame):
self.read_dataframe(dataset, config, start_index, max_items)

def read_csv(
self,
csv_file: str,
config: AutolabelConfig,
max_items: int = None,
Expand All @@ -34,22 +51,23 @@ def read_csv(
delimiter = config.delimiter()
label_column = config.label_column()

dat = pd.read_csv(csv_file, sep=delimiter, dtype="str")[start_index:]
dat = dat.astype(str)
self.dat = pd.read_csv(csv_file, sep=delimiter, dtype="str")[start_index:]
self.dat = self.dat.astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]
max_items = min(max_items, len(self.dat))
self.dat = self.dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
self.inputs = self.dat.to_dict(orient="records")
self.gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
if not label_column
or not len(self.inputs)
or label_column not in self.inputs[0]
else self.dat[label_column].tolist()
)
return (dat, inputs, gt_labels)

@staticmethod
def read_dataframe(
self,
df: pd.DataFrame,
config: AutolabelConfig,
max_items: int = None,
Expand All @@ -68,21 +86,22 @@ def read_dataframe(
"""
label_column = config.label_column()

dat = df[start_index:].astype(str)
self.dat = df[start_index:].astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]
max_items = min(max_items, len(self.dat))
self.dat = self.dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
self.inputs = self.dat.to_dict(orient="records")
self.gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
if not label_column
or not len(self.inputs)
or label_column not in self.inputs[0]
else self.dat[label_column].tolist()
)
return (dat, inputs, gt_labels)

@staticmethod
def read_jsonl(
self,
jsonl_file: str,
config: AutolabelConfig,
max_items: int = None,
Expand All @@ -102,21 +121,21 @@ def read_jsonl(
logger.debug(f"reading the jsonl from: {start_index}")
label_column = config.label_column()

dat = pd.read_json(jsonl_file, lines=True, dtype="str")[start_index:]
dat = dat.astype(str)
self.dat = pd.read_json(jsonl_file, lines=True, dtype="str")[start_index:]
self.dat = self.dat.astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]
max_items = min(max_items, len(self.dat))
self.dat = self.dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
self.inputs = self.dat.to_dict(orient="records")
self.gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
if not label_column
or not len(self.inputs)
or label_column not in self.inputs[0]
else self.dat[label_column].tolist()
)
return (dat, inputs, gt_labels)

@staticmethod
def read_sql(
self,
sql: Union[str, Selectable],
Expand All @@ -139,22 +158,23 @@ def read_sql(
logger.debug(f"reading the sql from: {start_index}")
label_column = config.label_column()

dat = pd.read_sql(sql, connection)[start_index:]
dat = dat.astype(str)
self.dat = pd.read_sql(sql, connection)[start_index:]
self.dat = self.dat.astype(str)
if max_items and max_items > 0:
max_items = min(max_items, len(dat))
dat = dat[:max_items]
max_items = min(max_items, len(self.dat))
self.dat = self.dat[:max_items]

inputs = dat.to_dict(orient="records")
gt_labels = (
self.inputs = self.dat.to_dict(orient="records")
self.gt_labels = (
None
if not label_column or not len(inputs) or label_column not in inputs[0]
else dat[label_column].tolist()
if not label_column
or not len(self.inputs)
or label_column not in self.inputs[0]
else self.dat[label_column].tolist()
)
return (dat, inputs, gt_labels)

@staticmethod
def read_file(
self,
file: str,
config: AutolabelConfig,
max_items: int = None,
Expand All @@ -175,11 +195,11 @@ def read_file(
Tuple[pd.DataFrame, List[Dict], List]: dataframe, inputs and gt_labels
"""
if file.endswith(".csv"):
return DatasetLoader.read_csv(
return self.read_csv(
file, config, max_items=max_items, start_index=start_index
)
elif file.endswith(".jsonl"):
return DatasetLoader.read_jsonl(
return self.read_jsonl(
file, config, max_items=max_items, start_index=start_index
)
else:
Expand Down
61 changes: 27 additions & 34 deletions src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,16 @@ def run(
csv_file_name = (
output_name if output_name else f"{dataset.replace('.csv','')}_labeled.csv"
)
if isinstance(dataset, str):
df, inputs, gt_labels = DatasetLoader.read_file(
dataset, self.config, max_items, start_index
)
elif isinstance(dataset, pd.DataFrame):
df, inputs, gt_labels = DatasetLoader.read_dataframe(
dataset, self.config, max_items, start_index
)

dataset_loader = DatasetLoader(dataset, self.config, max_items, start_index)

# Initialize task run and check if it already exists
self.task_run = self.db.get_task_run(self.task_object.id, self.dataset.id)
# Resume/Delete the task if it already exists or create a new task run
if self.task_run:
logger.info("Task run already exists.")
self.task_run = self.handle_existing_task_run(
self.task_run, csv_file_name, gt_labels=gt_labels
self.task_run, csv_file_name, gt_labels=dataset_loader.gt_labels
)
else:
self.task_run = self.db.create_task_run(
Expand All @@ -120,7 +114,8 @@ def run(

# If this dataset config is a string, read the corrresponding csv file
if isinstance(seed_examples, str):
_, seed_examples, _ = DatasetLoader.read_csv(seed_examples, self.config)
seed_loader = DatasetLoader(seed_examples, self.config)
seed_examples = seed_loader.inputs

# Check explanations are present in data if explanation_column is passed in
if (
Expand All @@ -133,23 +128,25 @@ def run(
)

self.example_selector = ExampleSelectorFactory.initialize_selector(
self.config, seed_examples, df.keys().tolist()
self.config, seed_examples, dataset_loader.dat.keys().tolist()
)

num_failures = 0
current_index = self.task_run.current_index
cost = 0.0
postfix_dict = {}

indices = range(current_index, len(inputs), self.CHUNK_SIZE)
indices = range(current_index, len(dataset_loader.inputs), self.CHUNK_SIZE)
for current_index in track_with_stats(
indices,
postfix_dict,
total=len(inputs),
total=len(dataset_loader.inputs),
advance=self.CHUNK_SIZE,
console=console,
):
chunk = inputs[current_index : current_index + self.CHUNK_SIZE]
chunk = dataset_loader.inputs[
current_index : current_index + self.CHUNK_SIZE
]
final_prompts = []
for i, input_i in enumerate(chunk):
# Fetch few-shot seed examples
Expand Down Expand Up @@ -224,9 +221,9 @@ def run(
self.db.session, self.task_run.id
)
llm_labels = [LLMAnnotation(**a.llm_annotation) for a in db_result]
if gt_labels:
if dataset_loader.gt_labels:
eval_result = self.task.eval(
llm_labels, gt_labels[: len(llm_labels)]
llm_labels, dataset_loader.gt_labels[: len(llm_labels)]
)

for m in eval_result:
Expand All @@ -250,8 +247,10 @@ def run(
llm_labels = [LLMAnnotation(**a.llm_annotation) for a in db_result]
eval_result = None
# if true labels are provided, evaluate accuracy of predictions
if gt_labels:
eval_result = self.task.eval(llm_labels, gt_labels[: len(llm_labels)])
if dataset_loader.gt_labels:
eval_result = self.task.eval(
llm_labels, dataset_loader.gt_labels[: len(llm_labels)]
)
table = {}
# TODO: serialize and write to file
for m in eval_result:
Expand All @@ -263,7 +262,7 @@ def run(
print_table(table, console=console, default_style=METRIC_TABLE_STYLE)

# Write output to CSV
output_df = df.copy()
output_df = dataset_loader.dat.copy()
output_df[self.config.task_name() + "_llm_labeled_successfully"] = [
l.successfully_labeled for l in llm_labels
]
Expand Down Expand Up @@ -316,14 +315,7 @@ def plan(
dataset: path to a CSV dataset
"""

if isinstance(dataset, str):
df, inputs, _ = DatasetLoader.read_file(
dataset, self.config, max_items, start_index
)
elif isinstance(dataset, pd.DataFrame):
df, inputs, _ = DatasetLoader.read_dataframe(
dataset, self.config, max_items, start_index
)
dataset_loader = DatasetLoader(dataset, self.config, max_items, start_index)

prompt_list = []
total_cost = 0
Expand All @@ -333,7 +325,8 @@ def plan(

# If this dataset config is a string, read the corrresponding csv file
if isinstance(seed_examples, str):
_, seed_examples, _ = DatasetLoader.read_file(seed_examples, self.config)
seed_loader = DatasetLoader(seed_examples, self.config)
seed_examples = seed_loader.inputs

# Check explanations are present in data if explanation_column is passed in
if (
Expand All @@ -346,12 +339,12 @@ def plan(
)

self.example_selector = ExampleSelectorFactory.initialize_selector(
self.config, seed_examples, df.keys().tolist()
self.config, seed_examples, dataset_loader.dat.keys().tolist()
)

input_limit = min(len(inputs), 100)
input_limit = min(len(dataset_loader.inputs), 100)
for input_i in track(
inputs[:input_limit],
dataset_loader.inputs[:input_limit],
description="Generating Prompts...",
console=console,
):
Expand All @@ -367,11 +360,11 @@ def plan(
curr_cost = self.llm.get_cost(prompt=final_prompt, label="")
total_cost += curr_cost

total_cost = total_cost * (len(inputs) / input_limit)
total_cost = total_cost * (len(dataset_loader.inputs) / input_limit)
table = {
"Total Estimated Cost": f"${maybe_round(total_cost)}",
"Number of Examples": len(inputs),
"Average cost per example": f"${maybe_round(total_cost / len(inputs))}",
"Number of Examples": len(dataset_loader.inputs),
"Average cost per example": f"${maybe_round(total_cost / len(dataset_loader.inputs))}",
}
table = {"parameter": list(table.keys()), "value": list(table.values())}
print_table(table, show_header=False, console=console, styles=COST_TABLE_STYLES)
Expand Down