diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 643880a661..fd7398d634 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -47,8 +47,9 @@ import os from dataclasses import asdict, dataclass, field from functools import cached_property -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Tuple, Union +import numpy as np import torch import transformers from datasets import Dataset, concatenate_datasets, load_dataset @@ -205,6 +206,132 @@ def visualize_token_role(tokens: list[int], masks: list[int], tokenizer: PreTrai "{% endif %}" "{% endfor %}" ), + # olmo-core-compatible chat templates: + # TODO: unify these 3 chat templates and send variables through the tokenizer's apply_chat_template kwargs + "olmo": ( + "{% set has_system = messages|selectattr('role', 'equalto', 'system')|list|length > 0 %}" + "{% if not has_system %}" + "{{ '<|im_start|>system\nYou are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <|im_end|>\n' }}" + "{% endif %}" + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|im_start|>system\n' + message['content'] }}" + "{% if message.get('functions', none) is not none %}" + "{{ ' ' + message['functions'] + '<|im_end|>\n' }}" + "{% else %}" + "{{ ' You do not currently have access to any functions. <|im_end|>\n' }}" + "{% endif %}" + "{% elif message['role'] == 'user' %}" + "{% if message.get('functions', none) is not none %}" + "{{ '<|im_start|>user\n' + message['content'] + '\n' + '' + message['functions'] + '<|im_end|>\n' }}" + "{% else %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% if message.get('content', none) is not none %}" + "{{ message['content'] }}" + "{% endif %}" + "{% if message.get('function_calls', none) is not none %}" + "{{ '' + message['function_calls'] + '' }}" + "{% endif %}" + "{% if not loop.last %}" + "{{ '<|im_end|>' + '\n' }}" + "{% else %}" + "{{ eos_token }}" + "{% endif %}" + "{% elif message['role'] == 'environment' %}" + "{{ '<|im_start|>environment\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% if loop.last and add_generation_prompt %}" + "{{ '<|im_start|>assistant\n' }}" + "{% endif %}" + "{% endfor %}" + ), + "olmo_thinker": ( + "{% set has_system = messages|selectattr('role', 'equalto', 'system')|list|length > 0 %}" + "{% if not has_system %}" + "{{ '<|im_start|>system\nYou are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <|im_end|>\n' }}" + "{% endif %}" + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|im_start|>system\n' + message['content'] }}" + "{% if message.get('functions', none) is not none %}" + "{{ ' ' + message['functions'] + '<|im_end|>\n' }}" + "{% else %}" + "{{ ' You do not currently have access to any functions. <|im_end|>\n' }}" + "{% endif %}" + "{% elif message['role'] == 'user' %}" + "{% if message.get('functions', none) is not none %}" + "{{ '<|im_start|>user\n' + message['content'] + '\n' + '' + message['functions'] + '<|im_end|>\n' }}" + "{% else %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% if message.get('content', none) is not none %}" + "{{ message['content'] }}" + "{% endif %}" + "{% if message.get('function_calls', none) is not none %}" + "{{ '' + message['function_calls'] + '' }}" + "{% endif %}" + "{% if not loop.last %}" + "{{ '<|im_end|>' + '\n' }}" + "{% else %}" + "{{ eos_token }}" + "{% endif %}" + "{% elif message['role'] == 'environment' %}" + "{{ '<|im_start|>environment\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% if loop.last and add_generation_prompt %}" + "{{ '<|im_start|>assistant\n' }}" + "{% endif %}" + "{% endfor %}" + ), + "olmo_thinker_r1_style": ( + "A conversation between user and assistant. " + "The user asks a question, and the assistant solves it. " + "The assistant first thinks and reasons about the question " + "and after thinking provides the user with the answer. " + "The reasoning process is enclosed in tags " + "and the answer are enclosed in tags " + "so the full response is reasoning process here " + " answer here ." + "\n\n" + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{% if message.get('functions', none) is not none %}" + "{{ '<|im_start|>system\n' + message['content'] + '\n' + '' + message['functions'] + '<|im_end|>\n' }}" + "{% else %}" + "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% elif message['role'] == 'user' %}" + "{% if message.get('functions', none) is not none %}" + "{{ '<|im_start|>user\n' + message['content'] + '\n' + '' + message['functions'] + '<|im_end|>\n' }}" + "{% else %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% if message.get('content', none) is not none %}" + "{{ message['content'] }}" + "{% endif %}" + "{% if message.get('function_calls', none) is not none %}" + "{{ '' + message['function_calls'] + '' }}" + "{% endif %}" + "{% if not loop.last %}" + "{{ '<|im_end|>' + '\n' }}" + "{% else %}" + "{{ eos_token }}" + "{% endif %}" + "{% elif message['role'] == 'environment' %}" + "{{ '<|im_start|>environment\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% if loop.last and add_generation_prompt %}" + "{{ '<|im_start|>assistant\n' }}" + "{% endif %}" + "{% endfor %}" + ), "tulu": ( "{% for message in messages %}" "{% if message['role'] == 'system' %}" @@ -273,6 +400,132 @@ def visualize_token_role(tokens: list[int], masks: list[int], tokenizer: PreTrai "{% endif %}" "{% endfor %}" ), + # olmo-core-compatible chat templates: + # TODO: unify these 3 chat templates and send variables through the tokenizer's apply_chat_template kwargs + "olmo": ( + "{% set has_system = messages|selectattr('role', 'equalto', 'system')|list|length > 0 %}" + "{% if not has_system %}" + "{{ '<|im_start|>system\nYou are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <|im_end|>\n' }}" + "{% endif %}" + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|im_start|>system\n' + message['content'] }}" + "{% if message.get('functions', none) is not none %}" + "{{ ' ' + message['functions'] + '<|im_end|>\n' }}" + "{% else %}" + "{{ ' You do not currently have access to any functions. <|im_end|>\n' }}" + "{% endif %}" + "{% elif message['role'] == 'user' %}" + "{% if message.get('functions', none) is not none %}" + "{{ '<|im_start|>user\n' + message['content'] + '\n' + '' + message['functions'] + '<|im_end|>\n' }}" + "{% else %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% if message.get('content', none) is not none %}" + "{{ message['content'] }}" + "{% endif %}" + "{% if message.get('function_calls', none) is not none %}" + "{{ '' + message['function_calls'] + '' }}" + "{% endif %}" + "{% if not loop.last %}" + "{{ '<|im_end|>' + '\n' }}" + "{% else %}" + "{{ eos_token }}" + "{% endif %}" + "{% elif message['role'] == 'environment' %}" + "{{ '<|im_start|>environment\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% if loop.last and add_generation_prompt %}" + "{{ '<|im_start|>assistant\n' }}" + "{% endif %}" + "{% endfor %}" + ), + "olmo_thinker": ( + "{% set has_system = messages|selectattr('role', 'equalto', 'system')|list|length > 0 %}" + "{% if not has_system %}" + "{{ '<|im_start|>system\nYou are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <|im_end|>\n' }}" + "{% endif %}" + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{{ '<|im_start|>system\n' + message['content'] }}" + "{% if message.get('functions', none) is not none %}" + "{{ ' ' + message['functions'] + '<|im_end|>\n' }}" + "{% else %}" + "{{ ' You do not currently have access to any functions. <|im_end|>\n' }}" + "{% endif %}" + "{% elif message['role'] == 'user' %}" + "{% if message.get('functions', none) is not none %}" + "{{ '<|im_start|>user\n' + message['content'] + '\n' + '' + message['functions'] + '<|im_end|>\n' }}" + "{% else %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% if message.get('content', none) is not none %}" + "{{ message['content'] }}" + "{% endif %}" + "{% if message.get('function_calls', none) is not none %}" + "{{ '' + message['function_calls'] + '' }}" + "{% endif %}" + "{% if not loop.last %}" + "{{ '<|im_end|>' + '\n' }}" + "{% else %}" + "{{ eos_token }}" + "{% endif %}" + "{% elif message['role'] == 'environment' %}" + "{{ '<|im_start|>environment\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% if loop.last and add_generation_prompt %}" + "{{ '<|im_start|>assistant\n' }}" + "{% endif %}" + "{% endfor %}" + ), + "olmo_thinker_r1_style": ( + "A conversation between user and assistant. " + "The user asks a question, and the assistant solves it. " + "The assistant first thinks and reasons about the question " + "and after thinking provides the user with the answer. " + "The reasoning process is enclosed in tags " + "and the answer is enclosed in tags " + "so the full response is reasoning process here " + " answer here ." + "\n\n" + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "{% if message.get('functions', none) is not none %}" + "{{ '<|im_start|>system\n' + message['content'] + '\n' + '' + message['functions'] + '<|im_end|>\n' }}" + "{% else %}" + "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% elif message['role'] == 'user' %}" + "{% if message.get('functions', none) is not none %}" + "{{ '<|im_start|>user\n' + message['content'] + '\n' + '' + message['functions'] + '<|im_end|>\n' }}" + "{% else %}" + "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% elif message['role'] == 'assistant' %}" + "{{ '<|im_start|>assistant\n' }}" + "{% if message.get('content', none) is not none %}" + "{{ message['content'] }}" + "{% endif %}" + "{% if message.get('function_calls', none) is not none %}" + "{{ '' + message['function_calls'] + '' }}" + "{% endif %}" + "{% if not loop.last %}" + "{{ '<|im_end|>' + '\n' }}" + "{% else %}" + "{{ eos_token }}" + "{% endif %}" + "{% elif message['role'] == 'environment' %}" + "{{ '<|im_start|>environment\n' + message['content'] + '<|im_end|>\n' }}" + "{% endif %}" + "{% if loop.last and add_generation_prompt %}" + "{{ '<|im_start|>assistant\n' }}" + "{% endif %}" + "{% endfor %}" + ), # template is taken from https://arxiv.org/abs/2501.12948. "r1_simple_chat": ( "A conversation between User and Assistant. " @@ -511,7 +764,10 @@ def get_tokenizer_tulu_v2_2(tc: "TokenizerConfig"): config = AutoConfig.from_pretrained(tc.tokenizer_name_or_path, revision=tc.tokenizer_revision) # @vwxyzjn: "olmo" handles both `olmo2` and `olmoe`. if "olmo" in config.model_type: - assert tc.add_bos, "For OLMo, you must run with `--add_bos`." + if "olmo" in tc.chat_template_name: + assert not tc.add_bos, "For newer OLMo chat templates, you must *not* run with `--add_bos`." + else: + assert tc.add_bos, "For OLMo, you must run with `--add_bos`." assert tc.use_fast, "For OLMo, you must use fast tokenizer." tokenizer = AutoTokenizer.from_pretrained( @@ -533,9 +789,11 @@ def get_tokenizer_tulu_v2_2(tc: "TokenizerConfig"): # OLMo newer models use this tokenizer if tokenizer.bos_token is None: tokenizer.bos_token = tokenizer.eos_token - assert tc.add_bos, ( - "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." - ) + if "olmo" not in tc.chat_template_name: + assert tc.add_bos, ( + "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence " + "if using an older chat template." + ) # else, pythia / other models else: num_added_tokens = tokenizer.add_special_tokens({"pad_token": ""}) @@ -597,7 +855,7 @@ class TokenizerConfig: tokenizer_revision: Optional[str] = None trust_remote_code: bool = False use_fast: bool = True - chat_template_name: str = "tulu" # TODO: should I give an option to force override? + chat_template_name: str = "olmo" add_bos: bool = False get_tokenizer_fn: str = "get_tokenizer_tulu_v2_2" @@ -639,7 +897,9 @@ def tokenizer(self): INPUT_IDS_KEY = "input_ids" ATTENTION_MASK_KEY = "attention_mask" LABELS_KEY = "labels" +DATASET_SOURCE_KEY = "dataset_source" TOKENIZED_SFT_DATASET_KEYS = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY] +TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY, DATASET_SOURCE_KEY] # Preference dataset # NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only @@ -1124,6 +1384,9 @@ class DatasetConfig: # for tracking purposes dataset_commit_hash: Optional[str] = None + frac_or_num_samples: Optional[Union[int, float]] = None + original_dataset_size: Optional[int] = None + is_upsampled: bool = False def __post_init__(self): # if the file exists locally, use the local file @@ -1145,9 +1408,40 @@ def __post_init__(self): def update_range(self, dataset_range: int): self.dataset_range = dataset_range - if self.dataset_range > len(self.dataset): - raise ValueError("Dataset range exceeds dataset length") - self.dataset = self.dataset.select(range(self.dataset_range)) + original_size = len(self.dataset) + self.original_dataset_size = original_size + + self.dataset = self.select_samples(self.dataset_range) + self.is_upsampled = dataset_range > original_size + + def select_samples(self, target_size: int): + """Upsample dataset to target_size by repeating samples.""" + original_size = len(self.dataset) + + # Calculate how many full repeats and how many extra samples + full_repeats = target_size // original_size + extra_samples = target_size % original_size + + # Create indices for upsampling + indices = [] + + # Add full repeats + for _ in range(full_repeats): + indices.extend(range(original_size)) + + # Add randomly sampled extra samples + if extra_samples > 0: + # Use numpy for reproducible random sampling + rng = np.random.RandomState(42) # Fixed seed for reproducibility + extra_indices = rng.choice(original_size, size=extra_samples, replace=False) + indices.extend(extra_indices.tolist()) + + print( + f"Upsampling dataset {self.dataset_name} from {original_size} to {target_size} samples " + f"({full_repeats} full repeats + {extra_samples} random samples)" + ) + + return self.dataset.select(indices) def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): @@ -1159,6 +1453,13 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): tokenizer = tc.tokenizer dataset = dc.dataset + + # Add dataset source field to track origin after shuffling + dataset = dataset.map( + lambda example: {**example, DATASET_SOURCE_KEY: dc.dataset_name}, + num_proc=num_proc, + desc=f"Adding dataset source field for {dc.dataset_name}", + ) for fn_name, fn_args in zip(dc.transform_fn, dc.transform_fn_args): fn, fn_type = TRANSFORM_FNS[fn_name] # always pass in tokenizer and other args if needed @@ -1167,6 +1468,10 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig): # perform the transformation target_columns = dataset.column_names if dc.target_columns is None else dc.target_columns + # Always preserve dataset_source if it exists + if DATASET_SOURCE_KEY in dataset.column_names and DATASET_SOURCE_KEY not in target_columns: + target_columns = target_columns + [DATASET_SOURCE_KEY] + if fn_type == "map": dataset = dataset.map( fn, @@ -1300,35 +1605,104 @@ def save_config(self, config_hash: str, dcs: List[DatasetConfig], tc: TokenizerC json.dump(config_dict, f, indent=2) def load_or_transform_dataset( - self, dcs: List[DatasetConfig], tc: TokenizerConfig, dataset_skip_cache: bool = False - ) -> Dataset: + self, + dcs: List[DatasetConfig], + tc: TokenizerConfig, + dataset_skip_cache: bool = False, + return_statistics: bool = False, + ) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]: """Load dataset from local cache if it exists, otherwise transform and cache it locally.""" cache_path = self.get_cache_path() # Check if the cache exists if os.path.exists(cache_path) and not dataset_skip_cache: print(f"✅ Found cached dataset at {cache_path}") - return Dataset.load_from_disk(cache_path, keep_in_memory=True) + dataset = Dataset.load_from_disk(cache_path, keep_in_memory=True) + if return_statistics: + # Load statistics from cache if available + stats_path = os.path.join(cache_path, "dataset_statistics.json") + if os.path.exists(stats_path): + with open(stats_path, "r") as f: + statistics = json.load(f) + return dataset, statistics + else: + # Return empty statistics if not cached + return dataset, {"per_dataset_stats": [], "dataset_order": []} + return dataset print(f"Cache not found or invalid, transforming datasets...") - # Transform each dataset + # Transform each dataset and collect statistics transformed_datasets = [] + dataset_statistics = [] + dataset_order = [] + for dc in dcs: + # Get initial dataset info + initial_size = len(dc.dataset) if dc.dataset else 0 + dataset = get_dataset_v1(dc, tc) transformed_datasets.append(dataset) + # Collect statistics for this dataset + stats = { + "dataset_name": dc.dataset_name, + "dataset_split": dc.dataset_split, + "initial_instances": initial_size, + "final_instances": len(dataset), + "instances_filtered": initial_size - len(dataset), + "frac_or_num_samples": dc.frac_or_num_samples, + "original_dataset_size": dc.original_dataset_size, + "is_upsampled": dc.is_upsampled, + "upsampling_factor": dc.dataset_range / dc.original_dataset_size + if dc.original_dataset_size and dc.original_dataset_size > 0 + else 1.0, + } + + # Count tokens if the dataset has been tokenized + if INPUT_IDS_KEY in dataset.column_names: + total_tokens = 0 + trainable_tokens = 0 + for sample in dataset: + tokens = len(sample[INPUT_IDS_KEY]) + total_tokens += tokens + if LABELS_KEY in sample: + trainable_tokens += sum(1 for label in sample[LABELS_KEY] if label != -100) + + stats["total_tokens"] = total_tokens + stats["trainable_tokens"] = trainable_tokens + stats["avg_tokens_per_instance"] = total_tokens / len(dataset) if len(dataset) > 0 else 0 + + dataset_statistics.append(stats) + dataset_order.append(dc.dataset_name) + # Combine datasets combined_dataset = concatenate_datasets(transformed_datasets) + + # Prepare return statistics + all_statistics = {"per_dataset_stats": dataset_statistics, "dataset_order": dataset_order} + if dataset_skip_cache: + if return_statistics: + return combined_dataset, all_statistics return combined_dataset # Save to local cache combined_dataset.save_to_disk(cache_path) self.save_config(self.config_hash, dcs, tc) + + # Save statistics to cache + stats_path = os.path.join(cache_path, "dataset_statistics.json") + with open(stats_path, "w") as f: + json.dump(all_statistics, f, indent=2) + print(f"🚀 Saved transformed dataset to {cache_path}") print(f"✅ Found cached dataset at {cache_path}") - return Dataset.load_from_disk(cache_path, keep_in_memory=True) + + loaded_dataset = Dataset.load_from_disk(cache_path, keep_in_memory=True) + if return_statistics: + return loaded_dataset, all_statistics + return loaded_dataset def get_cached_dataset( @@ -1337,12 +1711,15 @@ def get_cached_dataset( hf_entity: Optional[str] = None, dataset_local_cache_dir: Optional[str] = None, dataset_skip_cache: bool = False, -) -> Dataset: + return_statistics: bool = False, +) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]: if dataset_local_cache_dir is not None: cache = LocalDatasetTransformationCache(dataset_local_cache_dir=dataset_local_cache_dir) else: cache = DatasetTransformationCache(hf_entity=hf_entity) - return cache.load_or_transform_dataset(dcs, tc, dataset_skip_cache=dataset_skip_cache) + return cache.load_or_transform_dataset( + dcs, tc, dataset_skip_cache=dataset_skip_cache, return_statistics=return_statistics + ) def get_cached_dataset_tulu( @@ -1357,7 +1734,8 @@ def get_cached_dataset_tulu( hf_entity: Optional[str] = None, dataset_local_cache_dir: str = "local_dataset_cache", dataset_skip_cache: bool = False, -) -> Dataset: + return_statistics: bool = False, +) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]: dcs = [] if dataset_config_hash is None: if len(dataset_mixer_list_splits) == 1: @@ -1384,11 +1762,22 @@ def get_cached_dataset_tulu( transform_fn=dataset_transform_fn, transform_fn_args=transform_fn_args, target_columns=target_columns, + frac_or_num_samples=frac_or_num_samples, ) - if frac_or_num_samples > 1.0: - new_range = int(frac_or_num_samples) + + # Calculate target size properly handling fractional upsampling + original_size = len(dataset_config.dataset) + if isinstance(frac_or_num_samples, int) and frac_or_num_samples > original_size: + # Absolute number larger than dataset size - use as-is for upsampling + new_range = frac_or_num_samples + elif isinstance(frac_or_num_samples, float): + # Fractional sampling (can be > 1.0 for upsampling) + new_range = int(frac_or_num_samples * original_size) else: - new_range = int(frac_or_num_samples * len(dataset_config.dataset)) + # Integer <= dataset size, use as absolute count + new_range = int(frac_or_num_samples) + + print(f"Dataset {dataset_name}: {original_size} -> {new_range} samples (factor: {frac_or_num_samples})") dataset_config.update_range(new_range) dcs.append(dataset_config) dataset_config_hash = compute_config_hash(dcs, tc) @@ -1398,7 +1787,9 @@ def get_cached_dataset_tulu( ) elif dataset_cache_mode == "hf": cache = DatasetTransformationCache(config_hash=dataset_config_hash, hf_entity=hf_entity) - return cache.load_or_transform_dataset(dcs, tc, dataset_skip_cache=dataset_skip_cache) + return cache.load_or_transform_dataset( + dcs, tc, dataset_skip_cache=dataset_skip_cache, return_statistics=return_statistics + ) def test_sft_dpo_same_tokenizer(): diff --git a/open_instruct/dpo_tune_cache.py b/open_instruct/dpo_tune_cache.py index 9dda8a62d8..ce25078c44 100644 --- a/open_instruct/dpo_tune_cache.py +++ b/open_instruct/dpo_tune_cache.py @@ -988,7 +988,7 @@ def load_model(): accelerator.wait_for_everyone() if args.output_dir is not None: - save_with_accelerate(accelerator, model, tokenizer, args.output_dir, args.use_lora) + save_with_accelerate(accelerator, model, tokenizer, args.output_dir, args.use_lora, tc.chat_template_name) # remove all checkpoints to save space if accelerator.is_local_main_process: diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index c3d6e9bf59..581e440352 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -974,7 +974,7 @@ def main(args: FlatArguments, tc: TokenizerConfig): accelerator.wait_for_everyone() if args.output_dir is not None: - save_with_accelerate(accelerator, model, tokenizer, args.output_dir, args.use_lora) + save_with_accelerate(accelerator, model, tokenizer, args.output_dir, args.use_lora, tc.chat_template_name) # remove all checkpoints to save space if args.clean_checkpoints_at_end and accelerator.is_local_main_process: diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 7f644ec925..8a413e71fb 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -94,6 +94,7 @@ apply_verifiable_reward, disable_dropout_in_model, entropy_from_logits, + get_olmo3_generation_config, log_softmax_and_gather, print_rich_single_line_metrics, print_rich_table, @@ -983,8 +984,12 @@ def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: Dict[st checkpoint_state_dir, args.gs_checkpoint_state_dir ) - def save_model(self, output_dir: str) -> None: + def save_model(self, output_dir: str, chat_template_name: str, tokenizer: PreTrainedTokenizer) -> None: model_to_save = self.model + if "olmo" in chat_template_name: + # New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|> + model_to_save.generation_config = get_olmo3_generation_config(tokenizer) + if self.rank == 0: os.makedirs(output_dir, exist_ok=True) @@ -1774,6 +1779,7 @@ def one_training_step( train_dataset, writer, wandb_url, + chat_template_name, ): """Train the model for one step.""" update_ref_policy_future = [] @@ -1820,7 +1826,12 @@ def one_training_step( checkpoint_dir = f"{args.output_dir}_checkpoints" step_dir = os.path.join(checkpoint_dir, f"step_{training_step}") logger.info(f"Saving model at step {training_step} to {step_dir}") - ray.get([policy_group.models[i].save_model.remote(step_dir) for i in range(args.world_size)]) + ray.get( + [ + policy_group.models[i].save_model.remote(step_dir, chat_template_name, tokenizer) + for i in range(args.world_size) + ] + ) if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job(): leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}" for i in range(args.world_size): @@ -1917,11 +1928,23 @@ def maybe_evaluate( logger.warning("[Main Thread] 🙈 Evaluation responses not received") -def save_final_model(args: Args, policy_group: ModelGroup, training_step: int, wandb_url: str): +def save_final_model( + args: Args, + policy_group: ModelGroup, + tokenizer: PreTrainedTokenizer, + training_step: int, + wandb_url: str, + chat_template_name: str, +): """Save the final model and launch evaluation jobs if configured.""" logger.info(f"Saving final model at step {training_step} to {args.output_dir}") with Timer("[Main Thread] 🗡️ Saving model"): - ray.get([policy_group.models[i].save_model.remote(args.output_dir) for i in range(args.world_size)]) + ray.get( + [ + policy_group.models[i].save_model.remote(args.output_dir, chat_template_name, tokenizer) + for i in range(args.world_size) + ] + ) if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job(): leaderboard_name = args.hf_repo_revision for i in range(args.world_size): @@ -2189,6 +2212,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa train_dataset, writer, wandb_url, + tc.chat_template_name, ) maybe_evaluate( @@ -2204,7 +2228,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa writer, ) - save_final_model(args, policy_group, training_step, wandb_url) + save_final_model(args, policy_group, tokenizer, training_step, wandb_url, tc.chat_template_name) except Exception as e: logger.error(f"Training error occurred: {str(e)}\n{traceback.format_exc()}") diff --git a/open_instruct/grpo_vllm_thread_ray_gtrl.py b/open_instruct/grpo_vllm_thread_ray_gtrl.py index be8121baca..81d11110c7 100644 --- a/open_instruct/grpo_vllm_thread_ray_gtrl.py +++ b/open_instruct/grpo_vllm_thread_ray_gtrl.py @@ -102,6 +102,7 @@ disable_dropout_in_model, exact_div, first_true_indices, + get_olmo3_generation_config, get_reward, log_softmax_and_gather, print_rich_single_line_metrics, @@ -791,6 +792,7 @@ def train( train_dataset: Dataset, eval_dataset: Dataset, tokenizer: PreTrainedTokenizer, + tc: TokenizerConfig, vllm_engines: List[ray.actor.ActorHandle], metrics_queue: RayQueue, data_collator: Callable, @@ -1378,7 +1380,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar checkpoint_dir = f"{args.output_dir}_checkpoints" step_dir = os.path.join(checkpoint_dir, f"step_{training_step}") print(f"Saving model at step {training_step} to {step_dir}") - self.save_model(self.model, step_dir) + self.save_model(self.model, tc.chat_template_name, tokenizer, step_dir) if args.try_launch_beaker_eval_jobs_on_weka: leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}" if self.rank == 0 and is_beaker_job(): @@ -1404,7 +1406,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar print(f"Eval future {eval_futures[0]} is done") eval_futures.popleft() print(f"Saving final model at step {training_step} to {args.output_dir}") - self.save_model(self.model, args.output_dir) + self.save_model(self.model, tc.chat_template_name, tokenizer, args.output_dir) if args.try_launch_beaker_eval_jobs_on_weka: leaderboard_name = args.hf_repo_revision if self.rank == 0 and is_beaker_job(): @@ -1438,7 +1440,9 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) print("finished training") - def save_model(self, model_to_save: PreTrainedModel, output_dir: str) -> None: + def save_model( + self, model_to_save: PreTrainedModel, chat_template_name: str, tokenizer: PreTrainedTokenizer, output_dir: str + ) -> None: if self.rank == 0: os.makedirs(output_dir, exist_ok=True) @@ -1446,6 +1450,10 @@ def save_model(self, model_to_save: PreTrainedModel, output_dir: str) -> None: if hasattr(model_to_save, "module"): model_to_save = model_to_save.module + if "olmo" in chat_template_name: + # New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|> + model_to_save.generation_config = get_olmo3_generation_config(tokenizer) + # gather parameters output_state_dict = {} for k, v in model_to_save.named_parameters(): diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index 90d76f8da4..52e9b93ad5 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -403,6 +403,14 @@ def batch_generation( return torch.cat(query_responses, 0), torch.cat(logitss, 0) +def get_olmo3_generation_config(tokenizer): + return transformers.GenerationConfig( + temperature=None, + top_p=None, + eos_token_id=[tokenizer.convert_tokens_to_ids("<|im_end|>"), tokenizer.convert_tokens_to_ids("<|endoftext|>")], + ) + + def save_with_accelerate( accelerator: Accelerator, model: torch.nn.Module, @@ -410,14 +418,20 @@ def save_with_accelerate( output_dir: str, use_lora: bool = False, model_attribute_to_save: Optional[str] = None, + chat_template_name: str = "tulu", ) -> None: """`model_attribute_to_save` is for used to save PPO's policy instead of the full model""" # set the generation config to an empty setting to be safe. # we usually do greedy decoding for generation, so this should be okay. # otherwise, we get an error thrown at save time. - model.generation_config = transformers.GenerationConfig( - temperature=None, top_p=None, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id - ) + if "olmo" in chat_template_name: + # New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|> + logger.log(f"Detected olmo chat template: {chat_template_name}, updating model generation config.") + model.generation_config = get_olmo3_generation_config(tokenizer) + else: + model.generation_config = transformers.GenerationConfig( + temperature=None, top_p=None, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id + ) unwrapped_model: PreTrainedModel = accelerator.unwrap_model(model) if model_attribute_to_save is not None: diff --git a/open_instruct/ppo_fast.py b/open_instruct/ppo_fast.py index 3ecb262e44..b27921b0a4 100644 --- a/open_instruct/ppo_fast.py +++ b/open_instruct/ppo_fast.py @@ -101,6 +101,7 @@ apply_verifiable_reward, disable_dropout_in_model, entropy_from_logits, + get_olmo3_generation_config, log_softmax_and_gather, print_rich_single_line_metrics, print_rich_table, @@ -1074,7 +1075,7 @@ def train( self.offload_to_cpu(self.model) return metrics_list - def save_model(self, output_dir: str) -> None: + def save_model(self, output_dir: str, chat_template_name: str, tokenizer: PreTrainedTokenizer) -> None: model_to_save = self.model if self.rank == 0: os.makedirs(output_dir, exist_ok=True) @@ -1083,6 +1084,10 @@ def save_model(self, output_dir: str) -> None: if hasattr(model_to_save, "module"): model_to_save = model_to_save.module + if "olmo" in chat_template_name: + # New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|> + model_to_save.generation_config = get_olmo3_generation_config(tokenizer) + # gather parameters output_state_dict = {} for k, v in model_to_save.named_parameters(): @@ -1819,7 +1824,12 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn: checkpoint_dir = f"{args.output_dir}_checkpoints" step_dir = os.path.join(checkpoint_dir, f"step_{training_step}") print(f"Saving model at step {training_step} to {step_dir}") - ray.get([policy_group.models[i].save_model.remote(step_dir) for i in range(args.world_size)]) + ray.get( + [ + policy_group.models[i].save_model.remote(step_dir, tc.chat_template_name, tokenizer) + for i in range(args.world_size) + ] + ) if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job(): leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}" for i in range(args.world_size): @@ -1889,7 +1899,12 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn: print(f"Saving final model at step {training_step} to {args.output_dir}") with Timer("[Main Thread] 🗡️ Saving model"): - ray.get([policy_group.models[i].save_model.remote(args.output_dir) for i in range(args.world_size)]) + ray.get( + [ + policy_group.models[i].save_model.remote(args.output_dir, tc.chat_template_name, tokenizer) + for i in range(args.world_size) + ] + ) if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job(): leaderboard_name = args.hf_repo_revision for i in range(args.world_size): diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl.py b/open_instruct/ppo_vllm_thread_ray_gtrl.py index 5e2dc91ac3..63e433d3c1 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl.py @@ -100,6 +100,7 @@ disable_dropout_in_model, exact_div, first_true_indices, + get_olmo3_generation_config, get_reward, log_softmax_and_gather, print_rich_single_line_metrics, @@ -1513,7 +1514,9 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) print("finished training") - def save_model(self, model_to_save: PreTrainedModel, output_dir: str) -> None: + def save_model( + self, model_to_save: PreTrainedModel, chat_template_name: str, tokenizer: PreTrainedTokenizer, output_dir: str + ) -> None: if self.rank == 0: os.makedirs(output_dir, exist_ok=True) @@ -1521,6 +1524,10 @@ def save_model(self, model_to_save: PreTrainedModel, output_dir: str) -> None: if hasattr(model_to_save, "module"): model_to_save = model_to_save.module + if "olmo" in chat_template_name: + # New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|> + model_to_save.generation_config = get_olmo3_generation_config(tokenizer) + # gather parameters output_state_dict = {} for k, v in model_to_save.named_parameters(): diff --git a/open_instruct/reward_modeling.py b/open_instruct/reward_modeling.py index 51e51d3a19..31c80ea86f 100644 --- a/open_instruct/reward_modeling.py +++ b/open_instruct/reward_modeling.py @@ -424,7 +424,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): # save model os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) - save_with_accelerate(accelerator, model, tokenizer, args.output_dir) + save_with_accelerate(accelerator, model, tokenizer, args.output_dir, tc.chat_template_name) if args.push_to_hub: push_folder_to_hub(accelerator, args.output_dir, args.hf_repo_id, args.hf_repo_revision) diff --git a/scripts/data/convert_sft_data_for_olmocore.py b/scripts/data/convert_sft_data_for_olmocore.py index 9bb8ccbed1..5fce1a5d3c 100644 --- a/scripts/data/convert_sft_data_for_olmocore.py +++ b/scripts/data/convert_sft_data_for_olmocore.py @@ -3,45 +3,56 @@ implementation of the OLMo models (espeically for MoE), and so it can be preferable to use it for training on next-token prediction tasks (e.g. SFT). -OLMoCore accepts data in numpy mmap format. One file is for the input tokens, one for the labels, and one for the attention mask. +OLMoCore accepts data in numpy mmap format. One file is for the input tokens and one for the labels mask. Usage: python scripts/data/convert_sft_data_for_olmocore.py \ --tokenizer_name_or_path allenai/OLMo-2-1124-7B \ - --add_bos \ --dataset_mixer_list allenai/tulu-3-sft-olmo-2-mixture-0225 1.0 \ - --output_dir ./data/tulu-3-sft-olmo-2-mixture-0225-olmocore + --output_dir ./data/tulu-3-sft-olmo-2-mixture-0225-olmocore \ + --chat_template_name olmo Ai2 Internal Usage: - gantry run --cluster ai2/phobos-cirrascale --timeout -1 -y --budget ai2/oe-training \ + gantry run --cluster ai2/neptune-cirrascale --timeout -1 -y --budget ai2/oe-training --workspace ai2/jacobm \ --install "curl -LsSf https://astral.sh/uv/install.sh | sh && /root/.local/bin/uv sync" \ --weka=oe-training-default:/weka/oe-training-default \ - -- \ - /root/.local/bin/uv run python scripts/data/convert_sft_data_for_olmocore.py \ + --env-secret HF_TOKEN=HF_TOKEN \ + --gpus 1 \ + --priority high \ + -- /root/.local/bin/uv run python scripts/data/convert_sft_data_for_olmocore.py \ + --dataset_mixer_list allenai/tulu-3-sft-olmo-2-mixture 1.0 \ --tokenizer_name_or_path allenai/OLMo-2-1124-7B \ - --add_bos \ - --output_dir /weka/oe-training-default/ai2-llm/tylerr/data/sft/tulu-3-sft-olmo-2-mixture-0225-olmocore + --output_dir /weka/oe-training-default/ai2-llm/tylerr/data/sft/tulu-3-sft-olmo-2-mixture-0225-olmocore \ + --visualize True \ + --chat_template_name olmo \ + --max_seq_length 16384 NOTE: allenai/OLMo-2-1124-7B tokenizer is the same as allenai/dolma2-tokenizer, but allenai/OLMo-2-1124-7B has additional metadata required for this script. Recommendations: - * Don't use max-seq-length, keep full sequences and allow Olmo-core to truncate if needed. + * Set max_seq_length, and use the same length you use during SFT """ +import gzip +import json import os +import sys from collections.abc import Mapping from dataclasses import dataclass, field -from typing import Any, List, Literal, Optional +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional import numpy as np from tqdm import tqdm from open_instruct.dataset_transformation import ( ATTENTION_MASK_KEY, + DATASET_SOURCE_KEY, INPUT_IDS_KEY, LABELS_KEY, TOKENIZED_SFT_DATASET_KEYS, + TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE, TokenizerConfig, get_cached_dataset_tulu, visualize_token, @@ -72,14 +83,11 @@ class ConvertSFTDataArguments: """The list of transform functions to apply to the dataset.""" dataset_transform_fn: list[str] = field( - default_factory=lambda: [ - "sft_tulu_tokenize_and_truncate_v1", - "sft_tulu_filter_v1", - ] + default_factory=lambda: ["sft_tulu_tokenize_and_truncate_v1", "sft_tulu_filter_v1"] ) """The columns to use for the dataset.""" - dataset_target_columns: List[str] = field(default_factory=lambda: TOKENIZED_SFT_DATASET_KEYS) + dataset_target_columns: List[str] = field(default_factory=lambda: TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE) """The mode to use for caching the dataset.""" dataset_cache_mode: Literal["hf", "local"] = "local" @@ -102,6 +110,9 @@ class ConvertSFTDataArguments: """Visualize first token sequence""" visualize: bool = field(default=False) + """Only write the tokenizer config to the output directory""" + tokenizer_config_only: bool = field(default=False) + def main(args: ConvertSFTDataArguments, tc: TokenizerConfig): args.dataset_local_cache_dir = os.path.abspath(args.dataset_local_cache_dir) @@ -110,16 +121,47 @@ def main(args: ConvertSFTDataArguments, tc: TokenizerConfig): if os.path.exists(beaker_cache_dir): args.dataset_local_cache_dir = beaker_cache_dir + print("Verify these values match the tokenizer config used in Olmo-core:") + print(f"Tokenizer vocab_size: {tc.tokenizer.vocab_size}") + print(f"Tokenizer bos_token_id: {tc.tokenizer.bos_token_id}") + print(f"Tokenizer pad_token_id: {tc.tokenizer.pad_token_id}") + print(f"Tokenizer eos_token_id: {tc.tokenizer.eos_token_id}") + print(f"Tokenizer chat_template: {tc.tokenizer.chat_template}") + + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + + tokenizer_output_dir = os.path.join(output_dir, "tokenizer") + os.makedirs(tokenizer_output_dir, exist_ok=True) + print(f"Saving tokenizer to {tokenizer_output_dir}...") + tc.tokenizer.save_pretrained(tokenizer_output_dir) + + # Check if chat_template.jinja exists and add it to tokenizer_config.json + chat_template_path = os.path.join(tokenizer_output_dir, "chat_template.jinja") + tokenizer_config_path = os.path.join(tokenizer_output_dir, "tokenizer_config.json") + if os.path.exists(chat_template_path) and os.path.exists(tokenizer_config_path): + with open(chat_template_path, "r") as f: + chat_template_content = f.read() + with open(tokenizer_config_path, "r") as f: + tokenizer_config = json.load(f) + if "chat_template" not in tokenizer_config: + tokenizer_config["chat_template"] = chat_template_content + with open(tokenizer_config_path, "w") as f: + json.dump(tokenizer_config, f, indent=2) + print(f"Added chat_template from {chat_template_path} to tokenizer_config.json") + + print("Tokenizer saved successfully!") + + if args.tokenizer_config_only: + return + # TODO: improve configurability of transform factory transform_functions_and_args = [ - ( - "sft_tulu_tokenize_and_truncate_v1", - {"max_seq_length": args.max_seq_length}, - ), + ("sft_tulu_tokenize_and_truncate_v1", {"max_seq_length": args.max_seq_length}), ("sft_tulu_filter_v1", {}), # remove examples that don't have any labels ] - train_dataset = get_cached_dataset_tulu( + result = get_cached_dataset_tulu( dataset_mixer_list=args.dataset_mixer_list, dataset_mixer_list_splits=args.dataset_mixer_list_splits, tc=tc, @@ -130,7 +172,13 @@ def main(args: ConvertSFTDataArguments, tc: TokenizerConfig): dataset_config_hash=args.dataset_config_hash, dataset_local_cache_dir=args.dataset_local_cache_dir, dataset_skip_cache=args.dataset_skip_cache, + return_statistics=True, ) + + # Unpack the result + train_dataset, dataset_statistics = result + + train_dataset = train_dataset.shuffle() if args.visualize: print("Visualizing first example...") @@ -144,66 +192,293 @@ def main(args: ConvertSFTDataArguments, tc: TokenizerConfig): print("Collecting tokens from dataset...") token_ids = [] - labels = [] - attention_mask = [] + labels_mask = [] sample: Mapping[str, Any] - for sample in tqdm(train_dataset, desc="Collecting tokens"): # type: ignore - token_ids.extend(sample[INPUT_IDS_KEY]) - labels.extend(sample[LABELS_KEY]) - attention_mask.extend(sample[ATTENTION_MASK_KEY]) - - print(f"Total sequences: {len(train_dataset)}") - print(f"Total tokens: {len(token_ids)}") + num_samples_skipped = 0 + document_boundaries = [] + current_position = 0 + + # Track per-dataset statistics using dataset_source field + per_dataset_counts = {} + per_dataset_tokens = {} + per_dataset_trainable_tokens = {} + per_dataset_filtered = {} + + for idx, sample in enumerate(tqdm( # type: ignore + train_dataset, + desc="Collecting tokens", + file=sys.stdout, + bar_format="{l_bar}{bar}{r_bar}\n", # better printing in beaker + mininterval=10.0, + )): + sample_length = len(sample[INPUT_IDS_KEY]) + sample_tokens = sample[INPUT_IDS_KEY] + sample_labels = sample[LABELS_KEY] + dataset_source = sample.get(DATASET_SOURCE_KEY, "unknown") + + # Initialize counters for new datasets + if dataset_source not in per_dataset_counts: + per_dataset_counts[dataset_source] = 0 + per_dataset_tokens[dataset_source] = 0 + per_dataset_trainable_tokens[dataset_source] = 0 + per_dataset_filtered[dataset_source] = 0 + + # Update per-dataset statistics + per_dataset_counts[dataset_source] += 1 + per_dataset_tokens[dataset_source] += sample_length + trainable_tokens_in_sample = sum(1 for label in sample_labels if label != -100) + per_dataset_trainable_tokens[dataset_source] += trainable_tokens_in_sample + + token_ids.extend(sample_tokens) + labels_mask.extend([1 if label != -100 else 0 for label in sample_labels]) + + # Record document boundary (start, end) + document_boundaries.append((current_position, current_position + sample_length)) + current_position += sample_length + + if all(label == -100 for label in sample_labels): + num_samples_skipped += 1 + per_dataset_filtered[dataset_source] += 1 + + # Assert that attention mask is all 1s + assert all(mask == 1 for mask in sample[ATTENTION_MASK_KEY]), ( + f"Expected all attention mask values to be 1, but found: {sample[ATTENTION_MASK_KEY]}" + ) + + # Calculate final statistics + total_instances = len(train_dataset) + total_tokens = len(token_ids) + total_trainable_tokens = sum(labels_mask) + + print(f"Total sequences: {total_instances}") + print(f"Total tokens: {total_tokens}") print(f"Maximum token ID: {max(token_ids)}") + print(f"Labels mask sum (trainable tokens): {total_trainable_tokens}") print("Writing data to numpy files...") - - # Create output directory with tokenizer name - output_dir = args.output_dir - os.makedirs(output_dir, exist_ok=True) + print(f"Number of samples that should be skipped: {num_samples_skipped}") def write_memmap_chunked(base_filename, data, dtype, max_size_gb=1): """Write data to multiple memmap files if size exceeds max_size_gb.""" # Calculate size in bytes item_size = np.dtype(dtype).itemsize - total_size_bytes = len(data) * item_size max_size_bytes = max_size_gb * 1024**3 - if total_size_bytes <= max_size_bytes: # record in single file - mmap = np.memmap(f"{base_filename}.npy", mode="w+", dtype=dtype, shape=(len(data),)) - mmap[:] = data + chunk_size = max_size_bytes // item_size + chunks = [] + chunk_boundaries = [] + + for i in range(0, len(data), chunk_size): + chunk_data = data[i : i + chunk_size] + filename = f"{base_filename}_part_{i // chunk_size:04d}.npy" + mmap = np.memmap(filename, mode="w+", dtype=dtype, shape=(len(chunk_data),)) + mmap[:] = chunk_data mmap.flush() - print(f"Written {base_filename}.npy ({total_size_bytes / 1024**3:.2f} GB)") - return [mmap] - else: # record in multiple files (if too large) - chunk_size = max_size_bytes // item_size - chunks = [] - for i in range(0, len(data), chunk_size): - chunk_data = data[i : i + chunk_size] - filename = f"{base_filename}_part_{i // chunk_size:04d}.npy" - mmap = np.memmap(filename, mode="w+", dtype=dtype, shape=(len(chunk_data),)) - mmap[:] = chunk_data - mmap.flush() - chunks.append(mmap) - print(f"Written {filename} ({len(chunk_data) * item_size / 1024**3:.2f} GB)") - return chunks + chunks.append(mmap) + chunk_boundaries.append((i, i + len(chunk_data))) + print(f"Written {filename} ({len(chunk_data) * item_size / 1024**3:.2f} GB)") + + return chunks, chunk_boundaries + + def write_metadata_for_chunks(base_filename, document_boundaries, chunk_boundaries): + """Write metadata files for each chunk with document boundaries.""" + + for chunk_idx, (chunk_start, chunk_end) in enumerate(chunk_boundaries): + metadata_filename = f"{base_filename}_part_{chunk_idx:04d}.csv.gz" + + with gzip.open(metadata_filename, "wt") as f: + # Find all documents that overlap with this chunk + for doc_start, doc_end in document_boundaries: + # Check if document overlaps with chunk + if doc_end > chunk_start and doc_start < chunk_end: + # Adjust boundaries relative to chunk start + adjusted_start = max(0, doc_start - chunk_start) + adjusted_end = min(chunk_end - chunk_start, doc_end - chunk_start) + + # Only write if there's actual content in this chunk + if adjusted_end > adjusted_start: + f.write(f"{adjusted_start},{adjusted_end}\n") + + print(f"Written metadata {metadata_filename}") + + # Choose dtype based on vocab size - Olmo-core does the + # same operation to infer the dtype of the token_ids array. + vocab_size = tc.tokenizer.vocab_size + token_dtype = None + for dtype in (np.uint8, np.uint16, np.uint32, np.uint64): + if (vocab_size - 1) <= np.iinfo(dtype).max: + token_dtype = dtype + print(f"Using dtype '{dtype}' for token_ids based on vocab size {vocab_size}") + break + if token_dtype is None: + raise ValueError(f"Vocab size {vocab_size} is too big for any numpy integer dtype!") print(f"Writing converted data to {output_dir}") - write_memmap_chunked(f"{output_dir}/token_ids", token_ids, np.uint32) - write_memmap_chunked(f"{output_dir}/labels", labels, np.int32) - write_memmap_chunked(f"{output_dir}/attention_mask", attention_mask, np.int32) + _, token_chunk_boundaries = write_memmap_chunked(f"{output_dir}/token_ids", token_ids, token_dtype) + write_metadata_for_chunks(f"{output_dir}/token_ids", document_boundaries, token_chunk_boundaries) + + # Write labels_mask using the same chunk boundaries as token_ids + for i, (start, end) in enumerate(token_chunk_boundaries): + chunk_data = labels_mask[start:end] + filename = f"{output_dir}/labels_mask_part_{i:04d}.npy" + mmap = np.memmap(filename, mode="w+", dtype=np.bool_, shape=(len(chunk_data),)) + mmap[:] = chunk_data + mmap.flush() + print(f"Written {filename} ({len(chunk_data) * np.dtype(np.bool_).itemsize / 1024**3:.2f} GB)") + print("Data conversion completed successfully!") + + # Write dataset statistics + write_dataset_statistics( + output_dir=output_dir, + dataset_statistics=dataset_statistics, + total_instances=total_instances, + total_tokens=total_tokens, + total_trainable_tokens=total_trainable_tokens, + num_samples_skipped=num_samples_skipped, + tokenizer_name=tc.tokenizer_name_or_path, + max_seq_length=args.max_seq_length, + chat_template_name=tc.chat_template_name, + per_dataset_counts=per_dataset_counts, + per_dataset_tokens=per_dataset_tokens, + per_dataset_trainable_tokens=per_dataset_trainable_tokens, + per_dataset_filtered=per_dataset_filtered, + ) - tokenizer_output_dir = os.path.join(output_dir, "tokenizer") - os.makedirs(tokenizer_output_dir, exist_ok=True) - print(f"Saving tokenizer to {tokenizer_output_dir}...") - tc.tokenizer.save_pretrained(tokenizer_output_dir) - print("Tokenizer saved successfully!") - print("Verify these values match the tokenizer config used in Olmo-core:") - print(f"Tokenizer vocab_size: {tc.tokenizer.vocab_size}") - print(f"Tokenizer bos_token_id: {tc.tokenizer.bos_token_id}") - print(f"Tokenizer pad_token_id: {tc.tokenizer.pad_token_id}") - print(f"Tokenizer eos_token_id: {tc.tokenizer.eos_token_id}") +def write_dataset_statistics( + output_dir: str, + dataset_statistics: Dict[str, Any], + total_instances: int, + total_tokens: int, + total_trainable_tokens: int, + num_samples_skipped: int, + tokenizer_name: str, + max_seq_length: Optional[int], + chat_template_name: Optional[str], + per_dataset_counts: Dict[str, int], + per_dataset_tokens: Dict[str, int], + per_dataset_trainable_tokens: Dict[str, int], + per_dataset_filtered: Dict[str, int], +): + """Write dataset statistics to both text and JSON files.""" + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + # Merge pre-transformation stats with post-shuffle actual counts + merged_stats = [] + pre_transform_stats = {stat["dataset_name"]: stat for stat in dataset_statistics.get("per_dataset_stats", [])} + + for dataset_name in per_dataset_counts: + pre_stat = pre_transform_stats.get(dataset_name, {}) + merged_stat = { + "dataset_name": dataset_name, + "dataset_split": pre_stat.get("dataset_split", "unknown"), + "initial_instances": pre_stat.get("initial_instances", "N/A"), + "instances_after_transformation": pre_stat.get("final_instances", "N/A"), + "instances_filtered_during_transformation": pre_stat.get("instances_filtered", "N/A"), + "frac_or_num_samples": pre_stat.get("frac_or_num_samples"), + # Upsampling information + "original_dataset_size": pre_stat.get("original_dataset_size"), + "is_upsampled": pre_stat.get("is_upsampled", False), + "upsampling_factor": pre_stat.get("upsampling_factor", 1.0), + # Post-shuffle actual statistics + "final_instances_in_output": per_dataset_counts[dataset_name], + "final_tokens_in_output": per_dataset_tokens[dataset_name], + "final_trainable_tokens_in_output": per_dataset_trainable_tokens[dataset_name], + "instances_filtered_after_tokenization": per_dataset_filtered[dataset_name], + "avg_tokens_per_instance": per_dataset_tokens[dataset_name] / per_dataset_counts[dataset_name] if per_dataset_counts[dataset_name] > 0 else 0, + "percentage_of_total_tokens": (per_dataset_tokens[dataset_name] / total_tokens * 100) if total_tokens > 0 else 0, + "percentage_of_total_instances": (per_dataset_counts[dataset_name] / total_instances * 100) if total_instances > 0 else 0, + } + merged_stats.append(merged_stat) + + # Prepare statistics data + stats_data = { + "timestamp": timestamp, + "output_directory": output_dir, + "configuration": { + "tokenizer": tokenizer_name, + "max_sequence_length": max_seq_length, + "chat_template": chat_template_name, + }, + "per_dataset_statistics": merged_stats, + "overall_statistics": { + "total_datasets": len(per_dataset_counts), + "total_instances": total_instances, + "total_tokens": total_tokens, + "trainable_tokens": total_trainable_tokens, + "trainable_percentage": (total_trainable_tokens / total_tokens * 100) if total_tokens > 0 else 0, + "instances_filtered": num_samples_skipped, + "average_sequence_length": total_tokens / total_instances if total_instances > 0 else 0, + } + } + + # Write JSON file + json_path = os.path.join(output_dir, "dataset_statistics.json") + with open(json_path, "w") as f: + json.dump(stats_data, f, indent=2) + print(f"Written dataset statistics to {json_path}") + + # Write human-readable text file + text_path = os.path.join(output_dir, "dataset_statistics.txt") + with open(text_path, "w") as f: + f.write("Dataset Statistics Report\n") + f.write("=" * 80 + "\n") + f.write(f"Generated: {timestamp}\n") + f.write(f"Output Directory: {output_dir}\n\n") + + f.write("Configuration:\n") + f.write("-" * 40 + "\n") + f.write(f"- Tokenizer: {tokenizer_name}\n") + f.write(f"- Max Sequence Length: {max_seq_length}\n") + f.write(f"- Chat Template: {chat_template_name}\n\n") + + f.write("Per-Dataset Statistics:\n") + f.write("=" * 80 + "\n") + + for stat in stats_data["per_dataset_statistics"]: + f.write(f"\nDataset: {stat['dataset_name']}\n") + f.write(f"- Split: {stat['dataset_split']}\n") + + # Pre-transformation statistics + f.write("\nPre-transformation:\n") + f.write(f" - Instances loaded: {stat.get('initial_instances', 'N/A')}\n") + f.write(f" - Instances after transformation: {stat.get('instances_after_transformation', 'N/A')}\n") + f.write(f" - Instances filtered during transformation: {stat.get('instances_filtered_during_transformation', 'N/A')}\n") + + if stat.get('frac_or_num_samples') is not None: + if isinstance(stat['frac_or_num_samples'], float): + f.write(f" - Sampling fraction: {stat['frac_or_num_samples']}\n") + else: + f.write(f" - Sample count: {stat['frac_or_num_samples']}\n") + + # Show upsampling information if applicable + if stat.get('is_upsampled', False): + f.write(f" - Original dataset size: {stat.get('original_dataset_size', 'N/A')}\n") + f.write(f" - Upsampling factor: {stat.get('upsampling_factor', 1.0):.2f}x\n") + f.write(f" - Upsampled to: {stat.get('instances_after_transformation', 'N/A')} instances\n") + + # Post-shuffle statistics (actual output) + f.write("\nFinal output statistics (after shuffling):\n") + f.write(f" - Instances in output: {stat['final_instances_in_output']:,}\n") + f.write(f" - Total tokens: {stat['final_tokens_in_output']:,}\n") + f.write(f" - Trainable tokens: {stat['final_trainable_tokens_in_output']:,}\n") + f.write(f" - Instances with no labels: {stat['instances_filtered_after_tokenization']}\n") + f.write(f" - Average tokens per instance: {stat['avg_tokens_per_instance']:.1f}\n") + f.write(f" - Percentage of total tokens: {stat['percentage_of_total_tokens']:.1f}%\n") + f.write(f" - Percentage of total instances: {stat['percentage_of_total_instances']:.1f}%\n") + + f.write("\n" + "=" * 80 + "\n") + f.write("Overall Statistics:\n") + f.write("=" * 80 + "\n") + f.write(f"- Total datasets: {stats_data['overall_statistics']['total_datasets']}\n") + f.write(f"- Total instances: {stats_data['overall_statistics']['total_instances']:,}\n") + f.write(f"- Total tokens: {stats_data['overall_statistics']['total_tokens']:,}\n") + f.write(f"- Trainable tokens: {stats_data['overall_statistics']['trainable_tokens']:,} ") + f.write(f"({stats_data['overall_statistics']['trainable_percentage']:.1f}%)\n") + f.write(f"- Instances filtered out: {stats_data['overall_statistics']['instances_filtered']}\n") + f.write(f"- Average sequence length: {stats_data['overall_statistics']['average_sequence_length']:.1f}\n") + + print(f"Written human-readable statistics to {text_path}") if __name__ == "__main__":