Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions vllm/model_executor/models/granite_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,8 @@ def forward(
# shaw's relative positional embedding
dist = attention_dists.to(hidden_states.device)
rel_pos_emb = self.rel_pos_emb(dist)
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
pos_attn = (
torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1)
torch.einsum("bnhid,ijd->bnhij", query_states, rel_pos_emb)
* self.scale
)

Expand Down
Loading