Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_
Make causal mask used for bi-directional self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
mask = torch.full((target_length, target_length), -torch.inf)
mask_cond = torch.arange(mask.size(-1))
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
mask.masked_fill_(intermediate_mask, 0)
Expand All @@ -79,7 +79,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):

inverted_mask = 1.0 - expanded_mask

return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), -torch.inf)


def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) -> torch.Tensor:
Expand Down Expand Up @@ -303,7 +303,9 @@ def forward(
# We replace the scaled softmax by just a few line of code - [batch_size, num_heads, q_length, k_length]
input_dtype = attention_scores.dtype
attn_weights = (attention_scores * self.layer_number) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attn_weights = torch.clip(
attn_weights, torch.finfo(attn_weights.dtype).min, torch.finfo(attn_weights.dtype).max
)
attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
attention_probs = attention_probs * (~attention_mask.bool())
# [batch_size, num_heads, q_length, k_length]
Expand Down Expand Up @@ -599,7 +601,6 @@ def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_ke
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stupid question but when you sum torch.finfo(dtype).min with torch.finfo(dtype).min, it's not the masked value anymore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.finfo(dtype).min + torch.finfo(dtype).min = -inf

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would get something like this:

>> print(attention_mask)
tensor([[[[     0., -65504., -65504., -65504., -65504., -65504., -65504.],
          [     0.,      0., -65504., -65504., -65504., -65504., -65504.],
          [     0.,      0.,      0., -65504., -65504., -65504., -65504.],
          [     0.,      0.,      0.,      0., -65504., -65504., -65504.],
          [     0.,      0.,      0.,      0.,      0., -65504., -65504.],
          [     0.,      0.,      0.,      0.,      0.,      0., -65504.],
          [     0.,      0.,      0.,      0.,      0.,      0.,      0.]]],


        [[[-65504.,    -inf,    -inf, -65504., -65504., -65504., -65504.],
          [-65504., -65504.,    -inf, -65504., -65504., -65504., -65504.],
          [-65504., -65504., -65504., -65504., -65504., -65504., -65504.],
          [-65504., -65504., -65504.,      0., -65504., -65504., -65504.],
          [-65504., -65504., -65504.,      0.,      0., -65504., -65504.],
          [-65504., -65504., -65504.,      0.,      0.,      0., -65504.],
          [-65504., -65504., -65504.,      0.,      0.,      0.,      0.]]]],
       device='cuda:0', dtype=torch.float16)

)

return combined_attention_mask

def set_input_embeddings(self, new_embeddings):
Expand Down