diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 5ca23e4b55..7f4f8565d3 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -901,6 +901,18 @@ def tokenizer(self): TOKENIZED_SFT_DATASET_KEYS = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY] TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY, DATASET_ORIGIN_KEY] + +def _remove_dataset_source_field(dataset: Dataset) -> Dataset: + """Remove dataset_source field from dataset if it exists. + + This should be called after statistics collection but before returning + the final dataset to avoid storing unnecessary metadata in cached datasets. + """ + if DATASET_ORIGIN_KEY in dataset.column_names: + return dataset.remove_columns([DATASET_ORIGIN_KEY]) + return dataset + + # Preference dataset # NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only # also we don't really need `CHOSEN_ATTENTION_MASK_KEY` and `REJECTED_ATTENTION_MASK_KEY` @@ -1436,10 +1448,11 @@ def select_samples(self, target_size: int): extra_indices = rng.choice(original_size, size=extra_samples, replace=False) indices.extend(extra_indices.tolist()) - print( - f"Upsampling dataset {self.dataset_name} from {original_size} to {target_size} samples " - f"({full_repeats} full repeats + {extra_samples} random samples)" - ) + if target_size > original_size: + print( + f"Upsampling dataset {self.dataset_name} from {original_size} to {target_size} samples " + f"({full_repeats} full repeats + {extra_samples} random samples)" + ) return self.dataset.select(indices) @@ -1605,12 +1618,8 @@ def save_config(self, config_hash: str, dcs: List[DatasetConfig], tc: TokenizerC json.dump(config_dict, f, indent=2) def load_or_transform_dataset( - self, - dcs: List[DatasetConfig], - tc: TokenizerConfig, - dataset_skip_cache: bool = False, - return_statistics: bool = False, - ) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]: + self, dcs: List[DatasetConfig], tc: TokenizerConfig, dataset_skip_cache: bool = False + ) -> Tuple[Dataset, Dict[str, Any]]: """Load dataset from local cache if it exists, otherwise transform and cache it locally.""" cache_path = self.get_cache_path() @@ -1618,17 +1627,15 @@ def load_or_transform_dataset( if os.path.exists(cache_path) and not dataset_skip_cache: print(f"✅ Found cached dataset at {cache_path}") dataset = Dataset.load_from_disk(cache_path, keep_in_memory=True) - if return_statistics: - # Load statistics from cache if available - stats_path = os.path.join(cache_path, "dataset_statistics.json") - if os.path.exists(stats_path): - with open(stats_path, "r") as f: - statistics = json.load(f) - return dataset, statistics - else: - # Return empty statistics if not cached - return dataset, {"per_dataset_stats": [], "dataset_order": []} - return dataset, None + # Load statistics from cache if available + stats_path = os.path.join(cache_path, "dataset_statistics.json") + if os.path.exists(stats_path): + with open(stats_path, "r") as f: + statistics = json.load(f) + return dataset, statistics + else: + # Return empty statistics if not cached + return dataset, {"per_dataset_stats": [], "dataset_order": []} print(f"Cache not found or invalid, transforming datasets...") @@ -1683,9 +1690,7 @@ def load_or_transform_dataset( all_statistics = {"per_dataset_stats": dataset_statistics, "dataset_order": dataset_order} if dataset_skip_cache: - if return_statistics: - return combined_dataset, all_statistics - return combined_dataset, None + return combined_dataset, all_statistics # Save to local cache combined_dataset.save_to_disk(cache_path) @@ -1700,9 +1705,7 @@ def load_or_transform_dataset( print(f"✅ Found cached dataset at {cache_path}") loaded_dataset = Dataset.load_from_disk(cache_path, keep_in_memory=True) - if return_statistics: - return loaded_dataset, all_statistics - return loaded_dataset, None + return loaded_dataset, all_statistics def get_cached_dataset( @@ -1711,15 +1714,12 @@ def get_cached_dataset( hf_entity: Optional[str] = None, dataset_local_cache_dir: Optional[str] = None, dataset_skip_cache: bool = False, - return_statistics: bool = False, ) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]: if dataset_local_cache_dir is not None: cache = LocalDatasetTransformationCache(dataset_local_cache_dir=dataset_local_cache_dir) else: cache = DatasetTransformationCache(hf_entity=hf_entity) - return cache.load_or_transform_dataset( - dcs, tc, dataset_skip_cache=dataset_skip_cache, return_statistics=return_statistics - )[0] + return cache.load_or_transform_dataset(dcs, tc, dataset_skip_cache=dataset_skip_cache) def get_cached_dataset_tulu_with_statistics( @@ -1734,7 +1734,6 @@ def get_cached_dataset_tulu_with_statistics( hf_entity: Optional[str] = None, dataset_local_cache_dir: str = "local_dataset_cache", dataset_skip_cache: bool = False, - return_statistics: bool = False, ) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]: dcs = [] if dataset_config_hash is None: @@ -1787,9 +1786,10 @@ def get_cached_dataset_tulu_with_statistics( ) elif dataset_cache_mode == "hf": cache = DatasetTransformationCache(config_hash=dataset_config_hash, hf_entity=hf_entity) - return cache.load_or_transform_dataset( - dcs, tc, dataset_skip_cache=dataset_skip_cache, return_statistics=return_statistics - ) + + dataset, statistics = cache.load_or_transform_dataset(dcs, tc, dataset_skip_cache=dataset_skip_cache) + + return _remove_dataset_source_field(dataset), statistics def get_cached_dataset_tulu( @@ -1817,7 +1817,6 @@ def get_cached_dataset_tulu( hf_entity, dataset_local_cache_dir, dataset_skip_cache, - return_statistics=False, )[0] diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 2c3e08a618..9164d63157 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -331,7 +331,7 @@ class FlatArguments: metadata={"help": "Whether to use packing/padding-free collation via TensorDataCollatorWithFlattening"}, ) verbose: bool = field( - default=True, metadata={"help": "Optionally print additional statistics at each reporting period"} + default=False, metadata={"help": "Optionally print additional statistics at each reporting period"} ) def __post_init__(self): @@ -455,6 +455,8 @@ def main(args: FlatArguments, tc: TokenizerConfig): }, ) wandb_tracker = accelerator.get_tracker("wandb") + else: + wandb_tracker = None # for later eval launching if accelerator.is_main_process: pprint([args, tc]) @@ -996,7 +998,7 @@ def main(args: FlatArguments, tc: TokenizerConfig): path=args.output_dir, leaderboard_name=args.hf_repo_revision, oe_eval_max_length=args.oe_eval_max_length, - wandb_url=wandb_tracker.run.get_url(), + wandb_url=wandb_tracker.run.get_url() if wandb_tracker is not None else None, oe_eval_tasks=args.oe_eval_tasks, gs_bucket_path=args.gs_bucket_path, ) diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index 52e9b93ad5..49cfa64baa 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -426,7 +426,7 @@ def save_with_accelerate( # otherwise, we get an error thrown at save time. if "olmo" in chat_template_name: # New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|> - logger.log(f"Detected olmo chat template: {chat_template_name}, updating model generation config.") + logger.info(f"Detected olmo chat template: {chat_template_name}, updating model generation config.") model.generation_config = get_olmo3_generation_config(tokenizer) else: model.generation_config = transformers.GenerationConfig( diff --git a/scripts/data/convert_sft_data_for_olmocore.py b/scripts/data/convert_sft_data_for_olmocore.py index 039dceb045..c8667377b1 100644 --- a/scripts/data/convert_sft_data_for_olmocore.py +++ b/scripts/data/convert_sft_data_for_olmocore.py @@ -161,7 +161,7 @@ def main(args: ConvertSFTDataArguments, tc: TokenizerConfig): ("sft_tulu_filter_v1", {}), # remove examples that don't have any labels ] - result = get_cached_dataset_tulu_with_statistics( + train_dataset, dataset_statistics = get_cached_dataset_tulu_with_statistics( dataset_mixer_list=args.dataset_mixer_list, dataset_mixer_list_splits=args.dataset_mixer_list_splits, tc=tc, @@ -172,11 +172,7 @@ def main(args: ConvertSFTDataArguments, tc: TokenizerConfig): dataset_config_hash=args.dataset_config_hash, dataset_local_cache_dir=args.dataset_local_cache_dir, dataset_skip_cache=args.dataset_skip_cache, - return_statistics=True, ) - - # Unpack the result - train_dataset, dataset_statistics = result train_dataset = train_dataset.shuffle()