Skip to content
16 changes: 16 additions & 0 deletions cpp/include/tensorrt_llm/common/cudaUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <optional>
#include <sstream>
#include <string>
#include <unordered_map>
#ifndef _WIN32 // Linux
#include <sys/sysinfo.h>
#endif // not WIN32
Expand Down Expand Up @@ -432,6 +433,21 @@ inline int getMaxSharedMemoryPerBlockOptin()
return nByteMaxSharedMemoryPerBlockOptin;
}

template <typename T>
inline int getMaxActiveBlocksPerSM(T kernel, int blockSize, size_t dynamicSMemSize)
{
static std::unordered_map<T, int> cache;
auto it = cache.find(kernel);
if (it != cache.end())
{
return it->second;
}
int numBlocks;
check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks, kernel, blockSize, dynamicSMemSize));
cache[kernel] = numBlocks;
return numBlocks;
}

template <typename T1, typename T2>
inline size_t divUp(T1 const& a, T2 const& b)
{
Expand Down
126 changes: 113 additions & 13 deletions cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ __global__ void moePermuteKernel(InputType const* input, InputType* permuted_out
int64_t const kCopyPerToken = hidden_size / kElemPerCopy;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size;
Expand Down Expand Up @@ -110,7 +110,7 @@ __global__ void moePermuteKernel(InputType const* input, InputType* permuted_out
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down Expand Up @@ -141,12 +141,12 @@ void moePermute(InputType const* input, InputType* permuted_output, SFType const
}
#endif

auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize, kThreadsPerBlock>;
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;

auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize, kThreadsPerBlock>;

cudaLaunchConfig_t config;
config.gridDim = blocks;
config.blockDim = threads;
Expand Down Expand Up @@ -195,7 +195,7 @@ __global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* o
int32_t const token_idx = blockIdx.x;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

auto* dst_ptr = reinterpret_cast<ElemCopyType*>(output) + token_idx * kCopyPerToken;
Expand Down Expand Up @@ -232,7 +232,7 @@ __global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* o
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down Expand Up @@ -277,6 +277,105 @@ INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, __nv_bfloat16);
#endif
#undef INSTANTIATE_MOE_UNPERMUTE

template <typename InputType, int32_t kThreadsPerBlock>
__global__ void moeOutputMemsetKernel(InputType* input, int32_t const* tile_idx_to_mn_limit,
int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx,
int32_t const* num_non_exiting_tiles, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size)
{
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
int64_t const kCopyPerToken = hidden_size / kElemPerCopy;

InputType rmem[kElemPerCopy];
#pragma unroll
for (int32_t j = 0; j < kElemPerCopy; j++)
{
rmem[j] = 0;
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif

int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size;
for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x)
{
int32_t const tile_idx = permuted_idx / tile_size;
if (permuted_idx >= tile_idx_to_mn_limit[tile_idx])
{
continue;
}
int32_t const expanded_idx = permuted_idx_to_expanded_idx[permuted_idx];
int32_t const token_idx = expanded_idx / top_k;
int32_t const topk_idx = expanded_idx % top_k;

bool is_first_in_topk = true;
for (int32_t k = 0; k < topk_idx; k++)
{
if (expanded_idx_to_permuted_idx[token_idx * top_k + k] >= 0)
{
is_first_in_topk = false;
break;
}
}
if (!is_first_in_topk)
{
continue;
}

auto* dst_ptr = reinterpret_cast<ElemCopyType*>(input) + token_idx * kCopyPerToken;
for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock)
{
dst_ptr[i] = *reinterpret_cast<ElemCopyType*>(rmem);
}
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

template <typename InputType>
void moeOutputMemset(InputType* input, int32_t const* tile_idx_to_mn_limit, int32_t const* expanded_idx_to_permuted_idx,
int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles,
int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size,
cudaStream_t stream)
{
int32_t constexpr kThreadsPerBlock = 256;
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy);

auto kernel = &moeOutputMemsetKernel<InputType, kThreadsPerBlock>;
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;

cudaLaunchConfig_t config;
config.gridDim = blocks;
config.blockDim = threads;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel, input, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx, num_non_exiting_tiles, hidden_size, top_k, tile_size);
}

#define INSTANTIATE_MOE_OUTPUT_MEMSET(InputType) \
template void moeOutputMemset<InputType>(InputType * input, int32_t const* tile_idx_to_mn_limit, \
int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx, \
int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size, \
int32_t const top_k, int32_t const tile_size, cudaStream_t stream)

INSTANTIATE_MOE_OUTPUT_MEMSET(half);
#ifdef ENABLE_BF16
INSTANTIATE_MOE_OUTPUT_MEMSET(__nv_bfloat16);
#endif
#undef INSTANTIATE_MOE_OUTPUT_MEMSET

template <typename InputType, typename OutputType, typename SFType, int32_t kSFVecSize, typename ActFn,
int32_t kThreadsPerBlock>
__global__ void moeActivationKernel(InputType const* input, OutputType* output, float const* global_sf,
Expand All @@ -297,7 +396,7 @@ __global__ void moeActivationKernel(InputType const* input, OutputType* output,
ActFn act{};

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif

float global_sf_val = global_sf == nullptr ? 1.0f : global_sf[0];
Expand Down Expand Up @@ -353,7 +452,7 @@ __global__ void moeActivationKernel(InputType const* input, OutputType* output,
}

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand Down Expand Up @@ -382,10 +481,6 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
}
#endif

static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;

auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output,
float const* global_sf, SFType* output_sf,
int32_t const* tile_idx_to_mn_limit,
Expand Down Expand Up @@ -424,6 +519,11 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
};
auto kernel = get_act_kernel(activation_params.activation_type);

static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;

cudaLaunchConfig_t config;
config.gridDim = blocks;
config.blockDim = threads;
Expand Down
6 changes: 6 additions & 0 deletions cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ void moeUnpermute(InputType const* permuted_input, InputType* output, int32_t co
TopKScaleType const* topk_scales, int32_t const num_tokens, int32_t const hidden_size, int32_t const top_k,
cudaStream_t stream);

template <typename InputType>
void moeOutputMemset(InputType* input, int32_t const* tile_idx_to_mn_limit, int32_t const* expanded_idx_to_permuted_idx,
int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles,
int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size,
cudaStream_t stream);

template <typename InputType, typename OutputType, typename SFType>
void moeActivation(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf,
int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles,
Expand Down
37 changes: 21 additions & 16 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1587,11 +1587,6 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
int64_t num_padding_tokens = 0;
#endif

static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
int64_t const blocks = std::min(smCount * 8, std::max(num_rows * k, num_padding_tokens));
int64_t const threads = EXPAND_THREADS_PER_BLOCK;

auto func = [&]()
{
#ifdef ENABLE_FP8
Expand Down Expand Up @@ -1637,6 +1632,12 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
}
}();

static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(func, EXPAND_THREADS_PER_BLOCK, 0);
int32_t const blocks
= std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(std::max(num_rows * k, num_padding_tokens)));
int32_t const threads = EXPAND_THREADS_PER_BLOCK;

