Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch extraction for kaldi features #947

Merged
merged 15 commits into from
Jan 19, 2023
135 changes: 74 additions & 61 deletions lhotse/cut/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,6 +1983,7 @@ def compute_and_store_features_batch(
manifest_path: Optional[Pathlike] = None,
batch_duration: Seconds = 600.0,
num_workers: int = 4,
collate: bool = False,
augment_fn: Optional[AugmentFn] = None,
storage_type: Type[FW] = LilcomChunkyWriter,
overwrite: bool = False,
Expand Down Expand Up @@ -2024,6 +2025,10 @@ def compute_and_store_features_batch(
Determines batch size dynamically.
:param num_workers: How many background dataloading workers should be used
for reading the audio.
:param collate: If ``True``, the waveforms will be collated into a single
padded tensor before being passed to the feature extractor. Some extractors
can be faster this way (for e.g., see ``lhotse.features.kaldi.extractors``).
If you are using ``kaldifeat`` extractors, you should set this to ``False``.
:param augment_fn: an optional callable used for audio augmentation.
Be careful with the types of augmentations used: if they modify
the start/end/duration times of the cut and its supervisions,
Expand All @@ -2036,10 +2041,12 @@ def compute_and_store_features_batch(
By default, this method will append to these files if they exist.
:return: Returns a new ``CutSet`` with ``Features`` manifests attached to the cuts.
"""
from concurrent.futures import ThreadPoolExecutor

import torch
from torch.utils.data import DataLoader

from lhotse.dataset import SingleCutSampler, UnsupervisedWaveformDataset
from lhotse.dataset import SimpleCutSampler, UnsupervisedWaveformDataset
from lhotse.qa import validate_features

frame_shift = extractor.frame_shift
Expand All @@ -2052,23 +2059,84 @@ def compute_and_store_features_batch(

# We tell the sampler to ignore cuts that were already processed.
# It will avoid I/O for reading them in the DataLoader.
sampler = SingleCutSampler(self, max_duration=batch_duration)
sampler = SimpleCutSampler(self, max_duration=batch_duration)
sampler.filter(lambda cut: cut.id not in cuts_writer.ignore_ids)
dataset = UnsupervisedWaveformDataset(collate=False)
dataset = UnsupervisedWaveformDataset(collate=collate)
dloader = DataLoader(
dataset, batch_size=None, sampler=sampler, num_workers=num_workers
)

# Background worker to save features to disk.
def _save_worker(cuts: List[Cut], features: List[np.ndarray]) -> None:
for cut, feat_mat in zip(cuts, features):
if isinstance(cut, PaddingCut):
# For padding cuts, just fill out the fields in the manifest
# and don't store anything.
cuts_writer.write(
fastcopy(
cut,
num_frames=feat_mat.shape[0],
num_features=feat_mat.shape[1],
frame_shift=frame_shift,
)
)
continue
# Store the computed features and describe them in a manifest.
if isinstance(feat_mat, torch.Tensor):
feat_mat = feat_mat.cpu().numpy()
storage_key = feats_writer.write(cut.id, feat_mat)
feat_manifest = Features(
start=cut.start,
duration=cut.duration,
type=extractor.name,
num_frames=feat_mat.shape[0],
num_features=feat_mat.shape[1],
frame_shift=frame_shift,
sampling_rate=cut.sampling_rate,
channels=cut.channel,
storage_type=feats_writer.name,
storage_path=str(feats_writer.storage_path),
storage_key=storage_key,
)
validate_features(feat_manifest, feats_data=feat_mat)

# Update the cut manifest.
if isinstance(cut, DataCut):
feat_manifest.recording_id = cut.recording_id
cut = fastcopy(cut, features=feat_manifest)
if isinstance(cut, MixedCut):
# If this was a mixed cut, we will just discard its
# recordings and create a new mono cut that has just
# the features attached.
feat_manifest.recording_id = cut.id
cut = MonoCut(
id=cut.id,
start=0,
duration=cut.duration,
channel=0,
# Update supervisions recording_id for consistency
supervisions=[
fastcopy(s, recording_id=cut.id) for s in cut.supervisions
],
features=feat_manifest,
recording=None,
)
cuts_writer.write(cut, flush=True)

futures = []
with cuts_writer, storage_type(
storage_path, mode="w" if overwrite else "a"
) as feats_writer, tqdm(
desc="Computing features in batches", total=sampler.num_cuts
) as progress:
) as progress, ThreadPoolExecutor(
max_workers=1 # We only want one background worker so that serialization is deterministic.
) as executor:
# Display progress bar correctly.
progress.update(len(cuts_writer.ignore_ids))
for batch in dloader:
cuts = batch["cuts"]
waves = batch["audio"]
wave_lens = batch["audio_lens"] if collate else None

if len(cuts) == 0:
# Fault-tolerant audio loading filtered out everything.
Expand All @@ -2087,65 +2155,10 @@ def compute_and_store_features_batch(
# Note: chunk_size option limits the memory consumption
# for very long cuts.
features = extractor.extract_batch(
waves, sampling_rate=cuts[0].sampling_rate
)

for cut, feat_mat in zip(cuts, features):
if isinstance(cut, PaddingCut):
# For padding cuts, just fill out the fields in the manifest
# and don't store anything.
cuts_writer.write(
fastcopy(
cut,
num_frames=feat_mat.shape[0],
num_features=feat_mat.shape[1],
frame_shift=frame_shift,
)
)
continue
# Store the computed features and describe them in a manifest.
if isinstance(feat_mat, torch.Tensor):
feat_mat = feat_mat.cpu().numpy()
storage_key = feats_writer.write(cut.id, feat_mat)
feat_manifest = Features(
start=cut.start,
duration=cut.duration,
type=extractor.name,
num_frames=feat_mat.shape[0],
num_features=feat_mat.shape[1],
frame_shift=frame_shift,
sampling_rate=cut.sampling_rate,
channels=cut.channel,
storage_type=feats_writer.name,
storage_path=str(feats_writer.storage_path),
storage_key=storage_key,
waves, sampling_rate=cuts[0].sampling_rate, lengths=wave_lens
)
validate_features(feat_manifest, feats_data=feat_mat)

# Update the cut manifest.
if isinstance(cut, DataCut):
feat_manifest.recording_id = cut.recording_id
cut = fastcopy(cut, features=feat_manifest)
if isinstance(cut, MixedCut):
# If this was a mixed cut, we will just discard its
# recordings and create a new mono cut that has just
# the features attached.
feat_manifest.recording_id = cut.id
cut = MonoCut(
id=cut.id,
start=0,
duration=cut.duration,
channel=0,
# Update supervisions recording_id for consistency
supervisions=[
fastcopy(s, recording_id=cut.id)
for s in cut.supervisions
],
features=feat_manifest,
recording=None,
)
cuts_writer.write(cut, flush=True)

futures.append(executor.submit(_save_worker, cuts, features))
progress.update(len(cuts))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something that I find weird is that this works even though I have not called future.result() anywhere?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the executor is called as a context manager it blocks the execution on __exit__ until all threads have finished (calls .join())

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks!

# If ``manifest_path`` was provided, this is a lazy manifest;
Expand Down
11 changes: 9 additions & 2 deletions lhotse/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@
available_storage_backends,
close_cached_file_handles,
)
from .kaldi.extractors import Fbank, FbankConfig, Mfcc, MfccConfig
from .kaldi.extractors import (
Fbank,
FbankConfig,
Mfcc,
MfccConfig,
Spectrogram,
SpectrogramConfig,
)
from .kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
Expand All @@ -38,5 +45,5 @@
from .mfcc import TorchaudioMfcc, TorchaudioMfccConfig
from .mixer import FeatureMixer
from .opensmile import OpenSmileConfig, OpenSmileExtractor
from .spectrogram import Spectrogram, SpectrogramConfig
from .spectrogram import TorchaudioSpectrogram, TorchaudioSpectrogramConfig
from .ssl import S3PRLSSL, S3PRLSSLConfig
40 changes: 28 additions & 12 deletions lhotse/features/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Seconds,
asdict_nonull,
compute_num_frames,
compute_num_frames_from_samples,
exactly_one_not_null,
fastcopy,
ifnone,
Expand Down Expand Up @@ -139,12 +140,15 @@ def extract_batch(
np.ndarray, torch.Tensor, Sequence[np.ndarray], Sequence[torch.Tensor]
],
sampling_rate: int,
lengths: Optional[Union[np.ndarray, torch.Tensor]] = None,
) -> Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]:
"""
Performs batch extraction. It is not guaranteed to be faster
than :meth:`FeatureExtractor.extract` -- it depends on whether
the implementation of a particular feature extractor supports
accelerated batch computation.
accelerated batch computation. If `lengths` is provided, it is
assumed that the input is a batch of padded sequences, so we will
not perform any further collation.

.. note::
Unless overridden by child classes, it defaults to sequentially
Expand All @@ -156,22 +160,34 @@ def extract_batch(
input_is_list = False
input_is_torch = False

if isinstance(samples, list):
input_is_list = True
pass # nothing to do with `samples`
elif samples.ndim > 1:
samples = list(samples)
if lengths is not None:
feat_lens = [
compute_num_frames_from_samples(l, self.frame_shift, sampling_rate)
for l in lengths
]
assert isinstance(
samples, torch.Tensor
), "If `lengths` is provided, `samples` must be a batched and padded torch.Tensor."
else:
# The user passed an array/tensor of shape (num_samples,)
samples = [samples.reshape(1, -1)]
if isinstance(samples, list):
input_is_list = True
pass # nothing to do with `samples`
elif samples.ndim > 1:
samples = list(samples)
else:
# The user passed an array/tensor of shape (num_samples,)
samples = [samples.reshape(1, -1)]

if any(isinstance(x, torch.Tensor) for x in samples):
samples = [x.numpy() for x in samples]
input_is_torch = True
if any(isinstance(x, torch.Tensor) for x in samples):
samples = [x.numpy() for x in samples]
input_is_torch = True

result = []
for item in samples:
result.append(self.extract(item, sampling_rate=sampling_rate))
res = self.extract(item, sampling_rate=sampling_rate)
if lengths is not None:
res = res[: feat_lens[len(result)]]
result.append(res)

if input_is_torch:
result = [torch.from_numpy(x) for x in result]
Expand Down
Loading