Skip to content

Batch extraction for kaldi features #947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jan 19, 2023
11 changes: 9 additions & 2 deletions lhotse/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@
available_storage_backends,
close_cached_file_handles,
)
from .kaldi.extractors import Fbank, FbankConfig, Mfcc, MfccConfig
from .kaldi.extractors import (
Fbank,
FbankConfig,
Mfcc,
MfccConfig,
Spectrogram,
SpectrogramConfig,
)
from .kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
Expand All @@ -38,5 +45,5 @@
from .mfcc import TorchaudioMfcc, TorchaudioMfccConfig
from .mixer import FeatureMixer
from .opensmile import OpenSmileConfig, OpenSmileExtractor
from .spectrogram import Spectrogram, SpectrogramConfig
from .spectrogram import TorchaudioSpectrogram, TorchaudioSpectrogramConfig
from .ssl import S3PRLSSL, S3PRLSSLConfig
220 changes: 213 additions & 7 deletions lhotse/features/kaldi/extractors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import warnings
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union

import numpy as np
import torch

from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.features.kaldi.layers import Wav2LogFilterBank, Wav2MFCC
from lhotse.utils import EPSILON, Seconds, asdict_nonull
from lhotse.features.kaldi.layers import Wav2LogFilterBank, Wav2MFCC, Wav2Spec
from lhotse.utils import (
EPSILON,
Seconds,
asdict_nonull,
compute_num_frames_from_samples,
)


@dataclass
Expand All @@ -29,6 +35,7 @@ class FbankConfig:
num_filters: int = 80
num_mel_bins: Optional[int] = None # do not use
norm_filters: bool = False
device: str = "cpu"

def __post_init__(self):
# This is to help users transition to a different Fbank implementation
Expand All @@ -37,6 +44,11 @@ def __post_init__(self):
self.num_filters = self.num_mel_bins
self.num_mel_bins = None

if self.snip_edges:
warnings.warn(
"`snip_edges` is set to True, which may cause issues in duration to num-frames conversion in Lhotse."
)

def to_dict(self) -> Dict[str, Any]:
return asdict_nonull(self)

Expand All @@ -52,7 +64,13 @@ class Fbank(FeatureExtractor):

def __init__(self, config: Optional[FbankConfig] = None):
super().__init__(config=config)
self.extractor = Wav2LogFilterBank(**self.config.to_dict()).eval()
config_dict = self.config.to_dict()
config_dict.pop("device")
self.extractor = Wav2LogFilterBank(**config_dict).to(self.device).eval()

@property
def device(self) -> Union[str, torch.device]:
return self.config.device

@property
def frame_shift(self) -> Seconds:
Expand All @@ -79,13 +97,24 @@ def extract(
if samples.ndim == 1:
samples = samples.unsqueeze(0)

feats = self.extractor(samples)[0]
feats = self.extractor(samples.to(self.device))[0]

if is_numpy:
return feats.cpu().numpy()
else:
return feats

def extract_batch(
self,
samples: Union[
np.ndarray, torch.Tensor, Sequence[np.ndarray], Sequence[torch.Tensor]
],
sampling_rate: int,
) -> Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]:
return _extract_batch(
self.extractor, samples, sampling_rate, device=self.device
)

