diff --git a/src/transformers/models/bloom/configuration_bloom.py b/src/transformers/models/bloom/configuration_bloom.py index 23ecc6d92671..20584eeb96b0 100644 --- a/src/transformers/models/bloom/configuration_bloom.py +++ b/src/transformers/models/bloom/configuration_bloom.py @@ -91,7 +91,9 @@ class BloomConfig(PretrainedConfig): issue](https://github.com/pytorch/pytorch/issues/76232). A solution to obtain more accurate results is to enable this feature. Enabling this will hurt the computational time of the inference. Will be probably resolved in the future once the main model has been fine-tuned with TP_rank=1. - + force_lm_head_in_fp32 (`bool` defaults to `True`): + Casts `lm_head` in fp32 in order to increase the chances that obtained logits are totally ordered, ie with + no values that is equal to another. Example: ```python @@ -130,6 +132,7 @@ def __init__( attention_dropout=0.0, pretraining_tp=1, # TP rank used when training with megatron slow_but_exact=False, + force_lm_head_in_fp32=True, **kwargs, ): self.vocab_size = vocab_size @@ -149,6 +152,7 @@ def __init__( self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.slow_but_exact = slow_but_exact + self.force_lm_head_in_fp32 = force_lm_head_in_fp32 super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 357d959c162c..a85b00c8d105 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -15,6 +15,7 @@ """PyTorch BLOOM model.""" import math +import warnings from typing import Tuple, Union import torch @@ -51,43 +52,41 @@ ] -def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): +def _make_causal_mask(input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int = 0): """ - Make causal mask used for bi-directional self-attention. + Make causal mask used for self-attention. """ batch_size, target_length = input_ids_shape - mask = torch.full((target_length, target_length), torch.finfo(dtype).min) - mask_cond = torch.arange(mask.size(-1)) - intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1) - mask.masked_fill_(intermediate_mask, 0) - mask = mask.to(dtype) + + mask = torch.ones((target_length, target_length), dtype=torch.bool, device=device) + mask.triu_(diagonal=1) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(target_length, past_key_values_length, dtype=dtype), mask], dim=-1) + past_key_values_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device) + mask = torch.cat([past_key_values_mask, mask], dim=-1) + expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) return expanded_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None): +def _expand_mask(mask: torch.Tensor, tgt_len: int = None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ - batch_size, source_length = mask.size() + batch_size, source_length = mask.shape tgt_len = tgt_len if tgt_len is not None else source_length - expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, source_length).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + expanded_mask = mask[:, None, None, :].to(torch.bool).expand(batch_size, 1, tgt_len, source_length) + return ~expanded_mask -def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) -> torch.Tensor: +def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype: torch.dtype) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value `softmax(l+a) = softmax(l)`. Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. Args: Returns tensor shaped (batch_size * n_head, 1, max_seq_len) @@ -101,16 +100,18 @@ def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) device of the output alibi tensor """ closest_power_of_2 = 2 ** math.floor(math.log2(n_head)) - base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32) - powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 + ) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != n_head: extra_base = torch.tensor( - 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32 + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 ) num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2) - extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) # Note: alibi will added to the attention bias that will be applied to the query, key product of attention @@ -120,10 +121,8 @@ def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) # This is more or less identical to T5's relative position bias: # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 # batch_size = 1, n_head = n_head, query_length - - arange_tensor = (attention_mask.cumsum(-1)[:, None, :].to(device) - 1) * attention_mask[:, None] - alibi = slopes.unsqueeze(-1) * arange_tensor - alibi = alibi * attention_mask[:, None] + arange_tensor = ((attention_mask.cumsum(-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor return alibi.reshape(alibi.shape[0] * n_head, 1, -1).to(dtype) @@ -210,7 +209,7 @@ def forward(self, x): class BloomAttention(nn.Module): - def __init__(self, config, layer_number=None): + def __init__(self, config, layer_number): super().__init__() self.pretraining_tp = config.pretraining_tp @@ -230,7 +229,8 @@ def __init__(self, config, layer_number=None): # Layer-wise attention scaling self.layer_number = max(1, layer_number) - self.norm_factor = math.sqrt(self.head_dim) * self.layer_number + self.inv_norm_factor = 1.0 / (math.sqrt(self.head_dim) * self.layer_number) + self.beta = 1.0 / self.layer_number self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) self.dense = nn.Linear(self.hidden_size, self.hidden_size) @@ -240,34 +240,32 @@ def _split_heads(self, fused_qkv): """ Split the last dimension into (num_heads, head_dim) """ - new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim) - # new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1)) - # fused_qkv = fused_qkv.transpose(1, 0) - fused_qkv = fused_qkv.reshape(*new_tensor_shape) - # fused_qkv = fused_qkv.permute(0, 2, 1, 3) + fused_qkv = fused_qkv.reshape(*fused_qkv.size()[:-1], self.num_heads, 3 * self.head_dim) return torch.split(fused_qkv, self.head_dim, -1) def _merge_heads(self, x): # What we want to achieve is: # batch_size * num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads * head_dim + batch_size_and_num_heads, seq_len, _ = x.shape + batch_size = batch_size_and_num_heads // self.num_heads # First view to decompose the batch size # batch_size*num_heads, seq_len, head_dim -> batch_size, num_heads, seq_len, head_dim - x = x.view(x.size(0) // self.num_heads, self.num_heads, x.size(1), self.head_dim) + x = x.view(batch_size, self.num_heads, seq_len, self.head_dim) # batch_size, num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads, head_dim x = x.permute(0, 2, 1, 3) # batch_size, seq_len, num_heads, head_dim -> batch_size, seq_len, num_heads * head_dim - return x.reshape(x.size(0), x.size(1), self.num_heads * self.head_dim) + return x.reshape(batch_size, seq_len, self.num_heads * self.head_dim) def forward( self, hidden_states, residual, + alibi, + attention_mask, layer_past=None, - attention_mask=None, - alibi=None, head_mask=None, use_cache=False, output_attentions=False, @@ -280,7 +278,7 @@ def forward( if layer_past is not None: past_key, past_value = layer_past - # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] + # concatenate along seq_length dimension -> [batch_size, kv_length, num_heads, head_dim] key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1) @@ -289,23 +287,29 @@ def forward( else: present = None - beta = 1.0 / self.layer_number + batch_size, q_length, _, _ = query_layer.shape + _, kv_length, _, _ = key_layer.shape # # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, k_length] -> [batch_size*num_heads, q_length, k_length] - matmul_result = (1.0 / self.norm_factor) * torch.bmm( - query_layer.transpose(1, 2).reshape(-1, query_layer.shape[1], query_layer.shape[3]), - key_layer.permute(0, 2, 3, 1).reshape(-1, key_layer.shape[3], key_layer.shape[1]), - ) + beta * alibi + matmul_result = torch.baddbmm( + input=alibi, + batch1=query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim), + batch2=key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, kv_length), + beta=self.beta, + alpha=self.inv_norm_factor, + ) # change view to [batch_size, num_heads, q_length, k_length] - attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2)) + attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) - # We replace the scaled softmax by just a few line of code - [batch_size, num_heads, q_length, k_length] + # we cast attention scores to fp32 compute scaled softmax and cast back into initial dtype - [batch_size, num_heads, q_length, k_length] input_dtype = attention_scores.dtype - attn_weights = (attention_scores * self.layer_number) + attention_mask - attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + attention_scores = attention_scores.float() + attn_weights = torch.masked_fill( + attention_scores * self.layer_number, attention_mask, torch.finfo(torch.float32).min + ) attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) - attention_probs = attention_probs * (~attention_mask.bool()) + # [batch_size, num_heads, q_length, k_length] attention_probs = self.attention_dropout(attention_probs) @@ -313,11 +317,12 @@ def forward( attention_probs = attention_probs * head_mask # change view [batch_size x num_heads, q_length, k_length] - attention_probs_reshaped = attention_probs.view(*matmul_result.shape) + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm( - attention_probs_reshaped, value_layer.transpose(1, 2).reshape(-1, value_layer.size(1), value_layer.size(3)) + attention_probs_reshaped, + value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, kv_length, self.head_dim), ) # change view [batch_size, num_heads, q_length, head_dim] @@ -376,7 +381,7 @@ def forward(self, hidden_states, residual): class BloomBlock(nn.Module): - def __init__(self, config, layer_number=None): + def __init__(self, config, layer_number): super().__init__() hidden_size = config.hidden_size @@ -393,12 +398,12 @@ def __init__(self, config, layer_number=None): def forward( self, hidden_states, + alibi, + attention_mask, layer_past=None, - attention_mask=None, head_mask=None, use_cache=False, output_attentions=False, - alibi=None, ): # hidden_states: [batch_size, seq_length, hidden_size] @@ -522,11 +527,6 @@ def _set_gradient_checkpointing(self, module, value=False): - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.max_position_embeddings - 1]`. - - [What are position IDs?](../glossary#position-ids) head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: @@ -567,7 +567,6 @@ def __init__(self, config): # Embedding + LN Embedding self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) - self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) # Transformer blocks @@ -588,16 +587,18 @@ def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_ke # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None + device = attention_mask.device + if input_shape[-1] > 1: combined_attention_mask = _make_causal_mask( - input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length - ).to(attention_mask.device) + input_shape, device=device, past_key_values_length=past_key_values_length + ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) + expanded_attn_mask = _expand_mask(attention_mask, tgt_len=input_shape[-1]) combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask ) return combined_attention_mask @@ -625,6 +626,12 @@ def forward( output_hidden_states=None, return_dict=None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + if position_ids is not None: + warnings.warn( + "`position_ids` is deprecated and will be removed in v5.0.0", + FutureWarning, + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -637,6 +644,7 @@ def forward( elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) + input_shape = input_ids.size() elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: @@ -652,7 +660,7 @@ def forward( head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) + inputs_embeds = self.word_embeddings(input_ids).to(self.word_embeddings_layernorm.weight.dtype) hidden_states = self.word_embeddings_layernorm(inputs_embeds) @@ -672,9 +680,11 @@ def forward( if attention_mask is None: attention_mask = torch.ones((hidden_states.shape[0], current_sequence_length), device=hidden_states.device) else: + attention_mask_shape = attention_mask.size() + attention_mask = attention_mask.view(-1, attention_mask_shape[-1]) attention_mask = attention_mask.to(hidden_states.device) - alibi = build_alibi_tensor(attention_mask, self.n_head, hidden_states.dtype, hidden_states.device) + alibi = build_alibi_tensor(attention_mask, self.n_head, hidden_states.dtype) causal_mask = self._prepare_attn_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length) @@ -694,14 +704,14 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, use_cache, output_attentions, alibi) + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) return custom_forward outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, - None, + alibi, causal_mask, head_mask[i], ) @@ -760,6 +770,12 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + def tie_weights(self): + super(BloomForCausalLM, self).tie_weights() + + if self.config.force_lm_head_in_fp32: + self.lm_head.to(torch.float32) + def get_output_embeddings(self): return self.lm_head @@ -772,16 +788,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): input_ids = input_ids[:, -1].unsqueeze(-1) attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past: - position_ids = position_ids[:, -1].unsqueeze(-1) - else: - position_ids = None + return { "input_ids": input_ids, "past_key_values": past, @@ -816,13 +823,18 @@ def forward( `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ + if position_ids is not None: + warnings.warn( + "`position_ids` is deprecated and will be removed in v5.0.0", + FutureWarning, + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, - position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -832,6 +844,8 @@ def forward( ) hidden_states = transformer_outputs[0] + if self.config.force_lm_head_in_fp32: + hidden_states = hidden_states.to(self.lm_head.weight.dtype) lm_logits = self.lm_head(hidden_states) loss = None @@ -922,6 +936,11 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ + if position_ids is not None: + warnings.warn( + "`position_ids` is deprecated and will be removed in v5.0.0", + FutureWarning, + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -929,7 +948,6 @@ def forward( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, - position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -1051,6 +1069,11 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ + if position_ids is not None: + warnings.warn( + "`position_ids` is deprecated and will be removed in v5.0.0", + FutureWarning, + ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -1058,7 +1081,6 @@ def forward( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, - position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/tests/models/bloom/test_modeling_bloom.py b/tests/models/bloom/test_modeling_bloom.py index b0307b922c25..5cf378e658d5 100644 --- a/tests/models/bloom/test_modeling_bloom.py +++ b/tests/models/bloom/test_modeling_bloom.py @@ -96,7 +96,7 @@ def __init__( def get_large_model_config(self): return BloomConfig.from_pretrained("bigscience/bloom") - def prepare_config_and_inputs(self, gradient_checkpointing=False): + def prepare_config_and_inputs(self, gradient_checkpointing=False, force_lm_head_in_fp32=True): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) input_mask = None @@ -107,11 +107,13 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): if self.use_labels: sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - config = self.get_config(gradient_checkpointing=gradient_checkpointing) + config = self.get_config( + gradient_checkpointing=gradient_checkpointing, force_lm_head_in_fp32=force_lm_head_in_fp32 + ) return (config, input_ids, input_mask, sequence_labels) - def get_config(self, gradient_checkpointing=False, slow_but_exact=True): + def get_config(self, gradient_checkpointing=False, force_lm_head_in_fp32=True, slow_but_exact=True): return BloomConfig( vocab_size=self.vocab_size, seq_length=self.seq_length, @@ -130,7 +132,7 @@ def get_config(self, gradient_checkpointing=False, slow_but_exact=True): num_labels=self.num_labels, gradient_checkpointing=gradient_checkpointing, slow_but_exact=slow_but_exact, - dtype="float32", + force_lm_head_in_fp32=force_lm_head_in_fp32, ) def create_and_check_bloom_model(self, config, input_ids, input_mask, *args): @@ -138,7 +140,8 @@ def create_and_check_bloom_model(self, config, input_ids, input_mask, *args): model.to(torch_device) model.eval() - result = model(input_ids) + with torch.no_grad(): + result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(len(result.past_key_values), config.n_layer) @@ -150,9 +153,10 @@ def create_and_check_bloom_model_past(self, config, input_ids, input_mask, *args model.eval() # first forward pass - outputs = model(input_ids, attention_mask=torch.ones_like(input_ids), use_cache=True) - outputs_use_cache_conf = model(input_ids, attention_mask=torch.ones_like(input_ids)) - outputs_no_past = model(input_ids, use_cache=False, attention_mask=torch.ones_like(input_ids)) + with torch.no_grad(): + outputs = model(input_ids, attention_mask=torch.ones_like(input_ids), use_cache=True) + outputs_use_cache_conf = model(input_ids, attention_mask=torch.ones_like(input_ids)) + outputs_no_past = model(input_ids, use_cache=False, attention_mask=torch.ones_like(input_ids)) self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) @@ -165,8 +169,9 @@ def create_and_check_bloom_model_past(self, config, input_ids, input_mask, *args # append to next input_ids and token_type_ids next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) - output_from_no_past = model(next_input_ids)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past)["last_hidden_state"] + with torch.no_grad(): + output_from_no_past = model(next_input_ids)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past)["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() @@ -174,7 +179,7 @@ def create_and_check_bloom_model_past(self, config, input_ids, input_mask, *args output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice)) def create_and_check_bloom_model_attention_mask_past(self, config, input_ids, input_mask, *args): model = BloomModel(config=config) @@ -187,7 +192,8 @@ def create_and_check_bloom_model_attention_mask_past(self, config, input_ids, in attn_mask[:, half_seq_length:] = 0 # first forward pass - output, past = model(input_ids, attention_mask=attn_mask).to_tuple() + with torch.no_grad(): + output, past = model(input_ids, attention_mask=attn_mask).to_tuple() # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) @@ -205,8 +211,9 @@ def create_and_check_bloom_model_attention_mask_past(self, config, input_ids, in ) # get two different outputs - output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"] + with torch.no_grad(): + output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past, attention_mask=attn_mask)["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() @@ -214,7 +221,7 @@ def create_and_check_bloom_model_attention_mask_past(self, config, input_ids, in output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice)) def create_and_check_bloom_model_past_large_inputs(self, config, input_ids, input_mask, *args): model = BloomModel(config=config) @@ -222,7 +229,8 @@ def create_and_check_bloom_model_past_large_inputs(self, config, input_ids, inpu model.eval() # first forward pass - outputs = model(input_ids, attention_mask=input_mask, use_cache=True) + with torch.no_grad(): + outputs = model(input_ids, attention_mask=input_mask, use_cache=True) output, past = outputs.to_tuple() @@ -234,10 +242,11 @@ def create_and_check_bloom_model_past_large_inputs(self, config, input_ids, inpu next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) - output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past)[ - "last_hidden_state" - ] + with torch.no_grad(): + output_from_no_past = model(next_input_ids, attention_mask=next_attention_mask)["last_hidden_state"] + output_from_past = model(next_tokens, attention_mask=next_attention_mask, past_key_values=past)[ + "last_hidden_state" + ] self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1]) # select random slice @@ -246,14 +255,16 @@ def create_and_check_bloom_model_past_large_inputs(self, config, input_ids, inpu output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() # test that outputs are equal for slice - self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice)) def create_and_check_lm_head_model(self, config, input_ids, input_mask, *args): model = BloomForCausalLM(config) model.to(torch_device) model.eval() - result = model(input_ids, labels=input_ids) + with torch.no_grad(): + result = model(input_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) @@ -263,7 +274,9 @@ def create_and_check_sequence_classification_model(self, config, input_ids, inpu model.to(torch_device) model.eval() - result = model(input_ids, attention_mask=input_mask) + with torch.no_grad(): + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) def create_and_check_token_classification_model(self, config, input_ids, input_mask, *args): @@ -271,7 +284,9 @@ def create_and_check_token_classification_model(self, config, input_ids, input_m model.to(torch_device) model.eval() - result = model(input_ids, attention_mask=input_mask) + with torch.no_grad(): + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels)) def create_and_check_forward_and_backwards( @@ -368,6 +383,47 @@ def test_bloom_weight_initialization(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_bloom_weight_initialization(*config_and_inputs) + @require_torch_gpu + def test_force_lm_head_in_fp32_is_close_to_fp16(self): + model_name = "bigscience/bigscience-small-testing" + + _, input_ids, input_mask, _ = self.model_tester.prepare_config_and_inputs() + model = BloomForCausalLM.from_pretrained(model_name, force_lm_head_in_fp32=True, torch_dtype=torch.float16).to( + torch_device + ) + model_in_fp16 = BloomForCausalLM.from_pretrained( + model_name, force_lm_head_in_fp32=False, torch_dtype=torch.float16 + ).to(torch_device) + + # Test that the model have the correct precisions + for key, value in model.state_dict().items(): + if key in ["transformer.word_embeddings.weight", "lm_head.weight"]: + self.assertEqual(value.dtype, torch.float32) + else: + self.assertEqual(value.dtype, torch.float16) + for value in model_in_fp16.state_dict().values(): + self.assertEqual(value.dtype, torch.float16) + + model.eval() + model_in_fp16.eval() + + with torch.no_grad(): + output = model(input_ids=input_ids, attention_mask=input_mask).logits + output_in_fp16 = model_in_fp16(input_ids=input_ids, attention_mask=input_mask).logits + + # We guarantee that models in fp16 and fp16 with `force_lm_head_in_fp32=True` are close. + self.assertTrue(torch.allclose(output, output_in_fp16.to(torch.float32), atol=1e-4, rtol=1e-4)) + + # We verify that fp16 have value collapses due to output vocabulary begin higher that maximum range of fp16. + random_batch_id = torch.randint(input_ids.shape[0], ()) + random_sequence_id = torch.randint(input_ids.shape[1], ()) + # Test that we see at least one collapse in fp16 and none when `force_lm_head_in_fp32=True`,ie the `len(unique_values) < vocabulary_size`. + # We could test that it's smaller that fp16 max range as well + self.assertTrue( + len(torch.unique(output_in_fp16[random_batch_id, random_sequence_id])) < output_in_fp16.shape[-1] + ) + self.assertTrue(len(torch.unique(output[random_batch_id, random_sequence_id])) < output_in_fp16.shape[-1]) + @slow def test_model_from_pretrained(self): for model_name in BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: @@ -377,47 +433,58 @@ def test_model_from_pretrained(self): @slow @require_torch_gpu def test_simple_generation(self): - # This test is a bit flaky. For some GPU architectures, pytorch sets by default allow_fp16_reduced_precision_reduction = True and some operations - # do not give the same results under this configuration, especially torch.baddmm and torch.bmm. https://pytorch.org/docs/stable/notes/numerical_accuracy.html#fp16-on-mi200 - # As we leave the default value (True) for allow_fp16_reduced_precision_reduction , the tests failed when running in half-precision with smaller models (350m) - # Please see: https://pytorch.org/docs/stable/notes/cuda.html#reduced-precision-reduction-in-fp16-gemms - # This discrepancy is observed only when using small models and seems to be stable for larger models. - # Our conclusion is that these operations are flaky for small inputs but seems to be stable for larger inputs (for the functions `baddmm` and `bmm`), and therefore for larger models. - - # Here is a summary of an ablation study of our observations - # EXPECTED_OUTPUT = "I enjoy walking with my cute dog, and I love to watch the kids play. I am a very active person, and I am a very good listener. I am a very good person, and I am a very good person. I am a" - # 350m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS - # 350m + allow_fp16_reduced_precision_reduction = False + torch.baddm ==> PASS - # 350m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS - # 350m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> FAIL - - # EXPECTED_OUTPUT = "I enjoy walking with my cute dog, but I also enjoy hiking, biking, and swimming. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love to cook and bake. I love" - # >=760m + allow_fp16_reduced_precision_reduction = True + torch.baddm ==> PASS (for use_cache=True and use_cache=False) - # >=760m + allow_fp16_reduced_precision_reduction = True + torch.bmm ==> PASS - # >=760m + allow_fp16_reduced_precision_reduction = False + torch.bmm ==> PASS - path_350m = "bigscience/bloom-350m" - model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True, revision="gs555750").cuda() + model = BloomForCausalLM.from_pretrained( + path_350m, torch_dtype="auto", use_cache=True, revision="gs555750" + ).cuda() model = model.eval() tokenizer = BloomTokenizerFast.from_pretrained(path_350m) input_sentence = "I enjoy walking with my cute dog" - # This output has been obtained using fp32 model on the huggingface DGX workstation - NVIDIA A100 GPU EXPECTED_OUTPUT = ( - "I enjoy walking with my cute dog, and I love to watch the kids play with the kids. I am a very " - "active person, and I enjoy working out, and I am a very active person. I am a very active person, and I" + "I enjoy walking with my cute dog, and I love to watch the kids play with the kids. I am a very active" + " person, and I enjoy working out, and I am a very active person. I am a very active person, and I" + ) + + input_ids = tokenizer.encode(input_sentence, return_tensors="pt") + greedy_output = model.generate(input_ids.cuda(), max_length=50) + + self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT) + + @slow + @require_torch_gpu + def test_simple_generation_match_with_fp32(self): + path_350m = "bigscience/bloom-350m" + model = BloomForCausalLM.from_pretrained( + path_350m, torch_dtype="auto", use_cache=True, revision="gs555750" + ).cuda() + model_fp32 = BloomForCausalLM.from_pretrained(path_350m, use_cache=True, revision="gs555750").cuda() + model.eval() + model_fp32.eval() + + tokenizer = BloomTokenizerFast.from_pretrained(path_350m) + + input_sentence = "I enjoy walking with my cute dog" + EXPECTED_OUTPUT = ( + "I enjoy walking with my cute dog, and I love to watch the kids play with the kids. I am a very active" + " person, and I enjoy working out, and I am a very active person. I am a very active person, and I" ) input_ids = tokenizer.encode(input_sentence, return_tensors="pt") greedy_output = model.generate(input_ids.cuda(), max_length=50) + greedy_output_in_fp32 = model_fp32.generate(input_ids.cuda(), max_length=50) self.assertEqual(tokenizer.decode(greedy_output[0], skip_special_tokens=True), EXPECTED_OUTPUT) + # We test that fp32 has the same result + self.assertEqual(tokenizer.decode(greedy_output_in_fp32[0], skip_special_tokens=True), EXPECTED_OUTPUT) @slow @require_torch_gpu def test_batch_generation(self): path_350m = "bigscience/bloom-350m" - model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True, revision="gs555750").cuda() + model = BloomForCausalLM.from_pretrained( + path_350m, torch_dtype="auto", use_cache=True, revision="gs555750" + ).cuda() model = model.eval() tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left") @@ -438,7 +505,9 @@ def test_batch_generation(self): def test_batch_generation_padd(self): path_350m = "bigscience/bloom-350m" - model = BloomForCausalLM.from_pretrained(path_350m, use_cache=True, revision="gs555750").cuda() + model = BloomForCausalLM.from_pretrained( + path_350m, torch_dtype="auto", use_cache=True, revision="gs555750" + ).cuda() model = model.eval() tokenizer = BloomTokenizerFast.from_pretrained(path_350m, padding_side="left") @@ -478,10 +547,10 @@ class BloomEmbeddingTest(unittest.TestCase): You need to install tokenizers following this readme: - - https://huggingface.co/bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles + - https://huggingface.co/bigscience/tokenizer Tokenizer used during training: - - https://huggingface.co/bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles + - https://huggingface.co/bigscience/tokenizer # TODO change the script (or just add skip) when building the env with tokenizers 0.12.0 """ @@ -490,9 +559,10 @@ def setUp(self): super().setUp() self.path_bigscience_model = "bigscience/bigscience-small-testing" - @require_torch def test_embeddings(self): - model = BloomForCausalLM.from_pretrained(self.path_bigscience_model, torch_dtype="auto") # load in fp32 + model = BloomForCausalLM.from_pretrained( + self.path_bigscience_model, torch_dtype="auto", force_lm_head_in_fp32=False + ) # load in bf16 model.eval() EMBEDDINGS_DS_BEFORE_LN_BF_16_MEAN = { @@ -698,7 +768,9 @@ def test_embeddings(self): tensor_ids = torch.LongTensor([EXAMPLE_IDS]) with torch.no_grad(): embeddings = model.transformer.word_embeddings(tensor_ids) - embeddings_ln = model.transformer.word_embeddings_layernorm(embeddings) # + if model.config.force_lm_head_in_fp32: + embeddings = embeddings.to(model.transformer.word_embeddings_layernorm.weight.dtype) + embeddings_ln = model.transformer.word_embeddings_layernorm(embeddings) # first check the embeddings before LN output_dict = {"min": {}, "max": {}, "mean": {}, "sum": {"value": embeddings.sum().item()}} for i, idx in enumerate(EXAMPLE_IDS): @@ -720,12 +792,11 @@ def test_embeddings(self): for j, idx in enumerate(output_dict[key].keys()): self.assertAlmostEqual(EMBEDDINGS_DS_AFTER_LN[key][idx], output_dict_norm[key][idx], places=1) - @require_torch def test_hidden_states_transformers(self): cuda_available = torch.cuda.is_available() - model = BloomModel.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to( - torch_device - ) + model = BloomModel.from_pretrained( + self.path_bigscience_model, use_cache=False, torch_dtype="auto", force_lm_head_in_fp32=False + ).to(torch_device) model.eval() # fmt: off @@ -750,10 +821,11 @@ def test_hidden_states_transformers(self): self.assertDictEqual(MIN_MAX_DICT, output_dict) - @require_torch def test_logits(self): cuda_available = torch.cuda.is_available() - model = BloomForCausalLM.from_pretrained(self.path_bigscience_model, use_cache=False, torch_dtype="auto").to( + model = BloomForCausalLM.from_pretrained( + self.path_bigscience_model, use_cache=False, torch_dtype="auto", force_lm_head_in_fp32=False + ).to( torch_device ) # load in bf16 model.eval()