Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loudness normalization with pyloudnorm #1016

Merged
merged 5 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
${{ matrix.torch-install-cmd }}
pip install '.[tests]'
# Enable some optional tests
pip install h5py dill smart_open[http] kaldifeat kaldi_native_io webdataset==0.2.5 s3prl scipy nara_wpe
pip install h5py dill smart_open[http] kaldifeat kaldi_native_io webdataset==0.2.5 s3prl scipy nara_wpe pyloudnorm
- name: Install sph2pipe
run: |
lhotse install-sph2pipe # Handle sphere files.
Expand Down
18 changes: 18 additions & 0 deletions lhotse/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from lhotse.augmentation import (
AudioTransform,
DereverbWPE,
LoudnessNormalization,
Resample,
ReverbWithImpulseResponse,
Speed,
Expand Down Expand Up @@ -707,6 +708,23 @@ def perturb_volume(self, factor: float, affix_id: bool = True) -> "Recording":
transforms=transforms,
)

def normalize_loudness(self, target: float, affix_id: bool = False) -> "Recording":
"""
Return a new ``Recording`` that will lazily apply WPE dereverberation.

:param target: The target loudness (in dB) to normalize to.
:param affix_id: When true, we will modify the ``Recording.id`` field
by affixing it with "_ln{factor}".
:return: a modified copy of the current ``Recording``.
"""
transforms = self.transforms.copy() if self.transforms is not None else []
transforms.append(LoudnessNormalization(target=target).to_dict())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sampling_rate is missed

return fastcopy(
self,
id=f"{self.id}_ln{target}" if affix_id else self.id,
transforms=transforms,
)

def dereverb_wpe(self, affix_id: bool = True) -> "Recording":
"""
Return a new ``Recording`` that will lazily apply WPE dereverberation.
Expand Down
2 changes: 2 additions & 0 deletions lhotse/augmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .common import AugmentFn
from .loudness import LoudnessNormalization
from .rir import ReverbWithImpulseResponse
from .torchaudio import *
from .transform import AudioTransform
from .utils import FastRandomRIRGenerator, convolve1d
Expand Down
72 changes: 72 additions & 0 deletions lhotse/augmentation/loudness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import warnings
from dataclasses import asdict, dataclass
from typing import Optional, Tuple, Union

import numpy as np
import torch

from lhotse.augmentation.transform import AudioTransform
from lhotse.utils import Seconds, is_module_available


@dataclass
class LoudnessNormalization(AudioTransform):
"""
Loudness normalization based on pyloudnorm: https://github.com/csteinmetz1/pyloudnorm.
"""

target: float
sampling_rate: int = 16000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pzelasko should we make sampling_rate as a member?

class AudioTransform:
    def __call__(self, samples: np.ndarray, sampling_rate: int) -> np.ndarray:
        """
        Apply transform.

        To be implemented in derived classes.
        """
        raise NotImplementedError


def __call__(
self, samples: Union[np.ndarray, torch.Tensor], *args, **kwargs
) -> np.ndarray:
if torch.is_tensor(samples):
samples = samples.cpu().numpy()
augmented = normalize_loudness(samples, **asdict(self))
return augmented

def reverse_timestamps(
self, offset: Seconds, duration: Optional[Seconds], sampling_rate: int
) -> Tuple[Seconds, Optional[Seconds]]:
return offset, duration


def normalize_loudness(
audio: np.ndarray,
target: float,
sampling_rate: int = 16000,
) -> np.ndarray:
"""
Applies pyloudnorm based loudness normalization to the input audio. The input audio
can have up to 5 channels, with the following order: [left, right, center, left_surround, right_surround]

