Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharded manifests for tarred datasets #6395

Merged
merged 30 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
37b43f5
testing sharded manifests
bmwshop Apr 6, 2023
955bb0d
compatibility
bmwshop Apr 6, 2023
7be7860
proper fixes
bmwshop Apr 6, 2023
ce027f5
adding flag tot convert_to_tarred_audio_dataset
bmwshop Apr 6, 2023
a4fd990
shard_manifests conf param
bmwshop Apr 7, 2023
eac0324
propagating the shard_manifests param
bmwshop Apr 7, 2023
fc4ccba
propagating the shard_manifests param
bmwshop Apr 7, 2023
d5f4898
distributed checks
bmwshop Apr 7, 2023
cc762e7
typo
bmwshop Apr 7, 2023
1f78a49
typo
bmwshop Apr 7, 2023
483901d
fixes
bmwshop Apr 7, 2023
ac3f5ad
fixes
bmwshop Apr 7, 2023
9dacbdd
fixes
bmwshop Apr 7, 2023
dc81d26
fixes
bmwshop Apr 7, 2023
18e8b99
fixes
bmwshop Apr 7, 2023
bd5cc3b
fixes
bmwshop Apr 7, 2023
0f572b1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 7, 2023
f3cd8ff
fixes based on PR comments and tests
bmwshop Apr 8, 2023
b1aac87
fixes based on PR comments and tests
bmwshop Apr 8, 2023
ded5462
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2023
a37f794
fixes to convert_to_tarred_audio_dataset.py
bmwshop Apr 13, 2023
f788c30
reversing manifest shards flag
bmwshop Apr 13, 2023
e2ac42a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2023
4fba9fc
tests
bmwshop Apr 13, 2023
c27c57a
Merge branch 'shrdmnf' of https://github.com/nvidia/nemo into shrdmnf
bmwshop Apr 13, 2023
15a5d5a
excluding manifests from webdataset url expansion
bmwshop Apr 14, 2023
559581c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2023
b712738
expand manifest paths before attempting to cache from datastore
bmwshop Apr 14, 2023
55a3ace
Merge branch 'shrdmnf' of https://github.com/nvidia/nemo into shrdmnf
bmwshop Apr 14, 2023
e66b216
explicit use of UTF-8 for manifest i/o
bmwshop Apr 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions nemo/collections/asr/data/audio_to_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
import webdataset as wd

from nemo.collections.asr.data.audio_to_text import cache_datastore_manifests, expand_audio_filepaths
from nemo.collections.asr.data.audio_to_text import cache_datastore_manifests, expand_sharded_filepaths
from nemo.collections.asr.parts.preprocessing.segment import available_formats as valid_sf_formats
from nemo.collections.common.parts.preprocessing import collections
from nemo.core.classes import Dataset, IterableDataset
Expand Down Expand Up @@ -560,8 +560,8 @@ def __init__(
for idx in range(len(self.labels[:5])):
logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))

