Skip to content

Commit

Permalink
Support for pre-determined batch sizes in DynamicBucketingSampler (lh…
Browse files Browse the repository at this point in the history
  • Loading branch information
pzelasko authored and Your Name committed Jan 7, 2025
1 parent 769daeb commit 321ff23
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 7 deletions.
117 changes: 110 additions & 7 deletions lhotse/dataset/sampling/dynamic_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from bisect import bisect_right
from collections import deque
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from itertools import islice
from typing import (
Any,
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
Note: with multiple CutSets, ``max_duration`` constraint applies only to the first CutSet.
:param max_cuts: The maximum total number of ``cuts`` per batch.
When only ``max_duration`` is specified, this sampler yields static batch sizes.
:param num_buckets: how many buckets to create.
:param num_buckets: how many buckets to create. Ignored if duration_bins are provided.
:param shuffle: When ``True``, the cuts will be shuffled dynamically with
a reservoir-sampling-based algorithm.
Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.:
Expand Down Expand Up @@ -169,15 +169,11 @@ def __init__(
self.buffer_size += shuffle_buffer_size

if duration_bins is not None:
if num_buckets is not None:
assert len(duration_bins) == num_buckets - 1, (
f"num_buckets=={num_buckets} but len(duration_bins)=={len(duration_bins)} "
f"(bins are the boundaries, it should be one less than the number of buckets)."
)
assert list(duration_bins) == sorted(
duration_bins
), "Duration bins must be sorted ascendingly."
self.duration_bins = duration_bins
self.num_buckets = len(duration_bins) + 1
else:
if constraint is None:
constraint = TimeConstraint(
Expand Down Expand Up @@ -316,6 +312,113 @@ def num_cuts(self) -> Optional[int]:
return None


@dataclass
class FixedBucketBatchSizeConstraint(SamplingConstraint):
"""
Sampling constraint that accepts a pre-defined batch size for each bucket.
It uses the example's sequence length to determine which bucket we're sampling for,
and otherwise the batch size is locally static for each bucket.
This constraint doesn't support samples longer than the upper bound of the last bucket;
if such sample is provided, we will raise an exception.
"""

max_seq_len_buckets: List[float]
batch_sizes: List[int]
current_bucket: Union[int, None] = None
num_cuts: int = 0

def __post_init__(self):
assert sorted(self.max_seq_len_buckets) == list(self.max_seq_len_buckets)

def is_active(self) -> bool:
return True

def add(self, example: Cut) -> None:
"""
Increment the internal counter for the time constraint,
selecting the right property from the input ``cut`` object.
"""
seqlen = self.measure_length(example)
bucket_idx = bisect_right(self.max_seq_len_buckets, seqlen)
assert bucket_idx < len(self.max_seq_len_buckets), (
f"Received example with sequence length {seqlen} that exceeds "
f"the highest allowed length {self.max_seq_len_buckets[-1]}."
)
if self.current_bucket is None:
self.current_bucket = bucket_idx
else:
assert self.current_bucket == bucket_idx, (
f"User error: FixedBucketBatchSizeConstraint is supposed to be used only on one bucket. "
f"The example we received has sequence length {seqlen} which is outside of the allowed bounds "
f"for bucket index {bucket_idx} in buckets {self.max_seq_len_buckets}."
)
self.num_cuts += 1

def exceeded(self) -> bool:
"""Is the constraint exceeded or not."""
return self.num_cuts > self.batch_sizes[self.current_bucket]

def close_to_exceeding(self) -> bool:
"""
Check if the batch is close to satisfying the constraints.
We define "closeness" as: if we added one more cut that has
duration/num_frames/num_samples equal to the longest seen cut
in the current batch, then the batch would have exceeded the constraints.
"""
return self.num_cuts >= self.batch_sizes[self.current_bucket]

def reset(self) -> None:
"""
Reset the internal counter (to be used after a batch was created,
to start collecting a new one).
"""
self.current_bucket = None
self.num_cuts = 0

def measure_length(self, example: Cut) -> float:
return example.duration

def state_dict(self) -> Dict[str, Any]:
return asdict(self)

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.max_seq_len_buckets = state_dict.pop("max_seq_len_buckets")
self.batch_sizes = state_dict.pop("batch_sizes")
self.current_bucket = state_dict.pop("current_bucket")
self.num_cuts = state_dict.pop("num_cuts")
assert len(state_dict) == 0, (
"Error in FixedBucketBatchSizeConstraint.load_state_dict(): Unexpected keys:\n- "
+ "\n- ".join(state_dict.keys())
)

def __add__(
self, other: "FixedBucketBatchSizeConstraint"
) -> "FixedBucketBatchSizeConstraint":
for key in ("max_seq_len_buckets", "batch_sizes", "current_bucket"):
self_attr = getattr(self, key)
other_attr = getattr(other, key)
is_none = self_attr is None and other_attr is None
assert is_none or self_attr == other_attr, (
f"To add two TimeConstraint objects, they need to represent the same constraint "
f"(got self.{key}={self_attr} != other.{key}={other_attr})."
)
return FixedBucketBatchSizeConstraint(
max_seq_len_buckets=self.max_seq_len_buckets,
batch_sizes=self.batch_sizes,
current_bucket=self.current_bucket,
num_cuts=self.num_cuts + other.num_cuts,
)

def __eq__(self, other: "TimeConstraint") -> bool:
return (
isinstance(other, FixedBucketBatchSizeConstraint)
and self.max_seq_len_buckets == other.max_seq_len_buckets
and self.batch_sizes == other.batch_sizes
and self.current_bucket == other.current_bucket
)


def estimate_duration_buckets(
cuts: Iterable[Cut],
num_buckets: int,
Expand Down
58 changes: 58 additions & 0 deletions test/dataset/sampling/test_dynamic_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lhotse.dataset.sampling.dynamic_bucketing import (
DynamicBucketer,
DynamicBucketingSampler,
FixedBucketBatchSizeConstraint,
estimate_duration_buckets,
)
from lhotse.testing.dummies import DummyManifest, dummy_cut
Expand Down Expand Up @@ -670,3 +671,60 @@ def test_dynamic_bucketing_sampler_sync_buckets_map_dataset_usage(sync_buckets):
# some shapes will be mismatched because different buckets were selected.
matching_shapes = [len(b0) == len(b1) for b0, b1 in zip(batches0, batches1)]
assert not all(matching_shapes)


def test_dynamic_bucketing_sampler_fixed_batch_constraint():
cuts = DummyManifest(CutSet, begin_id=0, end_id=10)
for i, c in enumerate(cuts):
if i < 5:
c.duration = 1
else:
c.duration = 2

duration_bins = [1.5, 2.5]
sampler = DynamicBucketingSampler(
cuts,
duration_bins=duration_bins,
constraint=FixedBucketBatchSizeConstraint(
max_seq_len_buckets=duration_bins, batch_sizes=[2, 1]
),
seed=0,
shuffle=True,
)

batches = [b for b in sampler]
sampled_cuts = [c for b in batches for c in b]

# Invariant: no duplicated cut IDs
assert len(set(c.id for b in batches for c in b)) == len(sampled_cuts)

# Same number of sampled and source cuts.
assert len(sampled_cuts) == len(cuts)

# We sampled the follwoing batches with this RNG:
assert len(batches) == 8
print([len(b) for b in batches])

assert len(batches[0]) == 1
assert sum(c.duration for c in batches[0]) == 2

assert len(batches[1]) == 2
assert sum(c.duration for c in batches[1]) == 2

assert len(batches[2]) == 2
assert sum(c.duration for c in batches[2]) == 2

assert len(batches[3]) == 1
assert sum(c.duration for c in batches[3]) == 2

assert len(batches[4]) == 1
assert sum(c.duration for c in batches[4]) == 2

assert len(batches[5]) == 1
assert sum(c.duration for c in batches[5]) == 2

assert len(batches[6]) == 1
assert sum(c.duration for c in batches[6]) == 2

assert len(batches[7]) == 1
assert sum(c.duration for c in batches[7]) == 1

0 comments on commit 321ff23

Please sign in to comment.