Skip to content

Commit

Permalink
AudioToAudio datasets and related test
Browse files Browse the repository at this point in the history
Signed-off-by: Ante Jukić <[email protected]>
  • Loading branch information
anteju committed Nov 15, 2022
1 parent 5665f14 commit 7ea9b7b
Show file tree
Hide file tree
Showing 9 changed files with 2,306 additions and 12 deletions.
932 changes: 932 additions & 0 deletions nemo/collections/asr/data/audio_to_audio.py

Large diffs are not rendered by default.

91 changes: 91 additions & 0 deletions nemo/collections/asr/data/audio_to_audio_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.asr.data import audio_to_audio


def get_audio_to_target_dataset(config: dict) -> audio_to_audio.AudioToTargetDataset:
"""Instantiates an audio-to-audio dataset.
Args:
config: Config of AudioToTargetDataset.
Returns:
An instance of AudioToTargetDataset
"""
dataset = audio_to_audio.AudioToTargetDataset(
manifest_filepath=config['manifest_filepath'],
sample_rate=config['sample_rate'],
input_key=config['input_key'],
target_key=config['target_key'],
audio_duration=config.get('audio_duration', None),
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
input_channel_selector=config.get('input_channel_selector', None),
target_channel_selector=config.get('target_channel_selector', None),
)
return dataset


def get_audio_to_target_with_reference_dataset(config: dict) -> audio_to_audio.AudioToTargetWithReferenceDataset:
"""Instantiates an audio-to-audio dataset.
Args:
config: Config of AudioToTargetWithReferenceDataset.
Returns:
An instance of AudioToTargetWithReferenceDataset
"""
dataset = audio_to_audio.AudioToTargetWithReferenceDataset(
manifest_filepath=config['manifest_filepath'],
sample_rate=config['sample_rate'],
input_key=config['input_key'],
target_key=config['target_key'],
reference_key=config['reference_key'],
audio_duration=config.get('audio_duration', None),
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
input_channel_selector=config.get('input_channel_selector', None),
target_channel_selector=config.get('target_channel_selector', None),
reference_channel_selector=config.get('reference_channel_selector', None),
reference_is_synchronized=config.get('reference_is_synchronized', True),
)
return dataset


