diff --git a/unsloth_zoo/dataset_utils.py b/unsloth_zoo/dataset_utils.py index 27bd27d33..da464f3bb 100644 --- a/unsloth_zoo/dataset_utils.py +++ b/unsloth_zoo/dataset_utils.py @@ -18,10 +18,9 @@ "train_on_responses_only", ] - # From https://www.geeksforgeeks.org/longest-common-substring-array-strings/ # Longest Common Substring in an Array of Strings -def _longest_common_substring(arr): +def _old_longest_common_substring(arr): n = len(arr) s = arr[0] l = len(s) @@ -39,6 +38,71 @@ def _longest_common_substring(arr): pass +def _longest_common_sublist(lists): + """ + Finds the longest common sublist among multiple lists. + + Parameters: + lists (List[List[int]]): A list of lists. + + Returns: + List[int]: The longest common sublist. If multiple sublists have the same maximum length, + one of them is returned. If there's no common sublist, an empty list is returned. + """ + if not lists: return [] + + # Find the minimum length among all lists + min_len = min(len(lst) for lst in lists) + if min_len == 0: return [] + + def has_common_sublist(length): + """ + Checks if there's a common sublist of the given length across all lists. + + Returns: + (bool, List): Tuple of whether such a sublist exists and the sublist itself. + """ + common = set() + first = lists[0] + # Generate all possible sublists of the given length from the first list + for i in range(len(first) - length + 1): + sub = tuple(first[i:i + length]) + common.add(sub) + pass + + # Iterate over the remaining lists and retain only the common sublists + for lst in lists[1:]: + current = set() + for i in range(len(lst) - length + 1): + sub = tuple(lst[i:i + length]) + if sub in common: + current.add(sub) + common = current + if not common: + return False, [] + pass + + # If common is not empty, return one of the common sublists + return True, list(common.pop()) + pass + + left, right = 1, min_len + result = [] + + while left <= right: + mid = left + (right - left) // 2 + exists, sublist = has_common_sublist(mid) + if exists: + result = sublist # Update result with the latest found sublist + left = mid + 1 # Try to find a longer sublist + else: + right = mid - 1 # Try with a shorter length + pass + + return result +pass + + def _find_common_token_ids(component, tokenizer): """ \n### User:\n\n @@ -54,7 +118,7 @@ def _find_common_token_ids(component, tokenizer): if component.startswith (" "): left_text = " " elif component.startswith("\n"): left_text = "\n" stripped = component.strip() - + # Add current pieces and also newlines all_input_ids = [] for left in range(3): @@ -68,10 +132,13 @@ def _find_common_token_ids(component, tokenizer): all_input_ids.append(x) pass pass - substring = _longest_common_substring([str(x + [0]) for x in all_input_ids]) - substring = substring.split(", ")[:-1] - substring = [int(x) for x in substring] + # Old longest common substring is replaced with actual longest common list of numbers + # substring = _old_longest_common_substring([str(x + [0]) for x in all_input_ids]) + # substring = substring.split(", ")[:-1] + # substring = [int(x) for x in substring if x.isdigit()] + substring = _longest_common_sublist([x + [0] for x in all_input_ids]) + # Also get rest of tokenized string original = tokenizer(component, add_special_tokens = False).input_ids # Get optional left and right