Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions csrc/trtllm_fused_moe_dev_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ __global__ void permuteKernel(KernelParams params) {
for (int k = 0; k < params.topK; k++) {
int const expandedIdx = tokenIdx * params.topK + k;
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];
params.outPtr[permutedIdx * params.hiddenDim + hiddenIdx] = data;
params.outPtr[(int64_t)permutedIdx * params.hiddenDim + hiddenIdx] = data;
}
}
if (params.useDeepSeekFp8) {
Expand All @@ -597,7 +597,7 @@ __global__ void permuteKernel(KernelParams params) {
int const permutedIdx = params.expandedIdxToPermutedIdx[expandedIdx];

int const idx_in = tokenIdx + params.numTokens * scaleIdx;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency and to prevent potential integer overflow with large numTokens and hiddenDim, it's safer to use int64_t for this index calculation as well, similar to the change for idx_out. The product params.numTokens * scaleIdx could overflow if both numTokens and hiddenDim (which determines the range of scaleIdx) are large.

          int64_t const idx_in = (int64_t)tokenIdx + (int64_t)params.numTokens * scaleIdx;

int const idx_out = permutedIdx + params.totalNumPaddedTokens[0] * scaleIdx;
int64_t const idx_out = (int64_t)permutedIdx + (int64_t)params.totalNumPaddedTokens[0] * scaleIdx;

params.outDqSfsPtr[idx_out] = params.inDqSfsPtr[idx_in];
}
Expand Down Expand Up @@ -662,9 +662,9 @@ __global__ void finalizeKernel(KernelParams params) {
if (params.expertWeightsPtr != nullptr) {
TypeExpW const scale = params.expertWeightsPtr[expandedIdx];
data +=
float{scale} * float{params.inPtr[permutedIdx * params.hiddenDimPadded + hiddenIdx]};
float{scale} * float{params.inPtr[(int64_t)permutedIdx * params.hiddenDimPadded + hiddenIdx]};
} else {
data += float{params.inPtr[permutedIdx * params.hiddenDimPadded + hiddenIdx]};
data += float{params.inPtr[(int64_t)permutedIdx * params.hiddenDimPadded + hiddenIdx]};
}
}

Expand Down Expand Up @@ -926,14 +926,14 @@ __global__ void finalizeDeepSeekKernel(KernelParams params) {
continue;
}
int const totalNumPaddedTokens = params.totalNumPaddedTokens[0];
int const scaleIdx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
int64_t const scaleIdx = (int64_t)permutedIdx + (int64_t)totalNumPaddedTokens * (hiddenIdx / 128);
float const blockScale = params.inDqSfsPtr ? params.inDqSfsPtr[scaleIdx] : 1;

float const expertProb = (float)params.expertWeightsPtr[tokenIdx * params.topK + k];

float const scale = expertProb * blockScale;
acc += scale *
static_cast<float>(params.inPtr[permutedIdx * params.hiddenDimPadded + hiddenIdx]);
static_cast<float>(params.inPtr[(int64_t)permutedIdx * params.hiddenDimPadded + hiddenIdx]);
}

// The largest (finite) value that can be represented using E4m3.
Expand Down