def get_audio_to_target_with_embedding_dataset(config: dict) -> audio_to_audio.AudioToTargetWithEmbeddingDataset:
"""Instantiates an audio-to-audio dataset.
Args:
config: Config of AudioToTargetWithEmbeddingDataset.
Returns:
An instance of AudioToTargetWithEmbeddingDataset
"""
dataset = audio_to_audio.AudioToTargetWithEmbeddingDataset(
manifest_filepath=config['manifest_filepath'],
sample_rate=config['sample_rate'],
input_key=config['input_key'],
target_key=config['target_key'],
embedding_key=config['embedding_key'],
audio_duration=config.get('audio_duration', None),
max_duration=config.get('max_duration', None),
min_duration=config.get('min_duration', None),
max_utts=config.get('max_utts', 0),
input_channel_selector=config.get('input_channel_selector', None),
target_channel_selector=config.get('target_channel_selector', None),
)
return dataset
32 changes: 24 additions & 8 deletions nemo/collections/asr/parts/preprocessing/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,22 @@ def from_file(

@classmethod
def segment_from_file(
cls, audio_file, target_sr=None, n_segments=0, trim=False, orig_sr=None, channel_selector=None,
cls, audio_file, target_sr=None, n_segments=0, trim=False, orig_sr=None, channel_selector=None, offset=None
):
"""Grabs n_segments number of samples from audio_file randomly from the
file as opposed to at a specified offset.
"""Grabs n_segments number of samples from audio_file.
If offset is not provided, n_segments are selected randomly.
If offset is provided, it is used to calculate the starting sample.
Note that audio_file can be either the file path, or a file-like object.
:param audio_file: path to a file or a file-like object
:param target_sr: sample rate for the output samples
:param n_segments: desired number of samples
:param trim: if true, trim leading and trailing silence from an audio signal
:param orig_sr: the original sample rate
:param channel selector: select a subset of channels. If set to `None`, the original signal will be used.
:param offset: fixed offset in seconds
:return: numpy array of samples
"""
is_segmented = False
try:
Expand All @@ -275,13 +285,20 @@ def segment_from_file(

if 0 < n_segments_at_original_sr < len(f):
max_audio_start = len(f) - n_segments_at_original_sr
audio_start = random.randint(0, max_audio_start)
if offset is None:
audio_start = random.randint(0, max_audio_start)
else:
audio_start = math.floor(offset * sample_rate)
if audio_start > max_audio_start:
raise RuntimeError(
f'Provided audio start ({audio_start_seconds} seconds = {audio_start} samples) is larger than the maximum possible ({max_audio_start})'
)
f.seek(audio_start)
samples = f.read(n_segments_at_original_sr, dtype='float32')
is_segmented = True
elif n_segments_at_original_sr >= len(f):
elif n_segments_at_original_sr > len(f):
logging.warning(
f"Number of segments is greater than the length of the audio file {audio_file}. This may lead to shape mismatch errors."
f"Number of segments ({n_segments_at_original_sr}) is greater than the length ({len(f)}) of the audio file {audio_file}. This may lead to shape mismatch errors."
)
samples = f.read(dtype='float32')
else:
Expand Down Expand Up @@ -363,8 +380,7 @@ def subsegment(self, start_time=None, end_time=None):
:param end_time: End of subsegment in seconds.
:type end_time: float
:raise ValueError: If start_time or end_time is incorrectly set,
e.g. out
of bounds in time.
e.g. out of bounds in time.
"""
start_time = 0.0 if start_time is None else start_time
end_time = self.duration if end_time is None else end_time
Expand Down
20 changes: 20 additions & 0 deletions nemo/collections/asr/parts/utils/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import librosa
import numpy as np
import numpy.typing as npt
import scipy
import soundfile as sf
from scipy.spatial.distance import pdist, squareform

Expand Down Expand Up @@ -378,3 +379,22 @@ def pow2db(power: float, eps: Optional[float] = 1e-16) -> float:
Power in dB.
"""
return 10 * np.log10(power + eps)


def get_segment_start(signal: np.ndarray, segment: np.ndarray) -> int:
"""Get starting point of `segment` in `signal`.
We assume that `segment` is a part of `signal`.
Args:
signal: numpy array with shape (num_samples,)
segment: numpy array with shape (num_samples,)
Returns:
Index of the start of `segment` in `signal`.
"""
if len(signal) <= len(segment):
raise ValueError(
f'segment must be shorter than signal: len(segment) = {len(segment)}, len(signal) = {len(signal)}'
)
cc = scipy.signal.correlate(signal, segment, mode='valid')
return np.argmax(cc)
196 changes: 195 additions & 1 deletion nemo/collections/common/parts/preprocessing/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import json
import os
from itertools import combinations
from typing import Any, Dict, List, Optional, Union
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Union

import pandas as pd

Expand Down Expand Up @@ -784,3 +785,196 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]:
offset=item.get('offset', None),
)
return item


class Audio(_Collection):
"""Prepare a list of all audio items, filtered by duration.
"""

OUTPUT_TYPE = collections.namedtuple(typename='Audio', field_names='audio_files duration offset text')

def __init__(
self,
audio_files_list: List[Dict[str, str]],
duration_list: List[float],
offset_list: List[float],
text_list: List[str],
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
max_number: Optional[int] = None,
do_sort_by_duration: bool = False,
):
"""Instantiantes an list of audio files.
Args:
audio_files_list: list of dictionaries with mapping from audio_key to audio_filepath
duration_list: list of durations of input files
offset_list: list of offsets
text_list: list of texts
min_duration: Minimum duration to keep entry with (default: None).
max_duration: Maximum duration to keep entry with (default: None).
max_number: Maximum number of samples to collect.
do_sort_by_duration: True if sort samples list by duration.
"""

output_type = self.OUTPUT_TYPE
data, total_duration = [], 0.0
num_filtered, duration_filtered = 0, 0.0

for audio_files, duration, offset, text in zip(audio_files_list, duration_list, offset_list, text_list):
# Duration filters
if min_duration is not None and duration < min_duration:
duration_filtered += duration
num_filtered += 1
continue

if max_duration is not None and duration > max_duration:
duration_filtered += duration
num_filtered += 1
continue

total_duration += duration
data.append(output_type(audio_files, duration, offset, text))

# Max number of entities filter
if len(data) == max_number:
break

if do_sort_by_duration:
data.sort(key=lambda entity: entity.duration)

logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600)
logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600)

super().__init__(data)


class AudioCollection(Audio):
"""List of audio files from a manifest file.
"""

def __init__(
self, manifest_files: Union[str, List[str]], audio_to_manifest_key: Dict[str, str], *args, **kwargs,
):
"""Instantiates a list of audio files loaded from a manifest file.
Args:
manifest_files: path to a single manifest file or a list of paths
audio_to_manifest_key: dictionary mapping audio signals to keys of the manifest
"""
# Keys from manifest which contain audio
self.audio_to_manifest_key = audio_to_manifest_key

# Initialize data
audio_files_list, duration_list, offset_list, text_list = [], [], [], []

# Parse manifest files
for item in manifest.item_iter(manifest_files, parse_func=self.__parse_item):
audio_files_list.append(item['audio_files'])
duration_list.append(item['duration'])
offset_list.append(item['offset'])
text_list.append(item['text'])

super().__init__(audio_files_list, duration_list, offset_list, text_list, *args, **kwargs)

def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]:
"""Parse a single line from a manifest file.
Args:
line: a string representing a line from a manifest file in JSON format
manifest_file: path to the manifest file. Used to resolve relative paths.
Returns:
Dictionary with audio_files, duration, and offset.
"""
# Utility function
def get_full_path(audio_file: str, manifest_file: str) -> str:
""" # TODO move to some utility module, since this is
relatively general general
Get full path to audio_file.
If path in `audio_file` is not pointing to a valid file and it
is a relative path, we assume that the path is relative to
manifest_dir.
"""
audio_file = Path(audio_file)
manifest_dir = Path(manifest_file).parent

if (len(str(audio_file)) < 255) and not audio_file.is_file() and not audio_file.is_absolute():
# assume the path in manifest is relative to manifest_dir
audio_file_path = manifest_dir / audio_file
if audio_file_path.is_file():
audio_file = str(audio_file_path.absolute())
else:
audio_file = os.path.expanduser(audio_file)
else:
audio_file = os.path.expanduser(audio_file)
return audio_file

# Local utility function
def get_audio_file(item: Dict, manifest_key: Union[str, List[str]]):
"""Get item[key] if key is string, or a list
of strings by combining item[key[0]], item[key[1]], etc.
"""
# Prepare audio file(s)
if manifest_key is None:
# Support for inference, when a target key is None
audio_file = None
elif isinstance(manifest_key, str):
# Load files from a single manifest key
audio_file = item[manifest_key]
elif isinstance(manifest_key, Iterable):
# Load files from multiple manifest keys
audio_file = []
for key in manifest_key:
item_key = item[key]
if isinstance(item_key, str):
audio_file.append(item_key)
elif isinstance(item_key, list):
audio_file += item_key
else:
raise ValueError(f'Unexpected type {type(item_key)} of item for key {key}: {item_key}')
else:
raise ValueError(f'Unexpected type {type(manifest_key)} of manifest_key: {manifest_key}')

return audio_file

# Convert JSON line to a dictionary
item = json.loads(line)

# Handle all audio files
audio_files = {}
for audio_key, manifest_key in self.audio_to_manifest_key.items():

audio_file = get_audio_file(item, manifest_key)

# Get full path to audio file(s)
if isinstance(audio_file, str):
# This dictionary entry points to a single file
audio_files[audio_key] = get_full_path(audio_file, manifest_file)
elif isinstance(audio_file, Iterable):
# This dictionary entry points to multiple files
# Get the files and keep the list structure for this key
audio_files[audio_key] = [get_full_path(f, manifest_file) for f in audio_file]
elif audio_key == 'target' and audio_file is None:
# For inference, we don't need the target
audio_files[audio_key] = None
else:
raise ValueError(f'Unexpected type {type(audio_file)} of audio_file: {audio_file}')
item['audio_files'] = audio_files

# Handle duration
if 'duration' not in item:
raise ValueError(f'Duration not available in line: {line}. Manifest file: {manifest_file}')

# Handle offset
if 'offset' not in item:
item['offset'] = 0.0

# Handle text
if 'text' not in item:
item['text'] = None

return dict(
audio_files=item['audio_files'], duration=item['duration'], offset=item['offset'], text=item['text']
)
Loading

0 comments on commit 7ea9b7b

Please sign in to comment.