Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 73 additions & 32 deletions csrc/rocm/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
Expand All @@ -291,6 +292,13 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int rowid = laneid / 16;

const auto seq_idx = blockIdx.x;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if (query_start_loc_ptr != nullptr &&
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx]) != 1) {
return;
}

const auto partition_idx = blockIdx.y;

constexpr int T_PAR_SIZE = 256; // token partition size set to 256
Expand Down Expand Up @@ -377,9 +385,10 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// fetch Q in shared across warps and then write to registers
const int local_qhead_idx = 4 * warpid + rowid;
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
const int64_t seq_idx64 = static_cast<int64_t>(seq_idx);
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
const scalar_t* q_ptr =
q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE;
q + query_start_off * q_stride + global_qhead_idx * HEAD_SIZE;

const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B;
if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) {
Expand Down Expand Up @@ -777,6 +786,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
Expand All @@ -794,6 +804,12 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int lane4id = laneid % 4;

const auto seq_idx = blockIdx.x;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if (query_start_loc_ptr != nullptr &&
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
return;
}
const auto partition_idx = blockIdx.y;
const auto partition_size = blockDim.x;
const auto max_num_partitions = gridDim.y;
Expand Down Expand Up @@ -882,9 +898,11 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
}

// fetch q elements
// every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elems
// every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elemsc
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
const scalar_t* q_ptr =
q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
q + query_start_off * q_stride + wg_start_head_idx * HEAD_SIZE;
const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr);
const int qhead_elemh8 = laneid / 4;

Expand Down Expand Up @@ -1267,10 +1285,19 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions) {
const auto num_heads = gridDim.x;
const auto head_idx = blockIdx.x;
const auto seq_idx = blockIdx.y;

// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if (query_start_loc_ptr != nullptr &&
(query_start_loc_ptr[seq_idx + 1] - query_start_loc_ptr[seq_idx] != 1)) {
return;
}

const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
[[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
Expand Down Expand Up @@ -1439,7 +1466,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
acc *= inv_global_exp_sum;

OUTT* out_ptr = out + static_cast<int64_t>(seq_idx) * num_heads * HEAD_SIZE +
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
static_cast<int64_t>(head_idx) * HEAD_SIZE;
if constexpr (std::is_same<OUTT, bit8_t>::value) {
out_ptr[threadIdx.x] =
Expand All @@ -1466,6 +1495,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
Expand All @@ -1492,6 +1522,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
Expand All @@ -1515,41 +1546,42 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions) {
UNREACHABLE_CODE
}
// clang-format on

#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support

#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
k_scale_ptr, v_scale_ptr);

#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma16_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);

#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
HEAD_SIZE, NTHR, ALIBI_ENABLED, \
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);

#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, max_num_partitions);
context_lens_ptr, query_start_loc_ptr, max_num_partitions);

template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
Expand All @@ -1559,16 +1591,24 @@ void paged_attention_custom_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
int max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
torch::Tensor& k_scale, torch::Tensor& v_scale) {
int num_seqs = query.size(0);
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
int num_seqs = block_tables.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);

// NOTE: query start location is optional for V0 decode should not be used.
// If batch contains mix of prefills and decode, prefills should be skipped.
const int* query_start_loc_ptr =
query_start_loc
? reinterpret_cast<const int*>(query_start_loc.value().data_ptr())
: nullptr;

// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
Expand Down Expand Up @@ -1700,8 +1740,8 @@ void paged_attention_custom_launcher(
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
alibi_slopes, k_scale, v_scale);
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale);

#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
PSIZE) \
Expand Down Expand Up @@ -1750,6 +1790,7 @@ void paged_attention(
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
const std::optional<torch::Tensor>& query_start_loc, // [num_seqs]
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
Expand Down
5 changes: 3 additions & 2 deletions csrc/rocm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads,
double scale, torch::Tensor& block_tables,
torch::Tensor& context_lens, int64_t block_size,
int64_t max_context_len,
torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc,
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale);
4 changes: 3 additions & 1 deletion csrc/rocm/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables,"
" Tensor context_lens, int block_size,"
" Tensor context_lens,"
" Tensor? query_start_loc,"
" int block_size,"
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
Expand Down
2 changes: 1 addition & 1 deletion requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ tqdm
blake3
py-cpuinfo
transformers >= 4.50.3
huggingface-hub[hf-xet] >= 0.30.0 # Required for Xet downloads.
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
Expand Down
2 changes: 1 addition & 1 deletion requirements/test.in
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test
transformers==4.50.3
huggingface-hub[hf-xet]>=0.30.0 # Required for Xet downloads.
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
# quantization
bitsandbytes>=0.45.3
buildkite-test-collector==0.1.9
Expand Down
4 changes: 4 additions & 0 deletions tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def test_contexted_kv_attention(
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
Expand All @@ -180,6 +181,7 @@ def test_contexted_kv_attention(
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
Expand Down Expand Up @@ -397,6 +399,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
Expand All @@ -413,6 +416,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
block_table,
b_start_loc,
b_seq_len,
MAX_CTX_LEN,
max_input_len,
k_scale,
v_scale,
Expand Down
6 changes: 4 additions & 2 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def paged_attention_rocm(
scale: float,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
query_start_loc: Optional[torch.Tensor],
block_size: int,
max_seq_len: int,
alibi_slopes: Optional[torch.Tensor],
Expand All @@ -120,8 +121,9 @@ def paged_attention_rocm(
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale)
query_start_loc, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale,
v_scale)


def mla_decode_kvcache_cpu(
Expand Down
24 changes: 6 additions & 18 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@
PagedAttentionMetadata)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

logger = init_logger(__name__)

_PARTITION_SIZE_ROCM = 256
_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
_ON_NAVI = "gfx1" in _GPU_ARCH
_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"])


class ROCmFlashAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -790,9 +787,9 @@ def forward(
num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads
use_custom = _use_rocm_custom_paged_attention(
use_custom = use_rocm_custom_paged_attention(
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len)
decode_meta.max_decode_seq_len, self.sliding_window)
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
Expand All @@ -817,6 +814,8 @@ def forward(
out = output[num_prefill_tokens:]
else:
out = output

query_start_loc = None
ops.paged_attention_rocm(
out,
exp_sums,
Expand All @@ -833,6 +832,7 @@ def forward(
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
query_start_loc,
block_size,
max_seq_len,
self.alibi_slopes,
Expand Down Expand Up @@ -898,15 +898,3 @@ def _sdpa_attention(
start = end

return output


def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int) -> bool:
# rocm custom page attention not support on navi (gfx1*)
return (_ON_MI250_MI300 and not _ON_NAVI
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
Loading