:param audio: the input audio, expected to be 2D with shape (channels, samples).
:param target: the target loudness in LUFS.
:param sampling_rate: the sampling rate of the audio.
:return: the loudness normalized audio.
"""
if not is_module_available("pyloudnorm"):
raise ImportError(
"Please install pyloudnorm first using 'pip install pyloudnorm'"
)

import pyloudnorm as pyln

assert audio.ndim == 2, f"Expected 2D audio shape, got: {audio.shape}"

duration = audio.shape[1] / sampling_rate

# measure the loudness first
meter = pyln.Meter(
sampling_rate, block_size=min(0.4, duration)
) # create BS.1770 meter
loudness = meter.integrated_loudness(audio.T)

# loudness normalize audio to target LUFS. We will ignore the warnings related to
# clipping the audio.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
loudness_normalized_audio = pyln.normalize.loudness(audio.T, loudness, target)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If audio clipping is an issue here you can add a limiter as a post-processing step, e.g. https://github.com/pzelasko/cylimiter

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can add it. But I don't have enough familiarity with limiters (or "loudness" for that matter) to know what to do exactly.

Copy link
Collaborator

@pzelasko pzelasko Apr 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The defaults should "just work" with pretty much anything, it basically keeps track of the signal's loudness with a small lookahead and reduces the gain if it crosses some threshold. Think of it as soft clipping that doesn't introduce as much distortion as hard clipping.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.. but again I don't know if that's a real problem with this approach and worth the extra dependency, so it's your call :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far I have only used it to make the LibriCSS distant mic audio louder (it is at around -52 dB originally), and it sounds okay even with the clipping. I suppose we can let it be for now and add the limiter later if someone needs it?


return loudness_normalized_audio.T
143 changes: 143 additions & 0 deletions lhotse/augmentation/rir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch

from lhotse.augmentation.transform import AudioTransform
from lhotse.augmentation.utils import FastRandomRIRGenerator, convolve1d
from lhotse.utils import Seconds


@dataclass
class ReverbWithImpulseResponse(AudioTransform):
"""
Reverberation effect by convolving with a room impulse response.
This code is based on Kaldi's wav-reverberate utility:
https://github.com/kaldi-asr/kaldi/blob/master/src/featbin/wav-reverberate.cc
If no ``rir_recording`` is provided, we will generate an impulse response using a fast random
generator (https://arxiv.org/abs/2208.04101).
The impulse response can possibly be multi-channel, in which case multi-channel reverberated
audio can be obtained by appropriately setting `rir_channels`. For example, `rir_channels=[0,1]`
will convolve using the first two channels of the impulse response, generating a stereo
reverberated audio.
Note that we enforce the --shift-output option in Kaldi's wav-reverberate utility,
which means that the output length will be equal to the input length.
"""

rir: Optional[dict] = None
normalize_output: bool = True
early_only: bool = False
rir_channels: List[int] = field(default_factory=lambda: [0])
rir_generator: Optional[Union[dict, Callable]] = None

RIR_SCALING_FACTOR: float = 0.5**15

def __post_init__(self):
if isinstance(self.rir, dict):
from lhotse import Recording

# Pass a shallow copy of the RIR dict since `from_dict()` pops the `sources` key.
self.rir = Recording.from_dict(self.rir.copy())

assert (
self.rir is not None or self.rir_generator is not None
), "Either `rir` or `rir_generator` must be provided."

if self.rir is not None:
assert all(
c < self.rir.num_channels for c in self.rir_channels
), "Invalid channel index in `rir_channels`"

if self.rir_generator is not None and isinstance(self.rir_generator, dict):
self.rir_generator = FastRandomRIRGenerator(**self.rir_generator)

def __call__(
self,
samples: np.ndarray,
sampling_rate: int,
) -> np.ndarray:
"""
:param samples: The audio samples to reverberate.
:param sampling_rate: The sampling rate of the audio samples.
"""
sampling_rate = int(sampling_rate) # paranoia mode

D_in, N_in = samples.shape
input_is_mono = D_in == 1

# The following cases are possible:
# Case 1: input is mono, rir is mono -> mono output
# We will generate a random mono rir if not provided explicitly.
# Case 2: input is mono, rir is multi-channel -> multi-channel output
# This requires a user-provided rir, since we cannot simulate array microphone.
# Case 3: input is multi-channel, rir is mono -> multi-channel output
# This does not make much sense, but we will apply the same rir to all channels.
# 4. input is multi-channel, rir is multi-channel -> multi-channel output
# This also requires a user-provided rir. Also, the number of channels in the rir
# must match the number of channels in the input.

# Let us make some assertions based on the above.
if input_is_mono:
assert (
self.rir is not None or len(self.rir_channels) == 1
), "For mono input, either provide an RIR explicitly or set rir_channels to [0]."
else:
assert len(self.rir_channels) == 1 or len(self.rir_channels) == D_in, (
"For multi-channel input, we only support mono RIR or RIR with the same number "
"of channels as the input."
)

# Generate a random RIR if not provided.
if self.rir is None:
rir_ = self.rir_generator(nsource=1)
else:
rir_ = (
self.rir.load_audio(channels=self.rir_channels)
if not self.early_only
else self.rir.load_audio(channels=self.rir_channels, duration=0.05)
)

D_rir, N_rir = rir_.shape
N_out = N_in # Enforce shift output
# output is multi-channel if either input or rir is multi-channel
D_out = D_rir if input_is_mono else D_in

# if RIR is mono, repeat it to match the number of channels in the input
rir_ = rir_.repeat(D_out, axis=0) if D_rir == 1 else rir_

# Initialize output matrix with the specified input channel.
augmented = np.zeros((D_out, N_out), dtype=samples.dtype)

for d in range(D_out):
d_in = 0 if input_is_mono else d
augmented[d, :N_in] = samples[d_in]
power_before_reverb = np.sum(np.abs(samples[d_in]) ** 2) / N_in
rir_d = rir_[d, :] * self.RIR_SCALING_FACTOR

# Convolve the signal with impulse response.
aug_d = convolve1d(
torch.from_numpy(samples[d_in]), torch.from_numpy(rir_d)
).numpy()
shift_index = np.argmax(rir_d)
augmented[d, :] = aug_d[shift_index : shift_index + N_out]

if self.normalize_output:
power_after_reverb = np.sum(np.abs(augmented[d, :]) ** 2) / N_out
if power_after_reverb > 0:
augmented[d, :] *= np.sqrt(power_before_reverb / power_after_reverb)

return augmented

def reverse_timestamps(
self,
offset: Seconds,
duration: Optional[Seconds],
sampling_rate: Optional[int], # Not used, made for compatibility purposes
) -> Tuple[Seconds, Optional[Seconds]]:
"""
This method just returns the original offset and duration since we have
implemented output shifting which preserves these properties.
"""

return offset, duration
Loading