Skip to content

Commit

Permalink
Add MMLU prompt variants (#484)
Browse files Browse the repository at this point in the history
  • Loading branch information
OyvindTafjord authored Mar 5, 2024
1 parent cb711e2 commit 493c0b8
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 50 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Added `output_hidden_states` argument and associated functionality to `OLMo` and `OLMoForCausalLM` to return model intermediate hidden states.
- Added MMLU downstream evaluation tasks.
- Added MMLU downstream evaluation tasks, with prompt variations.
- Added support for PyTorch v2.2.

## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02
Expand Down
161 changes: 112 additions & 49 deletions olmo/eval/downstream.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import logging
import re
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Union

Expand All @@ -10,6 +11,8 @@

from ..tokenizer import Tokenizer

log = logging.getLogger(__name__)


class ICLMetric(Metric):
# update method does not require access to global metric state
Expand Down Expand Up @@ -152,13 +155,17 @@ def __init__(
dataset_name: Union[str, Sequence[str], None] = None,
model_ctx_len: int = 2048,
split="validation",
prompts=[None], # List of prompt variants to use
):
super().__init__()

self.tokenizer = tokenizer
self.dataset_path = dataset_path
self.dataset_name = dataset_name
self.model_ctx_len = model_ctx_len
self.prompts = prompts
self.current_prompt = None
self.log_instances = 5 # Log the first few instances as a sanity check

self.samples: List[Dict[str, Any]] = []
dataset_names: Sequence[Optional[str]]
Expand All @@ -174,6 +181,7 @@ def __init__(
path=self.dataset_path,
name=ds_name,
split=split,
trust_remote_code=True,
)
)
self.dataset = datasets.concatenate_datasets(dataset_list)
Expand All @@ -191,51 +199,65 @@ def prep_examples(self):
"""Append doc_ids to each example so that they are processed together in the metric"""
doc_id = 0
for doc in self.dataset:
# from EAI harness
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice

continuations = self.doc_to_continuations(doc)
label_id = self.doc_to_label(doc)
ctx = self.token_encode(self.doc_to_text(doc))
dc = self.token_encode(self.doc_to_domain_conditional(doc))

for cont_id, continuation_str in enumerate(continuations):
cont_str_len = len(continuation_str) - 1 # continuation contain leading blank
continuation = self.token_encode(continuation_str)

# query, remove last token from continuation, truncate from left is longer than model ctx length
query = ctx + continuation[:-1]
query = query[-self.model_ctx_len :]

# get domain conditional query
# we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left
dc_query = dc + continuation[:-1]

# form a sample
self.samples.append(
{
"doc_id": doc_id,
"cont_id": cont_id,
"ctx": ctx,
"continuation": continuation,
"ctx_len": len(ctx),
"dc_len": len(dc),
"cont_len": len(
continuation
), # even if query has last token removed, LM will output same cont len
"cont_str_len": cont_str_len,
"query": query, # remove last token from continuation
"dc_query": dc_query,
"label_id": label_id,
}
)

doc_id += 1
for prompt in self.prompts:
self.current_prompt = prompt
# from EAI harness
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice

continuations = self.doc_to_continuations(doc)
label_id = self.doc_to_label(doc)
doc_text = self.doc_to_text(doc)
ctx = self.token_encode(doc_text)
dc = self.token_encode(self.doc_to_domain_conditional(doc))
if self.log_instances > 0:
self.log_instances -= 1
ds_name = self.dataset_name
if isinstance(ds_name, list):
ds_name = ds_name[0]
log.info(
f"Sample doc from ({self.dataset_path}, {ds_name}, {self.current_prompt}):"
+ f"\ndoc_text: {doc_text}\ncontinuations: {continuations}"
)

for cont_id, continuation_str in enumerate(continuations):
cont_str_len = len(continuation_str) - 1 # continuation contain leading blank
continuation = self.token_encode(continuation_str)

# query, remove last token from continuation, truncate from left is longer than model ctx length
query = ctx + continuation[:-1]
query = query[-self.model_ctx_len :]
# this will be different from len(ctx) when truncated by model_ctx_len
actual_ctx_len = len(query) - len(continuation) + 1

# get domain conditional query
# we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left
dc_query = dc + continuation[:-1]

