Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 137 additions & 25 deletions csrc/rocm/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2410,6 +2410,45 @@ typedef struct _B8x16 {
_B8x8 xy[2];
} _B8x16;

// Convert 8 FP8 (e4m3) values to 8 f16/bf16 values for software dequant.
// Uses HIP runtime FP8 conversion (portable across gfx11/gfx12).
template <typename T>
__device__ __forceinline__ _B16x8 convert_b8x8_to_b16x8(const _B8x8 input) {
_B16x8 ret;
if constexpr (std::is_same<T, _Float16>::value) {
union {
uint4 u32x4;
_B16x8 b16x8;
} cvt;
cvt.u32x4 = vllm::fp8::vec_conversion<uint4, uint2>(input);
ret = cvt.b16x8;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
// Reuse vector fp8->half conversion, then convert 4 half2 packs to bf16x2.
union {
uint4 u32x4;
uint32_t u32[4];
} f16x8;
f16x8.u32x4 = vllm::fp8::vec_conversion<uint4, uint2>(input);
union {
__hip_bfloat162 bf16x2[4];
_B16x8 b16x8;
} cvt;
#pragma unroll
for (int i = 0; i < 4; i++) {
union {
uint32_t u32;
__half2 h2;
} half2_u;
half2_u.u32 = f16x8.u32[i];
cvt.bf16x2[i] = __float22bfloat162_rn(__half22float2(half2_u.h2));
}
ret = cvt.b16x8;
} else {
static_assert(false, "unsupported 16b dtype");
}
return ret;
}

template <typename T, int absz, int cbid, int blgp>
__device__ __forceinline__ floatx8 gcn_wmma16x16x16_instr(const bit16x8& inpA,
const bit16x8& inpB,
Expand Down Expand Up @@ -2566,8 +2605,12 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 2xQKHE_16B across
// warp

_B16x8 Qlocal[QKHELOOP]; // note that 16 contiguous elements of Q should
// be fetched per lane for 16 bit cache types
// Q loading always uses scalar_t-based constants (independent of cache type)
constexpr int Q_ELEMS_16B = 16 / sizeof(scalar_t); // always 8 for f16/bf16
constexpr int Q_PER_FETCH = Q_ELEMS_16B * ROWS_PER_WARP; // always 16
constexpr int QHELOOP = HEAD_SIZE / Q_PER_FETCH; // always HEAD_SIZE/16

_B16x8 Qlocal[QHELOOP];

constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t);

Expand Down Expand Up @@ -2598,12 +2641,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int local_qhead_idx = lane16id % GQA_RATIO;
const int global_qhead_idx = wg_start_head_idx + local_qhead_idx;
const scalar_t* q_ptr = q + query_start_off * q_stride +
global_qhead_idx * HEAD_SIZE +
rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD;
global_qhead_idx * HEAD_SIZE + rowid * Q_ELEMS_16B;
if (lane16id < GQA_RATIO) {
#pragma unroll
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * QKHE_PER_FETCH;
for (int qkhe_depth = 0; qkhe_depth < QHELOOP; qkhe_depth++) {
const scalar_t* q_fetch_ptr = q_ptr + qkhe_depth * Q_PER_FETCH;
const _B16x8* q_fetch_ptr_16B =
reinterpret_cast<const _B16x8*>(q_fetch_ptr);
Qlocal[qkhe_depth] = *q_fetch_ptr_16B;
Expand Down Expand Up @@ -2632,7 +2674,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
__syncthreads();

#pragma unroll
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
for (int qkhe_depth = 0; qkhe_depth < QHELOOP; qkhe_depth++) {
Qlocal[qkhe_depth] =
shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO][0];
}
Expand Down Expand Up @@ -2739,16 +2781,56 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
}
}

// calculate post qk wmma scale
float scale2 = scale;
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {
scale2 *= *k_scale;
}

