From af94a3d1ab8e28c4da2c8812a907158001173b4c Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 17 Jun 2022 16:47:35 +0200 Subject: [PATCH 1/3] - get rid of for loops - deals better with padded batched input - avoid useless cpu/gpu communication when creating alibi Co-authored-by: justheuristic --- .../models/bloom/modeling_bloom.py | 89 +++++++------------ 1 file changed, 33 insertions(+), 56 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index e1bc39a16324..6e6e096f02f6 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -96,7 +96,9 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask): ) -def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16): +def build_alibi_tensor( + max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu") +) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value @@ -111,66 +113,41 @@ def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16): number of heads dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`): dtype of the output tensor + device: (`torch.device`, *optional*, default=`torch.device('cpu')`): + device of the output alibi tensor """ + closest_power_of_2 = 2 ** math.floor(math.log2(n_head)) + base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != n_head: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32 + ) + num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) - def get_slopes(n): - def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] - ) - - slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1) - arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0) - alibi = slopes * arange_tensor.expand(n_head, -1, -1) - - alibi = alibi.to(dtype) - - return alibi + lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32) + return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype) -def pre_process_alibi_for_pad(alibi, attention_mask, num_heads): +def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor): """ Args: Pre-process the alibi tensor for padding. alibi: ([`torch.tensor`], *required*): alibi tensor to pre-process attention_mask: ([`torch.tensor`], *required*): - attention mask to pre-process""" - - # Sanity check if we are not inferring less tokens than the total sequence length - # This usually happens when the inference is done with past_key_values - # In this case we re-create the alibi tensor with the correct sequence length - if attention_mask.shape[-1] != alibi.shape[-1]: - alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.dtype).repeat( - attention_mask.shape[0], 1, 1 - ) - # Get the indexes of the padding tokens - index_x0, index_y0 = torch.where(attention_mask == 0.0) - index_x1, index_y1 = torch.where(attention_mask == 1.0) - - # Clone the embeddings - we can detach because the embeddings are not learned - # Get a refence tensor - slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.dtype) - - # Loop over the batch where the padding is and replace the alibi tensor by the reference tensor - # Only where you do not have padding. Replace padding tokens by zeros - # This operation can be seen as a shifting operation. - for i, index in enumerate(torch.unique(index_x0)): - slice_to_modify = torch.zeros_like(slice_reference_alibi) - index_shift = index_y1[index_x1 == index] - shift_value = len(index_shift) - slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value] - alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify - return alibi + attention mask to pre-process + """ + + unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1) + # ^-- [batch, max_len], values correspond to element indices after removing padding + # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0 + alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0) + return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1) def dropout_add(x, residual, prob, training): @@ -359,12 +336,12 @@ def forward( output_attentions=False, ): # hidden_states: [batch_size, seq_length, hidden_size] - # repeat alibi tensor with the batch size - alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device) - # apply preprocessing if the input is padded if attention_mask is not None and 0 in attention_mask: - alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads) + alibi = pre_process_alibi_for_pad(alibi, attention_mask) + # otherwise repeat alibi tensor with the batch size + else: + alibi = alibi.repeat(hidden_states.shape[0], 1, 1) mixed_x_layer = self.query_key_value(hidden_states) @@ -773,7 +750,7 @@ def forward( current_sequence_length = hidden_states.shape[1] if past_key_values[0] is not None: current_sequence_length += past_key_values[0][0].shape[1] - alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype) + alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype, hidden_states.device) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): From 9f47dbde7bf2c76fd700b996a907f7ed97f0fde8 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 17 Jun 2022 17:09:33 +0200 Subject: [PATCH 2/3] Update src/transformers/models/bloom/modeling_bloom.py Co-authored-by: justheuristic --- src/transformers/models/bloom/modeling_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 6e6e096f02f6..71861f38ed6a 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -337,7 +337,7 @@ def forward( ): # hidden_states: [batch_size, seq_length, hidden_size] # apply preprocessing if the input is padded - if attention_mask is not None and 0 in attention_mask: + if attention_mask is not None: alibi = pre_process_alibi_for_pad(alibi, attention_mask) # otherwise repeat alibi tensor with the batch size else: From 38ef869c1a631c7ab0540b5fa94051bb8338afa4 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 24 Jun 2022 15:13:07 +0200 Subject: [PATCH 3/3] Update src/transformers/models/bloom/modeling_bloom.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/bloom/modeling_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 71861f38ed6a..15e3e44da94b 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -339,8 +339,8 @@ def forward( # apply preprocessing if the input is padded if attention_mask is not None: alibi = pre_process_alibi_for_pad(alibi, attention_mask) - # otherwise repeat alibi tensor with the batch size else: + # otherwise repeat alibi tensor with the batch size alibi = alibi.repeat(hidden_states.shape[0], 1, 1) mixed_x_layer = self.query_key_value(hidden_states)