diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py
index f77defefdc..c39779e46a 100644
--- a/open_instruct/dataset_transformation.py
+++ b/open_instruct/dataset_transformation.py
@@ -47,8 +47,9 @@
import os
from dataclasses import asdict, dataclass, field
from functools import cached_property
-from typing import Any, Dict, List, Literal, Optional
+from typing import Any, Dict, List, Literal, Optional, Tuple, Union
+import numpy as np
import torch
import transformers
from datasets import Dataset, concatenate_datasets, load_dataset
@@ -205,6 +206,132 @@ def visualize_token_role(tokens: list[int], masks: list[int], tokenizer: PreTrai
"{% endif %}"
"{% endfor %}"
),
+ # olmo-core-compatible chat templates:
+ # TODO: unify these 3 chat templates and send variables through the tokenizer's apply_chat_template kwargs
+ "olmo": (
+ "{% set has_system = messages|selectattr('role', 'equalto', 'system')|list|length > 0 %}"
+ "{% if not has_system %}"
+ "{{ '<|im_start|>system\nYou are OLMo 2, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <|im_end|>\n' }}"
+ "{% endif %}"
+ "{% for message in messages %}"
+ "{% if message['role'] == 'system' %}"
+ "{{ '<|im_start|>system\n' + message['content'] }}"
+ "{% if message.get('functions', none) is not none %}"
+ "{{ ' ' + message['functions'] + '<|im_end|>\n' }}"
+ "{% else %}"
+ "{{ ' You do not currently have access to any functions. <|im_end|>\n' }}"
+ "{% endif %}"
+ "{% elif message['role'] == 'user' %}"
+ "{% if message.get('functions', none) is not none %}"
+ "{{ '<|im_start|>user\n' + message['content'] + '\n' + '' + message['functions'] + '<|im_end|>\n' }}"
+ "{% else %}"
+ "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}"
+ "{% endif %}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ '<|im_start|>assistant\n' }}"
+ "{% if message.get('content', none) is not none %}"
+ "{{ message['content'] }}"
+ "{% endif %}"
+ "{% if message.get('function_calls', none) is not none %}"
+ "{{ '' + message['function_calls'] + '' }}"
+ "{% endif %}"
+ "{% if not loop.last %}"
+ "{{ '<|im_end|>' + '\n' }}"
+ "{% else %}"
+ "{{ eos_token }}"
+ "{% endif %}"
+ "{% elif message['role'] == 'environment' %}"
+ "{{ '<|im_start|>environment\n' + message['content'] + '<|im_end|>\n' }}"
+ "{% endif %}"
+ "{% if loop.last and add_generation_prompt %}"
+ "{{ '<|im_start|>assistant\n' }}"
+ "{% endif %}"
+ "{% endfor %}"
+ ),
+ "olmo_thinker": (
+ "{% set has_system = messages|selectattr('role', 'equalto', 'system')|list|length > 0 %}"
+ "{% if not has_system %}"
+ "{{ '<|im_start|>system\nYou are OLMo 2, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <|im_end|>\n' }}"
+ "{% endif %}"
+ "{% for message in messages %}"
+ "{% if message['role'] == 'system' %}"
+ "{{ '<|im_start|>system\n' + message['content'] }}"
+ "{% if message.get('functions', none) is not none %}"
+ "{{ ' ' + message['functions'] + '<|im_end|>\n' }}"
+ "{% else %}"
+ "{{ ' You do not currently have access to any functions. <|im_end|>\n' }}"
+ "{% endif %}"
+ "{% elif message['role'] == 'user' %}"
+ "{% if message.get('functions', none) is not none %}"
+ "{{ '<|im_start|>user\n' + message['content'] + '\n' + '' + message['functions'] + '<|im_end|>\n' }}"
+ "{% else %}"
+ "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}"
+ "{% endif %}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ '<|im_start|>assistant\n' }}"
+ "{% if message.get('content', none) is not none %}"
+ "{{ message['content'] }}"
+ "{% endif %}"
+ "{% if message.get('function_calls', none) is not none %}"
+ "{{ '' + message['function_calls'] + '' }}"
+ "{% endif %}"
+ "{% if not loop.last %}"
+ "{{ '<|im_end|>' + '\n' }}"
+ "{% else %}"
+ "{{ eos_token }}"
+ "{% endif %}"
+ "{% elif message['role'] == 'environment' %}"
+ "{{ '<|im_start|>environment\n' + message['content'] + '<|im_end|>\n' }}"
+ "{% endif %}"
+ "{% if loop.last and add_generation_prompt %}"
+ "{{ '<|im_start|>assistant\n' }}"
+ "{% endif %}"
+ "{% endfor %}"
+ ),
+ "olmo_thinker_r1_style": (
+ "A conversation between user and assistant. "
+ "The user asks a question, and the assistant solves it. "
+ "The assistant first thinks and reasons about the question "
+ "and after thinking provides the user with the answer. "
+ "The reasoning process is enclosed in tags "
+ "and the answer are enclosed in tags "
+ "so the full response is reasoning process here "
+ " answer here ."
+ "\n\n"
+ "{% for message in messages %}"
+ "{% if message['role'] == 'system' %}"
+ "{% if message.get('functions', none) is not none %}"
+ "{{ '<|im_start|>system\n' + message['content'] + '\n' + '' + message['functions'] + '<|im_end|>\n' }}"
+ "{% else %}"
+ "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}"
+ "{% endif %}"
+ "{% elif message['role'] == 'user' %}"
+ "{% if message.get('functions', none) is not none %}"
+ "{{ '<|im_start|>user\n' + message['content'] + '\n' + '' + message['functions'] + '<|im_end|>\n' }}"
+ "{% else %}"
+ "{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }}"
+ "{% endif %}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ '<|im_start|>assistant\n' }}"
+ "{% if message.get('content', none) is not none %}"
+ "{{ message['content'] }}"
+ "{% endif %}"
+ "{% if message.get('function_calls', none) is not none %}"
+ "{{ '' + message['function_calls'] + '' }}"
+ "{% endif %}"
+ "{% if not loop.last %}"
+ "{{ '<|im_end|>' + '\n' }}"
+ "{% else %}"
+ "{{ eos_token }}"
+ "{% endif %}"
+ "{% elif message['role'] == 'environment' %}"
+ "{{ '<|im_start|>environment\n' + message['content'] + '<|im_end|>\n' }}"
+ "{% endif %}"
+ "{% if loop.last and add_generation_prompt %}"
+ "{{ '<|im_start|>assistant\n' }}"
+ "{% endif %}"
+ "{% endfor %}"
+ ),
"tulu": (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
@@ -770,7 +897,9 @@ def tokenizer(self):
INPUT_IDS_KEY = "input_ids"
ATTENTION_MASK_KEY = "attention_mask"
LABELS_KEY = "labels"
+DATASET_SOURCE_KEY = "dataset_source"
TOKENIZED_SFT_DATASET_KEYS = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY]
+TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY, DATASET_SOURCE_KEY]
# Preference dataset
# NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only
@@ -1255,6 +1384,9 @@ class DatasetConfig:
# for tracking purposes
dataset_commit_hash: Optional[str] = None
+ frac_or_num_samples: Optional[Union[int, float]] = None
+ original_dataset_size: Optional[int] = None
+ is_upsampled: bool = False
def __post_init__(self):
# if the file exists locally, use the local file
@@ -1276,9 +1408,40 @@ def __post_init__(self):
def update_range(self, dataset_range: int):
self.dataset_range = dataset_range
- if self.dataset_range > len(self.dataset):
- raise ValueError("Dataset range exceeds dataset length")
- self.dataset = self.dataset.select(range(self.dataset_range))
+ original_size = len(self.dataset)
+ self.original_dataset_size = original_size
+
+ self.dataset = self.select_samples(self.dataset_range)
+ self.is_upsampled = dataset_range > original_size
+
+ def select_samples(self, target_size: int):
+ """Upsample dataset to target_size by repeating samples."""
+ original_size = len(self.dataset)
+
+ # Calculate how many full repeats and how many extra samples
+ full_repeats = target_size // original_size
+ extra_samples = target_size % original_size
+
+ # Create indices for upsampling
+ indices = []
+
+ # Add full repeats
+ for _ in range(full_repeats):
+ indices.extend(range(original_size))
+
+ # Add randomly sampled extra samples
+ if extra_samples > 0:
+ # Use numpy for reproducible random sampling
+ rng = np.random.RandomState(42) # Fixed seed for reproducibility
+ extra_indices = rng.choice(original_size, size=extra_samples, replace=False)
+ indices.extend(extra_indices.tolist())
+
+ print(
+ f"Upsampling dataset {self.dataset_name} from {original_size} to {target_size} samples "
+ f"({full_repeats} full repeats + {extra_samples} random samples)"
+ )
+
+ return self.dataset.select(indices)
def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig):
@@ -1290,6 +1453,13 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig):
tokenizer = tc.tokenizer
dataset = dc.dataset
+
+ # Add dataset source field to track origin after shuffling
+ dataset = dataset.map(
+ lambda example: {**example, DATASET_SOURCE_KEY: dc.dataset_name},
+ num_proc=num_proc,
+ desc=f"Adding dataset source field for {dc.dataset_name}",
+ )
for fn_name, fn_args in zip(dc.transform_fn, dc.transform_fn_args):
fn, fn_type = TRANSFORM_FNS[fn_name]
# always pass in tokenizer and other args if needed
@@ -1298,6 +1468,10 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig):
# perform the transformation
target_columns = dataset.column_names if dc.target_columns is None else dc.target_columns
+ # Always preserve dataset_source if it exists
+ if DATASET_SOURCE_KEY in dataset.column_names and DATASET_SOURCE_KEY not in target_columns:
+ target_columns = target_columns + [DATASET_SOURCE_KEY]
+
if fn_type == "map":
dataset = dataset.map(
fn,
@@ -1431,35 +1605,104 @@ def save_config(self, config_hash: str, dcs: List[DatasetConfig], tc: TokenizerC
json.dump(config_dict, f, indent=2)
def load_or_transform_dataset(
- self, dcs: List[DatasetConfig], tc: TokenizerConfig, dataset_skip_cache: bool = False
- ) -> Dataset:
+ self,
+ dcs: List[DatasetConfig],
+ tc: TokenizerConfig,
+ dataset_skip_cache: bool = False,
+ return_statistics: bool = False,
+ ) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]:
"""Load dataset from local cache if it exists, otherwise transform and cache it locally."""
cache_path = self.get_cache_path()
# Check if the cache exists
if os.path.exists(cache_path) and not dataset_skip_cache:
print(f"✅ Found cached dataset at {cache_path}")
- return Dataset.load_from_disk(cache_path, keep_in_memory=True)
+ dataset = Dataset.load_from_disk(cache_path, keep_in_memory=True)
+ if return_statistics:
+ # Load statistics from cache if available
+ stats_path = os.path.join(cache_path, "dataset_statistics.json")
+ if os.path.exists(stats_path):
+ with open(stats_path, "r") as f:
+ statistics = json.load(f)
+ return dataset, statistics
+ else:
+ # Return empty statistics if not cached
+ return dataset, {"per_dataset_stats": [], "dataset_order": []}
+ return dataset
print(f"Cache not found or invalid, transforming datasets...")
- # Transform each dataset
+ # Transform each dataset and collect statistics
transformed_datasets = []
+ dataset_statistics = []
+ dataset_order = []
+
for dc in dcs:
+ # Get initial dataset info
+ initial_size = len(dc.dataset) if dc.dataset else 0
+
dataset = get_dataset_v1(dc, tc)
transformed_datasets.append(dataset)
+ # Collect statistics for this dataset
+ stats = {
+ "dataset_name": dc.dataset_name,
+ "dataset_split": dc.dataset_split,
+ "initial_instances": initial_size,
+ "final_instances": len(dataset),
+ "instances_filtered": initial_size - len(dataset),
+ "frac_or_num_samples": dc.frac_or_num_samples,
+ "original_dataset_size": dc.original_dataset_size,
+ "is_upsampled": dc.is_upsampled,
+ "upsampling_factor": dc.dataset_range / dc.original_dataset_size
+ if dc.original_dataset_size and dc.original_dataset_size > 0
+ else 1.0,
+ }
+
+ # Count tokens if the dataset has been tokenized
+ if INPUT_IDS_KEY in dataset.column_names:
+ total_tokens = 0
+ trainable_tokens = 0
+ for sample in dataset:
+ tokens = len(sample[INPUT_IDS_KEY])
+ total_tokens += tokens
+ if LABELS_KEY in sample:
+ trainable_tokens += sum(1 for label in sample[LABELS_KEY] if label != -100)
+
+ stats["total_tokens"] = total_tokens
+ stats["trainable_tokens"] = trainable_tokens
+ stats["avg_tokens_per_instance"] = total_tokens / len(dataset) if len(dataset) > 0 else 0
+
+ dataset_statistics.append(stats)
+ dataset_order.append(dc.dataset_name)
+
# Combine datasets
combined_dataset = concatenate_datasets(transformed_datasets)
+
+ # Prepare return statistics
+ all_statistics = {"per_dataset_stats": dataset_statistics, "dataset_order": dataset_order}
+
if dataset_skip_cache:
+ if return_statistics:
+ return combined_dataset, all_statistics
return combined_dataset
# Save to local cache
combined_dataset.save_to_disk(cache_path)
self.save_config(self.config_hash, dcs, tc)
+
+ # Save statistics to cache
+ stats_path = os.path.join(cache_path, "dataset_statistics.json")
+ with open(stats_path, "w") as f:
+ json.dump(all_statistics, f, indent=2)
+
print(f"🚀 Saved transformed dataset to {cache_path}")
print(f"✅ Found cached dataset at {cache_path}")
- return Dataset.load_from_disk(cache_path, keep_in_memory=True)
+
+ loaded_dataset = Dataset.load_from_disk(cache_path, keep_in_memory=True)
+ if return_statistics:
+ return loaded_dataset, all_statistics
+ return loaded_dataset
def get_cached_dataset(
@@ -1468,12 +1711,15 @@ def get_cached_dataset(
hf_entity: Optional[str] = None,
dataset_local_cache_dir: Optional[str] = None,
dataset_skip_cache: bool = False,
-) -> Dataset:
+ return_statistics: bool = False,
+) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]:
if dataset_local_cache_dir is not None:
cache = LocalDatasetTransformationCache(dataset_local_cache_dir=dataset_local_cache_dir)
else:
cache = DatasetTransformationCache(hf_entity=hf_entity)
- return cache.load_or_transform_dataset(dcs, tc, dataset_skip_cache=dataset_skip_cache)
+ return cache.load_or_transform_dataset(
+ dcs, tc, dataset_skip_cache=dataset_skip_cache, return_statistics=return_statistics
+ )
def get_cached_dataset_tulu(
@@ -1488,7 +1734,8 @@ def get_cached_dataset_tulu(
hf_entity: Optional[str] = None,
dataset_local_cache_dir: str = "local_dataset_cache",
dataset_skip_cache: bool = False,
-) -> Dataset:
+ return_statistics: bool = False,
+) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]:
dcs = []
if dataset_config_hash is None:
if len(dataset_mixer_list_splits) == 1:
@@ -1515,11 +1762,22 @@ def get_cached_dataset_tulu(
transform_fn=dataset_transform_fn,
transform_fn_args=transform_fn_args,
target_columns=target_columns,
+ frac_or_num_samples=frac_or_num_samples,
)
- if frac_or_num_samples > 1.0:
- new_range = int(frac_or_num_samples)
+
+ # Calculate target size properly handling fractional upsampling
+ original_size = len(dataset_config.dataset)
+ if isinstance(frac_or_num_samples, int) and frac_or_num_samples > original_size:
+ # Absolute number larger than dataset size - use as-is for upsampling
+ new_range = frac_or_num_samples
+ elif isinstance(frac_or_num_samples, float):
+ # Fractional sampling (can be > 1.0 for upsampling)
+ new_range = int(frac_or_num_samples * original_size)
else:
- new_range = int(frac_or_num_samples * len(dataset_config.dataset))
+ # Integer <= dataset size, use as absolute count
+ new_range = int(frac_or_num_samples)
+
+ print(f"Dataset {dataset_name}: {original_size} -> {new_range} samples (factor: {frac_or_num_samples})")
dataset_config.update_range(new_range)
dcs.append(dataset_config)
dataset_config_hash = compute_config_hash(dcs, tc)
@@ -1529,7 +1787,9 @@ def get_cached_dataset_tulu(
)
elif dataset_cache_mode == "hf":
cache = DatasetTransformationCache(config_hash=dataset_config_hash, hf_entity=hf_entity)
- return cache.load_or_transform_dataset(dcs, tc, dataset_skip_cache=dataset_skip_cache)
+ return cache.load_or_transform_dataset(
+ dcs, tc, dataset_skip_cache=dataset_skip_cache, return_statistics=return_statistics
+ )
def test_sft_dpo_same_tokenizer():
diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py
index 7f644ec925..51abe922c8 100644
--- a/open_instruct/grpo_fast.py
+++ b/open_instruct/grpo_fast.py
@@ -72,7 +72,7 @@
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from rich.pretty import pprint
from torch.utils.tensorboard import SummaryWriter
-from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, get_scheduler
+from transformers import AutoModelForCausalLM, GenerationConfig, PreTrainedModel, PreTrainedTokenizer, get_scheduler
from transformers.integrations import HfDeepSpeedConfig
from vllm import SamplingParams
@@ -983,8 +983,19 @@ def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: Dict[st
checkpoint_state_dir, args.gs_checkpoint_state_dir
)
- def save_model(self, output_dir: str) -> None:
+ def save_model(self, output_dir: str, chat_template_name: str, tokenizer: PreTrainedTokenizer) -> None:
model_to_save = self.model
+ if "olmo" in chat_template_name:
+ # New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|>
+ model_to_save.generation_config = GenerationConfig(
+ temperature=None,
+ top_p=None,
+ eos_token_id=[
+ tokenizer.convert_tokens_to_ids("<|im_end|>"),
+ tokenizer.convert_tokens_to_ids("<|endoftext|>"),
+ ],
+ )
+
if self.rank == 0:
os.makedirs(output_dir, exist_ok=True)
@@ -1774,6 +1785,7 @@ def one_training_step(
train_dataset,
writer,
wandb_url,
+ chat_template_name,
):
"""Train the model for one step."""
update_ref_policy_future = []
@@ -1820,7 +1832,12 @@ def one_training_step(
checkpoint_dir = f"{args.output_dir}_checkpoints"
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
logger.info(f"Saving model at step {training_step} to {step_dir}")
- ray.get([policy_group.models[i].save_model.remote(step_dir) for i in range(args.world_size)])
+ ray.get(
+ [
+ policy_group.models[i].save_model.remote(step_dir, chat_template_name, tokenizer)
+ for i in range(args.world_size)
+ ]
+ )
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job():
leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}"
for i in range(args.world_size):
@@ -1917,11 +1934,23 @@ def maybe_evaluate(
logger.warning("[Main Thread] 🙈 Evaluation responses not received")
-def save_final_model(args: Args, policy_group: ModelGroup, training_step: int, wandb_url: str):
+def save_final_model(
+ args: Args,
+ policy_group: ModelGroup,
+ tokenizer: PreTrainedTokenizer,
+ training_step: int,
+ wandb_url: str,
+ chat_template_name: str,
+):
"""Save the final model and launch evaluation jobs if configured."""
logger.info(f"Saving final model at step {training_step} to {args.output_dir}")
with Timer("[Main Thread] 🗡️ Saving model"):
- ray.get([policy_group.models[i].save_model.remote(args.output_dir) for i in range(args.world_size)])
+ ray.get(
+ [
+ policy_group.models[i].save_model.remote(args.output_dir, chat_template_name, tokenizer)
+ for i in range(args.world_size)
+ ]
+ )
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job():
leaderboard_name = args.hf_repo_revision
for i in range(args.world_size):
@@ -2189,6 +2218,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
train_dataset,
writer,
wandb_url,
+ tc.chat_template_name,
)
maybe_evaluate(
@@ -2204,7 +2234,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
writer,
)
- save_final_model(args, policy_group, training_step, wandb_url)
+ save_final_model(args, policy_group, tokenizer, training_step, wandb_url, tc.chat_template_name)
except Exception as e:
logger.error(f"Training error occurred: {str(e)}\n{traceback.format_exc()}")
diff --git a/open_instruct/grpo_vllm_thread_ray_gtrl.py b/open_instruct/grpo_vllm_thread_ray_gtrl.py
index be8121baca..22ce5bbfae 100644
--- a/open_instruct/grpo_vllm_thread_ray_gtrl.py
+++ b/open_instruct/grpo_vllm_thread_ray_gtrl.py
@@ -79,6 +79,7 @@
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
+ GenerationConfig,
PreTrainedModel,
PreTrainedTokenizer,
get_scheduler,
@@ -791,6 +792,7 @@ def train(
train_dataset: Dataset,
eval_dataset: Dataset,
tokenizer: PreTrainedTokenizer,
+ tc: TokenizerConfig,
vllm_engines: List[ray.actor.ActorHandle],
metrics_queue: RayQueue,
data_collator: Callable,
@@ -1378,7 +1380,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
checkpoint_dir = f"{args.output_dir}_checkpoints"
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
print(f"Saving model at step {training_step} to {step_dir}")
- self.save_model(self.model, step_dir)
+ self.save_model(self.model, tc.chat_template_name, tokenizer, step_dir)
if args.try_launch_beaker_eval_jobs_on_weka:
leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}"
if self.rank == 0 and is_beaker_job():
@@ -1404,7 +1406,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
print(f"Eval future {eval_futures[0]} is done")
eval_futures.popleft()
print(f"Saving final model at step {training_step} to {args.output_dir}")
- self.save_model(self.model, args.output_dir)
+ self.save_model(self.model, tc.chat_template_name, tokenizer, args.output_dir)
if args.try_launch_beaker_eval_jobs_on_weka:
leaderboard_name = args.hf_repo_revision
if self.rank == 0 and is_beaker_job():
@@ -1438,7 +1440,9 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True)
print("finished training")
- def save_model(self, model_to_save: PreTrainedModel, output_dir: str) -> None:
+ def save_model(
+ self, model_to_save: PreTrainedModel, chat_template_name: str, tokenizer: PreTrainedTokenizer, output_dir: str
+ ) -> None:
if self.rank == 0:
os.makedirs(output_dir, exist_ok=True)
@@ -1446,6 +1450,17 @@ def save_model(self, model_to_save: PreTrainedModel, output_dir: str) -> None:
if hasattr(model_to_save, "module"):
model_to_save = model_to_save.module
+ if "olmo" in chat_template_name:
+ # New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|>
+ model_to_save.generation_config = GenerationConfig(
+ temperature=None,
+ top_p=None,
+ eos_token_id=[
+ tokenizer.convert_tokens_to_ids("<|im_end|>"),
+ tokenizer.convert_tokens_to_ids("<|endoftext|>"),
+ ],
+ )
+
# gather parameters
output_state_dict = {}
for k, v in model_to_save.named_parameters():
diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py
index f3400d72c8..435bacd2b7 100644
--- a/open_instruct/model_utils.py
+++ b/open_instruct/model_utils.py
@@ -418,6 +418,7 @@ def save_with_accelerate(
# otherwise, we get an error thrown at save time.
if "olmo" in chat_template_name:
# New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|>
+ logger.log(f"Detected olmo chat template: {chat_template_name}, updating model generation config.")
model.generation_config = transformers.GenerationConfig(
temperature=None,
top_p=None,
diff --git a/open_instruct/ppo_fast.py b/open_instruct/ppo_fast.py
index 3ecb262e44..35df7e8662 100644
--- a/open_instruct/ppo_fast.py
+++ b/open_instruct/ppo_fast.py
@@ -76,6 +76,7 @@
AutoConfig,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
+ GenerationConfig,
PreTrainedModel,
PreTrainedTokenizer,
get_scheduler,
@@ -1074,7 +1075,7 @@ def train(
self.offload_to_cpu(self.model)
return metrics_list
- def save_model(self, output_dir: str) -> None:
+ def save_model(self, output_dir: str, chat_template_name: str, tokenizer: PreTrainedTokenizer) -> None:
model_to_save = self.model
if self.rank == 0:
os.makedirs(output_dir, exist_ok=True)
@@ -1083,6 +1084,17 @@ def save_model(self, output_dir: str) -> None:
if hasattr(model_to_save, "module"):
model_to_save = model_to_save.module
+ if "olmo" in chat_template_name:
+ # New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|>
+ model_to_save.generation_config = GenerationConfig(
+ temperature=None,
+ top_p=None,
+ eos_token_id=[
+ tokenizer.convert_tokens_to_ids("<|im_end|>"),
+ tokenizer.convert_tokens_to_ids("<|endoftext|>"),
+ ],
+ )
+
# gather parameters
output_state_dict = {}
for k, v in model_to_save.named_parameters():
@@ -1819,7 +1831,12 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn:
checkpoint_dir = f"{args.output_dir}_checkpoints"
step_dir = os.path.join(checkpoint_dir, f"step_{training_step}")
print(f"Saving model at step {training_step} to {step_dir}")
- ray.get([policy_group.models[i].save_model.remote(step_dir) for i in range(args.world_size)])
+ ray.get(
+ [
+ policy_group.models[i].save_model.remote(step_dir, tc.chat_template_name, tokenizer)
+ for i in range(args.world_size)
+ ]
+ )
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job():
leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}"
for i in range(args.world_size):
@@ -1889,7 +1906,12 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn:
print(f"Saving final model at step {training_step} to {args.output_dir}")
with Timer("[Main Thread] 🗡️ Saving model"):
- ray.get([policy_group.models[i].save_model.remote(args.output_dir) for i in range(args.world_size)])
+ ray.get(
+ [
+ policy_group.models[i].save_model.remote(args.output_dir, tc.chat_template_name, tokenizer)
+ for i in range(args.world_size)
+ ]
+ )
if args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job():
leaderboard_name = args.hf_repo_revision
for i in range(args.world_size):
diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl.py b/open_instruct/ppo_vllm_thread_ray_gtrl.py
index 5e2dc91ac3..5aefe743e1 100644
--- a/open_instruct/ppo_vllm_thread_ray_gtrl.py
+++ b/open_instruct/ppo_vllm_thread_ray_gtrl.py
@@ -77,6 +77,7 @@
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
+ GenerationConfig,
PreTrainedModel,
PreTrainedTokenizer,
get_scheduler,
@@ -1513,7 +1514,9 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True)
print("finished training")
- def save_model(self, model_to_save: PreTrainedModel, output_dir: str) -> None:
+ def save_model(
+ self, model_to_save: PreTrainedModel, chat_template_name: str, tokenizer: PreTrainedTokenizer, output_dir: str
+ ) -> None:
if self.rank == 0:
os.makedirs(output_dir, exist_ok=True)
@@ -1521,6 +1524,17 @@ def save_model(self, model_to_save: PreTrainedModel, output_dir: str) -> None:
if hasattr(model_to_save, "module"):
model_to_save = model_to_save.module
+ if "olmo" in chat_template_name:
+ # New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|>
+ model_to_save.generation_config = GenerationConfig(
+ temperature=None,
+ top_p=None,
+ eos_token_id=[
+ tokenizer.convert_tokens_to_ids("<|im_end|>"),
+ tokenizer.convert_tokens_to_ids("<|endoftext|>"),
+ ],
+ )
+
# gather parameters
output_state_dict = {}
for k, v in model_to_save.named_parameters():
diff --git a/scripts/data/convert_sft_data_for_olmocore.py b/scripts/data/convert_sft_data_for_olmocore.py
index 8c8fdf8478..5fce1a5d3c 100644
--- a/scripts/data/convert_sft_data_for_olmocore.py
+++ b/scripts/data/convert_sft_data_for_olmocore.py
@@ -3,7 +3,7 @@
implementation of the OLMo models (espeically for MoE), and so it can be preferable
to use it for training on next-token prediction tasks (e.g. SFT).
-OLMoCore accepts data in numpy mmap format. One file is for the input tokens, one for the labels, and one for the attention mask.
+OLMoCore accepts data in numpy mmap format. One file is for the input tokens and one for the labels mask.
Usage:
python scripts/data/convert_sft_data_for_olmocore.py \
@@ -13,35 +13,46 @@
--chat_template_name olmo
Ai2 Internal Usage:
- gantry run --cluster ai2/phobos-cirrascale --timeout -1 -y --budget ai2/oe-training \
+ gantry run --cluster ai2/neptune-cirrascale --timeout -1 -y --budget ai2/oe-training --workspace ai2/jacobm \
--install "curl -LsSf https://astral.sh/uv/install.sh | sh && /root/.local/bin/uv sync" \
--weka=oe-training-default:/weka/oe-training-default \
- -- \
- /root/.local/bin/uv run python scripts/data/convert_sft_data_for_olmocore.py \
+ --env-secret HF_TOKEN=HF_TOKEN \
+ --gpus 1 \
+ --priority high \
+ -- /root/.local/bin/uv run python scripts/data/convert_sft_data_for_olmocore.py \
+ --dataset_mixer_list allenai/tulu-3-sft-olmo-2-mixture 1.0 \
--tokenizer_name_or_path allenai/OLMo-2-1124-7B \
--output_dir /weka/oe-training-default/ai2-llm/tylerr/data/sft/tulu-3-sft-olmo-2-mixture-0225-olmocore \
- --chat_template_name olmo
+ --visualize True \
+ --chat_template_name olmo \
+ --max_seq_length 16384
NOTE: allenai/OLMo-2-1124-7B tokenizer is the same as allenai/dolma2-tokenizer, but allenai/OLMo-2-1124-7B
has additional metadata required for this script.
Recommendations:
- * Don't use max-seq-length, keep full sequences and allow Olmo-core to truncate if needed.
+ * Set max_seq_length, and use the same length you use during SFT
"""
+import gzip
+import json
import os
+import sys
from collections.abc import Mapping
from dataclasses import dataclass, field
-from typing import Any, List, Literal, Optional
+from datetime import datetime
+from typing import Any, Dict, List, Literal, Optional
import numpy as np
from tqdm import tqdm
from open_instruct.dataset_transformation import (
ATTENTION_MASK_KEY,
+ DATASET_SOURCE_KEY,
INPUT_IDS_KEY,
LABELS_KEY,
TOKENIZED_SFT_DATASET_KEYS,
+ TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE,
TokenizerConfig,
get_cached_dataset_tulu,
visualize_token,
@@ -72,14 +83,11 @@ class ConvertSFTDataArguments:
"""The list of transform functions to apply to the dataset."""
dataset_transform_fn: list[str] = field(
- default_factory=lambda: [
- "sft_tulu_tokenize_and_truncate_v1",
- "sft_tulu_filter_v1",
- ]
+ default_factory=lambda: ["sft_tulu_tokenize_and_truncate_v1", "sft_tulu_filter_v1"]
)
"""The columns to use for the dataset."""
- dataset_target_columns: List[str] = field(default_factory=lambda: TOKENIZED_SFT_DATASET_KEYS)
+ dataset_target_columns: List[str] = field(default_factory=lambda: TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE)
"""The mode to use for caching the dataset."""
dataset_cache_mode: Literal["hf", "local"] = "local"
@@ -102,6 +110,9 @@ class ConvertSFTDataArguments:
"""Visualize first token sequence"""
visualize: bool = field(default=False)
+ """Only write the tokenizer config to the output directory"""
+ tokenizer_config_only: bool = field(default=False)
+
def main(args: ConvertSFTDataArguments, tc: TokenizerConfig):
args.dataset_local_cache_dir = os.path.abspath(args.dataset_local_cache_dir)
@@ -110,16 +121,47 @@ def main(args: ConvertSFTDataArguments, tc: TokenizerConfig):
if os.path.exists(beaker_cache_dir):
args.dataset_local_cache_dir = beaker_cache_dir
+ print("Verify these values match the tokenizer config used in Olmo-core:")
+ print(f"Tokenizer vocab_size: {tc.tokenizer.vocab_size}")
+ print(f"Tokenizer bos_token_id: {tc.tokenizer.bos_token_id}")
+ print(f"Tokenizer pad_token_id: {tc.tokenizer.pad_token_id}")
+ print(f"Tokenizer eos_token_id: {tc.tokenizer.eos_token_id}")
+ print(f"Tokenizer chat_template: {tc.tokenizer.chat_template}")
+
+ output_dir = args.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+
+ tokenizer_output_dir = os.path.join(output_dir, "tokenizer")
+ os.makedirs(tokenizer_output_dir, exist_ok=True)
+ print(f"Saving tokenizer to {tokenizer_output_dir}...")
+ tc.tokenizer.save_pretrained(tokenizer_output_dir)
+
+ # Check if chat_template.jinja exists and add it to tokenizer_config.json
+ chat_template_path = os.path.join(tokenizer_output_dir, "chat_template.jinja")
+ tokenizer_config_path = os.path.join(tokenizer_output_dir, "tokenizer_config.json")
+ if os.path.exists(chat_template_path) and os.path.exists(tokenizer_config_path):
+ with open(chat_template_path, "r") as f:
+ chat_template_content = f.read()
+ with open(tokenizer_config_path, "r") as f:
+ tokenizer_config = json.load(f)
+ if "chat_template" not in tokenizer_config:
+ tokenizer_config["chat_template"] = chat_template_content
+ with open(tokenizer_config_path, "w") as f:
+ json.dump(tokenizer_config, f, indent=2)
+ print(f"Added chat_template from {chat_template_path} to tokenizer_config.json")
+
+ print("Tokenizer saved successfully!")
+
+ if args.tokenizer_config_only:
+ return
+
# TODO: improve configurability of transform factory
transform_functions_and_args = [
- (
- "sft_tulu_tokenize_and_truncate_v1",
- {"max_seq_length": args.max_seq_length},
- ),
+ ("sft_tulu_tokenize_and_truncate_v1", {"max_seq_length": args.max_seq_length}),
("sft_tulu_filter_v1", {}), # remove examples that don't have any labels
]
- train_dataset = get_cached_dataset_tulu(
+ result = get_cached_dataset_tulu(
dataset_mixer_list=args.dataset_mixer_list,
dataset_mixer_list_splits=args.dataset_mixer_list_splits,
tc=tc,
@@ -130,7 +172,13 @@ def main(args: ConvertSFTDataArguments, tc: TokenizerConfig):
dataset_config_hash=args.dataset_config_hash,
dataset_local_cache_dir=args.dataset_local_cache_dir,
dataset_skip_cache=args.dataset_skip_cache,
+ return_statistics=True,
)
+
+ # Unpack the result
+ train_dataset, dataset_statistics = result
+
+ train_dataset = train_dataset.shuffle()
if args.visualize:
print("Visualizing first example...")
@@ -144,66 +192,293 @@ def main(args: ConvertSFTDataArguments, tc: TokenizerConfig):
print("Collecting tokens from dataset...")
token_ids = []
- labels = []
- attention_mask = []
+ labels_mask = []
sample: Mapping[str, Any]
- for sample in tqdm(train_dataset, desc="Collecting tokens"): # type: ignore
- token_ids.extend(sample[INPUT_IDS_KEY])
- labels.extend(sample[LABELS_KEY])
- attention_mask.extend(sample[ATTENTION_MASK_KEY])
-
- print(f"Total sequences: {len(train_dataset)}")
- print(f"Total tokens: {len(token_ids)}")
+ num_samples_skipped = 0
+ document_boundaries = []
+ current_position = 0
+
+ # Track per-dataset statistics using dataset_source field
+ per_dataset_counts = {}
+ per_dataset_tokens = {}
+ per_dataset_trainable_tokens = {}
+ per_dataset_filtered = {}
+
+ for idx, sample in enumerate(tqdm( # type: ignore
+ train_dataset,
+ desc="Collecting tokens",
+ file=sys.stdout,
+ bar_format="{l_bar}{bar}{r_bar}\n", # better printing in beaker
+ mininterval=10.0,
+ )):
+ sample_length = len(sample[INPUT_IDS_KEY])
+ sample_tokens = sample[INPUT_IDS_KEY]
+ sample_labels = sample[LABELS_KEY]
+ dataset_source = sample.get(DATASET_SOURCE_KEY, "unknown")
+
+ # Initialize counters for new datasets
+ if dataset_source not in per_dataset_counts:
+ per_dataset_counts[dataset_source] = 0
+ per_dataset_tokens[dataset_source] = 0
+ per_dataset_trainable_tokens[dataset_source] = 0
+ per_dataset_filtered[dataset_source] = 0
+
+ # Update per-dataset statistics
+ per_dataset_counts[dataset_source] += 1
+ per_dataset_tokens[dataset_source] += sample_length
+ trainable_tokens_in_sample = sum(1 for label in sample_labels if label != -100)
+ per_dataset_trainable_tokens[dataset_source] += trainable_tokens_in_sample
+
+ token_ids.extend(sample_tokens)
+ labels_mask.extend([1 if label != -100 else 0 for label in sample_labels])
+
+ # Record document boundary (start, end)
+ document_boundaries.append((current_position, current_position + sample_length))
+ current_position += sample_length
+
+ if all(label == -100 for label in sample_labels):
+ num_samples_skipped += 1
+ per_dataset_filtered[dataset_source] += 1
+
+ # Assert that attention mask is all 1s
+ assert all(mask == 1 for mask in sample[ATTENTION_MASK_KEY]), (
+ f"Expected all attention mask values to be 1, but found: {sample[ATTENTION_MASK_KEY]}"
+ )
+
+ # Calculate final statistics
+ total_instances = len(train_dataset)
+ total_tokens = len(token_ids)
+ total_trainable_tokens = sum(labels_mask)
+
+ print(f"Total sequences: {total_instances}")
+ print(f"Total tokens: {total_tokens}")
print(f"Maximum token ID: {max(token_ids)}")
+ print(f"Labels mask sum (trainable tokens): {total_trainable_tokens}")
print("Writing data to numpy files...")
-
- # Create output directory with tokenizer name
- output_dir = args.output_dir
- os.makedirs(output_dir, exist_ok=True)
+ print(f"Number of samples that should be skipped: {num_samples_skipped}")
def write_memmap_chunked(base_filename, data, dtype, max_size_gb=1):
"""Write data to multiple memmap files if size exceeds max_size_gb."""
# Calculate size in bytes
item_size = np.dtype(dtype).itemsize
- total_size_bytes = len(data) * item_size
max_size_bytes = max_size_gb * 1024**3
- if total_size_bytes <= max_size_bytes: # record in single file
- mmap = np.memmap(f"{base_filename}.npy", mode="w+", dtype=dtype, shape=(len(data),))
- mmap[:] = data
+ chunk_size = max_size_bytes // item_size
+ chunks = []
+ chunk_boundaries = []
+
+ for i in range(0, len(data), chunk_size):
+ chunk_data = data[i : i + chunk_size]
+ filename = f"{base_filename}_part_{i // chunk_size:04d}.npy"
+ mmap = np.memmap(filename, mode="w+", dtype=dtype, shape=(len(chunk_data),))
+ mmap[:] = chunk_data
mmap.flush()
- print(f"Written {base_filename}.npy ({total_size_bytes / 1024**3:.2f} GB)")
- return [mmap]
- else: # record in multiple files (if too large)
- chunk_size = max_size_bytes // item_size
- chunks = []
- for i in range(0, len(data), chunk_size):
- chunk_data = data[i : i + chunk_size]
- filename = f"{base_filename}_part_{i // chunk_size:04d}.npy"
- mmap = np.memmap(filename, mode="w+", dtype=dtype, shape=(len(chunk_data),))
- mmap[:] = chunk_data
- mmap.flush()
- chunks.append(mmap)
- print(f"Written {filename} ({len(chunk_data) * item_size / 1024**3:.2f} GB)")
- return chunks
+ chunks.append(mmap)
+ chunk_boundaries.append((i, i + len(chunk_data)))
+ print(f"Written {filename} ({len(chunk_data) * item_size / 1024**3:.2f} GB)")
+
+ return chunks, chunk_boundaries
+
+ def write_metadata_for_chunks(base_filename, document_boundaries, chunk_boundaries):
+ """Write metadata files for each chunk with document boundaries."""
+
+ for chunk_idx, (chunk_start, chunk_end) in enumerate(chunk_boundaries):
+ metadata_filename = f"{base_filename}_part_{chunk_idx:04d}.csv.gz"
+
+ with gzip.open(metadata_filename, "wt") as f:
+ # Find all documents that overlap with this chunk
+ for doc_start, doc_end in document_boundaries:
+ # Check if document overlaps with chunk
+ if doc_end > chunk_start and doc_start < chunk_end:
+ # Adjust boundaries relative to chunk start
+ adjusted_start = max(0, doc_start - chunk_start)
+ adjusted_end = min(chunk_end - chunk_start, doc_end - chunk_start)
+
+ # Only write if there's actual content in this chunk
+ if adjusted_end > adjusted_start:
+ f.write(f"{adjusted_start},{adjusted_end}\n")
+
+ print(f"Written metadata {metadata_filename}")
+
+ # Choose dtype based on vocab size - Olmo-core does the
+ # same operation to infer the dtype of the token_ids array.
+ vocab_size = tc.tokenizer.vocab_size
+ token_dtype = None
+ for dtype in (np.uint8, np.uint16, np.uint32, np.uint64):
+ if (vocab_size - 1) <= np.iinfo(dtype).max:
+ token_dtype = dtype
+ print(f"Using dtype '{dtype}' for token_ids based on vocab size {vocab_size}")
+ break
+ if token_dtype is None:
+ raise ValueError(f"Vocab size {vocab_size} is too big for any numpy integer dtype!")
print(f"Writing converted data to {output_dir}")
- write_memmap_chunked(f"{output_dir}/token_ids", token_ids, np.uint32)
- write_memmap_chunked(f"{output_dir}/labels", labels, np.int32)
- write_memmap_chunked(f"{output_dir}/attention_mask", attention_mask, np.int32)
+ _, token_chunk_boundaries = write_memmap_chunked(f"{output_dir}/token_ids", token_ids, token_dtype)
+ write_metadata_for_chunks(f"{output_dir}/token_ids", document_boundaries, token_chunk_boundaries)
+
+ # Write labels_mask using the same chunk boundaries as token_ids
+ for i, (start, end) in enumerate(token_chunk_boundaries):
+ chunk_data = labels_mask[start:end]
+ filename = f"{output_dir}/labels_mask_part_{i:04d}.npy"
+ mmap = np.memmap(filename, mode="w+", dtype=np.bool_, shape=(len(chunk_data),))
+ mmap[:] = chunk_data
+ mmap.flush()
+ print(f"Written {filename} ({len(chunk_data) * np.dtype(np.bool_).itemsize / 1024**3:.2f} GB)")
+
print("Data conversion completed successfully!")
+
+ # Write dataset statistics
+ write_dataset_statistics(
+ output_dir=output_dir,
+ dataset_statistics=dataset_statistics,
+ total_instances=total_instances,
+ total_tokens=total_tokens,
+ total_trainable_tokens=total_trainable_tokens,
+ num_samples_skipped=num_samples_skipped,
+ tokenizer_name=tc.tokenizer_name_or_path,
+ max_seq_length=args.max_seq_length,
+ chat_template_name=tc.chat_template_name,
+ per_dataset_counts=per_dataset_counts,
+ per_dataset_tokens=per_dataset_tokens,
+ per_dataset_trainable_tokens=per_dataset_trainable_tokens,
+ per_dataset_filtered=per_dataset_filtered,
+ )
- tokenizer_output_dir = os.path.join(output_dir, "tokenizer")
- os.makedirs(tokenizer_output_dir, exist_ok=True)
- print(f"Saving tokenizer to {tokenizer_output_dir}...")
- tc.tokenizer.save_pretrained(tokenizer_output_dir)
- print("Tokenizer saved successfully!")
- print("Verify these values match the tokenizer config used in Olmo-core:")
- print(f"Tokenizer vocab_size: {tc.tokenizer.vocab_size}")
- print(f"Tokenizer bos_token_id: {tc.tokenizer.bos_token_id}")
- print(f"Tokenizer pad_token_id: {tc.tokenizer.pad_token_id}")
- print(f"Tokenizer eos_token_id: {tc.tokenizer.eos_token_id}")
+def write_dataset_statistics(
+ output_dir: str,
+ dataset_statistics: Dict[str, Any],
+ total_instances: int,
+ total_tokens: int,
+ total_trainable_tokens: int,
+ num_samples_skipped: int,
+ tokenizer_name: str,
+ max_seq_length: Optional[int],
+ chat_template_name: Optional[str],
+ per_dataset_counts: Dict[str, int],
+ per_dataset_tokens: Dict[str, int],
+ per_dataset_trainable_tokens: Dict[str, int],
+ per_dataset_filtered: Dict[str, int],
+):
+ """Write dataset statistics to both text and JSON files."""
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+
+ # Merge pre-transformation stats with post-shuffle actual counts
+ merged_stats = []
+ pre_transform_stats = {stat["dataset_name"]: stat for stat in dataset_statistics.get("per_dataset_stats", [])}
+
+ for dataset_name in per_dataset_counts:
+ pre_stat = pre_transform_stats.get(dataset_name, {})
+ merged_stat = {
+ "dataset_name": dataset_name,
+ "dataset_split": pre_stat.get("dataset_split", "unknown"),
+ "initial_instances": pre_stat.get("initial_instances", "N/A"),
+ "instances_after_transformation": pre_stat.get("final_instances", "N/A"),
+ "instances_filtered_during_transformation": pre_stat.get("instances_filtered", "N/A"),
+ "frac_or_num_samples": pre_stat.get("frac_or_num_samples"),
+ # Upsampling information
+ "original_dataset_size": pre_stat.get("original_dataset_size"),
+ "is_upsampled": pre_stat.get("is_upsampled", False),
+ "upsampling_factor": pre_stat.get("upsampling_factor", 1.0),
+ # Post-shuffle actual statistics
+ "final_instances_in_output": per_dataset_counts[dataset_name],
+ "final_tokens_in_output": per_dataset_tokens[dataset_name],
+ "final_trainable_tokens_in_output": per_dataset_trainable_tokens[dataset_name],
+ "instances_filtered_after_tokenization": per_dataset_filtered[dataset_name],
+ "avg_tokens_per_instance": per_dataset_tokens[dataset_name] / per_dataset_counts[dataset_name] if per_dataset_counts[dataset_name] > 0 else 0,
+ "percentage_of_total_tokens": (per_dataset_tokens[dataset_name] / total_tokens * 100) if total_tokens > 0 else 0,
+ "percentage_of_total_instances": (per_dataset_counts[dataset_name] / total_instances * 100) if total_instances > 0 else 0,
+ }
+ merged_stats.append(merged_stat)
+
+ # Prepare statistics data
+ stats_data = {
+ "timestamp": timestamp,
+ "output_directory": output_dir,
+ "configuration": {
+ "tokenizer": tokenizer_name,
+ "max_sequence_length": max_seq_length,
+ "chat_template": chat_template_name,
+ },
+ "per_dataset_statistics": merged_stats,
+ "overall_statistics": {
+ "total_datasets": len(per_dataset_counts),
+ "total_instances": total_instances,
+ "total_tokens": total_tokens,
+ "trainable_tokens": total_trainable_tokens,
+ "trainable_percentage": (total_trainable_tokens / total_tokens * 100) if total_tokens > 0 else 0,
+ "instances_filtered": num_samples_skipped,
+ "average_sequence_length": total_tokens / total_instances if total_instances > 0 else 0,
+ }
+ }
+
+ # Write JSON file
+ json_path = os.path.join(output_dir, "dataset_statistics.json")
+ with open(json_path, "w") as f:
+ json.dump(stats_data, f, indent=2)
+ print(f"Written dataset statistics to {json_path}")
+
+ # Write human-readable text file
+ text_path = os.path.join(output_dir, "dataset_statistics.txt")
+ with open(text_path, "w") as f:
+ f.write("Dataset Statistics Report\n")
+ f.write("=" * 80 + "\n")
+ f.write(f"Generated: {timestamp}\n")
+ f.write(f"Output Directory: {output_dir}\n\n")
+
+ f.write("Configuration:\n")
+ f.write("-" * 40 + "\n")
+ f.write(f"- Tokenizer: {tokenizer_name}\n")
+ f.write(f"- Max Sequence Length: {max_seq_length}\n")
+ f.write(f"- Chat Template: {chat_template_name}\n\n")
+
+ f.write("Per-Dataset Statistics:\n")
+ f.write("=" * 80 + "\n")
+
+ for stat in stats_data["per_dataset_statistics"]:
+ f.write(f"\nDataset: {stat['dataset_name']}\n")
+ f.write(f"- Split: {stat['dataset_split']}\n")
+
+ # Pre-transformation statistics
+ f.write("\nPre-transformation:\n")
+ f.write(f" - Instances loaded: {stat.get('initial_instances', 'N/A')}\n")
+ f.write(f" - Instances after transformation: {stat.get('instances_after_transformation', 'N/A')}\n")
+ f.write(f" - Instances filtered during transformation: {stat.get('instances_filtered_during_transformation', 'N/A')}\n")
+
+ if stat.get('frac_or_num_samples') is not None:
+ if isinstance(stat['frac_or_num_samples'], float):
+ f.write(f" - Sampling fraction: {stat['frac_or_num_samples']}\n")
+ else:
+ f.write(f" - Sample count: {stat['frac_or_num_samples']}\n")
+
+ # Show upsampling information if applicable
+ if stat.get('is_upsampled', False):
+ f.write(f" - Original dataset size: {stat.get('original_dataset_size', 'N/A')}\n")
+ f.write(f" - Upsampling factor: {stat.get('upsampling_factor', 1.0):.2f}x\n")
+ f.write(f" - Upsampled to: {stat.get('instances_after_transformation', 'N/A')} instances\n")
+
+ # Post-shuffle statistics (actual output)
+ f.write("\nFinal output statistics (after shuffling):\n")
+ f.write(f" - Instances in output: {stat['final_instances_in_output']:,}\n")
+ f.write(f" - Total tokens: {stat['final_tokens_in_output']:,}\n")
+ f.write(f" - Trainable tokens: {stat['final_trainable_tokens_in_output']:,}\n")
+ f.write(f" - Instances with no labels: {stat['instances_filtered_after_tokenization']}\n")
+ f.write(f" - Average tokens per instance: {stat['avg_tokens_per_instance']:.1f}\n")
+ f.write(f" - Percentage of total tokens: {stat['percentage_of_total_tokens']:.1f}%\n")
+ f.write(f" - Percentage of total instances: {stat['percentage_of_total_instances']:.1f}%\n")
+
+ f.write("\n" + "=" * 80 + "\n")
+ f.write("Overall Statistics:\n")
+ f.write("=" * 80 + "\n")
+ f.write(f"- Total datasets: {stats_data['overall_statistics']['total_datasets']}\n")
+ f.write(f"- Total instances: {stats_data['overall_statistics']['total_instances']:,}\n")
+ f.write(f"- Total tokens: {stats_data['overall_statistics']['total_tokens']:,}\n")
+ f.write(f"- Trainable tokens: {stats_data['overall_statistics']['trainable_tokens']:,} ")
+ f.write(f"({stats_data['overall_statistics']['trainable_percentage']:.1f}%)\n")
+ f.write(f"- Instances filtered out: {stats_data['overall_statistics']['instances_filtered']}\n")
+ f.write(f"- Average sequence length: {stats_data['overall_statistics']['average_sequence_length']:.1f}\n")
+
+ print(f"Written human-readable statistics to {text_path}")
if __name__ == "__main__":