Skip to content

Commit

Permalink
Add Frame-VAD model and datasets (#6441)
Browse files Browse the repository at this point in the history
* add model, dataset, necessary utils and tests

Signed-off-by: stevehuang52 <[email protected]>

* fix tarred data

Signed-off-by: stevehuang52 <[email protected]>

* fix typo

Signed-off-by: stevehuang52 <[email protected]>

* update docstring

Signed-off-by: stevehuang52 <[email protected]>

* update doc

Signed-off-by: stevehuang52 <[email protected]>

* update doc

Signed-off-by: stevehuang52 <[email protected]>

* update pretrained model info

Signed-off-by: stevehuang52 <[email protected]>

---------

Signed-off-by: stevehuang52 <[email protected]>
  • Loading branch information
stevehuang52 committed May 2, 2023
1 parent 4942dcf commit 1ecb8b6
Show file tree
Hide file tree
Showing 9 changed files with 1,100 additions and 30 deletions.
452 changes: 446 additions & 6 deletions nemo/collections/asr/data/audio_to_label.py

Large diffs are not rendered by default.

87 changes: 87 additions & 0 deletions nemo/collections/asr/data/audio_to_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
# limitations under the License.
import copy

from omegaconf import DictConfig

from nemo.collections.asr.data import audio_to_label
from nemo.collections.asr.data.audio_to_text_dataset import convert_to_config_list, get_chain_dataset
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
from nemo.collections.common.data.dataset import ConcatDataset


Expand Down Expand Up @@ -217,3 +220,87 @@ def get_tarred_speech_label_dataset(
datasets.append(dataset)

return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank)


def get_audio_multi_label_dataset(cfg: DictConfig) -> audio_to_label.AudioToMultiLabelDataset:
if "augmentor" in cfg:
augmentor = process_augmentations(cfg.augmentor)
else:
augmentor = None

dataset = audio_to_label.AudioToMultiLabelDataset(
manifest_filepath=cfg.get("manifest_filepath"),
sample_rate=cfg.get("sample_rate"),
labels=cfg.get("labels", None),
int_values=cfg.get("int_values", False),
augmentor=augmentor,
min_duration=cfg.get("min_duration", None),
max_duration=cfg.get("max_duration", None),
trim_silence=cfg.get("trim_silence", False),
is_regression_task=cfg.get("is_regression_task", False),
cal_labels_occurrence=cfg.get("cal_labels_occurrence", False),
delimiter=cfg.get("delimiter", None),
normalize_audio_db=cfg.get("normalize_audio_db", False),
normalize_audio_db_target=cfg.get("normalize_audio_db_target", -20),
)
return dataset


def get_tarred_audio_multi_label_dataset(
cfg: DictConfig, shuffle_n: int, global_rank: int, world_size: int
) -> audio_to_label.TarredAudioToMultiLabelDataset:

if "augmentor" in cfg:
augmentor = process_augmentations(cfg.augmentor)
else:
augmentor = None

tarred_audio_filepaths = cfg['tarred_audio_filepaths']
manifest_filepaths = cfg['manifest_filepath']
datasets = []
tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths)
manifest_filepaths = convert_to_config_list(manifest_filepaths)

bucketing_weights = cfg.get('bucketing_weights', None) # For upsampling buckets
if bucketing_weights:
for idx, weight in enumerate(bucketing_weights):
if not isinstance(weight, int) or weight <= 0:
raise ValueError(f"bucket weights must be positive integers")

if len(manifest_filepaths) != len(tarred_audio_filepaths):
raise ValueError(
f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets."
)

for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate(
zip(tarred_audio_filepaths, manifest_filepaths)
):
if len(tarred_audio_filepath) == 1:
tarred_audio_filepath = tarred_audio_filepath[0]

dataset = audio_to_label.TarredAudioToMultiLabelDataset(
audio_tar_filepaths=tarred_audio_filepath,
manifest_filepath=manifest_filepath,
sample_rate=cfg["sample_rate"],
labels=cfg['labels'],
shuffle_n=shuffle_n,
int_values=cfg.get("int_values", False),
augmentor=augmentor,
min_duration=cfg.get('min_duration', None),
max_duration=cfg.get('max_duration', None),
trim_silence=cfg.get('trim_silence', False),
is_regression_task=cfg.get('is_regression_task', False),
delimiter=cfg.get("delimiter", None),
shard_strategy=cfg.get('tarred_shard_strategy', 'scatter'),
global_rank=global_rank,
world_size=world_size,
normalize_audio_db=cfg.get("normalize_audio_db", False),
normalize_audio_db_target=cfg.get("normalize_audio_db_target", -20),
)

if bucketing_weights:
[datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])]
else:
datasets.append(dataset)

return get_chain_dataset(datasets=datasets, ds_config=cfg, rank=global_rank)
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from nemo.collections.asr.models.asr_model import ASRModel
from nemo.collections.asr.models.audio_to_audio_model import AudioToAudioModel
from nemo.collections.asr.models.classification_models import EncDecClassificationModel
from nemo.collections.asr.models.classification_models import EncDecClassificationModel, EncDecFrameClassificationModel
from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
Expand Down
Loading

0 comments on commit 1ecb8b6

Please sign in to comment.