Skip to content

Commit

Permalink
Fix consecutive same sampler selection in round robin sampler with nu…
Browse files Browse the repository at this point in the history
…m_workers>1 (#1432)

* Fix consecutive same sampler selection in round robin sampler with num_workers>1

Signed-off-by: Piotr Żelasko <[email protected]>

* update github actions version

Signed-off-by: Piotr Żelasko <[email protected]>

* fix sampler restore unit tests

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko authored Dec 7, 2024
1 parent a13c084 commit faa3599
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ jobs:
fail-fast: false

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
Expand Down
21 changes: 18 additions & 3 deletions lhotse/dataset/sampling/round_robin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch

from lhotse import CutSet
from lhotse.cut import Cut
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(

self._nondepleted_samplers_indices = list(range(len(self.samplers)))
self._cur_sampler_idx = 0
self._num_dl_workers = 1

if isinstance(randomize, list):
assert len(randomize) == len(self.samplers)
Expand Down Expand Up @@ -124,6 +126,7 @@ def state_dict(self) -> Dict[str, Any]:
"stop_early": self.stop_early,
"randomize": self.randomize,
"_cur_sampler_idx": self._cur_sampler_idx,
"_num_dl_workers": self._num_dl_workers,
# Explicit list copy below allows to restore within the same process.
"_nondepleted_samplers_indices": list(
self._nondepleted_samplers_indices
Expand Down Expand Up @@ -153,6 +156,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.stop_early = state_dict.pop("stop_early")
self.randomize = state_dict.pop("randomize")
self._cur_sampler_idx = state_dict.pop("_cur_sampler_idx")
self._num_dl_workers = state_dict.pop("_num_dl_workers")
self._nondepleted_samplers_indices = state_dict.pop(
"_nondepleted_samplers_indices"
)
Expand All @@ -171,7 +175,18 @@ def __iter__(self):
if self._just_restored_state:
return self
self._nondepleted_samplers_indices = list(range(len(self.samplers)))
# In case this sampler lives in the dataloading worker subprocess,
# set the starting index to a different value on each dataloading worker.
# This helps avoid situations where the round robin sampler chooses
# the same underlying sampler for N consecutive mini-batches, where N = num_workers (>1).
self._cur_sampler_idx = 0
self._num_dl_workers = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self._cur_sampler_idx = worker_info.id % len(
self._nondepleted_samplers_indices
)
self._num_dl_workers = worker_info.num_workers
return self

def _next_batch(self) -> Union[CutSet, Tuple[CutSet]]:
Expand Down Expand Up @@ -202,9 +217,9 @@ def _set_next_idx(self) -> None:
p = [x / sum(p) for x in p]
self._cur_sampler_idx = self.rng.choice(N, size=1, replace=False, p=p)[0]
else:
self._cur_sampler_idx = (self._cur_sampler_idx + 1) % len(
self._nondepleted_samplers_indices
)
self._cur_sampler_idx = (
self._cur_sampler_idx + self._num_dl_workers
) % len(self._nondepleted_samplers_indices)

def set_epoch(self, epoch: int) -> None:
"""
Expand Down
22 changes: 21 additions & 1 deletion test/dataset/sampling/test_sampling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import random
import re
from collections import Counter
Expand Down Expand Up @@ -856,6 +855,27 @@ def test_round_robin_sampler(randomize):
# ... and so on


@pytest.mark.parametrize("num_workers", [0, 1, 2, 3])
def test_nonrandomized_round_robin_sampler_keeps_round_robin_property_in_iterable_dataset(
num_workers,
):
cuts1 = DummyManifest(CutSet, begin_id=0, end_id=100)
cuts2 = DummyManifest(CutSet, begin_id=500, end_id=600)
cuts3 = DummyManifest(CutSet, begin_id=1000, end_id=1100)
sampler = RoundRobinSampler(
SimpleCutSampler(cuts1, max_cuts=1, shuffle=False),
SimpleCutSampler(cuts2, max_cuts=2, shuffle=False),
SimpleCutSampler(cuts3, max_cuts=3, shuffle=False),
)
dloader = DataLoader(
dataset=IterableDatasetWrapper(IdentityDataset(), sampler),
batch_size=None,
num_workers=num_workers,
)
lens = [len(b) for idx, b in zip(range(15), dloader)]
assert lens == [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]


@pytest.mark.parametrize("sampler_cls", [SimpleCutSampler, DynamicCutSampler])
def test_single_cut_sampler_drop_last(sampler_cls):
# The dummy cuts have a duration of 1 second each
Expand Down

0 comments on commit faa3599

Please sign in to comment.