Skip to content

Commit

Permalink
Concurrent dynamic bucketing (#1373)
Browse files Browse the repository at this point in the history
* Concurrent reads in dynamic bucketing for faster start time.

* Don't exceed the buffer_size; eliminate some race conditions

* Missing flag

* use a proper queue for concurrency

* disable concurrency by default

* Add a test for the concurrent implementation
  • Loading branch information
pzelasko authored Jul 22, 2024
1 parent fa8cbfe commit bd12d5d
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 17 deletions.
99 changes: 84 additions & 15 deletions lhotse/dataset/sampling/dynamic_bucketing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import random
import threading
import time
import warnings
from bisect import bisect_right
from collections import deque
from dataclasses import asdict, dataclass
from itertools import islice
from queue import Queue
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -95,6 +98,7 @@ def __init__(
rank: Optional[int] = None,
seed: Union[int, Literal["randomized", "trng"]] = 0,
sync_buckets: bool = True,
concurrent: bool = False,
strict=None,
shuffle_buffer_size=None,
) -> None:
Expand Down Expand Up @@ -131,6 +135,10 @@ def __init__(
when models have quadratic input complexity. Set between 15 and 40 for transformers.
:param sync_buckets: When set, we'll try to make each DDP rank sample from as close
duration buckets as possible to minimize the tail worker effect.
:param concurrent: Enabling concurrency eliminates most of the waiting to pre-populate the
bucketing buffers before the sampler starts yielding examples. For tarred/Lhotse Shar data
this can speed up the start of the training. Note that enabling concurrency will cause the
sampling results to be non-deterministic. This feature is experimental.
:param world_size: Total number of distributed nodes. We will try to infer it by default.
:param rank: Index of distributed node. We will try to infer it by default.
:param seed: Random seed used to consistently shuffle the dataset across different processes.
Expand All @@ -154,6 +162,7 @@ def __init__(
self.buffer_size = buffer_size
self.quadratic_duration = quadratic_duration
self.sync_buckets = sync_buckets
self.concurrent = concurrent
self.rng = None
check_constraint(constraint, max_duration, max_cuts)

Expand Down Expand Up @@ -282,6 +291,7 @@ def __iter__(self) -> "DynamicBucketingSampler":
shuffle=self.shuffle,
rng=self.rng,
bucket_rng=bucket_rng,
concurrent=self.concurrent,
diagnostics=self.diagnostics,
)
self.cuts_iter = iter(cuts_iter)
Expand Down Expand Up @@ -516,6 +526,7 @@ def __init__(
shuffle: bool = False,
rng: random.Random = None,
bucket_rng: random.Random = None,
concurrent: bool = False,
diagnostics: Optional[SamplingDiagnostics] = None,
) -> None:
self.cuts = cuts
Expand All @@ -533,6 +544,7 @@ def __init__(
self.rng = rng
self.bucket_rng = bucket_rng
self.shuffle = shuffle
self.concurrent = concurrent

assert duration_bins == sorted(duration_bins), (
f"Argument list for 'duration_bins' is expected to be in "
Expand Down Expand Up @@ -561,14 +573,19 @@ def __init__(
)

# Init: create empty buckets (note: `num_buckets = len(duration_bins) + 1`).
self.buckets: List[Deque[Union[Cut, Tuple[Cut, ...]]]] = [
deque() for _ in range(len(duration_bins) + 1)
]
self.buckets: List[Queue] = [Queue() for _ in range(len(duration_bins) + 1)]

self._producer_thread = None

def __iter__(self) -> Generator[CutSet, None, None]:
# Init: sample `buffer_size` cuts and assign them to the right buckets.
self.cuts_iter = iter(self.cuts)
self._collect_cuts_in_buckets(self.buffer_size)

if self.concurrent:
self._start_data_producer_thread()
self._maybe_wait_for_producer()
else:
self._collect_cuts_in_buckets(self.buffer_size)

state = BucketSelectionState(
bucket_rng=self.bucket_rng,
Expand All @@ -588,6 +605,9 @@ def __iter__(self) -> Generator[CutSet, None, None]:
maybe_shuffled = pick_at_random(
maybe_shuffled, rng=self.rng, out_indexes_used=indexes_used
)
else:
with sampling_bucket.mutex:
maybe_shuffled = list(sampling_bucket.queue)
# Sample one batch from that bucket and yield it to the caller.
batcher = DurationBatcher(
maybe_shuffled,
Expand All @@ -604,21 +624,26 @@ def __iter__(self) -> Generator[CutSet, None, None]:
if indexes_used:
# Shuffling, sort indexes of yielded elements largest -> smallest and remove them
indexes_used.sort(reverse=True)
for idx in indexes_used:
del sampling_bucket[idx]
with sampling_bucket.mutex:
_q = sampling_bucket.queue
for idx in indexes_used:
del _q[idx]
else:
# No shuffling, remove first N
for _ in range(batch_size):
sampling_bucket.popleft()
sampling_bucket.get()
# Fetch new cuts and add them to appropriate buckets.
self._collect_cuts_in_buckets(batch_size)
if self.concurrent:
self._maybe_wait_for_producer()
else:
self._collect_cuts_in_buckets(batch_size)
except StopIteration:
pass

# Cleanup.
self.cuts_iter = None

def _select_bucket(self, state: BucketSelectionState) -> Deque[Cut]:
def _select_bucket(self, state: BucketSelectionState) -> Queue:
if self.bucket_rng is None:
# Bucket selection algo 1:
# * there is just one RNG for choosing buckets and choosing samples randomly from the buckets
Expand Down Expand Up @@ -646,7 +671,7 @@ def _select_bucket(self, state: BucketSelectionState) -> Deque[Cut]:
# it will scan the neighbouring buckets until it finds one that's ready
# * if no bucket is ready, we end iteration

def scan_buckets(predicate: Callable[[Deque[Cut]], bool]) -> int:
def scan_buckets(predicate: Callable[[Queue], bool]) -> int:
bucket_idx = state.select_bucket_idx()

def valid_idx() -> bool:
Expand Down Expand Up @@ -689,43 +714,87 @@ def valid_idx() -> bool:
# which may yield partial batches.
try:
state.restore(ckpt)
selected_bucket_idx = scan_buckets(lambda b: len(b) > 0)
selected_bucket_idx = scan_buckets(lambda b: b.qsize() > 0)
except BucketsDontHaveEnoughData:
# We exhausted the full dataset.
raise StopIteration()

return self.buckets[selected_bucket_idx]

def _is_ready(self, bucket: Deque[Cut]) -> bool:
def _is_ready(self, bucket: Queue) -> bool:
tot = self.constraint.copy()
for c in bucket:
with bucket.mutex:
contents = list(bucket.queue)
for c in contents:
tot.add(c[0] if isinstance(c, tuple) else c)
if tot.close_to_exceeding():
return True
return False

def _start_data_producer_thread(self):
"""Start concurrent filling of the bucket buffer in a background thread."""

def producer():
try:
self._source_exhausted = False
while not self._source_exhausted:
if sum(b.qsize() for b in self.buckets) == self.buffer_size:
time.sleep(0.1)
continue
cuts = next(self.cuts_iter)
duration = self.constraint.measure_length(
cuts[0] if isinstance(cuts, tuple) else cuts
)
bucket_idx = bisect_right(self.duration_bins, duration)
self.buckets[bucket_idx].put(cuts)
except StopIteration:
self._source_exhausted = True

self._producer_thread = threading.Thread(target=producer)
self._producer_thread.start()

def _maybe_wait_for_producer(self):
"""Triggers wait for producer if the bucket buffers are less than 10% utilized."""
while (
sum(b.qsize() for b in self.buckets) < self.buffer_size / 10
and not self._source_exhausted
):
time.sleep(1.0)

def _collect_cuts_in_buckets(self, n_cuts: int) -> None:
"""Fetches ``n_cuts`` from the input data iterable. Doesn't use concurrency."""
try:
for _ in range(n_cuts):
cuts = next(self.cuts_iter)
duration = self.constraint.measure_length(
cuts[0] if isinstance(cuts, tuple) else cuts
)
bucket_idx = bisect_right(self.duration_bins, duration)
self.buckets[bucket_idx].append(cuts)
self.buckets[bucket_idx].put(cuts)
except StopIteration:
pass

def __del__(self):
if (
self.concurrent
and self._producer_thread is not None
and self._producer_thread.is_alive()
):
self._source_exhausted = True
self._producer_thread.join()


def pick_at_random(
bucket: Sequence[Union[Cut, Tuple[Cut, ...]]],
bucket: Queue,
rng: random.Random,
out_indexes_used: list,
) -> Generator[Union[Cut, Tuple[Cut, ...]], None, None]:
"""
Generator which will yield items in a sequence in a random order.
It will append the indexes of items yielded during iteration via ``out_used_indexes``.
"""
with bucket.mutex:
bucket = list(bucket.queue)
indexes = list(range(len(bucket)))
rng.shuffle(indexes)
for idx in indexes:
Expand Down
7 changes: 5 additions & 2 deletions test/dataset/sampling/test_dynamic_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,18 @@ def test_dynamic_bucketing_drop_last_true():
assert sum(c.duration for c in batches[2]) == 5


def test_dynamic_bucketing_sampler():
@pytest.mark.parametrize("concurrent", [False, True])
def test_dynamic_bucketing_sampler(concurrent):
cuts = DummyManifest(CutSet, begin_id=0, end_id=10)
for i, c in enumerate(cuts):
if i < 5:
c.duration = 1
else:
c.duration = 2

sampler = DynamicBucketingSampler(cuts, max_duration=5, num_buckets=2, seed=0)
sampler = DynamicBucketingSampler(
cuts, max_duration=5, num_buckets=2, seed=0, concurrent=concurrent
)
batches = [b for b in sampler]
sampled_cuts = [c for b in batches for c in b]

Expand Down

0 comments on commit bd12d5d

Please sign in to comment.