diff --git a/aiter/jit/optCompilerConfig.json b/aiter/jit/optCompilerConfig.json index e60563153f..557e475603 100755 --- a/aiter/jit/optCompilerConfig.json +++ b/aiter/jit/optCompilerConfig.json @@ -139,7 +139,8 @@ ], "extra_ldflags": "None", "extra_include": [ - "f'{AITER_CSRC_DIR}/include/ck_tile'" + "f'{AITER_CSRC_DIR}/include/ck_tile'", + "f'{AITER_CSRC_DIR}/include/opus'" ], "verbose": "False", "blob_gen_cmd": "''" diff --git a/csrc/kernels/cache_kernels.cu b/csrc/kernels/cache_kernels.cu index 6bfa541091..4aa83f4b22 100644 --- a/csrc/kernels/cache_kernels.cu +++ b/csrc/kernels/cache_kernels.cu @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -12,6 +12,7 @@ #include "quant_utils.cuh" #include "vec_convert.h" +#include "opus.hpp" #include #include @@ -1337,14 +1338,18 @@ template const scalar_t x = arr_in[x_index]; const scalar_t y = arr_in[y_index]; + float f32_x = ck_tile::type_convert(x); + float f32_y = ck_tile::type_convert(y); + float f32_cos = ck_tile::type_convert(cos); + float f32_sin = ck_tile::type_convert(sin); if constexpr (std::is_same_v) { arr_out[x_index] = ck_tile::type_convert( - ck_tile::type_convert(x * cos - y * sin) * inv_scale); + (f32_x * f32_cos - f32_y * f32_sin) * inv_scale); arr_out[y_index] = ck_tile::type_convert( - ck_tile::type_convert(y * cos + x * sin) * inv_scale); + (f32_y * f32_cos + f32_x * f32_sin) * inv_scale); } else { - arr_out[x_index] = x * cos - y * sin; - arr_out[y_index] = y * cos + x * sin; + arr_out[x_index] = ck_tile::type_convert((f32_x * f32_cos - f32_y * f32_sin)); + arr_out[y_index] = ck_tile::type_convert((f32_y * f32_cos + f32_x * f32_sin)); } } @@ -1453,7 +1458,9 @@ inline __device__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel_impl( static constexpr int32_t vec_size_i = std::is_same_v ? 4 : 8; static constexpr int32_t vec_size_o = vec_size_i; using vec_i = ck_tile::vec_t; - + using opus_vec_i = opus::vector_t; + using opus_vec_o = opus::vector_t; + using opus_vec_q = opus::vector_t; float inv_qscale = 1.0f; if constexpr (kv_dt != vllm::Fp8KVCacheDataType::kAuto) { @@ -1463,13 +1470,12 @@ inline __device__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel_impl( const int32_t q_oob_o = (kv_lora_rank + q_ooba_o - 1) / q_ooba_o * q_ooba_o; auto const* q_ptr_i = reinterpret_cast(q_nope + token_idx * q_nope_stride_0 + head_idx * q_nope_stride_1); auto* q_ptr_o = reinterpret_cast(q_out + token_idx * q_out_stride_0 + head_idx * q_out_stride_1); - auto buffer_i = ck_tile::make_buffer_view(q_ptr_i, oob_i); - buffer_i.init_raw(); - auto buffer_o = ck_tile::make_buffer_view(q_ptr_o, q_oob_o); - buffer_o.init_raw(); - vec_i vec_cur; + // Use opus::make_gmem instead of ck_tile::make_buffer_view + auto buffer_i = opus::make_gmem(q_ptr_i, oob_i * sizeof(scalar_t)); + auto buffer_o = opus::make_gmem(q_ptr_o, q_oob_o * sizeof(query_t)); + opus_vec_i vec_cur; // Use opus_vec_i directly, no need to cast on load size_t vec_idx = threadIdx.x; - vec_cur = buffer_i.template get(vec_idx * vec_size_i, 0, true); + vec_cur = buffer_i.template load(vec_idx * vec_size_i); const int embed_dim = 32; const int nq = embed_dim; q_out += head_size - pe_dim; @@ -1503,26 +1509,21 @@ inline __device__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel_impl( y = q_pe_rot[y_index]; } if constexpr (q_dt == vllm::Fp8KVCacheDataType::kAuto) { - buffer_o.template set_raw( - vec_idx * vec_size_o, - 0, - true, - vec_cur.template get_as()); + buffer_o.template store(vec_cur, vec_idx * vec_size_o); } else { - buffer_o.template set( - vec_idx * vec_size_o, - 0, - true, - ck_tile::vec_convert(vec_cur, inv_qscale) - .template get_as()); + auto vec_converted_ck = ck_tile::vec_convert( + ck_tile::bit_cast(vec_cur), inv_qscale); + opus_vec_q vec_converted = ck_tile::bit_cast(vec_converted_ck); + buffer_o.template store(vec_converted, vec_idx * vec_size_o); } - + float fp32_cos = ck_tile::type_convert(cos); + float fp32_sin = ck_tile::type_convert(sin); if (head_idx == 0) { auto const* ptr_i = reinterpret_cast(kv_c + token_idx * kv_c_stride); - auto buffer_i = ck_tile::make_buffer_view(ptr_i, oob_i); - buffer_i.init_raw(); - vec_cur = buffer_i.template get(vec_idx * vec_size_i, 0, true); + // Use opus::make_gmem for kv_c input + auto kv_buffer_i = opus::make_gmem(ptr_i, oob_i * sizeof(scalar_t)); + vec_cur = kv_buffer_i.template load(vec_idx * vec_size_i); float inv_kscale = 1.0f; if constexpr (kv_dt != vllm::Fp8KVCacheDataType::kAuto) { @@ -1541,22 +1542,20 @@ inline __device__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel_impl( } const int64_t kv_cache_offset = block_idx * block_stride + block_offset * entry_stride; auto* ptr_o = reinterpret_cast(kv_cache + kv_cache_offset); - auto buffer_o = ck_tile::make_buffer_view(ptr_o, oob_o); - buffer_o.init_raw(); + // Use opus::make_gmem for kv_cache output + auto kv_buffer_o = opus::make_gmem(ptr_o, oob_o * sizeof(cache_t)); if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) { - buffer_o.template set_raw( - vec_idx * vec_size_o, - 0, - true, - vec_cur.template get_as()); + kv_buffer_o.template store(vec_cur, vec_idx * vec_size_o); } else { - buffer_o.template set( - vec_idx * vec_size_o, - 0, - true, - ck_tile::vec_convert(vec_cur, inv_kscale) - .template get_as()); + auto vec_converted_ck = ck_tile::vec_convert( + ck_tile::bit_cast(vec_cur), inv_kscale); + opus_vec_o vec_converted = ck_tile::bit_cast(vec_converted_ck); + kv_buffer_o.template store(vec_converted, vec_idx * vec_size_o); } + + float fp32_k_x = ck_tile::type_convert(k_x); + float fp32_k_y = ck_tile::type_convert(k_y); + if (threadIdx.x < 32) { kv_cache += kv_lora_rank; @@ -1564,12 +1563,12 @@ inline __device__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel_impl( cache_t* kv_cache_rot = kv_cache + token_head; if constexpr (std::is_same_v) { kv_cache_rot[x_index] = ck_tile::type_convert( - ck_tile::type_convert(k_x * cos - k_y * sin) * inv_kscale); + (fp32_k_x * fp32_cos - fp32_k_y * fp32_sin) * inv_kscale); kv_cache_rot[y_index] = ck_tile::type_convert( - ck_tile::type_convert(k_y * cos + k_x * sin) * inv_kscale); + (fp32_k_y * fp32_cos + fp32_k_x * fp32_sin) * inv_kscale); } else { - kv_cache_rot[x_index] = k_x * cos - k_y * sin; - kv_cache_rot[y_index] = k_y * cos + k_x * sin; + kv_cache_rot[x_index] = ck_tile::type_convert((fp32_k_x * fp32_cos - fp32_k_y * fp32_sin)); + kv_cache_rot[y_index] = ck_tile::type_convert((fp32_k_y * fp32_cos + fp32_k_x * fp32_sin)); } } } @@ -1577,14 +1576,14 @@ inline __device__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel_impl( { const int64_t token_head = token_idx * q_out_stride_0 + head_idx * q_out_stride_1; query_t * q_out_rot = q_out + token_head; + float f32_x = ck_tile::type_convert(x); + float f32_y = ck_tile::type_convert(y); if constexpr (std::is_same_v) { - q_out_rot[x_index] = ck_tile::type_convert( - ck_tile::type_convert(x * cos - y * sin) * inv_qscale); - q_out_rot[y_index] = ck_tile::type_convert( - ck_tile::type_convert(y * cos + x * sin) * inv_qscale); + q_out_rot[x_index] = ck_tile::type_convert((f32_x * fp32_cos - f32_y * fp32_sin) * inv_qscale); + q_out_rot[y_index] = ck_tile::type_convert((f32_y * fp32_cos + f32_x * fp32_sin) * inv_qscale); } else { - q_out_rot[x_index] = x * cos - y * sin; - q_out_rot[y_index] = y * cos + x * sin; + q_out_rot[x_index] = ck_tile::type_convert((f32_x * fp32_cos - f32_y * fp32_sin)); + q_out_rot[y_index] = ck_tile::type_convert((f32_y * fp32_cos + f32_x * fp32_sin)); } } @@ -1696,6 +1695,9 @@ __global__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel( static constexpr int32_t vec_size_i = std::is_same_v ? 4 : 8; static constexpr int32_t vec_size_o = vec_size_i; using vec_i = ck_tile::vec_t; + using opus_vec_i = opus::vector_t; + using opus_vec_o = opus::vector_t; + using opus_vec_q = opus::vector_t; static constexpr int32_t ooba_i = 4 / sizeof(scalar_t); static constexpr int32_t ooba_o = 4 / sizeof(cache_t); auto out_offset = block_idx * block_stride + block_offset * entry_stride; @@ -1711,55 +1713,42 @@ __global__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel( const int32_t oob_o = (size + ooba_o - 1) / ooba_o * ooba_o; auto const* ptr_i = reinterpret_cast(src + token_idx * src_stride); auto* ptr_o = reinterpret_cast(dst + out_offset + offset); - auto buffer_i = ck_tile::make_buffer_view(ptr_i, oob_i); - buffer_i.init_raw(); - auto buffer_o = ck_tile::make_buffer_view(ptr_o, oob_o); - buffer_o.init_raw(); + // Use opus::make_gmem instead of ck_tile::make_buffer_view + auto buffer_i = opus::make_gmem(ptr_i, oob_i * sizeof(scalar_t)); + auto buffer_o = opus::make_gmem(ptr_o, oob_o * sizeof(cache_t)); // double load core loop start const int32_t num_vecs = (size + vec_size_i - 1) / vec_size_i; - vec_i vec_nxt; - vec_i vec_cur; + opus_vec_i vec_nxt; + opus_vec_i vec_cur; size_t vec_idx = threadIdx.x; size_t vec_stride = blockDim.x; if (vec_idx < num_vecs) { - vec_cur = buffer_i.template get(vec_idx * vec_size_i, 0, true); + vec_cur = buffer_i.template load(vec_idx * vec_size_i); } for (vec_idx += vec_stride; vec_idx < num_vecs; vec_idx += vec_stride) { - vec_nxt = buffer_i.template get(vec_idx * vec_size_i, 0, true); + vec_nxt = buffer_i.template load(vec_idx * vec_size_i); if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) { - buffer_o.template set_raw( - (vec_idx - vec_stride) * vec_size_o, - 0, - true, - vec_cur.template get_as()); + buffer_o.template store(vec_cur, (vec_idx - vec_stride) * vec_size_o); } else { - buffer_o.template set( - (vec_idx - vec_stride) * vec_size_o, - 0, - true, - ck_tile::vec_convert(vec_cur, inverted_kscale) - .template get_as()); + auto vec_converted_ck = ck_tile::vec_convert( + ck_tile::bit_cast(vec_cur), inverted_kscale); + opus_vec_o vec_converted = ck_tile::bit_cast(vec_converted_ck); + buffer_o.template store(vec_converted, (vec_idx - vec_stride) * vec_size_o); } vec_cur = vec_nxt; } if (vec_idx - vec_stride < num_vecs) { if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) { - buffer_o.template set_raw( - (vec_idx - vec_stride) * vec_size_o, - 0, - true, - vec_cur.template get_as()); + buffer_o.template store(vec_cur, (vec_idx - vec_stride) * vec_size_o); } else { - buffer_o.template set( - (vec_idx - vec_stride) * vec_size_o, - 0, - true, - ck_tile::vec_convert(vec_cur, inverted_kscale) - .template get_as()); + auto vec_converted_ck = ck_tile::vec_convert( + ck_tile::bit_cast(vec_cur), inverted_kscale); + opus_vec_o vec_converted = ck_tile::bit_cast(vec_converted_ck); + buffer_o.template store(vec_converted, (vec_idx - vec_stride) * vec_size_o); } } }; @@ -1775,40 +1764,32 @@ __global__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel( const int32_t oob_o = (num_heads * head_size + q_ooba_o - 1) / q_ooba_o * q_ooba_o; auto const* ptr_i = reinterpret_cast(q_nope + token_idx * q_nope_stride_0); auto* ptr_o = reinterpret_cast(q_out + token_idx * q_out_stride_0); - auto buffer_i = ck_tile::make_buffer_view(ptr_i, oob_i); - buffer_i.init_raw(); - auto buffer_o = ck_tile::make_buffer_view(ptr_o, oob_o); - buffer_o.init_raw(); + // Use opus::make_gmem instead of ck_tile::make_buffer_view + auto buffer_i = opus::make_gmem(ptr_i, oob_i * sizeof(scalar_t)); + auto buffer_o = opus::make_gmem(ptr_o, oob_o * sizeof(query_t)); // double load core loop start const int32_t num_vecs = (size + vec_size_i - 1) / vec_size_i; - vec_i vec_nxt; - vec_i vec_cur; + opus_vec_i vec_nxt; + opus_vec_i vec_cur; size_t vec_idx = threadIdx.x; size_t vec_stride = blockDim.x; size_t kv_lora_vec = kv_lora_rank/vec_size_o; if (vec_idx < num_vecs) { - vec_cur = buffer_i.template get(vec_idx * vec_size_i, 0, true); + vec_cur = buffer_i.template load(vec_idx * vec_size_i); } for (vec_idx += vec_stride; vec_idx < num_vecs; vec_idx += vec_stride) { - vec_nxt = buffer_i.template get(vec_idx * vec_size_i, 0, true); + vec_nxt = buffer_i.template load(vec_idx * vec_size_i); size_t head_idx = (vec_idx - vec_stride) / kv_lora_vec; size_t vec_dst_idx = (vec_idx - vec_stride) % kv_lora_vec; if constexpr (q_dt == vllm::Fp8KVCacheDataType::kAuto) { - // buffer_o.template set_raw will cause a little mismatch - buffer_o.template set( - (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset, - 0, - true, - vec_cur.template get_as()); + buffer_o.template store(vec_cur, (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset); } else { - buffer_o.template set( - (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset, - 0, - true, - ck_tile::vec_convert(vec_cur, inverted_qscale) - .template get_as()); + auto vec_converted_ck = ck_tile::vec_convert( + ck_tile::bit_cast(vec_cur), inverted_qscale); + opus_vec_q vec_converted = ck_tile::bit_cast(vec_converted_ck); + buffer_o.template store(vec_converted, (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset); } vec_cur = vec_nxt; } @@ -1817,18 +1798,12 @@ __global__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel( size_t head_idx = (vec_idx - vec_stride) / kv_lora_vec; size_t vec_dst_idx = (vec_idx - vec_stride) % kv_lora_vec; if constexpr (q_dt == vllm::Fp8KVCacheDataType::kAuto) { - buffer_o.template set_raw( - (head_idx * q_out_stride_1) + vec_dst_idx *vec_size_o + nope_offset, - 0, - true, - vec_cur.template get_as()); + buffer_o.template store(vec_cur, (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset); } else { - buffer_o.template set( - (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset, - 0, - true, - ck_tile::vec_convert(vec_cur, inverted_qscale) - .template get_as()); + auto vec_converted_ck = ck_tile::vec_convert( + ck_tile::bit_cast(vec_cur), inverted_qscale); + opus_vec_q vec_converted = ck_tile::bit_cast(vec_converted_ck); + buffer_o.template store(vec_converted, (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset); } } //apply rotary @@ -1887,6 +1862,8 @@ __global__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel( static constexpr int32_t vec_size_i = std::is_same_v ? 4 : 8; static constexpr int32_t vec_size_o = vec_size_i; using vec_i = ck_tile::vec_t; + using opus_vec_i = opus::vector_t; + using opus_vec_o = opus::vector_t; static constexpr int32_t ooba_i = 4 / sizeof(scalar_t); static constexpr int32_t ooba_o = 4 / sizeof(cache_t); auto out_offset = block_idx * block_stride + block_offset * entry_stride; @@ -1900,20 +1877,15 @@ __global__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel( } auto const* ptr_i = reinterpret_cast(kv_c + token_idx * kv_c_stride); auto* ptr_o = reinterpret_cast(kv_cache + out_offset + nope_offset); - auto buffer_i = ck_tile::make_buffer_view(ptr_i, oob_i); - buffer_i.init_raw(); - auto buffer_o = ck_tile::make_buffer_view(ptr_o, oob_o); - buffer_o.init_raw(); - // double load core loop start + + // FIX: oob_i is in elements, but make_gmem expects size in BYTES + auto buffer_i = opus::make_gmem(ptr_i, oob_i * sizeof(scalar_t)); + auto buffer_o = opus::make_gmem(ptr_o, oob_o * sizeof(cache_t)); + // Simple load and store for kv_lora_dim data const int32_t k_num_vecs = (kv_lora_dim + vec_size_i - 1) / vec_size_i; - vec_i k_vec_cur; - vec_i k_vec_next; + opus_vec_i k_vec_cur; size_t vec_idx = threadIdx.x; size_t vec_stride = 256;//blockDim.x; - if (vec_idx < k_num_vecs) - { - k_vec_cur = buffer_i.template get(vec_idx * vec_size_i, 0, true); - } const float inverted_qscale = 1.0f / *q_scale; const int64_t head_size = kv_lora_dim + pe_dim; @@ -1924,32 +1896,30 @@ __global__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel( auto const* q_ptr_i = reinterpret_cast(q_nope + token_idx * q_nope_stride_0); auto* q_ptr_o = reinterpret_cast(q_out + q_out_stride_0 * token_idx); - auto q_buffer_i = ck_tile::make_buffer_view(q_ptr_i, q_oob_i); - q_buffer_i.init_raw(); - auto q_buffer_o = ck_tile::make_buffer_view(q_ptr_o, q_oob_o); - q_buffer_o.init_raw(); + // Use opus::make_gmem instead of ck_tile::make_buffer_view, size in BYTES + auto q_buffer_i = opus::make_gmem(q_ptr_i, q_oob_i * sizeof(scalar_t)); + auto q_buffer_o = opus::make_gmem(q_ptr_o, q_oob_o * sizeof(query_t)); const int32_t num_vecs = (size + vec_size_i - 1) / vec_size_i; size_t q_vec_idx = threadIdx.x; - vec_i vec_nxt; - vec_i vec_cur; + using opus_vec_q = opus::vector_t; + opus_vec_i vec_nxt; // Changed from vec_i to opus_vec_i + opus_vec_i vec_cur; // Changed from vec_i to opus_vec_i size_t kv_lora_vec = kv_lora_dim / vec_size_o; - vec_cur = q_buffer_i.template get(q_vec_idx * vec_size_i, 0, true); + vec_cur = q_buffer_i.template load(q_vec_idx * vec_size_i); + + // Load and store k vector (only threads < k_num_vecs need to work) if (vec_idx < k_num_vecs) { - if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) { - buffer_o.template set_raw( - (vec_idx) * vec_size_o, - 0, - true, - k_vec_cur.template get_as()); - } else { - buffer_o.template set( - (vec_idx) * vec_size_o, - 0, - true, - ck_tile::vec_convert(k_vec_cur, inverted_kscale) - .template get_as()); - } + k_vec_cur = buffer_i.template load(vec_idx * vec_size_i); + if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) { + buffer_o.template store(k_vec_cur, vec_idx * vec_size_o); + } else { + // Use ck_tile::vec_convert and cast to opus_vec_o + auto vec_converted_ck = ck_tile::vec_convert( + ck_tile::bit_cast(k_vec_cur), inverted_kscale); + opus_vec_o vec_converted = ck_tile::bit_cast(vec_converted_ck); + buffer_o.template store(vec_converted, vec_idx * vec_size_o); + } } int64_t pos = positions[token_idx]; @@ -1967,7 +1937,7 @@ __global__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel( } for (; q_vec_idx < nq; q_vec_idx += vec_stride) { - vec_nxt = q_buffer_i.template get((q_vec_idx + vec_stride) * vec_size_i, 0, true); + vec_nxt = q_buffer_i.template load((q_vec_idx + vec_stride) * vec_size_i); size_t cur_idx = q_vec_idx; size_t head_idx = cur_idx / kv_lora_vec; size_t vec_dst_idx = cur_idx % kv_lora_vec; @@ -2000,51 +1970,45 @@ __global__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel( const scalar_t y = q_pe_in[y_index]; if constexpr (q_dt == vllm::Fp8KVCacheDataType::kAuto) { - q_buffer_o.template set( - (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset, - 0, - true, - vec_cur.template get_as()); + q_buffer_o.template store(vec_cur, (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset); } else { - q_buffer_o.template set( - (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset, - 0, - true, - ck_tile::vec_convert(vec_cur, inverted_qscale) - .template get_as()); + // Use ck_tile::vec_convert and cast to opus_vec_q + auto vec_q_converted_ck = ck_tile::vec_convert( + ck_tile::bit_cast(vec_cur), inverted_qscale); + opus_vec_q vec_q_converted = ck_tile::bit_cast(vec_q_converted_ck); + q_buffer_o.template store(vec_q_converted, (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset); } vec_cur = vec_nxt; const int64_t token_head = token_idx * q_out_stride_0 + r_head_idx * q_out_stride_1; query_t* q_out_rope = q_out + token_head; + float f32_x = ck_tile::type_convert(x); + float f32_y = ck_tile::type_convert(y); + float f32_cos = ck_tile::type_convert(cos); + float f32_sin = ck_tile::type_convert(sin); if constexpr (std::is_same_v) { q_out_rope[x_index] = ck_tile::type_convert( - ck_tile::type_convert(x * cos - y * sin) * inverted_qscale); + (f32_x * f32_cos - f32_y * f32_sin) * inverted_qscale); q_out_rope[y_index] = ck_tile::type_convert( - ck_tile::type_convert(y * cos + x * sin) * inverted_qscale); + (f32_y * f32_cos + f32_x * f32_sin) * inverted_qscale); } else { - q_out_rope[x_index] = x * cos - y * sin; - q_out_rope[y_index] = y * cos + x * sin; + q_out_rope[x_index] = ck_tile::type_convert(f32_x * f32_cos - f32_y * f32_sin); + q_out_rope[y_index] = ck_tile::type_convert(f32_y * f32_cos + f32_x * f32_sin); } } for (q_vec_idx += vec_stride; q_vec_idx < num_vecs; q_vec_idx += vec_stride) { - vec_nxt = q_buffer_i.template get(q_vec_idx * vec_size_i, 0, true); + vec_nxt = q_buffer_i.template load(q_vec_idx * vec_size_i); size_t head_idx = (q_vec_idx - vec_stride) / kv_lora_vec; size_t vec_dst_idx = (q_vec_idx - vec_stride) % kv_lora_vec; if constexpr (q_dt == vllm::Fp8KVCacheDataType::kAuto) { - q_buffer_o.template set_raw( - (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset, - 0, - true, - vec_cur.template get_as()); + q_buffer_o.template store(vec_cur, (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset); } else { - q_buffer_o.template set( - (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset, - 0, - true, - ck_tile::vec_convert(vec_cur, inverted_qscale) - .template get_as()); + // Use ck_tile::vec_convert and cast to opus_vec_q + auto vec_q_converted_ck = ck_tile::vec_convert( + ck_tile::bit_cast(vec_cur), inverted_qscale); + opus_vec_q vec_q_converted = ck_tile::bit_cast(vec_q_converted_ck); + q_buffer_o.template store(vec_q_converted, (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset); } vec_cur = vec_nxt; } @@ -2053,18 +2017,13 @@ __global__ void fuse_qk_rope_concat_and_cache_mla_per_head_kernel( size_t head_idx = (q_vec_idx - vec_stride) / kv_lora_vec; size_t vec_dst_idx = (q_vec_idx - vec_stride) % kv_lora_vec; if constexpr (q_dt == vllm::Fp8KVCacheDataType::kAuto) { - q_buffer_o.template set_raw( - (head_idx * q_out_stride_1) + vec_dst_idx* vec_size_o + nope_offset, - 0, - true, - vec_cur.template get_as()); + q_buffer_o.template store(vec_cur, (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset); } else { - q_buffer_o.template set( - (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset, - 0, - true, - ck_tile::vec_convert(vec_cur, inverted_qscale) - .template get_as()); + // Use ck_tile::vec_convert and cast to opus_vec_q + auto vec_q_converted_ck = ck_tile::vec_convert( + ck_tile::bit_cast(vec_cur), inverted_qscale); + opus_vec_q vec_q_converted = ck_tile::bit_cast(vec_q_converted_ck); + q_buffer_o.template store(vec_q_converted, (head_idx * q_out_stride_1) + vec_dst_idx * vec_size_o + nope_offset); } } // apply rotary @@ -2977,12 +2936,13 @@ void fused_qk_rope_concat_and_cache_mla( dim3 block(256); DISPATCH_BY_KV_CACHE_QUERY_DTYPE(kv_c.dtype(), kv_cache_dtype, q_out_type, CALL_FUSED_QK_ROPE_CONCAT_AND_CACHE_MLA); - } else { - dim3 grid(num_tokens); - dim3 block(std::min(kv_lora_rank*num_heads, 2048)/8); - DISPATCH_BY_KV_CACHE_QUERY_DTYPE(kv_c.dtype(), kv_cache_dtype, q_out_type, - CALL_FUSED_QK_ROPE_CONCAT_AND_CACHE_MLA_GENERAL); - } + } + //else { + // dim3 grid(num_tokens); + // dim3 block(std::min(kv_lora_rank*num_heads, 2048)/8); + // DISPATCH_BY_KV_CACHE_QUERY_DTYPE(kv_c.dtype(), kv_cache_dtype, q_out_type, + // CALL_FUSED_QK_ROPE_CONCAT_AND_CACHE_MLA_GENERAL); + //} } } // namespace aiter