Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
8b04e0e
ICT zeroshot evaluation code
mpatwary Mar 10, 2021
661553f
made more generic, aligned with other tasks
mpatwary Mar 11, 2021
43c9137
Fixed based on review recoemmendation
mpatwary Mar 19, 2021
4056539
fixed another issue
mpatwary Mar 19, 2021
a5acbf5
Merge branch 'main' into main_retriver_merge_ict_eval
mpatwary Mar 20, 2021
10ff060
implementing DPR
mpatwary Apr 9, 2021
cdde433
Merge branch 'main' into main_retriver_merge_dpr
mpatwary Apr 9, 2021
06076c7
implementation dpr
mpatwary Apr 23, 2021
957d1c9
Merge branch 'main' into main_retriver_merge_dpr
Apr 26, 2021
b9fcb7b
adding dpr code
Apr 29, 2021
8004731
removed commnets
Apr 29, 2021
f415dc8
removed commnets
Apr 29, 2021
a8d172b
removed commnets
Apr 29, 2021
220637f
DPR evaluation debugging
May 11, 2021
d2d5086
DPR ongoing
May 11, 2021
6d03d7a
DPR finetune and evaluation
May 12, 2021
f926720
fixing model evaluation of retriver
May 12, 2021
5409341
added pre ad post process
May 12, 2021
7e335e1
added pre ad post process
May 12, 2021
f64977f
evaluation works!
May 13, 2021
dca47cf
debugging DPR
May 14, 2021
3f75537
fix copy-n-paste error
stas00 May 17, 2021
07ca952
Typo fix in readme
devrimcavusoglu May 18, 2021
2dae74b
t5 fixes
stas00 May 18, 2021
4a09bb3
Merge branch 'main' into main_retriver_merge_dpr
mpatwary May 18, 2021
7a0710e
before cleaning the comments
mpatwary May 18, 2021
ccae9db
vit pipeline fixes
kvareddy May 18, 2021
2eaf6c7
cleaning the code
mpatwary May 18, 2021
2529380
additional cleaning
mpatwary May 19, 2021
8e44d61
renaming the folders
mpatwary May 19, 2021
113c636
Add temporary assert to finetuning until it can be fixed.
jaredcasper May 19, 2021
7577931
Fixed issues with ICT pretraining
mpatwary May 19, 2021
dfb6a9b
updated the evaluation script for retriver
mpatwary May 19, 2021
f21a662
updated the evaluation script for retriver
mpatwary May 19, 2021
a41e478
updated the evaluation script for retriver
mpatwary May 19, 2021
825375c
updated the evaluation script for retriver
mpatwary May 19, 2021
217f54b
Merge branch 'finetune_assert' into 'main'
shoeybi May 19, 2021
d078e54
added exit interval for finetuning
mpatwary May 20, 2021
63121a9
updating the scripts
mpatwary May 20, 2021
fda81a2
updating no load rng
mpatwary May 25, 2021
01fc083
Merge branch 'vit_pipeline_fixes' into 'main'
jaredcasper Jun 1, 2021
83c4d95
Merge branch 'main_retriver_merge_dpr' into 'main'
jaredcasper Jun 1, 2021
c7c65bb
updating script
mpatwary Jun 3, 2021
84eb016
Merge branch 'main' into main_retriver_merge_dpr
mpatwary Jun 3, 2021
3dadd16
Update T5 scripts
deepakn94 Jun 7, 2021
04c79f3
resolved hang issue
mpatwary Jun 8, 2021
ebfbfce
fixed the tensor size miss-mass issue
mpatwary Jun 9, 2021
e46f326
fixed the evaluation hangs
mpatwary Jun 10, 2021
a983cab
Adding readme
mpatwary Jun 10, 2021
d562d7b
Adding readme
mpatwary Jun 10, 2021
1095d7e
Adding readme
mpatwary Jun 10, 2021
bab5cc4
Adding readme
mpatwary Jun 10, 2021
8661ca2
Adding readme
mpatwary Jun 10, 2021
293554a
Adding readme
mpatwary Jun 10, 2021
e287bf0
Adding readme
mpatwary Jun 10, 2021
c45109e
Adding readme
mpatwary Jun 10, 2021
473127f
Clean up README.md a bit
jaredcasper Jun 10, 2021
2845047
addressed comments
mpatwary Jun 10, 2021
98113c6
Merge branch 'main_retriver_merge_dpr' of ssh://gitlab-master.nvidia.…
mpatwary Jun 10, 2021
598d7ee
Merge branch 'main_retriver_merge_dpr' into 'main'
jaredcasper Jun 10, 2021
2be1e51
Merge branch 't5_scripts' into 'main'
jaredcasper Jun 10, 2021
9d350c9
updated readme
mpatwary Jun 10, 2021
baf2e2a
updated readme
mpatwary Jun 10, 2021
32da2e7
updated readme
mpatwary Jun 10, 2021
4c92ca8
updated readme
mpatwary Jun 10, 2021
82b69e8
Merge branch 'main_retriver_merge_dpr' into 'main'
jaredcasper Jun 11, 2021
7898c9a
Merge branch 't5' of https://github.com/stas00/Megatron-LM into githu…
jaredcasper Jun 11, 2021
e1318f0
Merge branch 'typo-fix' of https://github.com/devrimcavusoglu/Megatro…
jaredcasper Jun 11, 2021
4a35d50
Merge branch 'patch-1' of https://github.com/stas00/Megatron-LM into …
jaredcasper Jun 11, 2021
90e0a0d
Merge branch 'github-pr' into 'main'
jaredcasper Jun 11, 2021
b9fbe3d
Basic handling of prefix lm by updating the mask
thomasw21 Aug 5, 2021
f69a002
Add prefix option to gpt temporarily and prevent it to use custom kernel
thomasw21 Aug 5, 2021
c63eb38
Add argument for prefix lm, in order to configure masking strategy
thomasw21 Aug 5, 2021
bc5313b
Woops
thomasw21 Aug 5, 2021
60a5884
loss_on_targets_only flag, assert that current prefix implementation …
thomasw21 Aug 5, 2021
fa134ed
Format
thomasw21 Aug 5, 2021
3b678fe
Reverse renaming
thomasw21 Aug 5, 2021
fe5f9b6
Allow prefix on partial document at the end
thomasw21 Aug 5, 2021
a3840d7
WIP: add prefix per row feature
thomasw21 Aug 5, 2021
e454761
Document the use of None
thomasw21 Aug 5, 2021
150b5a1
Woops
thomasw21 Aug 5, 2021
3f9efbb
Handle empty document better
thomasw21 Aug 5, 2021
9802ca6
We might not be able to concat empty tensors
thomasw21 Aug 5, 2021
fcb34b3
Handle empty tensor seperately
thomasw21 Aug 5, 2021
8e43c96
Debug
thomasw21 Aug 5, 2021
3d004ef
Test
thomasw21 Aug 5, 2021
675ef50
Add loss masking as script argument
thomasw21 Aug 5, 2021
e304b21
Turns out deepspeed integration of attention matrices prevented dynam…
thomasw21 Aug 6, 2021
30fddd2
Add more asserts
thomasw21 Aug 6, 2021
23397ad
Prefix can only see the prefix, it cannot see target
thomasw21 Aug 6, 2021
28a712d
Remove prefix-lm argument as we split the pretrain script
thomasw21 Aug 9, 2021
d49d6e5
Merge remote-tracking branch 'origin/master' into prefix_lm
thomasw21 Sep 16, 2021
bbfac96
Iz PR review
thomasw21 Sep 16, 2021
0cdb0a9
Make masking row dependent when using prefix
thomasw21 Sep 16, 2021
a7c51aa
Revert "Merge remote-tracking branch 'origin/master' into prefix_lm"
thomasw21 Sep 16, 2021
6db9b97
Make asserts concerning the choice on loss_on_targets_only
thomasw21 Sep 16, 2021
295e8d0
Tests (#1)
thomasw21 Sep 16, 2021
6a96fb9
Update code for prefix lm
thomasw21 Sep 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,9 +689,11 @@ def _add_data_args(parser):
help='Reset posistion ids after end-of-document token.')
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
'end-of-document token.')
'end-of-document token. Attention between tokens from different documents is null.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.')
group.add_argument('--loss-on-targets-only', action='store_true',

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reset-attention-mask parameter, I recommend adding a note that says


    group.add_argument('--reset-attention-mask', action='store_true',
                       help='Reset self attention maske after '
                       'end-of-document token. If running in --prefix-lm, then this will set per-document prefixing.')

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can improve while being less verbose. Something along the lines that "token interactions are document constrained in self attention mask"

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay my opinion after thinking about it:

  • either we rename --reset-attention-mask to something like --document-specific-attention and add a comment that documents are independent from one another (using eod token as a document segmenter)
  • either we change the help function to say something like Documents, segmented using eod tokens, are considered independent. Attention between tokens from different documents have a null score.

The reason why I like these versions better is to force orthogonalisation of features, ie those flags (prefix and reset attention) have to be independent from one another. Otherwise if you end up adding something else on the attention mechanism whenever have to add a note in the help

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an argument here to enable/disable prefix-lm

help='Mask loss on input sequence.')

