From 298b77acbe94b40a721ed0ca59f7940a77f62b97 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 4 Mar 2024 12:00:14 -0800 Subject: [PATCH 1/8] Adds ability to show all logs --- olmo/util.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/olmo/util.py b/olmo/util.py index 71ee67e60..933640a41 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -59,6 +59,7 @@ def __repr__(self) -> str: class LogFilterType(StrEnum): rank0_only = "rank0_only" local_rank0_only = "local_rank0_only" + firehose = "firehose" def log_extra_field(field_name: str, field_value: Any) -> None: @@ -126,11 +127,12 @@ def local_rank0_filter(record: logging.LogRecord) -> int: else: return 0 - filter = None if log_filter_type == LogFilterType.rank0_only: filter = rank0_filter elif log_filter_type == LogFilterType.local_rank0_only: filter = local_rank0_filter # type: ignore + elif log_filter_type == LogFilterType.firehose: + filter = None else: raise ValueError(log_filter_type) From f5951ff3e2142f97a803fb42e2620659c44f0c6e Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 4 Mar 2024 12:02:07 -0800 Subject: [PATCH 2/8] Changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3724bbffe..3d001e6a9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `output_hidden_states` argument and associated functionality to `OLMo` and `OLMoForCausalLM` to return model intermediate hidden states. - Added MMLU downstream evaluation tasks. - Added support for PyTorch v2.2. +- Added ability to show logs from all ranks + ## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02 From f55e7d5c21b576096348608b3ce198939aa8310f Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 26 Feb 2024 17:11:27 -0800 Subject: [PATCH 3/8] Set start method right away --- scripts/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/train.py b/scripts/train.py index de97e31be..19292d268 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -8,6 +8,7 @@ import torch import torch.distributed as dist +import multiprocessing as mp import wandb from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -240,6 +241,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: if __name__ == "__main__": + mp.set_start_method("spawn") # Initialize process group. dist.init_process_group(backend="nccl") From 9d47c23623ae685bdb8662cdc9e409e002fcacc6 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 4 Mar 2024 12:44:27 -0800 Subject: [PATCH 4/8] Changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3724bbffe..59adc1313 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,12 +13,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added the option to directly pass input embeddings to `OLMo` and `OLMoForCausalLM`. - Added support for Python 3.8. - Added code to throw an error if `output_attentions` is set to `True` in forward call to `OLMoForCausalLM`. This functionality hasn't been implemented yet. +- Fixed running with data loading workers on LUMI ### Added - Added `output_hidden_states` argument and associated functionality to `OLMo` and `OLMoForCausalLM` to return model intermediate hidden states. - Added MMLU downstream evaluation tasks. - Added support for PyTorch v2.2. + ## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02 ### Fixed From dd454bdaced49f9809162a1f7c7d667c45a8091f Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Mon, 4 Mar 2024 12:46:44 -0800 Subject: [PATCH 5/8] Call it "all_ranks" instead --- olmo/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/olmo/util.py b/olmo/util.py index 933640a41..c0519e9a1 100644 --- a/olmo/util.py +++ b/olmo/util.py @@ -59,7 +59,7 @@ def __repr__(self) -> str: class LogFilterType(StrEnum): rank0_only = "rank0_only" local_rank0_only = "local_rank0_only" - firehose = "firehose" + all_ranks = "all_ranks" def log_extra_field(field_name: str, field_value: Any) -> None: @@ -131,7 +131,7 @@ def local_rank0_filter(record: logging.LogRecord) -> int: filter = rank0_filter elif log_filter_type == LogFilterType.local_rank0_only: filter = local_rank0_filter # type: ignore - elif log_filter_type == LogFilterType.firehose: + elif log_filter_type == LogFilterType.all_ranks: filter = None else: raise ValueError(log_filter_type) From 493c0b83e131676e36cd932f6960d78d0987b689 Mon Sep 17 00:00:00 2001 From: Oyvind Tafjord Date: Tue, 5 Mar 2024 12:13:27 -0800 Subject: [PATCH 6/8] Add MMLU prompt variants (#484) --- CHANGELOG.md | 2 +- olmo/eval/downstream.py | 161 ++++++++++++++++++++++++++++------------ pyproject.toml | 1 + 3 files changed, 114 insertions(+), 50 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3724bbffe..36fce791e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,7 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `output_hidden_states` argument and associated functionality to `OLMo` and `OLMoForCausalLM` to return model intermediate hidden states. -- Added MMLU downstream evaluation tasks. +- Added MMLU downstream evaluation tasks, with prompt variations. - Added support for PyTorch v2.2. ## [v0.2.4](https://github.com/allenai/OLMo/releases/tag/v0.2.4) - 2024-02-02 diff --git a/olmo/eval/downstream.py b/olmo/eval/downstream.py index 815a24956..fc823cb3e 100644 --- a/olmo/eval/downstream.py +++ b/olmo/eval/downstream.py @@ -1,4 +1,5 @@ import abc +import logging import re from typing import Any, ClassVar, Dict, List, Optional, Sequence, Union @@ -10,6 +11,8 @@ from ..tokenizer import Tokenizer +log = logging.getLogger(__name__) + class ICLMetric(Metric): # update method does not require access to global metric state @@ -152,6 +155,7 @@ def __init__( dataset_name: Union[str, Sequence[str], None] = None, model_ctx_len: int = 2048, split="validation", + prompts=[None], # List of prompt variants to use ): super().__init__() @@ -159,6 +163,9 @@ def __init__( self.dataset_path = dataset_path self.dataset_name = dataset_name self.model_ctx_len = model_ctx_len + self.prompts = prompts + self.current_prompt = None + self.log_instances = 5 # Log the first few instances as a sanity check self.samples: List[Dict[str, Any]] = [] dataset_names: Sequence[Optional[str]] @@ -174,6 +181,7 @@ def __init__( path=self.dataset_path, name=ds_name, split=split, + trust_remote_code=True, ) ) self.dataset = datasets.concatenate_datasets(dataset_list) @@ -191,51 +199,65 @@ def prep_examples(self): """Append doc_ids to each example so that they are processed together in the metric""" doc_id = 0 for doc in self.dataset: - # from EAI harness - # how this all works: - # CTX CONT - # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] - # gpt2 \ \ - # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the - # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice - - continuations = self.doc_to_continuations(doc) - label_id = self.doc_to_label(doc) - ctx = self.token_encode(self.doc_to_text(doc)) - dc = self.token_encode(self.doc_to_domain_conditional(doc)) - - for cont_id, continuation_str in enumerate(continuations): - cont_str_len = len(continuation_str) - 1 # continuation contain leading blank - continuation = self.token_encode(continuation_str) - - # query, remove last token from continuation, truncate from left is longer than model ctx length - query = ctx + continuation[:-1] - query = query[-self.model_ctx_len :] - - # get domain conditional query - # we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left - dc_query = dc + continuation[:-1] - - # form a sample - self.samples.append( - { - "doc_id": doc_id, - "cont_id": cont_id, - "ctx": ctx, - "continuation": continuation, - "ctx_len": len(ctx), - "dc_len": len(dc), - "cont_len": len( - continuation - ), # even if query has last token removed, LM will output same cont len - "cont_str_len": cont_str_len, - "query": query, # remove last token from continuation - "dc_query": dc_query, - "label_id": label_id, - } - ) - - doc_id += 1 + for prompt in self.prompts: + self.current_prompt = prompt + # from EAI harness + # how this all works: + # CTX CONT + # inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1] + # gpt2 \ \ + # logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the + # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice + + continuations = self.doc_to_continuations(doc) + label_id = self.doc_to_label(doc) + doc_text = self.doc_to_text(doc) + ctx = self.token_encode(doc_text) + dc = self.token_encode(self.doc_to_domain_conditional(doc)) + if self.log_instances > 0: + self.log_instances -= 1 + ds_name = self.dataset_name + if isinstance(ds_name, list): + ds_name = ds_name[0] + log.info( + f"Sample doc from ({self.dataset_path}, {ds_name}, {self.current_prompt}):" + + f"\ndoc_text: {doc_text}\ncontinuations: {continuations}" + ) + + for cont_id, continuation_str in enumerate(continuations): + cont_str_len = len(continuation_str) - 1 # continuation contain leading blank + continuation = self.token_encode(continuation_str) + + # query, remove last token from continuation, truncate from left is longer than model ctx length + query = ctx + continuation[:-1] + query = query[-self.model_ctx_len :] + # this will be different from len(ctx) when truncated by model_ctx_len + actual_ctx_len = len(query) - len(continuation) + 1 + + # get domain conditional query + # we don't expect this to be longer than self.model_ctx_len and it won't make sense to truncate from left + dc_query = dc + continuation[:-1] + + # form a sample + self.samples.append( + { + "doc_id": doc_id, + "cont_id": cont_id, + "ctx": ctx, + "continuation": continuation, + "ctx_len": actual_ctx_len, + "dc_len": len(dc), + "cont_len": len( + continuation + ), # even if query has last token removed, LM will output same cont len + "cont_str_len": cont_str_len, + "query": query, # remove last token from continuation + "dc_query": dc_query, + "label_id": label_id, + } + ) + + doc_id += 1 def pad_tokens_until_max(self, tokens, max_len=2048): """truncate from left if len(tokens) > model_ctx_len, max_len is not considered then @@ -655,7 +677,7 @@ def __init__(self, tokenizer, dataset_path="sciq", dataset_name=None): ) def doc_to_text(self, doc): - return doc["support"] + "\nQuestion: " + doc["question"] + "\nAnswer:".strip() + return doc["support"].strip() + "\nQuestion: " + doc["question"] + "\nAnswer:" def doc_to_continuations(self, doc): # add spaces in front of continuation @@ -1055,7 +1077,14 @@ class MMLU(ICLMultiChoiceTaskDataset): "other": ["other", "business", "health"], } - def __init__(self, tokenizer, dataset_path="hails/mmlu_no_train", dataset_name=None, split="validation"): + def __init__( + self, + tokenizer, + dataset_path="hails/mmlu_no_train", + dataset_name=None, + split="validation", + prompt_variations=None, + ): dataset_names = [] # Collect the relevant categories if dataset_name in MMLU._categories: @@ -1069,10 +1098,40 @@ def __init__(self, tokenizer, dataset_path="hails/mmlu_no_train", dataset_name=N for name, cats in MMLU._subcategories.items(): if dataset_name in cats: dataset_names.append(name) - super().__init__(tokenizer=tokenizer, dataset_path=dataset_path, dataset_name=dataset_names, split=split) + self.dev_set = {} + if prompt_variations == 1: + prompts = [None, "inst", "inst+1", "inst+2", "inst+3", "inst+4", "inst+5"] + # Need to grab the dev set for the few-shot prompts + for name in dataset_names: + self.dev_set[name] = datasets.load_dataset( + path=dataset_path, name=name, split="dev", trust_remote_code=True + ) + super().__init__( + tokenizer=tokenizer, + dataset_path=dataset_path, + dataset_name=dataset_names, + split=split, + prompts=prompts, + ) def doc_to_text(self, doc): - return "Question: " + doc["question"] + "\nAnswer:" + output_text = "Question: " + doc["question"] + "\nAnswer:" + if self.current_prompt is not None: + prefix = "" + if "inst" in self.current_prompt: + subject = doc.get("subject").replace("_", " ") + prefix = f"The following are multiple choice questions (with answers) about {subject}:\n\n" + num_shots = re.findall("\\+(\\d+)", self.current_prompt) + if num_shots: + dev_set = self.dev_set.get(doc.get("subject"), []) + num_shots_int = int(num_shots[0]) + for idx, dev_doc in enumerate(dev_set): + if idx >= num_shots_int: + break + answer = dev_doc["choices"][dev_doc["answer"]] + prefix += "Question: " + dev_doc["question"] + "\nAnswer: " + answer + "\n\n" + output_text = prefix + output_text + return output_text def doc_to_continuations(self, doc): # add spaces in front of continuation @@ -1108,4 +1167,8 @@ def doc_to_domain_conditional(self, doc): "mmlu_humanities": (MMLU, {"dataset_name": "humanities"}), "mmlu_social_sciences": (MMLU, {"dataset_name": "social_sciences"}), "mmlu_other": (MMLU, {"dataset_name": "other"}), + "mmlu_stem_var": (MMLU, {"dataset_name": "stem", "prompt_variations": 1}), + "mmlu_humanities_var": (MMLU, {"dataset_name": "humanities", "prompt_variations": 1}), + "mmlu_social_sciences_var": (MMLU, {"dataset_name": "social_sciences", "prompt_variations": 1}), + "mmlu_other_var": (MMLU, {"dataset_name": "other", "prompt_variations": 1}), } diff --git a/pyproject.toml b/pyproject.toml index 6ff1c2f91..07177dbdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "numpy", "torch>=2.0,<2.3", "omegaconf", + "fsspec==2023.5.0", # temporary fix for HF dataset downloads "rich", "boto3", "google-cloud-storage", From 364e90c2e3a38779180a80a821d546c00e4f282b Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 6 Mar 2024 10:43:45 -0800 Subject: [PATCH 7/8] Fix dependencies --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 07177dbdf..75c6b2827 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,13 +16,12 @@ dependencies = [ "numpy", "torch>=2.0,<2.3", "omegaconf", - "fsspec==2023.5.0", # temporary fix for HF dataset downloads "rich", "boto3", "google-cloud-storage", "tokenizers", "packaging", - "cached_path", + "cached_path>=1.6.2", "transformers", ] From a49d4619e209f6bd9bd4d9e899a1cb06b9ae2a61 Mon Sep 17 00:00:00 2001 From: Dirk Groeneveld Date: Wed, 6 Mar 2024 11:05:21 -0800 Subject: [PATCH 8/8] Productivity through formatting --- scripts/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train.py b/scripts/train.py index 19292d268..fca309a7f 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -2,13 +2,13 @@ import gzip import logging +import multiprocessing as mp import sys from pathlib import Path from typing import Optional, TextIO import torch import torch.distributed as dist -import multiprocessing as mp import wandb from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP