Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,15 +751,7 @@ def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
# encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility

if self.dtype == torch.float16:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
elif self.dtype in [torch.bfloat16, torch.float32]:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else:
raise ValueError(
f"{self.dtype} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`"
)
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min

return encoder_extended_attention_mask

Expand Down Expand Up @@ -792,7 +784,7 @@ def create_extended_attention_mask_for_decoder(input_shape, attention_mask, devi
return extended_attention_mask

def get_extended_attention_mask(
self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None
self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None, dtype: torch.float = None
) -> Tensor:
"""
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
Expand All @@ -806,6 +798,9 @@ def get_extended_attention_mask(
Returns:
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
"""
if dtype is None:
dtype = self.dtype

if not (attention_mask.dim() == 2 and self.config.is_decoder):
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
if device is not None:
Expand Down Expand Up @@ -836,8 +831,8 @@ def get_extended_attention_mask(
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a dtype argument, so we can use it in LayoutLMv3Model

Copy link
Collaborator Author

@ydshieh ydshieh Jun 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return extended_attention_mask

def get_head_mask(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def forward(

extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho, the incantation is obtuse enough to make a utility function:

def near_neg_inf(dtype):
    return torch.finfo(self.dtype).min

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, I think I suggested to do the exact opposite for readability - what do you think @sgugger @LysandreJik ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally prefer to avoid those kinds of one-line functions as it makes it harder to read the code, with the reader having to navigate.

head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

embedding_output = self.embeddings(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), float("-inf"))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.tensor(float("-inf")))
mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
mask_cond = torch.arange(mask.size(-1))
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/canine/modeling_canine.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def forward(
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
attention_mask = (1.0 - attention_mask.float()) * -10000.0
attention_mask = (1.0 - attention_mask.float()) * torch.finfo(attention_scores.dtype).min
# Apply the attention mask (precomputed for all layers in CanineModel forward() function)
attention_scores = attention_scores + attention_mask

Expand Down
10 changes: 6 additions & 4 deletions src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,9 @@ def forward(
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
Expand Down Expand Up @@ -670,11 +672,11 @@ def forward(
attentions=encoder_outputs.attentions,
)

def _build_causal_attention_mask(self, bsz, seq_len):
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len)
mask.fill_(torch.tensor(float("-inf")))
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/ctrl/modeling_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def forward(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

# Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/data2vec/modeling_data2vec_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,8 @@ def forward(
hidden_states[~expand_attention_mask] = 0

# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
attention_mask = attention_mask.expand(
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,11 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)

if attention_mask is not None:
# Apply the attention mask
Expand Down Expand Up @@ -239,7 +243,11 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)

if attention_mask is not None:
# Apply the attention mask
Expand Down Expand Up @@ -578,7 +586,7 @@ def forward(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1809,7 +1809,7 @@ def forward(self, q, k, mask: Optional[Tensor] = None):
weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head)

if mask is not None:
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min)
weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size())
weights = self.dropout(weights)
return weights
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)
mask = (mask == 0).view(mask_reshp).expand_as(scores) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(mask, torch.tensor(-float("inf"))) # (bs, n_heads, q_length, k_length)
scores = scores.masked_fill(
mask, torch.tensor(torch.finfo(scores.dtype).min)
) # (bs, n_heads, q_length, k_length)

weights = nn.functional.softmax(scores, dim=-1) # (bs, n_heads, q_length, k_length)
weights = self.dropout(weights) # (bs, n_heads, q_length, k_length)
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/fsmt/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,8 @@ def _prepare_fsmt_decoder_inputs(
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else:
decoder_padding_mask = invert_mask(decoder_padding_mask)
causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1).to(
dtype=causal_mask_dtype, device=decoder_input_ids.device
causal_mask = triu_onnx(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len, dtype=causal_mask_dtype)), 1).to(
device=decoder_input_ids.device
)
return decoder_input_ids, decoder_padding_mask, causal_mask

Expand Down Expand Up @@ -908,7 +908,7 @@ def forward(
if key_padding_mask is not None: # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.masked_fill(reshaped, torch.finfo(attn_weights.dtype).min)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

attn_weights = nn.functional.softmax(attn_weights, dim=-1)
Expand Down Expand Up @@ -975,7 +975,7 @@ def _use_saved_state(self, k, v, saved_state, key_padding_mask, static_kv, bsz):

def fill_with_neg_inf(t):
"""FP16-compatible function that fills a input_ids with -inf."""
return t.float().fill_(float("-inf")).type_as(t)
return t.float().fill_(torch.finfo(t.dtype).min).type_as(t)


# Public API
Expand Down
14 changes: 11 additions & 3 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,11 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)

if attention_mask is not None:
# Apply the attention mask
Expand Down Expand Up @@ -250,7 +254,11 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)

if attention_mask is not None:
# Apply the attention mask
Expand Down Expand Up @@ -811,7 +819,7 @@ def forward(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):

query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)

if attention_mask is not None:
# Apply the attention mask
Expand Down Expand Up @@ -566,7 +570,7 @@ def forward(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down
10 changes: 7 additions & 3 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,11 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_scores = torch.einsum("bik,bjk->bij", query, key) / self.norm_factor
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)

attn_scores = torch.where(causal_mask, attn_scores, self.masked_bias.to(attn_scores.dtype))
mask_value = torch.finfo(attn_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
attn_scores = torch.where(causal_mask, attn_scores, mask_value)

if attention_mask is not None:
# Apply the attention mask
Expand All @@ -214,7 +218,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):


def attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(~ltor_mask, -10000.0)
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
return attention_scores


Expand Down Expand Up @@ -460,7 +464,7 @@ def forward(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down
Loading