Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optionalize evaluation and console progress printing #537

Merged
merged 3 commits into from
Aug 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 65 additions & 40 deletions src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
in_notebook,
)

console = Console()
logger = logging.getLogger(__name__)

COST_TABLE_STYLES = {
Expand All @@ -54,11 +53,13 @@ def __init__(
cache: Optional[bool] = True,
example_selector: Optional[BaseExampleSelector] = None,
create_task: Optional[bool] = True,
console_output: Optional[bool] = True,
) -> None:
self.create_task = create_task
self.db = StateManager() if self.create_task else None
self.generation_cache = SQLAlchemyGenerationCache() if cache else None
self.transform_cache = SQLAlchemyTransformCache() if cache else None
self.console = Console() if console_output else None

self.config = (
config if isinstance(config, AutolabelConfig) else AutolabelConfig(config)
Expand All @@ -84,10 +85,10 @@ def run(
self,
dataset: AutolabelDataset,
output_name: Optional[str] = None,
eval_every: Optional[int] = 50,
additional_metrics: Optional[List[BaseMetric]] = [],
max_items: Optional[int] = None,
start_index: int = 0,
additional_metrics: Optional[List[BaseMetric]] = [],
skip_eval: Optional[bool] = False,
) -> Tuple[pd.Series, pd.DataFrame, List[MetricResult]]:
"""Labels data in a given dataset. Output written to new CSV file.

Expand Down Expand Up @@ -167,12 +168,17 @@ def run(

indices = range(current_index, len(dataset.inputs))

for current_index in track_with_stats(
indices,
postfix_dict,
total=len(dataset.inputs) - current_index,
console=console,
):
if self.console:
tracker = track_with_stats(
indices,
postfix_dict,
total=len(dataset.inputs) - current_index,
console=self.console,
)
else:
tracker = indices

for current_index in tracker:
chunk = dataset.inputs[current_index]

if self.example_selector:
Expand Down Expand Up @@ -219,7 +225,7 @@ def run(
postfix_dict[self.COST_KEY] = f"{cost:.2f}"

# Evaluate the task every eval_every examples
if (current_index + 1) % eval_every == 0:
if not skip_eval and (current_index + 1) % 100 == 0:
llm_labels = self.get_all_annotations()
if dataset.gt_labels:
eval_result = self.task.eval(
Expand Down Expand Up @@ -247,14 +253,15 @@ def run(

llm_labels = self.get_all_annotations()
eval_result = None
table = {}

# if true labels are provided, evaluate accuracy of predictions
if dataset.gt_labels:
if not skip_eval and dataset.gt_labels:
eval_result = self.task.eval(
llm_labels,
dataset.gt_labels[: len(llm_labels)],
additional_metrics=additional_metrics,
)
table = {}
# TODO: serialize and write to file
for m in eval_result:
if isinstance(m.value, list):
Expand All @@ -263,11 +270,13 @@ def run(
table[m.name] = m.value
else:
print(f"{m.name}:\n{m.value}")

# print cost
if self.console:
print(f"Actual Cost: {maybe_round(cost)}")
print_table(table, console=console, default_style=METRIC_TABLE_STYLE)
print_table(table, console=self.console, default_style=METRIC_TABLE_STYLE)

dataset.process_labels(llm_labels, eval_result)

# Only save to csv if output_name is provided or dataset is a string
if not output_name and isinstance(dataset, str):
output_name = (
Expand Down Expand Up @@ -325,11 +334,17 @@ def plan(
)

input_limit = min(len(dataset.inputs), 100)
for input_i in track(
dataset.inputs[:input_limit],
description="Generating Prompts...",
console=console,
):

if self.console:
tracker = track(
dataset.inputs[:input_limit],
description="Generating Prompts...",
console=self.console,
)
else:
tracker = dataset.inputs[:input_limit]

for input_i in tracker:
# TODO: Check if this needs to use the example selector
if self.example_selector:
examples = self.example_selector.select_examples(input_i)
Expand All @@ -349,23 +364,31 @@ def plan(
"Average cost per example": f"${maybe_round(total_cost / len(dataset.inputs))}",
}
table = {"parameter": list(table.keys()), "value": list(table.values())}
print_table(table, show_header=False, console=console, styles=COST_TABLE_STYLES)

console.rule("Prompt Example")
print(f"{prompt_list[0]}")
console.rule()
if self.console:
print_table(
table, show_header=False, console=self.console, styles=COST_TABLE_STYLES
)
self.console.rule("Prompt Example")
print(f"{prompt_list[0]}")
self.console.rule()

async def async_run_transform(
self, transform: BaseTransform, dataset: AutolabelDataset
):
transform_outputs = [
transform.apply(input_dict) for input_dict in dataset.inputs
]
outputs = await gather_async_tasks_with_progress(
transform_outputs,
description=f"Running transform {transform.name()}...",
console=console,
)

if self.console:
outputs = await gather_async_tasks_with_progress(
transform_outputs,
description=f"Running transform {transform.name()}...",
console=self.console,
)
else:
outputs = await asyncio.gather(*transform_outputs)

output_df = pd.DataFrame.from_records(outputs)
final_df = pd.concat([dataset.df, output_df], axis=1)
dataset = AutolabelDataset(final_df, self.config)
Expand Down Expand Up @@ -413,16 +436,11 @@ def handle_existing_task_run(
else:
print(f"{m.name}:\n{m.value}")

print_table(table, console=console, default_style=METRIC_TABLE_STYLE)
if self.console:
print_table(
table, console=self.console, default_style=METRIC_TABLE_STYLE
)
pprint(f"{task_run.current_index} examples labeled so far.")
if len(llm_labels) > 0:
console.rule("Last Annotated Example")
pprint("[bold blue]Prompt[/bold blue]: ", end="")
print(llm_labels[-1].prompt)
pprint("[bold blue]Annotation[/bold blue]: ", end="")
print(llm_labels[-1].label)
console.rule()

if not Confirm.ask("Do you want to resume the task?"):
TaskRunModel.delete_by_id(self.db.session, task_run.id)
pprint("Deleted the existing task and starting a new one...")
Expand Down Expand Up @@ -477,9 +495,16 @@ def generate_explanations(
"The explanation column needs to be specified in the dataset config."
)

for seed_example in track(
seed_examples, description="Generating explanations", console=console
):
if self.console:
tracker = track(
seed_examples,
description="Generating explanations",
console=self.console,
)
else:
tracker = seed_examples

for seed_example in tracker:
explanation_prompt = self.task.get_explanation_prompt(seed_example)
explanation = self.llm.label([explanation_prompt])
explanation = explanation.generations[0][0].text
Expand Down
Loading