From ae09d5bc5cd5c8ea0976c4ae0d1adb0074b297b4 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 9 Jun 2022 18:04:11 +0200 Subject: [PATCH 01/31] fix tolerance for a bloom slow test --- tests/models/bloom/test_modeling_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index 1f5c10d2ee59..30b1ac8767de 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -676,7 +676,7 @@ def test_hidden_states_transformers(self): } if cuda_available: - self.assertEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item()) + self.assertAlmostEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item(), places=4) else: self.assertAlmostEqual(MEAN_VALUE_LAST_LM, logits.last_hidden_state.mean().item(), places=3) From a3be0713d9d92c5f5bf79a56fdf12306977b81eb Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 17 Jun 2022 16:42:45 +0200 Subject: [PATCH 02/31] enhance alibi padding - get rid of for loops - deals better with padded batched input - avoid useless cpu/gpu communication when creating alibi Co-authored-by: justheuristic --- .../models/bloom/modeling_bloom.py | 326 ++++++++++++++---- 1 file changed, 264 insertions(+), 62 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 232bbc5f22e3..6e6e096f02f6 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -15,15 +15,20 @@ """PyTorch BLOOM model.""" import math -from typing import Tuple +from typing import Tuple, Union import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward -from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) from ...modeling_utils import PreTrainedModel from ...utils import logging from .configuration_bloom import BloomConfig @@ -42,7 +47,7 @@ "bigscience/bloom-1b3", "bigscience/bloom-2b5", "bigscience/bloom-6b3", - "bigscience/bloom-176b", + "bigscience/bloom", ] @@ -91,7 +96,9 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask): ) -def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16): +def build_alibi_tensor( + max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu") +) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value @@ -106,66 +113,41 @@ def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16): number of heads dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`): dtype of the output tensor + device: (`torch.device`, *optional*, default=`torch.device('cpu')`): + device of the output alibi tensor """ + closest_power_of_2 = 2 ** math.floor(math.log2(n_head)) + base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != n_head: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32 + ) + num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) - def get_slopes(n): - def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] - ) - - slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1) - arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0) - alibi = slopes * arange_tensor.expand(n_head, -1, -1) - - alibi = alibi.to(dtype) - - return alibi + lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32) + return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype) -def pre_process_alibi_for_pad(alibi, attention_mask, num_heads): +def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor): """ Args: Pre-process the alibi tensor for padding. alibi: ([`torch.tensor`], *required*): alibi tensor to pre-process attention_mask: ([`torch.tensor`], *required*): - attention mask to pre-process""" - - # Sanity check if we are not inferring less tokens than the total sequence length - # This usually happens when the inference is done with past_key_values - # In this case we re-create the alibi tensor with the correct sequence length - if attention_mask.shape[-1] != alibi.shape[-1]: - alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.dtype).repeat( - attention_mask.shape[0], 1, 1 - ) - # Get the indexes of the padding tokens - index_x0, index_y0 = torch.where(attention_mask == 0.0) - index_x1, index_y1 = torch.where(attention_mask == 1.0) - - # Clone the embeddings - we can detach because the embeddings are not learned - # Get a refence tensor - slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.dtype) - - # Loop over the batch where the padding is and replace the alibi tensor by the reference tensor - # Only where you do not have padding. Replace padding tokens by zeros - # This operation can be seen as a shifting operation. - for i, index in enumerate(torch.unique(index_x0)): - slice_to_modify = torch.zeros_like(slice_reference_alibi) - index_shift = index_y1[index_x1 == index] - shift_value = len(index_shift) - slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value] - alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify - return alibi + attention mask to pre-process + """ + + unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1) + # ^-- [batch, max_len], values correspond to element indices after removing padding + # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0 + alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0) + return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1) def dropout_add(x, residual, prob, training): @@ -354,12 +336,12 @@ def forward( output_attentions=False, ): # hidden_states: [batch_size, seq_length, hidden_size] - # repeat alibi tensor with the batch size - alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device) - # apply preprocessing if the input is padded if attention_mask is not None and 0 in attention_mask: - alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads) + alibi = pre_process_alibi_for_pad(alibi, attention_mask) + # otherwise repeat alibi tensor with the batch size + else: + alibi = alibi.repeat(hidden_states.shape[0], 1, 1) mixed_x_layer = self.query_key_value(hidden_states) @@ -726,7 +708,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, - ): + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -768,7 +750,7 @@ def forward( current_sequence_length = hidden_states.shape[1] if past_key_values[0] is not None: current_sequence_length += past_key_values[0][0].shape[1] - alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype) + alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype, hidden_states.device) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -902,7 +884,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, - ): + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set @@ -959,3 +941,223 @@ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) for layer_past in past ) + + +@add_start_docstrings( + """ + The Bloom Model transformer with a sequence classification head on top (linear layer). + + [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + BLOOM_START_DOCSTRING, +) +class BloomForSequenceClassification(BloomPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = BloomModel(config) + self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + Bloom Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + BLOOM_START_DOCSTRING, +) +class BloomForTokenClassification(BloomPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = BloomModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) From fcfe5b74f17733743ce9bd4a3ebf9e16464318d1 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 6 Jul 2022 16:23:05 +0200 Subject: [PATCH 03/31] optimize attention mask --- .../models/bloom/modeling_bloom.py | 119 ++++++++++-------- 1 file changed, 67 insertions(+), 52 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 6e6e096f02f6..4375946a06c8 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -51,6 +51,21 @@ ] +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask_cond = torch.arange(mask.size(-1)) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): """Split a tensor along its last dimension. @@ -77,23 +92,18 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks= return tensor_list -def attention_mask_func(attention_scores, attention_mask, causal_mask): - if attention_mask.dtype == torch.bool: - attention_mask_bool = ~attention_mask - else: - attention_mask_bool = (1 - attention_mask).bool() - - query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1) - padded_causal_mask = ( - attention_mask_bool[:, None, key_length - query_length : key_length, None] - + ~causal_mask[:, :, key_length - query_length : key_length, :key_length] - ).bool() - padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool() - # Make use of floats - return ( - attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0), - padded_causal_mask, - ) +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) def build_alibi_tensor( @@ -142,7 +152,6 @@ def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor) attention_mask: ([`torch.tensor`], *required*): attention mask to pre-process """ - unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1) # ^-- [batch, max_len], values correspond to element indices after removing padding # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0 @@ -234,34 +243,30 @@ def forward(self, x): class BloomScaledSoftmax(nn.Module): """ - fused operation: scaling + mask + softmax - Args: + fused operation: scaling + mask + softmax input_in_fp16 (`bool`, *required*): flag to indicate if input in fp16 data format. input_in_bf16 (`bool`, *required*): flag to indicate if input in bf16 data format. scaled_masked_softmax_fusion (`bool`, *required*): flag to indicate user want to use softmax fusion - mask_func (`function`, *required*): - mask function to be applied. softmax_in_fp32 (`bool`, *required*): if true, softmax in performed at fp32 precision. scale (`float`, *required*): scaling factor used in input tensor scaling. """ - def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale): + def __init__(self, scaled_masked_softmax_fusion, softmax_in_fp32, scale): super().__init__() self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale if not (self.scale is None or softmax_in_fp32): raise ValueError("softmax should be in fp32 when scaled") - def forward(self, input, mask, max_positions): + def forward(self, input, causal_mask): input_dtype = input.dtype input_in_16bit = input_dtype in [torch.float16, torch.bfloat16] softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype @@ -269,17 +274,7 @@ def forward(self, input, mask, max_positions): if self.scale is not None: input = input * self.scale - if mask is not None: - mask = mask.to(input.device) - causal_mask = ( - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)) - .view(1, 1, max_positions, max_positions) - .to(input.device) - ) - mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) - probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask) - else: - probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype) + probs = nn.functional.softmax(input + causal_mask, dim=-1, dtype=softmax_dtype) * (~causal_mask.bool()) if input_in_16bit and self.softmax_in_fp32: probs = probs.to(dtype=input_dtype) @@ -315,7 +310,6 @@ def __init__(self, config, layer_number=None): # Scaled Softmax self.scale_mask_softmax = BloomScaledSoftmax( self.masked_softmax_fusion, - attention_mask_func, self.attention_softmax_in_fp32, self.layer_number, ) @@ -335,14 +329,6 @@ def forward( use_cache=False, output_attentions=False, ): - # hidden_states: [batch_size, seq_length, hidden_size] - # apply preprocessing if the input is padded - if attention_mask is not None and 0 in attention_mask: - alibi = pre_process_alibi_for_pad(alibi, attention_mask) - # otherwise repeat alibi tensor with the batch size - else: - alibi = alibi.repeat(hidden_states.shape[0], 1, 1) - mixed_x_layer = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim] @@ -389,10 +375,7 @@ def forward( attention_scores = matmul_result.view(*output_size) # attention scores and attention mask [b, np, sq, sk] - max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2]) - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to( - value_layer.dtype - ) + attention_probs = self.scale_mask_softmax(attention_scores, attention_mask).to(value_layer.dtype) attention_probs = self.attention_dropout(attention_probs) if head_mask is not None: @@ -686,6 +669,24 @@ def __init__(self, config): def get_input_embeddings(self): return self.word_embeddings + def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(inputs_embeds.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + def set_input_embeddings(self, new_embeddings): self.word_embeddings = new_embeddings @@ -748,10 +749,24 @@ def forward( # Compute alibi tensor: check build_alibi_tensor documentation current_sequence_length = hidden_states.shape[1] + past_key_values_length = 0 if past_key_values[0] is not None: - current_sequence_length += past_key_values[0][0].shape[1] + past_key_values_length = past_key_values[0][0].shape[1] + current_sequence_length += past_key_values_length alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype, hidden_states.device) + # apply preprocessing if the input is padded + if attention_mask is not None and 0 in attention_mask: + alibi = pre_process_alibi_for_pad(alibi, attention_mask) + # otherwise repeat alibi tensor with the batch size + else: + alibi = alibi.repeat(hidden_states.shape[0], 1, 1) + + if attention_mask is None: + attention_mask = torch.ones((hidden_states.shape[:-1]), device=hidden_states.device) + + causal_mask = self._prepare_attn_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: @@ -776,14 +791,14 @@ def custom_forward(*inputs): create_custom_forward(block), hidden_states, None, - attention_mask, + causal_mask, head_mask[i], ) else: outputs = block( hidden_states, layer_past=layer_past, - attention_mask=attention_mask, + attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, From 23d8eb3d061f9308c9d81b9b6fa8bb8d585987cf Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 7 Jul 2022 11:31:59 +0200 Subject: [PATCH 04/31] fix scaled softmax limit values --- src/transformers/models/bloom/modeling_bloom.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 4375946a06c8..59101e7af918 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -56,7 +56,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -155,7 +155,8 @@ def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor) unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1) # ^-- [batch, max_len], values correspond to element indices after removing padding # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0 - alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0) + alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0) # [num_heads, batch_size, max_len] + alibi = alibi.transpose(0, 1) # [batch_size, num_heads, max_len] return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1) @@ -274,7 +275,9 @@ def forward(self, input, causal_mask): if self.scale is not None: input = input * self.scale - probs = nn.functional.softmax(input + causal_mask, dim=-1, dtype=softmax_dtype) * (~causal_mask.bool()) + attn_weights = input + causal_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + probs = nn.functional.softmax(attn_weights, dim=-1, dtype=softmax_dtype) * (~causal_mask.bool()) if input_in_16bit and self.softmax_in_fp32: probs = probs.to(dtype=input_dtype) @@ -676,7 +679,7 @@ def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_ke if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(inputs_embeds.device) + ).to(attention_mask.device) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -756,6 +759,8 @@ def forward( alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype, hidden_states.device) # apply preprocessing if the input is padded + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) if attention_mask is not None and 0 in attention_mask: alibi = pre_process_alibi_for_pad(alibi, attention_mask) # otherwise repeat alibi tensor with the batch size From 287d3c7abacd290e74f8048cb00cc278f8dcb4b7 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 7 Jul 2022 12:50:58 +0200 Subject: [PATCH 05/31] optimize building alibi tensor Co-authored-by: Younes Belkada --- .../models/bloom/modeling_bloom.py | 54 ++++++++----------- 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index fd6008e7ba00..d3cdd3fd9c47 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -107,7 +107,10 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): def build_alibi_tensor( - max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu") + attention_mask: torch.Tensor, + n_head: int, + dtype: torch.dtype = torch.bfloat16, + device: torch.device = torch.device("cpu"), ) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it @@ -116,9 +119,9 @@ def build_alibi_tensor( https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 Args: - Returns tensor shaped (n_head, 1, max_seq_len) - max_seq_len: (`int`, *required*): - max sequence length + Returns tensor shaped (batch_size * n_head, 1, max_seq_len) + attention_mask: (`torch.Tensor`, *required*) + attention mask n_head: (`int`, *required*): number of heads dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`): @@ -139,25 +142,18 @@ def build_alibi_tensor( extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) - lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32) - return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype) + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_flax_t5.py#L426 + # batch_size = 1, n_head = n_head, query_length - -def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor): - """ - Args: - Pre-process the alibi tensor for padding. - alibi: ([`torch.tensor`], *required*): - alibi tensor to pre-process - attention_mask: ([`torch.tensor`], *required*): - attention mask to pre-process - """ - unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1) - # ^-- [batch, max_len], values correspond to element indices after removing padding - # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0 - alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0) # [num_heads, batch_size, max_len] - alibi = alibi.transpose(0, 1) # [batch_size, num_heads, max_len] - return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1) + arange_tensor = (attention_mask.cumsum(-1)[:, None, :].to(device) - 1) * attention_mask.unsqueeze(1) + alibi = slopes.unsqueeze(-1) * arange_tensor + alibi = alibi * attention_mask.unsqueeze(1) + return alibi.reshape(alibi.shape[0] * n_head, 1, -1).to(dtype) def dropout_add(x, residual, prob, training): @@ -756,19 +752,13 @@ def forward( if past_key_values[0] is not None: past_key_values_length = past_key_values[0][0].shape[1] current_sequence_length += past_key_values_length - alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype, hidden_states.device) - - # apply preprocessing if the input is padded - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) - if attention_mask is not None and 0 in attention_mask: - alibi = pre_process_alibi_for_pad(alibi, attention_mask) - # otherwise repeat alibi tensor with the batch size - else: - alibi = alibi.repeat(hidden_states.shape[0], 1, 1) if attention_mask is None: attention_mask = torch.ones((hidden_states.shape[:-1]), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = build_alibi_tensor(attention_mask, self.n_head, hidden_states.dtype, hidden_states.device) causal_mask = self._prepare_attn_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) From 342dad18d95b75411510339ce9f11715a3517c5e Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Thu, 7 Jul 2022 12:51:29 +0200 Subject: [PATCH 06/31] fix attention_mask shape when it's None --- src/transformers/models/bloom/modeling_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index d3cdd3fd9c47..8d168f63b2cc 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -754,7 +754,7 @@ def forward( current_sequence_length += past_key_values_length if attention_mask is None: - attention_mask = torch.ones((hidden_states.shape[:-1]), device=hidden_states.device) + attention_mask = torch.ones((hidden_states.shape[0], current_sequence_length), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) From 563c3234228a0878ff79bd58fd6042c2211c323a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 7 Jul 2022 15:35:48 +0200 Subject: [PATCH 07/31] minor fixes - fix docstring + arg names --- .../models/bloom/modeling_bloom.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 8d168f63b2cc..6d3cfd3f1786 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -55,22 +55,23 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ """ Make causal mask used for bi-directional self-attention. """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min) + batch_size, target_length = input_ids_shape + mask = torch.full((target_length, target_length), torch.finfo(dtype).min) mask_cond = torch.arange(mask.size(-1)) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1) + mask.masked_fill_(intermediate_mask, 0) mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + mask = torch.cat([torch.zeros(target_length, past_key_values_length, dtype=dtype), mask], dim=-1) + return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): """Split a tensor along its last dimension. Args: - tensor: ([`torch.tensor`], *required*): + tensor ([`torch.tensor`], *required*): input tensor to split num_partitions ([`int`], *required*): number of partitions to split the tensor @@ -96,10 +97,10 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len + batch_size, source_length = mask.size() + tgt_len = tgt_len if tgt_len is not None else source_length - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, source_length).to(dtype) inverted_mask = 1.0 - expanded_mask @@ -120,8 +121,8 @@ def build_alibi_tensor( Args: Returns tensor shaped (batch_size * n_head, 1, max_seq_len) - attention_mask: (`torch.Tensor`, *required*) - attention mask + attention_mask (`torch.Tensor`, *required*): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). n_head: (`int`, *required*): number of heads dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`): From 082a1b89c8d15524871ffe08f318d5e60ae1d527 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 7 Jul 2022 15:39:02 +0200 Subject: [PATCH 08/31] remove colons in docstring --- src/transformers/models/bloom/modeling_bloom.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 6d3cfd3f1786..094a1ad16986 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -123,11 +123,11 @@ def build_alibi_tensor( Returns tensor shaped (batch_size * n_head, 1, max_seq_len) attention_mask (`torch.Tensor`, *required*): Token-wise attention mask, this should be of shape (batch_size, max_seq_len). - n_head: (`int`, *required*): + n_head (`int`, *required*): number of heads - dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): dtype of the output tensor - device: (`torch.device`, *optional*, default=`torch.device('cpu')`): + device (`torch.device`, *optional*, default=`torch.device('cpu')`): device of the output alibi tensor """ closest_power_of_2 = 2 ** math.floor(math.log2(n_head)) From af8dce9df79a9c483d53e9a9a44778b59a203348 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 7 Jul 2022 15:43:15 +0200 Subject: [PATCH 09/31] Apply suggestions from code review Co-authored-by: Patrick von Platen --- src/transformers/models/bloom/modeling_bloom.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 094a1ad16986..044c18f2863b 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -110,8 +110,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): def build_alibi_tensor( attention_mask: torch.Tensor, n_head: int, - dtype: torch.dtype = torch.bfloat16, - device: torch.device = torch.device("cpu"), + dtype, + device, ) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it From 2dc5134d597d0f38a26b6c14c918bad65d72f35a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 7 Jul 2022 15:46:06 +0200 Subject: [PATCH 10/31] apply suggestion --- src/transformers/models/bloom/modeling_bloom.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 044c18f2863b..b3c531518e14 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -64,7 +64,8 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ if past_key_values_length > 0: mask = torch.cat([torch.zeros(target_length, past_key_values_length, dtype=dtype), mask], dim=-1) - return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) + expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) + return expanded_mask def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): From 9fbfc3693e04a813e36f0f4b5e2985738cffe558 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 7 Jul 2022 15:51:23 +0200 Subject: [PATCH 11/31] remove unsued arg --- src/transformers/models/bloom/configuration_bloom.py | 2 -- src/transformers/models/bloom/modeling_bloom.py | 7 +------ 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/transformers/models/bloom/configuration_bloom.py b/src/transformers/models/bloom/configuration_bloom.py index f7929dee8afa..4712b0f18a1b 100644 --- a/src/transformers/models/bloom/configuration_bloom.py +++ b/src/transformers/models/bloom/configuration_bloom.py @@ -128,7 +128,6 @@ def __init__( hidden_size=64, n_layer=2, n_head=8, - masked_softmax_fusion=True, layer_norm_epsilon=1e-5, initializer_range=0.02, use_cache=False, @@ -147,7 +146,6 @@ def __init__( self.hidden_size = hidden_size self.n_layer = n_layer self.n_head = n_head - self.masked_softmax_fusion = masked_softmax_fusion self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range self.use_cache = use_cache diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index b3c531518e14..2595ba7799f3 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -248,17 +248,14 @@ class BloomScaledSoftmax(nn.Module): flag to indicate if input in fp16 data format. input_in_bf16 (`bool`, *required*): flag to indicate if input in bf16 data format. - scaled_masked_softmax_fusion (`bool`, *required*): - flag to indicate user want to use softmax fusion softmax_in_fp32 (`bool`, *required*): if true, softmax in performed at fp32 precision. scale (`float`, *required*): scaling factor used in input tensor scaling. """ - def __init__(self, scaled_masked_softmax_fusion, softmax_in_fp32, scale): + def __init__(self, softmax_in_fp32, scale): super().__init__() - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion self.softmax_in_fp32 = softmax_in_fp32 self.scale = scale @@ -295,7 +292,6 @@ def __init__(self, config, layer_number=None): self.head_dim = self.hidden_size // self.num_heads self.split_size = self.hidden_size self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - self.masked_softmax_fusion = config.masked_softmax_fusion self.hidden_dropout = config.hidden_dropout if self.head_dim * self.num_heads != self.hidden_size: @@ -310,7 +306,6 @@ def __init__(self, config, layer_number=None): # Scaled Softmax self.scale_mask_softmax = BloomScaledSoftmax( - self.masked_softmax_fusion, self.attention_softmax_in_fp32, self.layer_number, ) From 7ed70e41cf50961c08f631de947deb4f17d36278 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Thu, 7 Jul 2022 15:53:58 +0200 Subject: [PATCH 12/31] refactor a bit - use [:, None] for consistency --- src/transformers/models/bloom/modeling_bloom.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 2595ba7799f3..778839cee720 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -152,9 +152,9 @@ def build_alibi_tensor( # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_flax_t5.py#L426 # batch_size = 1, n_head = n_head, query_length - arange_tensor = (attention_mask.cumsum(-1)[:, None, :].to(device) - 1) * attention_mask.unsqueeze(1) + arange_tensor = (attention_mask.cumsum(-1)[:, None, :].to(device) - 1) * attention_mask[:, None] alibi = slopes.unsqueeze(-1) * arange_tensor - alibi = alibi * attention_mask.unsqueeze(1) + alibi = alibi * attention_mask[:, None] return alibi.reshape(alibi.shape[0] * n_head, 1, -1).to(dtype) From 6b8fb3933874adae657bb813e9f19c058d248fed Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 8 Jul 2022 17:54:54 +0200 Subject: [PATCH 13/31] refactor attention block Co-authored-by: Nouamane Tazi --- .../models/bloom/modeling_bloom.py | 174 +++++------------- 1 file changed, 49 insertions(+), 125 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 778839cee720..5fd531994c18 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -68,32 +68,6 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ return expanded_mask -def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): - """Split a tensor along its last dimension. - - Args: - tensor ([`torch.tensor`], *required*): - input tensor to split - num_partitions ([`int`], *required*): - number of partitions to split the tensor - contiguous_split_chunks ([`bool`], *optional*, default=`False`):: - If True, make each chunk contiguous in memory. - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - numerator, denominator = tensor.size()[last_dim], num_partitions - if not (numerator % denominator == 0): - raise ValueError(f"{numerator} is not divisible by {denominator}") - last_dim_size = numerator // denominator - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. @@ -240,46 +214,6 @@ def forward(self, x): return bloom_gelu_forward(x) -class BloomScaledSoftmax(nn.Module): - """ - Args: - fused operation: scaling + mask + softmax - input_in_fp16 (`bool`, *required*): - flag to indicate if input in fp16 data format. - input_in_bf16 (`bool`, *required*): - flag to indicate if input in bf16 data format. - softmax_in_fp32 (`bool`, *required*): - if true, softmax in performed at fp32 precision. - scale (`float`, *required*): - scaling factor used in input tensor scaling. - """ - - def __init__(self, softmax_in_fp32, scale): - super().__init__() - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - - if not (self.scale is None or softmax_in_fp32): - raise ValueError("softmax should be in fp32 when scaled") - - def forward(self, input, causal_mask): - input_dtype = input.dtype - input_in_16bit = input_dtype in [torch.float16, torch.bfloat16] - softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype - - if self.scale is not None: - input = input * self.scale - - attn_weights = input + causal_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - probs = nn.functional.softmax(attn_weights, dim=-1, dtype=softmax_dtype) * (~causal_mask.bool()) - - if input_in_16bit and self.softmax_in_fp32: - probs = probs.to(dtype=input_dtype) - - return probs - - class BloomAttention(nn.Module): def __init__(self, config, layer_number=None): super().__init__() @@ -291,7 +225,6 @@ def __init__(self, config, layer_number=None): self.num_heads = config.n_head self.head_dim = self.hidden_size // self.num_heads self.split_size = self.hidden_size - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 self.hidden_dropout = config.hidden_dropout if self.head_dim * self.num_heads != self.hidden_size: @@ -304,16 +237,39 @@ def __init__(self, config, layer_number=None): self.layer_number = max(1, layer_number) self.norm_factor = math.sqrt(self.head_dim) * self.layer_number - # Scaled Softmax - self.scale_mask_softmax = BloomScaledSoftmax( - self.attention_softmax_in_fp32, - self.layer_number, - ) - self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) self.dense = nn.Linear(self.hidden_size, self.hidden_size) self.attention_dropout = nn.Dropout(config.attention_dropout) + def _split_heads(self, fused_qkv): + """ + Split the last dimension into (num_heads, head_dim) + """ + new_tensor_shape = fused_qkv.size()[1:-1] + (fused_qkv.size()[0] * self.num_heads, 3 * self.head_dim) + # new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1)) + fused_qkv = fused_qkv.transpose(1, 0) + fused_qkv = fused_qkv.reshape(*new_tensor_shape) + return torch.split(fused_qkv, self.head_dim, -1) + + def _split_attention(self, matmul_result): + return matmul_result.view( + matmul_result.size(0) // self.num_heads, self.num_heads, matmul_result.size(1), matmul_result.size(2) + ) + + def _merge_heads(self, x): + # What we want to achieve is: + # batch_size * num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads * head_dim + + # First view to decompose the batch size + # batch_size*num_heads, seq_len, head_dim -> batch_size, num_heads, seq_len, head_dim + x = x.view(x.size(0) // self.num_heads, self.num_heads, x.size(1), self.head_dim) + + # batch_size, num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads, head_dim + x = x.permute(0, 2, 1, 3) + + # batch_size, seq_len, num_heads, head_dim -> batch_size, seq_len, num_heads * head_dim + return x.reshape(x.size(0), x.size(1), self.num_heads * self.head_dim) + def forward( self, hidden_states, @@ -325,63 +281,43 @@ def forward( use_cache=False, output_attentions=False, ): - mixed_x_layer = self.query_key_value(hidden_states) - - # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim] - new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + fused_qkv = self.query_key_value(hidden_states) - # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + # [batch_size, seq_length, 3 x hidden_size] --> 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + query_layer = query_layer * (1 / math.sqrt(self.head_dim)) if layer_past is not None: past_key, past_value = layer_past - key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) - value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1) + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1) + value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-1) if use_cache is True: present = (key_layer, value_layer) else: present = None - # [batch_size, head_dim, q_length, k_length] - output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) - - # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] - query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1) - - # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] - key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1) - - # slice alibi tensor until the query length - sliced_alibi = alibi[: output_size[0] * output_size[1], :, : output_size[3]] - - # Raw attention scores. [batch_size * num_heads, q_length, k_length] - beta = 1.0 / self.layer_number - - matmul_result = torch.baddbmm( - sliced_alibi, - query_layer.transpose(1, 0), - key_layer.transpose(1, 0).transpose(1, 2), - beta=beta, - alpha=(1.0 / self.norm_factor), - ) + # [num_heads*batch_size, q_length, k_length] + matmul_result = torch.bmm(query_layer.transpose(1, 0), key_layer.permute(1, 2, 0)) + alibi # change view to [batch_size, num_heads, q_length, k_length] - attention_scores = matmul_result.view(*output_size) + attention_scores = self._split_attention(matmul_result) # attention scores and attention mask [b, np, sq, sk] - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask).to(value_layer.dtype) + input_dtype = attention_scores.dtype + attn_weights = attention_scores + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) * ( + ~attention_mask.bool() + ) + # attention_probs = self.scale_mask_softmax(attention_scores, attention_mask).to(value_layer.dtype) attention_probs = self.attention_dropout(attention_probs) if head_mask is not None: attention_probs = attention_probs * head_mask # context layer shape: [batch_size, num_heads, q_length, head_dim] - output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - - # change view [k_length, batch_size x num_heads, head_dim] - value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1) + output_size = (value_layer.size(1) // self.num_heads, self.num_heads, value_layer.size(0), self.head_dim) # change view [batch_size x num_heads, q_length, k_length] attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) @@ -390,17 +326,7 @@ def forward( context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1)) # change view [batch_size, num_heads, q_length, head_dim] - context_layer = context_layer.view(*output_size) - - # [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) - - context_layer = context_layer.view(*new_context_layer_shape) - - # Output. [q_length, batch_size, hidden_size] + context_layer = self._merge_heads(context_layer) # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 if self.pretraining_tp > 1 and self.slow_but_exact: @@ -414,11 +340,9 @@ def forward( else: output_tensor = self.dense(context_layer) - output = output_tensor.transpose(1, 0) - - output = dropout_add(output, residual, self.hidden_dropout, self.training) + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) - outputs = (output, present) + outputs = (output_tensor, present) if output_attentions: outputs += (attention_probs,) From 68a4d39ca30132493a8e99bb3f07b269c9f8d223 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Fri, 8 Jul 2022 19:02:36 +0200 Subject: [PATCH 14/31] quick fixes --- src/transformers/models/bloom/modeling_bloom.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 5fd531994c18..15225deae20c 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -283,14 +283,14 @@ def forward( ): fused_qkv = self.query_key_value(hidden_states) - # [batch_size, seq_length, 3 x hidden_size] --> 3 x [batch_size, seq_length, num_heads, head_dim] + # [batch_size, seq_length, 3 x hidden_size] --> 3 x [seq_length, batch_size * num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) query_layer = query_layer * (1 / math.sqrt(self.head_dim)) if layer_past is not None: past_key, past_value = layer_past - key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1) - value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-1) + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) + value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) if use_cache is True: present = (key_layer, value_layer) @@ -320,7 +320,7 @@ def forward( output_size = (value_layer.size(1) // self.num_heads, self.num_heads, value_layer.size(0), self.head_dim) # change view [batch_size x num_heads, q_length, k_length] - attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], -1, output_size[2]) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1)) @@ -671,7 +671,7 @@ def forward( current_sequence_length = hidden_states.shape[1] past_key_values_length = 0 if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[1] + past_key_values_length = past_key_values[0][0].shape[0] current_sequence_length += past_key_values_length if attention_mask is None: From 53bef9b7274632678df14fbaee45662d9b562bf6 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 8 Jul 2022 19:43:57 +0200 Subject: [PATCH 15/31] first attempt --- src/transformers/models/bloom/modeling_bloom.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 15225deae20c..6a124c86e53a 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -245,10 +245,11 @@ def _split_heads(self, fused_qkv): """ Split the last dimension into (num_heads, head_dim) """ - new_tensor_shape = fused_qkv.size()[1:-1] + (fused_qkv.size()[0] * self.num_heads, 3 * self.head_dim) + new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim) # new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1)) - fused_qkv = fused_qkv.transpose(1, 0) + # fused_qkv = fused_qkv.permute(1, 0) fused_qkv = fused_qkv.reshape(*new_tensor_shape) + fused_qkv = fused_qkv.permute(0, 2, 1, 3) return torch.split(fused_qkv, self.head_dim, -1) def _split_attention(self, matmul_result): @@ -256,6 +257,9 @@ def _split_attention(self, matmul_result): matmul_result.size(0) // self.num_heads, self.num_heads, matmul_result.size(1), matmul_result.size(2) ) + def _merge_batch(self, x): + return x.reshape(x.size(0)*x.size(1), x.size(3), x.size(2)) + def _merge_heads(self, x): # What we want to achieve is: # batch_size * num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads * head_dim @@ -283,19 +287,21 @@ def forward( ): fused_qkv = self.query_key_value(hidden_states) - # [batch_size, seq_length, 3 x hidden_size] --> 3 x [seq_length, batch_size * num_heads, head_dim] + # [batch_size, seq_length, 3 x hidden_size] --> 3 x [seq_length, batch_size, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) query_layer = query_layer * (1 / math.sqrt(self.head_dim)) if layer_past is not None: past_key, past_value = layer_past - key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) - value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) + value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1) if use_cache is True: present = (key_layer, value_layer) else: present = None + + query_layer, key_layer, value_layer = self._merge_batch(query_layer), self._merge_batch(key_layer), self._merge_batch(value_layer) # [num_heads*batch_size, q_length, k_length] matmul_result = torch.bmm(query_layer.transpose(1, 0), key_layer.permute(1, 2, 0)) + alibi @@ -791,7 +797,6 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): "input_ids": input_ids, "past_key_values": past, "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, "attention_mask": attention_mask, } From 773d8e780fea41b8a8f77bf2bccbfbeacc91d50d Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Sat, 9 Jul 2022 14:17:01 +0200 Subject: [PATCH 16/31] refactor attention block and fix all tests except "test_simple_generation" - added comments to better explain attention block --- .../models/bloom/modeling_bloom.py | 69 +++++++++++-------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 6a124c86e53a..5251636ef85e 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -247,19 +247,11 @@ def _split_heads(self, fused_qkv): """ new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim) # new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1)) - # fused_qkv = fused_qkv.permute(1, 0) + # fused_qkv = fused_qkv.transpose(1, 0) fused_qkv = fused_qkv.reshape(*new_tensor_shape) - fused_qkv = fused_qkv.permute(0, 2, 1, 3) + # fused_qkv = fused_qkv.permute(0, 2, 1, 3) return torch.split(fused_qkv, self.head_dim, -1) - def _split_attention(self, matmul_result): - return matmul_result.view( - matmul_result.size(0) // self.num_heads, self.num_heads, matmul_result.size(1), matmul_result.size(2) - ) - - def _merge_batch(self, x): - return x.reshape(x.size(0)*x.size(1), x.size(3), x.size(2)) - def _merge_heads(self, x): # What we want to achieve is: # batch_size * num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads * head_dim @@ -285,51 +277,72 @@ def forward( use_cache=False, output_attentions=False, ): - fused_qkv = self.query_key_value(hidden_states) + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - # [batch_size, seq_length, 3 x hidden_size] --> 3 x [seq_length, batch_size, num_heads, head_dim] + # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) query_layer = query_layer * (1 / math.sqrt(self.head_dim)) if layer_past is not None: past_key, past_value = layer_past - key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) - value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1) + # concatenate along seq_length dimension + key_layer = torch.cat( + (past_key.type_as(key_layer), key_layer), dim=1 + ) # [batch_size, k_length, num_heads, head_dim] + value_layer = torch.cat( + (past_value.type_as(value_layer), value_layer), dim=1 + ) # [batch_size, k_length, num_heads, head_dim] if use_cache is True: present = (key_layer, value_layer) else: present = None - - query_layer, key_layer, value_layer = self._merge_batch(query_layer), self._merge_batch(key_layer), self._merge_batch(value_layer) - # [num_heads*batch_size, q_length, k_length] - matmul_result = torch.bmm(query_layer.transpose(1, 0), key_layer.permute(1, 2, 0)) + alibi + q = query_layer.transpose(1, 2).reshape(-1, query_layer.shape[1], query_layer.shape[3]) + k = key_layer.permute(0, 2, 3, 1).reshape(-1, key_layer.shape[3], key_layer.shape[1]) + + + # [batch_size*num_heads, q_length, k_length] + matmul_result = ( + torch.bmm( + query_layer.transpose(1, 2).reshape( + -1, query_layer.shape[1], query_layer.shape[3] + ), # [batch_size*num_heads, q_length, head_dim] + key_layer.permute(0, 2, 3, 1).reshape( + -1, key_layer.shape[3], key_layer.shape[1] + ), # [batch_size*num_heads, head_dim, k_length] + ) + + alibi + ) + + import joblib + joblib.dump(matmul_result, "matmul_result.pkl") + joblib.dump(q, "q.pkl") + joblib.dump(k, "k.pkl") + joblib.dump(alibi, "alibi.pkl") # change view to [batch_size, num_heads, q_length, k_length] - attention_scores = self._split_attention(matmul_result) + attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2)) - # attention scores and attention mask [b, np, sq, sk] input_dtype = attention_scores.dtype - attn_weights = attention_scores + attention_mask + attn_weights = attention_scores + attention_mask # [batch_size, num_heads, q_length, k_length] attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) * ( ~attention_mask.bool() - ) + ) # [batch_size, num_heads, q_length, k_length] # attention_probs = self.scale_mask_softmax(attention_scores, attention_mask).to(value_layer.dtype) attention_probs = self.attention_dropout(attention_probs) if head_mask is not None: attention_probs = attention_probs * head_mask - # context layer shape: [batch_size, num_heads, q_length, head_dim] - output_size = (value_layer.size(1) // self.num_heads, self.num_heads, value_layer.size(0), self.head_dim) - # change view [batch_size x num_heads, q_length, k_length] - attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], -1, output_size[2]) + attention_probs_reshaped = attention_probs.view(*matmul_result.shape) # matmul: [batch_size * num_heads, q_length, head_dim] - context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1)) + context_layer = torch.bmm( + attention_probs_reshaped, value_layer.transpose(1, 2).reshape(-1, value_layer.size(1), value_layer.size(3)) + ) # change view [batch_size, num_heads, q_length, head_dim] context_layer = self._merge_heads(context_layer) @@ -677,7 +690,7 @@ def forward( current_sequence_length = hidden_states.shape[1] past_key_values_length = 0 if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[0] + past_key_values_length = past_key_values[0][0].shape[1] current_sequence_length += past_key_values_length if attention_mask is None: From ed93b22757d58eb6f816860bb2e17d3f799de4fd Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Sun, 10 Jul 2022 19:39:01 +0200 Subject: [PATCH 17/31] remove debug lines and add TODO comment --- src/transformers/models/bloom/modeling_bloom.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 5251636ef85e..d342516e46b9 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -313,13 +313,7 @@ def forward( ), # [batch_size*num_heads, head_dim, k_length] ) + alibi - ) - - import joblib - joblib.dump(matmul_result, "matmul_result.pkl") - joblib.dump(q, "q.pkl") - joblib.dump(k, "k.pkl") - joblib.dump(alibi, "alibi.pkl") + ) # TODO: this doesn't give same results as torch.baddbmm() for fp16 # change view to [batch_size, num_heads, q_length, k_length] attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2)) From 91dfee56c0d6dec769f31e0db626e7af00c5abf4 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Sun, 10 Jul 2022 20:47:56 +0200 Subject: [PATCH 18/31] change `torch.bmm` to `torch.baddbmm` - fixes `test_simple_generation`but breaks `test_batch_generation_padd` --- src/transformers/models/bloom/modeling_bloom.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index d342516e46b9..11da2a444772 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -297,14 +297,11 @@ def forward( present = (key_layer, value_layer) else: present = None - - q = query_layer.transpose(1, 2).reshape(-1, query_layer.shape[1], query_layer.shape[3]) - k = key_layer.permute(0, 2, 3, 1).reshape(-1, key_layer.shape[3], key_layer.shape[1]) - - + # [batch_size*num_heads, q_length, k_length] matmul_result = ( - torch.bmm( + torch.baddbmm( + alibi, query_layer.transpose(1, 2).reshape( -1, query_layer.shape[1], query_layer.shape[3] ), # [batch_size*num_heads, q_length, head_dim] @@ -312,8 +309,7 @@ def forward( -1, key_layer.shape[3], key_layer.shape[1] ), # [batch_size*num_heads, head_dim, k_length] ) - + alibi - ) # TODO: this doesn't give same results as torch.baddbmm() for fp16 + ) # change view to [batch_size, num_heads, q_length, k_length] attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2)) From 2272cb0c4efee0ac65974590b34cd9593af51fdf Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Sun, 10 Jul 2022 22:09:07 +0200 Subject: [PATCH 19/31] styling --- .../models/bloom/modeling_bloom.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 11da2a444772..9b8d3c47792e 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -297,18 +297,16 @@ def forward( present = (key_layer, value_layer) else: present = None - + # [batch_size*num_heads, q_length, k_length] - matmul_result = ( - torch.baddbmm( - alibi, - query_layer.transpose(1, 2).reshape( - -1, query_layer.shape[1], query_layer.shape[3] - ), # [batch_size*num_heads, q_length, head_dim] - key_layer.permute(0, 2, 3, 1).reshape( - -1, key_layer.shape[3], key_layer.shape[1] - ), # [batch_size*num_heads, head_dim, k_length] - ) + matmul_result = torch.baddbmm( + alibi, + query_layer.transpose(1, 2).reshape( + -1, query_layer.shape[1], query_layer.shape[3] + ), # [batch_size*num_heads, q_length, head_dim] + key_layer.permute(0, 2, 3, 1).reshape( + -1, key_layer.shape[3], key_layer.shape[1] + ), # [batch_size*num_heads, head_dim, k_length] ) # change view to [batch_size, num_heads, q_length, k_length] From eb86c4369d465d6181a865e43fd198b28c0b7a02 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Mon, 11 Jul 2022 15:52:03 +0200 Subject: [PATCH 20/31] all tests are passing now - use `bmm` - add explanation for `allow_fp16_reduced_precision_reduction` Co-authored-by: Younes Belkada --- .../models/bloom/modeling_bloom.py | 52 +++++++++++++++---- tests/models/bloom/test_modeling_bloom.py | 45 +++++++++++++++- 2 files changed, 86 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 9b8d3c47792e..c2fb44d324a2 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -36,6 +36,9 @@ logger = logging.get_logger(__name__) +if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction == False: + logger.info("allow_fp16_reduced_precision_reduction is set to False, this can lead to slightly inconsistent results (batched/cached) under half precision mode. Use it at your own risk.") + _CHECKPOINT_FOR_DOC = "bigscience/Bloom" _CONFIG_FOR_DOC = "BloomConfig" _TOKENIZER_FOR_DOC = "BloomTokenizerFast" @@ -281,7 +284,7 @@ def forward( # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - query_layer = query_layer * (1 / math.sqrt(self.head_dim)) + # query_layer = query_layer * (1 / math.sqrt(self.head_dim)) if layer_past is not None: past_key, past_value = layer_past @@ -298,22 +301,51 @@ def forward( else: present = None + beta = 1.0 / self.layer_number # [batch_size*num_heads, q_length, k_length] - matmul_result = torch.baddbmm( - alibi, - query_layer.transpose(1, 2).reshape( - -1, query_layer.shape[1], query_layer.shape[3] - ), # [batch_size*num_heads, q_length, head_dim] - key_layer.permute(0, 2, 3, 1).reshape( - -1, key_layer.shape[3], key_layer.shape[1] - ), # [batch_size*num_heads, head_dim, k_length] + # matmul_result = torch.baddbmm( + # alibi, + # query_layer.transpose(1, 2).reshape( + # -1, query_layer.shape[1], query_layer.shape[3] + # ), # [batch_size*num_heads, q_length, head_dim] + # key_layer.permute(0, 2, 3, 1).reshape( + # -1, key_layer.shape[3], key_layer.shape[1] + # ), # [batch_size*num_heads, head_dim, k_length] + # beta=beta, + # alpha=(1.0 / self.norm_factor), + # ) + + + matmul_result = ( + (1.0 / self.norm_factor) * torch.bmm( + query_layer.transpose(1, 2).reshape( + -1, query_layer.shape[1], query_layer.shape[3] + ), # [batch_size*num_heads, q_length, head_dim] + key_layer.permute(0, 2, 3, 1).reshape( + -1, key_layer.shape[3], key_layer.shape[1] + ), # [batch_size*num_heads, head_dim, k_length] + ) + + beta * alibi ) + # matmul_result = ( + # torch.bmm( + # query_layer.transpose(1, 2).reshape( + # -1, query_layer.shape[1], query_layer.shape[3] + # ), # [batch_size*num_heads, q_length, head_dim] + # key_layer.permute(0, 2, 3, 1).reshape( + # -1, key_layer.shape[3], key_layer.shape[1] + # ), # [batch_size*num_heads, head_dim, k_length] + # ) + # + alibi + # ) + + # change view to [batch_size, num_heads, q_length, k_length] attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2)) input_dtype = attention_scores.dtype - attn_weights = attention_scores + attention_mask # [batch_size, num_heads, q_length, k_length] + attn_weights = (attention_scores * self.layer_number) + attention_mask # [batch_size, num_heads, q_length, k_length] attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) * ( ~attention_mask.bool() diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index 0b2501982d52..0a50e67cb3d6 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -377,6 +377,24 @@ def test_model_from_pretrained(self): @slow @require_torch_gpu def test_simple_generation(self): + # This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations + # do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200 + # We set allow_fp16_reduced_precision_reduction = True. Please see: https://pytorch.org/docs/stable/notes/cuda.html#reduced-precision-reduction-in-fp16-gemms + # This discrepancy is observed only when using small models and seems to be stable for larger models. + # Our conclusion is that these operations are flaky for small inputs but seems to be stable for larger inputs (for the functions `baddmm` and `bmm`), and therefore for larger models. + + # Here is a summary of an ablation study of our observations + # EXPECTED_OUTPUT = "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am a very good listener. I am a very good person, and I am a very good person. I am a" + # 350m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS + # 350m + allow_fp16_reduced_precision_reduction = False + torch.baddm ==> PASS + # 350m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS + # 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> FAIL + + # EXPECTED_OUTPUT = "I enjoy walking with my cute dog, but I also enjoy hiking, biking, and swimming. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love" + # >=760m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS (for use_cache=True and use_cache=False) + # >=760m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS + # >=760m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS + path_350m = "bigscience/bloom-350m" model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda() model = model.eval() @@ -386,11 +404,19 @@ def test_simple_generation(self): EXPECTED_OUTPUT = ( "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am" " a very good listener. I am a very good person, and I am a very good person. I am a" - ) + ) # for bloom-350m + + # EXPECTED_OUTPUT = ( + # "I enjoy walking with my cute dog, but I also enjoy hiking, biking, and swimming. I love to cook and bake." + # " I love to cook and bake. I love to cook and bake. I love to cook and bake. I love" + # ) # for bloom-760m input_ids = tokenizer.encode(input_sentence, return_tensors="pt") + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False greedy_output = model.generate(input_ids.cuda(), max_length=50) + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + print(tokenizer.decode(greedy_output[0], skip_special_tokens=True)) self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT) @slow @@ -416,6 +442,23 @@ def test_batch_generation(self): @slow @require_torch_gpu def test_batch_generation_padd(self): + # With small models the test will fail because of the operator torch.baddm that will give inconsistent results + # for small models (<350m). + + # 350m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> FAIL + # 350m + allow_fp16_reduced_precision_reduction = False + torch.baddm ==> FAIL + # 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS + # 350m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> FAIL + + # >=760m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS + # >=760m + allow_fp16_reduced_precision_reduction = False + torch.baddm ==> PASS + # >=760m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS + # >=760m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS + + if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction == False: + # warning this could fail + pass + path_350m = "bigscience/bloom-350m" model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda() model = model.eval() From 3ba1bd20c29e8cf438bea51ffaf380e6c516091e Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Mon, 11 Jul 2022 16:07:17 +0200 Subject: [PATCH 21/31] styling Co-authored-by: Younes Belkada --- .../models/bloom/modeling_bloom.py | 63 ++++++------------- tests/models/bloom/test_modeling_bloom.py | 20 +++--- 2 files changed, 28 insertions(+), 55 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index c2fb44d324a2..53f9f509fdeb 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -36,8 +36,11 @@ logger = logging.get_logger(__name__) -if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction == False: - logger.info("allow_fp16_reduced_precision_reduction is set to False, this can lead to slightly inconsistent results (batched/cached) under half precision mode. Use it at your own risk.") +if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction is False: + logger.info( + "allow_fp16_reduced_precision_reduction is set to False, this can lead to slightly inconsistent results" + " (batched/cached) under half precision mode. Use it at your own risk." + ) _CHECKPOINT_FOR_DOC = "bigscience/Bloom" _CONFIG_FOR_DOC = "BloomConfig" @@ -284,7 +287,6 @@ def forward( # 3 x [batch_size, seq_length, num_heads, head_dim] (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - # query_layer = query_layer * (1 / math.sqrt(self.head_dim)) if layer_past is not None: past_key, past_value = layer_past @@ -302,55 +304,30 @@ def forward( present = None beta = 1.0 / self.layer_number - # [batch_size*num_heads, q_length, k_length] - # matmul_result = torch.baddbmm( - # alibi, - # query_layer.transpose(1, 2).reshape( - # -1, query_layer.shape[1], query_layer.shape[3] - # ), # [batch_size*num_heads, q_length, head_dim] - # key_layer.permute(0, 2, 3, 1).reshape( - # -1, key_layer.shape[3], key_layer.shape[1] - # ), # [batch_size*num_heads, head_dim, k_length] - # beta=beta, - # alpha=(1.0 / self.norm_factor), - # ) - - - matmul_result = ( - (1.0 / self.norm_factor) * torch.bmm( - query_layer.transpose(1, 2).reshape( - -1, query_layer.shape[1], query_layer.shape[3] - ), # [batch_size*num_heads, q_length, head_dim] - key_layer.permute(0, 2, 3, 1).reshape( - -1, key_layer.shape[3], key_layer.shape[1] - ), # [batch_size*num_heads, head_dim, k_length] - ) - + beta * alibi - ) - - # matmul_result = ( - # torch.bmm( - # query_layer.transpose(1, 2).reshape( - # -1, query_layer.shape[1], query_layer.shape[3] - # ), # [batch_size*num_heads, q_length, head_dim] - # key_layer.permute(0, 2, 3, 1).reshape( - # -1, key_layer.shape[3], key_layer.shape[1] - # ), # [batch_size*num_heads, head_dim, k_length] - # ) - # + alibi - # ) + # [batch_size*num_heads, q_length, k_length] + matmul_result = (1.0 / self.norm_factor) * torch.bmm( + query_layer.transpose(1, 2).reshape( + -1, query_layer.shape[1], query_layer.shape[3] + ), # [batch_size*num_heads, q_length, head_dim] + key_layer.permute(0, 2, 3, 1).reshape( + -1, key_layer.shape[3], key_layer.shape[1] + ), # [batch_size*num_heads, head_dim, k_length] + ) + beta * alibi # change view to [batch_size, num_heads, q_length, k_length] attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2)) + # We replace the scaled softmax by just a few line of code input_dtype = attention_scores.dtype - attn_weights = (attention_scores * self.layer_number) + attention_mask # [batch_size, num_heads, q_length, k_length] + attn_weights = ( + attention_scores * self.layer_number + ) + attention_mask # [batch_size, num_heads, q_length, k_length] attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) * ( ~attention_mask.bool() - ) # [batch_size, num_heads, q_length, k_length] - # attention_probs = self.scale_mask_softmax(attention_scores, attention_mask).to(value_layer.dtype) + ) + # [batch_size, num_heads, q_length, k_length] attention_probs = self.attention_dropout(attention_probs) if head_mask is not None: diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index 0a50e67cb3d6..5522e2364984 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -377,10 +377,10 @@ def test_model_from_pretrained(self): @slow @require_torch_gpu def test_simple_generation(self): - # This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations + # This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations # do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200 # We set allow_fp16_reduced_precision_reduction = True. Please see: https://pytorch.org/docs/stable/notes/cuda.html#reduced-precision-reduction-in-fp16-gemms - # This discrepancy is observed only when using small models and seems to be stable for larger models. + # This discrepancy is observed only when using small models and seems to be stable for larger models. # Our conclusion is that these operations are flaky for small inputs but seems to be stable for larger inputs (for the functions `baddmm` and `bmm`), and therefore for larger models. # Here is a summary of an ablation study of our observations @@ -401,11 +401,11 @@ def test_simple_generation(self): tokenizer = BloomTokenizerFast.from_pretrained(path_350m) input_sentence = "I enjoy walking with my cute dog" - EXPECTED_OUTPUT = ( + EXPECTED_OUTPUT = ( # for bloom-350m "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am" " a very good listener. I am a very good person, and I am a very good person. I am a" - ) # for bloom-350m - + ) + # EXPECTED_OUTPUT = ( # "I enjoy walking with my cute dog, but I also enjoy hiking, biking, and swimming. I love to cook and bake." # " I love to cook and bake. I love to cook and bake. I love to cook and bake. I love" @@ -443,22 +443,18 @@ def test_batch_generation(self): @require_torch_gpu def test_batch_generation_padd(self): # With small models the test will fail because of the operator torch.baddm that will give inconsistent results - # for small models (<350m). - + # for small models (<350m). + # 350m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> FAIL # 350m + allow_fp16_reduced_precision_reduction = False + torch.baddm ==> FAIL # 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS # 350m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> FAIL - + # >=760m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS # >=760m + allow_fp16_reduced_precision_reduction = False + torch.baddm ==> PASS # >=760m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS # >=760m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS - if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction == False: - # warning this could fail - pass - path_350m = "bigscience/bloom-350m" model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda() model = model.eval() From 18cc4d4a003e4211b3251532e56a1a221b220ca4 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Mon, 11 Jul 2022 16:17:48 +0200 Subject: [PATCH 22/31] fix support for accelerate Co-authored-by: Younes Belkada --- src/transformers/models/bloom/modeling_bloom.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 53f9f509fdeb..fc3b43530719 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -283,6 +283,7 @@ def forward( use_cache=False, output_attentions=False, ): + alibi = alibi.to(hidden_states.device) # to make the model possible to run under accelerate fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, seq_length, num_heads, head_dim] From 1f95e29226a990d22654ef739c8890d586936678 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 11 Jul 2022 18:17:14 +0200 Subject: [PATCH 23/31] Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/bloom/modeling_bloom.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index fc3b43530719..546d90524a68 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -88,12 +88,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) -def build_alibi_tensor( - attention_mask: torch.Tensor, - n_head: int, - dtype, - device, -) -> torch.Tensor: +def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value @@ -102,7 +97,7 @@ def build_alibi_tensor( Args: Returns tensor shaped (batch_size * n_head, 1, max_seq_len) - attention_mask (`torch.Tensor`, *required*): + attention_mask (`torch.Tensor`): Token-wise attention mask, this should be of shape (batch_size, max_seq_len). n_head (`int`, *required*): number of heads From 34af2671ffe22e24011f009e4d9ca5753aefc8b3 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 11 Jul 2022 18:19:04 +0200 Subject: [PATCH 24/31] remove attn softmax in fp32 --- src/transformers/models/bloom/configuration_bloom.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/transformers/models/bloom/configuration_bloom.py b/src/transformers/models/bloom/configuration_bloom.py index 4712b0f18a1b..732636193670 100644 --- a/src/transformers/models/bloom/configuration_bloom.py +++ b/src/transformers/models/bloom/configuration_bloom.py @@ -72,9 +72,6 @@ class BloomConfig(PretrainedConfig): If set to `True`, it will skip bias add for each linear layer in the transformer blocks skip_bias_add_qkv (`bool`, *optional*, defaults to `False`): If set to `True`, it will skip bias add for the first linear layer in the transformer blocks - attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): - If set to `True` and the `dtype` is set to `float16` it will scale the input of the Softmax function to - `fp32` hidden_dropout (`float`, *optional*, defaults to 0.1): Dropout rate of the dropout function on the bias dropout. attention_dropout (`float`, *optional*, defaults to 0.1): @@ -136,7 +133,6 @@ def __init__( apply_residual_connection_post_layernorm=False, hidden_dropout=0.0, attention_dropout=0.0, - attention_softmax_in_fp32=True, pretraining_tp=1, # TP rank used when training with megatron dtype="bfloat16", slow_but_exact=False, @@ -153,7 +149,6 @@ def __init__( self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm self.hidden_dropout = hidden_dropout self.attention_dropout = attention_dropout - self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id From adb67d64b90ce02cf016443f8afad80079214fff Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 11 Jul 2022 18:24:04 +0200 Subject: [PATCH 25/31] refactor comments --- .../models/bloom/modeling_bloom.py | 31 ++++++------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 546d90524a68..723ef6178e93 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -286,13 +286,9 @@ def forward( if layer_past is not None: past_key, past_value = layer_past - # concatenate along seq_length dimension - key_layer = torch.cat( - (past_key.type_as(key_layer), key_layer), dim=1 - ) # [batch_size, k_length, num_heads, head_dim] - value_layer = torch.cat( - (past_value.type_as(value_layer), value_layer), dim=1 - ) # [batch_size, k_length, num_heads, head_dim] + # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] + key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) + value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1) if use_cache is True: present = (key_layer, value_layer) @@ -301,28 +297,21 @@ def forward( beta = 1.0 / self.layer_number - # [batch_size*num_heads, q_length, k_length] + # # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, k_length] -> [batch_size*num_heads, q_length, k_length] matmul_result = (1.0 / self.norm_factor) * torch.bmm( - query_layer.transpose(1, 2).reshape( - -1, query_layer.shape[1], query_layer.shape[3] - ), # [batch_size*num_heads, q_length, head_dim] - key_layer.permute(0, 2, 3, 1).reshape( - -1, key_layer.shape[3], key_layer.shape[1] - ), # [batch_size*num_heads, head_dim, k_length] + query_layer.transpose(1, 2).reshape(-1, query_layer.shape[1], query_layer.shape[3]), + key_layer.permute(0, 2, 3, 1).reshape(-1, key_layer.shape[3], key_layer.shape[1]), ) + beta * alibi # change view to [batch_size, num_heads, q_length, k_length] attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2)) - # We replace the scaled softmax by just a few line of code + # We replace the scaled softmax by just a few line of code - [batch_size, num_heads, q_length, k_length] input_dtype = attention_scores.dtype - attn_weights = ( - attention_scores * self.layer_number - ) + attention_mask # [batch_size, num_heads, q_length, k_length] + attn_weights = (attention_scores * self.layer_number) + attention_mask attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) - attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) * ( - ~attention_mask.bool() - ) + attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) + attention_probs = attention_probs * (~attention_mask.bool()) # [batch_size, num_heads, q_length, k_length] attention_probs = self.attention_dropout(attention_probs) From ed9682dcdd7366e2a95a3e021c37a2434fda6a24 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 11 Jul 2022 18:27:00 +0200 Subject: [PATCH 26/31] refactor a bit - remove warning message - remove print on test --- src/transformers/models/bloom/modeling_bloom.py | 6 ------ tests/models/bloom/test_modeling_bloom.py | 1 - 2 files changed, 7 deletions(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 723ef6178e93..728ebf1e53a3 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -36,12 +36,6 @@ logger = logging.get_logger(__name__) -if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction is False: - logger.info( - "allow_fp16_reduced_precision_reduction is set to False, this can lead to slightly inconsistent results" - " (batched/cached) under half precision mode. Use it at your own risk." - ) - _CHECKPOINT_FOR_DOC = "bigscience/Bloom" _CONFIG_FOR_DOC = "BloomConfig" _TOKENIZER_FOR_DOC = "BloomTokenizerFast" diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index 5522e2364984..83c30d35dab2 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -416,7 +416,6 @@ def test_simple_generation(self): greedy_output = model.generate(input_ids.cuda(), max_length=50) torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True - print(tokenizer.decode(greedy_output[0], skip_special_tokens=True)) self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT) @slow From 3ef948cea8b23688a64af52362d5341841600419 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 11 Jul 2022 18:32:03 +0200 Subject: [PATCH 27/31] refer to pytorch t5 --- src/transformers/models/bloom/modeling_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 728ebf1e53a3..9bf200de5bd3 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -118,7 +118,7 @@ def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) # => the query_length dimension will then be broadcasted correctly # This is more or less identical to T5's relative position bias: - # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_flax_t5.py#L426 + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 # batch_size = 1, n_head = n_head, query_length arange_tensor = (attention_mask.cumsum(-1)[:, None, :].to(device) - 1) * attention_mask[:, None] From 992e318797bb640fdf02e8c4dadd4c4a8e6b44d3 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 11 Jul 2022 18:42:41 +0200 Subject: [PATCH 28/31] change the slow tests - do the tests in fp32 - remove some comments - keep large comments --- tests/models/bloom/test_modeling_bloom.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index 83c30d35dab2..01fcdbcdd069 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -396,7 +396,7 @@ def test_simple_generation(self): # >=760m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS path_350m = "bigscience/bloom-350m" - model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda() + model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda() model = model.eval() tokenizer = BloomTokenizerFast.from_pretrained(path_350m) @@ -412,9 +412,7 @@ def test_simple_generation(self): # ) # for bloom-760m input_ids = tokenizer.encode(input_sentence, return_tensors="pt") - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False greedy_output = model.generate(input_ids.cuda(), max_length=50) - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT) @@ -422,7 +420,7 @@ def test_simple_generation(self): @require_torch_gpu def test_batch_generation(self): path_350m = "bigscience/bloom-350m" - model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda() + model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda() model = model.eval() tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left") @@ -441,18 +439,6 @@ def test_batch_generation(self): @slow @require_torch_gpu def test_batch_generation_padd(self): - # With small models the test will fail because of the operator torch.baddm that will give inconsistent results - # for small models (<350m). - - # 350m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> FAIL - # 350m + allow_fp16_reduced_precision_reduction = False + torch.baddm ==> FAIL - # 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS - # 350m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> FAIL - - # >=760m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS - # >=760m + allow_fp16_reduced_precision_reduction = False + torch.baddm ==> PASS - # >=760m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS - # >=760m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS path_350m = "bigscience/bloom-350m" model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda() From 0b3db537f3aac3dacadb69dc32e32c91c23ccdc3 Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Mon, 11 Jul 2022 18:48:10 +0200 Subject: [PATCH 29/31] update expected output for `test_simple_generation` - we now test using fp32 --- tests/models/bloom/test_modeling_bloom.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index 01fcdbcdd069..63aa06a0d6c8 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -402,8 +402,8 @@ def test_simple_generation(self): input_sentence = "I enjoy walking with my cute dog" EXPECTED_OUTPUT = ( # for bloom-350m - "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am" - " a very good listener. I am a very good person, and I am a very good person. I am a" + "I enjoy walking with my cute dog, and I love to watch the kids play with the kids. I am a very " + "active person, and I enjoy working out, and I am a very active person. I am a very active person, and I" ) # EXPECTED_OUTPUT = ( @@ -413,7 +413,7 @@ def test_simple_generation(self): input_ids = tokenizer.encode(input_sentence, return_tensors="pt") greedy_output = model.generate(input_ids.cuda(), max_length=50) - + self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT) @slow From 81bf622cce2bef47d1d4318bb4ed2c613c1f2edb Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 11 Jul 2022 18:52:23 +0200 Subject: [PATCH 30/31] make style + change comments a bit --- tests/models/bloom/test_modeling_bloom.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index 63aa06a0d6c8..d9f84a1a6a55 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -401,19 +401,15 @@ def test_simple_generation(self): tokenizer = BloomTokenizerFast.from_pretrained(path_350m) input_sentence = "I enjoy walking with my cute dog" - EXPECTED_OUTPUT = ( # for bloom-350m + # This output has been obtained using fp32 model on the huggingface DGX workstation - NVIDIA A100 GPU + EXPECTED_OUTPUT = ( "I enjoy walking with my cute dog, and I love to watch the kids play with the kids. I am a very " "active person, and I enjoy working out, and I am a very active person. I am a very active person, and I" ) - # EXPECTED_OUTPUT = ( - # "I enjoy walking with my cute dog, but I also enjoy hiking, biking, and swimming. I love to cook and bake." - # " I love to cook and bake. I love to cook and bake. I love to cook and bake. I love" - # ) # for bloom-760m - input_ids = tokenizer.encode(input_sentence, return_tensors="pt") greedy_output = model.generate(input_ids.cuda(), max_length=50) - + self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT) @slow From bf311482284525423f9cf58f15cb006bcf52bb4d Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 11 Jul 2022 19:03:17 +0200 Subject: [PATCH 31/31] fix dtype padd test --- tests/models/bloom/test_modeling_bloom.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index d9f84a1a6a55..ffe3247f3dd7 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -437,7 +437,7 @@ def test_batch_generation(self): def test_batch_generation_padd(self): path_350m = "bigscience/bloom-350m" - model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda() + model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda() model = model.eval() tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")