Skip to content

Commit

Permalink
Adding log spectrogram (#1094)
Browse files Browse the repository at this point in the history
* Adding log spectrogram feat extractor.

* Removed test.

* Blacked.
  • Loading branch information
Tomiinek authored Jul 17, 2023
1 parent 8abbf9f commit cb0b7d3
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 3 deletions.
2 changes: 2 additions & 0 deletions lhotse/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from .kaldi.extractors import (
Fbank,
FbankConfig,
LogSpectrogram,
LogSpectrogramConfig,
Mfcc,
MfccConfig,
Spectrogram,
Expand Down
11 changes: 10 additions & 1 deletion lhotse/features/kaldi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,11 @@
from .extractors import Fbank, FbankConfig, Mfcc, MfccConfig
from .extractors import (
Fbank,
FbankConfig,
LogSpectrogram,
LogSpectrogramConfig,
Mfcc,
MfccConfig,
Spectrogram,
SpectrogramConfig,
)
from .layers import Wav2FFT, Wav2LogFilterBank, Wav2LogSpec, Wav2MFCC, Wav2Spec, Wav2Win
113 changes: 112 additions & 1 deletion lhotse/features/kaldi/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import torch

from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.features.kaldi.layers import Wav2LogFilterBank, Wav2MFCC, Wav2Spec
from lhotse.features.kaldi.layers import (
Wav2LogFilterBank,
Wav2LogSpec,
Wav2MFCC,
Wav2Spec,
)
from lhotse.utils import (
EPSILON,
Seconds,
Expand Down Expand Up @@ -357,6 +362,112 @@ def compute_energy(features: np.ndarray) -> float:
return float(np.sum(features))


@dataclass
class LogSpectrogramConfig:
sampling_rate: int = 16000
frame_length: Seconds = 0.025
frame_shift: Seconds = 0.01
round_to_power_of_two: bool = True
remove_dc_offset: bool = True
preemph_coeff: float = 0.97
window_type: str = "povey"
dither: float = 0.0
snip_edges: bool = False
energy_floor: float = EPSILON
raw_energy: bool = True
use_energy: bool = False
use_fft_mag: bool = False
device: str = "cpu"

def __post_init__(self):
if self.snip_edges:
warnings.warn(
"`snip_edges` is set to True, which may cause issues in duration to num-frames conversion in Lhotse."
)

def to_dict(self) -> Dict[str, Any]:
return asdict_nonull(self)

@staticmethod
def from_dict(data: Dict[str, Any]) -> "LogSpectrogramConfig":
return LogSpectrogramConfig(**data)


@register_extractor
class LogSpectrogram(FeatureExtractor):
name = "kaldi-log-spectrogram"
config_type = LogSpectrogramConfig

def __init__(self, config: Optional[LogSpectrogramConfig] = None):
super().__init__(config=config)
config_dict = self.config.to_dict()
config_dict.pop("device")
self.extractor = Wav2LogSpec(**config_dict).to(self.device).eval()

@property
def device(self) -> Union[str, torch.device]:
return self.config.device

@property
def frame_shift(self) -> Seconds:
return self.config.frame_shift

def feature_dim(self, sampling_rate: int) -> int:
return self.config.num_ceps

def extract(
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
) -> Union[np.ndarray, torch.Tensor]:
assert sampling_rate == self.config.sampling_rate, (
f"Spectrogram was instantiated for sampling_rate "
f"{self.config.sampling_rate}, but "
f"sampling_rate={sampling_rate} was passed to extract(). "
"Note you can use CutSet/RecordingSet.resample() to change the audio sampling rate."
)

is_numpy = False
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
is_numpy = True

if samples.ndim == 1:
samples = samples.unsqueeze(0)

feats = self.extractor(samples.to(self.device))[0]

if is_numpy:
return feats.cpu().numpy()
else:
return feats.cpu()

def extract_batch(
self,
samples: Union[
np.ndarray, torch.Tensor, Sequence[np.ndarray], Sequence[torch.Tensor]
],
sampling_rate: int,
lengths: Optional[Union[np.ndarray, torch.Tensor]] = None,
) -> Union[np.ndarray, torch.Tensor, List[np.ndarray], List[torch.Tensor]]:
return _extract_batch(
self.extractor,
samples,
sampling_rate,
frame_shift=self.frame_shift,
lengths=lengths,
device=self.device,
)

@staticmethod
def mix(
features_a: np.ndarray, features_b: np.ndarray, energy_scaling_factor_b: float
) -> np.ndarray:
return features_a + energy_scaling_factor_b * features_b

@staticmethod
def compute_energy(features: np.ndarray) -> float:
return float(np.sum(features))


def _extract_batch(
extractor: FeatureExtractor,
samples: Union[
Expand Down
2 changes: 2 additions & 0 deletions test/cut/test_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LibrosaFbank,
LibrosaFbankConfig,
LilcomChunkyWriter,
LogSpectrogram,
Mfcc,
MonoCut,
Recording,
Expand Down Expand Up @@ -197,6 +198,7 @@ def is_python_311_or_higher() -> bool:
Fbank,
Mfcc,
Spectrogram,
LogSpectrogram,
TorchaudioFbank,
TorchaudioMfcc,
pytest.param(
Expand Down
6 changes: 5 additions & 1 deletion test/features/test_kaldi_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from lhotse.features.kaldi.extractors import (
Fbank,
FbankConfig,
LogSpectrogram,
LogSpectrogramConfig,
Mfcc,
MfccConfig,
Spectrogram,
Expand Down Expand Up @@ -141,6 +143,7 @@ def test_kaldi_spectrogram_extractor_vs_torchaudio(recording):
lambda: Fbank(FbankConfig(snip_edges=True)),
lambda: Mfcc(MfccConfig(snip_edges=True)),
lambda: Spectrogram(SpectrogramConfig(snip_edges=True)),
lambda: LogSpectrogram(LogSpectrogramConfig(snip_edges=True)),
],
)
def test_kaldi_extractors_snip_edges_warning(extractor_type):
Expand All @@ -149,7 +152,8 @@ def test_kaldi_extractors_snip_edges_warning(extractor_type):


@pytest.mark.parametrize(
"feature_type", ["kaldi-fbank", "kaldi-mfcc", "kaldi-spectrogram"]
"feature_type",
["kaldi-fbank", "kaldi-mfcc", "kaldi-spectrogram", "kaldi-log-spectrogram"],
)
def test_feature_extractor_serialization(feature_type):
fe = create_default_feature_extractor(feature_type)
Expand Down

0 comments on commit cb0b7d3

Please sign in to comment.