Skip to content

Commit

Permalink
Fix buckeing seeding (#6254) (#6255)
Browse files Browse the repository at this point in the history
* fixed the seeding bug of bucketing.



* fixed the seeding bug of bucketing.



---------

Signed-off-by: Vahid <[email protected]>
Co-authored-by: Vahid Noroozi <[email protected]>
  • Loading branch information
github-actions[bot] and VahidooX committed Mar 19, 2023
1 parent a0f584a commit c6f480c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/asr/data/audio_to_label_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def get_tarred_classification_label_dataset(
else:
datasets.append(dataset)

return get_chain_dataset(datasets=datasets, ds_config=config)
return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank)


def get_concat_tarred_speech_label_dataset(
Expand Down Expand Up @@ -216,4 +216,4 @@ def get_tarred_speech_label_dataset(
else:
datasets.append(dataset)

return get_chain_dataset(datasets=datasets, ds_config=config)
return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank)
6 changes: 3 additions & 3 deletions nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def get_tarred_dataset(
else:
datasets.append(dataset)

return get_chain_dataset(datasets=datasets, ds_config=config)
return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank)


def get_dali_char_dataset(
Expand Down Expand Up @@ -741,7 +741,7 @@ def convert_to_config_list(initial_list):
return initial_list


def get_chain_dataset(datasets, ds_config):
def get_chain_dataset(datasets, ds_config, rank=0):
if len(datasets) > 1:
if ds_config.get('bucketing_batch_size', None) is not None:
bucketing_batch_sizes = calc_bucketing_batch_sizes(ds_config, len(datasets))
Expand All @@ -765,7 +765,7 @@ def get_chain_dataset(datasets, ds_config):
elif bucketing_strategy == 'synced_randomized':
return audio_to_text.RandomizedChainDataset(datasets=datasets, rnd_seed=0)
elif bucketing_strategy == 'fully_randomized':
return audio_to_text.RandomizedChainDataset(datasets=datasets, rnd_seed=random.randint(0, 30000))
return audio_to_text.RandomizedChainDataset(datasets=datasets, rnd_seed=random.randint(0, 30000) + rank)
else:
raise ValueError(
f'bucketing_strategy={bucketing_strategy} is not supported! Supported strategies are [fixed_order, fully_randomized, synced_randomized].'
Expand Down

0 comments on commit c6f480c

Please sign in to comment.