Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 35 additions & 36 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,18 @@ def tokenizer(self):
TOKENIZED_SFT_DATASET_KEYS = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY]
TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY, DATASET_ORIGIN_KEY]


def _remove_dataset_source_field(dataset: Dataset) -> Dataset:
"""Remove dataset_source field from dataset if it exists.

This should be called after statistics collection but before returning
the final dataset to avoid storing unnecessary metadata in cached datasets.
"""
if DATASET_ORIGIN_KEY in dataset.column_names:
return dataset.remove_columns([DATASET_ORIGIN_KEY])
return dataset


# Preference dataset
# NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only
# also we don't really need `CHOSEN_ATTENTION_MASK_KEY` and `REJECTED_ATTENTION_MASK_KEY`
Expand Down Expand Up @@ -1436,10 +1448,11 @@ def select_samples(self, target_size: int):
extra_indices = rng.choice(original_size, size=extra_samples, replace=False)
indices.extend(extra_indices.tolist())

print(
f"Upsampling dataset {self.dataset_name} from {original_size} to {target_size} samples "
f"({full_repeats} full repeats + {extra_samples} random samples)"
)
if target_size > original_size:
print(
f"Upsampling dataset {self.dataset_name} from {original_size} to {target_size} samples "
f"({full_repeats} full repeats + {extra_samples} random samples)"
)

return self.dataset.select(indices)

Expand Down Expand Up @@ -1605,30 +1618,24 @@ def save_config(self, config_hash: str, dcs: List[DatasetConfig], tc: TokenizerC
json.dump(config_dict, f, indent=2)

def load_or_transform_dataset(
self,
dcs: List[DatasetConfig],
tc: TokenizerConfig,
dataset_skip_cache: bool = False,
return_statistics: bool = False,
) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]:
self, dcs: List[DatasetConfig], tc: TokenizerConfig, dataset_skip_cache: bool = False
) -> Tuple[Dataset, Dict[str, Any]]:
"""Load dataset from local cache if it exists, otherwise transform and cache it locally."""
cache_path = self.get_cache_path()

# Check if the cache exists
if os.path.exists(cache_path) and not dataset_skip_cache:
print(f"✅ Found cached dataset at {cache_path}")
dataset = Dataset.load_from_disk(cache_path, keep_in_memory=True)
if return_statistics:
# Load statistics from cache if available
stats_path = os.path.join(cache_path, "dataset_statistics.json")
if os.path.exists(stats_path):
with open(stats_path, "r") as f:
statistics = json.load(f)
return dataset, statistics
else:
# Return empty statistics if not cached
return dataset, {"per_dataset_stats": [], "dataset_order": []}
return dataset, None
# Load statistics from cache if available
stats_path = os.path.join(cache_path, "dataset_statistics.json")
if os.path.exists(stats_path):
with open(stats_path, "r") as f:
statistics = json.load(f)
return dataset, statistics
else:
# Return empty statistics if not cached
return dataset, {"per_dataset_stats": [], "dataset_order": []}

print(f"Cache not found or invalid, transforming datasets...")

