diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 6cc0ef7c9947..f598c8299d15 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -641,6 +641,9 @@ def __init__(self, config: BloomConfig): # Initialize weights and apply final processing self.post_init() + def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: + return build_alibi_tensor(attention_mask, num_heads, dtype) + def get_input_embeddings(self): return self.word_embeddings @@ -750,7 +753,7 @@ def forward( else: attention_mask = attention_mask.to(hidden_states.device) - alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) causal_mask = self._prepare_attn_mask( attention_mask,