diff --git a/csrc/trtllm_fused_moe_dev_kernel.cu b/csrc/trtllm_fused_moe_dev_kernel.cu index 7a58042041..a19c89638d 100644 --- a/csrc/trtllm_fused_moe_dev_kernel.cu +++ b/csrc/trtllm_fused_moe_dev_kernel.cu @@ -196,6 +196,8 @@ struct KernelTraits<1> { //////////////////////////////////////////////////////////////////////////////////////////////////// +constexpr int DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA = 128; + template __global__ void activationDeepSeekKernel(KernelParams params) { using Type = typename KernelParams::Type; @@ -203,7 +205,7 @@ __global__ void activationDeepSeekKernel(KernelParams params) { using KernelTraits = KernelTraits; using MaxOp = typename KernelTraits::MaxOp; using PackedType = typename KernelTraits::PackedType; - using BlockReduce = cub::BlockReduce; + using BlockReduce = cub::BlockReduce; __shared__ float s_scaleOutArr[NumTokensPerCta]; __shared__ typename BlockReduce::TempStorage tempStorage; @@ -235,6 +237,15 @@ __global__ void activationDeepSeekKernel(KernelParams params) { tokenCtaIdx += gridDim.y * NumTokensPerCta) { for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2; hiddenIdx += blockDim.x * gridDim.x) { +#pragma unroll + for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { + scale1Arr[tokenInCtaIdx] = 0.0f; + scale2Arr[tokenInCtaIdx] = 0.0f; + dataX1Arr[tokenInCtaIdx] = 0.0f; + dataX2Arr[tokenInCtaIdx] = 0.0f; + outArr[tokenInCtaIdx] = 0.0f; + absOutArr[tokenInCtaIdx] = 0.0f; + } #pragma unroll for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) { int const tokenIdx = tokenCtaIdx + tokenInCtaIdx; @@ -328,7 +339,6 @@ void run(Data const& data, void* stream) { if (data.mUseDeepSeekFp8) { constexpr int NUM_ELTS_PER_LOAD = 1; constexpr int NUM_ELTS_PER_SF = 128; - int const NUM_THREADS_PER_CTA = 128; int device{-1}; cudaGetDevice(&device); @@ -355,8 +365,8 @@ void run(Data const& data, void* stream) { const dim3 grid(gridSizeX, gridSizeY, data.topK); - LAUNCH_ACTIVATION(data, activationDeepSeekKernel, numTokensPerCta, grid, NUM_THREADS_PER_CTA, 0, - stream); + LAUNCH_ACTIVATION(data, activationDeepSeekKernel, numTokensPerCta, grid, + DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA, 0, stream); } else { int const numThreads = 256; const dim3 grid(data.innerDim / 128, data.topK, data.numTokens);