Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 102 additions & 42 deletions vllm/v1/attention/ops/triton_unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,31 +231,57 @@ def kernel_unified_attention_2d(

# iterate through tiles (now limited to the sliding window range)
for j in range(tile_start, tile_end):
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len

physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)

v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
if TILE_SIZE == BLOCK_SIZE:
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j).to(
tl.int64
)

k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
v_offset = (
physical_block_idx * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ offs_t[:, None] * stride_v_cache_1
)

k_offset = (
physical_block_idx * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ offs_t[None, :] * stride_k_cache_1
)

K_load_mask = dim_mask[:, None]
V_load_mask = dim_mask[None, :]
seq_offset = j * BLOCK_SIZE + offs_t
else:
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len

physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)

v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)

k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)

K_load_mask = dim_mask[:, None] & tile_mask[None, :]
V_load_mask = dim_mask[None, :] & tile_mask[:, None]

# K : (HEAD_SIZE, TILE_SIZE)
K_load = tl.load(
key_cache_ptr + k_offset,
mask=dim_mask[:, None] & tile_mask[None, :],
mask=K_load_mask,
other=0.0,
)

Expand All @@ -270,7 +296,7 @@ def kernel_unified_attention_2d(
# V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(
value_cache_ptr + v_offset,
mask=dim_mask[None, :] & tile_mask[:, None],
mask=V_load_mask,
other=0.0,
)

Expand Down Expand Up @@ -586,31 +612,57 @@ def kernel_unified_attention_3d(
max(segm_idx * tiles_per_segment, tile_start),
min((segm_idx + 1) * tiles_per_segment, tile_end),
):
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len

physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)

v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
if TILE_SIZE == BLOCK_SIZE:
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j).to(
tl.int64
)

k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
v_offset = (
physical_block_idx * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ offs_t[:, None] * stride_v_cache_1
)

k_offset = (
physical_block_idx * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ offs_t[None, :] * stride_k_cache_1
)

K_load_mask = dim_mask[:, None]
V_load_mask = dim_mask[None, :]
seq_offset = j * BLOCK_SIZE + offs_t
else:
seq_offset = j * TILE_SIZE + offs_t
tile_mask = seq_offset < max_seq_prefix_len

physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)

v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)

k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)

K_load_mask = dim_mask[:, None] & tile_mask[None, :]
V_load_mask = dim_mask[None, :] & tile_mask[:, None]

# K : (HEAD_SIZE, TILE_SIZE)
K_load = tl.load(
key_cache_ptr + k_offset,
mask=dim_mask[:, None] & tile_mask[None, :],
mask=K_load_mask,
other=0.0,
)

Expand All @@ -625,7 +677,7 @@ def kernel_unified_attention_3d(
# V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(
value_cache_ptr + v_offset,
mask=dim_mask[None, :] & tile_mask[:, None],
mask=V_load_mask,
other=0.0,
)

Expand Down Expand Up @@ -971,6 +1023,14 @@ def unified_attention(
is_prefill=False,
)

if (
current_platform.is_rocm()
and current_platform.is_navi()
and q.element_size() >= 2
):
TILE_SIZE_PREFILL = block_size
TILE_SIZE_DECODE = block_size

# Launch the 2D kernel if
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
# 2. The batch includes at least one prefill request, or
Expand Down
Loading