Skip to content

Commit 33ad5db

Browse files
YannDubsLLM-Alignment-sh
authored andcommitted
[ENH] add metadata to completion: date, version,... (tatsu-lab#402)
* [ENH] add data, version, and cleaner metadata col * [ENH] allow to show multiple version of packages. * [ENH] forward kwargs to single annotator * [ENH] add annotator to column name * [ENH] cache and test completion * [ENH] better use of relative directories for paths * [BUG] pass tests
1 parent d74acf6 commit 33ad5db

File tree

7 files changed

+186
-38
lines changed

7 files changed

+186
-38
lines changed

Diff for: src/alpaca_eval/annotators/base.py

+55-18
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
import json
33
import logging
44
import os
5+
from datetime import datetime
56
from functools import partial
67
from pathlib import Path
78
from typing import Any, Callable, Optional, Sequence, Type, Union
89

910
import numpy as np
1011
import pandas as pd
1112

13+
import alpaca_eval
14+
1215
from .. import completion_parsers, constants, processors, types, utils
1316
from ..decoders import get_fn_completions
1417

@@ -50,10 +53,12 @@ class BaseAnnotator(abc.ABC):
5053
Keys use to distinguish the example.
5154
5255
other_output_keys_to_keep : sequence of str, optional
53-
Other output columns to store besides the annotations.
56+
Other output columns to store besides the annotations. You can use `{annotation_key}` to refer to the name
57+
of the annotation column.
5458
5559
other_input_keys_to_keep : sequence of str, optional
56-
Other columns to keep from the input dataframe besides the primary keys.
60+
Other columns to keep from the input dataframe besides the primary keys. You can use `{annotation_key}` to refer
61+
to the name of the annotation column.
5762
5863
is_store_missing_annotations : bool, optional
5964
Whether to store missing annotations. If True it avoids trying to reannotate examples that have errors.
@@ -90,16 +95,19 @@ def __init__(
9095
seed: Optional[int] = 0,
9196
is_avoid_reannotations: bool = True,
9297
other_output_keys_to_keep: Sequence[str] = (
93-
"price_per_example",
94-
"time_per_example",
95-
"raw_completion",
98+
"{annotation_key}_price_per_example",
99+
"{annotation_key}_time_per_example",
100+
"{annotation_key}_version",
101+
"{annotation_key}_date",
102+
"{annotation_key}_raw_completion",
96103
),
97104
other_input_keys_to_keep: Sequence[str] = (),
98105
is_store_missing_annotations: bool = True,
99106
base_dir: Optional[Union[types.AnyPath, Sequence[types.AnyPath]]] = None,
100107
is_raise_if_missing_primary_keys: bool = True,
101108
annotation_type: Optional[Type] = None,
102109
is_reapply_parsing: bool = False,
110+
**single_annotator_kwargs,
103111
):
104112
logging.info(f"Creating the annotator from `{annotators_config}`.")
105113
base_dir = base_dir or self.DEFAULT_BASE_DIR
@@ -123,9 +131,11 @@ def __init__(
123131
if self.annotators_config.exists():
124132
break
125133

126-
self.annotators = self._initialize_annotators()
134+
self.annotators = self._initialize_annotators(**single_annotator_kwargs)
127135
self.df_annotations = None
128136

137+
other_output_keys_to_keep = [c.format(annotation_key=self.annotation_key) for c in other_output_keys_to_keep]
138+
other_input_keys_to_keep = [c.format(annotation_key=self.annotation_key) for c in other_input_keys_to_keep]
129139
self.other_input_keys_to_keep = self._get_other_input_keys_to_keep(other_input_keys_to_keep)
130140
self.other_output_keys_to_keep = self._get_other_output_keys_to_keep(other_output_keys_to_keep)
131141
self.other_keys_to_keep = self.other_output_keys_to_keep + self.other_input_keys_to_keep
@@ -148,6 +158,11 @@ def annotation_key(self) -> str:
148158
"""How to refer to the annotations, this will be the key for annotations in the output."""
149159
return "annotation"
150160

161+
@property
162+
def completion_key(self) -> str:
163+
"""How to refer to the raw completions, this will be the key for raw completions in the output."""
164+
return f"{self.annotation_key}_raw_completion"
165+
151166
@property
152167
def random_seed_keys(self) -> list[str]:
153168
"""What key / column to seed on for the random generator."""
@@ -227,7 +242,7 @@ def _initialize_annotators_config(self, annotators_config):
227242

228243
return annotators_config
229244

230-
def _initialize_annotators(self) -> dict[str, "SingleAnnotator"]:
245+
def _initialize_annotators(self, **kwargs) -> dict[str, "SingleAnnotator"]:
231246
"""Load all the configs and prompts if necessary."""
232247
annotators_config = utils.load_configs(self.annotators_config)
233248
try:
@@ -241,7 +256,9 @@ def _initialize_annotators(self) -> dict[str, "SingleAnnotator"]:
241256
seed=self.seed,
242257
base_dir=base_dir,
243258
annotation_column=self.annotation_key,
259+
completion_column=self.completion_key,
244260
**annotator_config,
261+
**kwargs,
245262
)
246263
for name, annotator_config in annotators_config.items()
247264
}
@@ -311,8 +328,8 @@ def _annotate(self, df_to_annotate: pd.DataFrame, **decoding_kwargs) -> pd.DataF
311328
]
312329
# if df_to_annotate "raw_completion" is a dict, put it back to a json string so that you can reparse it
313330
# TODO: this is for backward compatibility, remove in the future
314-
if "raw_completion" in df_to_annotate.columns:
315-
df_to_annotate["raw_completion"] = df_to_annotate["raw_completion"].apply(
331+
if self.completion_key in df_to_annotate.columns:
332+
df_to_annotate[self.completion_key] = df_to_annotate[self.completion_key].apply(
316333
lambda x: json.dumps(x) if isinstance(x, dict) else x
317334
)
318335

@@ -583,11 +600,11 @@ class SingleAnnotator:
583600
annotation_column : str, optional
584601
Name of the annotation column in the output dataframe.
585602
586-
is_store_raw_completions : bool, optional
587-
Whether to store raw completions at `"raw_completion"` column in the output dataframe. Note that raw_completion
588-
will not be modified by the postprocessors. E.g. if we switch the columns output_1 and output_2 in the prompt
589-
then the raw completion will show the switched order, which makes interpretation harder. This should
590-
nevertheless not be an issue when using reapply_parsing because of seeding.
603+
completion_column : str, optional
604+
Name of the raw completion column in the output dataframe. If None will not store the raw completions. Note that
605+
raw_completion will not be modified by the postprocessors. E.g. if we switch the columns output_1 and output_2
606+
in the prompt then the raw completion will show the switched order, which makes interpretation harder. This
607+
should nevertheless not be an issue when using reapply_parsing because of seeding.
591608
592609
processors_to_kwargs : Sequence[dict(str, dict)], optional
593610
A dictionary of BaseProcessor objects to apply for preprocessing the dataframe before making the prompts and
@@ -599,6 +616,9 @@ class SingleAnnotator:
599616
600617
completion_key : str, optional
601618
Key of the output of `fn_completions` to use for parsing the completions into annotations.
619+
620+
packages_for_which_to_show_version : Sequence[str], optional
621+
List of packages for which to show the version in the metadata of the completions.
602622
"""
603623

604624
def __init__(
@@ -613,10 +633,12 @@ def __init__(
613633
batch_size: int = 1,
614634
base_dir: types.AnyPath = constants.EVALUATORS_CONFIG_DIR,
615635
annotation_column: str = "annotation",
616-
is_store_raw_completions: bool = True,
636+
completion_column: Optional[str] = "raw_completion",
617637
processors_to_kwargs: Optional[dict[str, dict]] = None,
618638
is_add_default_processors: bool = True,
619639
completion_key: str = "completions",
640+
packages_for_which_to_show_version: Optional[Sequence[str]] = ("alpaca_eval",),
641+
prfx_to_completion_cols: Optional[str] = "{annotation_column}_",
620642
# The following two keys are only for the documentation
621643
pretty_name: Optional[str] = None,
622644
link: Optional[str] = None,
@@ -637,7 +659,11 @@ def __init__(
637659
self.is_shuffle = is_shuffle
638660
self.batch_size = batch_size
639661
self.annotation_column = annotation_column
640-
self.completion_column = "raw_completion" if is_store_raw_completions else None
662+
self.completion_column = completion_column
663+
self.packages_for_which_to_show_version = packages_for_which_to_show_version
664+
if prfx_to_completion_cols is None:
665+
prfx_to_completion_cols = ""
666+
self.prfx_to_completion_cols = prfx_to_completion_cols.format(annotation_column=annotation_column)
641667

642668
self.is_add_default_processors = is_add_default_processors
643669
self.processors = []
@@ -690,9 +716,14 @@ def __call__(self, df_to_annotate: pd.DataFrame, **decoding_kwargs) -> pd.DataFr
690716
# prompts and completions here will not be the same length as the dataframe due to batching
691717
prompts, df_to_annotate = self._make_prompts(df_to_annotate)
692718
completions = self.fn_completions(prompts=prompts, **self.completions_kwargs, **decoding_kwargs)
719+
self._add_metadata_to_completions_(completions)
720+
completions = {
721+
f"{self.prfx_to_completion_cols}{k}" if k != self.completion_key else k: v
722+
for k, v in completions.items()
723+
}
693724

694725
for k, v in completions.items():
695-
if k != "completions":
726+
if k != self.completion_key:
696727
if self.batch_size != 1 and (len(df_to_annotate) == len(v) * self.batch_size):
697728
v = [el for el in v for _ in range(self.batch_size)]
698729
df_to_annotate[k] = v
@@ -735,7 +766,7 @@ def _search_processor(self, name: Union[str, Type["processors.BaseProcessor"]])
735766
return name
736767

737768
def _get_prompt_template(self, prompt_template: types.AnyPath):
738-
return utils.read_or_return(self.base_dir / prompt_template)
769+
return utils.read_or_return(prompt_template, relative_to=self.base_dir)
739770

740771
def _make_prompts(
741772
self, df_to_annotate: pd.DataFrame, prompt_template: Optional[str] = None
@@ -762,6 +793,12 @@ def _make_prompts(
762793
prompt_template = self.prompt_template
763794
return utils.make_prompts(df=df_to_annotate, template=prompt_template, batch_size=self.batch_size)
764795

796+
def _add_metadata_to_completions_(self, completions: dict[str, Any]):
797+
"""Add metadata to the completions."""
798+
completions["date"] = datetime.now().isoformat()
799+
if self.packages_for_which_to_show_version is not None:
800+
completions["version"] = utils.get_multi_package_version(self.packages_for_which_to_show_version)
801+
765802
def _preprocess(self, df_to_annotate: pd.DataFrame) -> pd.DataFrame:
766803
"""Preprocess the examples before annotating. In particular, takes care of all the randomization."""
767804

Diff for: src/alpaca_eval/decoders/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -102,5 +102,15 @@ def get_fn_completions(name: Union[str, Callable]) -> Callable:
102102
logging.exception(f"You need {packages} to use bedrock_anthropic. Error:")
103103
raise e
104104

105+
elif name == "cache_completions":
106+
from .cache import cache_completions
107+
108+
return cache_completions
109+
110+
elif name == "test_completions":
111+
from .test import test_completions
112+
113+
return test_completions
114+
105115
else:
106116
raise ValueError(f"Unknown decoder: {name}")

Diff for: src/alpaca_eval/decoders/cache.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import json
2+
from pathlib import Path
3+
from typing import Sequence
4+
5+
from alpaca_eval.decoders import get_fn_completions
6+
from alpaca_eval.types import AnyPath
7+
8+
__all__ = ["cache_completions"]
9+
10+
11+
def cache_completions(prompts: Sequence[str], fn_completions: str, cache_path: AnyPath, **completions_kwargs):
12+
"""Simple wrapper around a completion function to cache the results to JSON on disk.
13+
Parameters
14+
----------
15+
prompts : list of str
16+
Prompts to get completions for.
17+
18+
fn_completions : str
19+
Function in `decoders.py` to use for decoding the output.
20+
21+
cache_path : str
22+
Path to the cache file.
23+
24+
completions_kwargs : dict
25+
kwargs for fn_completions. E.g. model_name, max_tokens, temperature, top_p, top_k, stop_seq.
26+
27+
"""
28+
assert isinstance(fn_completions, str), "fn_completions must be a string to be hashable."
29+
all_args = [dict(prompt=p, fn_completions=fn_completions, completions_kwargs=completions_kwargs) for p in prompts]
30+
31+
cache_path = Path(cache_path)
32+
33+
try:
34+
with open(cache_path, "r") as f:
35+
cache = json.load(f)
36+
except FileNotFoundError:
37+
cache_path.parent.mkdir(parents=True, exist_ok=True)
38+
cache = {}
39+
40+
outs = []
41+
fn_completions = get_fn_completions(fn_completions)
42+
for args in all_args:
43+
hashable_args = json.dumps(args, sort_keys=True)
44+
if hashable_args not in cache:
45+
cache[hashable_args] = fn_completions(prompts=[args["prompt"]], **args["completions_kwargs"])
46+
outs.append(cache[hashable_args])
47+
48+
with open(cache_path, "w") as f:
49+
json.dump(cache, f)
50+
51+
return outs

Diff for: src/alpaca_eval/decoders/test.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import logging
2+
from typing import Sequence
3+
4+
from .. import utils
5+
6+
__all__ = ["test_completions"]
7+
8+
9+
def test_completions(
10+
prompts: Sequence[str],
11+
model_name="test",
12+
value: str = "{'name': 'test'}",
13+
**decoding_kwargs,
14+
) -> dict[str, list]:
15+
"""Completion function for testing purposes. Returns the same value for all prompts."""
16+
17+
n_examples = len(prompts)
18+
19+
kwargs = dict(model_name=model_name, **decoding_kwargs)
20+
logging.info(f"Kwargs to completion: {kwargs}")
21+
with utils.Timer() as t:
22+
responses = [value for _ in prompts]
23+
avg_time = [t.duration / n_examples] * len(responses)
24+
price_per_example = [0] * len(responses)
25+
return dict(
26+
completions=responses, price_per_example=price_per_example, time_per_example=avg_time, completions_all=responses
27+
)

Diff for: src/alpaca_eval/main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def get_completions(configs, df: pd.DataFrame, old_output_path: Optional[Path] =
322322
if len(curr_outputs) > 0:
323323
prompts, _ = utils.make_prompts(
324324
curr_outputs,
325-
template=utils.read_or_return(base_dir / configs["prompt_template"]),
325+
template=utils.read_or_return(configs["prompt_template"], relative_to=base_dir),
326326
)
327327
fn_completions = decoders.get_fn_completions(configs["fn_completions"])
328328
completions = fn_completions(prompts=prompts, **configs["completions_kwargs"])["completions"]

0 commit comments

Comments
 (0)