Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to mcore dataset [with FIM support] #8149

Merged
merged 49 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
d6247fd
switch to mcore dataset
dimapihtar Jan 9, 2024
f18928e
Merge branch 'main' into dpykhtar/mcore_ds
dimapihtar Jan 10, 2024
12de8fd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2024
a957966
remove commented lines
dimapihtar Jan 10, 2024
afebd65
fix rank issue
dimapihtar Jan 11, 2024
8a07519
fix rank issue
dimapihtar Jan 11, 2024
20241f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2024
6734fb3
remove unnecessary prints
dimapihtar Jan 11, 2024
31735e0
Merge branch 'dpykhtar/mcore_ds' of https://github.com/NVIDIA/NeMo in…
dimapihtar Jan 11, 2024
48f4b8c
Merge branch 'main' into dpykhtar/mcore_ds
dimapihtar Jan 11, 2024
e37dba3
fix typo
dimapihtar Jan 12, 2024
91789b5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2024
d4d515e
Merge branch 'main' into dpykhtar/mcore_ds
dimapihtar Jan 15, 2024
e43dd09
add FIM support
dimapihtar Jan 15, 2024
4877011
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2024
2d9dc42
revert gpt config
dimapihtar Jan 16, 2024
e198bdb
change if statement
dimapihtar Jan 16, 2024
a3e3c9c
add starcoder config
dimapihtar Jan 16, 2024
6114e2a
add starcoder config
dimapihtar Jan 16, 2024
7253779
remove commented lines
dimapihtar Jan 16, 2024
9229474
Merge branch 'main' into dpykhtar/mcore_ds
dimapihtar Jan 16, 2024
1d00520
code changes
dimapihtar Jan 16, 2024
2ab25b7
change mcore commit
dimapihtar Jan 16, 2024
9635e31
add copyright header
dimapihtar Jan 16, 2024
98a26be
code changes
dimapihtar Jan 17, 2024
9cc81ef
remove if statement
dimapihtar Jan 17, 2024
1940ff9
Merge branch 'main' into dpykhtar/mcore_ds
dimapihtar Jan 17, 2024
b779990
Merge branch 'main' into dpykhtar/mcore_ds
dimapihtar Jan 17, 2024
ecf7bef
fix batch for fine-tuning
dimapihtar Jan 17, 2024
902f296
move is_dataset_built_on_rank function
dimapihtar Jan 18, 2024
0b292c2
Merge branch 'main' into dpykhtar/mcore_ds
dimapihtar Jan 18, 2024
4fe6b68
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2024
5cd0c9f
remove commented lines
dimapihtar Jan 18, 2024
d235e47
Merge branch 'main' into dpykhtar/mcore_ds
dimapihtar Jan 19, 2024
9d9bd3f
config changes
dimapihtar Jan 22, 2024
9b02de9
revert gpt config
dimapihtar Jan 22, 2024
8d7014d
Merge branch 'main' into dpykhtar/mcore_ds
dimapihtar Jan 22, 2024
c3b26ce
revert falcon model
dimapihtar Jan 23, 2024
d41c6fc
fix falcon tests
dimapihtar Jan 23, 2024
54fa7a1
Merge branch 'main' into dpykhtar/mcore_ds
dimapihtar Jan 23, 2024
11d7d4d
fix tests
dimapihtar Jan 23, 2024
52b1960
Merge branch 'main' into dpykhtar/mcore_ds
dimapihtar Jan 23, 2024
771be12
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 23, 2024
417f1d4
revert gpt config
dimapihtar Jan 24, 2024
6dfdf9b
Merge branch 'dpykhtar/mcore_ds' of https://github.com/NVIDIA/NeMo in…
dimapihtar Jan 24, 2024
cef6719
Merge branch 'main' into dpykhtar/mcore_ds
dimapihtar Jan 24, 2024
dd4ac81
comment out MockGPTDataset test
dimapihtar Jan 25, 2024
d1b09b7
Merge branch 'dpykhtar/mcore_ds' of https://github.com/NVIDIA/NeMo in…
dimapihtar Jan 25, 2024
f11c4fd
Merge branch 'main' into dpykhtar/mcore_ds
dimapihtar Jan 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ trainer:
devices: 1
num_nodes: 1
accelerator: gpu
precision: 16
precision: bf16
logger: False # logger provided by exp_manager
enable_checkpointing: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice, max_steps will be reached first.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
max_steps: 5000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
val_check_interval: 100
limit_val_batches: 50
val_check_interval: 250
limit_val_batches: 25
limit_test_batches: 500
accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models
gradient_clip_val: 1.0
Expand All @@ -28,17 +28,17 @@ exp_manager:
explicit_log_dir: null
exp_dir: null
name: megatron_gpt
create_wandb_logger: False
create_wandb_logger: True
wandb_logger_kwargs:
project: null
name: null
project: mcore_ds_test
name: mcore_ds_new
resume_if_exists: True
resume_ignore_no_checkpoint: True
resume_from_checkpoint: ${model.resume_from_checkpoint}
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: val_loss
save_top_k: 10
save_top_k: 5
mode: min
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits
Expand All @@ -47,7 +47,7 @@ exp_manager:

model:
# use GPTModel from megatron.core
mcore_gpt: False
mcore_gpt: True