return parser

Expand Down
1 change: 1 addition & 0 deletions megatron/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class AttnType(enum.Enum):
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
prefix = 3

class PositionEmbeddingType(enum.Enum):
rotary = 1
Expand Down
4 changes: 3 additions & 1 deletion megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
scale: scaling factor used in input tensor scaling.

"""
custom_kernel_friendly_attn_mask_type = [AttnMaskType.causal, AttnMaskType.padding]

def __init__(
self,
Expand Down Expand Up @@ -134,7 +135,8 @@ def forward(self, input, mask):

# invoke custom kernel
if self.input_in_float16 and mask is not None and \
custom_kernel_constraint and self.scaled_masked_softmax_fusion:
custom_kernel_constraint and self.scaled_masked_softmax_fusion and \
self.attn_mask_type in self.custom_kernel_friendly_attn_mask_type:
scale = self.scale if self.scale is not None else 1.0

if self.attn_mask_type == AttnMaskType.causal:
Expand Down
28 changes: 18 additions & 10 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,14 @@ def post_language_model_processing(lm_output, labels, logit_weights,
class GPTModel(MegatronModule):
"""GPT-2 Language model."""

def __init__(self,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True):
def __init__(
self,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True,
prefix_lm=False,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass prefix_lm using args instead

):
super(GPTModel, self).__init__()
args = get_args()

Expand All @@ -83,7 +86,8 @@ def __init__(self,
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
# TODO: Change naming of class from GPT to something that encapsulate prefix lm.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned previously, I think naming this class GPT is wrong. Can we obtain a better naming? Something like autoregressive maybe? Basically an abstraction of both gpt and prefix-lm

encoder_attn_mask_type=AttnMaskType.prefix if prefix_lm else AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
Expand Down Expand Up @@ -157,9 +161,12 @@ def CrossEntropy(output, labels):
class GPTModelPipe(PipelineModule,MegatronModule):
"""GPT-2 Language model."""

def __init__(self,
num_tokentypes=0,
parallel_output=True):
def __init__(
self,
num_tokentypes=0,
parallel_output=True,
prefix_lm=False

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, use args.

):
args = get_args()
self.parallel_output = parallel_output

Expand Down Expand Up @@ -199,7 +206,8 @@ def _to_float16(inputs):
output_layer_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
layer_number=layer_idx,
self_attn_mask_type=AttnMaskType.causal))
# TODO: Change naming of class from GPT to something that encapsulate prefix lm.
self_attn_mask_type=AttnMaskType.prefix if prefix_lm else AttnMaskType.causal))


# Undo data format change
Expand Down
7 changes: 5 additions & 2 deletions megatron/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,16 @@ def get_batch(context_tokens):

# Move to GPU.
tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
# Get the attention mask and postition ids.
# Get the attention mask and position ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=args.loss_on_targets_only
)

return tokens, attention_mask, position_ids

Expand Down
127 changes: 119 additions & 8 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""General utilities."""