cudaLaunchConfig_t config;
config.gridDim = blocks;
config.blockDim = threads;
Expand Down Expand Up @@ -1891,15 +1892,18 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro
if (parallelism_config.ep_size > 1 && enable_alltoall)
{
// If all-to-all comm is enabled, finalizeMoeRouting doesn't need to fill the invalid output tokens with zeros.
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
int64_t const blocks = smCount * 8;
int64_t const threads = FINALIZE_THREADS_PER_BLOCK;
config.gridDim = blocks;
config.blockDim = threads;
auto func = final_scales
? &finalizeMoeRoutingNoFillingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::DEFAULT>
: &finalizeMoeRoutingNoFillingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::NO_SCALE>;

static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const maxBlocksPerSM
= tensorrt_llm::common::getMaxActiveBlocksPerSM(func, FINALIZE_THREADS_PER_BLOCK, 0);
int32_t const blocks = std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(num_rows * experts_per_token));
int32_t const threads = FINALIZE_THREADS_PER_BLOCK;

config.gridDim = blocks;
config.blockDim = threads;
cudaLaunchKernelEx(&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, final_scales,
unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts,
expert_first_token_offset, num_rows, padded_cols, unpadded_cols, experts_per_token, num_experts_per_node,
Expand Down Expand Up @@ -2235,11 +2239,6 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
int64_t num_padding_tokens = 0;
#endif

static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
int64_t const blocks = std::min(smCount * 8, std::max(expanded_num_tokens, num_padding_tokens));
int64_t const threads = ACTIVATION_THREADS_PER_BLOCK;

auto fn = [&]()
{
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
Expand Down Expand Up @@ -2302,6 +2301,12 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
}
}();

static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(fn, ACTIVATION_THREADS_PER_BLOCK, 0);
int32_t const blocks
= std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(std::max(expanded_num_tokens, num_padding_tokens)));
int32_t const threads = ACTIVATION_THREADS_PER_BLOCK;

cudaLaunchConfig_t config;
config.gridDim = blocks;
config.blockDim = threads;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,9 @@ void run(Data& data, void* stream)
//
// The upper bound is a strict requirement. The number of blocks should be determined by querying
// the device properties, or conservatively low.
static int const numBlocksCoop = tensorrt_llm::common::getMultiProcessorCount();
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
// WAR: Reserve 8 SMs for overlapping kernels.
int const numBlocksCoop = smCount - 8;

// Maximum number of tokens supported by the kernel using a cooperative launch.
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
Expand Down
Loading