Skip to content
3 changes: 3 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,12 +208,15 @@ For now, Transformers supports SDPA inference and training for the following arc
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Pop2Piano](https://huggingface.co/docs/transformers/model_doc/pop2piano#transformers.Pop2PianoForConditionalGeneration)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [MT5](https://huggingface.co/docs/transformers/model_doc/mt5#transformers.MT5Model)
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
* [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Model)
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
Expand Down
96 changes: 55 additions & 41 deletions src/transformers/models/longt5/modeling_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,41 @@ def compute_bias(self, query_length, key_length, device=None):
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values

def _shape(self, states, batch_size):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

def _unshape(self, states, batch_size):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

def _project(self, hidden_states, proj_layer, key_value_states, past_key_value, batch_size):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = self._shape(proj_layer(hidden_states), batch_size)
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = self._shape(proj_layer(key_value_states), batch_size)

if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = self._shape(proj_layer(key_value_states), batch_size)
else:
# cross-attn
hidden_states = past_key_value
return hidden_states

def forward(
self,
hidden_states,
Expand Down Expand Up @@ -451,50 +486,25 @@ def forward(

key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]

def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))

if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else:
# cross-attn
hidden_states = past_key_value
return hidden_states

# get query states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
query_states = self._shape(
self.q(hidden_states), batch_size
) # (batch_size, n_heads, seq_length, dim_per_head)
Comment on lines +490 to +492
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Library convention is for comments to go on the line above to avoid line splitting

Suggested change
query_states = self._shape(
self.q(hidden_states), batch_size
) # (batch_size, n_heads, seq_length, dim_per_head)
# (batch_size, n_heads, seq_length, dim_per_head)
query_states = self._shape(self.q(hidden_states), batch_size)


# get key/value states
key_states = project(
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
key_states = self._project(
hidden_states,
self.k,
key_value_states,
past_key_value[0] if past_key_value is not None else None,
batch_size,
)
value_states = project(
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
value_states = self._project(
hidden_states,
self.v,
key_value_states,
past_key_value[1] if past_key_value is not None else None,
batch_size,
)

# compute scores
Expand Down Expand Up @@ -539,7 +549,9 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask

attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
attn_output = self._unshape(
torch.matmul(attn_weights, value_states), batch_size
) # (batch_size, seq_length, dim)
Comment on lines +552 to +554
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attn_output = self._unshape(
torch.matmul(attn_weights, value_states), batch_size
) # (batch_size, seq_length, dim)
# (batch_size, seq_length, dim)
attn_output = self._unshape(torch.matmul(attn_weights, value_states), batch_size)

attn_output = self.o(attn_output)

present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
Expand Down Expand Up @@ -1007,6 +1019,7 @@ def unshape(states):

# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5
class LongT5LayerSelfAttention(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason for not adding support for LongT5?

# Ignore copy
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.SelfAttention = LongT5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
Expand Down Expand Up @@ -1104,6 +1117,7 @@ def forward(

# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5
class LongT5LayerCrossAttention(nn.Module):
# Ignore copy
def __init__(self, config):
super().__init__()
self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False)
Expand Down
Loading