Skip to content

Commit

Permalink
Working V2 version of multimodal dataloading. Each modality gets its …
Browse files Browse the repository at this point in the history
…own batch settings that can be merged with zip sampler to enjoy max batch sizes for both modalities in a single training step. Each modality runs fwd+bwd in turn to save GPU memory (instead of running fwd separately and bwd together).

Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko committed Jun 5, 2024
1 parent bb97173 commit 9dcee82
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 116 deletions.
157 changes: 119 additions & 38 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings
from dataclasses import dataclass
from functools import partial
from typing import Any, Optional, TypeVar, Union
from typing import Any, List, Optional, TypeVar, Union

import numpy as np
import torch
Expand Down Expand Up @@ -155,6 +155,40 @@ def get_lhotse_dataloader_from_config(
we can account for their number of tokens.
Note: this behaviour might eventually be extended to audio datasets too.
Note that ``tokenizer`` can be any tokenizer type (e.g. both SentencePiece and Aggregate tokenizers work).
"""
if config.get("multi_config"):
return get_lhotse_dataloader_from_multi_config(
configs=config, global_rank=global_rank, world_size=world_size, dataset=dataset, tokenizer=tokenizer
)
else:
return get_lhotse_dataloader_from_single_config(
config=config, global_rank=global_rank, world_size=world_size, dataset=dataset, tokenizer=tokenizer
)


def get_lhotse_dataloader_from_single_config(
config: DictConfig, global_rank: int, world_size: int, dataset: torch.utils.data.Dataset, tokenizer=None,
) -> torch.utils.data.DataLoader:
"""
Set up a Lhotse training dataloder.
Expects a typical NeMo dataset configuration format, with additional fields: "use_lhotse=True".
Some fields in the original NeMo configuration may be ignored.
The ``dataset`` parameter should be an instance of a Lhotse-compatible PyTorch Dataset class.
It only needs to define the following method ``__getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]``.
This dataset is not expected to hold a reference to any actual data; it may be interpreted as a function
mapping a Lhotse CutSet into a mini-batch of tensors.
For an example, see: :class:`nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`,
which is constructed from just a tokenizer and essentially loads and collates audio and tokenizes the transcript.
The ``tokenizer`` is used when text-only datasets are included in dataloading.
In these cases we will tokenize ``TextExample``s before sampling mini-batches so that
we can account for their number of tokens.
Note: this behaviour might eventually be extended to audio datasets too.
Note that ``tokenizer`` can be any tokenizer type (e.g. both SentencePiece and Aggregate tokenizers work).
"""
logging.info("We will be using a Lhotse DataLoader.")
Expand All @@ -167,46 +201,93 @@ def get_lhotse_dataloader_from_config(
seed = resolve_seed(config.seed)
fix_random_seed(seed)

if config.sampler_fusion == "mux":
# Default strategy: every dataset is treated as a stream that is stochastically multiplexed (interleaved).
# Supports all types of dataloader input specifications (manifest_filepath, cuts_path, input_cfg, etc.).
sampler, is_tarred = get_lhotse_sampler_from_config(
assert config.sampler_fusion == "mux", (
"In order to use a sampler_fusion strategy different than 'mux', "
"create your dataloader using 'get_lhotse_dataloader_from_multi_config' instead."
)
sampler, is_tarred = get_lhotse_sampler_from_config(
config=config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer
)

# 4. Creating dataloader.
if is_tarred:
# Wrapper here is necessary when using NeMo tarred data or Lhotse Shar data,
# because then I/O happens upon sampler iteration. Normally, the sampler resides
# in the training loop process, but when we use iterable dataset, we can move it to
# the dataloading worker process.
# We use lhotse's own worker_init_fn which leverages information such as rank, world_size,
# worker_id, etc. to set a different random seed for each (node, worker) combination.
# This together with infinite datasets removes the need to split data across nodes/workers.
dloader_kwargs = dict(
dataset=IterableDatasetWrapper(dataset=dataset, sampler=sampler),
worker_init_fn=make_worker_init_fn(rank=global_rank, world_size=world_size, seed=seed),
persistent_workers=config.num_workers > 0, # helps Lhotse Shar maintain shuffling state
)
else:
# For non-tarred data, the sampler resides in the training loop process and
# reads only light-weight JSON objects; it samples mini-batches and passes
# the meta-data to Dataset, which performs the actual I/O inside its __getitem__ method.
dloader_kwargs = dict(dataset=dataset, sampler=sampler)
dloader = torch.utils.data.DataLoader(
**dloader_kwargs, batch_size=None, num_workers=config.num_workers, pin_memory=config.pin_memory,
)

return dloader


def get_lhotse_dataloader_from_multi_config(
configs: DictConfig, global_rank: int, world_size: int, dataset: torch.utils.data.Dataset, tokenizer=None,
) -> torch.utils.data.DataLoader:
"""
Set up a Lhotse training dataloder.
It works similarly to :func:`get_lhotse_dataloader_from_config`, except that you can provide multiple configs
to set up different sampling, batching, and augmentation settings for every dataset and decide how to merge them.
The expected format is that the ``configs`` is a dict of group name -> actual config.
The first config is treated as a "main" config that determines the RNG, CUDA allocator, and sampler fusion settings.
"""
logging.info(f"We will be using a multi config Lhotse DataLoader with groups: {list(configs.keys())}.")

configs = [make_structured_with_schema_warnings(c) for c in configs.values() if isinstance(c, DictConfig)]
main_config = configs[0]
maybe_set_cuda_expandable_segments(enabled=main_config.cuda_expandable_segments)
seed = resolve_seed(main_config.seed)
fix_random_seed(seed)

source_samplers, source_tarred = [], []
for config in configs:
# TODO(pzelasko): perhaps emit a warning in the unlikely case somebody defines different seeds explicitly.
config.seed = seed
config.shard_seed = main_config.shard_seed
s, t = get_lhotse_sampler_from_config(
config=config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer
)
source_samplers.append(s)
source_tarred.append(t)

assert all(
st == source_tarred[0] for st in source_tarred[1:]
), "When using multiple input_cfg sources ensure they are all tarred or non-tarred (can't mix)."
is_tarred = all(source_tarred)
if main_config.sampler_fusion == "zip":
sampler = ZipSampler(*source_samplers)
elif main_config.sampler_fusion == "round_robin":
sampler = RoundRobinSampler(*source_samplers)
elif main_config.sampler_fusion == "randomized_round_robin":
sampler = RoundRobinSampler(
*source_samplers,
randomize=True if main_config.sampler_weights is None else main_config.sampler_weights,
seed=seed,
)
elif main_config.sampler_fusion == "mux":
raise RuntimeError(
"In order to use a sampler_fusion strategy 'mux', "
"create your dataloader using 'get_lhotse_dataloader_from_config' instead."
)
else:
# Custom sampler fusion strategy: that means we will create a separate sampler for each entry in input_cfg list,
# and fuse the sampler later. Strategies supported at the moment are:
# * zip: ZipSampler iterates a step on each sub-sampler and merges the results into one mini-batch.
# * round_robin: with RoundRobinSampler, the sub-samplers take turns to yield their mini-batches.
# * randomized_round_robin: similar to round_robin, except we use RNG to choose which sub-sampler takes the current turn (weights can be provided via sampler_weights).
assert (
config.input_cfg is not None
), "In order to use a different sampler fusion strategy than 'mux', you have to provide the dataloader inputs via input_cfg parameter."
source_samplers, source_tarred = [], []
for input_cfg in config.input_cfg:
source_config = config.copy()
source_config.input_cfg = input_cfg if isinstance(input_cfg, str) else [input_cfg]
s, t = get_lhotse_sampler_from_config(
config=source_config, global_rank=global_rank, world_size=world_size, tokenizer=tokenizer
)
source_samplers.append(s)
source_tarred.append(t)
assert all(
st == source_tarred[0] for st in source_tarred[1:]
), "When using multiple input_cfg sources ensure they are all tarred or non-tarred (can't mix)."
is_tarred = all(source_tarred)
if config.sampler_fusion == "zip":
sampler = ZipSampler(*source_samplers)
elif config.sampler_fusion == "round_robin":
sampler = RoundRobinSampler(*source_samplers)
elif config.sampler_fusion == "randomized_round_robin":
sampler = RoundRobinSampler(
*source_samplers,
randomize=True if config.sampler_weights is None else config.sampler_weights,
seed=seed,
)
else:
raise RuntimeError(f"Unsupported sampler fusion strategy: {config.sampler_fusion}")
raise RuntimeError(f"Unsupported sampler fusion strategy: {main_config.sampler_fusion}")

# 4. Creating dataloader.
if is_tarred:
Expand Down
170 changes: 92 additions & 78 deletions nemo/collections/multimodal/speech_llm/models/modular_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,92 +437,106 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None):
else:
batch = next(dataloader_iter)

audio_batch = {k: v for k, v in batch.items() if not k.startswith("text_")}
text_batch = {k: v for k, v in batch.items() if k.startswith("text_")}

# TODO(pzelasko): restore this logging once we decide what's the right format for joint text-audio batches
# log_token_counts = self.cfg.get('log_token_counts', False)
# if log_token_counts:
# token_count_avg = sum(batch['token_count']) / len(batch['token_count'])

# Pass only torch.Tensor to prevent errors when process get_iterator_k_split()
batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)}

# TODO(pzelasko): For the prototype, computing seq_length as a max from both modalities,
# but I feel like this needs larger refactoring
if 'tokens' in batch and 'text_input_ids' in batch:
seq_length = max(batch['tokens'].shape[1], batch['text_input_ids'].shape[1])
elif 'tokens' in batch:
seq_length = batch['tokens'].shape[1]
elif 'text_input_ids' in batch:
seq_length = batch['text_input_ids'].shape[1]
else:
seq_length = None # TODO(pzelasko): not sure if it is even needed ???

data_iter = get_iterator_k_split(batch, get_num_microbatches())

# TODO(pzelasko): restore this logging once we decide what's the right format for joint text-audio batches
# if log_token_counts:
# self.log('seq_length_padded', seq_length, prog_bar=True, batch_size=1)
# self.log('tokens_avg', token_count_avg, prog_bar=True, sync_dist=True, batch_size=1)

# handle asynchronous grad reduction
no_sync_func = None
grad_sync_func = None
param_sync_func = None
if not forward_only and self.with_distributed_adam:
no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,)
grad_sync_func = self.reduce_overlap_gradients
param_sync_func = self.sync_overlap_parameters

for module in self.get_model_module_list():
module.config.no_sync_func = no_sync_func
module.config.grad_sync_func = grad_sync_func
module.config.param_sync_func = param_sync_func

fwd_bwd_function = get_forward_backward_func()

losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(tuning=True, validation_step=forward_only),
data_iterator=self._make_data_iterator_list(data_iter),
model=self.model,
num_microbatches=get_num_microbatches(),
forward_only=forward_only,
seq_length=seq_length,
micro_batch_size=get_micro_batch_size(),
first_val_step=first_val_step,
)
# Note: We want to perform full fwd+bwd separately for each modality,
# as it allows us to save GPU memory. Otherwise, we'd have to
# hold the activations from one modality in memory while running
# forward for the other.
batch_losses = []
for batch in (audio_batch, text_batch):
if not batch:
continue

non_loss_tensors = {}
# only the last stages of the pipeline return losses
if losses_reduced_per_micro_batch:
for item in losses_reduced_per_micro_batch:
for k, v in item.items():
if k != 'avg':
av = non_loss_tensors.get(k, [])
av.append(v)
non_loss_tensors[k] = av
if (not forward_only) or self.cfg.data.get('validation_drop_last', True):
# average loss across micro batches
loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
loss_mean = loss_tensor.mean()
# Pass only torch.Tensor to prevent errors when process get_iterator_k_split()
batch = {k: v for k, v in batch.items() if isinstance(v, torch.Tensor)}

# TODO(pzelasko): For the prototype, computing seq_length as a max from both modalities,
# but I feel like this needs larger refactoring
if 'tokens' in batch and 'text_input_ids' in batch:
seq_length = max(batch['tokens'].shape[1], batch['text_input_ids'].shape[1])
elif 'tokens' in batch:
seq_length = batch['tokens'].shape[1]
elif 'text_input_ids' in batch:
seq_length = batch['text_input_ids'].shape[1]
else:
# Get the total loss since micro batches sizes are not uniform
loss_sum_tensors_list = [
loss_sum['loss_sum_and_ub_size']
for loss_sum in losses_reduced_per_micro_batch
if loss_sum['loss_sum_and_ub_size'][1] > 0
]
loss_sum = (
torch.vstack(loss_sum_tensors_list).sum(axis=0)
if len(loss_sum_tensors_list) > 0
else torch.tensor([0.0, 0.0]).cuda()
)
return loss_sum
else:
# we're not on the last pipeline stage so no losses
if forward_only:
loss_mean = []
seq_length = None # TODO(pzelasko): not sure if it is even needed ???

data_iter = get_iterator_k_split(batch, get_num_microbatches())

# TODO(pzelasko): restore this logging once we decide what's the right format for joint text-audio batches
# if log_token_counts:
# self.log('seq_length_padded', seq_length, prog_bar=True, batch_size=1)
# self.log('tokens_avg', token_count_avg, prog_bar=True, sync_dist=True, batch_size=1)

# handle asynchronous grad reduction
no_sync_func = None
grad_sync_func = None
param_sync_func = None
if not forward_only and self.with_distributed_adam:
no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,)
grad_sync_func = self.reduce_overlap_gradients
param_sync_func = self.sync_overlap_parameters

for module in self.get_model_module_list():
module.config.no_sync_func = no_sync_func
module.config.grad_sync_func = grad_sync_func
module.config.param_sync_func = param_sync_func

fwd_bwd_function = get_forward_backward_func()

losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(tuning=True, validation_step=forward_only),
data_iterator=self._make_data_iterator_list(data_iter),
model=self.model,
num_microbatches=get_num_microbatches(),
forward_only=forward_only,
seq_length=seq_length,
micro_batch_size=get_micro_batch_size(),
first_val_step=first_val_step,
)

non_loss_tensors = {}
# only the last stages of the pipeline return losses
if losses_reduced_per_micro_batch:
for item in losses_reduced_per_micro_batch:
for k, v in item.items():
if k != 'avg':
av = non_loss_tensors.get(k, [])
av.append(v)
non_loss_tensors[k] = av
if (not forward_only) or self.cfg.data.get('validation_drop_last', True):
# average loss across micro batches
loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
loss_mean = loss_tensor.mean()
else:
# Get the total loss since micro batches sizes are not uniform
loss_sum_tensors_list = [
loss_sum['loss_sum_and_ub_size']
for loss_sum in losses_reduced_per_micro_batch
if loss_sum['loss_sum_and_ub_size'][1] > 0
]
loss_mean = (
torch.vstack(loss_sum_tensors_list).sum(axis=0)
if len(loss_sum_tensors_list) > 0
else torch.tensor([0.0, 0.0]).cuda()
)
else:
loss_mean = torch.tensor(0.0).cuda()
# we're not on the last pipeline stage so no losses
if forward_only:
loss_mean = []
else:
loss_mean = torch.tensor(0.0).cuda()
batch_losses.append(loss_mean.unsqueeze(0))

loss_mean = torch.cat(batch_losses).mean()

# if forward_only:
# return loss_mean
Expand Down

0 comments on commit 9dcee82

Please sign in to comment.