Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def main(
with open(args_fn, "w") as fd:
# Convert Path objects to strings
cache_kwargs_json = {
k: str(v) if type(v) == Path else v for k, v in cache_kwargs.items()
k: str(v) if isinstance(v, Path) else v for k, v in cache_kwargs.items()
}
json.dump(cache_kwargs_json, fd, indent=2)

Expand Down
54 changes: 52 additions & 2 deletions metric.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from claudette import models, Chat

import numpy as np
from evaluate import load
import regex as re
from claudette import Chat, models
from evaluate import load


class Metric:
Expand Down Expand Up @@ -73,6 +74,54 @@ def compute(self, prompts, predictions, references):
return self.metric(references, predictions)


class RulerStringMatch(Metric):
"""
Metric used in RULER.
Reference: https://github.com/hsiehjackson/RULER/blob/main/scripts/eval/synthetic/constants.py
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

@staticmethod
def postprocess_pred(predict_str: str):
predict_str = predict_str.strip()

# Remove all non-printable characters
np_pattern = re.compile(r"[\x00-\x1f]")
predict_str = np_pattern.sub("\n", predict_str).strip()

return predict_str

@staticmethod
def string_match_part(refs, preds):
scores = [
max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref])
for pred, ref in zip(preds, refs)
]
score = sum(scores) / len(preds) * 100
return {"score": round(score, 4)}

@staticmethod
def string_match_all(refs, preds):
scores = [
sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
for pred, ref in zip(preds, refs)
]
score = sum(scores) / len(preds) * 100
return {"score": round(score, 4)}

def _load_metric(self, **kwargs):
if kwargs.get("match_part", False):
self.metric = self.string_match_part
else:
self.metric = self.string_match_all

def compute(self, prompts, predictions, references):
predictions = [self.postprocess_pred(pred) for pred in predictions]
return self.metric(references, predictions)


REFERENCE_TEMPLATE = """You are shown ground-truth answer(s) and asked to judge the quality of an LLM-generated answer.
Assign it a score from 1-5 where 1 is the worst and 5 is the best based on how similar it is to the ground-truth(s).
Do NOT explain your choice. Simply return a number from 1-5.
Expand Down Expand Up @@ -184,6 +233,7 @@ def compute(self, prompts, predictions, labels):
"llm-rouge": LLMRouge,
"llm-as-a-judge": LLMJudge,
"rouge": Rouge,
"ruler-string-match": RulerStringMatch,
}


Expand Down
164 changes: 163 additions & 1 deletion task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
from abc import ABC, abstractmethod
from string import ascii_uppercase

import random
import numpy as np
from datasets import load_dataset

Expand Down Expand Up @@ -506,6 +506,164 @@ def _process_logits(self, logits, split):
return preds


class RulerQA(EvaluationTask):
"""
RULER hotpotqa task with 4k context length. (context length can be adjusted as needed)
"""

DEFAULT_PROMPT_TEMPLATE = "{task_input}"

def __init__(
self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=32, **kwargs
):
super().__init__(
prompt_template,
max_tokens,
hf_args=["rbiswasfc/ruler", "qa_2_4k"],
**kwargs,
)

self.metrics = {
"StringMatch": AutoMetric.from_name("ruler-string-match", match_part=True),
}
self.test_split = "validation"

def prepare_row(self, row: dict):
task_input = row["input"]

question = task_input.split("Question:")[-1].split("Answer:")[0].strip()
context = task_input.split("Question:")[0].strip()

prompt = self.prompt_template.format(task_input=task_input)
answer = row["outputs"] # List[str]

return {
"context": context,
"question": question,
"prompt": prompt,
"labels": answer,
}


class RulerNIAH(EvaluationTask):
"""
RULER Multi-keys Needle-in-a-haystack (NIAH) task with 4k context length. (context length can be adjusted as needed)
"""

DEFAULT_PROMPT_TEMPLATE = "{task_input}"

def __init__(
self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=128, **kwargs
):
super().__init__(
prompt_template,
max_tokens,
hf_args=["rbiswasfc/ruler", "niah_multikey_1_4k"],
**kwargs,
)

self.metrics = {
"StringMatch": AutoMetric.from_name("ruler-string-match", match_part=False),
}
self.test_split = "validation"

def prepare_row(self, row: dict):
task_input = row["input"]

question = (
"The special magic number for fair-sprout mentioned in the provided text is"
)
context = task_input

prompt = self.prompt_template.format(task_input=task_input)
answer = row["outputs"] # List[str]

return {
"context": context,
"question": question,
"prompt": prompt,
"labels": answer,
}


class RulerVT(EvaluationTask):
"""
RULER Multi-hop Tracing: Variable Tracking (VT) task with 4k context length. (context length can be adjusted as needed)
"""

DEFAULT_PROMPT_TEMPLATE = "{task_input}"

def __init__(
self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=30, **kwargs
):
super().__init__(
prompt_template,
max_tokens,
hf_args=["rbiswasfc/ruler", "vt_4k"],
**kwargs,
)

self.metrics = {
"StringMatch": AutoMetric.from_name("ruler-string-match", match_part=False),
}
self.test_split = "validation"

def prepare_row(self, row: dict):
task_input = row["input"]

question = task_input.split("Question:")[-1].split("Answer:")[0].strip()
context = task_input.split("Question:")[0].strip()

prompt = self.prompt_template.format(task_input=task_input)
answer = row["outputs"] # List[str]

return {
"context": context,
"question": question,
"prompt": prompt,
"labels": answer,
}


class RulerCWE(EvaluationTask):
"""
RULER Aggregation: Common Words (CWE) task with 4k context length. (context length can be adjusted as needed)
"""

DEFAULT_PROMPT_TEMPLATE = "{task_input}"

def __init__(
self, prompt_template=DEFAULT_PROMPT_TEMPLATE, max_tokens=120, **kwargs
):
super().__init__(
prompt_template,
max_tokens,
hf_args=["rbiswasfc/ruler", "cwe_4k"],
**kwargs,
)

self.metrics = {
"StringMatch": AutoMetric.from_name("ruler-string-match", match_part=False),
}
self.test_split = "validation"

def prepare_row(self, row: dict):
task_input = row["input"]

question = task_input.split("Question:")[-1].split("Answer:")[0].strip()
context = task_input.split("Question:")[0].strip()

prompt = self.prompt_template.format(task_input=task_input)
answer = row["outputs"] # List[str]

return {
"context": context,
"question": question,
"prompt": prompt,
"labels": answer,
}


TASK_MAPPING = {
"squality": Squality,
"triviaqa": TriviaQA,
Expand All @@ -514,6 +672,10 @@ def _process_logits(self, logits, split):
"musique": Musique,
"truthfulqa": TruthfulQA,
"scrollsquality": ScrollsQuality,
"rulerqa": RulerQA,
"rulerniah": RulerNIAH,
"rulervt": RulerVT,
"rulercwe": RulerCWE,
}


Expand Down