-
Notifications
You must be signed in to change notification settings - Fork 226
Prefix lm #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prefix lm #52
Changes from 21 commits
8b04e0e
661553f
43c9137
4056539
a5acbf5
10ff060
cdde433
06076c7
957d1c9
b9fcb7b
8004731
f415dc8
a8d172b
220637f
d2d5086
6d03d7a
f926720
5409341
7e335e1
f64977f
dca47cf
3f75537
07ca952
2dae74b
4a09bb3
7a0710e
ccae9db
2eaf6c7
2529380
8e44d61
113c636
7577931
dfb6a9b
f21a662
a41e478
825375c
217f54b
d078e54
63121a9
fda81a2
01fc083
83c4d95
c7c65bb
84eb016
3dadd16
04c79f3
ebfbfce
e46f326
a983cab
d562d7b
1095d7e
bab5cc4
8661ca2
293554a
e287bf0
c45109e
473127f
2845047
98113c6
598d7ee
2be1e51
9d350c9
baf2e2a
32da2e7
4c92ca8
82b69e8
7898c9a
e1318f0
4a35d50
90e0a0d
b9fbe3d
f69a002
c63eb38
bc5313b
60a5884
fa134ed
3b678fe
fe5f9b6
a3840d7
e454761
150b5a1
3f9efbb
9802ca6
fcb34b3
8e43c96
3d004ef
675ef50
e304b21
30fddd2
23397ad
28a712d
d49d6e5
bbfac96
0cdb0a9
a7c51aa
6db9b97
295e8d0
6a96fb9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,11 +67,14 @@ def post_language_model_processing(lm_output, labels, logit_weights, | |
| class GPTModel(MegatronModule): | ||
| """GPT-2 Language model.""" | ||
|
|
||
| def __init__(self, | ||
| num_tokentypes=0, | ||
| parallel_output=True, | ||
| pre_process=True, | ||
| post_process=True): | ||
| def __init__( | ||
| self, | ||
| num_tokentypes=0, | ||
| parallel_output=True, | ||
| pre_process=True, | ||
| post_process=True, | ||
| prefix_lm=False, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pass |
||
| ): | ||
| super(GPTModel, self).__init__() | ||
| args = get_args() | ||
|
|
||
|
|
@@ -83,7 +86,8 @@ def __init__(self, | |
| self.language_model, self._language_model_key = get_language_model( | ||
| num_tokentypes=num_tokentypes, | ||
| add_pooler=False, | ||
| encoder_attn_mask_type=AttnMaskType.causal, | ||
| # TODO: Change naming of class from GPT to something that encapsulate prefix lm. | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned previously, I think naming this class GPT is wrong. Can we obtain a better naming? Something like autoregressive maybe? Basically an abstraction of both gpt and prefix-lm |
||
| encoder_attn_mask_type=AttnMaskType.prefix if prefix_lm else AttnMaskType.causal, | ||
| init_method=init_method_normal(args.init_method_std), | ||
| scaled_init_method=scaled_init_method_normal(args.init_method_std, | ||
| args.num_layers), | ||
|
|
@@ -157,9 +161,12 @@ def CrossEntropy(output, labels): | |
| class GPTModelPipe(PipelineModule,MegatronModule): | ||
| """GPT-2 Language model.""" | ||
|
|
||
| def __init__(self, | ||
| num_tokentypes=0, | ||
| parallel_output=True): | ||
| def __init__( | ||
| self, | ||
| num_tokentypes=0, | ||
| parallel_output=True, | ||
| prefix_lm=False | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, use args. |
||
| ): | ||
| args = get_args() | ||
| self.parallel_output = parallel_output | ||
|
|
||
|
|
@@ -199,7 +206,8 @@ def _to_float16(inputs): | |
| output_layer_init_method=scaled_init_method_normal(args.init_method_std, | ||
| args.num_layers), | ||
| layer_number=layer_idx, | ||
| self_attn_mask_type=AttnMaskType.causal)) | ||
| # TODO: Change naming of class from GPT to something that encapsulate prefix lm. | ||
| self_attn_mask_type=AttnMaskType.prefix if prefix_lm else AttnMaskType.causal)) | ||
|
|
||
|
|
||
| # Undo data format change | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |||||||||||||
| """General utilities.""" | ||||||||||||||
|
|
||||||||||||||
| import sys | ||||||||||||||
| from random import randint | ||||||||||||||
|
|
||||||||||||||
| import torch | ||||||||||||||
| from torch.nn.parallel import DistributedDataParallel as torchDDP | ||||||||||||||
|
|
@@ -144,12 +145,23 @@ def check_adlr_autoresume_termination(iteration, model, | |||||||||||||
| sys.exit(0) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def get_ltor_masks_and_position_ids(data, | ||||||||||||||
| eod_token, | ||||||||||||||
| reset_position_ids, | ||||||||||||||
| reset_attention_mask, | ||||||||||||||
| eod_mask_loss): | ||||||||||||||
| """Build masks and position id for left to right model.""" | ||||||||||||||
| def get_ltor_masks_and_position_ids( | ||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't review |
||||||||||||||
| data, | ||||||||||||||
| eod_token, | ||||||||||||||
| reset_position_ids, | ||||||||||||||
| reset_attention_mask, | ||||||||||||||
| eod_mask_loss, | ||||||||||||||
| prefix_indices, | ||||||||||||||
| loss_on_targets_only, | ||||||||||||||
| ): | ||||||||||||||
| """ | ||||||||||||||
| Build masks and position id for left to right model. | ||||||||||||||
| :param prefix_indices: argument can have multiple types: | ||||||||||||||
| - None signifies that the model is fully autoregressive. | ||||||||||||||
| - List[int] the argument holds all prefix indices that split a row into an input and a target | ||||||||||||||
| - List[List[int]] the argument holds all prefix indices that split documents between input and target. | ||||||||||||||
| :param loss_on_targets_only: bool to determine if we should mask loss on prefix. | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| # Extract batch size and sequence length. | ||||||||||||||
| micro_batch_size, seq_length = data.size() | ||||||||||||||
|
|
@@ -182,6 +194,14 @@ def get_ltor_masks_and_position_ids(data, | |||||||||||||
|
|
||||||||||||||
| # Find indecies where EOD token is. | ||||||||||||||
| eod_index = position_ids[b, data[b] == eod_token] | ||||||||||||||
|
|
||||||||||||||
| # If the last eod token is not the last token of the sequence, we suppose that there is a partial document | ||||||||||||||
| # We treat this case as if we add an eod token at the end of the sequence. | ||||||||||||||
| if data[b][-1] != eod_token: | ||||||||||||||
| eod_index = torch.cat( | ||||||||||||||
| (eod_index, torch.tensor([len(data[b])], dtype=eod_index.dtype, device=eod_index.device)) | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| # Detach indecies from positions if going to modify positions. | ||||||||||||||
| if reset_position_ids: | ||||||||||||||
| eod_index = eod_index.clone() | ||||||||||||||
|
|
@@ -190,13 +210,31 @@ def get_ltor_masks_and_position_ids(data, | |||||||||||||
| prev_index = 0 | ||||||||||||||
| for j in range(eod_index.size()[0]): | ||||||||||||||
| i = eod_index[j] | ||||||||||||||
| # Mask attention loss. | ||||||||||||||
|
|
||||||||||||||
| if reset_attention_mask: | ||||||||||||||
| # Prevent cross document interactions. | ||||||||||||||
| attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 | ||||||||||||||
|
|
||||||||||||||
| # Prefix lm per document. | ||||||||||||||
| if prefix_indices: | ||||||||||||||
| assert isinstance(prefix_indices[b], list), f"prefix for a row has to be document specific, and consequently return a list, got {prefix_indices[b]}" | ||||||||||||||
| attention_mask[b, 0, prev_index: prefix_indices[b][j], prev_index: prefix_indices[b][j]] = 1 | ||||||||||||||
| if loss_on_targets_only: | ||||||||||||||
| loss_mask[b, prev_index: prefix_indices[b][j]] = 0.0 | ||||||||||||||
|
|
||||||||||||||
| # Reset positions. | ||||||||||||||
| if reset_position_ids: | ||||||||||||||
| position_ids[b, (i + 1):] -= (i + 1 - prev_index) | ||||||||||||||
| prev_index = i + 1 | ||||||||||||||
|
|
||||||||||||||
| prev_index = i + 1 | ||||||||||||||
|
|
||||||||||||||
|
thomasw21 marked this conversation as resolved.
|
||||||||||||||
| # Prefix lm per row. | ||||||||||||||
| if prefix_indices is not None and (reset_attention_mask is False): | ||||||||||||||
| assert isinstance(prefix_indices[b], int), \ | ||||||||||||||
| f"prefix for a row has to be row specific, and consequently return an int, got {prefix_indices[b]}" | ||||||||||||||
| attention_mask[b, 0, :prefix_indices[b], :prefix_indices[b]] = 1 | ||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just to make sure I understand, do you have one prefix index per batch or one prefix index per instance?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So there's two case:
My code might be confusing because of
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the Megatron-DeepSpeed/megatron/utils.py Line 173 in 28a712d
you see that mask shape is [1, 1, seqlen, seqlen].So here Megatron-DeepSpeed/megatron/utils.py Line 235 in 28a712d
This line shouldn't work because b can be greater than 1.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah nice catch ! So I'm pretty sure we want to remove the logic here Megatron-DeepSpeed/megatron/utils.py Lines 170 to 173 in 28a712d
Essentially I'd say this is an optimisation as the masking is batch size independent in gpt, but would be batch size dependent in prefix lm? We can also work with batch size independent also I guess, sampling a single prefix id for the whole batch. WDYT? #0cdb0a941ddeefbdb1ccab3598f1e34bb38c35a3
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am fine with a single prefix index per batch.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hum this made the prefix row dependent, thomasw21@0cdb0a9 . I can revert to handle single index for the whole batch as you suggested. Do you have insights as to why one would work better than the other?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. being row dependent is better assuming it doesn't increase time and memory that much
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay let's keep this as is, and monitor if we're much slower compared to gpt. Bear in mind:
Maybe we can test on a 350m version and then switch if we think have a single index might be faster? |
||||||||||||||
| if loss_on_targets_only: | ||||||||||||||
| loss_mask[b, :prefix_indices[b]] = 0.0 | ||||||||||||||
|
thomasw21 marked this conversation as resolved.
|
||||||||||||||
|
|
||||||||||||||
| # Convert attention mask to binary: | ||||||||||||||
| attention_mask = (attention_mask < 0.5) | ||||||||||||||
|
|
@@ -226,3 +264,76 @@ def flops_calculator(model, args, iteration_time): | |||||||||||||
| effective_tera_flops_per_gpu = giga_flops_per_model_per_train_step / (iteration_time * 1000.0 * gpus_per_model) | ||||||||||||||
|
|
||||||||||||||
| print_rank_0(f"Effective Tera Flops per GPU: {round(effective_tera_flops_per_gpu, 2)} and total parameters {round(approx_parameters_in_billions, 3)} B") | ||||||||||||||
|
|
||||||||||||||
| def get_prefix_indices(data, eod_token, partial_prefix_indices, reset_attention_mask): | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. recommend adding
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not do that just yet. Reason is that I don't know yet how we'll work with prompts. Otherwise yes I agree with need a parser to find those partial indices (I assumed we would get them in preprocessing or whatever), but if not, we might need to compute them here. |
||||||||||||||
| """ | ||||||||||||||
| Helper function in order to: | ||||||||||||||
| - randomly choose prefix index when there's no constraint | ||||||||||||||
| - check that prefix are compatible with convention. | ||||||||||||||
|
|
||||||||||||||
| :param data: torch.Tensor | ||||||||||||||
| :param eod_token: int, token_id used to signal end of document | ||||||||||||||
| :param partial_prefix_indices: this agument can have multiple types: | ||||||||||||||
| - None, it signals that all prefix indices are randomly sampled. | ||||||||||||||
| - List[Optional[int]], its length has to be equal to mini batch size. It stores all the indices for per row prefix. | ||||||||||||||
| Optional means that if set to None, we allows ourselves to sample one randomly. | ||||||||||||||
| - List[List[Optional[int]]], it follows the following rules: | ||||||||||||||
| - The first dimension refers to that sample, ie len(partial_prefix_indices) == len(data) | ||||||||||||||
| - The second dimension refers to the number of document of that sample, ie | ||||||||||||||
| len(partial_prefix_indices[b]) == (data[b] == eod_token).sum() (+1 for the last partial document). | ||||||||||||||
| - partial_prefix_indices have to be interleaved with eod_indices, ie | ||||||||||||||
| eod_indices[b][d-1] < partial_prefix_indices[b][d] < eod_indices[b][d] + 1 or is None. | ||||||||||||||
| - Optional means that if set to None, we allows ourselves to sample one randomly. | ||||||||||||||
| :param reset_attention_mask: bool, determines if prefixes are to be per document or per row. | ||||||||||||||
| :return Depending if prefix is per document or per row, the method returns: | ||||||||||||||
| - List[List[int]]: prefix indices for each document in case of per document prefix | ||||||||||||||
| - List[int]: prefix indices for rows else. | ||||||||||||||
| """ | ||||||||||||||
| micro_batch_size, seq_length = data.size() | ||||||||||||||
| prefix_indices = [] | ||||||||||||||
|
|
||||||||||||||
|
thomasw21 marked this conversation as resolved.
|
||||||||||||||
| assert partial_prefix_indices is None or len(partial_prefix_indices) == micro_batch_size, f"partial_prefix_indices has to be None or its length equal to {micro_batch_size}, got {len(partial_prefix_indices)}" | ||||||||||||||
| for batch_id in range(micro_batch_size): | ||||||||||||||
| prefix_indices.append([]) | ||||||||||||||
| # Compute the index of all eod tokens in data. | ||||||||||||||
| eod_indices = (data[batch_id] == eod_token).nonzero().squeeze(-1) | ||||||||||||||
|
|
||||||||||||||
| # If the last eod token is not the last token of the sequence, we suppose that there is a partial document | ||||||||||||||
| # We treat this case as if we add an eod token at the end of the sequence. | ||||||||||||||
| if data[batch_id][-1] != eod_token: | ||||||||||||||
| eod_indices = torch.cat( | ||||||||||||||
| (eod_indices, torch.tensor([len(data[batch_id])], dtype = eod_indices.dtype, device = eod_indices.device)) | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
|
thomasw21 marked this conversation as resolved.
Outdated
|
||||||||||||||
| # Prefix lm per document. | ||||||||||||||
| if reset_attention_mask: | ||||||||||||||
| prev_index = 0 | ||||||||||||||
| assert partial_prefix_indices is None or len(partial_prefix_indices[batch_id]) == len(eod_indices), f"The number of prefixes has to match the number of documents, complete or partial. Got {len(partial_prefix_indices[batch_id])} prefixes and {len(eod_indices)} documents" | ||||||||||||||
|
|
||||||||||||||
| for doc_id, eod_index in enumerate(eod_indices): | ||||||||||||||
| assert partial_prefix_indices is None or isinstance(partial_prefix_indices[batch_id], list), f"Per document prefix has to store a list on indices for each row, got {partial_prefix_indices[batch_id]}" | ||||||||||||||
| if partial_prefix_indices is None or partial_prefix_indices[batch_id][doc_id] is None: | ||||||||||||||
| # We need to randomly generate a prefix index that satisfies the interleave condition in the docstring | ||||||||||||||
| prefix_index = randint(prev_index, eod_index) | ||||||||||||||
| else: | ||||||||||||||
| # We get value from partial_prefix_indices, and run validation on that value | ||||||||||||||
| prefix_index = partial_prefix_indices[batch_id][doc_id] | ||||||||||||||
| assert prev_index <= prefix_index < eod_index, f"Prefix index needs to be between documents indices, {prev_index} <= {prefix_index} < {eod_index} should be True." | ||||||||||||||
|
|
||||||||||||||
| prefix_indices[batch_id].append(prefix_index) | ||||||||||||||
| prev_index = eod_index + 1 | ||||||||||||||
|
|
||||||||||||||
| # Prefix lm per row. | ||||||||||||||
| else: | ||||||||||||||
| assert partial_prefix_indices is None or isinstance(partial_prefix_indices[batch_id], int), \ | ||||||||||||||
| f"Per document prefix has to store an int for each row, got {partial_prefix_indices[batch_id]}" | ||||||||||||||
|
|
||||||||||||||
| if partial_prefix_indices is None or partial_prefix_indices[batch_id] is None: | ||||||||||||||
| # We need to randomly generate a prefix index | ||||||||||||||
| prefix_index = randint(0, seq_length - 1) | ||||||||||||||
| else: | ||||||||||||||
| # We get value from partial_prefix_indices, and run validation on that value | ||||||||||||||
| prefix_index = partial_prefix_indices[batch_id] | ||||||||||||||
| assert 0 <= prefix_index < seq_length - 1, f"Prefix index needs to be between documents indices, 0 <= {prefix_index} < {seq_length - 1} should be True." | ||||||||||||||
| prefix_indices[batch_id].append(prefix_index) | ||||||||||||||
| return prefix_indices | ||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reset-attention-mask parameter, I recommend adding a note that says
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can improve while being less verbose. Something along the lines that "token interactions are document constrained in self attention mask"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay my opinion after thinking about it:
--reset-attention-maskto something like--document-specific-attentionand add a comment that documents are independent from one another (using eod token as a document segmenter)Documents, segmented using eod tokens, are considered independent. Attention between tokens from different documents have a null score.The reason why I like these versions better is to force orthogonalisation of features, ie those flags (prefix and reset attention) have to be independent from one another. Otherwise if you end up adding something else on the attention mechanism whenever have to add a note in the help
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add an argument here to enable/disable prefix-lm