floatx8 dout[TLOOP];
// qk wmma
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
dout[token_depth] = {0};
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
dout[token_depth] = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Klocal[token_depth][qkhe_depth].u16x8, Qlocal[qkhe_depth].u16x8,
dout[token_depth]);
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
dout[token_depth] = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Klocal[token_depth][qkhe_depth].u16x8, Qlocal[qkhe_depth].u16x8,
dout[token_depth]);
} else {
// FP8 KV cache: each row loads 16 FP8 K values covering different
// head-element ranges (row 0: base+[0..15], row 1: base+[16..31]).
// But Q is loaded with 16 contiguous head elements per WMMA split
// across both rows (row 0: lower 8, row 1: upper 8).
// Splitting the 16 FP8 bytes into xy[0]/xy[1] by byte position
// creates a cross-row mismatch:
// j=0: row0=[0..7] OK, row1=[16..23] WRONG (Q expects [8..15])
// j=1: row0=[8..15] WRONG (Q expects [16..23]), row1=[24..31] OK
// Fix: exchange inner halves between rows via cross-row shuffle.
auto Ktmp = Klocal[token_depth][qkhe_depth];
_B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp);

// Row 0 sends xy[1] (he 8..15), row 1 sends xy[0] (he 16..23).
// After shfl_xor: row 0 gets he 16..23, row 1 gets he 8..15.
const _B8x8 inner = Ktmp8x16.xy[1 - rowid];
_B8x8 cross;
cross.x = __shfl_xor(inner.x, 16);
cross.y = __shfl_xor(inner.y, 16);

#pragma unroll
for (int j = 0; j < 2; j++) {
// j==rowid: use own outer half; j!=rowid: use cross-row data.
// j=0: row0=xy[0](he 0..7), row1=cross(he 8..15) -> he [0..15]
// j=1: row0=cross(he 16..23), row1=xy[1](he 24..31) -> he [16..31]
_B8x8 Kfp8;
Kfp8.x = (j == rowid) ? Ktmp8x16.xy[j].x : cross.x;
Kfp8.y = (j == rowid) ? Ktmp8x16.xy[j].y : cross.y;
_B16x8 Kconv = convert_b8x8_to_b16x8<scalar_t>(Kfp8);
dout[token_depth] = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Kconv.u16x8, Qlocal[qkhe_depth * 2 + j].u16x8, dout[token_depth]);
}
}
}
dout[token_depth] *= scale;
dout[token_depth] *= scale2;
}

// calculate qk_max and exp_sum per warp and write to shared memory
Expand Down Expand Up @@ -2833,22 +2915,50 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
_B16x8 outelems[VHELOOP];
// Softmax V wmma
// v layout: 16he across lanes x 16 tokens per lane
// VTLANELOOP_F16 is the number of 8-f16-element chunks per V fetch group.
// For f16 cache: VTLANELOOP=2, matching directly.
// For FP8 cache: VTLANELOOP=1 (one 16-byte fetch = 16 FP8 values),
// which we split into 2 halves of 8, so effective loop count is 2.
constexpr int VTLANELOOP_F16 = DIVIDE_ROUND_UP(
VTOKENS_PER_LANE, Q_ELEMS_16B); // always 2 for 16 vtokens / 8 per f16
for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) {
floatx8 tmp_out = {0};

for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) {
for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) {
const int offset = rowid * VTLANELOOP + vfetch_depth;
const int offset1 = offset % ROWS_PER_WARP;
const int offset2 = offset / ROWS_PER_WARP;
// if output format is 16 qheads across 16 lanes, 16 head elems spread
// across rows
tmp_out = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x8,
shared_logits[vtoken_depth][offset2][lane16id][offset1].u16x8,
tmp_out);
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) {
const int offset = rowid * VTLANELOOP + vfetch_depth;
const int offset1 = offset % ROWS_PER_WARP;
const int offset2 = offset / ROWS_PER_WARP;
tmp_out = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Vlocal[vtoken_depth][vhe_depth][vfetch_depth].u16x8,
shared_logits[vtoken_depth][offset2][lane16id][offset1].u16x8,
tmp_out);
}
} else {
// FP8 KV cache: each Vlocal entry has 16 FP8 values (16 bytes).
// Split into two _B8x8, convert each to _B16x8, and do 2 WMMA calls.
for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) {
auto Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth];
_B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp);
for (int j = 0; j < 2; j++) {
_B16x8 Vconv = convert_b8x8_to_b16x8<scalar_t>(Vtmp8x16.xy[j]);
const int vf_idx = vfetch_depth * 2 + j;
const int offset = rowid * VTLANELOOP_F16 + vf_idx;
const int offset1 = offset % ROWS_PER_WARP;
const int offset2 = offset / ROWS_PER_WARP;
tmp_out = gcn_wmma16x16x16_instr<scalar_t, 0, 0, 0>(
Vconv.u16x8,
shared_logits[vtoken_depth][offset2][lane16id][offset1].u16x8,
tmp_out);
}
}
}
}
// apply post Softmax V wmma v_scale
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {
tmp_out *= *v_scale;
}
outelems[vhe_depth] = from_floatx8<scalar_t>(tmp_out);
}

