From c286f28fe5a883ccc64fffb087b46ce124a8efe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 18 Jul 2024 18:35:00 -0400 Subject: [PATCH] Support for pre-determined batch sizes in DynamicBucketingSampler (#1372) --- lhotse/dataset/sampling/dynamic_bucketing.py | 117 ++++++++++++++++-- .../sampling/test_dynamic_bucketing.py | 58 +++++++++ 2 files changed, 168 insertions(+), 7 deletions(-) diff --git a/lhotse/dataset/sampling/dynamic_bucketing.py b/lhotse/dataset/sampling/dynamic_bucketing.py index cf4da23f2..9b9a41f1c 100644 --- a/lhotse/dataset/sampling/dynamic_bucketing.py +++ b/lhotse/dataset/sampling/dynamic_bucketing.py @@ -2,7 +2,7 @@ import warnings from bisect import bisect_right from collections import deque -from dataclasses import dataclass +from dataclasses import asdict, dataclass from itertools import islice from typing import ( Any, @@ -104,7 +104,7 @@ def __init__( Note: with multiple CutSets, ``max_duration`` constraint applies only to the first CutSet. :param max_cuts: The maximum total number of ``cuts`` per batch. When only ``max_duration`` is specified, this sampler yields static batch sizes. - :param num_buckets: how many buckets to create. + :param num_buckets: how many buckets to create. Ignored if duration_bins are provided. :param shuffle: When ``True``, the cuts will be shuffled dynamically with a reservoir-sampling-based algorithm. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: @@ -169,15 +169,11 @@ def __init__( self.buffer_size += shuffle_buffer_size if duration_bins is not None: - if num_buckets is not None: - assert len(duration_bins) == num_buckets - 1, ( - f"num_buckets=={num_buckets} but len(duration_bins)=={len(duration_bins)} " - f"(bins are the boundaries, it should be one less than the number of buckets)." - ) assert list(duration_bins) == sorted( duration_bins ), "Duration bins must be sorted ascendingly." self.duration_bins = duration_bins + self.num_buckets = len(duration_bins) + 1 else: if constraint is None: constraint = TimeConstraint( @@ -316,6 +312,113 @@ def num_cuts(self) -> Optional[int]: return None +@dataclass +class FixedBucketBatchSizeConstraint(SamplingConstraint): + """ + Sampling constraint that accepts a pre-defined batch size for each bucket. + It uses the example's sequence length to determine which bucket we're sampling for, + and otherwise the batch size is locally static for each bucket. + + This constraint doesn't support samples longer than the upper bound of the last bucket; + if such sample is provided, we will raise an exception. + """ + + max_seq_len_buckets: List[float] + batch_sizes: List[int] + current_bucket: Union[int, None] = None + num_cuts: int = 0 + + def __post_init__(self): + assert sorted(self.max_seq_len_buckets) == list(self.max_seq_len_buckets) + + def is_active(self) -> bool: + return True + + def add(self, example: Cut) -> None: + """ + Increment the internal counter for the time constraint, + selecting the right property from the input ``cut`` object. + """ + seqlen = self.measure_length(example) + bucket_idx = bisect_right(self.max_seq_len_buckets, seqlen) + assert bucket_idx < len(self.max_seq_len_buckets), ( + f"Received example with sequence length {seqlen} that exceeds " + f"the highest allowed length {self.max_seq_len_buckets[-1]}." + ) + if self.current_bucket is None: + self.current_bucket = bucket_idx + else: + assert self.current_bucket == bucket_idx, ( + f"User error: FixedBucketBatchSizeConstraint is supposed to be used only on one bucket. " + f"The example we received has sequence length {seqlen} which is outside of the allowed bounds " + f"for bucket index {bucket_idx} in buckets {self.max_seq_len_buckets}." + ) + self.num_cuts += 1 + + def exceeded(self) -> bool: + """Is the constraint exceeded or not.""" + return self.num_cuts > self.batch_sizes[self.current_bucket] + + def close_to_exceeding(self) -> bool: + """ + Check if the batch is close to satisfying the constraints. + We define "closeness" as: if we added one more cut that has + duration/num_frames/num_samples equal to the longest seen cut + in the current batch, then the batch would have exceeded the constraints. + """ + return self.num_cuts >= self.batch_sizes[self.current_bucket] + + def reset(self) -> None: + """ + Reset the internal counter (to be used after a batch was created, + to start collecting a new one). + """ + self.current_bucket = None + self.num_cuts = 0 + + def measure_length(self, example: Cut) -> float: + return example.duration + + def state_dict(self) -> Dict[str, Any]: + return asdict(self) + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.max_seq_len_buckets = state_dict.pop("max_seq_len_buckets") + self.batch_sizes = state_dict.pop("batch_sizes") + self.current_bucket = state_dict.pop("current_bucket") + self.num_cuts = state_dict.pop("num_cuts") + assert len(state_dict) == 0, ( + "Error in FixedBucketBatchSizeConstraint.load_state_dict(): Unexpected keys:\n- " + + "\n- ".join(state_dict.keys()) + ) + + def __add__( + self, other: "FixedBucketBatchSizeConstraint" + ) -> "FixedBucketBatchSizeConstraint": + for key in ("max_seq_len_buckets", "batch_sizes", "current_bucket"): + self_attr = getattr(self, key) + other_attr = getattr(other, key) + is_none = self_attr is None and other_attr is None + assert is_none or self_attr == other_attr, ( + f"To add two TimeConstraint objects, they need to represent the same constraint " + f"(got self.{key}={self_attr} != other.{key}={other_attr})." + ) + return FixedBucketBatchSizeConstraint( + max_seq_len_buckets=self.max_seq_len_buckets, + batch_sizes=self.batch_sizes, + current_bucket=self.current_bucket, + num_cuts=self.num_cuts + other.num_cuts, + ) + + def __eq__(self, other: "TimeConstraint") -> bool: + return ( + isinstance(other, FixedBucketBatchSizeConstraint) + and self.max_seq_len_buckets == other.max_seq_len_buckets + and self.batch_sizes == other.batch_sizes + and self.current_bucket == other.current_bucket + ) + + def estimate_duration_buckets( cuts: Iterable[Cut], num_buckets: int, diff --git a/test/dataset/sampling/test_dynamic_bucketing.py b/test/dataset/sampling/test_dynamic_bucketing.py index bb3253b30..454bff58f 100644 --- a/test/dataset/sampling/test_dynamic_bucketing.py +++ b/test/dataset/sampling/test_dynamic_bucketing.py @@ -7,6 +7,7 @@ from lhotse.dataset.sampling.dynamic_bucketing import ( DynamicBucketer, DynamicBucketingSampler, + FixedBucketBatchSizeConstraint, estimate_duration_buckets, ) from lhotse.testing.dummies import DummyManifest, dummy_cut @@ -670,3 +671,60 @@ def test_dynamic_bucketing_sampler_sync_buckets_map_dataset_usage(sync_buckets): # some shapes will be mismatched because different buckets were selected. matching_shapes = [len(b0) == len(b1) for b0, b1 in zip(batches0, batches1)] assert not all(matching_shapes) + + +def test_dynamic_bucketing_sampler_fixed_batch_constraint(): + cuts = DummyManifest(CutSet, begin_id=0, end_id=10) + for i, c in enumerate(cuts): + if i < 5: + c.duration = 1 + else: + c.duration = 2 + + duration_bins = [1.5, 2.5] + sampler = DynamicBucketingSampler( + cuts, + duration_bins=duration_bins, + constraint=FixedBucketBatchSizeConstraint( + max_seq_len_buckets=duration_bins, batch_sizes=[2, 1] + ), + seed=0, + shuffle=True, + ) + + batches = [b for b in sampler] + sampled_cuts = [c for b in batches for c in b] + + # Invariant: no duplicated cut IDs + assert len(set(c.id for b in batches for c in b)) == len(sampled_cuts) + + # Same number of sampled and source cuts. + assert len(sampled_cuts) == len(cuts) + + # We sampled the follwoing batches with this RNG: + assert len(batches) == 8 + print([len(b) for b in batches]) + + assert len(batches[0]) == 1 + assert sum(c.duration for c in batches[0]) == 2 + + assert len(batches[1]) == 2 + assert sum(c.duration for c in batches[1]) == 2 + + assert len(batches[2]) == 2 + assert sum(c.duration for c in batches[2]) == 2 + + assert len(batches[3]) == 1 + assert sum(c.duration for c in batches[3]) == 2 + + assert len(batches[4]) == 1 + assert sum(c.duration for c in batches[4]) == 2 + + assert len(batches[5]) == 1 + assert sum(c.duration for c in batches[5]) == 2 + + assert len(batches[6]) == 1 + assert sum(c.duration for c in batches[6]) == 2 + + assert len(batches[7]) == 1 + assert sum(c.duration for c in batches[7]) == 1