# form a sample
self.samples.append(
{
"doc_id": doc_id,
"cont_id": cont_id,
"ctx": ctx,
"continuation": continuation,
"ctx_len": actual_ctx_len,
"dc_len": len(dc),
"cont_len": len(
continuation
), # even if query has last token removed, LM will output same cont len
"cont_str_len": cont_str_len,
"query": query, # remove last token from continuation
"dc_query": dc_query,
"label_id": label_id,
}
)

doc_id += 1

def pad_tokens_until_max(self, tokens, max_len=2048):
"""truncate from left if len(tokens) > model_ctx_len, max_len is not considered then
Expand Down Expand Up @@ -655,7 +677,7 @@ def __init__(self, tokenizer, dataset_path="sciq", dataset_name=None):
)

def doc_to_text(self, doc):
return doc["support"] + "\nQuestion: " + doc["question"] + "\nAnswer:".strip()
return doc["support"].strip() + "\nQuestion: " + doc["question"] + "\nAnswer:"

def doc_to_continuations(self, doc):
# add spaces in front of continuation
Expand Down Expand Up @@ -1055,7 +1077,14 @@ class MMLU(ICLMultiChoiceTaskDataset):
"other": ["other", "business", "health"],
}

def __init__(self, tokenizer, dataset_path="hails/mmlu_no_train", dataset_name=None, split="validation"):
def __init__(
self,
tokenizer,
dataset_path="hails/mmlu_no_train",
dataset_name=None,
split="validation",
prompt_variations=None,
):
dataset_names = []
# Collect the relevant categories
if dataset_name in MMLU._categories:
Expand All @@ -1069,10 +1098,40 @@ def __init__(self, tokenizer, dataset_path="hails/mmlu_no_train", dataset_name=N
for name, cats in MMLU._subcategories.items():
if dataset_name in cats:
dataset_names.append(name)
super().__init__(tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_names, split=split)
self.dev_set = {}
if prompt_variations == 1:
prompts = [None, "inst", "inst+1", "inst+2", "inst+3", "inst+4", "inst+5"]
# Need to grab the dev set for the few-shot prompts
for name in dataset_names:
self.dev_set[name] = datasets.load_dataset(
path=dataset_path, name=name, split="dev", trust_remote_code=True
)
super().__init__(
tokenizer=tokenizer,
dataset_path=dataset_path,
dataset_name=dataset_names,
split=split,
prompts=prompts,
)

def doc_to_text(self, doc):
return "Question: " + doc["question"] + "\nAnswer:"
output_text = "Question: " + doc["question"] + "\nAnswer:"
if self.current_prompt is not None:
prefix = ""
if "inst" in self.current_prompt:
subject = doc.get("subject").replace("_", " ")
prefix = f"The following are multiple choice questions (with answers) about {subject}:\n\n"
num_shots = re.findall("\\+(\\d+)", self.current_prompt)
if num_shots:
dev_set = self.dev_set.get(doc.get("subject"), [])
num_shots_int = int(num_shots[0])
for idx, dev_doc in enumerate(dev_set):
if idx >= num_shots_int:
break
answer = dev_doc["choices"][dev_doc["answer"]]
prefix += "Question: " + dev_doc["question"] + "\nAnswer: " + answer + "\n\n"
output_text = prefix + output_text
return output_text

def doc_to_continuations(self, doc):
# add spaces in front of continuation
Expand Down Expand Up @@ -1108,4 +1167,8 @@ def doc_to_domain_conditional(self, doc):
"mmlu_humanities": (MMLU, {"dataset_name": "humanities"}),
"mmlu_social_sciences": (MMLU, {"dataset_name": "social_sciences"}),
"mmlu_other": (MMLU, {"dataset_name": "other"}),
"mmlu_stem_var": (MMLU, {"dataset_name": "stem", "prompt_variations": 1}),
"mmlu_humanities_var": (MMLU, {"dataset_name": "humanities", "prompt_variations": 1}),
"mmlu_social_sciences_var": (MMLU, {"dataset_name": "social_sciences", "prompt_variations": 1}),
"mmlu_other_var": (MMLU, {"dataset_name": "other", "prompt_variations": 1}),
}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"numpy",
"torch>=2.0,<2.3",
"omegaconf",
"fsspec==2023.5.0", # temporary fix for HF dataset downloads
"rich",
"boto3",
"google-cloud-storage",
Expand Down

0 comments on commit 493c0b8

Please sign in to comment.