Skip to content

Commit

Permalink
Improve StreamingDataset Speed (#19114)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas <[email protected]>
(cherry picked from commit 4d15468)
  • Loading branch information
tchaton authored and Borda committed Dec 19, 2023
1 parent 64e2a76 commit 2cb8f45
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 88 deletions.
126 changes: 60 additions & 66 deletions src/lightning/data/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing
import os
import shutil
import warnings
from threading import Lock, Thread
from queue import Empty
from threading import Thread
from time import sleep
from typing import Any, Dict, List, Optional, Tuple

Expand All @@ -34,35 +36,25 @@
class PrepareChunksThread(Thread):
"""This thread is responsible to download the chunks associated to a given worker."""

def __init__(self, config: ChunksConfig, max_cache_size: Optional[int] = None, pre_download: int = 10) -> None:
def __init__(self, config: ChunksConfig, max_cache_size: Optional[int] = None) -> None:
super().__init__(daemon=True)
self._config = config
self._chunks_index_to_be_downloaded: List[int] = []
self._chunks_index_to_be_deleted: List[int] = []
self._lock = Lock()
self._max_cache_size = max_cache_size
self._downloaded_chunks = 0
self._processed_chunks = 0
self._processed_chunks_counter = 0
self._delete_chunks = 0
self._pre_download = pre_download
self._should_stop = False

def download(self, chunk_indices: List[int]) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
with self._lock:
for chunk_indice in chunk_indices:
if chunk_indice not in self._chunks_index_to_be_downloaded:
self._chunks_index_to_be_downloaded.append(chunk_indice)
self._to_download_queue: multiprocessing.Queue = multiprocessing.Queue()
self._to_delete_queue: multiprocessing.Queue = multiprocessing.Queue()
self._to_stop_queue: multiprocessing.Queue = multiprocessing.Queue()

def delete(self, chunk_indices: List[int]) -> None:
def download(self, chunk_indexes: List[int]) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
with self._lock:
for chunk_indice in chunk_indices:
if chunk_indice not in self._chunks_index_to_be_deleted:
self._chunks_index_to_be_deleted.append(chunk_indice)
self._processed_chunks += 1
self._processed_chunks_counter += 1
for chunk_index in chunk_indexes:
self._to_download_queue.put(chunk_index)

def delete(self, chunk_indexes: List[int]) -> None:
"""Receive the list of the chunk indices to delete for the current epoch."""
for chunk_index in chunk_indexes:
self._to_delete_queue.put(chunk_index)

def _delete(self, chunk_index: int) -> None:
chunk_filepath, begin, _ = self._config[ChunkedIndex(index=-1, chunk_index=chunk_index)]
Expand All @@ -72,54 +64,56 @@ def _delete(self, chunk_index: int) -> None:

def stop(self) -> None:
"""Receive the list of the chunk indices to download for the current epoch."""
with self._lock:
self._should_stop = True
self._to_stop_queue.put(None)

def run(self) -> None:
while True:
with self._lock:
if self._should_stop:
if (
self._max_cache_size
and self._max_cache_size <= shutil.disk_usage(self._config._cache_dir).total
):
for chunk_index in self._chunks_index_to_be_deleted:
if chunk_index not in self._chunks_index_to_be_downloaded:
self._delete(chunk_index)
self._delete_chunks += 1
self._processed_chunks_counter = 0
return

# Wait for something to do
if len(self._chunks_index_to_be_downloaded) == 0 and len(self._chunks_index_to_be_deleted) == 0:
continue

# Delete the chunks if we are missing disk space.
if self._max_cache_size and self._processed_chunks_counter >= self._pre_download:
try:
chunk_index = self._to_download_queue.get(timeout=0.01)
self._config.download_chunk_from_index(chunk_index)
except Empty:
pass
except OSError as e:
# handle closed queue before the thread terminates
if "handle is closed" in str(e):
pass
else:
raise e

try:
chunk_index = self._to_delete_queue.get(timeout=0.01)
if self._max_cache_size:
if shutil.disk_usage(self._config._cache_dir).total >= self._max_cache_size:
for chunk_index in self._chunks_index_to_be_deleted:
if chunk_index not in self._chunks_index_to_be_downloaded:
self._delete(chunk_index)
self._delete_chunks += 1
self._processed_chunks_counter = 0
self._chunks_index_to_be_deleted = []

# If there is no chunks to download, go back to waiting
if len(self._chunks_index_to_be_downloaded) == 0:
continue

# If we have already downloaded too many chunks, let's wait for processed chunks to catch up
if self._max_cache_size and (self._downloaded_chunks - self._processed_chunks) > self._pre_download:
sleep(0.1)
continue

chunk_index = self._chunks_index_to_be_downloaded.pop(0)

self._config.download_chunk_from_index(chunk_index)
self._downloaded_chunks += 1
self._chunks_index_to_be_deleted.append(chunk_index)

# Delete 2 chunk at the time to give enough space while not blocking downloads
for chunk_index in self._chunks_index_to_be_deleted[:2]:
self._delete(chunk_index)

self._chunks_index_to_be_deleted = self._chunks_index_to_be_deleted[2:]
else:
self._chunks_index_to_be_deleted.append(chunk_index)
except Empty:
pass
except OSError as e:
# handle closed queue before the thread terminates
if "handle is closed" in str(e):
pass
else:
raise e

try:
self._to_stop_queue.get(timeout=0.01)
return
except Empty:
pass
except OSError as e:
# handle closed queue before the thread terminates
if "handle is closed" in str(e):
return
raise e

# Sleep to release the lock
sleep(0.1)
sleep(0.01)


class BinaryReader:
Expand Down
44 changes: 22 additions & 22 deletions tests/tests_data/streaming/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,38 +45,38 @@ def test_reader_chunk_removal(tmpdir, monkeypatch):
shutil_mock.disk_usage.return_value = disk_usage
monkeypatch.setattr(reader, "shutil", shutil_mock)

expected = []
generated = []
for i in range(25):
expected.append([i, len(os.listdir(cache_dir))])
generated.append([i, len(os.listdir(cache_dir))])
index = ChunkedIndex(i, cache._get_chunk_index_from_index(i), is_last_index=i == 24)
assert cache[index] == i

assert expected == [
assert generated == [
[0, 0],
[1, 1],
[2, 1],
[3, 2],
[4, 2],
[5, 3],
[6, 3],
[7, 4],
[8, 4],
[9, 5],
[10, 5],
[11, 6],
[12, 6],
[13, 7],
[14, 7],
[15, 8],
[16, 8],
[17, 9],
[18, 9],
[19, 10],
[20, 10],
[5, 2],
[6, 2],
[7, 2],
[8, 2],
[9, 2],
[10, 2],
[11, 2],
[12, 2],
[13, 2],
[14, 2],
[15, 2],
[16, 2],
[17, 2],
[18, 2],
[19, 2],
[20, 2],
[21, 2],
[22, 2],
[23, 3],
[24, 3],
[23, 2],
[24, 2],
]

assert len(os.listdir(cache_dir)) in [3, 4]
assert len(os.listdir(cache_dir)) == 2

0 comments on commit 2cb8f45

Please sign in to comment.