Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions open_instruct/dataset_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
INPUT_IDS_PROMPT_KEY = "input_ids_prompt"
ATTENTION_MASK_PROMPT_KEY = "attention_mask_prompt"
GROUND_TRUTHS_KEY = "ground_truth"
DATASET_SOURCE_KEY = "dataset"
VERIFIER_SOURCE_KEY = "dataset"

# NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only
# also we don't really need `ATTENTION_MASK_CHOSEN_KEY` and `ATTENTION_MASK_REJECTED_KEY`
Expand Down Expand Up @@ -186,7 +186,7 @@ class DatasetConfig:
ground_truths_key: str = GROUND_TRUTHS_KEY

# columns name for dataset source
dataset_source_key: str = DATASET_SOURCE_KEY
dataset_source_key: str = VERIFIER_SOURCE_KEY

# columns names for binary dataset
binary_messages_key: str = SFT_MESSAGE_KEY
Expand Down Expand Up @@ -434,7 +434,7 @@ def tokenize_fn(row):
labels[: len(row[INPUT_IDS_PROMPT_KEY])] = [-100] * len(row[INPUT_IDS_PROMPT_KEY])
row[LABELS_KEY] = labels
row[GROUND_TRUTHS_KEY] = row[self.config.ground_truths_key]
row[DATASET_SOURCE_KEY] = row[self.config.dataset_source_key]
row[VERIFIER_SOURCE_KEY] = row[self.config.dataset_source_key]
return row

return dataset.map(
Expand Down Expand Up @@ -561,9 +561,13 @@ def __call__(self, batch: list[dict]):
ground_truths = [x[GROUND_TRUTHS_KEY] for x in batch]

# datasets
datasets = [x[DATASET_SOURCE_KEY] for x in batch]
datasets = [x[VERIFIER_SOURCE_KEY] for x in batch]

return {INPUT_IDS_PROMPT_KEY: padded_sequences, GROUND_TRUTHS_KEY: ground_truths, DATASET_SOURCE_KEY: datasets}
return {
INPUT_IDS_PROMPT_KEY: padded_sequences,
GROUND_TRUTHS_KEY: ground_truths,
VERIFIER_SOURCE_KEY: datasets,
}


if __name__ == "__main__":
Expand Down
65 changes: 47 additions & 18 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def get_tokenizer_tulu_v2_2(tc: "TokenizerConfig"):

DEFAULT_SFT_MESSAGES_KEY = "messages"
GROUND_TRUTHS_KEY = "ground_truth"
DATASET_SOURCE_KEY = "dataset"
VERIFIER_SOURCE_KEY = "dataset"


@dataclass
Expand All @@ -855,7 +855,7 @@ class TokenizerConfig:
tokenizer_revision: Optional[str] = None
trust_remote_code: bool = False
use_fast: bool = True
chat_template_name: str = "olmo"
chat_template_name: str = "tulu"
add_bos: bool = False
get_tokenizer_fn: str = "get_tokenizer_tulu_v2_2"

Expand Down Expand Up @@ -897,9 +897,9 @@ def tokenizer(self):
INPUT_IDS_KEY = "input_ids"
ATTENTION_MASK_KEY = "attention_mask"
LABELS_KEY = "labels"
DATASET_SOURCE_KEY = "dataset_source"
DATASET_ORIGIN_KEY = "dataset_source" # just 'dataset' clashes with RLVR stuff (see VERIFIER_SOURCE_KEY)
TOKENIZED_SFT_DATASET_KEYS = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY]
TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY, DATASET_SOURCE_KEY]
TOKENIZED_SFT_DATASET_KEYS_WITH_SOURCE = [INPUT_IDS_KEY, ATTENTION_MASK_KEY, LABELS_KEY, DATASET_ORIGIN_KEY]