@staticmethod
def mix(
features_a: np.ndarray, features_b: np.ndarray, energy_scaling_factor_b: float
Expand Down Expand Up @@ -125,6 +154,7 @@ class MfccConfig:
norm_filters: bool = False
num_ceps: int = 13
cepstral_lifter: int = 22
device: str = "cpu"

def __post_init__(self):
# This is to help users transition to a different Mfcc implementation
Expand All @@ -133,6 +163,11 @@ def __post_init__(self):
self.num_filters = self.num_mel_bins
self.num_mel_bins = None

if self.snip_edges:
warnings.warn(
"`snip_edges` is set to True, which may cause issues in duration to num-frames conversion in Lhotse."
)

def to_dict(self) -> Dict[str, Any]:
return asdict_nonull(self)

Expand All @@ -148,7 +183,13 @@ class Mfcc(FeatureExtractor):

def __init__(self, config: Optional[MfccConfig] = None):
super().__init__(config=config)
self.extractor = Wav2MFCC(**self.config.to_dict()).eval()
config_dict = self.config.to_dict()
config_dict.pop("device")
self.extractor = Wav2MFCC(**config_dict).to(self.device).eval()

@property
def device(self) -> Union[str, torch.device]:
return self.config.device

@property
def frame_shift(self) -> Seconds:
Expand All @@ -175,9 +216,174 @@ def extract(
if samples.ndim == 1:
samples = samples.unsqueeze(0)

feats = self.extractor(samples)[0]
feats = self.extractor(samples.to(self.device))[0]

if is_numpy:
return feats.cpu().numpy()
else:
return feats

def extract_batch(
self,
samples: Union[
np.ndarray, torch.Tensor, Sequence[np.ndarray], Sequence[torch.Tensor]
],
sampling_rate: int,
) -> Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]:
return _extract_batch(
self.extractor, samples, sampling_rate, device=self.device
)


@dataclass
class SpectrogramConfig:
sampling_rate: int = 16000
frame_length: Seconds = 0.025
frame_shift: Seconds = 0.01
round_to_power_of_two: bool = True
remove_dc_offset: bool = True
preemph_coeff: float = 0.97
window_type: str = "povey"
dither: float = 0.0
snip_edges: bool = False
energy_floor: float = EPSILON
raw_energy: bool = True
use_energy: bool = False
use_fft_mag: bool = False
device: str = "cpu"

def __post_init__(self):
if self.snip_edges:
warnings.warn(
"`snip_edges` is set to True, which may cause issues in duration to num-frames conversion in Lhotse."
)

def to_dict(self) -> Dict[str, Any]:
return asdict_nonull(self)

@staticmethod
def from_dict(data: Dict[str, Any]) -> "SpectrogramConfig":
return SpectrogramConfig(**data)


@register_extractor
class Spectrogram(FeatureExtractor):
name = "kaldi-spectrogram"
config_type = SpectrogramConfig

def __init__(self, config: Optional[SpectrogramConfig] = None):
super().__init__(config=config)
config_dict = self.config.to_dict()
config_dict.pop("device")
self.extractor = Wav2Spec(**config_dict).to(self.device).eval()

@property
def device(self) -> Union[str, torch.device]:
return self.config.device

@property
def frame_shift(self) -> Seconds:
return self.config.frame_shift

def feature_dim(self, sampling_rate: int) -> int:
return self.config.num_ceps

def extract(
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
) -> Union[np.ndarray, torch.Tensor]:
assert sampling_rate == self.config.sampling_rate, (
f"Spectrogram was instantiated for sampling_rate "
f"{self.config.sampling_rate}, but "
f"sampling_rate={sampling_rate} was passed to extract(). "
"Note you can use CutSet/RecordingSet.resample() to change the audio sampling rate."
)

is_numpy = False
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
is_numpy = True

if samples.ndim == 1:
samples = samples.unsqueeze(0)

feats = self.extractor(samples.to(self.device))[0]

if is_numpy:
return feats.cpu().numpy()
else:
return feats.cpu()

def extract_batch(
self,
samples: Union[
np.ndarray, torch.Tensor, Sequence[np.ndarray], Sequence[torch.Tensor]
],
sampling_rate: int,
) -> Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]:
return _extract_batch(
self.extractor, samples, sampling_rate, device=self.device
)

@staticmethod
def mix(
features_a: np.ndarray, features_b: np.ndarray, energy_scaling_factor_b: float
) -> np.ndarray:
return np.exp(features_a) + energy_scaling_factor_b * np.exp(features_b)


def _extract_batch(
extractor: FeatureExtractor,
samples: Union[
np.ndarray, torch.Tensor, Sequence[np.ndarray], Sequence[torch.Tensor]
],
sampling_rate: int,
device: Union[str, torch.device] = "cpu",
) -> Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]:
input_is_list = False
input_is_torch = False

if isinstance(samples, list):
input_is_list = True
pass # nothing to do with `samples`
elif samples.ndim > 1:
samples = list(samples)
else:
# The user passed an array/tensor of shape (num_samples,)
samples = [samples.reshape(1, -1)]

if any(isinstance(x, torch.Tensor) for x in samples):
samples = [x.numpy() for x in samples]
input_is_torch = True

samples = [
torch.from_numpy(x).squeeze() if isinstance(x, np.ndarray) else x.squeeze()
for x in samples
]
feat_lens = [
compute_num_frames_from_samples(
num_samples=len(x),
frame_shift=extractor.frame_shift,
sampling_rate=sampling_rate,
)
for x in samples
]
samples = torch.nn.utils.rnn.pad_sequence(samples, batch_first=True)
feats = extractor(samples.to(device)).cpu()
result = [feats[i, : feat_lens[i]] for i in range(len(samples))]

if not input_is_torch:
result = [x.numpy() for x in result]

# If all items are of the same shape, concatenate
if len(result) == 1:
if input_is_list:
return result
else:
return result[0]
elif all(item.shape == result[0].shape for item in result[1:]):
if input_is_torch:
return torch.stack(result, dim=0)
else:
return np.stack(result, axis=0)
else:
return result
10 changes: 5 additions & 5 deletions lhotse/features/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@dataclass
class SpectrogramConfig:
class TorchaudioSpectrogramConfig:
# Note that `snip_edges` parameter is missing from config: in order to simplify the relationship between
# the duration and the number of frames, we are always setting `snip_edges` to False.
dither: float = 0.0
Expand All @@ -27,16 +27,16 @@ def to_dict(self) -> Dict[str, Any]:
return asdict(self)

@staticmethod
def from_dict(data: Dict[str, Any]) -> "SpectrogramConfig":
return SpectrogramConfig(**data)
def from_dict(data: Dict[str, Any]) -> "TorchaudioSpectrogramConfig":
return TorchaudioSpectrogramConfig(**data)


@register_extractor
class Spectrogram(TorchaudioFeatureExtractor):
class TorchaudioSpectrogram(TorchaudioFeatureExtractor):
"""Log spectrogram feature extractor based on ``torchaudio.compliance.kaldi.spectrogram`` function."""

name = "spectrogram"
config_type = SpectrogramConfig
config_type = TorchaudioSpectrogramConfig

def _feature_fn(self, *args, **kwargs):
from torchaudio.compliance.kaldi import spectrogram
Expand Down
Loading