diff --git a/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh b/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh index 5a1fa2b6df27..6ef424891164 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/kvcache.cuh @@ -130,6 +130,7 @@ struct StoreKVCacheKernel { auto I = SymbolicSize{"indices_stride"}; auto dtype = SymbolicDType{}; auto device = SymbolicDevice{}; + auto indice_dtype = SymbolicDType{}; device.set_options(); TensorMatcher({B, D}) // @@ -150,7 +151,7 @@ struct StoreKVCacheKernel { .verify(v_cache); TensorMatcher({B}) // .with_strides({I}) - .with_dtype() + .with_dtype(indice_dtype) .with_device(device) .verify(indices); @@ -171,7 +172,8 @@ struct StoreKVCacheKernel { .batch_size = static_cast(B.unwrap()), }; // select kernel and update num_split if needed - const auto kernel = dtype.is_type() ? get_kernel(num_split) : get_kernel(num_split); + const auto use_int32 = indice_dtype.is_type(); + const auto kernel = use_int32 ? get_kernel(num_split) : get_kernel(num_split); const auto num_blocks = div_ceil(num_elements * num_split, kNumWarps); LaunchKernel(num_blocks, kThreadsPerBlock, device.unwrap()) // .enable_pdl(kUsePDL)(kernel, params);