Expand Down Expand Up @@ -3391,7 +3501,7 @@ void paged_attention_custom_launcher_navi(
torch::Tensor& block_tables, torch::Tensor& seq_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale) {
int num_seqs = block_tables.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand Down Expand Up @@ -3421,8 +3531,10 @@ void paged_attention_custom_launcher_navi(

const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// NOTE: Navi does not support fp8.
const auto fp8_out_scale_ptr = nullptr;
const auto fp8_out_scale_ptr =
fp8_out_scale
? reinterpret_cast<const float*>(fp8_out_scale.value().data_ptr())
: nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());

const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE);
Expand Down Expand Up @@ -3568,7 +3680,7 @@ void paged_attention_custom_launcher_navi(
ALIBI_ENABLED, MFMA_TYPE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
max_seq_len, alibi_slopes, k_scale, v_scale); \
max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
}

#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
Expand Down
18 changes: 10 additions & 8 deletions tests/kernels/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,14 @@ def test_paged_attention(
):
pytest.skip()

if (
version == "rocm"
and current_platform.is_navi()
and (
kv_cache_dtype == "fp8" or head_size != 128 or block_size != 16 or use_alibi
)
):
pytest.skip()
if version == "rocm" and current_platform.is_navi():
# gfx12 (RDNA4) supports FP8 KV cache via software dequant;
# within is_navi(), supports_fp8() implies gfx12.
is_gfx12 = current_platform.supports_fp8()
fp8_unsupported = kv_cache_dtype == "fp8" and not is_gfx12
block_size_ok = block_size == 16 or (is_gfx12 and block_size == 32)
if fp8_unsupported or head_size != 128 or not block_size_ok or use_alibi:
pytest.skip()

global PARTITION_SIZE

Expand Down Expand Up @@ -196,6 +196,8 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0]

# Using default kv_scale
# NOTE: non-trivial k_scale/v_scale would exercise FP8 dequant paths but
# the reference computation does not apply scales, so keep at 1.0 for now.
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)

# Call the paged attention kernel.
Expand Down
11 changes: 8 additions & 3 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def _get_gcn_arch() -> str:
_GCN_ARCH = _get_gcn_arch()

_ON_GFX1X = any(arch in _GCN_ARCH for arch in ["gfx11", "gfx12"])
_ON_GFX12 = "gfx12" in _GCN_ARCH
_ON_MI3XX = any(arch in _GCN_ARCH for arch in ["gfx942", "gfx950"])
_ON_GFX9 = any(arch in _GCN_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
_ON_GFX942 = "gfx942" in _GCN_ARCH
Expand Down Expand Up @@ -269,16 +270,20 @@ def use_rocm_custom_paged_attention(
)

else:
# gfx12 (RDNA4) supports FP8 KV cache via software dequant
fp8_ok = kv_cache_dtype in ("fp8", "fp8_e4m3") and _ON_GFX12
block_size_ok = block_size == 16 or (_ON_GFX12 and block_size == 32)
gqa_min = 1 if _ON_GFX12 else 3
return (
_ON_GFX1X
and (sliding_window == 0 or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128
and block_size == 16
and (gqa_ratio >= 3 and gqa_ratio <= 16)
and block_size_ok
and (gqa_ratio >= gqa_min and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024
and alibi_slopes is None
and kv_cache_dtype == "auto"
and (kv_cache_dtype == "auto" or fp8_ok)
and sinks is None
)

Expand Down
Loading