diff --git a/csrc/trtllm_fused_moe_dev_kernel.cu b/csrc/trtllm_fused_moe_dev_kernel.cu index ead183cfd2..63e8aef5a4 100644 --- a/csrc/trtllm_fused_moe_dev_kernel.cu +++ b/csrc/trtllm_fused_moe_dev_kernel.cu @@ -78,7 +78,8 @@ __global__ void activationKernel(KernelParams params) { // Loop over hidden dim for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2; hiddenIdx += blockDim.x * gridDim.x) { - int const baseIdx = permutedIdx * params.innerDim + hiddenIdx; + // Use int64_t to avoid overflow when permutedIdx * innerDim > INT32_MAX + int64_t const baseIdx = (int64_t)permutedIdx * params.innerDim + hiddenIdx; float x1 = (float)params.inPtr[baseIdx]; float x2 = (float)params.inPtr[baseIdx + params.innerDim / 2]; @@ -86,7 +87,7 @@ __global__ void activationKernel(KernelParams params) { float act = silu(x2); Type out = (Type)(act * x1); - int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx; + int64_t const outIdx = (int64_t)permutedIdx * (params.innerDim / 2) + hiddenIdx; params.outPtr[outIdx] = out; } } @@ -261,11 +262,14 @@ __global__ void activationDeepSeekKernel(KernelParams params) { } // Process blocks for this CTA - int const baseIdx = permutedIdx * params.innerDim + hiddenIdx; + // Use int64_t to avoid overflow when permutedIdx * innerDim > INT32_MAX + int64_t const baseIdx = (int64_t)permutedIdx * params.innerDim + hiddenIdx; - int const scale1Idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); - int const scale2Idx = permutedIdx + totalNumPaddedTokens * - ((hiddenIdx / 128) + (params.innerDim / 2 / 128)); + int64_t const scale1Idx = + (int64_t)permutedIdx + (int64_t)totalNumPaddedTokens * (hiddenIdx / 128); + int64_t const scale2Idx = + (int64_t)permutedIdx + + (int64_t)totalNumPaddedTokens * ((hiddenIdx / 128) + (params.innerDim / 2 / 128)); scale1Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale1Idx]; scale2Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale2Idx]; @@ -304,8 +308,8 @@ __global__ void activationDeepSeekKernel(KernelParams params) { float scaleOut = fmaxf(aMaxArr[tokenInCtaIdx] / E4m3MaxVal, std::numeric_limits::min()); s_scaleOutArr[tokenInCtaIdx] = scaleOut; - int const scaleOut_idx = - permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128); + int64_t const scaleOut_idx = (int64_t)permutedIdxArr[tokenInCtaIdx] + + (int64_t)totalNumPaddedTokens * (hiddenIdx / 128); params.outDqSfsPtr[scaleOut_idx] = scaleOut; } } @@ -322,7 +326,7 @@ __global__ void activationDeepSeekKernel(KernelParams params) { continue; } float const scaleOut = s_scaleOutArr[tokenInCtaIdx]; - int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx; + int64_t const outIdx = (int64_t)permutedIdx * (params.innerDim / 2) + hiddenIdx; params.outPtr[outIdx] = static_cast(outArr[tokenInCtaIdx] / scaleOut); } }