Expand Down Expand Up @@ -1683,9 +1690,7 @@ def load_or_transform_dataset(
all_statistics = {"per_dataset_stats": dataset_statistics, "dataset_order": dataset_order}

if dataset_skip_cache:
if return_statistics:
return combined_dataset, all_statistics
return combined_dataset, None
return combined_dataset, all_statistics

# Save to local cache
combined_dataset.save_to_disk(cache_path)
Expand All @@ -1700,9 +1705,7 @@ def load_or_transform_dataset(
print(f"✅ Found cached dataset at {cache_path}")

loaded_dataset = Dataset.load_from_disk(cache_path, keep_in_memory=True)
if return_statistics:
return loaded_dataset, all_statistics
return loaded_dataset, None
return loaded_dataset, all_statistics


def get_cached_dataset(
Expand All @@ -1711,15 +1714,12 @@ def get_cached_dataset(
hf_entity: Optional[str] = None,
dataset_local_cache_dir: Optional[str] = None,
dataset_skip_cache: bool = False,
return_statistics: bool = False,
) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]:
if dataset_local_cache_dir is not None:
cache = LocalDatasetTransformationCache(dataset_local_cache_dir=dataset_local_cache_dir)
else:
cache = DatasetTransformationCache(hf_entity=hf_entity)
return cache.load_or_transform_dataset(
dcs, tc, dataset_skip_cache=dataset_skip_cache, return_statistics=return_statistics
)[0]
return cache.load_or_transform_dataset(dcs, tc, dataset_skip_cache=dataset_skip_cache)


def get_cached_dataset_tulu_with_statistics(
Expand All @@ -1734,7 +1734,6 @@ def get_cached_dataset_tulu_with_statistics(
hf_entity: Optional[str] = None,
dataset_local_cache_dir: str = "local_dataset_cache",
dataset_skip_cache: bool = False,
return_statistics: bool = False,
) -> Union[Dataset, Tuple[Dataset, Dict[str, Any]]]:
dcs = []
if dataset_config_hash is None:
Expand Down Expand Up @@ -1787,9 +1786,10 @@ def get_cached_dataset_tulu_with_statistics(
)
elif dataset_cache_mode == "hf":
cache = DatasetTransformationCache(config_hash=dataset_config_hash, hf_entity=hf_entity)
return cache.load_or_transform_dataset(
dcs, tc, dataset_skip_cache=dataset_skip_cache, return_statistics=return_statistics
)

dataset, statistics = cache.load_or_transform_dataset(dcs, tc, dataset_skip_cache=dataset_skip_cache)

return _remove_dataset_source_field(dataset), statistics


def get_cached_dataset_tulu(
Expand Down Expand Up @@ -1817,7 +1817,6 @@ def get_cached_dataset_tulu(
hf_entity,
dataset_local_cache_dir,
dataset_skip_cache,
return_statistics=False,
)[0]


Expand Down
6 changes: 4 additions & 2 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class FlatArguments:
metadata={"help": "Whether to use packing/padding-free collation via TensorDataCollatorWithFlattening"},
)
verbose: bool = field(
default=True, metadata={"help": "Optionally print additional statistics at each reporting period"}
default=False, metadata={"help": "Optionally print additional statistics at each reporting period"}
)

def __post_init__(self):
Expand Down Expand Up @@ -455,6 +455,8 @@ def main(args: FlatArguments, tc: TokenizerConfig):
},
)
wandb_tracker = accelerator.get_tracker("wandb")
else:
wandb_tracker = None # for later eval launching

if accelerator.is_main_process:
pprint([args, tc])
Expand Down Expand Up @@ -996,7 +998,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
path=args.output_dir,
leaderboard_name=args.hf_repo_revision,
oe_eval_max_length=args.oe_eval_max_length,
wandb_url=wandb_tracker.run.get_url(),
wandb_url=wandb_tracker.run.get_url() if wandb_tracker is not None else None,
oe_eval_tasks=args.oe_eval_tasks,
gs_bucket_path=args.gs_bucket_path,
)
Expand Down
2 changes: 1 addition & 1 deletion open_instruct/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def save_with_accelerate(
# otherwise, we get an error thrown at save time.
if "olmo" in chat_template_name:
# New chat template has no bos token, and two eos tokens: <|im_end|> and <|endoftext|>
logger.log(f"Detected olmo chat template: {chat_template_name}, updating model generation config.")
logger.info(f"Detected olmo chat template: {chat_template_name}, updating model generation config.")
model.generation_config = get_olmo3_generation_config(tokenizer)
else:
model.generation_config = transformers.GenerationConfig(
Expand Down
6 changes: 1 addition & 5 deletions scripts/data/convert_sft_data_for_olmocore.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def main(args: ConvertSFTDataArguments, tc: TokenizerConfig):
("sft_tulu_filter_v1", {}), # remove examples that don't have any labels
]

result = get_cached_dataset_tulu_with_statistics(
train_dataset, dataset_statistics = get_cached_dataset_tulu_with_statistics(
dataset_mixer_list=args.dataset_mixer_list,
dataset_mixer_list_splits=args.dataset_mixer_list_splits,
tc=tc,
Expand All @@ -172,11 +172,7 @@ def main(args: ConvertSFTDataArguments, tc: TokenizerConfig):
dataset_config_hash=args.dataset_config_hash,
dataset_local_cache_dir=args.dataset_local_cache_dir,
dataset_skip_cache=args.dataset_skip_cache,
return_statistics=True,
)

# Unpack the result
train_dataset, dataset_statistics = result

train_dataset = train_dataset.shuffle()

Expand Down