diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 09bf91999274..b8692f407a90 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -751,15 +751,7 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: # encoder_extended_attention_mask = (encoder_extended_attention_mask == # encoder_extended_attention_mask.transpose(-1, -2)) encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility - - if self.dtype == torch.float16: - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4 - elif self.dtype in [torch.bfloat16, torch.float32]: - encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 - else: - raise ValueError( - f"{self.dtype} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`" - ) + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min return encoder_extended_attention_mask @@ -792,7 +784,7 @@ def create_extended_attention_mask_for_decoder(input_shape, attention_mask, devi return extended_attention_mask def get_extended_attention_mask( - self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None + self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None, dtype: torch.float = None ) -> Tensor: """ Makes broadcastable attention and causal masks so that future and masked tokens are ignored. @@ -806,6 +798,9 @@ def get_extended_attention_mask( Returns: `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. """ + if dtype is None: + dtype = self.dtype + if not (attention_mask.dim() == 2 and self.config.is_decoder): # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` if device is not None: @@ -836,8 +831,8 @@ def get_extended_attention_mask( # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min return extended_attention_mask def get_head_mask( diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index cc6871d936c0..169f1faeb8ba 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -728,7 +728,7 @@ def forward( extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 477a82b07b7e..011eee1f24b5 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -93,7 +93,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index b42eb803493e..b1bba525fbc0 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -96,7 +96,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index f5b34ed8baa7..2a53099d9c4c 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -83,7 +83,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index ac2051b322f3..e5b717ef9c18 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/canine/modeling_canine.py b/src/transformers/models/canine/modeling_canine.py index bb7b1492c7bf..3b965bb9f2be 100644 --- a/src/transformers/models/canine/modeling_canine.py +++ b/src/transformers/models/canine/modeling_canine.py @@ -467,7 +467,7 @@ def forward( # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. - attention_mask = (1.0 - attention_mask.float()) * -10000.0 + attention_mask = (1.0 - attention_mask.float()) * torch.finfo(attention_scores.dtype).min # Apply the attention mask (precomputed for all layers in CanineModel forward() function) attention_scores = attention_scores + attention_mask diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 5c34c658b59f..ddc2236371c2 100755 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -638,7 +638,9 @@ def forward( bsz, seq_len = input_shape # CLIP's text model uses causal mask, prepare it here. # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 - causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device) + causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to( + hidden_states.device + ) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -670,11 +672,11 @@ def forward( attentions=encoder_outputs.attentions, ) - def _build_causal_attention_mask(self, bsz, seq_len): + def _build_causal_attention_mask(self, bsz, seq_len, dtype): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(bsz, seq_len, seq_len) - mask.fill_(torch.tensor(float("-inf"))) + mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype) + mask.fill_(torch.tensor(torch.finfo(dtype).min)) mask.triu_(1) # zero out the lower diagonal mask = mask.unsqueeze(1) # expand mask return mask diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index cec2d0d345b2..c091c201de8a 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -435,7 +435,7 @@ def forward( # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.n_layer) diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 1ee847f4ea1a..63712863bfc8 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -578,7 +578,8 @@ def forward( hidden_states[~expand_attention_mask] = 0 # extend attention_mask - attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 509d7250b719..90c393e5098b 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -188,7 +188,11 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) - attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: # Apply the attention mask @@ -239,7 +243,11 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() - attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: # Apply the attention mask @@ -578,7 +586,7 @@ def forward( # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index d261104ac7ad..ea966b7e8eea 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -1809,7 +1809,7 @@ def forward(self, q, k, mask: Optional[Tensor] = None): weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head) if mask is not None: - weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) + weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min) weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size()) weights = self.dropout(weights) return weights diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index a93c08345a0b..fc5b5a7b0f7c 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -211,7 +211,9 @@ def unshape(x: torch.Tensor) -> torch.Tensor: q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length) mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length) - scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length) + scores = scores.masked_fill( + mask, torch.tensor(torch.finfo(scores.dtype).min) + ) # (bs, n_heads, q_length, k_length) weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length) weights = self.dropout(weights) # (bs, n_heads, q_length, k_length) diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 937b8a712821..d44bc80363d0 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -323,8 +323,8 @@ def _prepare_fsmt_decoder_inputs( decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id) else: decoder_padding_mask = invert_mask(decoder_padding_mask) - causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to( - dtype=causal_mask_dtype, device=decoder_input_ids.device + causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len, dtype=causal_mask_dtype)), 1).to( + device=decoder_input_ids.device ) return decoder_input_ids, decoder_padding_mask, causal_mask @@ -908,7 +908,7 @@ def forward( if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2) - attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) + attn_weights = attn_weights.masked_fill(reshaped, torch.finfo(attn_weights.dtype).min) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) @@ -975,7 +975,7 @@ def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz): def fill_with_neg_inf(t): """FP16-compatible function that fills a input_ids with -inf.""" - return t.float().fill_(float("-inf")).type_as(t) + return t.float().fill_(torch.finfo(t.dtype).min).type_as(t) # Public API diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index b5872be2815c..4c888cda8e01 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -199,7 +199,11 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) - attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: # Apply the attention mask @@ -250,7 +254,11 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() - attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: # Apply the attention mask @@ -811,7 +819,7 @@ def forward( # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 4e507d8d8598..ddcfb4e33677 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -189,7 +189,11 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) - attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: # Apply the attention mask @@ -566,7 +570,7 @@ def forward( # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 8a1879a624ae..933b535b083c 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -196,7 +196,11 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) - attn_scores = torch.where(causal_mask, attn_scores, self.masked_bias.to(attn_scores.dtype)) + mask_value = torch.finfo(attn_scores.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) + attn_scores = torch.where(causal_mask, attn_scores, mask_value) if attention_mask is not None: # Apply the attention mask @@ -214,7 +218,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): def attention_mask_func(attention_scores, ltor_mask): - attention_scores.masked_fill_(~ltor_mask, -10000.0) + attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min) return attention_scores @@ -460,7 +464,7 @@ def forward( # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index fed2ee12a8c9..cb05902ee422 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -170,7 +170,12 @@ def _attn( key = key.to(torch.float32) attn_weights = torch.matmul(query, key.transpose(-1, -2)) - attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) attn_weights = attn_weights / self.scale_attn @@ -605,7 +610,7 @@ def forward( # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index ab29fe13c936..c2b745f6d5a6 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -664,7 +664,8 @@ def forward( hidden_states[~expand_attention_mask] = 0 # extend attention_mask - attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) @@ -753,7 +754,8 @@ def forward( hidden_states[~expand_attention_mask] = 0 # extend attention_mask - attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 4a18c64a13dd..c80f16267c69 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -253,7 +253,11 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() - attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: # Apply the attention mask @@ -304,7 +308,11 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() - attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: # Apply the attention mask @@ -765,7 +773,7 @@ def forward( # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 2a48ba5f4fd5..e3a625416a7d 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -805,7 +805,7 @@ def forward( extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min if head_mask is not None: if head_mask.dim() == 1: diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 7faa34eec430..382a7b305bbb 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -178,7 +178,9 @@ def forward( attention_scores += rel_pos if self.has_spatial_attention_bias: attention_scores += rel_2d_pos - attention_scores = attention_scores.float().masked_fill_(attention_mask.to(torch.bool), float("-inf")) + attention_scores = attention_scores.float().masked_fill_( + attention_mask.to(torch.bool), torch.finfo(attention_scores.dtype).min + ) attention_probs = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).type_as(value_layer) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. @@ -916,7 +918,7 @@ def forward( extended_attention_mask = final_attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min if head_mask is not None: if head_mask.dim() == 1: diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index f301afafd80e..f3bdd2cd8d90 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -886,7 +886,9 @@ def forward( position_ids = position_ids.expand_as(input_ids) final_position_ids = position_ids - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, None, device) + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( + attention_mask, None, device, dtype=embedding_output.dtype + ) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 6ac0bfccb196..3fba42b7d544 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) @@ -207,7 +207,7 @@ def forward( # cast to fp32/fp16 then replace 1's with -inf float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( - remove_from_windowed_attention_mask, -10000.0 + remove_from_windowed_attention_mask, torch.finfo(query_vectors.dtype).min ) # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( @@ -579,7 +579,7 @@ def _concat_with_global_key_attn_probs( attn_probs_from_global_key[ is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] - ] = -10000.0 + ] = torch.finfo(attn_probs_from_global_key.dtype).min return attn_probs_from_global_key @@ -675,11 +675,11 @@ def _compute_global_attn_output_from_hidden( global_attn_scores[ is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : - ] = -10000.0 + ] = torch.finfo(global_attn_scores.dtype).min global_attn_scores = global_attn_scores.masked_fill( is_index_masked[:, None, None, :], - -10000.0, + torch.finfo(global_attn_scores.dtype).min, ) global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 30db98dea1ab..797e744bbf7d 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -586,7 +586,7 @@ def forward( # cast to fp32/fp16 then replace 1's with -inf float_mask = remove_from_windowed_attention_mask.type_as(query_vectors).masked_fill( - remove_from_windowed_attention_mask, -10000.0 + remove_from_windowed_attention_mask, torch.finfo(query_vectors.dtype).min ) # diagonal mask with zeros everywhere and -inf inplace of padding diagonal_mask = self._sliding_chunks_query_key_matmul( @@ -958,7 +958,7 @@ def _concat_with_global_key_attn_probs( attn_probs_from_global_key[ is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] - ] = -10000.0 + ] = torch.finfo(attn_probs_from_global_key.dtype).min return attn_probs_from_global_key @@ -1054,11 +1054,11 @@ def _compute_global_attn_output_from_hidden( global_attn_scores[ is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : - ] = -10000.0 + ] = torch.finfo(global_attn_scores.dtype).min global_attn_scores = global_attn_scores.masked_fill( is_index_masked[:, None, None, :], - -10000.0, + torch.finfo(global_attn_scores.dtype).min, ) global_attn_scores = global_attn_scores.view(batch_size * self.num_heads, max_num_global_attn_indices, seq_len) diff --git a/src/transformers/models/longt5/modeling_flax_longt5.py b/src/transformers/models/longt5/modeling_flax_longt5.py index 8ea0b38bb430..766dc36888e2 100644 --- a/src/transformers/models/longt5/modeling_flax_longt5.py +++ b/src/transformers/models/longt5/modeling_flax_longt5.py @@ -549,10 +549,11 @@ def __call__( # replace masked positions with -10_000 if attention_mask is not None: + mask_value = jnp.finfo(self.dtype).min attention_mask = jax.lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e4).astype(self.dtype), + jnp.full(attention_mask.shape, mask_value).astype(self.dtype), ) if position_bias is None: diff --git a/src/transformers/models/luke/modeling_luke.py b/src/transformers/models/luke/modeling_luke.py index 4c2491aee730..cca68cd5354b 100644 --- a/src/transformers/models/luke/modeling_luke.py +++ b/src/transformers/models/luke/modeling_luke.py @@ -1076,7 +1076,7 @@ def get_extended_attention_mask( raise ValueError(f"Wrong shape for attention_mask (shape {attention_mask.shape})") extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min return extended_attention_mask diff --git a/src/transformers/models/lxmert/modeling_lxmert.py b/src/transformers/models/lxmert/modeling_lxmert.py index 34b90441fede..749fd9ea0c47 100644 --- a/src/transformers/models/lxmert/modeling_lxmert.py +++ b/src/transformers/models/lxmert/modeling_lxmert.py @@ -959,13 +959,13 @@ def forward( # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min # Process the visual attention mask if visual_attention_mask is not None: extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2) extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=self.dtype) - extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0 + extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * torch.finfo(self.dtype).min else: extended_visual_attention_mask = None diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 90de52d4c351..9e8030296d50 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -79,7 +79,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 0dc30ed0b476..32e59098ef11 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -81,7 +81,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index d342d5fcbf3b..b9057178a032 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -97,7 +97,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index f51367813781..e5e5da5da0c9 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -479,7 +479,7 @@ def forward( # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.n_layer) diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 6db58a82d61a..9e74c98d43cc 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -61,7 +61,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 25a8676d6f1e..9e797af035cf 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 7ca17146f3c5..eb8b5d2b4167 100755 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -94,7 +94,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 1ca6a0e49089..537d1f80d448 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -187,7 +187,9 @@ def ngram_attention_bias(sequence_length, ngram, device, dtype): """ This function computes the bias for the predict stream """ - left_block = torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * float("-inf") + left_block = ( + torch.ones((ngram, sequence_length, sequence_length), device=device, dtype=dtype) * torch.finfo(dtype).min + ) right_block = left_block.detach().clone() # create bias for stream_idx in range(ngram): @@ -1336,7 +1338,7 @@ def forward( if attention_mask is not None: extended_attention_mask = ( 1.0 - attention_mask[:, None, :].repeat(self.config.num_encoder_attention_heads, 1, 1) - ) * -10000.0 + ) * torch.finfo(self.dtype).min extended_attention_mask = extended_attention_mask.to(inputs_embeds.dtype) else: extended_attention_mask = None @@ -1554,7 +1556,7 @@ def forward( if encoder_attention_mask is not None: extended_encoder_attention_mask = ( 1.0 - encoder_attention_mask[:, None, :].repeat(self.config.num_decoder_attention_heads, 1, 1) - ) * -10000.0 + ) * torch.finfo(self.dtype).min extended_encoder_attention_mask = extended_encoder_attention_mask.to(inputs_embeds.dtype) else: extended_encoder_attention_mask = None @@ -1713,7 +1715,10 @@ def prepare_attention_mask(self, hidden_states, attention_mask): # get causal mask causal_mask = torch.full( - (seq_length, seq_length), -float("inf"), dtype=hidden_states.dtype, device=hidden_states.device + (seq_length, seq_length), + torch.finfo(hidden_states.dtype).min, + dtype=hidden_states.dtype, + device=hidden_states.device, ) causal_mask = torch.triu(causal_mask, 1) extended_causal_mask = causal_mask[:seq_length, :seq_length][None, :, :].expand( @@ -1722,7 +1727,7 @@ def prepare_attention_mask(self, hidden_states, attention_mask): # add usual attention mask if attention_mask is not None: - extended_attention_mask = (1.0 - attention_mask[:, None, :]) * -10000.0 + extended_attention_mask = (1.0 - attention_mask[:, None, :]) * torch.finfo(self.dtype).min extended_attention_mask = extended_causal_mask + extended_attention_mask else: extended_attention_mask = extended_causal_mask @@ -1750,7 +1755,7 @@ def prepare_predict_attention_mask(self, hidden_states, attention_mask): # add usual attention mask if attention_mask is not None: - extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * -10000.0 + extended_attention_mask = (1.0 - attention_mask[None, :, None, :]) * torch.finfo(self.dtype).min extended_attention_mask = extended_attention_mask.expand((self.ngram, batch_size, seq_length, seq_length)) # predicted stream attention_mask should always be 0 extended_attention_mask = torch.cat( diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index e6de31a4cb5e..f63ea07ad960 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -869,8 +869,8 @@ def _spans_given_width(width): return starts, ends, span_masks - def mask_to_score(mask): - return (1.0 - mask.type(torch.float32)) * -10000.0 + def mask_to_score(mask, dtype=torch.float32): + return (1.0 - mask.type(dtype)) * torch.finfo(dtype).min # [reader_beam_size, max_sequence_len, span_hidden_size * 2] hidden_states = self.dense_intermediate(hidden_states) @@ -890,7 +890,7 @@ def mask_to_score(mask): # [reader_beam_size, num_candidates] reader_logits = self.dense_output(candidate_hidden).squeeze(-1) # [reader_beam_size, num_candidates] - reader_logits += mask_to_score(candidate_mask) + reader_logits += mask_to_score(candidate_mask, dtype=reader_logits.dtype) return reader_logits, candidate_starts, candidate_ends @@ -1634,11 +1634,11 @@ def compute_correct_candidates(candidate_starts, candidate_ends, gold_starts, go def marginal_log_loss(logits, is_correct): """Loss based on the negative marginal log-likelihood.""" - def mask_to_score(mask): - return (1.0 - mask.type(torch.float32)) * -10000.0 + def mask_to_score(mask, dtype=torch.float32): + return (1.0 - mask.type(dtype)) * torch.finfo(dtype).min # [] - log_numerator = torch.logsumexp(logits + mask_to_score(is_correct), dim=-1) + log_numerator = torch.logsumexp(logits + mask_to_score(is_correct, dtype=logits.dtype), dim=-1) log_denominator = torch.logsumexp(logits, dim=-1) return log_denominator - log_numerator diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 1ead29326139..16b09bd2af74 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -643,7 +643,8 @@ def forward( attention_mask = (attention_ids < output_lengths.view(-1, 1)).long() # extend attention_mask - attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 78fac2cac1ec..59866f66d036 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -69,7 +69,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index d90c4c87b6bd..fd5b9186c289 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -49,7 +49,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index ae8ba4fa34b0..1f94f6f9ad27 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -910,8 +910,8 @@ def forward( start_logits, end_logits = start_logits.squeeze(1), end_logits.squeeze(1) if attention_mask is not None: - start_logits = start_logits + (1 - attention_mask) * -10000.0 - end_logits = end_logits + (1 - attention_mask) * -10000.0 + start_logits = start_logits + (1 - attention_mask) * torch.finfo(start_logits.dtype).min + end_logits = end_logits + (1 - attention_mask) * torch.finfo(end_logits.dtype).min total_loss = None if start_positions is not None and end_positions is not None: @@ -1060,8 +1060,8 @@ def forward( attention_mask_for_each_question = attention_mask.unsqueeze(1).expand( batch_size, num_questions, sequence_length ) - start_logits = start_logits + (1 - attention_mask_for_each_question) * -10000.0 - end_logits = end_logits + (1 - attention_mask_for_each_question) * -10000.0 + start_logits = start_logits + (1 - attention_mask_for_each_question) * torch.finfo(start_logits.dtype).min + end_logits = end_logits + (1 - attention_mask_for_each_question) * torch.finfo(end_logits.dtype).min total_loss = None # [batch_size, num_questions, sequence_length] diff --git a/src/transformers/models/t5/modeling_flax_t5.py b/src/transformers/models/t5/modeling_flax_t5.py index 23f7436eab80..4094599592a3 100644 --- a/src/transformers/models/t5/modeling_flax_t5.py +++ b/src/transformers/models/t5/modeling_flax_t5.py @@ -409,10 +409,11 @@ def __call__( # replace masked positions with -10_000 if attention_mask is not None: + mask_value = jnp.finfo(self.dtype).min attention_mask = jax.lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), - jnp.full(attention_mask.shape, -1e4).astype(self.dtype), + jnp.full(attention_mask.shape, mask_value).astype(self.dtype), ) if position_bias is None: diff --git a/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py b/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py index f647a13afead..b2c14029a074 100644 --- a/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py +++ b/src/transformers/models/trajectory_transformer/modeling_trajectory_transformer.py @@ -333,7 +333,7 @@ def forward( # [ batch_size x n_heads x sequence_length x sequence_length ] attn_weights = (torch.matmul(query, key.transpose(-2, -1))) * (1.0 / math.sqrt(key.size(-1))) attn_weights = attn_weights.masked_fill( - self.mask[:, :, :sequence_length, :sequence_length] == 0, float("-inf") + self.mask[:, :, :sequence_length, :sequence_length] == 0, torch.finfo(attn_weights.dtype).min ) attn_weights = F.softmax(attn_weights, dim=-1) self._attn_map = attn_weights.clone() diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_transfo_xl.py index 7986aa7af9b6..75793466c7a8 100644 --- a/src/transformers/models/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_transfo_xl.py @@ -327,21 +327,17 @@ def forward(self, w, r, attn_mask=None, mems=None, head_mask=None, output_attent attn_score = AC + BD attn_score.mul_(self.scale) + mask_value = torch.finfo(attn_score.dtype).min + # compute attention probability if attn_mask is not None and torch.sum(attn_mask).item(): attn_mask = attn_mask == 1 # Switch to bool if attn_mask.dim() == 2: - if next(self.parameters()).dtype == torch.float16: - attn_score = ( - attn_score.float().masked_fill(attn_mask[None, :, :, None], -65000).type_as(attn_score) - ) - else: - attn_score = attn_score.float().masked_fill(attn_mask[None, :, :, None], -1e30).type_as(attn_score) + attn_score = ( + attn_score.float().masked_fill(attn_mask[None, :, :, None], mask_value).type_as(attn_score) + ) elif attn_mask.dim() == 3: - if next(self.parameters()).dtype == torch.float16: - attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -65000).type_as(attn_score) - else: - attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], -1e30).type_as(attn_score) + attn_score = attn_score.float().masked_fill(attn_mask[:, :, :, None], mask_value).type_as(attn_score) # [qlen x klen x bsz x n_head] attn_prob = nn.functional.softmax(attn_score, dim=1) diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 3a960bb86f3c..a79e5e901d67 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -50,7 +50,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 8c17708fb726..744fbe731219 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -701,7 +701,8 @@ def forward( hidden_states[~expand_attention_mask] = 0 # extend attention_mask - attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) @@ -790,7 +791,8 @@ def forward( hidden_states[~expand_attention_mask] = 0 # extend attention_mask - attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 5c80c693e8ae..2ed4f43b515e 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -715,7 +715,8 @@ def forward( hidden_states[~expand_attention_mask] = 0 # extend attention_mask - attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) @@ -804,7 +805,8 @@ def forward( hidden_states[~expand_attention_mask] = 0 # extend attention_mask - attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 118ab1fe5c3c..983efada283b 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -1433,7 +1433,7 @@ def transpose_for_scores(self, x): def forward(self, query, key, attention_mask): attention_mask = attention_mask.to(query.dtype) attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = (1.0 - attention_mask) * torch.finfo(query.dtype).min mixed_query_layer = self.query(query) mixed_key_layer = self.key(key) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 14f5c02e724e..69ac90053323 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -749,7 +749,8 @@ def forward( hidden_states[~expand_attention_mask] = 0 # extend attention_mask - attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) @@ -837,7 +838,8 @@ def forward( hidden_states[~expand_attention_mask] = 0 # extend attention_mask - attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index d4972f65b43f..1cdab6f90415 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -895,7 +895,8 @@ def forward( hidden_states[~attention_mask] = 0.0 # extend attention_mask - attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index fa2a8c6eb606..1709f6dd5ad1 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -120,7 +120,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf"))) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index ebb3c503475c..79a1e0292e99 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -181,7 +181,7 @@ def unshape(x): q = q / math.sqrt(dim_per_head) # (bs, n_heads, qlen, dim_per_head) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen) mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen) - scores.masked_fill_(mask, -float("inf")) # (bs, n_heads, qlen, klen) + scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen) weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen) weights = nn.functional.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 7d09a77b70ec..5dcddd87f332 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -1632,7 +1632,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), float("-inf")) + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min)) mask_cond = torch.arange(mask.size(-1)) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) diff --git a/tests/models/fsmt/test_modeling_fsmt.py b/tests/models/fsmt/test_modeling_fsmt.py index 9e487b609aae..4cc4055a69f2 100644 --- a/tests/models/fsmt/test_modeling_fsmt.py +++ b/tests/models/fsmt/test_modeling_fsmt.py @@ -351,9 +351,10 @@ def test_prepare_fsmt_decoder_inputs(self): config, *_ = self._get_config_and_data() input_ids = _long_tensor(([4, 4, 2])) decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]]) - ignore = float("-inf") + causal_mask_dtype = torch.float32 + ignore = torch.finfo(causal_mask_dtype).min decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs( - config, input_ids, decoder_input_ids + config, input_ids, decoder_input_ids, causal_mask_dtype=causal_mask_dtype ) expected_causal_mask = torch.tensor( [[0, ignore, ignore], [0, 0, ignore], [0, 0, 0]] # never attend to the final token, because its pad