diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 8890270d397b..3e421feb2cd0 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -317,7 +317,7 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + bsz, _, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) @@ -351,13 +351,15 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) + query_length = query_states.shape[1] + tgt_len = key_states.shape[-2] + # Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim - query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim) + query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) - _, query_length, _, _ = query_states.shape attn_dropout = self.dropout if self.training else 0.0