diff --git a/vllm/v1/attention/ops/triton_unified_attention.py b/vllm/v1/attention/ops/triton_unified_attention.py index 4ddd47c6dd65..c2ea2c025a0b 100644 --- a/vllm/v1/attention/ops/triton_unified_attention.py +++ b/vllm/v1/attention/ops/triton_unified_attention.py @@ -231,31 +231,57 @@ def kernel_unified_attention_2d( # iterate through tiles (now limited to the sliding window range) for j in range(tile_start, tile_end): - seq_offset = j * TILE_SIZE + offs_t - tile_mask = seq_offset < max_seq_prefix_len - - physical_block_idx = tl.load( - block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE - ).to(tl.int64) - - v_offset = ( - physical_block_idx[:, None] * stride_v_cache_0 - + kv_head_idx * stride_v_cache_2 - + offs_d[None, :] * stride_v_cache_3 - + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 - ) + if TILE_SIZE == BLOCK_SIZE: + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j).to( + tl.int64 + ) - k_offset = ( - physical_block_idx[None, :] * stride_k_cache_0 - + kv_head_idx * stride_k_cache_2 - + offs_d[:, None] * stride_k_cache_3 - + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 - ) + v_offset = ( + physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + offs_t[:, None] * stride_v_cache_1 + ) + + k_offset = ( + physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_t[None, :] * stride_k_cache_1 + ) + + K_load_mask = dim_mask[:, None] + V_load_mask = dim_mask[None, :] + seq_offset = j * BLOCK_SIZE + offs_t + else: + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len + + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) + + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) + + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + + K_load_mask = dim_mask[:, None] & tile_mask[None, :] + V_load_mask = dim_mask[None, :] & tile_mask[:, None] # K : (HEAD_SIZE, TILE_SIZE) K_load = tl.load( key_cache_ptr + k_offset, - mask=dim_mask[:, None] & tile_mask[None, :], + mask=K_load_mask, other=0.0, ) @@ -270,7 +296,7 @@ def kernel_unified_attention_2d( # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load( value_cache_ptr + v_offset, - mask=dim_mask[None, :] & tile_mask[:, None], + mask=V_load_mask, other=0.0, ) @@ -586,31 +612,57 @@ def kernel_unified_attention_3d( max(segm_idx * tiles_per_segment, tile_start), min((segm_idx + 1) * tiles_per_segment, tile_end), ): - seq_offset = j * TILE_SIZE + offs_t - tile_mask = seq_offset < max_seq_prefix_len - - physical_block_idx = tl.load( - block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE - ).to(tl.int64) - - v_offset = ( - physical_block_idx[:, None] * stride_v_cache_0 - + kv_head_idx * stride_v_cache_2 - + offs_d[None, :] * stride_v_cache_3 - + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 - ) + if TILE_SIZE == BLOCK_SIZE: + physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j).to( + tl.int64 + ) - k_offset = ( - physical_block_idx[None, :] * stride_k_cache_0 - + kv_head_idx * stride_k_cache_2 - + offs_d[:, None] * stride_k_cache_3 - + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 - ) + v_offset = ( + physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + offs_t[:, None] * stride_v_cache_1 + ) + + k_offset = ( + physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + offs_t[None, :] * stride_k_cache_1 + ) + + K_load_mask = dim_mask[:, None] + V_load_mask = dim_mask[None, :] + seq_offset = j * BLOCK_SIZE + offs_t + else: + seq_offset = j * TILE_SIZE + offs_t + tile_mask = seq_offset < max_seq_prefix_len + + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) + + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) + + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + + K_load_mask = dim_mask[:, None] & tile_mask[None, :] + V_load_mask = dim_mask[None, :] & tile_mask[:, None] # K : (HEAD_SIZE, TILE_SIZE) K_load = tl.load( key_cache_ptr + k_offset, - mask=dim_mask[:, None] & tile_mask[None, :], + mask=K_load_mask, other=0.0, ) @@ -625,7 +677,7 @@ def kernel_unified_attention_3d( # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load( value_cache_ptr + v_offset, - mask=dim_mask[None, :] & tile_mask[:, None], + mask=V_load_mask, other=0.0, ) @@ -971,6 +1023,14 @@ def unified_attention( is_prefill=False, ) + if ( + current_platform.is_rocm() + and current_platform.is_navi() + and q.element_size() >= 2 + ): + TILE_SIZE_PREFILL = block_size + TILE_SIZE_DECODE = block_size + # Launch the 2D kernel if # 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or # 2. The batch includes at least one prefill request, or