diff --git a/csrc/trtllm_fused_moe_dev_kernel.cu b/csrc/trtllm_fused_moe_dev_kernel.cu index ead183cfd2..eaffd1d0df 100644 --- a/csrc/trtllm_fused_moe_dev_kernel.cu +++ b/csrc/trtllm_fused_moe_dev_kernel.cu @@ -586,7 +586,7 @@ __global__ void permuteKernel(KernelParams params) { for (int k = 0; k < params.topK; k++) { int const expandedIdx = tokenIdx * params.topK + k; int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; - params.outPtr[permutedIdx * params.hiddenDim + hiddenIdx] = data; + params.outPtr[(int64_t)permutedIdx * params.hiddenDim + hiddenIdx] = data; } } if (params.useDeepSeekFp8) { @@ -597,7 +597,7 @@ __global__ void permuteKernel(KernelParams params) { int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx]; int const idx_in = tokenIdx + params.numTokens * scaleIdx; - int const idx_out = permutedIdx + params.totalNumPaddedTokens[0] * scaleIdx; + int64_t const idx_out = (int64_t)permutedIdx + (int64_t)params.totalNumPaddedTokens[0] * scaleIdx; params.outDqSfsPtr[idx_out] = params.inDqSfsPtr[idx_in]; } @@ -662,9 +662,9 @@ __global__ void finalizeKernel(KernelParams params) { if (params.expertWeightsPtr != nullptr) { TypeExpW const scale = params.expertWeightsPtr[expandedIdx]; data += - float{scale} * float{params.inPtr[permutedIdx * params.hiddenDimPadded + hiddenIdx]}; + float{scale} * float{params.inPtr[(int64_t)permutedIdx * params.hiddenDimPadded + hiddenIdx]}; } else { - data += float{params.inPtr[permutedIdx * params.hiddenDimPadded + hiddenIdx]}; + data += float{params.inPtr[(int64_t)permutedIdx * params.hiddenDimPadded + hiddenIdx]}; } } @@ -926,14 +926,14 @@ __global__ void finalizeDeepSeekKernel(KernelParams params) { continue; } int const totalNumPaddedTokens = params.totalNumPaddedTokens[0]; - int const scaleIdx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128); + int64_t const scaleIdx = (int64_t)permutedIdx + (int64_t)totalNumPaddedTokens * (hiddenIdx / 128); float const blockScale = params.inDqSfsPtr ? params.inDqSfsPtr[scaleIdx] : 1; float const expertProb = (float)params.expertWeightsPtr[tokenIdx * params.topK + k]; float const scale = expertProb * blockScale; acc += scale * - static_cast(params.inPtr[permutedIdx * params.hiddenDimPadded + hiddenIdx]); + static_cast(params.inPtr[(int64_t)permutedIdx * params.hiddenDimPadded + hiddenIdx]); } // The largest (finite) value that can be represented using E4m3.