From 2cb8f45af83ae625e1181fb56566fba59cb3fdff Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 5 Dec 2023 19:50:27 +0000 Subject: [PATCH] Improve StreamingDataset Speed (#19114) Co-authored-by: thomas (cherry picked from commit 4d154685557881b4ff47083267b5a6328b465d61) --- src/lightning/data/streaming/reader.py | 126 +++++++++++----------- tests/tests_data/streaming/test_reader.py | 44 ++++---- 2 files changed, 82 insertions(+), 88 deletions(-) diff --git a/src/lightning/data/streaming/reader.py b/src/lightning/data/streaming/reader.py index 1f9d6d7ffc1d8..d97f577869578 100644 --- a/src/lightning/data/streaming/reader.py +++ b/src/lightning/data/streaming/reader.py @@ -11,10 +11,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing import os import shutil import warnings -from threading import Lock, Thread +from queue import Empty +from threading import Thread from time import sleep from typing import Any, Dict, List, Optional, Tuple @@ -34,35 +36,25 @@ class PrepareChunksThread(Thread): """This thread is responsible to download the chunks associated to a given worker.""" - def __init__(self, config: ChunksConfig, max_cache_size: Optional[int] = None, pre_download: int = 10) -> None: + def __init__(self, config: ChunksConfig, max_cache_size: Optional[int] = None) -> None: super().__init__(daemon=True) self._config = config self._chunks_index_to_be_downloaded: List[int] = [] self._chunks_index_to_be_deleted: List[int] = [] - self._lock = Lock() self._max_cache_size = max_cache_size - self._downloaded_chunks = 0 - self._processed_chunks = 0 - self._processed_chunks_counter = 0 - self._delete_chunks = 0 - self._pre_download = pre_download - self._should_stop = False - - def download(self, chunk_indices: List[int]) -> None: - """Receive the list of the chunk indices to download for the current epoch.""" - with self._lock: - for chunk_indice in chunk_indices: - if chunk_indice not in self._chunks_index_to_be_downloaded: - self._chunks_index_to_be_downloaded.append(chunk_indice) + self._to_download_queue: multiprocessing.Queue = multiprocessing.Queue() + self._to_delete_queue: multiprocessing.Queue = multiprocessing.Queue() + self._to_stop_queue: multiprocessing.Queue = multiprocessing.Queue() - def delete(self, chunk_indices: List[int]) -> None: + def download(self, chunk_indexes: List[int]) -> None: """Receive the list of the chunk indices to download for the current epoch.""" - with self._lock: - for chunk_indice in chunk_indices: - if chunk_indice not in self._chunks_index_to_be_deleted: - self._chunks_index_to_be_deleted.append(chunk_indice) - self._processed_chunks += 1 - self._processed_chunks_counter += 1 + for chunk_index in chunk_indexes: + self._to_download_queue.put(chunk_index) + + def delete(self, chunk_indexes: List[int]) -> None: + """Receive the list of the chunk indices to delete for the current epoch.""" + for chunk_index in chunk_indexes: + self._to_delete_queue.put(chunk_index) def _delete(self, chunk_index: int) -> None: chunk_filepath, begin, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)] @@ -72,54 +64,56 @@ def _delete(self, chunk_index: int) -> None: def stop(self) -> None: """Receive the list of the chunk indices to download for the current epoch.""" - with self._lock: - self._should_stop = True + self._to_stop_queue.put(None) def run(self) -> None: while True: - with self._lock: - if self._should_stop: - if ( - self._max_cache_size - and self._max_cache_size <= shutil.disk_usage(self._config._cache_dir).total - ): - for chunk_index in self._chunks_index_to_be_deleted: - if chunk_index not in self._chunks_index_to_be_downloaded: - self._delete(chunk_index) - self._delete_chunks += 1 - self._processed_chunks_counter = 0 - return - - # Wait for something to do - if len(self._chunks_index_to_be_downloaded) == 0 and len(self._chunks_index_to_be_deleted) == 0: - continue - - # Delete the chunks if we are missing disk space. - if self._max_cache_size and self._processed_chunks_counter >= self._pre_download: + try: + chunk_index = self._to_download_queue.get(timeout=0.01) + self._config.download_chunk_from_index(chunk_index) + except Empty: + pass + except OSError as e: + # handle closed queue before the thread terminates + if "handle is closed" in str(e): + pass + else: + raise e + + try: + chunk_index = self._to_delete_queue.get(timeout=0.01) + if self._max_cache_size: if shutil.disk_usage(self._config._cache_dir).total >= self._max_cache_size: - for chunk_index in self._chunks_index_to_be_deleted: - if chunk_index not in self._chunks_index_to_be_downloaded: - self._delete(chunk_index) - self._delete_chunks += 1 - self._processed_chunks_counter = 0 - self._chunks_index_to_be_deleted = [] - - # If there is no chunks to download, go back to waiting - if len(self._chunks_index_to_be_downloaded) == 0: - continue - - # If we have already downloaded too many chunks, let's wait for processed chunks to catch up - if self._max_cache_size and (self._downloaded_chunks - self._processed_chunks) > self._pre_download: - sleep(0.1) - continue - - chunk_index = self._chunks_index_to_be_downloaded.pop(0) - - self._config.download_chunk_from_index(chunk_index) - self._downloaded_chunks += 1 + self._chunks_index_to_be_deleted.append(chunk_index) + + # Delete 2 chunk at the time to give enough space while not blocking downloads + for chunk_index in self._chunks_index_to_be_deleted[:2]: + self._delete(chunk_index) + + self._chunks_index_to_be_deleted = self._chunks_index_to_be_deleted[2:] + else: + self._chunks_index_to_be_deleted.append(chunk_index) + except Empty: + pass + except OSError as e: + # handle closed queue before the thread terminates + if "handle is closed" in str(e): + pass + else: + raise e + + try: + self._to_stop_queue.get(timeout=0.01) + return + except Empty: + pass + except OSError as e: + # handle closed queue before the thread terminates + if "handle is closed" in str(e): + return + raise e - # Sleep to release the lock - sleep(0.1) + sleep(0.01) class BinaryReader: diff --git a/tests/tests_data/streaming/test_reader.py b/tests/tests_data/streaming/test_reader.py index 67f884924fb02..dde0928e2d15e 100644 --- a/tests/tests_data/streaming/test_reader.py +++ b/tests/tests_data/streaming/test_reader.py @@ -45,38 +45,38 @@ def test_reader_chunk_removal(tmpdir, monkeypatch): shutil_mock.disk_usage.return_value = disk_usage monkeypatch.setattr(reader, "shutil", shutil_mock) - expected = [] + generated = [] for i in range(25): - expected.append([i, len(os.listdir(cache_dir))]) + generated.append([i, len(os.listdir(cache_dir))]) index = ChunkedIndex(i, cache._get_chunk_index_from_index(i), is_last_index=i == 24) assert cache[index] == i - assert expected == [ + assert generated == [ [0, 0], [1, 1], [2, 1], [3, 2], [4, 2], - [5, 3], - [6, 3], - [7, 4], - [8, 4], - [9, 5], - [10, 5], - [11, 6], - [12, 6], - [13, 7], - [14, 7], - [15, 8], - [16, 8], - [17, 9], - [18, 9], - [19, 10], - [20, 10], + [5, 2], + [6, 2], + [7, 2], + [8, 2], + [9, 2], + [10, 2], + [11, 2], + [12, 2], + [13, 2], + [14, 2], + [15, 2], + [16, 2], + [17, 2], + [18, 2], + [19, 2], + [20, 2], [21, 2], [22, 2], - [23, 3], - [24, 3], + [23, 2], + [24, 2], ] - assert len(os.listdir(cache_dir)) in [3, 4] + assert len(os.listdir(cache_dir)) == 2