Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 34 additions & 57 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask):
)


def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
def build_alibi_tensor(
max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
Expand All @@ -111,66 +113,41 @@ def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
number of heads
dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
device of the output alibi tensor
"""
closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
slopes = torch.pow(base, powers)

if closest_power_of_2 != n_head:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)

def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]

if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)

slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1)
arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0)
alibi = slopes * arange_tensor.expand(n_head, -1, -1)

alibi = alibi.to(dtype)

return alibi
lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)


def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
"""
Args:
Pre-process the alibi tensor for padding.
alibi: ([`torch.tensor`], *required*):
alibi tensor to pre-process
attention_mask: ([`torch.tensor`], *required*):
attention mask to pre-process"""

# Sanity check if we are not inferring less tokens than the total sequence length
# This usually happens when the inference is done with past_key_values
# In this case we re-create the alibi tensor with the correct sequence length
if attention_mask.shape[-1] != alibi.shape[-1]:
alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.dtype).repeat(
attention_mask.shape[0], 1, 1
)
# Get the indexes of the padding tokens
index_x0, index_y0 = torch.where(attention_mask == 0.0)
index_x1, index_y1 = torch.where(attention_mask == 1.0)

# Clone the embeddings - we can detach because the embeddings are not learned
# Get a refence tensor
slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.dtype)

# Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
# Only where you do not have padding. Replace padding tokens by zeros
# This operation can be seen as a shifting operation.
for i, index in enumerate(torch.unique(index_x0)):
slice_to_modify = torch.zeros_like(slice_reference_alibi)
index_shift = index_y1[index_x1 == index]
shift_value = len(index_shift)
slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value]
alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify
return alibi
attention mask to pre-process
"""

unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
# ^-- [batch, max_len], values correspond to element indices after removing padding
# We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)


def dropout_add(x, residual, prob, training):
Expand Down Expand Up @@ -359,12 +336,12 @@ def forward(
output_attentions=False,
):
# hidden_states: [batch_size, seq_length, hidden_size]
# repeat alibi tensor with the batch size
alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device)

# apply preprocessing if the input is padded
if attention_mask is not None and 0 in attention_mask:
alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads)
if attention_mask is not None:
alibi = pre_process_alibi_for_pad(alibi, attention_mask)
else:
# otherwise repeat alibi tensor with the batch size
alibi = alibi.repeat(hidden_states.shape[0], 1, 1)

mixed_x_layer = self.query_key_value(hidden_states)

Expand Down Expand Up @@ -773,7 +750,7 @@ def forward(
current_sequence_length = hidden_states.shape[1]
if past_key_values[0] is not None:
current_sequence_length += past_key_values[0][0].shape[1]
alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype, hidden_states.device)

for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):

Expand Down