Skip to content

Commit 4115317

Browse files
Gil Kerenfacebook-github-bot
Gil Keren
authored andcommitted
adding a buffered iterator
Summary: Torch's DataLoader keeps a buffer of only 2 ready batches only, which cannot be changed. This causes a data loading bottleneck at times where data preparation time fluctuates. Adding BufferedIterator, which is a generic wrapper for an iterator, implementing a buffer using queue. Adding FairseqTask support, and in BatchSamplerIterator as default. Reviewed By: myleott Differential Revision: D21261026 fbshipit-source-id: 23d4bc6181fe1f9a7ee7ad7d18491594725c0f53
1 parent dd518ef commit 4115317

File tree

4 files changed

+88
-11
lines changed

4 files changed

+88
-11
lines changed

fairseq/data/iterators.py

+82-11
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,21 @@
66
import itertools
77
import math
88
import os
9-
9+
import time
1010
import numpy as np
1111
import torch
12-
12+
import queue
13+
import logging
14+
from threading import Thread
1315
from . import data_utils
1416

17+
logger = logging.getLogger(__name__)
18+
logger.setLevel(logging.DEBUG)
19+
20+
# Object used by _background_consumer to signal the source is exhausted
21+
# to the main thread.
22+
_sentinel = object()
23+
1524

1625
class CountingIterator(object):
1726
"""Wrapper around an iterable that maintains the iteration count.
@@ -178,11 +187,14 @@ class EpochBatchIterator(EpochBatchIterating):
178187
(default: 0).
179188
epoch (int, optional): the epoch to start the iterator from
180189
(default: 1).
190+
buffer_size (int, optional): the number of batches to keep ready in the
191+
queue. Helps speeding up dataloading. When buffer_size is zero, the
192+
default torch.utils.data.DataLoader preloading is used.
181193
"""
182194

183195
def __init__(
184196
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
185-
num_workers=0, epoch=1,
197+
num_workers=0, epoch=1, buffer_size=0
186198
):
187199
assert isinstance(dataset, torch.utils.data.Dataset)
188200
self.dataset = dataset
@@ -192,6 +204,7 @@ def __init__(
192204
self.num_shards = num_shards
193205
self.shard_id = shard_id
194206
self.num_workers = num_workers
207+
self.buffer_size = buffer_size
195208

196209
self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
197210
self.shuffle = True
@@ -307,16 +320,22 @@ def shuffle_batches(batches, seed):
307320
if self.num_workers > 0:
308321
os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning'
309322

310-
return CountingIterator(
311-
torch.utils.data.DataLoader(
312-
self.dataset,
313-
collate_fn=self.collate_fn,
314-
batch_sampler=batches[offset:],
315-
num_workers=self.num_workers,
316-
),
317-
start=offset,
323+
# Create data loader
324+
itr = torch.utils.data.DataLoader(
325+
self.dataset,
326+
collate_fn=self.collate_fn,
327+
batch_sampler=batches[offset:],
328+
num_workers=self.num_workers,
318329
)
319330

331+
# Wrap with a BufferedIterator if needed
332+
if self.buffer_size > 0:
333+
itr = BufferedIterator(self.buffer_size, itr)
334+
335+
# Wrap with CoutingIterator
336+
itr = CountingIterator(itr, start=offset)
337+
return itr
338+
320339

321340
class GroupedIterator(object):
322341
"""Wrapper around an iterable that returns groups (chunks) of items.
@@ -382,3 +401,55 @@ def __iter__(self):
382401

383402
def __next__(self):
384403
return next(self.itr)[1]
404+
405+
406+
class BackgroundConsumer(Thread):
407+
def __init__(self, queue, source):
408+
Thread.__init__(self)
409+
410+
self._queue = queue
411+
self._source = source
412+
413+
def run(self):
414+
for item in self._source:
415+
self._queue.put(item)
416+
417+
# Signal the consumer we are done.
418+
self._queue.put(_sentinel)
419+
420+
421+
class BufferedIterator(object):
422+
def __init__(self, size, iterable):
423+
self._queue = queue.Queue(size)
424+
self._iterable = iterable
425+
426+
self._consumer = BackgroundConsumer(self._queue, iterable)
427+
self._consumer.daemon = True
428+
self._consumer.start()
429+
430+
self.start_time = time.time()
431+
self.warning_time = None
432+
433+
def __iter__(self):
434+
return self
435+
436+
def __len__(self):
437+
return len(self._iterable)
438+
439+
def __next__(self):
440+
# Notify the user if there is a data loading bottleneck
441+
if self._queue.qsize() < 2:
442+
if time.time() - self.start_time > 5 * 60:
443+
if self.warning_time is None or time.time() - self.warning_time > 15 * 60:
444+
logger.info(
445+
"Data loading buffer is empty or nearly empty. This may "
446+
"indicate a data loading bottleneck, and increasing the "
447+
"number of workers may help."
448+
)
449+
self.warning_time = time.time()
450+
451+
# Get next example
452+
item = self._queue.get(True)
453+
if item is _sentinel:
454+
raise StopIteration()
455+
return item

fairseq/options.py

+2
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,8 @@ def add_dataset_args(parser, train=False, gen=False):
317317
parser.add_argument('--dataset-impl', metavar='FORMAT',
318318
choices=get_available_dataset_impl(),
319319
help='output dataset implementation')
320+
group.add_argument('--data-buffer-size', default=0, type=int, metavar='N',
321+
help='Number of batches to preload')
320322
if train:
321323
group.add_argument('--train-subset', default='train', metavar='SPLIT',
322324
help='data subset to use for training (e.g. train, valid, test)')

fairseq/tasks/fairseq_task.py

+2
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def get_batch_iterator(
117117
shard_id=0,
118118
num_workers=0,
119119
epoch=1,
120+
buffer_size=0
120121
):
121122
"""
122123
Get an iterator that yields batches of data from the given dataset.
@@ -191,6 +192,7 @@ def get_batch_iterator(
191192
shard_id=shard_id,
192193
num_workers=num_workers,
193194
epoch=epoch,
195+
buffer_size=buffer_size,
194196
)
195197
self.dataset_to_epoch_iter[dataset] = epoch_iter
196198
return epoch_iter

fairseq/trainer.py

+2
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def get_train_iterator(
294294
shard_id=self.data_parallel_rank if shard_batch_itr else 0,
295295
num_workers=self.args.num_workers,
296296
epoch=epoch,
297+
buffer_size=self.args.data_buffer_size,
297298
)
298299

299300
def get_valid_iterator(
@@ -315,6 +316,7 @@ def get_valid_iterator(
315316
num_shards=self.data_parallel_world_size,
316317
shard_id=self.data_parallel_rank,
317318
num_workers=self.args.num_workers,
319+
buffer_size=self.args.data_buffer_size,
318320
)
319321

320322
def begin_epoch(self, epoch):

0 commit comments

Comments
 (0)