Skip to content

Commit

Permalink
fixing the ability to use temp sampling with concat datasets (NVIDIA#…
Browse files Browse the repository at this point in the history
…6423)

* fixing the ability to use temp sampling with concat datasets

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: hsiehjackson <[email protected]>
  • Loading branch information
2 people authored and hsiehjackson committed Jun 2, 2023
1 parent 9455c31 commit a3316d3
Showing 1 changed file with 14 additions and 17 deletions.
31 changes: 14 additions & 17 deletions nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,16 +518,14 @@ def get_audio_to_text_char_dataset_from_config(
f"Concat dataset requires `concat_sampling_technique` but it was not provided. Config: {config}"
)
return None

if not 'concat_sampling_probabilities' in config:
logging.warning(
f"Concat dataset requires `concat_sampling_probabilities` list but it was not provided. Config: {config}"
)
return None
else:
if not isclose(sum(config['concat_sampling_probabilities']), 1, abs_tol=1e-6):
logging.warning(f"`concat_sampling_probabilities` need to sum to 1. Config: {config}")
if config['concat_sampling_technique'] == 'random':
if not 'concat_sampling_probabilities' in config:
logging.warning(f"Concat dataset requires `concat_sampling_probabilities` list. Config: {config}")
return None
else:
if not isclose(sum(config['concat_sampling_probabilities']), 1, abs_tol=1e-6):
logging.warning(f"`concat_sampling_probabilities` need to sum to 1. Config: {config}")
return None

shuffle = config['shuffle']
device = 'gpu' if torch.cuda.is_available() else 'cpu'
Expand Down Expand Up @@ -618,15 +616,14 @@ def get_audio_to_text_bpe_dataset_from_config(
)
return None

if not 'concat_sampling_probabilities' in config:
logging.warning(
f"Concat dataset requires `concat_sampling_probabilities` list but it was not provided. Config: {config}"
)
return None
else:
if not isclose(sum(config['concat_sampling_probabilities']), 1, abs_tol=1e-6):
logging.warning(f"`concat_sampling_probabilities` need to sum to 1. Config: {config}")
if config['concat_sampling_technique'] == 'random':
if not 'concat_sampling_probabilities' in config:
logging.warning(f"Concat dataset requires `concat_sampling_probabilities` list. Config: {config}")
return None
else:
if not isclose(sum(config['concat_sampling_probabilities']), 1, abs_tol=1e-6):
logging.warning(f"`concat_sampling_probabilities` need to sum to 1. Config: {config}")
return None

shuffle = config['shuffle']
device = 'gpu' if torch.cuda.is_available() else 'cpu'
Expand Down

0 comments on commit a3316d3

Please sign in to comment.