diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index cb107d69a697..e360c47a674a 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -15,7 +15,7 @@ import warnings from dataclasses import dataclass from functools import partial -from typing import Any, Optional, TypeVar, Union +from typing import Any, List, Optional, TypeVar, Union import numpy as np import torch @@ -155,6 +155,40 @@ def get_lhotse_dataloader_from_config( we can account for their number of tokens. Note: this behaviour might eventually be extended to audio datasets too. + Note that ``tokenizer`` can be any tokenizer type (e.g. both SentencePiece and Aggregate tokenizers work). + """ + if config.get("multi_config"): + return get_lhotse_dataloader_from_multi_config( + configs=config, global_rank=global_rank, world_size=world_size, dataset=dataset, tokenizer=tokenizer + ) + else: + return get_lhotse_dataloader_from_single_config( + config=config, global_rank=global_rank, world_size=world_size, dataset=dataset, tokenizer=tokenizer + ) + + +def get_lhotse_dataloader_from_single_config( + config: DictConfig, global_rank: int, world_size: int, dataset: torch.utils.data.Dataset, tokenizer=None, +) -> torch.utils.data.DataLoader: + """ + Set up a Lhotse training dataloder. + + Expects a typical NeMo dataset configuration format, with additional fields: "use_lhotse=True". + Some fields in the original NeMo configuration may be ignored. + + The ``dataset`` parameter should be an instance of a Lhotse-compatible PyTorch Dataset class. + It only needs to define the following method ``__getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]``. + This dataset is not expected to hold a reference to any actual data; it may be interpreted as a function + mapping a Lhotse CutSet into a mini-batch of tensors. + + For an example, see: :class:`nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`, + which is constructed from just a tokenizer and essentially loads and collates audio and tokenizes the transcript. + + The ``tokenizer`` is used when text-only datasets are included in dataloading. + In these cases we will tokenize ``TextExample``s before sampling mini-batches so that + we can account for their number of tokens. + Note: this behaviour might eventually be extended to audio datasets too. + Note that ``tokenizer`` can be any tokenizer type (e.g. both SentencePiece and Aggregate tokenizers work). """ logging.info("We will be using a Lhotse DataLoader.") @@ -167,46 +201,93 @@ def get_lhotse_dataloader_from_config( seed = resolve_seed(config.seed) fix_random_seed(seed) - if config.sampler_fusion == "mux": - # Default strategy: every dataset is treated as a stream that is stochastically multiplexed (interleaved). - # Supports all types of dataloader input specifications (manifest_filepath, cuts_path, input_cfg, etc.). - sampler, is_tarred = get_lhotse_sampler_from_config( + assert config.sampler_fusion == "mux", ( + "In order to use a sampler_fusion strategy different than 'mux', " + "create your dataloader using 'get_lhotse_dataloader_from_multi_config' instead." + ) + sampler, is_tarred = get_lhotse_sampler_from_config( + config=config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer + ) + + # 4. Creating dataloader. + if is_tarred: + # Wrapper here is necessary when using NeMo tarred data or Lhotse Shar data, + # because then I/O happens upon sampler iteration. Normally, the sampler resides + # in the training loop process, but when we use iterable dataset, we can move it to + # the dataloading worker process. + # We use lhotse's own worker_init_fn which leverages information such as rank, world_size, + # worker_id, etc. to set a different random seed for each (node, worker) combination. + # This together with infinite datasets removes the need to split data across nodes/workers. + dloader_kwargs = dict( + dataset=IterableDatasetWrapper(dataset=dataset, sampler=sampler), + worker_init_fn=make_worker_init_fn(rank=global_rank, world_size=world_size, seed=seed), + persistent_workers=config.num_workers > 0, # helps Lhotse Shar maintain shuffling state + ) + else: + # For non-tarred data, the sampler resides in the training loop process and + # reads only light-weight JSON objects; it samples mini-batches and passes + # the meta-data to Dataset, which performs the actual I/O inside its __getitem__ method. + dloader_kwargs = dict(dataset=dataset, sampler=sampler) + dloader = torch.utils.data.DataLoader( + **dloader_kwargs, batch_size=None, num_workers=config.num_workers, pin_memory=config.pin_memory, + ) + + return dloader + + +def get_lhotse_dataloader_from_multi_config( + configs: DictConfig, global_rank: int, world_size: int, dataset: torch.utils.data.Dataset, tokenizer=None, +) -> torch.utils.data.DataLoader: + """ + Set up a Lhotse training dataloder. + + It works similarly to :func:`get_lhotse_dataloader_from_config`, except that you can provide multiple configs + to set up different sampling, batching, and augmentation settings for every dataset and decide how to merge them. + + The expected format is that the ``configs`` is a dict of group name -> actual config. + + The first config is treated as a "main" config that determines the RNG, CUDA allocator, and sampler fusion settings. + """ + logging.info(f"We will be using a multi config Lhotse DataLoader with groups: {list(configs.keys())}.") + + configs = [make_structured_with_schema_warnings(c) for c in configs.values() if isinstance(c, DictConfig)] + main_config = configs[0] + maybe_set_cuda_expandable_segments(enabled=main_config.cuda_expandable_segments) + seed = resolve_seed(main_config.seed) + fix_random_seed(seed) + + source_samplers, source_tarred = [], [] + for config in configs: + # TODO(pzelasko): perhaps emit a warning in the unlikely case somebody defines different seeds explicitly. + config.seed = seed + config.shard_seed = main_config.shard_seed + s, t = get_lhotse_sampler_from_config( config=config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer ) + source_samplers.append(s) + source_tarred.append(t) + + assert all( + st == source_tarred[0] for st in source_tarred[1:] + ), "When using multiple input_cfg sources ensure they are all tarred or non-tarred (can't mix)." + is_tarred = all(source_tarred) + if main_config.sampler_fusion == "zip": + sampler = ZipSampler(*source_samplers) + elif main_config.sampler_fusion == "round_robin": + sampler = RoundRobinSampler(*source_samplers) + elif main_config.sampler_fusion == "randomized_round_robin": + sampler = RoundRobinSampler( + *source_samplers, + randomize=True if main_config.sampler_weights is None else main_config.sampler_weights, + seed=seed, + ) + elif main_config.sampler_fusion == "mux": + raise RuntimeError( + "In order to use a sampler_fusion strategy 'mux', " + "create your dataloader using 'get_lhotse_dataloader_from_config' instead." + ) else: - # Custom sampler fusion strategy: that means we will create a separate sampler for each entry in input_cfg list, - # and fuse the sampler later. Strategies supported at the moment are: - # * zip: ZipSampler iterates a step on each sub-sampler and merges the results into one mini-batch. - # * round_robin: with RoundRobinSampler, the sub-samplers take turns to yield their mini-batches. - # * randomized_round_robin: similar to round_robin, except we use RNG to choose which sub-sampler takes the current turn (weights can be provided via sampler_weights). - assert ( - config.input_cfg is not None - ), "In order to use a different sampler fusion strategy than 'mux', you have to provide the dataloader inputs via input_cfg parameter." - source_samplers, source_tarred = [], [] - for input_cfg in config.input_cfg: - source_config = config.copy() - source_config.input_cfg = input_cfg if isinstance(input_cfg, str) else [input_cfg] - s, t = get_lhotse_sampler_from_config( - config=source_config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer - ) - source_samplers.append(s) - source_tarred.append(t) - assert all( - st == source_tarred[0] for st in source_tarred[1:] - ), "When using multiple input_cfg sources ensure they are all tarred or non-tarred (can't mix)." - is_tarred = all(source_tarred) - if config.sampler_fusion == "zip": - sampler = ZipSampler(*source_samplers) - elif config.sampler_fusion == "round_robin": - sampler = RoundRobinSampler(*source_samplers) - elif config.sampler_fusion == "randomized_round_robin": - sampler = RoundRobinSampler( - *source_samplers, - randomize=True if config.sampler_weights is None else config.sampler_weights, - seed=seed, - ) - else: - raise RuntimeError(f"Unsupported sampler fusion strategy: {config.sampler_fusion}") + raise RuntimeError(f"Unsupported sampler fusion strategy: {main_config.sampler_fusion}") # 4. Creating dataloader. if is_tarred: diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 32e1952bcbff..b8ff0b77067d 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -437,92 +437,106 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): else: batch = next(dataloader_iter) + audio_batch = {k: v for k, v in batch.items() if not k.startswith("text_")} + text_batch = {k: v for k, v in batch.items() if k.startswith("text_")} + # TODO(pzelasko): restore this logging once we decide what's the right format for joint text-audio batches # log_token_counts = self.cfg.get('log_token_counts', False) # if log_token_counts: # token_count_avg = sum(batch['token_count']) / len(batch['token_count']) - # Pass only torch.Tensor to prevent errors when process get_iterator_k_split() - batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)} - - # TODO(pzelasko): For the prototype, computing seq_length as a max from both modalities, - # but I feel like this needs larger refactoring - if 'tokens' in batch and 'text_input_ids' in batch: - seq_length = max(batch['tokens'].shape[1], batch['text_input_ids'].shape[1]) - elif 'tokens' in batch: - seq_length = batch['tokens'].shape[1] - elif 'text_input_ids' in batch: - seq_length = batch['text_input_ids'].shape[1] - else: - seq_length = None # TODO(pzelasko): not sure if it is even needed ??? - - data_iter = get_iterator_k_split(batch, get_num_microbatches()) - - # TODO(pzelasko): restore this logging once we decide what's the right format for joint text-audio batches - # if log_token_counts: - # self.log('seq_length_padded', seq_length, prog_bar=True, batch_size=1) - # self.log('tokens_avg', token_count_avg, prog_bar=True, sync_dist=True, batch_size=1) - - # handle asynchronous grad reduction - no_sync_func = None - grad_sync_func = None - param_sync_func = None - if not forward_only and self.with_distributed_adam: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) - grad_sync_func = self.reduce_overlap_gradients - param_sync_func = self.sync_overlap_parameters - - for module in self.get_model_module_list(): - module.config.no_sync_func = no_sync_func - module.config.grad_sync_func = grad_sync_func - module.config.param_sync_func = param_sync_func - - fwd_bwd_function = get_forward_backward_func() - - losses_reduced_per_micro_batch = fwd_bwd_function( - forward_step_func=self.get_forward_output_and_loss_func(tuning=True, validation_step=forward_only), - data_iterator=self._make_data_iterator_list(data_iter), - model=self.model, - num_microbatches=get_num_microbatches(), - forward_only=forward_only, - seq_length=seq_length, - micro_batch_size=get_micro_batch_size(), - first_val_step=first_val_step, - ) + # Note: We want to perform full fwd+bwd separately for each modality, + # as it allows us to save GPU memory. Otherwise, we'd have to + # hold the activations from one modality in memory while running + # forward for the other. + batch_losses = [] + for batch in (audio_batch, text_batch): + if not batch: + continue - non_loss_tensors = {} - # only the last stages of the pipeline return losses - if losses_reduced_per_micro_batch: - for item in losses_reduced_per_micro_batch: - for k, v in item.items(): - if k != 'avg': - av = non_loss_tensors.get(k, []) - av.append(v) - non_loss_tensors[k] = av - if (not forward_only) or self.cfg.data.get('validation_drop_last', True): - # average loss across micro batches - loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] - loss_tensor = torch.concat(loss_tensors_list) - loss_mean = loss_tensor.mean() + # Pass only torch.Tensor to prevent errors when process get_iterator_k_split() + batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)} + + # TODO(pzelasko): For the prototype, computing seq_length as a max from both modalities, + # but I feel like this needs larger refactoring + if 'tokens' in batch and 'text_input_ids' in batch: + seq_length = max(batch['tokens'].shape[1], batch['text_input_ids'].shape[1]) + elif 'tokens' in batch: + seq_length = batch['tokens'].shape[1] + elif 'text_input_ids' in batch: + seq_length = batch['text_input_ids'].shape[1] else: - # Get the total loss since micro batches sizes are not uniform - loss_sum_tensors_list = [ - loss_sum['loss_sum_and_ub_size'] - for loss_sum in losses_reduced_per_micro_batch - if loss_sum['loss_sum_and_ub_size'][1] > 0 - ] - loss_sum = ( - torch.vstack(loss_sum_tensors_list).sum(axis=0) - if len(loss_sum_tensors_list) > 0 - else torch.tensor([0.0, 0.0]).cuda() - ) - return loss_sum - else: - # we're not on the last pipeline stage so no losses - if forward_only: - loss_mean = [] + seq_length = None # TODO(pzelasko): not sure if it is even needed ??? + + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + + # TODO(pzelasko): restore this logging once we decide what's the right format for joint text-audio batches + # if log_token_counts: + # self.log('seq_length_padded', seq_length, prog_bar=True, batch_size=1) + # self.log('tokens_avg', token_count_avg, prog_bar=True, sync_dist=True, batch_size=1) + + # handle asynchronous grad reduction + no_sync_func = None + grad_sync_func = None + param_sync_func = None + if not forward_only and self.with_distributed_adam: + no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + grad_sync_func = self.reduce_overlap_gradients + param_sync_func = self.sync_overlap_parameters + + for module in self.get_model_module_list(): + module.config.no_sync_func = no_sync_func + module.config.grad_sync_func = grad_sync_func + module.config.param_sync_func = param_sync_func + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(tuning=True, validation_step=forward_only), + data_iterator=self._make_data_iterator_list(data_iter), + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=seq_length, + micro_batch_size=get_micro_batch_size(), + first_val_step=first_val_step, + ) + + non_loss_tensors = {} + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + for item in losses_reduced_per_micro_batch: + for k, v in item.items(): + if k != 'avg': + av = non_loss_tensors.get(k, []) + av.append(v) + non_loss_tensors[k] = av + if (not forward_only) or self.cfg.data.get('validation_drop_last', True): + # average loss across micro batches + loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.concat(loss_tensors_list) + loss_mean = loss_tensor.mean() + else: + # Get the total loss since micro batches sizes are not uniform + loss_sum_tensors_list = [ + loss_sum['loss_sum_and_ub_size'] + for loss_sum in losses_reduced_per_micro_batch + if loss_sum['loss_sum_and_ub_size'][1] > 0 + ] + loss_mean = ( + torch.vstack(loss_sum_tensors_list).sum(axis=0) + if len(loss_sum_tensors_list) > 0 + else torch.tensor([0.0, 0.0]).cuda() + ) else: - loss_mean = torch.tensor(0.0).cuda() + # we're not on the last pipeline stage so no losses + if forward_only: + loss_mean = [] + else: + loss_mean = torch.tensor(0.0).cuda() + batch_losses.append(loss_mean.unsqueeze(0)) + + loss_mean = torch.cat(batch_losses).mean() # if forward_only: # return loss_mean