@@ -246,8 +246,6 @@ an error.
246246
247247We guard these partial chunks by zero-padding unused lanes, and only writing
248248back the elements that actually exist in the chunk.
249-
250- Need to check to see how we impact perf.
251249*/
252250template <typename DType, typename QuantType, uint32_t vec_size>
253251__device__ __forceinline__ void scale_store_partial_chunk (const DType* in_ptr, QuantType* out_ptr,
@@ -1165,8 +1163,9 @@ cudaError_t RopeQuantizeAppendPagedKVCache(
11651163 config.attrs = attribute;
11661164 config.numAttrs = 1 ;
11671165
1168- auto kernel = RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, 1 , DType, IdType,
1169- QuantType, paged_kv_t <QuantType, IdType>>;
1166+ auto kernel =
1167+ RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, /* bdx=*/ 1 , DType, IdType,
1168+ QuantType, paged_kv_t <QuantType, IdType>>;
11701169 RopeQuantizeAppendPagedKVCacheParams params;
11711170 params.nnz = nnz;
11721171 params.num_qo_heads = num_qo_heads;
@@ -1239,8 +1238,8 @@ cudaError_t RopeQuantizeAppendPagedMLACache(
12391238 config.numAttrs = 1 ;
12401239
12411240 auto kernel =
1242- RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, 1 , DType, IdType, QuantType ,
1243- paged_kv_mla_t <QuantType, IdType>>;
1241+ RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, /* bdx= */ 1 , DType, IdType,
1242+ QuantType, paged_kv_mla_t <QuantType, IdType>>;
12441243 DType* v_in_nullptr = nullptr ;
12451244 uint32_t num_kv_heads_1 = 1 ;
12461245 size_t k_rope_in_stride_h_dup = k_rope_in_stride;
@@ -1268,9 +1267,18 @@ cudaError_t RopeQuantizeAppendPagedMLACache(
12681267 params.quant_scale_q = quant_scale_q;
12691268 params.quant_scale_kv = quant_scale_kv;
12701269
1271- FLASHINFER_CUDA_CALL (cudaLaunchKernelEx (
1272- &config, kernel, q_rope_in, k_rope_in, q_nope_in, k_nope_in, v_in_nullptr, q_rope_out,
1273- q_nope_out, paged_kv_mla, batch_indices, positions, cos_sin_cache, pos_ids, params));
1270+ FLASHINFER_CUDA_CALL (cudaLaunchKernelEx (&config, kernel,
1271+ // inputs
1272+ q_rope_in, k_rope_in, q_nope_in, k_nope_in,
1273+ v_in_nullptr,
1274+ // q outputs
1275+ q_rope_out, q_nope_out,
1276+ // cache + indices
1277+ paged_kv_mla, batch_indices, positions,
1278+ // rope tables
1279+ cos_sin_cache, pos_ids,
1280+ // params
1281+ params));
12741282 });
12751283
12761284 return cudaSuccess;
0 commit comments