diff --git a/pipeline/common/datasets.py b/pipeline/common/datasets.py index 54c30d9f5..24594fa98 100644 --- a/pipeline/common/datasets.py +++ b/pipeline/common/datasets.py @@ -3,7 +3,7 @@ from collections import deque from io import TextIOWrapper from random import Random -from typing import Iterator, Optional +from typing import Iterable, Iterator, Optional class Dataset: @@ -56,7 +56,7 @@ def shuffle_with_max_lines( max_lines: int, max_words_in_sentence, total_byte_size: int, -) -> Iterator[str]: +) -> Iterable[str]: """ Shuffle a line stream, but only retain up to a maximum number of lines in memory. Note that the final ordering is determined by the seed and the contents of the file. So @@ -90,14 +90,21 @@ def shuffle_with_max_lines( if len(lines) == max_lines: break - random.shuffle(lines) + # random.shuffle requires random access via indexing + # deque supports fast adding/removing from its ends with O(1) + # but indexing is O(N) which is too slow for shuffling large arrays + lines_list = list(lines) + lines = None + random.shuffle(lines_list) # Consume the rest of the line stream, but sample based on the probability that adding # something to the collection will be representative. - i = 0 for line in line_stream: i = i + 1 + if lines is None: + lines = deque(lines_list) + lines_list = None # Continuously adjust this estimation in case the first sampled data is not representative. total_bytes = total_bytes + len(line.encode("utf-8")) average_bytes_per_line = total_bytes / (max_lines + i) @@ -109,11 +116,14 @@ def shuffle_with_max_lines( lines.popleft() lines.append(line) - # Do a final shuffle to ensure that the newly sampled lines are shuffled with the original - # set of shuffled lines. - random.shuffle(lines) + if i != 0: + # Do a final shuffle to ensure that the newly sampled lines are shuffled with the original + # set of shuffled lines. + lines_list = list(lines) + del lines + random.shuffle(lines_list) - return lines + return lines_list def shuffle_in_temp_files( diff --git a/tests/test_data_importer.py b/tests/test_data_importer.py index 771fa7e8b..ecf1d3da0 100644 --- a/tests/test_data_importer.py +++ b/tests/test_data_importer.py @@ -127,10 +127,10 @@ def make_url_dataset(lang: str): mono_params = [ - ("news-crawl", "en", "news_2021", [0, 1, 4, 6, 3, 7, 5, 2]), - ("news-crawl", "ru", "news_2021", [0, 1, 4, 6, 3, 7, 5, 2]), - ("url", "en", make_url_dataset("en"), [2, 1, 5, 4, 0, 7, 6, 3]), - ("url", "ru", make_url_dataset("ru"), [5, 4, 2, 0, 7, 1, 3, 6]), + ("news-crawl", "en", "news_2021", [2, 5, 3, 7, 0, 6, 4, 1]), + ("news-crawl", "ru", "news_2021", [2, 5, 3, 7, 0, 6, 4, 1]), + ("url", "en", make_url_dataset("en"), [3, 4, 5, 0, 1, 6, 2, 7]), + ("url", "ru", make_url_dataset("ru"), [5, 6, 2, 4, 7, 1, 3, 0]), ]