diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index d2a1cb5755f3..5b4959dc2055 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -389,10 +389,8 @@ def forward( # shaw's relative positional embedding dist = attention_dists.to(hidden_states.device) rel_pos_emb = self.rel_pos_emb(dist) - rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) pos_attn = ( - torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) - * self.scale + torch.einsum("bnhid,ijd->bnhij", query_states, rel_pos_emb) * self.scale ) if remainder > 0: