Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention, GPTBigCodeForCausalLM
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
GPTBigCodeAttention,
GPTBigCodeForCausalLM,
upcast_masked_softmax,
upcast_softmax,
)

from ...modeling_attn_mask_utils import GaudiAttentionMaskConverter

Expand Down Expand Up @@ -57,6 +62,90 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA is not None else None
self.block_size = 4096

def _attn(self, query, key, value, attention_mask=None, head_mask=None):
"""
This method should be deleted when https://github.com/huggingface/transformers/pull/34508 is merged.
Copied from GPTBigCodeAttention._attn: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
The only differences are:
- in self._attn, use torch.matmul instead of torch.baddbmm when the device used for query is not cpu
"""
dtype = query.dtype
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else dtype
upcast = dtype != softmax_dtype

unscale = self.layer_idx + 1 if self.scale_attention_softmax_in_fp32 and upcast else 1
scale_factor = unscale**-1
if self.scale_attn_weights:
scale_factor /= self.head_dim**0.5

# MQA models: (batch_size, query_length, num_heads * head_dim)
# MHA models: (batch_size, num_heads, query_length, head_dim)
query_shape = query.shape
batch_size = query_shape[0]
key_length = key.size(-1)
if self.multi_query:
# (batch_size, query_length, num_heads, head_dim) x (batch_size, head_dim, key_length)
# -> (batch_size, query_length, num_heads, key_length)
query_length = query_shape[1]
attn_shape = (batch_size, query_length, self.num_heads, key_length)
attn_view = (batch_size, query_length * self.num_heads, key_length)
# No copy needed for MQA 2, or when layer_past is provided.
query = query.reshape(batch_size, query_length * self.num_heads, self.head_dim)
else:
# (batch_size, num_heads, query_length, head_dim) x (batch_size, num_heads, head_dim, key_length)
# -> (batch_size, num_heads, query_length, key_length)
query_length = query_shape[2]
attn_shape = (batch_size, self.num_heads, query_length, key_length)
attn_view = (batch_size * self.num_heads, query_length, key_length)
# Always copies
query = query.reshape(batch_size * self.num_heads, query_length, self.head_dim)
# No copy when layer_past is provided.
key = key.reshape(batch_size * self.num_heads, self.head_dim, key_length)

attn_weights = torch.empty(attn_view, device=query.device, dtype=query.dtype)
if query.device.type == "cpu":
# This is needed because of a bug in pytorch https://github.com/pytorch/pytorch/issues/80588.
# The bug was fixed in https://github.com/pytorch/pytorch/pull/96086,
# but the fix has not been released as of pytorch version 2.0.0.
attn_weights = torch.zeros_like(attn_weights)
attn_weights = torch.baddbmm(attn_weights, query, key, beta=1, alpha=scale_factor).view(attn_shape)
else:
# Formula for torch.baddbmm: out = beta * attn_weights + scale_factor * (query ⋅ key)
# for beta = 0, it simplifies to: out = scale_factor * (query ⋅ key)
attn_weights = (torch.matmul(query, key) * scale_factor).view(attn_shape)

if upcast:
# Use a fused kernel to prevent a large overhead from casting and scaling.
# Sub-optimal when the key length is not a multiple of 8.
if attention_mask is None:
attn_weights = upcast_softmax(attn_weights, unscale, softmax_dtype)
else:
mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)
attn_weights = upcast_masked_softmax(attn_weights, attention_mask, mask_value, unscale, softmax_dtype)
else:
if attention_mask is not None:
mask_value = self._get_mask_value(attn_weights.device, softmax_dtype)

# The fused kernel is very slow when the key length is not a multiple of 8, so we skip fusion.
attn_weights = torch.where(attention_mask, attn_weights, mask_value)

attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)

attn_weights = self.attn_dropout(attn_weights)

# Mask heads if we want to
if head_mask is not None:
if self.multi_query:
head_mask = head_mask.transpose(1, 2)
attn_weights = attn_weights * head_mask

if self.multi_query:
attn_output = torch.bmm(attn_weights.view(attn_view), value).view(query_shape)
else:
attn_output = torch.matmul(attn_weights, value)

return attn_output, attn_weights

def gaudi_flash_attn_v1(
self,
query_layer,
Expand Down