diff --git a/sgl-kernel/csrc/elementwise/concat_mla.cu b/sgl-kernel/csrc/elementwise/concat_mla.cu index 0335dc724a9c..7d5b8595c8da 100644 --- a/sgl-kernel/csrc/elementwise/concat_mla.cu +++ b/sgl-kernel/csrc/elementwise/concat_mla.cu @@ -18,11 +18,11 @@ __global__ void concat_mla_k_kernel( const nv_bfloat16* __restrict__ k_nope, const nv_bfloat16* __restrict__ k_rope, const int num_tokens, - const int k_stride_0, + const int64_t k_stride_0, const int k_stride_1, - const int k_nope_stride_0, + const int64_t k_nope_stride_0, const int k_nope_stride_1, - const int k_rope_stride_0) { + const int64_t k_rope_stride_0) { const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32; const int token_id = flat_warp_id / NUM_HEAD_CHUNKS; const int head_chunk_id = flat_warp_id % NUM_HEAD_CHUNKS; @@ -126,11 +126,11 @@ __global__ void concat_mla_absorb_q_kernel( nv_bfloat16* out, const int num_items, const int dim_1, - const int a_stride_0, + const int64_t a_stride_0, const int a_stride_1, - const int b_stride_0, + const int64_t b_stride_0, const int b_stride_1, - const int out_stride_0, + const int64_t out_stride_0, const int out_stride_1) { const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32; const int lane_id = get_lane_id();