# Preference dataset
# NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only
Expand Down Expand Up @@ -1244,7 +1244,7 @@ def rlvr_tokenize_v1(
tokenizer: PreTrainedTokenizer,
sft_messages_key: str = DEFAULT_SFT_MESSAGES_KEY,
ground_truths_key: str = GROUND_TRUTHS_KEY,
dataset_source_key: str = DATASET_SOURCE_KEY,
verifier_source_key: str = VERIFIER_SOURCE_KEY,
):
if len(row[sft_messages_key]) == 1:
prompt = row[sft_messages_key]
Expand All @@ -1256,7 +1256,7 @@ def rlvr_tokenize_v1(
labels = copy.deepcopy(row[INPUT_IDS_KEY])
row[LABELS_KEY] = labels
row[GROUND_TRUTHS_KEY] = row[ground_truths_key]
row[DATASET_SOURCE_KEY] = row[dataset_source_key]
row[VERIFIER_SOURCE_KEY] = row[verifier_source_key]
return row


Expand All @@ -1265,7 +1265,7 @@ def rlvr_tokenize_v2(
tokenizer: PreTrainedTokenizer,
sft_messages_key: str = DEFAULT_SFT_MESSAGES_KEY,
ground_truths_key: str = GROUND_TRUTHS_KEY,
dataset_source_key: str = DATASET_SOURCE_KEY,
verifier_source_key: str = VERIFIER_SOURCE_KEY,
):
if len(row[sft_messages_key]) == 1:
prompt = row[sft_messages_key]
Expand All @@ -1283,14 +1283,14 @@ def rlvr_tokenize_v2(
labels = copy.deepcopy(row[INPUT_IDS_KEY])
row[LABELS_KEY] = labels
row[GROUND_TRUTHS_KEY] = row[ground_truths_key]
row[DATASET_SOURCE_KEY] = row[dataset_source_key]
row[VERIFIER_SOURCE_KEY] = row[verifier_source_key]
# some basic transformations:
# if ground truths is a string, make it a list
if isinstance(row[ground_truths_key], str):
row[ground_truths_key] = [row[ground_truths_key]]
# if dataset source is a string, make it a list
if isinstance(row[dataset_source_key], str):
row[dataset_source_key] = [row[dataset_source_key]]
if isinstance(row[verifier_source_key], str):
row[verifier_source_key] = [row[verifier_source_key]]
# drop the messages field as it often causes issues.
row.pop(sft_messages_key)
return row
Expand Down Expand Up @@ -1456,7 +1456,7 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig):

# Add dataset source field to track origin after shuffling
dataset = dataset.map(
lambda example: {**example, DATASET_SOURCE_KEY: dc.dataset_name},
lambda example: {**example, DATASET_ORIGIN_KEY: dc.dataset_name},
num_proc=num_proc,
desc=f"Adding dataset source field for {dc.dataset_name}",
)
Expand All @@ -1469,8 +1469,8 @@ def get_dataset_v1(dc: DatasetConfig, tc: TokenizerConfig):
# perform the transformation
target_columns = dataset.column_names if dc.target_columns is None else dc.target_columns
# Always preserve dataset_source if it exists
if DATASET_SOURCE_KEY in dataset.column_names and DATASET_SOURCE_KEY not in target_columns:
target_columns = target_columns + [DATASET_SOURCE_KEY]
if DATASET_ORIGIN_KEY in dataset.column_names and DATASET_ORIGIN_KEY not in target_columns:
target_columns = target_columns + [DATASET_ORIGIN_KEY]

if fn_type == "map":
dataset = dataset.map(
Expand Down Expand Up @@ -1628,7 +1628,7 @@ def load_or_transform_dataset(
else:
# Return empty statistics if not cached
return dataset, {"per_dataset_stats": [], "dataset_order": []}
return dataset
return dataset, None

print(f"Cache not found or invalid, transforming datasets...")

Expand Down Expand Up @@ -1685,7 +1685,7 @@ def load_or_transform_dataset(
if dataset_skip_cache:
if return_statistics:
return combined_dataset, all_statistics
return combined_dataset
return combined_dataset, None

# Save to local cache
combined_dataset.save_to_disk(cache_path)
Expand All @@ -1702,7 +1702,7 @@ def load_or_transform_dataset(
loaded_dataset = Dataset.load_from_disk(cache_path, keep_in_memory=True)
if return_statistics:
return loaded_dataset, all_statistics
return loaded_dataset
return loaded_dataset, None


def get_cached_dataset(
Expand All @@ -1719,10 +1719,10 @@ def get_cached_dataset(
cache = DatasetTransformationCache(hf_entity=hf_entity)
return cache.load_or_transform_dataset(
dcs, tc, dataset_skip_cache=dataset_skip_cache, return_statistics=return_statistics
)
)[0]


def get_cached_dataset_tulu(
def get_cached_dataset_tulu_with_statistics(
dataset_mixer_list: List[str],
dataset_mixer_list_splits: List[str],
tc: TokenizerConfig,
Expand Down Expand Up @@ -1792,6 +1792,35 @@ def get_cached_dataset_tulu(
)


def get_cached_dataset_tulu(
dataset_mixer_list: List[str],
dataset_mixer_list_splits: List[str],
tc: TokenizerConfig,
dataset_transform_fn: List[str],
transform_fn_args: List[Dict[str, Any]],
target_columns: Optional[List[str]] = None,
dataset_cache_mode: Literal["hf", "local"] = "local",
dataset_config_hash: Optional[str] = None,
hf_entity: Optional[str] = None,
dataset_local_cache_dir: str = "local_dataset_cache",
dataset_skip_cache: bool = False,
) -> Dataset:
return get_cached_dataset_tulu_with_statistics(
dataset_mixer_list,
dataset_mixer_list_splits,
tc,
dataset_transform_fn,
transform_fn_args,
target_columns,
dataset_cache_mode,
dataset_config_hash,
hf_entity,
dataset_local_cache_dir,
dataset_skip_cache,
return_statistics=False,
)[0]


def test_sft_dpo_same_tokenizer():
base_to_sft_tc = TokenizerConfig(
tokenizer_name_or_path="meta-llama/Llama-3.1-8B", tokenizer_revision="main", chat_template_name="tulu"
Expand Down
9 changes: 4 additions & 5 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
except Exception:
pass
# isort: on

import asyncio
import json
import logging
Expand Down Expand Up @@ -77,9 +76,9 @@
from vllm import SamplingParams

from open_instruct.dataset_transformation import (
DATASET_SOURCE_KEY,
GROUND_TRUTHS_KEY,
INPUT_IDS_PROMPT_KEY,
VERIFIER_SOURCE_KEY,
TokenizerConfig,
get_cached_dataset_tulu,
visualize_token,
Expand Down Expand Up @@ -1729,7 +1728,7 @@ def sync_weights_and_prepare_prompts(
data_next = train_dataset[dataset_indices]
queries_next = data_next[INPUT_IDS_PROMPT_KEY]
ground_truths_next = data_next[GROUND_TRUTHS_KEY]
datasets_next = data_next[DATASET_SOURCE_KEY]
datasets_next = data_next[VERIFIER_SOURCE_KEY]
with Timer(
"[Main Thread] 🔄 Loading weights using shared memory"
if args.async_mode
Expand Down Expand Up @@ -2119,7 +2118,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
if eval_dataset is not None:
eval_prompt_token_ids = eval_dataset[:num_eval_samples][INPUT_IDS_PROMPT_KEY]
eval_ground_truths = eval_dataset[:num_eval_samples][GROUND_TRUTHS_KEY]
eval_dataset_names = eval_dataset[:num_eval_samples][DATASET_SOURCE_KEY]
eval_dataset_names = eval_dataset[:num_eval_samples][VERIFIER_SOURCE_KEY]
reward_fn = make_reward_fn(args)

# Start vLLM engines to process from queues
Expand Down Expand Up @@ -2153,7 +2152,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
data_next = train_dataset[dataset_indices]
queries_next = data_next[INPUT_IDS_PROMPT_KEY]
ground_truths_next = data_next[GROUND_TRUTHS_KEY]
datasets_next = data_next[DATASET_SOURCE_KEY]
datasets_next = data_next[VERIFIER_SOURCE_KEY]

# Split the initial batch using the split_and_insert_batch function
split_and_insert_batch(
Expand Down
8 changes: 4 additions & 4 deletions open_instruct/grpo_vllm_thread_ray_gtrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@

from open_instruct.dataset_processor import SimpleGenerateCollatorWithGroundTruth
from open_instruct.dataset_transformation import (
DATASET_SOURCE_KEY,
GROUND_TRUTHS_KEY,
INPUT_IDS_PROMPT_KEY,
VERIFIER_SOURCE_KEY,
TokenizerConfig,
get_cached_dataset_tulu,
visualize_token,
Expand Down Expand Up @@ -1027,7 +1027,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
].tolist() # can be simplified since we `remove_padding` later anyway
queries_next = data[INPUT_IDS_PROMPT_KEY].to(device)
ground_truths_next = data[GROUND_TRUTHS_KEY]
datasets_next = data[DATASET_SOURCE_KEY]
datasets_next = data[VERIFIER_SOURCE_KEY]
if self.rank == 0:
param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id)))

Expand Down Expand Up @@ -1066,7 +1066,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist()
queries_next = data[INPUT_IDS_PROMPT_KEY].to(device)
ground_truths_next = data[GROUND_TRUTHS_KEY]
datasets_next = data[DATASET_SOURCE_KEY]
datasets_next = data[VERIFIER_SOURCE_KEY]
with Timer("🔥🔥🔥 Loading weights using shared memory", noop=self.rank != 0):
broadcast_to_vllm()
if self.rank == 0:
Expand All @@ -1084,7 +1084,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist()
queries_next = data[INPUT_IDS_PROMPT_KEY].to(device)
ground_truths_next = data[GROUND_TRUTHS_KEY]
datasets_next = data[DATASET_SOURCE_KEY]
datasets_next = data[VERIFIER_SOURCE_KEY]
with Timer("🔥🔥🔥 Loading weights using shared memory", noop=self.rank != 0):
broadcast_to_vllm()
if self.rank == 0:
Expand Down
6 changes: 3 additions & 3 deletions open_instruct/ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@

from open_instruct.dataset_processor import SimpleGenerateCollatorWithGroundTruth
from open_instruct.dataset_transformation import (
DATASET_SOURCE_KEY,
GROUND_TRUTHS_KEY,
INPUT_IDS_PROMPT_KEY,
VERIFIER_SOURCE_KEY,
TokenizerConfig,
get_cached_dataset_tulu,
visualize_token,
Expand Down Expand Up @@ -1022,7 +1022,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
].tolist() # can be simplified since we `remove_padding` later anyway
queries_next = data[INPUT_IDS_PROMPT_KEY].to(device)
ground_truths_next = data[GROUND_TRUTHS_KEY]
datasets_next = data[DATASET_SOURCE_KEY]
datasets_next = data[VERIFIER_SOURCE_KEY]
if self.rank == 0:
param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id)))

Expand Down Expand Up @@ -1060,7 +1060,7 @@ def generate_with_engines(prompts: List[List[int]], sampling_params: SamplingPar
global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist()
queries_next = data[INPUT_IDS_PROMPT_KEY].to(device)
ground_truths_next = data[GROUND_TRUTHS_KEY]
datasets_next = data[DATASET_SOURCE_KEY]
datasets_next = data[VERIFIER_SOURCE_KEY]
with Timer("🔥🔥🔥 Loading weights using shared memory", noop=self.rank != 0):
broadcast_to_vllm()
if self.rank == 0:
Expand Down
10 changes: 5 additions & 5 deletions open_instruct/ppo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@
from vllm import SamplingParams

from open_instruct.dataset_transformation import (
DATASET_SOURCE_KEY,
GROUND_TRUTHS_KEY,
INPUT_IDS_PROMPT_KEY,
VERIFIER_SOURCE_KEY,
TokenizerConfig,
get_cached_dataset_tulu,
visualize_token,
Expand Down Expand Up @@ -1685,7 +1685,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn:
if eval_dataset is not None:
eval_prompt_token_ids = eval_dataset[:num_eval_samples][INPUT_IDS_PROMPT_KEY]
eval_ground_truths = eval_dataset[:num_eval_samples][GROUND_TRUTHS_KEY]
eval_dataset_names = eval_dataset[:num_eval_samples][DATASET_SOURCE_KEY]
eval_dataset_names = eval_dataset[:num_eval_samples][VERIFIER_SOURCE_KEY]
thread = threading.Thread(
target=vllm_generate_thread,
args=(
Expand Down Expand Up @@ -1723,7 +1723,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn:
data_next = train_dataset[next(iter_dataloader)]
queries_next = data_next[INPUT_IDS_PROMPT_KEY]
ground_truths_next = data_next[GROUND_TRUTHS_KEY]
datasets_next = data_next[DATASET_SOURCE_KEY]
datasets_next = data_next[VERIFIER_SOURCE_KEY]
queries_prompt_Q.put((queries_next, ground_truths_next, datasets_next))
param_prompt_Q.put((None, queries_next))

Expand All @@ -1743,7 +1743,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn:
data_next = train_dataset[next(iter_dataloader)]
queries_next = data_next[INPUT_IDS_PROMPT_KEY]
ground_truths_next = data_next[GROUND_TRUTHS_KEY]
datasets_next = data_next[DATASET_SOURCE_KEY]
datasets_next = data_next[VERIFIER_SOURCE_KEY]
with Timer("[Main Thread] 🔄 Loading weights using shared memory"):
ray.get([m.broadcast_to_vllm.remote() for m in policy_group.models])
queries_prompt_Q.put((queries_next, ground_truths_next, datasets_next))
Expand All @@ -1755,7 +1755,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn:
data_next = train_dataset[next(iter_dataloader)]
queries_next = data_next[INPUT_IDS_PROMPT_KEY]
ground_truths_next = data_next[GROUND_TRUTHS_KEY]
datasets_next = data_next[DATASET_SOURCE_KEY]
datasets_next = data_next[VERIFIER_SOURCE_KEY]
with Timer("🔄 Loading weights using shared memory"):
ray.get([m.broadcast_to_vllm.remote() for m in policy_group.models])
queries_prompt_Q.put((queries_next, ground_truths_next, datasets_next))
Expand Down
Loading