Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions csrc/trtllm_fused_moe_dev_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,16 @@ __global__ void activationKernel(KernelParams params) {
// Loop over hidden dim
for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2;
hiddenIdx += blockDim.x * gridDim.x) {
int const baseIdx = permutedIdx * params.innerDim + hiddenIdx;
// Use int64_t to avoid overflow when permutedIdx * innerDim > INT32_MAX
int64_t const baseIdx = (int64_t)permutedIdx * params.innerDim + hiddenIdx;

float x1 = (float)params.inPtr[baseIdx];
float x2 = (float)params.inPtr[baseIdx + params.innerDim / 2];

float act = silu(x2);
Type out = (Type)(act * x1);

int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx;
int64_t const outIdx = (int64_t)permutedIdx * (params.innerDim / 2) + hiddenIdx;
params.outPtr[outIdx] = out;
}
}
Expand Down Expand Up @@ -261,11 +262,14 @@ __global__ void activationDeepSeekKernel(KernelParams params) {
}

// Process blocks for this CTA
int const baseIdx = permutedIdx * params.innerDim + hiddenIdx;
// Use int64_t to avoid overflow when permutedIdx * innerDim > INT32_MAX
int64_t const baseIdx = (int64_t)permutedIdx * params.innerDim + hiddenIdx;

int const scale1Idx = permutedIdx + totalNumPaddedTokens * (hiddenIdx / 128);
int const scale2Idx = permutedIdx + totalNumPaddedTokens *
((hiddenIdx / 128) + (params.innerDim / 2 / 128));
int64_t const scale1Idx =
(int64_t)permutedIdx + (int64_t)totalNumPaddedTokens * (hiddenIdx / 128);
int64_t const scale2Idx =
(int64_t)permutedIdx +
(int64_t)totalNumPaddedTokens * ((hiddenIdx / 128) + (params.innerDim / 2 / 128));

scale1Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale1Idx];
scale2Arr[tokenInCtaIdx] = params.inDqSfsPtr[scale2Idx];
Expand Down Expand Up @@ -304,8 +308,8 @@ __global__ void activationDeepSeekKernel(KernelParams params) {
float scaleOut =
fmaxf(aMaxArr[tokenInCtaIdx] / E4m3MaxVal, std::numeric_limits<float>::min());
s_scaleOutArr[tokenInCtaIdx] = scaleOut;
int const scaleOut_idx =
permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128);
int64_t const scaleOut_idx = (int64_t)permutedIdxArr[tokenInCtaIdx] +
(int64_t)totalNumPaddedTokens * (hiddenIdx / 128);
params.outDqSfsPtr[scaleOut_idx] = scaleOut;
}
}
Expand All @@ -322,7 +326,7 @@ __global__ void activationDeepSeekKernel(KernelParams params) {
continue;
}
float const scaleOut = s_scaleOutArr[tokenInCtaIdx];
int const outIdx = permutedIdx * (params.innerDim / 2) + hiddenIdx;
int64_t const outIdx = (int64_t)permutedIdx * (params.innerDim / 2) + hiddenIdx;
params.outPtr[outIdx] = static_cast<Type>(outArr[tokenInCtaIdx] / scaleOut);
}
}
Expand Down
Loading