Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions csrc/trtllm_fused_moe_dev_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,16 @@ struct KernelTraits<1> {

////////////////////////////////////////////////////////////////////////////////////////////////////

constexpr int DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA = 128;

template <typename KernelParams>
__global__ void activationDeepSeekKernel(KernelParams params) {
using Type = typename KernelParams::Type;
int32_t constexpr NumTokensPerCta = KernelParams::NumTokensPerCta;
using KernelTraits = KernelTraits<NumTokensPerCta>;
using MaxOp = typename KernelTraits::MaxOp;
using PackedType = typename KernelTraits::PackedType;
using BlockReduce = cub::BlockReduce<PackedType, 128>;
using BlockReduce = cub::BlockReduce<PackedType, DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA>;

__shared__ float s_scaleOutArr[NumTokensPerCta];
__shared__ typename BlockReduce::TempStorage tempStorage;
Expand Down Expand Up @@ -235,6 +237,15 @@ __global__ void activationDeepSeekKernel(KernelParams params) {
tokenCtaIdx += gridDim.y * NumTokensPerCta) {
for (int hiddenIdx = threadIdx.x + blockDim.x * blockIdx.x; hiddenIdx < params.innerDim / 2;
hiddenIdx += blockDim.x * gridDim.x) {
#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
scale1Arr[tokenInCtaIdx] = 0.0f;
scale2Arr[tokenInCtaIdx] = 0.0f;
dataX1Arr[tokenInCtaIdx] = 0.0f;
dataX2Arr[tokenInCtaIdx] = 0.0f;
outArr[tokenInCtaIdx] = 0.0f;
absOutArr[tokenInCtaIdx] = 0.0f;
}
Comment on lines +241 to +248
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The arrays outArr and absOutArr are unconditionally written to in the subsequent loop (lines 278-285) before being read. Therefore, initializing them to zero here is redundant and can be removed for a minor performance improvement.

        for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
          scale1Arr[tokenInCtaIdx] = 0.0f;
          scale2Arr[tokenInCtaIdx] = 0.0f;
          dataX1Arr[tokenInCtaIdx] = 0.0f;
          dataX2Arr[tokenInCtaIdx] = 0.0f;
        }

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nekorobov do you think gemini's suggestion is reasonable?

#pragma unroll
for (int tokenInCtaIdx = 0; tokenInCtaIdx < NumTokensPerCta; tokenInCtaIdx++) {
int const tokenIdx = tokenCtaIdx + tokenInCtaIdx;
Expand Down Expand Up @@ -328,7 +339,6 @@ void run(Data const& data, void* stream) {
if (data.mUseDeepSeekFp8) {
constexpr int NUM_ELTS_PER_LOAD = 1;
constexpr int NUM_ELTS_PER_SF = 128;
int const NUM_THREADS_PER_CTA = 128;

int device{-1};
cudaGetDevice(&device);
Expand All @@ -355,8 +365,8 @@ void run(Data const& data, void* stream) {

const dim3 grid(gridSizeX, gridSizeY, data.topK);

LAUNCH_ACTIVATION(data, activationDeepSeekKernel, numTokensPerCta, grid, NUM_THREADS_PER_CTA, 0,
stream);
LAUNCH_ACTIVATION(data, activationDeepSeekKernel, numTokensPerCta, grid,
DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA, 0, stream);
} else {
int const numThreads = 256;
const dim3 grid(data.innerDim / 128, data.topK, data.numTokens);
Expand Down