diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 9bf200de5bd3..b209146cb7b6 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -305,7 +305,7 @@ def forward( attn_weights = (attention_scores * self.layer_number) + attention_mask attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) - attention_probs = attention_probs * (~attention_mask.bool()) + attention_probs = attention_probs # [batch_size, num_heads, q_length, k_length] attention_probs = self.attention_dropout(attention_probs)