import sys
from random import randint

import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
Expand Down Expand Up @@ -144,12 +145,23 @@ def check_adlr_autoresume_termination(iteration, model,
sys.exit(0)


def get_ltor_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss):
"""Build masks and position id for left to right model."""
def get_ltor_masks_and_position_ids(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't review get_ltor_masks_and_position_ids and get_prefix_indices yet.

data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss,
prefix_indices,
loss_on_targets_only,
):
"""
Build masks and position id for left to right model.
:param prefix_indices: argument can have multiple types:
- None signifies that the model is fully autoregressive.
- List[int] the argument holds all prefix indices that split a row into an input and a target
- List[List[int]] the argument holds all prefix indices that split documents between input and target.
:param loss_on_targets_only: bool to determine if we should mask loss on prefix.
"""

# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
Expand Down Expand Up @@ -182,6 +194,14 @@ def get_ltor_masks_and_position_ids(data,

# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]

# If the last eod token is not the last token of the sequence, we suppose that there is a partial document
# We treat this case as if we add an eod token at the end of the sequence.
if data[b][-1] != eod_token:
eod_index = torch.cat(
(eod_index, torch.tensor([len(data[b])], dtype=eod_index.dtype, device=eod_index.device))
)

# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
Expand All @@ -190,13 +210,31 @@ def get_ltor_masks_and_position_ids(data,
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.

if reset_attention_mask:
# Prevent cross document interactions.
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0

# Prefix lm per document.
if prefix_indices:
assert isinstance(prefix_indices[b], list), f"prefix for a row has to be document specific, and consequently return a list, got {prefix_indices[b]}"
attention_mask[b, 0, prev_index: prefix_indices[b][j], prev_index: prefix_indices[b][j]] = 1
if loss_on_targets_only:
loss_mask[b, prev_index: prefix_indices[b][j]] = 0.0

# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1

prev_index = i + 1

Comment thread
thomasw21 marked this conversation as resolved.
# Prefix lm per row.
if prefix_indices is not None and (reset_attention_mask is False):
assert isinstance(prefix_indices[b], int), \
f"prefix for a row has to be row specific, and consequently return an int, got {prefix_indices[b]}"
attention_mask[b, 0, :prefix_indices[b], :prefix_indices[b]] = 1

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to make sure I understand, do you have one prefix index per batch or one prefix index per instance?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So there's two case:

  • if reset_attention_mask is True: I make use of eod, so I might end up with multiple prefixes in a row
  • if reset_attention_mask is False: then I treat the row as a document, so you end up with as many prefixes as you have rows in your batch.

My code might be confusing because of for batch_id in micro_batch_size that should probably be more for row in micro_batch_size

@ibeltagy ibeltagy Sep 15, 2021

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the reset_attention_mask is False case, the way I am reading it is that the size of the attention mask is one mask per batch (not per row), check here

att_mask_batch = 1

you see that mask shape is [1, 1, seqlen, seqlen].
So here
attention_mask[b, 0, :prefix_indices[b], :prefix_indices[b]] = 1

This line shouldn't work because b can be greater than 1.

@thomasw21 thomasw21 Sep 16, 2021

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nice catch ! So I'm pretty sure we want to remove the logic here

if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1

Essentially I'd say this is an optimisation as the masking is batch size independent in gpt, but would be batch size dependent in prefix lm? We can also work with batch size independent also I guess, sampling a single prefix id for the whole batch. WDYT? #0cdb0a941ddeefbdb1ccab3598f1e34bb38c35a3

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am fine with a single prefix index per batch.

@thomasw21 thomasw21 Sep 16, 2021

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum this made the prefix row dependent, thomasw21@0cdb0a9 . I can revert to handle single index for the whole batch as you suggested.

Do you have insights as to why one would work better than the other?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

being row dependent is better assuming it doesn't increase time and memory that much

@thomasw21 thomasw21 Sep 16, 2021

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay let's keep this as is, and monitor if we're much slower compared to gpt. Bear in mind:

  • prefix lm is not using a custom cuda kernel anymore, so some loss in time/memory is expected

Maybe we can test on a 350m version and then switch if we think have a single index might be faster?

if loss_on_targets_only:
loss_mask[b, :prefix_indices[b]] = 0.0
Comment thread
thomasw21 marked this conversation as resolved.

# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
Expand Down Expand Up @@ -226,3 +264,76 @@ def flops_calculator(model, args, iteration_time):
effective_tera_flops_per_gpu = giga_flops_per_model_per_train_step / (iteration_time * 1000.0 * gpus_per_model)

print_rank_0(f"Effective Tera Flops per GPU: {round(effective_tera_flops_per_gpu, 2)} and total parameters {round(approx_parameters_in_billions, 3)} B")

def get_prefix_indices(data, eod_token, partial_prefix_indices, reset_attention_mask):

@huu4ontocord huu4ontocord Aug 5, 2021

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recommend adding prefix_seprator_token = None as a parameter. And this could be an argument (you would need get the token id as an int from the command line or alternately you need to get the token_id from the vocab for "|||" for example.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not do that just yet. Reason is that I don't know yet how we'll work with prompts. Otherwise yes I agree with need a parser to find those partial indices (I assumed we would get them in preprocessing or whatever), but if not, we might need to compute them here.

"""
Helper function in order to:
- randomly choose prefix index when there's no constraint
- check that prefix are compatible with convention.

:param data: torch.Tensor
:param eod_token: int, token_id used to signal end of document
:param partial_prefix_indices: this agument can have multiple types:
- None, it signals that all prefix indices are randomly sampled.
- List[Optional[int]], its length has to be equal to mini batch size. It stores all the indices for per row prefix.
Optional means that if set to None, we allows ourselves to sample one randomly.
- List[List[Optional[int]]], it follows the following rules:
- The first dimension refers to that sample, ie len(partial_prefix_indices) == len(data)
- The second dimension refers to the number of document of that sample, ie
len(partial_prefix_indices[b]) == (data[b] == eod_token).sum() (+1 for the last partial document).
- partial_prefix_indices have to be interleaved with eod_indices, ie
eod_indices[b][d-1] < partial_prefix_indices[b][d] < eod_indices[b][d] + 1 or is None.
- Optional means that if set to None, we allows ourselves to sample one randomly.
:param reset_attention_mask: bool, determines if prefixes are to be per document or per row.
:return Depending if prefix is per document or per row, the method returns:
- List[List[int]]: prefix indices for each document in case of per document prefix
- List[int]: prefix indices for rows else.
"""
micro_batch_size, seq_length = data.size()
prefix_indices = []

Comment thread
thomasw21 marked this conversation as resolved.
assert partial_prefix_indices is None or len(partial_prefix_indices) == micro_batch_size, f"partial_prefix_indices has to be None or its length equal to {micro_batch_size}, got {len(partial_prefix_indices)}"
for batch_id in range(micro_batch_size):
prefix_indices.append([])
# Compute the index of all eod tokens in data.
eod_indices = (data[batch_id] == eod_token).nonzero().squeeze(-1)

# If the last eod token is not the last token of the sequence, we suppose that there is a partial document
# We treat this case as if we add an eod token at the end of the sequence.
if data[batch_id][-1] != eod_token:
eod_indices = torch.cat(
(eod_indices, torch.tensor([len(data[batch_id])], dtype = eod_indices.dtype, device = eod_indices.device))
)

Comment thread
thomasw21 marked this conversation as resolved.
Outdated
# Prefix lm per document.
if reset_attention_mask:
prev_index = 0
assert partial_prefix_indices is None or len(partial_prefix_indices[batch_id]) == len(eod_indices), f"The number of prefixes has to match the number of documents, complete or partial. Got {len(partial_prefix_indices[batch_id])} prefixes and {len(eod_indices)} documents"

for doc_id, eod_index in enumerate(eod_indices):
assert partial_prefix_indices is None or isinstance(partial_prefix_indices[batch_id], list), f"Per document prefix has to store a list on indices for each row, got {partial_prefix_indices[batch_id]}"
if partial_prefix_indices is None or partial_prefix_indices[batch_id][doc_id] is None:
# We need to randomly generate a prefix index that satisfies the interleave condition in the docstring
prefix_index = randint(prev_index, eod_index)
else:
# We get value from partial_prefix_indices, and run validation on that value
prefix_index = partial_prefix_indices[batch_id][doc_id]
assert prev_index <= prefix_index < eod_index, f"Prefix index needs to be between documents indices, {prev_index} <= {prefix_index} < {eod_index} should be True."

prefix_indices[batch_id].append(prefix_index)
prev_index = eod_index + 1

# Prefix lm per row.
else:
assert partial_prefix_indices is None or isinstance(partial_prefix_indices[batch_id], int), \
f"Per document prefix has to store an int for each row, got {partial_prefix_indices[batch_id]}"

if partial_prefix_indices is None or partial_prefix_indices[batch_id] is None:
# We need to randomly generate a prefix index
prefix_index = randint(0, seq_length - 1)
else:
# We get value from partial_prefix_indices, and run validation on that value
prefix_index = partial_prefix_indices[batch_id]
assert 0 <= prefix_index < seq_length - 1, f"Prefix index needs to be between documents indices, 0 <= {prefix_index} < {seq_length - 1} should be True."
prefix_indices[batch_id].append(prefix_index)
return prefix_indices
17 changes: 12 additions & 5 deletions pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, GPTModelPipe
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices
from megatron.utils import average_losses_across_data_parallel_group

import deepspeed
Expand All @@ -40,6 +40,7 @@ def model_provider(pre_process=True, post_process=True):
see_memory_usage(f"Before Building Model", force=True)

args = get_args()

with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
remote_device=None if args.remote_device=='none' else args.remote_device,
config=args.deepspeed_config,
Expand All @@ -53,7 +54,7 @@ def model_provider(pre_process=True, post_process=True):
# We need to call model.set_batch_fn after deepspeed.initialize
model._megatron_batch_fn = get_batch_pipe

# Predompute the attention mask and store it in args. This avoids having to
# Precompute the attention mask and store it in args. This avoids having to
# pipeline it as an activation during training. The mask is constant, and thus
# we can reuse it.
attention_mask = torch.tril(torch.ones(
Expand Down Expand Up @@ -108,7 +109,10 @@ def get_batch(data_iterator):
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=args.loss_on_targets_only
)

return tokens, labels, loss_mask, attention_mask, position_ids

Expand All @@ -129,13 +133,16 @@ def get_batch_pipe(data):
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()

# Get the masks and postition ids.
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
args.eod_mask_loss,
prefix_indices=None,
loss_on_targets_only=args.loss_on_targets_only
)

return (tokens, position_ids, attention_mask), (labels, loss_mask)

Expand Down
Loading