Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 73 additions & 6 deletions unsloth_zoo/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
"train_on_responses_only",
]


# From https://www.geeksforgeeks.org/longest-common-substring-array-strings/
# Longest Common Substring in an Array of Strings
def _longest_common_substring(arr):
def _old_longest_common_substring(arr):
n = len(arr)
s = arr[0]
l = len(s)
Expand All @@ -39,6 +38,71 @@ def _longest_common_substring(arr):
pass


def _longest_common_sublist(lists):
"""
Finds the longest common sublist among multiple lists.

Parameters:
lists (List[List[int]]): A list of lists.

Returns:
List[int]: The longest common sublist. If multiple sublists have the same maximum length,
one of them is returned. If there's no common sublist, an empty list is returned.
"""
if not lists: return []

# Find the minimum length among all lists
min_len = min(len(lst) for lst in lists)
if min_len == 0: return []

def has_common_sublist(length):
"""
Checks if there's a common sublist of the given length across all lists.

Returns:
(bool, List): Tuple of whether such a sublist exists and the sublist itself.
"""
common = set()
first = lists[0]
# Generate all possible sublists of the given length from the first list
for i in range(len(first) - length + 1):
sub = tuple(first[i:i + length])
common.add(sub)
pass

# Iterate over the remaining lists and retain only the common sublists
for lst in lists[1:]:
current = set()
for i in range(len(lst) - length + 1):
sub = tuple(lst[i:i + length])
if sub in common:
current.add(sub)
common = current
if not common:
return False, []
pass

# If common is not empty, return one of the common sublists
return True, list(common.pop())
pass

left, right = 1, min_len
result = []

while left <= right:
mid = left + (right - left) // 2
exists, sublist = has_common_sublist(mid)
if exists:
result = sublist # Update result with the latest found sublist
left = mid + 1 # Try to find a longer sublist
else:
right = mid - 1 # Try with a shorter length
pass

return result
pass


def _find_common_token_ids(component, tokenizer):
"""
\n### User:\n\n
Expand All @@ -54,7 +118,7 @@ def _find_common_token_ids(component, tokenizer):
if component.startswith (" "): left_text = " "
elif component.startswith("\n"): left_text = "\n"
stripped = component.strip()

# Add current pieces and also newlines
all_input_ids = []
for left in range(3):
Expand All @@ -68,10 +132,13 @@ def _find_common_token_ids(component, tokenizer):
all_input_ids.append(x)
pass
pass
substring = _longest_common_substring([str(x + [0]) for x in all_input_ids])
substring = substring.split(", ")[:-1]
substring = [int(x) for x in substring]

# Old longest common substring is replaced with actual longest common list of numbers
# substring = _old_longest_common_substring([str(x + [0]) for x in all_input_ids])
# substring = substring.split(", ")[:-1]
# substring = [int(x) for x in substring if x.isdigit()]
substring = _longest_common_sublist([x + [0] for x in all_input_ids])

# Also get rest of tokenized string
original = tokenizer(component, add_special_tokens = False).input_ids
# Get optional left and right
Expand Down