6
6
import itertools
7
7
import math
8
8
import os
9
-
9
+ import time
10
10
import numpy as np
11
11
import torch
12
-
12
+ import queue
13
+ import logging
14
+ from threading import Thread
13
15
from . import data_utils
14
16
17
+ logger = logging .getLogger (__name__ )
18
+ logger .setLevel (logging .DEBUG )
19
+
20
+ # Object used by _background_consumer to signal the source is exhausted
21
+ # to the main thread.
22
+ _sentinel = object ()
23
+
15
24
16
25
class CountingIterator (object ):
17
26
"""Wrapper around an iterable that maintains the iteration count.
@@ -178,11 +187,14 @@ class EpochBatchIterator(EpochBatchIterating):
178
187
(default: 0).
179
188
epoch (int, optional): the epoch to start the iterator from
180
189
(default: 1).
190
+ buffer_size (int, optional): the number of batches to keep ready in the
191
+ queue. Helps speeding up dataloading. When buffer_size is zero, the
192
+ default torch.utils.data.DataLoader preloading is used.
181
193
"""
182
194
183
195
def __init__ (
184
196
self , dataset , collate_fn , batch_sampler , seed = 1 , num_shards = 1 , shard_id = 0 ,
185
- num_workers = 0 , epoch = 1 ,
197
+ num_workers = 0 , epoch = 1 , buffer_size = 0
186
198
):
187
199
assert isinstance (dataset , torch .utils .data .Dataset )
188
200
self .dataset = dataset
@@ -192,6 +204,7 @@ def __init__(
192
204
self .num_shards = num_shards
193
205
self .shard_id = shard_id
194
206
self .num_workers = num_workers
207
+ self .buffer_size = buffer_size
195
208
196
209
self .epoch = max (epoch , 1 ) # we use 1-based indexing for epochs
197
210
self .shuffle = True
@@ -307,16 +320,22 @@ def shuffle_batches(batches, seed):
307
320
if self .num_workers > 0 :
308
321
os .environ ['PYTHONWARNINGS' ] = 'ignore:semaphore_tracker:UserWarning'
309
322
310
- return CountingIterator (
311
- torch .utils .data .DataLoader (
312
- self .dataset ,
313
- collate_fn = self .collate_fn ,
314
- batch_sampler = batches [offset :],
315
- num_workers = self .num_workers ,
316
- ),
317
- start = offset ,
323
+ # Create data loader
324
+ itr = torch .utils .data .DataLoader (
325
+ self .dataset ,
326
+ collate_fn = self .collate_fn ,
327
+ batch_sampler = batches [offset :],
328
+ num_workers = self .num_workers ,
318
329
)
319
330
331
+ # Wrap with a BufferedIterator if needed
332
+ if self .buffer_size > 0 :
333
+ itr = BufferedIterator (self .buffer_size , itr )
334
+
335
+ # Wrap with CoutingIterator
336
+ itr = CountingIterator (itr , start = offset )
337
+ return itr
338
+
320
339
321
340
class GroupedIterator (object ):
322
341
"""Wrapper around an iterable that returns groups (chunks) of items.
@@ -382,3 +401,55 @@ def __iter__(self):
382
401
383
402
def __next__ (self ):
384
403
return next (self .itr )[1 ]
404
+
405
+
406
+ class BackgroundConsumer (Thread ):
407
+ def __init__ (self , queue , source ):
408
+ Thread .__init__ (self )
409
+
410
+ self ._queue = queue
411
+ self ._source = source
412
+
413
+ def run (self ):
414
+ for item in self ._source :
415
+ self ._queue .put (item )
416
+
417
+ # Signal the consumer we are done.
418
+ self ._queue .put (_sentinel )
419
+
420
+
421
+ class BufferedIterator (object ):
422
+ def __init__ (self , size , iterable ):
423
+ self ._queue = queue .Queue (size )
424
+ self ._iterable = iterable
425
+
426
+ self ._consumer = BackgroundConsumer (self ._queue , iterable )
427
+ self ._consumer .daemon = True
428
+ self ._consumer .start ()
429
+
430
+ self .start_time = time .time ()
431
+ self .warning_time = None
432
+
433
+ def __iter__ (self ):
434
+ return self
435
+
436
+ def __len__ (self ):
437
+ return len (self ._iterable )
438
+
439
+ def __next__ (self ):
440
+ # Notify the user if there is a data loading bottleneck
441
+ if self ._queue .qsize () < 2 :
442
+ if time .time () - self .start_time > 5 * 60 :
443
+ if self .warning_time is None or time .time () - self .warning_time > 15 * 60 :
444
+ logger .info (
445
+ "Data loading buffer is empty or nearly empty. This may "
446
+ "indicate a data loading bottleneck, and increasing the "
447
+ "number of workers may help."
448
+ )
449
+ self .warning_time = time .time ()
450
+
451
+ # Get next example
452
+ item = self ._queue .get (True )
453
+ if item is _sentinel :
454
+ raise StopIteration ()
455
+ return item
0 commit comments