-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Lhotse integration squashed PR Signed-off-by: Piotr Żelasko <[email protected]> * Code review - Som Signed-off-by: Piotr Żelasko <[email protected]> * Update copyright headers to 2024 Signed-off-by: Piotr Żelasko <[email protected]> * Fix NLP imports Signed-off-by: Piotr Żelasko <[email protected]> * Code review - Vahid Signed-off-by: Piotr Żelasko <[email protected]> --------- Signed-off-by: Piotr Żelasko <[email protected]> Signed-off-by: Piotr Żelasko <[email protected]> Signed-off-by: Pablo Garay <[email protected]>
- Loading branch information
1 parent
f975b9c
commit 8c7bed2
Showing
16 changed files
with
1,696 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Dict, Optional, Tuple | ||
|
||
import torch.utils.data | ||
from lhotse.dataset import AudioSamples | ||
from lhotse.dataset.collation import collate_vectors | ||
|
||
from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer | ||
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec | ||
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType | ||
|
||
|
||
class LhotseSpeechToTextBpeDataset(torch.utils.data.Dataset): | ||
""" | ||
This dataset is based on BPE datasets from audio_to_text.py. | ||
Unlike native NeMo datasets, Lhotse dataset defines only the mapping from | ||
a CutSet (meta-data) to a mini-batch with PyTorch tensors. | ||
Specifically, it performs tokenization, I/O, augmentation, and feature extraction (if any). | ||
Managing data, sampling, de-duplication across workers/nodes etc. is all handled | ||
by Lhotse samplers instead. | ||
""" | ||
|
||
@property | ||
def output_types(self) -> Optional[Dict[str, NeuralType]]: | ||
return { | ||
'audio_signal': NeuralType(('B', 'T'), AudioSignal()), | ||
'a_sig_length': NeuralType(tuple('B'), LengthsType()), | ||
'transcripts': NeuralType(('B', 'T'), LabelsType()), | ||
'transcript_length': NeuralType(tuple('B'), LengthsType()), | ||
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), | ||
} | ||
|
||
def __init__(self, tokenizer): | ||
super().__init__() | ||
self.tokenizer = TokenizerWrapper(tokenizer) | ||
self.load_audio = AudioSamples(fault_tolerant=True) | ||
|
||
def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: | ||
audio, audio_lens, cuts = self.load_audio(cuts) | ||
tokens = [torch.as_tensor(self.tokenizer(c.supervisions[0].text, c.supervisions[0].language)) for c in cuts] | ||
token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) | ||
tokens = collate_vectors(tokens, padding_value=0) | ||
return audio, audio_lens, tokens, token_lens | ||
|
||
|
||
class TokenizerWrapper: | ||
""" | ||
Provide a unified interface for NeMo Tokenizer, AggregateTokenizer, and (char) Parser. | ||
""" | ||
|
||
def __init__(self, tokenizer): | ||
self._tokenizer = tokenizer | ||
if isinstance(tokenizer, AggregateTokenizer): | ||
self._impl = self._call_agg_tokenizer | ||
elif isinstance(tokenizer, TokenizerSpec): | ||
self._impl = self._call_tokenizer | ||
else: | ||
self._impl = self._call_parser | ||
|
||
def __call__(self, text: str, lang: str | None = None): | ||
return self._impl(text, lang) | ||
|
||
def _call_agg_tokenizer(self, text: str, lang: str | None = None): | ||
assert lang is not None, "Expected 'lang' to be set for AggregateTokenizer." | ||
return self._tokenizer.text_to_ids(text, lang) | ||
|
||
def _call_tokenizer(self, text: str, lang: str | None = None): | ||
return self._tokenizer.text_to_ids(text) | ||
|
||
def _call_parser(self, text: str, lang: str | None = None): | ||
return self._tokenizer(text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.