# specify micro_batch_size, global_batch_size, and model parallelism
# gradient accumulation will be done automatically based on data_parallel_size
Expand Down Expand Up @@ -210,7 +210,7 @@ model:
# Dictionary: can override from CLI "model.data.data_prefix"={"train":[1.0, /path/to/data], "validation":/path/to/data, "test":/path/to/test}
# Or see example below:
# "model.data.data_prefix: {train:[1.0,/path/to/data], validation:[/path/to/data], test:[/path/to/test]}"
data_prefix: ???
data_prefix: [1.0, /home/data/test_text_document]
dimapihtar marked this conversation as resolved.
Show resolved Hide resolved
index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix
data_impl: mmap
splits_string: 900,50,50
Expand All @@ -236,16 +236,16 @@ model:
gen_shape: False # Generate model and kernel details including input shapes

optim:
name: fused_adam
name: distributed_fused_adam
lr: 2e-4
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 500
constant_steps: 0
warmup_steps: 50
constant_steps: 500
min_lr: 2e-5

gc_interval: 0
Expand Down
153 changes: 140 additions & 13 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@
HAVE_APEX = False

try:
from megatron.core import InferenceParams, parallel_state
from megatron.core import InferenceParams, mpu, parallel_state, tensor_parallel
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
Expand Down Expand Up @@ -111,6 +113,15 @@
return name_spec_dict[spec_name]


global is_dataset_built_on_rank
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed


def is_dataset_built_on_rank():
return (
mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()
) and mpu.get_tensor_model_parallel_rank() == 0


class MegatronGPTExportableModel(torch.nn.Module, Exportable):
"""
Megatron GPT Wrapper for ONNX export
Expand Down Expand Up @@ -231,6 +242,10 @@
self.if_first_step = 0
self.prev_global_batch_size = None

self.reset_position_ids = cfg.data.get('reset_position_ids', False)
self.reset_attention_mask = cfg.data.get('reset_attention_mask', False)
self.eod_mask_loss = cfg.data.get('eod_mask_loss', False)

if not self.megatron_amp_O2 and self.cfg.get('virtual_pipeline_model_parallel_size', None):
raise ValueError('Virtual pipeline model parallel is only supported when using megatron_amp_O2')

Expand Down Expand Up @@ -837,6 +852,102 @@
# TODO @tmoon: Use once available in Megatron-LM
# return DataIteratorList(iters)

def get_ltor_masks_and_position_ids(
self, data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss
):
"""Build masks and position id for left to right model."""

# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()

# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length
)

# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0

# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()

if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):

# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()

# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1) :] -= i + 1 - prev_index
prev_index = i + 1

# Convert attention mask to binary:
attention_mask = attention_mask < 0.5

return attention_mask, loss_mask, position_ids

def get_batch(self, data_iterator):
"""Generate a batch."""

# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None

# Items and their type.
keys = ['text']
datatype = torch.int64

# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)

# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()

# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = self.get_ltor_masks_and_position_ids(
tokens, self.tokenizer.eos_id, self.reset_position_ids, self.reset_attention_mask, self.eod_mask_loss
)

batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids,
}
# slice batch along sequence dimension for context parallelism
batch = self.get_batch_on_this_context_parallel_rank(batch)

return batch

def get_batch_on_this_context_parallel_rank(self, batch):
cp_size = self.cfg.get('context_parallel_size', 1)
num_valid_tokens_in_ub = None
Expand Down Expand Up @@ -867,7 +978,8 @@
def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None):

# Get data batch
batch = next(dataloader_iter)
# batch = next(dataloader_iter)
batch = self.get_batch(dataloader_iter)

# Transfer needed data to GPU
required_keys = set()
Expand Down Expand Up @@ -1094,18 +1206,33 @@
1
] = 1 # This is to make sure we only have one epoch on every validation iteration

self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets(
cfg=self.cfg,
trainer=self.trainer,
data_prefix=self.cfg.data.data_prefix,
data_impl=self.cfg.data.data_impl,
splits_string=self.cfg.data.splits_string,
train_valid_test_num_samples=train_valid_test_num_samples,
seq_length=self.cfg.data.seq_length,
seed=self.cfg.seed,
skip_warmup=self.cfg.data.get('skip_warmup', True),
tokenizer=self.tokenizer,
# self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets(
# cfg=self.cfg,
# trainer=self.trainer,
# data_prefix=self.cfg.data.data_prefix,
# data_impl=self.cfg.data.data_impl,
# splits_string=self.cfg.data.splits_string,
# train_valid_test_num_samples=train_valid_test_num_samples,
# seq_length=self.cfg.data.seq_length,
# seed=self.cfg.seed,
# skip_warmup=self.cfg.data.get('skip_warmup', True),
# tokenizer=self.tokenizer,
# )

dataset_config = GPTDatasetConfig(
is_built_on_rank=is_dataset_built_on_rank,
random_seed=self.cfg.seed,
sequence_length=self.cfg.data.seq_length,
blend=self.cfg.data.data_prefix,
blend_per_split=None,
split=self.cfg.data.splits_string,
path_to_cache=self.cfg.data.index_mapping_dir,
)

self._train_ds, self._validation_ds, self._test_ds = BlendedMegatronDatasetBuilder(
GPTDataset, train_valid_test_num_samples, dataset_config,
).build()

if self._train_ds is not None:
logging.info(f'Length of train dataset: {len(self._train_ds)}')
if self._validation_ds is not None:
Expand Down
Loading