Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,40 +979,6 @@ def preference_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenize
return any(x != -100 for x in row[CHOSEN_LABELS_KEY]) and any(x != -100 for x in row[REJECTED_LABELS_KEY])


def preference_span_search_mask_out(
row: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
max_seq_length: int,
chosen_key: str = DEFAULT_CHOSEN_KEY,
rejected_key: str = DEFAULT_REJECTED_KEY,
):
"""
Here we assume each example has a rejected and chosen field, both of which are a list of messages.
Each message is a dict with 'role' and 'content' fields.
We assume only the last message is different, and the prompt is contained in the list of messages.
"""
chosen_messages = row[chosen_key]
rejected_messages = row[rejected_key]
if len(chosen_messages) == 0:
raise ValueError("chosen messages field is empty.")
if len(rejected_messages) == 0:
raise ValueError("rejected messages field is empty.")

chosen_encoded = sft_span_seach_mask_out({DEFAULT_SFT_MESSAGES_KEY: chosen_messages}, tokenizer, max_seq_length)
rejected_encoded = sft_span_seach_mask_out(
{DEFAULT_SFT_MESSAGES_KEY: rejected_messages}, tokenizer, max_seq_length
)

return {
CHOSEN_INPUT_IDS_KEY: chosen_encoded["input_ids"],
CHOSEN_LABELS_KEY: chosen_encoded["labels"],
CHOSEN_ATTENTION_MASK_KEY: chosen_encoded["attention_mask"],
REJECTED_INPUT_IDS_KEY: rejected_encoded["input_ids"],
REJECTED_LABELS_KEY: rejected_encoded["labels"],
REJECTED_ATTENTION_MASK_KEY: rejected_encoded["attention_mask"],
}


def rlvr_tokenize_v1(
row: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
Expand Down Expand Up @@ -1098,7 +1064,6 @@ def rlvr_filter_v1(
"preference_tokenize_v1": (preference_tokenize_v1, "map"),
"preference_filter_v1": (preference_filter_v1, "filter"),
"preference_tulu_tokenize_and_truncate_v1": (preference_tulu_tokenize_and_truncate_v1_2, "map"),
"preference_span_search_mask_out": (preference_span_search_mask_out, "map"),
"preference_tulu_filter_v1": (preference_tulu_filter_v1, "filter"),
"rlvr_tokenize_v1": (rlvr_tokenize_v2, "map"),
"rlvr_filter_v1": (rlvr_filter_v1, "filter"),
Expand Down