Skip to content

Commit 6805e9c

Browse files
committed
address small comment
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
1 parent 1bee170 commit 6805e9c

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

include/flashinfer/pos_enc.cuh

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,6 @@ an error.
246246
247247
We guard these partial chunks by zero-padding unused lanes, and only writing
248248
back the elements that actually exist in the chunk.
249-
250-
Need to check to see how we impact perf.
251249
*/
252250
template <typename DType, typename QuantType, uint32_t vec_size>
253251
__device__ __forceinline__ void scale_store_partial_chunk(const DType* in_ptr, QuantType* out_ptr,
@@ -1165,8 +1163,9 @@ cudaError_t RopeQuantizeAppendPagedKVCache(
11651163
config.attrs = attribute;
11661164
config.numAttrs = 1;
11671165

1168-
auto kernel = RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, 1, DType, IdType,
1169-
QuantType, paged_kv_t<QuantType, IdType>>;
1166+
auto kernel =
1167+
RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, /*bdx=*/1, DType, IdType,
1168+
QuantType, paged_kv_t<QuantType, IdType>>;
11701169
RopeQuantizeAppendPagedKVCacheParams params;
11711170
params.nnz = nnz;
11721171
params.num_qo_heads = num_qo_heads;
@@ -1239,8 +1238,8 @@ cudaError_t RopeQuantizeAppendPagedMLACache(
12391238
config.numAttrs = 1;
12401239

12411240
auto kernel =
1242-
RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, 1, DType, IdType, QuantType,
1243-
paged_kv_mla_t<QuantType, IdType>>;
1241+
RopeQuantizeAppendPagedKVCacheKernel<INTERLEAVE, vec_size, /*bdx=*/1, DType, IdType,
1242+
QuantType, paged_kv_mla_t<QuantType, IdType>>;
12441243
DType* v_in_nullptr = nullptr;
12451244
uint32_t num_kv_heads_1 = 1;
12461245
size_t k_rope_in_stride_h_dup = k_rope_in_stride;
@@ -1268,9 +1267,18 @@ cudaError_t RopeQuantizeAppendPagedMLACache(
12681267
params.quant_scale_q = quant_scale_q;
12691268
params.quant_scale_kv = quant_scale_kv;
12701269

1271-
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(
1272-
&config, kernel, q_rope_in, k_rope_in, q_nope_in, k_nope_in, v_in_nullptr, q_rope_out,
1273-
q_nope_out, paged_kv_mla, batch_indices, positions, cos_sin_cache, pos_ids, params));
1270+
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel,
1271+
// inputs
1272+
q_rope_in, k_rope_in, q_nope_in, k_nope_in,
1273+
v_in_nullptr,
1274+
// q outputs
1275+
q_rope_out, q_nope_out,
1276+
// cache + indices
1277+
paged_kv_mla, batch_indices, positions,
1278+
// rope tables
1279+
cos_sin_cache, pos_ids,
1280+
// params
1281+
params));
12741282
});
12751283

12761284
return cudaSuccess;

0 commit comments

Comments
 (0)