Skip to content

Commit

Permalink
[TTS][ASR] customize arguments for trimming the leading/trailing sile…
Browse files Browse the repository at this point in the history
…nce (NVIDIA#4582)

[TTS][ASR] enabled overriding arguments for trimming the leading and trailing silence using librosa.effects.trim

Signed-off-by: Xuesong Yang <[email protected]>
Signed-off-by: David Mosallanezhad <[email protected]>
  • Loading branch information
XuesongYang authored and Davood-M committed Aug 9, 2022
1 parent c736067 commit 6992abc
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 7 deletions.
18 changes: 17 additions & 1 deletion nemo/collections/asr/parts/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import random

import librosa
import numpy as np
import torch
import torch.nn as nn

Expand Down Expand Up @@ -100,14 +101,29 @@ def __init__(self, sample_rate=16000, int_values=False, augmentor=None):
def max_augmentation_length(self, length):
return self.augmentor.max_augmentation_length(length)

def process(self, file_path, offset=0, duration=0, trim=False, orig_sr=None):
def process(
self,
file_path,
offset=0,
duration=0,
trim=False,
trim_ref=np.max,
trim_top_db=60,
trim_frame_length=2048,
trim_hop_length=512,
orig_sr=None,
):
audio = AudioSegment.from_file(
file_path,
target_sr=self.sample_rate,
int_values=self.int_values,
offset=offset,
duration=duration,
trim=trim,
trim_ref=trim_ref,
trim_top_db=trim_top_db,
trim_frame_length=trim_frame_length,
trim_hop_length=trim_hop_length,
orig_sr=orig_sr,
)
return self.process_segment(audio)
Expand Down
49 changes: 45 additions & 4 deletions nemo/collections/asr/parts/preprocessing/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,18 @@ class AudioSegment(object):
:raises TypeError: If the sample data type is not float or int.
"""

def __init__(self, samples, sample_rate, target_sr=None, trim=False, trim_db=60, orig_sr=None):
def __init__(
self,
samples,
sample_rate,
target_sr=None,
trim=False,
trim_ref=np.max,
trim_top_db=60,
trim_frame_length=2048,
trim_hop_length=512,
orig_sr=None,
):
"""Create audio segment from samples.
Samples are convert float32 internally, with int scaled to [-1, 1].
"""
Expand All @@ -73,7 +84,9 @@ def __init__(self, samples, sample_rate, target_sr=None, trim=False, trim_db=60,
samples = librosa.core.resample(samples, orig_sr=sample_rate, target_sr=target_sr)
sample_rate = target_sr
if trim:
samples, _ = librosa.effects.trim(samples, top_db=trim_db)
samples, _ = librosa.effects.trim(
samples, top_db=trim_top_db, ref=trim_ref, frame_length=trim_frame_length, hop_length=trim_hop_length
)
self._samples = samples
self._sample_rate = sample_rate
if self._samples.ndim >= 2:
Expand Down Expand Up @@ -125,7 +138,18 @@ def _convert_samples_to_float32(samples):

@classmethod
def from_file(
cls, audio_file, target_sr=None, int_values=False, offset=0, duration=0, trim=False, orig_sr=None,
cls,
audio_file,
target_sr=None,
int_values=False,
offset=0,
duration=0,
trim=False,
trim_ref=np.max,
trim_top_db=60,
trim_frame_length=2048,
trim_hop_length=512,
orig_sr=None,
):
"""
Load a file supported by librosa and return as an AudioSegment.
Expand All @@ -134,6 +158,13 @@ def from_file(
:param int_values: if true, load samples as 32-bit integers
:param offset: offset in seconds when loading audio
:param duration: duration in seconds when loading audio
:param trim: if true, trim leading and trailing silence from an audio signal
:param trim_ref: the reference amplitude. By default, it uses `np.max` and compares to the peak amplitude in
the signal
:param trim_top_db: the threshold (in decibels) below reference to consider as silence
:param trim_frame_length: the number of samples per analysis frame
:param trim_hop_length: the number of samples between analysis frames
:param orig_sr: the original sample rate
:return: numpy array of samples
"""
samples = None
Expand Down Expand Up @@ -174,7 +205,17 @@ def from_file(
libs = "soundfile, and pydub" if HAVE_PYDUB else "soundfile"
raise Exception(f"Your audio file {audio_file} could not be decoded. We tried using {libs}.")

return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr)
return cls(
samples,
sample_rate,
target_sr=target_sr,
trim=trim,
trim_ref=trim_ref,
trim_top_db=trim_top_db,
trim_frame_length=trim_frame_length,
trim_hop_length=trim_hop_length,
orig_sr=orig_sr,
)

@classmethod
def segment_from_file(cls, audio_file, target_sr=None, n_segments=0, trim=False, orig_sr=None):
Expand Down
26 changes: 24 additions & 2 deletions nemo/collections/tts/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def __init__(
min_duration: Optional[float] = None,
ignore_file: Optional[Union[str, Path]] = None,
trim: bool = False,
trim_ref: Optional[float] = None,
trim_top_db: Optional[int] = None,
trim_frame_length: Optional[int] = None,
trim_hop_length: Optional[int] = None,
n_fft: int = 1024,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
Expand Down Expand Up @@ -119,7 +123,14 @@ def __init__(
audio to compute duration. Defaults to None which does not prune.
ignore_file (Optional[Union[str, Path]]): The location of a pickle-saved list of audio paths
that will be pruned prior to training. Defaults to None which does not prune.
trim (Optional[bool]): Whether to apply librosa.effects.trim to the audio file. Defaults to False.
trim (bool): Whether to apply `librosa.effects.trim` to trim leading and trailing silence from an audio
signal. Defaults to False.
trim_ref (Optional[float]): the reference amplitude. By default, it uses `np.max` and compares to the peak
amplitude in the signal.
trim_top_db (Optional[int]): the threshold (in decibels) below reference to consider as silence.
Defaults to 60.
trim_frame_length (Optional[int]): the number of samples per analysis frame. Defaults to 2048.
trim_hop_length (Optional[int]): the number of samples between analysis frames. Defaults to 512.
n_fft (int): The number of fft samples. Defaults to 1024
win_length (Optional[int]): The length of the stft windows. Defaults to None which uses n_fft.
hop_length (Optional[int]): The hope length between fft computations. Defaults to None which uses n_fft//4.
Expand Down Expand Up @@ -229,6 +240,10 @@ def __init__(
self.sample_rate = sample_rate
self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate)
self.trim = trim
self.trim_ref = trim_ref if trim_ref is not None else np.max
self.trim_top_db = trim_top_db if trim_top_db is not None else 60
self.trim_frame_length = trim_frame_length if trim_frame_length is not None else 2048
self.trim_hop_length = trim_hop_length if trim_hop_length is not None else 512

self.n_fft = n_fft
self.n_mels = n_mels
Expand Down Expand Up @@ -438,7 +453,14 @@ def __getitem__(self, index):
rel_audio_path_as_text_id += "_phoneme"

# Load audio
features = self.featurizer.process(sample["audio_filepath"], trim=self.trim)
features = self.featurizer.process(
sample["audio_filepath"],
trim=self.trim,
trim_ref=self.trim_ref,
trim_top_db=self.trim_top_db,
trim_frame_length=self.trim_frame_length,
trim_hop_length=self.trim_hop_length,
)
audio, audio_length = features, torch.tensor(features.shape[0]).long()

if "text_tokens" in sample:
Expand Down

0 comments on commit 6992abc

Please sign in to comment.