18
18
from logging import Logger
19
19
from queue import Empty
20
20
from threading import Thread
21
- from time import sleep
22
21
from typing import Any , Dict , List , Optional , Tuple , Union
23
22
24
23
from lightning .data .streaming .config import ChunksConfig
30
29
31
30
warnings .filterwarnings ("ignore" , message = ".*The given buffer is not writable.*" )
32
31
33
-
34
32
if _TORCH_GREATER_EQUAL_2_1_0 :
35
33
pass
36
34
37
35
38
36
logger = Logger (__name__ )
39
37
40
38
39
+ _END_TOKEN = "END"
40
+
41
+ # Note: The timeout here should not be too short. We need to prevent the caller from aggressively
42
+ # querying the queue and consuming too many CPU cycles.
43
+ _DEFAULT_TIMEOUT = 0.1
44
+ _LONG_DEFAULT_TIMEOUT = 5
45
+
46
+
41
47
class PrepareChunksThread (Thread ):
42
48
"""This thread is responsible to download the chunks associated to a given worker."""
43
49
@@ -59,22 +65,7 @@ def __init__(
59
65
self ._parent_cache_dir = os .path .dirname (self ._config ._cache_dir )
60
66
self ._to_download_queue : multiprocessing .Queue = multiprocessing .Queue ()
61
67
self ._to_delete_queue : multiprocessing .Queue = multiprocessing .Queue ()
62
- self ._to_stop_queue : multiprocessing .Queue = multiprocessing .Queue ()
63
-
64
- # populate back the queues with existing items. As they already exists, this is almost a no-op
65
- for chunk_index in self ._collect_ordered_chunk_indexes_from_cache ():
66
- self ._to_download_queue .put (chunk_index )
67
- self ._to_delete_queue .put (chunk_index )
68
-
69
- def _collect_ordered_chunk_indexes_from_cache (self ) -> List [int ]:
70
- """List the chunks available in the cache, order them based on their creation time and retrieves their
71
- indexes."""
72
- chunk_indexes = [
73
- [self ._config ._get_chunk_index_from_filename (f ), os .path .getctime (os .path .join (self ._config ._cache_dir , f ))]
74
- for f in os .listdir (self ._config ._cache_dir )
75
- if f .endswith (".bin" )
76
- ]
77
- return [int (x [0 ]) for x in sorted (chunk_indexes , key = lambda x : x [1 ])]
68
+ self ._delete_chunks_when_processed = self ._config .num_bytes > max_cache_size if max_cache_size else False
78
69
79
70
def download (self , chunk_indexes : List [int ]) -> None :
80
71
"""Receive the list of the chunk indices to download for the current epoch."""
@@ -93,10 +84,15 @@ def _delete(self, chunk_index: int) -> None:
93
84
94
85
def stop (self ) -> None :
95
86
"""Receive the list of the chunk indices to download for the current epoch."""
96
- self ._to_stop_queue .put (True )
87
+ self ._to_download_queue .put (_END_TOKEN )
97
88
98
89
def _maybe_delete_chunks (self ) -> None :
99
- chunk_index = _get_from_queue (self ._to_delete_queue )
90
+ reached_pre_download = self ._pre_download_counter == self ._max_pre_download
91
+
92
+ # we have already pre-downloaded some chunks, we just need to wait for them to be processed.
93
+ chunk_index = _get_from_queue (
94
+ self ._to_delete_queue , timeout = _LONG_DEFAULT_TIMEOUT if reached_pre_download else _DEFAULT_TIMEOUT
95
+ )
100
96
101
97
if chunk_index is not None :
102
98
self ._pre_download_counter -= 1
@@ -105,14 +101,17 @@ def _maybe_delete_chunks(self) -> None:
105
101
self ._chunks_index_to_be_deleted .append (chunk_index )
106
102
107
103
# Get the current cache size and decide whether we need to start cleanup. Otherwise, keep track of it
108
- while (
109
- self ._max_cache_size
110
- and self ._chunks_index_to_be_deleted
111
- and _get_folder_size (self ._parent_cache_dir ) >= self ._max_cache_size
112
- ):
104
+ while self ._max_cache_size and self ._chunks_index_to_be_deleted and self ._can_delete_chunk ():
113
105
# Delete the oldest chunk
114
106
self ._delete (self ._chunks_index_to_be_deleted .pop (0 ))
115
107
108
+ return
109
+
110
+ def _can_delete_chunk (self ) -> bool :
111
+ if self ._delete_chunks_when_processed :
112
+ return self ._pre_download_counter == self ._max_pre_download - 1
113
+ return self ._max_cache_size is not None and _get_folder_size (self ._parent_cache_dir ) >= self ._max_cache_size
114
+
116
115
def _pre_load_chunk (self , chunk_index : int ) -> None :
117
116
chunk_filepath , _ , _ = self ._config [ChunkedIndex (index = - 1 , chunk_index = chunk_index )]
118
117
self ._item_loader .pre_load_chunk (chunk_index , chunk_filepath )
@@ -121,6 +120,9 @@ def run(self) -> None:
121
120
while True :
122
121
if self ._pre_download_counter <= self ._max_pre_download :
123
122
chunk_index = _get_from_queue (self ._to_download_queue )
123
+ if chunk_index == _END_TOKEN :
124
+ return
125
+
124
126
if chunk_index is not None :
125
127
self ._config .download_chunk_from_index (chunk_index )
126
128
@@ -135,11 +137,6 @@ def run(self) -> None:
135
137
if self ._max_cache_size :
136
138
self ._maybe_delete_chunks ()
137
139
138
- if _get_from_queue (self ._to_stop_queue ):
139
- return
140
-
141
- sleep (0.05 )
142
-
143
140
144
141
class BinaryReader :
145
142
def __init__ (
@@ -238,6 +235,9 @@ def read(self, index: ChunkedIndex) -> Any:
238
235
assert self ._prepare_thread
239
236
self ._prepare_thread .download ([index .chunk_index ])
240
237
238
+ if self ._last_chunk_index is None :
239
+ self ._last_chunk_index = index .chunk_index
240
+
241
241
# Fetch the element
242
242
chunk_filepath , begin , _ = self .config [index ]
243
243
item = self ._item_loader .load_item_from_chunk (index .index , index .chunk_index , chunk_filepath , begin )
@@ -246,9 +246,10 @@ def read(self, index: ChunkedIndex) -> Any:
246
246
# Otherwise, this could trigger segmentation fault error depending on the item loader used.
247
247
if self ._config and self ._config ._remote_dir and index .chunk_index != self ._last_chunk_index :
248
248
assert self ._prepare_thread
249
- if self ._last_chunk_index is not None :
250
- # inform the chunk has been completely consumed
251
- self ._prepare_thread .delete ([self ._last_chunk_index ])
249
+ assert self ._last_chunk_index is not None
250
+
251
+ # inform the chunk has been completely consumed
252
+ self ._prepare_thread .delete ([self ._last_chunk_index ])
252
253
253
254
# track the new chunk index as the latest one
254
255
self ._last_chunk_index = index .chunk_index
@@ -294,11 +295,9 @@ def _get_folder_size(path: str) -> int:
294
295
return size
295
296
296
297
297
- def _get_from_queue (queue : multiprocessing .Queue ) -> Optional [Any ]:
298
+ def _get_from_queue (queue : multiprocessing .Queue , timeout : float = _DEFAULT_TIMEOUT ) -> Optional [Any ]:
298
299
try :
299
- # Note: The timeout here should not be too short. We need to prevent the caller from aggressively
300
- # querying the queue and consuming too many CPU cycles.
301
- return queue .get (timeout = 0.1 )
300
+ return queue .get (timeout = timeout )
302
301
except Empty :
303
302
pass
304
303
except OSError as e :
0 commit comments