audio_tar_filepaths = expand_audio_filepaths(
audio_tar_filepaths=audio_tar_filepaths,
audio_tar_filepaths = expand_sharded_filepaths(
sharded_filepaths=audio_tar_filepaths,
shard_strategy=shard_strategy,
world_size=world_size,
global_rank=global_rank,
Expand Down
110 changes: 91 additions & 19 deletions nemo/collections/asr/data/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,47 +171,48 @@ def process_text_by_sample(self, sample: collections.ASRAudioText.OUTPUT_TYPE) -
return t, tl


def expand_audio_filepaths(audio_tar_filepaths, shard_strategy: str, world_size: int, global_rank: int):
def expand_sharded_filepaths(sharded_filepaths, shard_strategy: str, world_size: int, global_rank: int):
bmwshop marked this conversation as resolved.
Show resolved Hide resolved
valid_shard_strategies = ['scatter', 'replicate']
if shard_strategy not in valid_shard_strategies:
raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}")

if isinstance(audio_tar_filepaths, str):
if isinstance(sharded_filepaths, str):
# Replace '(' and '[' with '{'
brace_keys_open = ['(', '[', '<', '_OP_']
for bkey in brace_keys_open:
if bkey in audio_tar_filepaths:
audio_tar_filepaths = audio_tar_filepaths.replace(bkey, "{")
if bkey in sharded_filepaths:
sharded_filepaths = sharded_filepaths.replace(bkey, "{")

# Replace ')' and ']' with '}'
brace_keys_close = [')', ']', '>', '_CL_']
for bkey in brace_keys_close:
if bkey in audio_tar_filepaths:
audio_tar_filepaths = audio_tar_filepaths.replace(bkey, "}")
if bkey in sharded_filepaths:
sharded_filepaths = sharded_filepaths.replace(bkey, "}")

if isinstance(audio_tar_filepaths, str):
if isinstance(sharded_filepaths, str):
# Brace expand
audio_tar_filepaths = list(braceexpand.braceexpand(audio_tar_filepaths))
sharded_filepaths = list(braceexpand.braceexpand(sharded_filepaths))

# Expand store paths into WebDataset URLs
audio_tar_filepaths = [
datastore_path_to_webdataset_url(p) if is_datastore_path(p) else p for p in audio_tar_filepaths
sharded_filepaths = [
datastore_path_to_webdataset_url(p) if is_datastore_path(p) and is_tarred_path(p) else p
for p in sharded_filepaths
]

# Check for distributed and partition shards accordingly
if world_size > 1:
if shard_strategy == 'scatter':
logging.info("All tarred dataset shards will be scattered evenly across all nodes.")

if len(audio_tar_filepaths) % world_size != 0:
if len(sharded_filepaths) % world_size != 0:
logging.warning(
f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
f"Number of shards in tarred dataset ({len(sharded_filepaths)}) is not divisible "
f"by number of distributed workers ({world_size})."
)

begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank
end_idx = begin_idx + len(audio_tar_filepaths) // world_size
audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
begin_idx = (len(sharded_filepaths) // world_size) * global_rank
end_idx = begin_idx + len(sharded_filepaths) // world_size
sharded_filepaths = sharded_filepaths[begin_idx:end_idx]
logging.info(
"Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx
)
Expand All @@ -221,7 +222,7 @@ def expand_audio_filepaths(audio_tar_filepaths, shard_strategy: str, world_size:
else:
raise ValueError(f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}")

return audio_tar_filepaths
return sharded_filepaths


def cache_datastore_manifests(
Expand Down Expand Up @@ -345,6 +346,47 @@ def cache_data(manifest_filepaths, cache_audio, num_workers, max_num_workers):
)


"""Optionally expand / shard the list of manifests
This is made to use the same notation as the sharded audio files

Args:
manifest_filepaths: list of manifest files (the sharded notation)
shard_strategy: scatter or replicate (scatter by default)
shard_manifests: bool, if False, no sharding / manifest filepath expansion will be attempted
global_rank: int, the rank of this worker
world_size: int, total number of workers
"""


def shard_manifests_if_needed(
manifest_filepaths: Union[str, List[str]],
shard_strategy: str,
shard_manifests: bool,
global_rank: int,
world_size: int,
):
bmwshop marked this conversation as resolved.
Show resolved Hide resolved
if shard_manifests:
if not torch.distributed.is_available():
logging.warning("Not running in torch.distributed mode. Manifest sharding not available")
return manifest_filepaths

if not torch.distributed.is_initialized():
logging.warning(
'Manifest sharding was requested but torch.distributed is not initialized '
'Did you intend to set the defer_setup flag?'
)
return manifest_filepaths

manifest_filepaths = expand_sharded_filepaths(
sharded_filepaths=manifest_filepaths,
shard_strategy=shard_strategy,
world_size=world_size,
global_rank=global_rank,
)

return manifest_filepaths


class _AudioTextDataset(Dataset):
"""
Dataset that loads tensors via a json file containing paths to audio files, transcripts, and durations (in seconds).
Expand Down Expand Up @@ -748,6 +790,7 @@ class _TarredAudioToTextDataset(IterableDataset):
occasions (when the number of shards is not divisible with ``world_size``), will not sample
the entire dataset. For these reasons it is not advisable to use tarred datasets as validation
or test datasets.
shard_manifests (bool): Whether or not to try / shard manifests. Defaults to False.
global_rank (int): Worker rank, used for partitioning shards. Defaults to 0.
world_size (int): Total number of processes, used for partitioning shards. Defaults to 0.
return_sample_id (bool): whether to return the sample_id as a part of each sample
Expand All @@ -769,10 +812,22 @@ def __init__(
eos_id: Optional[int] = None,
pad_id: int = 0,
shard_strategy: str = "scatter",
shard_manifests: bool = False,
global_rank: int = 0,
world_size: int = 0,
return_sample_id: bool = False,
):
self.shard_manifests = shard_manifests
bmwshop marked this conversation as resolved.
Show resolved Hide resolved

# Shard manifests if necessary and possible and then expand the paths
manifest_filepath = shard_manifests_if_needed(
bmwshop marked this conversation as resolved.
Show resolved Hide resolved
shard_manifests=shard_manifests,
shard_strategy=shard_strategy,
manifest_filepaths=manifest_filepath,
world_size=world_size,
global_rank=global_rank,
)

# If necessary, cache manifests from object store
cache_datastore_manifests(manifest_filepaths=manifest_filepath)

Expand All @@ -788,15 +843,17 @@ def __init__(
index_by_file_id=True, # Must set this so the manifest lines can be indexed by file ID
)

self.len = self._compute_len()

self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor)
self.trim = trim
self.eos_id = eos_id
self.bos_id = bos_id
self.pad_id = pad_id
self.return_sample_id = return_sample_id

audio_tar_filepaths = expand_audio_filepaths(
audio_tar_filepaths=audio_tar_filepaths,
audio_tar_filepaths = expand_sharded_filepaths(
sharded_filepaths=audio_tar_filepaths,
shard_strategy=shard_strategy,
world_size=world_size,
global_rank=global_rank,
Expand Down Expand Up @@ -928,8 +985,19 @@ def get_manifest_sample(self, sample_id):
def __iter__(self):
return self._dataset.__iter__()

def _compute_len(self):
if self.shard_manifests and torch.distributed.is_available() and torch.distributed.is_initialized():
my_len = torch.tensor(len(self.manifest_processor.collection), dtype=torch.int32).cuda()
torch.distributed.all_reduce(my_len)
titu1994 marked this conversation as resolved.
Show resolved Hide resolved
my_len = my_len.int()
logging.info(f'Sharded manifests: Total length: {my_len}')
else:
my_len = len(self.manifest_processor.collection)

return my_len

def __len__(self):
return len(self.manifest_processor.collection)
return self.len
bmwshop marked this conversation as resolved.
Show resolved Hide resolved


class TarredAudioToCharDataset(_TarredAudioToTextDataset):
Expand Down Expand Up @@ -1042,6 +1110,7 @@ def __init__(
parser: Optional[str] = 'en',
pad_id: int = 0,
shard_strategy: str = "scatter",
shard_manifests: bool = False,
global_rank: int = 0,
world_size: int = 0,
return_sample_id: bool = False,
Expand All @@ -1067,6 +1136,7 @@ def __init__(
eos_id=eos_id,
pad_id=pad_id,
shard_strategy=shard_strategy,
shard_manifests=shard_manifests,
global_rank=global_rank,
world_size=world_size,
return_sample_id=return_sample_id,
Expand Down Expand Up @@ -1167,6 +1237,7 @@ def __init__(
trim: bool = False,
use_start_end_token: bool = True,
shard_strategy: str = "scatter",
shard_manifests: bool = False,
global_rank: int = 0,
world_size: int = 0,
return_sample_id: bool = False,
Expand Down Expand Up @@ -1219,6 +1290,7 @@ def __call__(self, *args):
eos_id=eos_id,
pad_id=pad_id,
shard_strategy=shard_strategy,
shard_manifests=shard_manifests,
global_rank=global_rank,
world_size=world_size,
return_sample_id=return_sample_id,
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/data/audio_to_text_dali.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
from omegaconf import DictConfig

from nemo.collections.asr.data.audio_to_text import ASRManifestProcessor, expand_audio_filepaths
from nemo.collections.asr.data.audio_to_text import ASRManifestProcessor, expand_sharded_filepaths
from nemo.collections.common.parts.preprocessing import parsers
from nemo.utils import logging, model_utils

Expand Down Expand Up @@ -345,10 +345,10 @@ def __init__(
self.is_tarred_dataset = False

elif audio_tar_filepaths is not None and audio_tar_index_filepaths is not None:
audio_tar_filepaths = expand_audio_filepaths(
audio_tar_filepaths = expand_sharded_filepaths(
audio_tar_filepaths, shard_strategy=shard_strategy, world_size=world_size, global_rank=global_rank
)
audio_tar_index_filepaths = expand_audio_filepaths(
audio_tar_index_filepaths = expand_sharded_filepaths(
audio_tar_index_filepaths,
shard_strategy=shard_strategy,
world_size=world_size,
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ def get_tarred_dataset(
):
if len(tarred_audio_filepath) == 1:
tarred_audio_filepath = tarred_audio_filepath[0]
if len(manifest_filepath) == 1:
manifest_filepath = manifest_filepath[0]

if tokenizer is None:
dataset = audio_to_text.TarredAudioToCharDataset(
audio_tar_filepaths=tarred_audio_filepath,
Expand All @@ -363,6 +366,7 @@ def get_tarred_dataset(
trim=config.get('trim_silence', False),
parser=config.get('parser', 'en'),
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
shard_manifests=config.get('shard_manifests', False),
global_rank=global_rank,
world_size=world_size,
return_sample_id=config.get('return_sample_id', False),
Expand All @@ -381,6 +385,7 @@ def get_tarred_dataset(
trim=config.get('trim_silence', False),
use_start_end_token=config.get('use_start_end_token', True),
shard_strategy=config.get('tarred_shard_strategy', 'scatter'),
shard_manifests=config.get('shard_manifests', False),
global_rank=global_rank,
world_size=world_size,
return_sample_id=config.get('return_sample_id', False),
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/asr/models/configs/asr_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ASRDatasetConfig(nemo.core.classes.dataset.DatasetConfig):
is_tarred: bool = False
tarred_audio_filepaths: Optional[Any] = None
tarred_shard_strategy: str = "scatter"
shard_manifests: bool = False
shuffle_n: int = 0

# Optional
Expand Down
6 changes: 6 additions & 0 deletions nemo/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def is_datastore_path(path) -> bool:
return path.startswith('ais://')


def is_tarred_path(path) -> bool:
"""Check if a path is for a tarred file.
"""
return path.endswith('.tar')


def is_datastore_cache_shared() -> bool:
"""Check if store cache is shared.
"""
Expand Down
24 changes: 24 additions & 0 deletions scripts/speech_recognition/convert_to_tarred_audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@
"and it must be filled out by the user."
),
)
parser.add_argument(
"--no_shard_manifests",
action='store_true',
help="Do not write sharded manifests along with the aggregated manifest.",
)
parser.add_argument('--workers', type=int, default=1, help='Number of worker processes')
args = parser.parse_args()

Expand All @@ -186,6 +191,7 @@ class ASRTarredDatasetConfig:
min_duration: Optional[float] = None
shuffle_seed: Optional[int] = None
sort_in_shards: bool = True
shard_manifests: bool = True
keep_files_together: bool = False


Expand Down Expand Up @@ -322,6 +328,19 @@ def create_new_dataset(self, manifest_path: str, target_dir: str = "./tarred/",
for i, (start_idx, end_idx) in enumerate(zip(start_indices, end_indices))
)

if config.shard_manifests:
sharded_manifests_dir = target_dir + '/sharded_manifests'
if not os.path.exists(sharded_manifests_dir):
os.makedirs(sharded_manifests_dir)

for manifest in new_entries_list:
shard_id = manifest[0]['shard_id']
new_manifest_shard_path = os.path.join(sharded_manifests_dir, f'manifest_{shard_id}.json')
with open(new_manifest_shard_path, 'w') as m2:
bmwshop marked this conversation as resolved.
Show resolved Hide resolved
bmwshop marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add encoding of utf-8 for any file open (read or write)

for entry in manifest:
json.dump(entry, m2)
m2.write('\n')

# Flatten the list of list of entries to a list of entries
new_entries = [sample for manifest in new_entries_list for sample in manifest]
del new_entries_list
Expand Down Expand Up @@ -626,6 +645,8 @@ def main():
def create_tar_datasets(min_duration: float, max_duration: float, target_dir: str):
builder = ASRTarredDatasetBuilder()

shard_manifests = False if args.no_shard_manifests else True

if args.write_metadata:
metadata = ASRTarredDatasetMetadata()
dataset_cfg = ASRTarredDatasetConfig(
Expand All @@ -635,6 +656,7 @@ def create_tar_datasets(min_duration: float, max_duration: float, target_dir: st
min_duration=min_duration,
shuffle_seed=args.shuffle_seed,
sort_in_shards=args.sort_in_shards,
shard_manifests=shard_manifests,
keep_files_together=args.keep_files_together,
)
metadata.dataset_config = dataset_cfg
Expand All @@ -655,6 +677,7 @@ def create_tar_datasets(min_duration: float, max_duration: float, target_dir: st
min_duration=min_duration,
shuffle_seed=args.shuffle_seed,
sort_in_shards=args.sort_in_shards,
shard_manifests=shard_manifests,
keep_files_together=args.keep_files_together,
)
builder.configure(config)
Expand Down Expand Up @@ -682,6 +705,7 @@ def create_tar_datasets(min_duration: float, max_duration: float, target_dir: st
metadata.dataset_config.shuffle = args.shuffle
metadata.dataset_config.shuffle_seed = args.shuffle_seed
metadata.dataset_config.sort_in_shards = args.sort_in_shards
metadata.dataset_config.shard_manifests = shard_manifests

builder.configure(metadata.dataset_config)

Expand Down
1 change: 1 addition & 0 deletions tests/collections/asr/test_asr_ctc_encoder_model_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def test_ASRDatasetConfig_for_AudioToBPEDataset(self):
'pin_memory',
'drop_last',
'tarred_shard_strategy',
'shard_manifests',
'shuffle_n',
'parser',
'normalize',
Expand Down
Loading