From ce96af552a1ab167828e7859baea1fb43b9664d7 Mon Sep 17 00:00:00 2001 From: Iztok Lebar Bajec Date: Tue, 26 Jul 2022 15:51:16 +0200 Subject: [PATCH] fix tarred dataset len when num shards is not divisible by workers (#4553) * fix tarred dataset len when num shards is not divisible by workers Signed-off-by: Iztok Lebar Bajec * update error reporting on invalid `shard_strategy` * update NLP/PC tarred dataset docstring * add `shard_strategy` to NLP/PC `@dataclass` * update NLP/PC tarred dataset docstring * add `shard_strategy` to NLP/PC docs * revert test with Dataloader retruning the actual data length * make dataloader return actual num of samples, set `limit_train_baches` on `setup_*` * update `shard_strategy` docstrings Signed-off-by: Iztok Lebar Bajec * update `tarred_dataset` documentation Signed-off-by: Iztok Lebar Bajec * fix style * update documentation Signed-off-by: Iztok Lebar Bajec * updated docstrings Signed-off-by: Iztok Lebar Bajec Co-authored-by: PeganovAnton Signed-off-by: Hainan Xu --- docs/source/asr/datasets.rst | 66 ++++++++++--------- .../nlp/punctuation_and_capitalization.rst | 28 ++++++-- nemo/collections/asr/data/audio_to_label.py | 36 ++++++---- nemo/collections/asr/data/audio_to_text.py | 35 ++++++---- .../data/language_modeling/l2r_lm_dataset.py | 23 +++++-- .../language_modeling/sentence_dataset.py | 29 +++++--- .../machine_translation_dataset.py | 27 +++++--- .../text_normalization/decoder_dataset.py | 28 +++++--- .../punctuation_capitalization_dataset.py | 20 +++++- ...nctuation_capitalization_tarred_dataset.py | 65 +++++++++++++++--- .../duplex_decoder.py | 35 ++++++++++ .../language_modeling/transformer_lm_model.py | 34 ++++++++++ .../machine_translation/mt_enc_dec_model.py | 36 ++++++++++ .../punctuation_capitalization_model.py | 37 +++++++++++ 14 files changed, 395 insertions(+), 104 deletions(-) diff --git a/docs/source/asr/datasets.rst b/docs/source/asr/datasets.rst index 364c7fea1926..8878d66ae739 100644 --- a/docs/source/asr/datasets.rst +++ b/docs/source/asr/datasets.rst @@ -1,7 +1,7 @@ Datasets ======== -NeMo has scripts to convert several common ASR datasets into the format expected by the ``nemo_asr`` collection. You can get started +NeMo has scripts to convert several common ASR datasets into the format expected by the ``nemo_asr`` collection. You can get started with those datasets by following the instructions to run those scripts in the section appropriate to each dataset below. If the user has their own data and want to preprocess it to use with NeMo ASR models, refer to the `Preparing Custom ASR Data`_ section. @@ -13,8 +13,8 @@ If the user already has a dataset that you want to convert to a tarred format, r LibriSpeech ----------- -Run the following scripts to download the LibriSpeech data and convert it into the format expected by `nemo_asr`. At least 250GB free -space is required. +Run the following scripts to download the LibriSpeech data and convert it into the format expected by `nemo_asr`. At least 250GB free +space is required. .. code-block:: bash @@ -37,18 +37,18 @@ Fisher English Training Speech Run these scripts to convert the Fisher English Training Speech data into a format expected by the ``nemo_asr`` collection. -In brief, the following scripts convert the ``.sph`` files to ``.wav``, slices those files into smaller audio samples, matches the -smaller slices with their corresponding transcripts, and splits the resulting audio segments into train, validation, and test sets +In brief, the following scripts convert the ``.sph`` files to ``.wav``, slices those files into smaller audio samples, matches the +smaller slices with their corresponding transcripts, and splits the resulting audio segments into train, validation, and test sets (with one manifest each). .. note:: - 106 GB of space is required to run the ``.wav`` conversion - additional 105 GB is required for the slicing and matching - - ``sph2pipe`` is required in order to run the ``.wav`` conversion + - ``sph2pipe`` is required in order to run the ``.wav`` conversion **Instructions** -The following scripts assume that you already have the Fisher dataset from the Linguistic Data Consortium, with a directory structure +The following scripts assume that you already have the Fisher dataset from the Linguistic Data Consortium, with a directory structure that looks similar to the following: .. code-block:: bash @@ -67,7 +67,7 @@ that looks similar to the following: ├── fe_03_p2_sph3 └── ... -The transcripts that will be used are located in the ``fe_03_p<1,2>_transcripts/data/trans`` directory. The audio files (``.sph``) +The transcripts that will be used are located in the ``fe_03_p<1,2>_transcripts/data/trans`` directory. The audio files (``.sph``) are located in the remaining directories in an ``audio`` subdirectory. #. Convert the audio files from ``.sph`` to ``.wav`` by running: @@ -78,7 +78,7 @@ are located in the remaining directories in an ``audio`` subdirectory. python fisher_audio_to_wav.py \ --data_root= --dest_root= - This will place the unsliced ``.wav`` files in ``/LDC200[4,5]S13-Part[1,2]/audio-wav/``. It will take several + This will place the unsliced ``.wav`` files in ``/LDC200[4,5]S13-Part[1,2]/audio-wav/``. It will take several minutes to run. #. Process the transcripts and slice the audio data. @@ -90,7 +90,7 @@ are located in the remaining directories in an ``audio`` subdirectory. --dest_root= \ --remove_noises - This script splits the full dataset into train, validation, test sets, and places the audio slices in the corresponding folders + This script splits the full dataset into train, validation, test sets, and places the audio slices in the corresponding folders in the destination directory. One manifest is written out per set, which includes each slice's transcript, duration, and path. This will likely take around 20 minutes to run. Once finished, delete the 10 minute long ``.wav`` files. @@ -100,8 +100,8 @@ are located in the remaining directories in an ``audio`` subdirectory. Run the following script to convert the HUB5 data into a format expected by the ``nemo_asr`` collection. -Similarly, to the Fisher dataset processing scripts, this script converts the ``.sph`` files to ``.wav``, slices the audio files and -transcripts into utterances, and combines them into segments of some minimum length (default is 10 seconds). The resulting segments +Similarly, to the Fisher dataset processing scripts, this script converts the ``.sph`` files to ``.wav``, slices the audio files and +transcripts into utterances, and combines them into segments of some minimum length (default is 10 seconds). The resulting segments are all written out to an audio directory and the corresponding transcripts are written to a manifest JSON file. .. note:: @@ -123,7 +123,7 @@ You can optionally include ``--min_slice_duration=`` if you would l AN4 Dataset ----------- -This is a small dataset recorded and distributed by Carnegie Mellon University. It consists of recordings of people spelling out +This is a small dataset recorded and distributed by Carnegie Mellon University. It consists of recordings of people spelling out addresses, names, etc. Information about this dataset can be found on the `official CMU site `_. #. `Download and extract the dataset `_ (which is labeled "NIST's Sphere audio (.sph) format (64M)". @@ -153,14 +153,14 @@ After the script finishes, the ``data`` folder should contain a ``data_aishell`` Aishell-2 --------- -To process the AIShell-2 dataset, in the command below, set the data folder of AIShell-2 using ``--audio_folder`` and where to push -these files using ``--dest_folder``. In order to generate files in the supported format of ``nemo_asr``, run: +To process the AIShell-2 dataset, in the command below, set the data folder of AIShell-2 using ``--audio_folder`` and where to push +these files using ``--dest_folder``. In order to generate files in the supported format of ``nemo_asr``, run: .. code-block:: bash python process_aishell2_data.py --audio_folder= --dest_folder= -After the script finishes, the ``train.json``, ``dev.json``, ``test.json``, and ``vocab.txt`` files can be found in the ``dest_folder`` directory. +After the script finishes, the ``train.json``, ``dev.json``, ``test.json``, and ``vocab.txt`` files can be found in the ``dest_folder`` directory. Preparing Custom ASR Data ------------------------- @@ -171,7 +171,7 @@ The audio files can be of any format supported by `Pydub `` and the special + For brace expansion, there may be cases where ``{x..y}`` syntax cannot be used due to shell interference. This occurs most commonly + inside SLURM scripts. Therefore, we provide a few equivalent replacements. Supported opening braces (equivalent to ``{``) are ``(``, + ``[``, ``<`` and the special tag ``_OP_``. Supported closing braces (equivalent to ``}``) are ``)``, ``]``, ``>`` and the special tag ``_CL_``. For SLURM based tasks, we suggest the use of the special tags for ease of use. -As with non-tarred datasets, the manifest file should be passed in ``manifest_filepath``. The dataloader assumes that the length +As with non-tarred datasets, the manifest file should be passed in ``manifest_filepath``. The dataloader assumes that the length of the manifest after filtering is the correct size of the dataset for reporting training progress. -The ``tarred_shard_strategy`` field of the config file can be set if you have multiple shards and are running an experiment with +The ``tarred_shard_strategy`` field of the config file can be set if you have multiple shards and are running an experiment with multiple workers. It defaults to ``scatter``, which preallocates a set of shards per worker which do not change during runtime. +Note that this strategy, on specific occasions (when the number of shards is not divisible with ``world_size``), will not sample +the entire dataset. As an alternative the ``replicate`` strategy, will preallocate the entire set of shards to every worker and not +change it during runtime. The benefit of this strategy is that it allows each worker to sample data points from the entire dataset +independently of others. Note, though, that more than one worker may sample the same shard, and even sample the same data points! +As such, there is no assured guarantee that all samples in the dataset will be sampled at least once during 1 epoch. Note that +for these reasons it is not advisable to use tarred datasets as validation and test datasets. For more information about the individual tarred datasets and the parameters available, including shuffling options, see the corresponding class APIs in the `Datasets <./api.html#Datasets>`__ section. @@ -228,7 +234,7 @@ see the corresponding class APIs in the `Datasets <./api.html#Datasets>`__ secti If using multiple workers, the number of shards should be divisible by the world size to ensure an even split among workers. If it is not divisible, logging will give a warning but training will proceed, but likely hang at the last epoch. In addition, if using distributed processing, each shard must have the same number of entries after filtering is - applied such that each worker ends up with the same number of files. We currently do not check for this in any dataloader, but the user's + applied such that each worker ends up with the same number of files. We currently do not check for this in any dataloader, but the user's program may hang if the shards are uneven. Conversion to Tarred Datasets @@ -262,9 +268,9 @@ The files in the target directory should look similar to the following: ├── metadata.yaml └── tarred_audio_manifest.json -Note that file structures are flattened such that all audio files are at the top level in each tarball. This ensures that -filenames are unique in the tarred dataset and the filepaths do not contain "-sub" and forward slashes in each ``audio_filepath`` are -simply converted to underscores. For example, a manifest entry for ``/data/directory1/file.wav`` would be ``_data_directory1_file.wav`` +Note that file structures are flattened such that all audio files are at the top level in each tarball. This ensures that +filenames are unique in the tarred dataset and the filepaths do not contain "-sub" and forward slashes in each ``audio_filepath`` are +simply converted to underscores. For example, a manifest entry for ``/data/directory1/file.wav`` would be ``_data_directory1_file.wav`` in the tarred dataset manifest, and ``/data/directory2/file.wav`` would be converted to ``_data_directory2_file.wav``. Bucketing Datasets @@ -325,9 +331,9 @@ Currently bucketing feature is just supported for tarred datasets. Upsampling Datasets ------------------ -Buckets may also be 'weighted' to allow multiple runs through a target dataset during each training epoch. This can be beneficial in cases when a dataset is composed of several component sets of unequal sizes and one desires to mitigate bias towards the larger sets through oversampling. +Buckets may also be 'weighted' to allow multiple runs through a target dataset during each training epoch. This can be beneficial in cases when a dataset is composed of several component sets of unequal sizes and one desires to mitigate bias towards the larger sets through oversampling. -Weighting is managed with the `bucketing_weights` parameter. After passing your composite tarred datasets in the format described above for bucketing, pass a list of integers (one per bucket) to indicate how many times a manifest should be read during training. +Weighting is managed with the `bucketing_weights` parameter. After passing your composite tarred datasets in the format described above for bucketing, pass a list of integers (one per bucket) to indicate how many times a manifest should be read during training. For example, by passing `[2,1,1,3]` to the code below: @@ -363,7 +369,7 @@ If using adaptive bucketing, note that the same batch size will be assigned to e model.train_ds.bucketing_weights=[2,1,1,3] model.train_ds.bucketing_batch_size=[4,4,4,2] -All instances of data from `bucket4` will still be trained with a batch size of 2 while all others would have a batch size of 4. As with standard bucketing, this requires `batch_size`` to be set to 1. +All instances of data from `bucket4` will still be trained with a batch size of 2 while all others would have a batch size of 4. As with standard bucketing, this requires `batch_size`` to be set to 1. If `bucketing_batch_size` is not specified, all datasets will be passed with the same fixed batch size as specified by the `batch_size` parameter. -It is recommended to set bucketing strategies to `fully_randomized` during multi-GPU training to prevent possible dataset bias during training. \ No newline at end of file +It is recommended to set bucketing strategies to `fully_randomized` during multi-GPU training to prevent possible dataset bias during training. \ No newline at end of file diff --git a/docs/source/nlp/punctuation_and_capitalization.rst b/docs/source/nlp/punctuation_and_capitalization.rst index be2890943acb..16a1e6856703 100755 --- a/docs/source/nlp/punctuation_and_capitalization.rst +++ b/docs/source/nlp/punctuation_and_capitalization.rst @@ -3,7 +3,7 @@ Punctuation and Capitalization Model ==================================== -Automatic Speech Recognition (ASR) systems typically generate text with no punctuation and capitalization of the words. +Automatic Speech Recognition (ASR) systems typically generate text with no punctuation and capitalization of the words. There are two issues with non-punctuated ASR output: - it could be difficult to read and understand @@ -35,7 +35,7 @@ For each word in the input text, the Punctuation and Capitalization model: - predicts a punctuation mark that should follow the word (if any). By default, the model supports commas, periods, and question marks. - predicts if the word should be capitalized or not -In the Punctuation and Capitalization model, we are jointly training two token-level classifiers on top of a pre-trained +In the Punctuation and Capitalization model, we are jointly training two token-level classifiers on top of a pre-trained language model, such as `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding `__ :cite:`nlp-punct-devlin2018bert`. .. note:: @@ -85,7 +85,7 @@ NeMo Data Format The Punctuation and Capitalization model expects the data in the following format: -The training and evaluation data is divided into 2 files: +The training and evaluation data is divided into 2 files: - ``text.txt`` - ``labels.txt`` @@ -108,10 +108,10 @@ spaces. Each label in ``labels.txt`` file consists of 2 symbols: - the second symbol determines if a word needs to be capitalized or not (where ``U`` indicates that the word should be upper cased, and ``O`` - no capitalization needed) -By default, the following punctuation marks are considered: commas, periods, and question marks; the remaining punctuation marks were +By default, the following punctuation marks are considered: commas, periods, and question marks; the remaining punctuation marks were removed from the data. This can be changed by introducing new labels in the ``labels.txt`` files. -Each line of the ``labels.txt`` should follow the format: ``[LABEL] [SPACE] [LABEL] [SPACE] [LABEL]`` (for ``labels.txt``). For example, +Each line of the ``labels.txt`` should follow the format: ``[LABEL] [SPACE] [LABEL] [SPACE] [LABEL]`` (for ``labels.txt``). For example, labels for the above ``text.txt`` file should be: :: @@ -120,7 +120,7 @@ labels for the above ``text.txt`` file should be: OU OO OO OO ... ... -The complete list of all possible labels used in this tutorial are: +The complete list of all possible labels used in this tutorial are: - ``OO`` - ``.O`` @@ -588,6 +588,22 @@ For convenience, items of data config are described in 4 tables: - ``1`` - The size of shuffle buffer of `webdataset `_. The number of batches which are permuted. + * - **shard_strategy** + - string + - ``scatter`` + - Tarred dataset shard distribution strategy chosen as a str value during ddp. Accepted values are ``scatter`` and ``replicate``. + ``scatter``: Each node gets a unique set of shards, which are permanently pre-allocated and never changed at runtime, when the total + number of shards is not divisible with ``world_size``, some shards (at max ``world_size-1``) will not be used. + ``replicate``: Each node gets the entire set of shards available in the tarred dataset, which are permanently pre-allocated and never + changed at runtime. The benefit of replication is that it allows each node to sample data points from the entire dataset independently + of other nodes, and reduces dependence on value of ``tar_shuffle_n``. + + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, and therefore more than one node may sample + the same tarfile, and even sample the same data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific occasions (when the number of shards is not + divisible with ``world_size``), will not sample the entire dataset. For these reasons it is not advisable to use tarred datasets as + validation or test datasets. .. _pytorch-dataloader-parameters-label: diff --git a/nemo/collections/asr/data/audio_to_label.py b/nemo/collections/asr/data/audio_to_label.py index d002c0c76717..92dc8873d9ec 100644 --- a/nemo/collections/asr/data/audio_to_label.py +++ b/nemo/collections/asr/data/audio_to_label.py @@ -485,10 +485,14 @@ class _TarredAudioLabelDataset(IterableDataset): The benefit of replication is that it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of `shuffle_n`. - Note: Replicated strategy allows every node to sample the entire set of available tarfiles, - and therefore more than one node may sample the same tarfile, and even sample the same - data points! As such, there is no assured guarantee that all samples in the dataset will be - sampled at least once during 1 epoch. + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. is_regression_task (bool): Whether it is a regression task. Defualts to False. @@ -697,10 +701,14 @@ class TarredAudioToClassificationLabelDataset(_TarredAudioLabelDataset): The benefit of replication is that it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of `shuffle_n`. - Note: Replicated strategy allows every node to sample the entire set of available tarfiles, - and therefore more than one node may sample the same tarfile, and even sample the same - data points! As such, there is no assured guarantee that all samples in the dataset will be - sampled at least once during 1 epoch. + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. is_regression_task (bool): Whether it is a regression task. Defualts to False. @@ -771,10 +779,14 @@ class TarredAudioToSpeechLabelDataset(_TarredAudioLabelDataset): The benefit of replication is that it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of `shuffle_n`. - Note: Replicated strategy allows every node to sample the entire set of available tarfiles, - and therefore more than one node may sample the same tarfile, and even sample the same - data points! As such, there is no assured guarantee that all samples in the dataset will be - sampled at least once during 1 epoch. + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. """ diff --git a/nemo/collections/asr/data/audio_to_text.py b/nemo/collections/asr/data/audio_to_text.py index 573dfd672fa2..7b7caa90d697 100644 --- a/nemo/collections/asr/data/audio_to_text.py +++ b/nemo/collections/asr/data/audio_to_text.py @@ -579,10 +579,14 @@ class _TarredAudioToTextDataset(IterableDataset): The benefit of replication is that it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of `shuffle_n`. - Note: Replicated strategy allows every node to sample the entire set of available tarfiles, - and therefore more than one node may sample the same tarfile, and even sample the same - data points! As such, there is no assured guarantee that all samples in the dataset will be - sampled at least once during 1 epoch. + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. return_sample_id (bool): whether to return the sample_id as a part of each sample @@ -840,10 +844,14 @@ class TarredAudioToCharDataset(_TarredAudioToTextDataset): The benefit of replication is that it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of `shuffle_n`. - Note: Replicated strategy allows every node to sample the entire set of available tarfiles, - and therefore more than one node may sample the same tarfile, and even sample the same - data points! As such, there is no assured guarantee that all samples in the dataset will be - sampled at least once during 1 epoch. + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. return_sample_id (bool): whether to return the sample_id as a part of each sample @@ -967,10 +975,13 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset): The benefit of replication is that it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of `shuffle_n`. - Note: Replicated strategy allows every node to sample the entire set of available tarfiles, - and therefore more than one node may sample the same tarfile, and even sample the same - data points! As such, there is no assured guarantee that all samples in the dataset will be - sampled at least once during 1 epoch. + .. warning:: Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. return_sample_id (bool): whether to return the sample_id as a part of each sample diff --git a/nemo/collections/nlp/data/language_modeling/l2r_lm_dataset.py b/nemo/collections/nlp/data/language_modeling/l2r_lm_dataset.py index eff0737c7fec..adb6f126cd78 100644 --- a/nemo/collections/nlp/data/language_modeling/l2r_lm_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/l2r_lm_dataset.py @@ -31,7 +31,7 @@ class L2RLanguageModelingDataset(Dataset): """ Dataset for training and evaluating left-to-right language models. - + Args: tokenizer: tokenizer, such as WordTokenizer or CharTokenizer dataset: path to data @@ -73,7 +73,7 @@ def __getitem__(self, idx): class TarredL2RLanguageModelingDataset(IterableDataset): """ A similar Dataset to the L2RLanguageModelingDataset, but which loads tarred tokenized numpy files. - Accepts a single JSON metadata manifest file as well as the path(s) to the tarball(s) containing the wav files. + Accepts a single JSON metadata manifest file as well as the path(s) to the tarball(s) containing the wav files. The manifest should contain information such as the number of shards, the number of tokens in the corpus, and the number of tokens contained within each shard of the tarfile(s). @@ -114,10 +114,15 @@ class TarredL2RLanguageModelingDataset(IterableDataset): available in the tarred dataset, which are permanently pre-allocated and never changed at runtime. The benefit of replication is that it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of `shuffle_n`. - Note: Replicated strategy allows every node to sample the entire set of available tarfiles, - and therefore more than one node may sample the same tarfile, and even sample the same - data points! As such, there is no assured guarantee that all samples in the dataset will be - sampled at least once during 1 epoch. + + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. """ @@ -142,7 +147,11 @@ def __init__( valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: - raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}") + raise ValueError( + f"Invalid shard strategy of type {type(shard_strategy)} " + f"{repr(shard_strategy) if len(repr(shard_strategy)) < 100 else repr(shard_strategy)[:100] + '...'}! " + f"Allowed values are: {valid_shard_strategies}." + ) with open(metadata_path, 'r') as f: metadata = json.load(f) diff --git a/nemo/collections/nlp/data/language_modeling/sentence_dataset.py b/nemo/collections/nlp/data/language_modeling/sentence_dataset.py index e677fbade383..26127bc3aa36 100644 --- a/nemo/collections/nlp/data/language_modeling/sentence_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/sentence_dataset.py @@ -142,7 +142,7 @@ class TarredSentenceDataset(IterableDataset): """ A similar Dataset to the SentenceDataset, but which loads tarred tokenized pickle files. Accepts a single JSON metadata file containing the total number of batches - as well as the path(s) to the tarball(s) containing the wav files. + as well as the path(s) to the tarball(s) containing the wav files. Valid formats for the text_tar_filepaths argument include: (1) a single string that can be brace-expanded, e.g. 'path/to/text.tar' or 'path/to/text_{1..100}.tar.gz', or (2) a list of file paths that will not be brace-expanded, e.g. ['text_1.tar', 'text_2.tar', ...]. @@ -172,10 +172,15 @@ class TarredSentenceDataset(IterableDataset): available in the tarred dataset, which are permanently pre-allocated and never changed at runtime. The benefit of replication is that it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of `shuffle_n`. - Note: Replicated strategy allows every node to sample the entire set of available tarfiles, - and therefore more than one node may sample the same tarfile, and even sample the same - data points! As such, there is no assured guarantee that all samples in the dataset will be - sampled at least once during 1 epoch. + + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. reverse_lang_direction (bool): When True, swaps the source and target directions when returning minibatches. @@ -198,7 +203,11 @@ def __init__( valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: - raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}") + raise ValueError( + f"Invalid shard strategy of type {type(shard_strategy)} " + f"{repr(shard_strategy) if len(repr(shard_strategy)) < 100 else repr(shard_strategy)[:100] + '...'}! " + f"Allowed values are: {valid_shard_strategies}." + ) with open(metadata_path, 'r') as f: metadata = json.load(f) @@ -223,12 +232,14 @@ def __init__( text_tar_filepaths = list(braceexpand.braceexpand(text_tar_filepaths)) if shard_strategy == 'scatter': - logging.info("All tarred dataset shards will be scattered evenly across all nodes.") + logging.info("Tarred dataset shards will be scattered evenly across all nodes.") if len(text_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible " - f"by number of distributed workers ({world_size})." + f"by number of distributed workers ({world_size}). " + f"Some shards will not be used ({len(text_tar_filepaths) % world_size})." ) + batches_per_tar = self.metadata['num_batches'] // len(text_tar_filepaths) begin_idx = (len(text_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(text_tar_filepaths) // world_size) logging.info('Begin Index : %d' % (begin_idx)) @@ -237,7 +248,7 @@ def __init__( logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx ) - self.length = self.metadata['num_batches'] // world_size + self.length = batches_per_tar * len(text_tar_filepaths) * world_size elif shard_strategy == 'replicate': logging.info("All tarred dataset shards will be replicated across all nodes.") diff --git a/nemo/collections/nlp/data/machine_translation/machine_translation_dataset.py b/nemo/collections/nlp/data/machine_translation/machine_translation_dataset.py index fb253df1c098..ac1db2123d99 100644 --- a/nemo/collections/nlp/data/machine_translation/machine_translation_dataset.py +++ b/nemo/collections/nlp/data/machine_translation/machine_translation_dataset.py @@ -326,10 +326,15 @@ class TarredTranslationDataset(IterableDataset): available in the tarred dataset, which are permanently pre-allocated and never changed at runtime. The benefit of replication is that it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of `shuffle_n`. - Note: Replicated strategy allows every node to sample the entire set of available tarfiles, - and therefore more than one node may sample the same tarfile, and even sample the same - data points! As such, there is no assured guarantee that all samples in the dataset will be - sampled at least once during 1 epoch. + + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 1. reverse_lang_direction (bool): When True, swaps the source and target directions when returning minibatches. @@ -360,7 +365,11 @@ def __init__( valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: - raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}") + raise ValueError( + f"Invalid shard strategy of type {type(shard_strategy)} " + f"{repr(shard_strategy) if len(repr(shard_strategy)) < 100 else repr(shard_strategy)[:100] + '...'}! " + f"Allowed values are: {valid_shard_strategies}." + ) with open(metadata_path, 'r') as f: metadata = json.load(f) @@ -385,12 +394,14 @@ def __init__( text_tar_filepaths = list(braceexpand.braceexpand(text_tar_filepaths)) if shard_strategy == 'scatter': - logging.info("All tarred dataset shards will be scattered evenly across all nodes.") + logging.info("Tarred dataset shards will be scattered evenly across all nodes.") if len(text_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible " - f"by number of distributed workers ({world_size})." + f"by number of distributed workers ({world_size}). " + f"Some shards will not be used ({len(text_tar_filepaths) % world_size})." ) + batches_per_tar = self.metadata['num_batches'] // len(text_tar_filepaths) begin_idx = (len(text_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(text_tar_filepaths) // world_size) logging.info('Begin Index : %d' % (begin_idx)) @@ -399,7 +410,7 @@ def __init__( logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx ) - self.length = self.metadata['num_batches'] // world_size + self.length = batches_per_tar * len(text_tar_filepaths) * world_size elif shard_strategy == 'replicate': logging.info("All tarred dataset shards will be replicated across all nodes.") diff --git a/nemo/collections/nlp/data/text_normalization/decoder_dataset.py b/nemo/collections/nlp/data/text_normalization/decoder_dataset.py index 1110807dcb0b..724e2d3229e5 100644 --- a/nemo/collections/nlp/data/text_normalization/decoder_dataset.py +++ b/nemo/collections/nlp/data/text_normalization/decoder_dataset.py @@ -452,10 +452,15 @@ class TarredTextNormalizationDecoderDataset(IterableDataset): available in the tarred dataset, which are permanently pre-allocated and never changed at runtime. The benefit of replication is that it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of `shuffle_n`. - Note: Replicated strategy allows every node to sample the entire set of available tarfiles, - and therefore more than one node may sample the same tarfile, and even sample the same - data points! As such, there is no assured guarantee that all samples in the dataset will be - sampled at least once during 1 epoch. + + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. global_rank: Worker rank, used for partitioning shards. world_size: Total number of processes, used for partitioning shards. """ @@ -473,7 +478,11 @@ def __init__( valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: - raise ValueError(f"`shard_strategy` must be one of {valid_shard_strategies}") + raise ValueError( + f"Invalid shard strategy of type {type(shard_strategy)} " + f"{repr(shard_strategy) if len(repr(shard_strategy)) < 100 else repr(shard_strategy)[:100] + '...'}! " + f"Allowed values are: {valid_shard_strategies}." + ) if isinstance(text_tar_filepaths, str): # Replace '(', '[', '<' and '_OP_' with '{' @@ -493,12 +502,14 @@ def __init__( text_tar_filepaths = list(braceexpand.braceexpand(text_tar_filepaths)) if shard_strategy == 'scatter': - logging.info("All tarred dataset shards will be scattered evenly across all nodes.") + logging.info("Tarred dataset shards will be scattered evenly across all nodes.") if len(text_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(text_tar_filepaths)}) is not divisible " - f"by number of distributed workers ({world_size})." + f"by number of distributed workers ({world_size}). " + f"Some shards will not be used ({len(text_tar_filepaths) % world_size})." ) + batches_per_tar = num_batches // len(text_tar_filepaths) begin_idx = (len(text_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(text_tar_filepaths) // world_size) logging.info('Begin Index : %d' % (begin_idx)) @@ -507,16 +518,17 @@ def __init__( logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx ) + self.length = batches_per_tar * len(text_tar_filepaths) * world_size elif shard_strategy == 'replicate': logging.info("All tarred dataset shards will be replicated across all nodes.") + self.length = num_batches else: raise ValueError(f"Invalid shard strategy! Allowed values are: {valid_shard_strategies}") # Put together WebDataset self._dataset = wd.WebDataset(urls=text_tar_filepaths, nodesplitter=None) - self.length = num_batches // world_size if shuffle_n > 0: self._dataset = self._dataset.shuffle(shuffle_n) else: diff --git a/nemo/collections/nlp/data/token_classification/punctuation_capitalization_dataset.py b/nemo/collections/nlp/data/token_classification/punctuation_capitalization_dataset.py index 493624de3f65..d38fdb45bbfa 100644 --- a/nemo/collections/nlp/data/token_classification/punctuation_capitalization_dataset.py +++ b/nemo/collections/nlp/data/token_classification/punctuation_capitalization_dataset.py @@ -85,7 +85,7 @@ class PunctuationCapitalizationDataConfigBase: labels_file: Optional[str] = None """A path to a file with punctuation and capitalization labels in NeMo format. NeMo format is described in - `documentation + `documentation `_ """ @@ -112,7 +112,7 @@ class PunctuationCapitalizationDataConfigBase: """A path to a directory containing cache or directory where newly created cache is saved. By default, it is a directory containing ``text_file``. You may need this parameter if cache for a dataset is going to be created and the dataset directory is read-only. - + ``cache_dir`` and ``label_info_save_dir`` are separate parameters for the case when a cache is ready and this cache is stored in a read only directory. In this case you will separate ``label_info_save_dir``.""" @@ -142,6 +142,22 @@ class PunctuationCapitalizationDataConfigBase: tar_shuffle_n: int = 1 """The size of shuffle buffer of `webdataset`. The number of batches which are permuted.""" + shard_strategy: Optional[str] = 'scatter' + """Tarred dataset shard distribution strategy chosen as a str value during ddp. Accepted values are `scatter` and `replicate`. + `scatter`: The default shard strategy applied by WebDataset, where each node gets a unique set of shards, which are permanently + pre-allocated and never changed at runtime. `replicate` is an optional shard strategy, where each node gets the entire set of shards + available in the tarred dataset, which are permanently pre-allocated and never changed at runtime. The benefit of replication is that + it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of + ``tar_shuffle_n``. + + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, and therefore more than one node may sample + the same tarfile, and even sample the same data points! As such, there is no assured guarantee that all samples in the dataset + will be sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific occasions (when the number of + shards is not divisible with ``world_size``), will not sample the entire dataset. For these reasons it is not advisable to use + tarred datasets as validation or test datasets. + """ + ################################################# # PYTORCH DATALOADER PARAMETERS ################################################# diff --git a/nemo/collections/nlp/data/token_classification/punctuation_capitalization_tarred_dataset.py b/nemo/collections/nlp/data/token_classification/punctuation_capitalization_tarred_dataset.py index 2bfcb7969b6e..da63b20dc560 100644 --- a/nemo/collections/nlp/data/token_classification/punctuation_capitalization_tarred_dataset.py +++ b/nemo/collections/nlp/data/token_classification/punctuation_capitalization_tarred_dataset.py @@ -536,7 +536,7 @@ def repack_tar_files_with_not_enough_batches(output_dir: Path, num_batches_per_t ``repack_tar_files_with_not_enough_batches`` function into tar files with correct ``num_batches_per_tarfile`` batches each. If there is no enough batches in repacked files, then up to ``num_batches_per_tarfile - 1`` remaining batches may be discarded. - + Args: output_dir: a path to the output directory which contains files to repack and where new files are saved num_batches_per_tarfile: a number of batches in 1 tar file. If number of batches in files matching a pattern @@ -685,10 +685,10 @@ def create_tarred_dataset( `examples/nlp/token_classification/data/create_punctuation_capitalization_tarred_dataset.py `_. - Tarred dataset is a directory which contains metadata file, tar files with batches, + Tarred dataset is a directory which contains metadata file, tar files with batches, ``punct_label_vocab.csv`` and ``capit_label_vocab.csv`` files. - Metadata file is a JSON file with 4 items: ``'num_batches'``, ``'tar_files'``, ``'punct_label_vocab_file'``, + Metadata file is a JSON file with 4 items: ``'num_batches'``, ``'tar_files'``, ``'punct_label_vocab_file'``, ``'capit_label_vocab_file'``. The item ``'num_batches'`` (``int``) is a total number of batches in tarred dataset. ``'tar_files'`` is a list of paths to tar files relative to directory containing the metadata file. The items ``'punct_label_vocab_file'`` and ``'capit_label_vocab_file'`` are correspondingly paths to punctuation and @@ -871,6 +871,23 @@ class BertPunctuationCapitalizationTarredDataset(IterableDataset): be used in the current process. shuffle_n (:obj:`int`, `optional`, defaults to :obj:`1`): a number of shuffled batches in a buffer. ``shuffle_n`` batches are loaded into memory, shuffled, and then yielded by a dataset instance. + shard_strategy (:obj:`str`, defaults to :obj:``'scatter'``): Tarred dataset shard distribution strategy chosen as + a str value during ddp. + - ``'scatter'``: The default shard strategy applied by WebDataset, where each node gets + a unique set of shards, which are permanently pre-allocated and never changed at runtime. + - ``'replicate'``: Optional shard strategy, where each node gets all of the set of shards + available in the tarred dataset, which are permanently pre-allocated and never changed at runtime. + The benefit of replication is that it allows each node to sample data points from the entire + dataset independently of other nodes, and reduces dependence on value of :param:`shuffle_n`. + + .. warning:: + Replicated strategy allows every node to sample the entire set of available tarfiles, + and therefore more than one node may sample the same tarfile, and even sample the same + data points! As such, there is no assured guarantee that all samples in the dataset will be + sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific + occasions (when the number of shards is not divisible with ``world_size``), will not sample + the entire dataset. For these reasons it is not advisable to use tarred datasets as validation + or test datasets. """ @property @@ -897,8 +914,18 @@ def __init__( world_size: int = 1, global_rank: int = 0, shuffle_n: int = 1, + shard_strategy: str = "scatter", ) -> None: super().__init__() + + valid_shard_strategies = ['scatter', 'replicate'] + if shard_strategy not in valid_shard_strategies: + raise ValueError( + f"Invalid shard strategy of type {type(shard_strategy)} " + f"{repr(shard_strategy) if len(repr(shard_strategy)) < 100 else repr(shard_strategy)[:100] + '...'}! " + f"Allowed values are: {valid_shard_strategies}." + ) + self.tokenizer = tokenizer self.metadata_file = Path(metadata_file).expanduser() if label_info_save_dir is None: @@ -922,13 +949,31 @@ def __init__( self.capit_label_ids = load_label_ids(self.capit_label_vocab_file) self.pad_label = pad_label self._check_pad_label() - begin_idx = (len(self.tar_files) // world_size) * global_rank - end_idx = begin_idx + (len(self.tar_files) // world_size) - logging.info( - "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx - ) - self.tar_files = self.tar_files[begin_idx:end_idx] - self.length = self.metadata['num_batches'] // world_size + + if shard_strategy == 'scatter': + logging.info("Tarred dataset shards will be scattered evenly across all nodes.") + if len(self.tar_files) % world_size != 0: + logging.warning( + f"Number of shards in tarred dataset ({len(self.tar_files)}) is not divisible " + f"by number of distributed workers ({world_size}). " + f"Some shards will not be used ({len(self.tar_files) % world_size})." + ) + begin_idx = (len(self.tar_files) // world_size) * global_rank + end_idx = begin_idx + (len(self.tar_files) // world_size) + logging.info( + "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx + ) + batches_per_tar = self.metadata['num_batches'] // len(self.tar_files) + self.tar_files = self.tar_files[begin_idx:end_idx] + self.length = batches_per_tar * len(self.tar_files) * world_size + + elif shard_strategy == 'replicate': + logging.info("All tarred dataset shards will be replicated across all nodes.") + self.length = self.metadata['num_batches'] + + else: + raise ValueError(f"Invalid shard strategy! Allowed values are: {valid_shard_strategies}") + self._dataset = wds.WebDataset(urls=self.tar_files, nodesplitter=None).decode( wds.handle_extension('.pyd', decode_pyd) ) diff --git a/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py b/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py index 4f602f90da8b..440640231664 100644 --- a/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py +++ b/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py @@ -15,6 +15,7 @@ import json import os from collections import defaultdict +from math import ceil from typing import Dict, List, Optional, Union import torch @@ -398,6 +399,23 @@ def setup_training_data(self, train_data_config: Optional[DictConfig]): cfg=train_data_config, data_split="train" ) + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'use_tarred_dataset' in train_data_config and train_data_config['use_tarred_dataset']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches * ceil(len(self._train_dl.dataset) / self.world_size) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "training batches will be used. Please set the trainer and rebuild the dataset." + ) + def setup_validation_data(self, val_data_config: Optional[DictConfig]): if not val_data_config or not val_data_config.data_path: logging.info( @@ -409,6 +427,23 @@ def setup_validation_data(self, val_data_config: Optional[DictConfig]): cfg=val_data_config, data_split="val" ) + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'use_tarred_dataset' in val_data_config and val_data_config['use_tarred_dataset']: + # We also need to check if limit_val_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # validation batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if self._trainer is not None and isinstance(self._trainer.limit_val_batches, float): + self._trainer.limit_val_batches = int( + self._trainer.limit_val_batches * ceil(len(self._validation_dl.dataset) / self.world_size) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "validation batches will be used. Please set the trainer and rebuild the dataset." + ) + def setup_multiple_validation_data(self, val_data_config: Union[DictConfig, Dict] = None): if val_data_config is None: val_data_config = self._cfg.validation_ds diff --git a/nemo/collections/nlp/models/language_modeling/transformer_lm_model.py b/nemo/collections/nlp/models/language_modeling/transformer_lm_model.py index e7c13e539dd0..180fdad9ddc3 100644 --- a/nemo/collections/nlp/models/language_modeling/transformer_lm_model.py +++ b/nemo/collections/nlp/models/language_modeling/transformer_lm_model.py @@ -211,9 +211,43 @@ def setup_tokenizer( def setup_training_data(self, train_data_config: Optional[DictConfig]): self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config) + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'use_tarred_dataset' in train_data_config and train_data_config['use_tarred_dataset']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches * math.ceil(len(self._train_dl.dataset) / self.world_size) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "training batches will be used. Please set the trainer and rebuild the dataset." + ) + def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config) + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'use_tarred_dataset' in val_data_config and val_data_config['use_tarred_dataset']: + # We also need to check if limit_val_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # validation batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if self._trainer is not None and isinstance(self._trainer.limit_val_batches, float): + self._trainer.limit_val_batches = int( + self._trainer.limit_val_batches * math.ceil(len(self._validation_dl.dataset) / self.world_size) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "validation batches will be used. Please set the trainer and rebuild the dataset." + ) + def setup_test_data(self, test_data_config: Optional[DictConfig]): self._test_dl = self._setup_dataloader_from_config(cfg=test_data_config) diff --git a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py index 7c27aeb1bb20..ebdcdecdab2b 100644 --- a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py +++ b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py @@ -16,6 +16,7 @@ import json import os import random +from math import ceil from pathlib import Path from typing import Dict, List, Optional, Union @@ -609,6 +610,23 @@ def setup_training_data(self, train_data_config: Optional[DictConfig]): ) self._train_dl = MTEncDecModel._setup_dataloader_from_config(cfg=train_data_config, dataset=self._train_ds,) + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'use_tarred_dataset' in train_data_config and train_data_config['use_tarred_dataset']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches * ceil(len(self._train_dl.dataset) / self.world_size) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "training batches will be used. Please set the trainer and rebuild the dataset." + ) + def setup_multiple_validation_data(self, val_data_config: Union[DictConfig, Dict]): self.setup_validation_data(val_data_config) @@ -626,6 +644,24 @@ def setup_validation_data(self, val_data_config: Optional[DictConfig]): self._validation_dl = MTEncDecModel._setup_eval_dataloader_from_config( cfg=val_data_config, datasets=self._validation_ds ) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'use_tarred_dataset' in val_data_config and val_data_config['use_tarred_dataset']: + # We also need to check if limit_val_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # validation batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if self._trainer is not None and isinstance(self._trainer.limit_val_batches, float): + self._trainer.limit_val_batches = int( + self._trainer.limit_val_batches * ceil(len(self._validation_dl.dataset) / self.world_size) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "validation batches will be used. Please set the trainer and rebuild the dataset." + ) + # instantiate Torchmetric for each val dataloader if self._validation_dl is not None: for dataloader_idx in range(len(self._validation_dl)): diff --git a/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py b/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py index 5f6fa7f6164f..d8d07fcee87d 100644 --- a/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py +++ b/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py @@ -468,6 +468,24 @@ def setup_training_data(self, train_data_config: Optional[Union[Dict[str, Any], train_data_config = self._cfg.train_ds self._train_dl = self._setup_dataloader_from_config(cfg=train_data_config, train=True) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'use_tarred_dataset' in train_data_config and train_data_config['use_tarred_dataset']: + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches * ceil(len(self._train_dl.dataset) / self.world_size) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "training batches will be used. Please set the trainer and rebuild the dataset." + ) + self.punct_label_ids = self._train_dl.dataset.punct_label_ids.copy() self.capit_label_ids = self._train_dl.dataset.capit_label_ids.copy() self.label_ids_are_set = True @@ -540,6 +558,24 @@ def setup_validation_data(self, val_data_config: Optional[Union[Dict[str, Any], val_data_config = self._cfg.validation_ds self._validation_dl = self._setup_dataloader_from_config(cfg=val_data_config, train=False) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if 'use_tarred_dataset' in val_data_config and val_data_config['use_tarred_dataset']: + # We also need to check if limit_val_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # validation batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if self._trainer is not None and isinstance(self._trainer.limit_val_batches, float): + self._trainer.limit_val_batches = int( + self._trainer.limit_val_batches * ceil(len(self._validation_dl.dataset) / self.world_size) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "validation batches will be used. Please set the trainer and rebuild the dataset." + ) + loss_kw, punct_kw, capit_kw = self._get_eval_metrics_kwargs() self.metrics['val']['loss'].append(GlobalAverageLossMetric(**loss_kw)) self.metrics['val']['punct_class_report'].append(ClassificationReport(**punct_kw)) @@ -741,6 +777,7 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, train: bool) -> torch.u world_size=self.world_size, global_rank=self.global_rank, shuffle_n=cfg.tar_shuffle_n, + shard_strategy=cfg.shard_strategy, label_info_save_dir=cfg.label_info_save_dir, ) dataset.check_for_label_consistency_with_model_config(