Skip to content

Commit

Permalink
Merge pull request #777 from allenai/in-loop-gsm
Browse files Browse the repository at this point in the history
Add GSM8K in-loop
  • Loading branch information
dirkgr authored Feb 4, 2025
2 parents 8527ffe + 7dde292 commit a209343
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add GSM8K to in-loop evals (BPB over correct continuation)
- Support for specifying custom dataset objects in the `data` section of the config file.


## [v0.6.0](https://github.com/allenai/OLMo/releases/tag/v0.6.0) - 2024-12-17

### Added
Expand Down
17 changes: 14 additions & 3 deletions olmo/eval/downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
log = logging.getLogger(__name__)

# Map from oe-eval metrics to metrics used here
METRIC_FROM_OE_EVAL = {"acc_raw": "acc", "acc_per_char": "len_norm", "acc_uncond": "pmi_dc"}
METRIC_FROM_OE_EVAL = {
"acc_raw": "acc",
"acc_per_char": "len_norm",
"acc_uncond": "pmi_dc",
"logits_per_byte": "bpb",
}
LOG_2_OF_E = 1.44269504089


Expand Down Expand Up @@ -361,9 +366,11 @@ def collate_fn(self, data):
"cont_byte_len": torch.LongTensor(cont_byte_lens),
"input_ids": torch.stack(queries),
"dc_input_ids": torch.stack(dc_queries),
"label_id": torch.LongTensor(label_ids),
}

if not isinstance(label_ids, str):
batch["label_id"] = torch.LongTensor(label_ids)

return batch

def token_encode(self, string: str) -> List[int]:
Expand Down Expand Up @@ -1538,7 +1545,7 @@ def prep_examples(self):
label_id = request["label"]
cont_id = request["idx"]
if self.metric_type in ["ce_loss", "bpb"]:
if label_id != cont_id:
if label_id != cont_id and not isinstance(label_id, str):
# Skip non-target continuations for ce_loss and bpb
continue
else:
Expand Down Expand Up @@ -1758,6 +1765,10 @@ def doc_to_label(self, doc) -> int:
"csqa_rc_0shot_bpb": (OEEvalTask, {"dataset_path": "csqa", "dataset_name": "rc_0shot", "metric_type": "bpb"}),
"csqa_rc_5shot": (OEEvalTask, {"dataset_path": "csqa", "dataset_name": "rc_5shot", "metric_type": "len_norm"}),
"csqa_rc_5shot_bpb": (OEEvalTask, {"dataset_path": "csqa", "dataset_name": "rc_5shot", "metric_type": "bpb"}),
"gsm8k_gold_bpb_5shot": (
OEEvalTask,
{"dataset_path": "gsm8k", "dataset_name": "gold_bpb_5shot", "metric_type": "bpb"},
),
"hellaswag_mc_5shot": (
OEEvalTask,
{"dataset_path": "hellaswag", "dataset_name": "mc_5shot", "metric_type": "acc"},
Expand Down
1 change: 1 addition & 0 deletions olmo_data/oe_eval_tasks/gsm8k/gold_bpb_5shot/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"task_name": "gsm8k", "task_hash": "c9a8b5bfa866f678c3ea4ef06729f149", "task_config": {"task_name": "gsm8k", "task_core": "gsm8k", "limit": null, "split": "test", "num_shots": 8, "fewshot_seed": 1234, "primary_metric": "logits_per_byte", "random_subsample_seed": 1234, "context_kwargs": {"no_cot": false}, "generation_kwargs": {"max_gen_toks": 512, "do_sample": false, "temperature": 0.0, "stop_sequences": ["Question:", "</s>", "<|im_end|>", "\n\n"], "repeats": 1}, "metric_kwargs": {"regexes_to_ignore": [",", "\\$", "(?s).*#### ", "\\.$"]}, "native_id_field": "id", "fewshot_source": "STD:GSM8k", "dataset_path": "gsm8k", "dataset_name": "main", "use_chat_format": null, "version": 0.1, "revision": null, "compute_gold_bpb": true, "metadata": {"alias": "gsm8k::bpb"}}, "current_date": "2025-01-08 21:30:11 UTC", "num_instances": 1319}
Binary file not shown.

0 comments on commit a209343

Please sign in to comment.