Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multiprocessing to meeting simulation workflow #972

Merged
merged 9 commits into from
Feb 9, 2023
7 changes: 7 additions & 0 deletions lhotse/bin/modes/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,13 @@ def align_with_torchaudio(
default=1234,
help="Random seed for reproducibility.",
)
@click.option(
"--num-jobs",
"-j",
type=int,
default=1,
help="Number of parallel jobs to run.",
)
def simulate_meetings(
in_cuts: str,
out_cuts: str,
Expand Down
2 changes: 1 addition & 1 deletion lhotse/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
self.use_threads = threads

def run(self) -> None:
executor = ProcessPoolExecutor if self.use_threads else ThreadPoolExecutor
executor = ThreadPoolExecutor if self.use_threads else ProcessPoolExecutor
with executor(self.num_jobs) as ex:
for args in zip(*self.iterables):
future = ex.submit(self.fn, *args)
Expand Down
176 changes: 129 additions & 47 deletions lhotse/workflows/meeting_simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@
a CutSet containing MonoCut objects.
"""
import abc
from collections import defaultdict
from typing import Optional
import random
from itertools import groupby
from typing import Dict, List, Optional, Union

import numpy as np
from tqdm import tqdm

from lhotse import RecordingSet, SupervisionSet
from lhotse.cut import CutSet
from lhotse.dataset.sampling import DynamicCutSampler, RoundRobinSampler
from lhotse.dataset.sampling import DynamicCutSampler
from lhotse.utils import fastcopy, is_module_available

MAX_TASKS_WAITING = 1000


class BaseMeetingSimulator(abc.ABC):
"""
Expand Down Expand Up @@ -83,6 +89,126 @@ def reverberate(self, cuts: CutSet, *rirs: RecordingSet) -> CutSet:
...


class MeetingSampler:
"""
Create a sampler that will be used to sample groups of utterances from the sources.
The cuts are partitioned into speaker-wise buckets, and a SimpleCutSampler is created
for each bucket. When we sample a group of utterances, we first sample the number of
speakers in the meeting, and then sample the utterances of each speaker. This is done
by sampling a batch from the corresponding SimpleCutSampler.

:param cuts: a CutSet containing MonoCut objects.
:param num_repeats: the number of times each cut will be repeated (by default, they
are repeated infinitely).
:param num_meetings: the number of meetings to simulate.
:param num_speakers_per_meeting: the number of speakers per meeting.
:param speaker_count_probs: the probabilities of the number of speakers per meeting.
:param max_duration_per_speaker: the maximum duration of a speaker in a meeting.
:param max_utterances_per_speaker: the maximum number of utterances of a speaker in a meeting.
:param seed: the random seed.
:return: a DynamicCutSampler object.
"""

def __init__(
self,
cuts: CutSet,
num_repeats: Optional[int] = None,
num_meetings: Optional[int] = None,
num_speakers_per_meeting: Union[int, List[int]] = 2,
speaker_count_probs: Optional[List[float]] = None,
max_duration_per_speaker: Optional[float] = 20.0,
max_utterances_per_speaker: Optional[int] = 5,
seed: int = 0,
):
# Some basic checks
assert all(n > 1 for n in num_speakers_per_meeting), (
"The number of speakers per meeting must be greater than 1. "
f"Got: {num_speakers_per_meeting}"
)
assert all(p > 0.0 for p in speaker_count_probs), (
"The probabilities of the number of speakers per meeting must be greater than 0. "
f"Got: {speaker_count_probs}"
)
assert sum(speaker_count_probs) == 1.0, (
"The probabilities of the number of speakers per meeting must sum to 1. "
f"Got: {speaker_count_probs}"
)
assert len(num_speakers_per_meeting) == len(
speaker_count_probs
), "The number of speakers per meeting and the number of probabilities must be the same."

# Create samplers for each bucket. We create this as a dict so that we can
# efficiently remove items and also randomly sample items in constant time.
# It also supports the len() function in constant time.
# Note that a Python list is not a good choice here, because removing items
# from a list is O(n). A set is also not a good choice, because randomly
# sampling items from a set is O(n).
self.samplers = {}
for spk, spk_cuts in tqdm(
groupby(
sorted(cuts, key=lambda cut: cut.supervisions[0].speaker),
lambda cut: cut.supervisions[0].speaker,
),
desc="Creating samplers for each speaker...",
):
sampler = DynamicCutSampler(
CutSet.from_cuts(list(spk_cuts)).repeat(
times=num_repeats, preserve_id=False
),
max_duration=max_duration_per_speaker,
max_cuts=max_utterances_per_speaker,
shuffle=True,
seed=seed,
)
self.samplers[spk] = sampler

self.num_speakers_per_meeting = num_speakers_per_meeting
self.speaker_count_probs = speaker_count_probs

self.npr = np.random.RandomState(seed)
self.rng = random.Random(seed)
self._remaining_meetings = num_meetings

def __iter__(self):
for sampler in self.samplers.values():
iter(sampler)
return self

def __next__(self):
# If we have sampled enough meetings, stop.
if self._remaining_meetings is not None and self._remaining_meetings == 0:
raise StopIteration()

# If we don't have enough speakers, stop.
if len(self.samplers) < min(self.num_speakers_per_meeting):
raise StopIteration()

# Sample the number of speakers for this meeting.
N = min(
self.npr.choice(self.num_speakers_per_meeting, p=self.speaker_count_probs),
len(self.samplers),
)

# Sample speakers.
this_batch_spk_ids = self.rng.sample(list(self.samplers.keys()), N)
utterances = CutSet.from_cuts([])
for spk_id in this_batch_spk_ids:
sampler = self.samplers[spk_id]
try:
this_batch = next(sampler)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pzelasko An issue I am facing here is that the whole sampler gets exhausted after sampling just 1 batch. I'm not sure why this is happening. Could you take a look?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay very basic mistake --- forgot to sort the cuts by speaker id before groupby!

utterances += this_batch
except StopIteration:
del self.samplers[spk_id]
continue

# shuffle the utterances
utterances = utterances.shuffle()

if self._remaining_meetings is not None:
self._remaining_meetings -= 1
return utterances if len(utterances) > 0 else next(self)


def reverberate_cuts(cuts: CutSet, *rirs: RecordingSet) -> CutSet:
"""
Use provided RIRs to convolve each track of the input CutSet. The cuts here are
Expand Down Expand Up @@ -114,47 +240,3 @@ def reverberate_cuts(cuts: CutSet, *rirs: RecordingSet) -> CutSet:
out_cuts.append(cut.reverb_rir())

return CutSet.from_cuts(out_cuts)


def create_sampler(
cuts: CutSet,
num_repeats: Optional[int] = None,
max_duration: float = None,
max_cuts: int = None,
seed: int = 0,
) -> DynamicCutSampler:
"""
Create a sampler that will be used to sample cuts from the input CutSet. The cuts
are partitioned into speaker-wise buckets, and a DynamicCutSampler is created for
each bucket. The samplers are then combined into a RoundRobinSampler, which will
sample cuts from each bucket in a round-robin fashion.

:param cuts: a CutSet containing MonoCut objects.
:param num_repeats: the number of times each cut will be repeated (by default, they
are repeated infinitely).
:param max_duration: the maximum duration of the cuts in each batch.
:param max_cuts: the maximum number of cuts in each batch.
:param seed: the random seed.
:return: a DynamicCutSampler object.
"""
# Create buckets by speaker.
buckets = defaultdict(list)
for cut in cuts:
buckets[cut.supervisions[0].speaker].append(cut)

buckets = [CutSet.from_cuts(cuts) for cuts in buckets.values()]

# Create samplers for each bucket.
samplers = [
DynamicCutSampler(
cuts.repeat(times=num_repeats),
max_duration=max_duration,
max_cuts=max_cuts,
shuffle=True,
seed=seed,
)
for cuts in buckets
]

# Combine samplers into a round-robin sampler.
return RoundRobinSampler(*samplers, randomize=True, seed=seed)
107 changes: 41 additions & 66 deletions lhotse/workflows/meeting_simulation/conversational.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import logging
from collections import defaultdict
from multiprocessing import Pool
from typing import Any, List, Optional, Union

import numpy as np
from tqdm import tqdm

from lhotse import RecordingSet, SupervisionSet
from lhotse.cut import CutSet, MixedCut, MixTrack, MonoCut
from lhotse.cut import CutSet, MixedCut, MixTrack
from lhotse.cut.set import mix
from lhotse.parallel import parallel_map
from lhotse.utils import uuid4
from lhotse.workflows.meeting_simulation.base import (
MAX_TASKS_WAITING,
BaseMeetingSimulator,
create_sampler,
MeetingSampler,
reverberate_cuts,
)

Expand Down Expand Up @@ -252,6 +255,7 @@ def simulate(
max_utterances_per_speaker: Optional[int] = 5,
allow_3fold_overlap: bool = False,
seed: int = 0,
num_jobs: int = 1,
) -> CutSet:
"""
Simulate the desired number of multi-speaker meetings.
Expand All @@ -271,6 +275,8 @@ def simulate(
:param allow_3fold_overlap: if True, allow 3-fold overlaps between speakers.
[Default: False]
:param seed: the random seed to be used for simulation. [Default: 0]
:param num_jobs: the number of jobs to use for simulation. Use more jobs to speed up
simulation when you have large number of source utterances. [Default: 1]
"""
from scipy.stats import bernoulli

Expand All @@ -288,85 +294,54 @@ def simulate(
num_speakers_per_meeting
)

# Some basic checks
assert all(n > 1 for n in num_speakers_per_meeting), (
"The number of speakers per meeting must be greater than 1. "
f"Got: {num_speakers_per_meeting}"
)
assert all(p > 0.0 for p in speaker_count_probs), (
"The probabilities of the number of speakers per meeting must be greater than 0. "
f"Got: {speaker_count_probs}"
)
assert sum(speaker_count_probs) == 1.0, (
"The probabilities of the number of speakers per meeting must sum to 1. "
f"Got: {speaker_count_probs}"
)
assert len(num_speakers_per_meeting) == len(
speaker_count_probs
), "The number of speakers per meeting and the number of probabilities must be the same."
assert all(
isinstance(cut, MonoCut) for cut in cuts
), "Only MonoCuts are supported."

# Initialize default distributions if not provided.
if getattr(self, "same_spk_pause_dist", None) is None:
self._init_defaults()

# Create random number generators with the given seed.
self.bernoulli = bernoulli

# Create cuts sampler
sampler = create_sampler(
sampler = MeetingSampler(
cuts,
num_repeats=num_repeats,
max_duration=max_duration_per_speaker,
max_cuts=max_utterances_per_speaker,
num_meetings=num_meetings,
max_duration_per_speaker=max_duration_per_speaker,
max_utterances_per_speaker=max_utterances_per_speaker,
num_speakers_per_meeting=num_speakers_per_meeting,
speaker_count_probs=speaker_count_probs,
seed=seed,
)
# Create an iterator from the sampler
sampler_iter = iter(sampler)

# Create random number generators with the given seed.
npr = np.random.RandomState(seed)

self.bernoulli = bernoulli

mixtures = []
N = len(cuts.speakers)

pbar = tqdm(total=num_meetings)
while True:
# If the number of meetings is provided, stop when we reach that number.
if num_meetings is not None and len(mixtures) >= num_meetings:
break

# Sample the number of speakers for this meeting.
num_speakers = min(
npr.choice(num_speakers_per_meeting, p=speaker_count_probs), N
)
global _simulate_worker

# Sample from the sampler to get 1 batch per desired number of speakers.
utterances = CutSet.from_cuts([])
finished = False
while len(utterances.speakers) < num_speakers:
try:
this_batch = next(sampler_iter)
except StopIteration:
# If we run out of data, finish simulation.
finished = True
break
utterances += this_batch

if finished:
break

# Randomly shuffle the utterances
utterances = utterances.shuffle()

# Create the meeting.
mixture = self._create_mixture(
def _simulate_worker(utterances):
return self._create_mixture(
utterances, allow_3fold_overlap=allow_3fold_overlap
)

mixtures.append(mixture)
pbar.update(1)
mixtures = []
if num_jobs == 1:
# Don't use multiprocessing if num_jobs == 1.
for mixture in tqdm(
map(_simulate_worker, sampler_iter),
total=num_meetings,
desc="Simulating meetings",
):
mixtures.append(mixture)
else:
for mixture in tqdm(
parallel_map(
_simulate_worker,
sampler_iter,
num_jobs=num_jobs,
queue_size=num_jobs * MAX_TASKS_WAITING,
),
total=num_meetings,
desc="Simulating meetings",
):
mixtures.append(mixture)

return CutSet.from_cuts(mixtures)

Expand Down
Loading