Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions nemo_skills/evaluation/evaluator/mcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,56 @@
from tqdm import tqdm

from nemo_skills.evaluation.math_grader import extract_answer
from nemo_skills.utils import get_logger_name, unroll_files
from nemo_skills.utils import get_logger_name, nested_dataclass, unroll_files

LOG = logging.getLogger(get_logger_name(__file__))


@nested_dataclass(kw_only=True)
class MCQEvaluatorConfig:
extract_from_boxed: bool = True
# only used if extract_from_boxed is False
extract_regex: str = r"The final answer is (.+)$"


def eval_mcq(cfg):
# Create config from cfg.eval_config (following pattern from other evaluators)
eval_config = MCQEvaluatorConfig(**cfg.eval_config)

def extract_letter(text, extract_from_boxed: bool = True, extract_regex: str = r"The final answer is (.+)$"):
# extract prediction from boxed{}
parsed = extract_answer(text, extract_from_boxed=extract_from_boxed, extract_regex=extract_regex)
if parsed is not None and len(parsed) != 1:
match = re.findall(r"\b[A-J]\b(?!.*\b[A-J]\b)", parsed, re.DOTALL)
if len(match) > 0:
parsed = match[-1].strip()
# extract prediction from boxed{} or regex
extracted_answer = extract_answer(text, extract_from_boxed=extract_from_boxed, extract_regex=extract_regex)
parsed_letter = None

if extracted_answer is not None:
if len(extracted_answer) == 1:
parsed_letter = extracted_answer
elif len(extracted_answer) > 1:
# try to extract the letter from extracted answer, useful to match <A>, {A}, *A*, etc.
match = re.findall(r"\b[A-Z]\b(?!.*\b[A-Z]\b)", extracted_answer, re.DOTALL)
if len(match) > 0:
parsed_letter = match[-1].strip()

# adapted from https://artificialanalysis.ai/methodology/intelligence-benchmarking#intelligence-index-evaluation-suite-overview
if parsed is None:
if parsed_letter is None:
match = re.findall(r"(?i)[\*\_]{0,2}Answer[\*\_]{0,2}\s*:[\s\*\_]{0,2}\s*([A-Z])(?![a-zA-Z0-9])", text)
if match:
parsed = match[-1].strip()
return parsed
parsed_letter = match[-1].strip().upper()

LOG.info(
f"Final parsed letter: {parsed_letter}, extract_from_boxed: {extract_from_boxed}, extract_regex: {extract_regex}, extracted_answer: {extracted_answer}"
)

return parsed_letter

for file in unroll_files(cfg.input_files):
with open(file, "rt", encoding="utf-8") as fin:
data = [json.loads(line) for line in fin]
with open(file, "wt", encoding="utf-8") as fout:
for sample in tqdm(data):
extract_from_boxed = sample.get("extract_from_boxed", True)
extract_regex = sample.get("extract_regex", r"The final answer is (.+)$")
# Per-sample values override config defaults for backward compatibility
extract_from_boxed = sample.get("extract_from_boxed", eval_config.extract_from_boxed)
extract_regex = sample.get("extract_regex", eval_config.extract_regex)
sample["predicted_answer"] = extract_letter(
sample["generation"], extract_from_boxed=extract_from_boxed, extract_regex=extract_regex
)
Expand Down