-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
t5-rpe-fix targeting r1.10.0; raise exception for PP>2. #4469
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -287,6 +287,7 @@ def __init__( | |
megatron_legacy=False, | ||
bias=True, | ||
headscale=False, | ||
has_relative_attention_bias=False, | ||
): | ||
super(ParallelAttention, self).__init__() | ||
|
||
|
@@ -299,6 +300,7 @@ def __init__( | |
self.attn_mask_type = attn_mask_type | ||
self.megatron_legacy = megatron_legacy | ||
self.headscale = headscale | ||
self.has_relative_attention_bias = has_relative_attention_bias | ||
|
||
if kv_channels is None: | ||
assert ( | ||
|
@@ -382,7 +384,7 @@ def __init__( | |
self.position_embedding_type = position_embedding_type | ||
self.relative_attention_num_buckets = relative_attention_num_buckets | ||
self.relative_attention_max_distance = relative_attention_max_distance | ||
if self.position_embedding_type == 'relative': | ||
if self.position_embedding_type == 'relative' and self.has_relative_attention_bias: | ||
self.relative_attention_bias = torch.nn.Embedding( | ||
relative_attention_num_buckets, self.num_attention_heads_per_partition | ||
).to(torch.cuda.current_device()) | ||
|
@@ -498,7 +500,7 @@ def compute_bias(self, query_length, key_length): | |
relative_position = memory_position - context_position # shape (query_length, key_length) | ||
relative_position_bucket = self._relative_position_bucket( | ||
relative_position, # shape (query_length, key_length) | ||
bidirectional=(self.layer_type != LayerType.decoder), # (not self.is_decoder), | ||
bidirectional=(self.attention_type != AttnMaskType.causal), # self.is_decoder and self_attention. | ||
num_buckets=self.relative_attention_num_buckets, | ||
max_distance=self.relative_attention_max_distance, | ||
) | ||
|
@@ -683,9 +685,12 @@ def forward( | |
|
||
if position_bias is None: | ||
if self.position_embedding_type == 'relative': | ||
position_bias = self.compute_bias(real_seq_length, key_length) | ||
else: | ||
pass # HuggingFace implementation initialize position_bias to zero when not using | ||
if self.has_relative_attention_bias: | ||
position_bias = self.compute_bias(real_seq_length, key_length) | ||
elif attention_mask is not None: | ||
position_bias = torch.zeros_like(attention_mask).to(torch.cuda.current_device()) | ||
else: | ||
position_bias = torch.zeros(1, key_length, key_length).to(torch.cuda.current_device()) | ||
|
||
# if key and values are already calculated | ||
# we want only the last query position bias | ||
|
@@ -975,6 +980,7 @@ def __init__( | |
normalization='layernorm', | ||
transformer_block_type='pre_ln', | ||
headscale=False, | ||
has_relative_attention_bias=False, | ||
): | ||
super(ParallelTransformerLayer_, self).__init__() | ||
|
||
|
@@ -1038,6 +1044,7 @@ def __init__( | |
megatron_legacy=megatron_legacy, | ||
bias=bias, | ||
headscale=headscale, | ||
has_relative_attention_bias=has_relative_attention_bias, | ||
) | ||
# Normformer normalization | ||
if transformer_block_type == 'normformer': | ||
|
@@ -1079,6 +1086,7 @@ def __init__( | |
megatron_legacy=megatron_legacy, | ||
bias=bias, | ||
headscale=headscale, | ||
has_relative_attention_bias=False, | ||
) | ||
# Normformer normalization | ||
if transformer_block_type == 'normformer': | ||
|
@@ -1207,14 +1215,6 @@ def forward( | |
# Post-LN: x -> MHA -> Residual -> LN -> MLP -> Residual -> LN | ||
# Normformer: x -> LN -> MHA -> LN -> Residual -> MLP (w/LN) -> Residual | ||
|
||
if type(hidden_states) is tuple: | ||
if len(hidden_states) == 2: | ||
hidden_states, position_bias = hidden_states | ||
elif len(hidden_states) == 3: | ||
hidden_states, position_bias, encoder_decoder_position_bias = hidden_states | ||
else: | ||
raise IndexError('Hidden_states needs to be tuple containing 2 or 3 elements.') | ||
|
||
Comment on lines
-1210
to
-1217
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you delete this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think he moved it further down if type(x_) is tuple: |
||
residual = hidden_states | ||
# Layer norm at the beginning of the transformer layer. | ||
if self.transformer_block_type in ['pre_ln', 'normformer']: | ||
|
@@ -1228,6 +1228,7 @@ def forward( | |
set_inference_key_value_memory=set_inference_key_value_memory, | ||
inference_max_sequence_len=inference_max_sequence_len, | ||
rotary_pos_emb=self_attention_pos_emb, | ||
position_bias=position_bias, | ||
) | ||
|
||
if get_key_value: | ||
|
@@ -1484,7 +1485,7 @@ def __init__( | |
|
||
self.num_layers = self.get_num_layers(num_layers) | ||
# Transformer layers. | ||
def build_layer(layer_number): | ||
def build_layer(layer_number, has_relative_attention_bias=False): | ||
if isinstance(layer_type, list): | ||
lt = layer_type[layer_number - 1] | ||
else: | ||
|
@@ -1522,6 +1523,7 @@ def build_layer(layer_number): | |
normalization=normalization, | ||
transformer_block_type=transformer_block_type, | ||
headscale=headscale, | ||
has_relative_attention_bias=has_relative_attention_bias, | ||
) | ||
|
||
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: | ||
|
@@ -1558,7 +1560,14 @@ def build_layer(layer_number): | |
else: | ||
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers | ||
|
||
self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)]) | ||
self.layers = torch.nn.ModuleList( | ||
[ | ||
build_layer( | ||
i + 1 + offset, has_relative_attention_bias=(i == 0) and parallel_state.is_pipeline_first_stage() | ||
) | ||
for i in range(self.num_layers) | ||
] | ||
) | ||
|
||
if self.post_process: | ||
# Final layer norm before output. | ||
|
@@ -1624,9 +1633,16 @@ def custom_forward(*inputs): | |
encoder_output, | ||
enc_dec_attn_mask, | ||
rotary_pos_emb, | ||
position_bias, | ||
encoder_decoder_position_bias, | ||
position_bias=position_bias, | ||
encoder_decoder_position_bias=encoder_decoder_position_bias, | ||
) | ||
if type(x_) is tuple: | ||
if len(x_) == 2: | ||
x_, position_bias = x_ | ||
elif len(x_) == 3: | ||
x_, position_bias, encoder_decoder_position_bias = x_ | ||
else: | ||
raise IndexError('Hidden_states (x_) needs to be tuple containing 2 or 3 elements.') | ||
return x_ | ||
|
||
return custom_forward | ||
|
@@ -1704,9 +1720,10 @@ def forward( | |
inference_max_sequence_len=None, | ||
rotary_pos_emb=None, # list of positional embedding tensors, first one self attention, second one and third one are for cross attention (q, k) | ||
retrieved_emb=None, # tensor of retrieved embedding of shape [b, k, r, n, d] | ||
position_bias=None, | ||
encoder_decoder_position_bias=None, | ||
): | ||
position_bias = None | ||
encoder_decoder_position_bias = None | ||
|
||
# Checks. | ||
if inference_max_sequence_len: | ||
assert self.activations_checkpoint_method is None, 'inference does not work with activation checkpointing' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we change this comment? I think should be encoder + self-attention?