diff --git a/src/transformers/models/bloom/configuration_bloom.py b/src/transformers/models/bloom/configuration_bloom.py index f7929dee8afa..732636193670 100644 --- a/src/transformers/models/bloom/configuration_bloom.py +++ b/src/transformers/models/bloom/configuration_bloom.py @@ -72,9 +72,6 @@ class BloomConfig(PretrainedConfig): If set to `True`, it will skip bias add for each linear layer in the transformer blocks skip_bias_add_qkv (`bool`, *optional*, defaults to `False`): If set to `True`, it will skip bias add for the first linear layer in the transformer blocks - attention_softmax_in_fp32 (`bool`, *optional*, defaults to `True`): - If set to `True` and the `dtype` is set to `float16` it will scale the input of the Softmax function to - `fp32` hidden_dropout (`float`, *optional*, defaults to 0.1): Dropout rate of the dropout function on the bias dropout. attention_dropout (`float`, *optional*, defaults to 0.1): @@ -128,7 +125,6 @@ def __init__( hidden_size=64, n_layer=2, n_head=8, - masked_softmax_fusion=True, layer_norm_epsilon=1e-5, initializer_range=0.02, use_cache=False, @@ -137,7 +133,6 @@ def __init__( apply_residual_connection_post_layernorm=False, hidden_dropout=0.0, attention_dropout=0.0, - attention_softmax_in_fp32=True, pretraining_tp=1, # TP rank used when training with megatron dtype="bfloat16", slow_but_exact=False, @@ -147,7 +142,6 @@ def __init__( self.hidden_size = hidden_size self.n_layer = n_layer self.n_head = n_head - self.masked_softmax_fusion = masked_softmax_fusion self.layer_norm_epsilon = layer_norm_epsilon self.initializer_range = initializer_range self.use_cache = use_cache @@ -155,7 +149,6 @@ def __init__( self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm self.hidden_dropout = hidden_dropout self.attention_dropout = attention_dropout - self.attention_softmax_in_fp32 = attention_softmax_in_fp32 self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 315ea90627c2..9bf200de5bd3 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -51,49 +51,38 @@ ] -def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): - """Split a tensor along its last dimension. +def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): + """ + Make causal mask used for bi-directional self-attention. + """ + batch_size, target_length = input_ids_shape + mask = torch.full((target_length, target_length), torch.finfo(dtype).min) + mask_cond = torch.arange(mask.size(-1)) + intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1) + mask.masked_fill_(intermediate_mask, 0) + mask = mask.to(dtype) - Args: - tensor: ([`torch.tensor`], *required*): - input tensor to split - num_partitions ([`int`], *required*): - number of partitions to split the tensor - contiguous_split_chunks ([`bool`], *optional*, default=`False`):: - If True, make each chunk contiguous in memory. + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(target_length, past_key_values_length, dtype=dtype), mask], dim=-1) + expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) + return expanded_mask + + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - numerator, denominator = tensor.size()[last_dim], num_partitions - if not (numerator % denominator == 0): - raise ValueError(f"{numerator} is not divisible by {denominator}") - last_dim_size = numerator // denominator - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -def attention_mask_func(attention_scores, attention_mask, causal_mask): - attention_mask_bool = ~attention_mask.bool() - - query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1) - padded_causal_mask = torch.logical_or( - attention_mask_bool[:, None, key_length - query_length : key_length, None], - ~causal_mask[:, :, key_length - query_length : key_length, :key_length].bool(), - ) - padded_causal_mask = torch.logical_or(padded_causal_mask, attention_mask_bool[:, None, None, :key_length]) - # Make use of floats - return ( - attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0), - padded_causal_mask, - ) + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + batch_size, source_length = mask.size() + tgt_len = tgt_len if tgt_len is not None else source_length + + expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, source_length).to(dtype) + inverted_mask = 1.0 - expanded_mask -def build_alibi_tensor(max_seq_len, n_head, device, dtype=torch.bfloat16): + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value @@ -101,73 +90,41 @@ def build_alibi_tensor(max_seq_len, n_head, device, dtype=torch.bfloat16): https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 Args: - Returns tensor shaped (n_head, 1, max_seq_len) - max_seq_len: (`int`, *required*): - max sequence length - n_head: (`int`, *required*): + Returns tensor shaped (batch_size * n_head, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + n_head (`int`, *required*): number of heads - dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): dtype of the output tensor + device (`torch.device`, *optional*, default=`torch.device('cpu')`): + device of the output alibi tensor """ + closest_power_of_2 = 2 ** math.floor(math.log2(n_head)) + base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != n_head: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32 + ) + num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) - def get_slopes(n): - def get_slopes_power_of_2(n): - start = 2 ** (-(2 ** -(math.log2(n) - 3))) - ratio = start - return [start * ratio**i for i in range(n)] - - if math.log2(n).is_integer(): - return get_slopes_power_of_2(n) - else: - closest_power_of_2 = 2 ** math.floor(math.log2(n)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] - ) - - slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1) - arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0) - alibi = slopes * arange_tensor.expand(n_head, -1, -1) - - alibi = alibi.to(device=device, dtype=dtype) - - return alibi - + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + # batch_size = 1, n_head = n_head, query_length -def pre_process_alibi_for_pad(alibi, attention_mask, num_heads): - """ - Args: - Pre-process the alibi tensor for padding. - alibi: ([`torch.tensor`], *required*): - alibi tensor to pre-process - attention_mask: ([`torch.tensor`], *required*): - attention mask to pre-process""" - - # Sanity check if we are not inferring less tokens than the total sequence length - # This usually happens when the inference is done with past_key_values - # In this case we re-create the alibi tensor with the correct sequence length - if attention_mask.shape[-1] != alibi.shape[-1]: - alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.device, alibi.dtype).repeat( - attention_mask.shape[0], 1, 1 - ) - # Get the indexes of the padding tokens - index_x0, index_y0 = torch.where(attention_mask == 0.0) - index_x1, index_y1 = torch.where(attention_mask == 1.0) - - # Clone the embeddings - we can detach because the embeddings are not learned - # Get a refence tensor - slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.device, alibi.dtype) - - # Loop over the batch where the padding is and replace the alibi tensor by the reference tensor - # Only where you do not have padding. Replace padding tokens by zeros - # This operation can be seen as a shifting operation. - for i, index in enumerate(torch.unique(index_x0)): - slice_to_modify = torch.zeros_like(slice_reference_alibi) - index_shift = index_y1[index_x1 == index] - shift_value = len(index_shift) - slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value] - alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify - return alibi + arange_tensor = (attention_mask.cumsum(-1)[:, None, :].to(device) - 1) * attention_mask[:, None] + alibi = slopes.unsqueeze(-1) * arange_tensor + alibi = alibi * attention_mask[:, None] + return alibi.reshape(alibi.shape[0] * n_head, 1, -1).to(dtype) def dropout_add(x, residual, prob, training): @@ -252,58 +209,6 @@ def forward(self, x): return bloom_gelu_forward(x) -class BloomScaledSoftmax(nn.Module): - """ - fused operation: scaling + mask + softmax - - Args: - input_in_fp16 (`bool`, *required*): - flag to indicate if input in fp16 data format. - input_in_bf16 (`bool`, *required*): - flag to indicate if input in bf16 data format. - scaled_masked_softmax_fusion (`bool`, *required*): - flag to indicate user want to use softmax fusion - mask_func (`function`, *required*): - mask function to be applied. - softmax_in_fp32 (`bool`, *required*): - if true, softmax in performed at fp32 precision. - scale (`float`, *required*): - scaling factor used in input tensor scaling. - """ - - def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale): - super().__init__() - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - - if not (self.scale is None or softmax_in_fp32): - raise ValueError("softmax should be in fp32 when scaled") - - def forward(self, input, mask, max_positions): - input_dtype = input.dtype - input_in_16bit = input_dtype in [torch.float16, torch.bfloat16] - softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype - - if self.scale is not None: - input = input * self.scale - - if mask is None: - mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device) - - mask = mask.to(input.device) - seq_ids = torch.arange(max_positions, device=input.device) - causal_mask = (seq_ids[None, :] <= seq_ids[:, None]).view(1, 1, max_positions, max_positions).to(input.device) - mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) - probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask) - - if input_in_16bit and self.softmax_in_fp32: - probs = probs.to(dtype=input_dtype) - - return probs - - class BloomAttention(nn.Module): def __init__(self, config, layer_number=None): super().__init__() @@ -315,8 +220,6 @@ def __init__(self, config, layer_number=None): self.num_heads = config.n_head self.head_dim = self.hidden_size // self.num_heads self.split_size = self.hidden_size - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - self.masked_softmax_fusion = config.masked_softmax_fusion self.hidden_dropout = config.hidden_dropout if self.head_dim * self.num_heads != self.hidden_size: @@ -329,18 +232,35 @@ def __init__(self, config, layer_number=None): self.layer_number = max(1, layer_number) self.norm_factor = math.sqrt(self.head_dim) * self.layer_number - # Scaled Softmax - self.scale_mask_softmax = BloomScaledSoftmax( - self.masked_softmax_fusion, - attention_mask_func, - self.attention_softmax_in_fp32, - self.layer_number, - ) - self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) self.dense = nn.Linear(self.hidden_size, self.hidden_size) self.attention_dropout = nn.Dropout(config.attention_dropout) + def _split_heads(self, fused_qkv): + """ + Split the last dimension into (num_heads, head_dim) + """ + new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim) + # new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1)) + # fused_qkv = fused_qkv.transpose(1, 0) + fused_qkv = fused_qkv.reshape(*new_tensor_shape) + # fused_qkv = fused_qkv.permute(0, 2, 1, 3) + return torch.split(fused_qkv, self.head_dim, -1) + + def _merge_heads(self, x): + # What we want to achieve is: + # batch_size * num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads * head_dim + + # First view to decompose the batch size + # batch_size*num_heads, seq_len, head_dim -> batch_size, num_heads, seq_len, head_dim + x = x.view(x.size(0) // self.num_heads, self.num_heads, x.size(1), self.head_dim) + + # batch_size, num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads, head_dim + x = x.permute(0, 2, 1, 3) + + # batch_size, seq_len, num_heads, head_dim -> batch_size, seq_len, num_heads * head_dim + return x.reshape(x.size(0), x.size(1), self.num_heads * self.head_dim) + def forward( self, hidden_states, @@ -352,25 +272,15 @@ def forward( use_cache=False, output_attentions=False, ): - # hidden_states: [batch_size, seq_length, hidden_size] - # repeat alibi tensor with the batch size - alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device) - - # apply preprocessing if the input is padded - if attention_mask is not None and 0 in attention_mask: - alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads) - - mixed_x_layer = self.query_key_value(hidden_states) - - # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim] - new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + alibi = alibi.to(hidden_states.device) # to make the model possible to run under accelerate + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] - # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) if layer_past is not None: past_key, past_value = layer_past + # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1) @@ -379,66 +289,39 @@ def forward( else: present = None - # [batch_size, head_dim, q_length, k_length] - output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1)) - - # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim] - query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1) - - # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim] - key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1) - - # slice alibi tensor until the query length - sliced_alibi = alibi[: output_size[0] * output_size[1], :, : output_size[3]] - - # Raw attention scores. [batch_size * num_heads, q_length, k_length] beta = 1.0 / self.layer_number - matmul_result = torch.baddbmm( - sliced_alibi, - query_layer.transpose(1, 0), - key_layer.transpose(1, 0).transpose(1, 2), - beta=beta, - alpha=(1.0 / self.norm_factor), - ) + # # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, k_length] -> [batch_size*num_heads, q_length, k_length] + matmul_result = (1.0 / self.norm_factor) * torch.bmm( + query_layer.transpose(1, 2).reshape(-1, query_layer.shape[1], query_layer.shape[3]), + key_layer.permute(0, 2, 3, 1).reshape(-1, key_layer.shape[3], key_layer.shape[1]), + ) + beta * alibi # change view to [batch_size, num_heads, q_length, k_length] - attention_scores = matmul_result.view(*output_size) - - # attention scores and attention mask [b, np, sq, sk] - max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2]) - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to( - value_layer.dtype - ) + attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2)) + + # We replace the scaled softmax by just a few line of code - [batch_size, num_heads, q_length, k_length] + input_dtype = attention_scores.dtype + attn_weights = (attention_scores * self.layer_number) + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) + attention_probs = attention_probs * (~attention_mask.bool()) + # [batch_size, num_heads, q_length, k_length] attention_probs = self.attention_dropout(attention_probs) if head_mask is not None: attention_probs = attention_probs * head_mask - # context layer shape: [batch_size, num_heads, q_length, head_dim] - output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - - # change view [k_length, batch_size x num_heads, head_dim] - value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1) - # change view [batch_size x num_heads, q_length, k_length] - attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + attention_probs_reshaped = attention_probs.view(*matmul_result.shape) # matmul: [batch_size * num_heads, q_length, head_dim] - context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1)) + context_layer = torch.bmm( + attention_probs_reshaped, value_layer.transpose(1, 2).reshape(-1, value_layer.size(1), value_layer.size(3)) + ) # change view [batch_size, num_heads, q_length, head_dim] - context_layer = context_layer.view(*output_size) - - # [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size] - new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,) - - context_layer = context_layer.view(*new_context_layer_shape) - - # Output. [q_length, batch_size, hidden_size] + context_layer = self._merge_heads(context_layer) # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 if self.pretraining_tp > 1 and self.slow_but_exact: @@ -452,11 +335,9 @@ def forward( else: output_tensor = self.dense(context_layer) - output = output_tensor.transpose(1, 0) - - output = dropout_add(output, residual, self.hidden_dropout, self.training) + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) - outputs = (output, present) + outputs = (output_tensor, present) if output_attentions: outputs += (attention_probs,) @@ -703,6 +584,24 @@ def __init__(self, config): def get_input_embeddings(self): return self.word_embeddings + def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length + ).to(attention_mask.device) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + def set_input_embeddings(self, new_embeddings): self.word_embeddings = new_embeddings @@ -765,9 +664,19 @@ def forward( # Compute alibi tensor: check build_alibi_tensor documentation current_sequence_length = hidden_states.shape[1] + past_key_values_length = 0 if past_key_values[0] is not None: - current_sequence_length += past_key_values[0][0].shape[1] - alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.device, hidden_states.dtype) + past_key_values_length = past_key_values[0][0].shape[1] + current_sequence_length += past_key_values_length + + if attention_mask is None: + attention_mask = torch.ones((hidden_states.shape[0], current_sequence_length), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = build_alibi_tensor(attention_mask, self.n_head, hidden_states.dtype, hidden_states.device) + + causal_mask = self._prepare_attn_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -793,14 +702,14 @@ def custom_forward(*inputs): create_custom_forward(block), hidden_states, None, - attention_mask, + causal_mask, head_mask[i], ) else: outputs = block( hidden_states, layer_past=layer_past, - attention_mask=attention_mask, + attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, @@ -877,7 +786,6 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): "input_ids": input_ids, "past_key_values": past, "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, "attention_mask": attention_mask, } diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index 0b2501982d52..ffe3247f3dd7 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -377,15 +377,34 @@ def test_model_from_pretrained(self): @slow @require_torch_gpu def test_simple_generation(self): + # This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations + # do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200 + # We set allow_fp16_reduced_precision_reduction = True. Please see: https://pytorch.org/docs/stable/notes/cuda.html#reduced-precision-reduction-in-fp16-gemms + # This discrepancy is observed only when using small models and seems to be stable for larger models. + # Our conclusion is that these operations are flaky for small inputs but seems to be stable for larger inputs (for the functions `baddmm` and `bmm`), and therefore for larger models. + + # Here is a summary of an ablation study of our observations + # EXPECTED_OUTPUT = "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am a very good listener. I am a very good person, and I am a very good person. I am a" + # 350m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS + # 350m + allow_fp16_reduced_precision_reduction = False + torch.baddm ==> PASS + # 350m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS + # 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> FAIL + + # EXPECTED_OUTPUT = "I enjoy walking with my cute dog, but I also enjoy hiking, biking, and swimming. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love" + # >=760m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS (for use_cache=True and use_cache=False) + # >=760m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS + # >=760m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS + path_350m = "bigscience/bloom-350m" - model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda() + model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda() model = model.eval() tokenizer = BloomTokenizerFast.from_pretrained(path_350m) input_sentence = "I enjoy walking with my cute dog" + # This output has been obtained using fp32 model on the huggingface DGX workstation - NVIDIA A100 GPU EXPECTED_OUTPUT = ( - "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am" - " a very good listener. I am a very good person, and I am a very good person. I am a" + "I enjoy walking with my cute dog, and I love to watch the kids play with the kids. I am a very " + "active person, and I enjoy working out, and I am a very active person. I am a very active person, and I" ) input_ids = tokenizer.encode(input_sentence, return_tensors="pt") @@ -397,7 +416,7 @@ def test_simple_generation(self): @require_torch_gpu def test_batch_generation(self): path_350m = "bigscience/bloom-350m" - model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda() + model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda() model = model.eval() tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left") @@ -416,8 +435,9 @@ def test_batch_generation(self): @slow @require_torch_gpu def test_batch_generation_padd(self): + path_350m = "bigscience/bloom-350m" - model = BloomForCausalLM.from_pretrained(path_350m, torch_dtype="auto", use_cache=True).cuda() + model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True).cuda() model = model.eval() tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left")