Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions vllm/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,12 @@ def kernel_unified_attention_2d(
L = L * alpha + l_j
M = m_j

if SLIDING_WINDOW:
qpos_lo = q_block_local_idx * BLOCK_Q
V = tl.where(
(context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0
)

# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V)

Expand Down Expand Up @@ -672,6 +678,12 @@ def kernel_unified_attention_3d(
L = L * alpha + l_j
M = m_j

if SLIDING_WINDOW:
qpos_lo = q_block_local_idx * BLOCK_Q
V = tl.where(
(context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0
)

# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V)

Expand Down