-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support dataloader as input to audio
for transcription
#9201
Merged
+139
−44
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
e557318
Support dataloader as input to `audio` for transcription
titu1994 0201ec2
Apply isort and black reformatting
titu1994 3be450f
Support dataloader as input to `audio` for transcription
titu1994 403c0ae
Update transcribe signatures
titu1994 9a2dae5
Apply isort and black reformatting
titu1994 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,6 @@ | |
import copy | ||
import json | ||
import os | ||
import tempfile | ||
from abc import abstractmethod | ||
from dataclasses import dataclass, field | ||
from math import ceil, floor | ||
|
@@ -24,6 +23,7 @@ | |
import torch | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
from pytorch_lightning import Trainer | ||
from torch.utils.data import DataLoader | ||
from torchmetrics import Accuracy | ||
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError | ||
|
||
|
@@ -169,7 +169,8 @@ def forward( | |
|
||
if not has_processed_signal: | ||
processed_signal, processed_signal_length = self.preprocessor( | ||
input_signal=input_signal, length=input_signal_length, | ||
input_signal=input_signal, | ||
length=input_signal_length, | ||
) | ||
# Crop or pad is always applied | ||
if self.crop_or_pad is not None: | ||
|
@@ -355,7 +356,7 @@ def _setup_feature_label_dataloader(self, config: DictConfig) -> torch.utils.dat | |
@torch.no_grad() | ||
def transcribe( | ||
self, | ||
audio: List[str], | ||
audio: Union[List[str], DataLoader], | ||
batch_size: int = 4, | ||
logprobs=None, | ||
override_config: Optional[ClassificationInferConfig] | Optional[RegressionInferConfig] = None, | ||
|
@@ -364,7 +365,8 @@ def transcribe( | |
Generate class labels for provided audio files. Use this method for debugging and prototyping. | ||
|
||
Args: | ||
audio: (a single or list) of paths to audio files or a np.ndarray audio sample. \ | ||
audio: (a single or list) of paths to audio files or a np.ndarray audio array. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
Can also be a dataloader object that provides values that can be consumed by the model. | ||
Recommended length per file is approximately 1 second. | ||
batch_size: (int) batch size to use during inference. \ | ||
Bigger will result in better throughput performance but would use more memory. | ||
|
@@ -952,7 +954,10 @@ def _setup_dataloader_from_config(self, config: DictConfig): | |
|
||
shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0 | ||
dataset = audio_to_label_dataset.get_tarred_audio_multi_label_dataset( | ||
cfg=config, shuffle_n=shuffle_n, global_rank=self.global_rank, world_size=self.world_size, | ||
cfg=config, | ||
shuffle_n=shuffle_n, | ||
global_rank=self.global_rank, | ||
world_size=self.world_size, | ||
) | ||
shuffle = False | ||
if hasattr(dataset, 'collate_fn'): | ||
|
@@ -1022,7 +1027,8 @@ def forward( | |
|
||
if not has_processed_signal: | ||
processed_signal, processed_signal_length = self.preprocessor( | ||
input_signal=input_signal, length=input_signal_length, | ||
input_signal=input_signal, | ||
length=input_signal_length, | ||
) | ||
|
||
# Crop or pad is always applied | ||
|
@@ -1124,7 +1130,7 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): | |
def reshape_labels(self, logits, labels, logits_len, labels_len): | ||
""" | ||
Reshape labels to match logits shape. For example, each label is expected to cover a 40ms frame, while each frme prediction from the | ||
model covers 20ms. If labels are shorter than logits, labels are repeated, otherwise labels are folded and argmax is applied to obtain | ||
model covers 20ms. If labels are shorter than logits, labels are repeated, otherwise labels are folded and argmax is applied to obtain | ||
the label of each frame. When lengths of labels and logits are not factors of each other, labels are truncated or padded with zeros. | ||
The ratio_threshold=0.2 is used to determine whether to pad or truncate labels, where the value 0.2 is not important as in real cases the ratio | ||
is very close to either ceil(ratio) or floor(ratio). We use 0.2 here for easier unit-testing. This implementation does not allow frame length | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add torch.dataloader to
audio
types in the func signature.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done