Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions csrc/trtllm_fused_moe_dev_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,14 @@ __global__ void activationDeepSeekKernel(KernelParams params) {
if (permutedIdx == -1) {
continue;
}
s_scaleOutArr[tokenInCtaIdx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
// Make sure the scale is strictly positive to avoid division by zero in case the
// maximum is zero.
float scaleOut =
fmaxf(aMaxArr[tokenInCtaIdx] / E4m3MaxVal, std::numeric_limits<float>::min());
s_scaleOutArr[tokenInCtaIdx] = scaleOut;
int const scaleOut_idx =
permutedIdxArr[tokenInCtaIdx] + totalNumPaddedTokens * (hiddenIdx / 128);
params.outDqSfsPtr[scaleOut_idx] = aMaxArr[tokenInCtaIdx] / E4m3MaxVal;
params.outDqSfsPtr[scaleOut_idx] = scaleOut;
}
}
__syncthreads();
Expand Down
4 changes: 2 additions & 2 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ArtifactPath:

TRTLLM_GEN_FMHA: str = "75d477a640f268ea9ad117cc596eb39245713b9e/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"ccae3ed120a12a2c6922b458086b460413dbf731/batched_gemm-0d275a2-9936841"
"fea3b0ecfe11d7b34556042aeb5d0465ad101500/batched_gemm-332ffef-9936841"
)
TRTLLM_GEN_GEMM: str = (
"1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3"
Expand All @@ -110,7 +110,7 @@ class CheckSumHash:
"e014d7a54c396733ef012b223603c1be2861019f88faa5dcc882ed1ecfe5c2d9"
)
TRTLLM_GEN_BMM: str = (
"b7689d3046493806251351c2744c6d7faed6af25518647a955b35c4919b014fc"
"1c3c7ae0755a0acb7ad35da7dbdb90ab71c253dc289051faa2b4e3180dfc4b23"
)
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
TRTLLM_GEN_GEMM: str = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ struct BatchedGemmData {
// The matrix A. The data type is controlled by options.mDtypeA.
//
// If (routeAct == true && batchM), the shape is [M, K]
// Else
// Elseif (batchStrideInTokens > 0)
// If batchM:
// Logical shape is [sum(divUpMul(M[bi], tileM) for bi in B), K].
// Logical strides are [K, 1].
Expand All @@ -113,6 +113,14 @@ struct BatchedGemmData {
// Logical shape is [B, K / blockK, divUpMul(M, tileM), blockK].
// Logical strides are [K * divUpMul(M, tileM), divUpMul(M, tileM) * blockK, blockK, 1].
// where blockK is 128B.
// Else // batchStrideInTokens == 0
// If batchM:
// Logical shape is [M, K].
// Logical strides are [K, 1].
//
// If batchN:
// Logical shape is [B, divUpMul(M, tileM), K].
// Logical strides are [divUpMul(M, tileM) * K, K, 1].
void const* mPtrA{nullptr};

// The block scaling factors to dequantize A.
Expand Down Expand Up @@ -160,7 +168,7 @@ struct BatchedGemmData {
//
// If (routeAct == true && batchN), the shape is [N, K]
//
// Else
// Else if (batchStrideInTokens > 0)
// If batchN:
// Logical shape is [sum(divUpMul(N[bi], tileN) for bi in B), K].
// Logical strides are [K, 1].
Expand All @@ -176,6 +184,15 @@ struct BatchedGemmData {
// Logical shape is [B, K / blockK, divUpMul(N, tileN), blockK].
// Logical strides are [K * divUpMul(N, tileN), divUpMul(N, tileN) * blockK, blockK, 1].
// where blockK is 128B.
//
// Else // batchStrideInTokens == 0
// If batchN:
// Logical shape is [N, K].
// Logical strides are [K, 1].
//
// If batchM:
// Logical shape is [B, divUpMul(N, tileN), K].
// Logical strides are [divUpMul(N, tileN) * K, K, 1].
void const* mPtrB{nullptr};

// The scaling factors to dequantize B.
Expand Down Expand Up @@ -256,6 +273,13 @@ struct BatchedGemmData {
// Shape is [B].
float const* mPtrScaleC{nullptr};

// The pre-activation scaling factor (typically dequantA * dequantB) for non-gated non-linear
// activation.
// Only used when non-linear activation is applied (e.g., GELU, Relu2).
// When used, scaleC should be quantScaleC only, and this scale is applied before the
// activation. Shape is [B].
float const* mPtrScaleAct{nullptr};

// The output gate scale for Fp8 (not DeepSeek FP8) and NvFp4 quantization.
// TensorRT-LLM API requires a scaling factor on the device.
// scaleGate = dequantA * dequantB,
Expand Down Expand Up @@ -478,7 +502,127 @@ class BatchedGemmInterface {
int32_t run(BatchedGemmConfig const& config, void* workspace,
BatchedGemmData const& batchedGemmData, void* cudaStream,
int32_t /*multiProcessorCount*/, bool usePdl = true,
std::optional<std::reference_wrapper<ModuleCache>> moduleCache = std::nullopt);
std::optional<std::reference_wrapper<ModuleCache>> moduleCache = std::nullopt) {
// Get options from config and data.
auto options = getOptionsFromConfigAndData(config, batchedGemmData);

bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM;
bool const useDeepSeekFp8 = options.mUseDeepSeekFp8 && options.mDtypeA == tg::Dtype::E4m3 &&
options.mDtypeB == tg::Dtype::E4m3;

auto workspaceSizes = getWorkspaceSizesInBytes(config, batchedGemmData);
float* dPtrRowMax{nullptr};
uint32_t* dPtrRowMaxBars{nullptr};

// Set the completion barriers to 0 if needed.
if (useDeepSeekFp8 && options.mFusedAct) {
dPtrRowMax = reinterpret_cast<float*>(alignPtr(reinterpret_cast<char*>(workspace), 1024));
dPtrRowMaxBars = reinterpret_cast<uint32_t*>(
alignPtr(reinterpret_cast<char*>(dPtrRowMax) + workspaceSizes[0], 1024));
auto err = cudaMemsetAsync((void*)dPtrRowMaxBars, 0x00, workspaceSizes[1],
reinterpret_cast<cudaStream_t>(cudaStream));
if (err != cudaSuccess) {
return 1;
}
}

auto [numCtaBatch, numCtaTile, numCtaInner] =
getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim);
auto kernelParams = KernelParamsSetup::setKernelParams(
options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB,
batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA,
batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA,
batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias,
batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC,
batchedGemmData.mInputBuffers.mPtrScaleAct, batchedGemmData.mInputBuffers.mPtrScaleGate,
batchedGemmData.mInputBuffers.mPtrClampLimit,
batchedGemmData.mInputBuffers.mPtrGatedActAlpha,
batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap,
dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas,
batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens,
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx,
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch);

// The size of the grid.
std::vector<int32_t> grid = batchM ? std::vector<int32_t>{numCtaBatch, numCtaTile, numCtaInner}
: std::vector<int32_t>{numCtaTile, numCtaBatch, numCtaInner};

BatchedGemmConfig batchedGemmConfig = config;
#ifndef TLLM_GEN_EXPORT_INTERFACE
// Generate and compile the kernel if data is not provided.
if (config.mData == nullptr) {
batchedGemmConfig = generateAndCompileKernel(batchedGemmConfig);
}
TLLM_CHECK_ERROR(batchedGemmConfig.mCudaRunner != nullptr, "CudaRunner is not set");
batchedGemmConfig.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid,
/* cluster */ {},
/* instanceId */ batchedGemmConfig.mInstanceIdx);
return 0;
#endif

CUmodule cuModule;
CUfunction cuFunction;

if (moduleCache.has_value()) {
ModuleCache& moduleCacheRef = moduleCache.value().get();

// Modules are associated with a specific context, so the context is included in the key
CUcontext ctx;
unsigned long long ctxId;
cuCtxGetCurrent(&ctx);
cuCtxGetId(ctx, &ctxId);

// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a
// string in decimal representation.
std::string const ctxName =
std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
std::string const funcName = std::string(batchedGemmConfig.mFunctionName);
auto const moduleKey = ctxName + funcName;
auto module = moduleCacheRef.find(moduleKey);

// Use cache if module is found, otherwise load and insert into cache
if (module != moduleCacheRef.end()) {
cuFunction = std::get<1>(module->second);
} else {
gemm::loadCubinData(&cuModule, batchedGemmConfig);
cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName);
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
}
} else {
gemm::loadCubinData(&cuModule, batchedGemmConfig);
cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName);
}

// Prepare the grid/block.
dim3 block3{static_cast<uint32_t>(batchedGemmConfig.mNumThreadsPerCTA),
static_cast<uint32_t>(1), static_cast<uint32_t>(1)};
dim3 grid3{(grid.size() > 0 ? static_cast<uint32_t>(grid[0]) : 1u),
(grid.size() > 1 ? static_cast<uint32_t>(grid[1]) : 1u),
(grid.size() > 2 ? static_cast<uint32_t>(grid[2]) : 1u)};
// Prepare the cluster size.
dim3 cluster3{static_cast<uint32_t>(options.mClusterDimX),
static_cast<uint32_t>(options.mClusterDimY),
static_cast<uint32_t>(options.mClusterDimZ)};

// Whether PDL can safely be enabled
const bool pdlSafe = batchedGemmConfig.mOptions.mGridWaitForPrimaryRouting ||
batchedGemmConfig.mOptions.mGridWaitForPrimaryEarlyExit ||
batchedGemmConfig.mOptions.mGridWaitForPrimaryA ||
batchedGemmConfig.mOptions.mGridWaitForPrimaryB;

// Run the kernel.
auto result = trtllm::gen::launchKernel((void*)&kernelParams, cudaStream,
batchedGemmConfig.mSharedMemSize, cuFunction, block3,
grid3, cluster3, usePdl && pdlSafe);
if (result != CUDA_SUCCESS) {
return result;
}
// If a module cache has not been given, unload the module to avoid leaking
if (!moduleCache.has_value()) {
cuModuleUnload(cuModule);
}
return 0;
}

//////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -683,130 +827,6 @@ class BatchedGemmInterface {
int32_t mNumRotations;
};

int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspace,
BatchedGemmData const& batchedGemmData, void* cudaStream,
int32_t /*multiProcessorCount*/, bool usePdl,
std::optional<std::reference_wrapper<ModuleCache>> moduleCache) {
// Get options from config and data.
auto options = getOptionsFromConfigAndData(config, batchedGemmData);

bool const batchM = options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM;
bool const useDeepSeekFp8 = options.mUseDeepSeekFp8 && options.mDtypeA == tg::Dtype::E4m3 &&
options.mDtypeB == tg::Dtype::E4m3;

auto workspaceSizes = getWorkspaceSizesInBytes(config, batchedGemmData);
float* dPtrRowMax{nullptr};
uint32_t* dPtrRowMaxBars{nullptr};

// Set the completion barriers to 0 if needed.
if (useDeepSeekFp8 && options.mFusedAct) {
dPtrRowMax = reinterpret_cast<float*>(alignPtr(reinterpret_cast<char*>(workspace), 1024));
dPtrRowMaxBars = reinterpret_cast<uint32_t*>(
alignPtr(reinterpret_cast<char*>(dPtrRowMax) + workspaceSizes[0], 1024));
auto err = cudaMemsetAsync((void*)dPtrRowMaxBars, 0x00, workspaceSizes[1],
reinterpret_cast<cudaStream_t>(cudaStream));
if (err != cudaSuccess) {
return 1;
}
}

auto [numCtaBatch, numCtaTile, numCtaInner] =
getGridDim(options, batchedGemmData.mProblemDimensions.mMaxNumCtasInTokenDim);
auto kernelParams = KernelParamsSetup::setKernelParams(
options, batchM, batchedGemmData.mInputBuffers.mPtrA, batchedGemmData.mInputBuffers.mPtrB,
batchedGemmData.mOutputBuffers.mPtrC, batchedGemmData.mInputBuffers.mPtrSfA,
batchedGemmData.mInputBuffers.mPtrSfB, batchedGemmData.mInputBuffers.mPtrPerTokenSfA,
batchedGemmData.mInputBuffers.mPtrPerTokenSfB, batchedGemmData.mInputBuffers.mPtrBias,
batchedGemmData.mOutputBuffers.mPtrSfC, batchedGemmData.mInputBuffers.mPtrScaleC,
batchedGemmData.mInputBuffers.mPtrScaleGate, batchedGemmData.mInputBuffers.mPtrClampLimit,
batchedGemmData.mInputBuffers.mPtrGatedActAlpha,
batchedGemmData.mInputBuffers.mPtrGatedActBeta, batchedGemmData.mInputBuffers.mPtrRouteMap,
dPtrRowMax, dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas,
batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens,
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx,
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, numCtaBatch);

// The size of the grid.
std::vector<int32_t> grid = batchM ? std::vector<int32_t>{numCtaBatch, numCtaTile, numCtaInner}
: std::vector<int32_t>{numCtaTile, numCtaBatch, numCtaInner};

BatchedGemmConfig batchedGemmConfig = config;
#ifndef TLLM_GEN_EXPORT_INTERFACE
// Generate and compile the kernel if data is not provided.
if (config.mData == nullptr) {
batchedGemmConfig = generateAndCompileKernel(batchedGemmConfig);
}
TLLM_CHECK_ERROR(batchedGemmConfig.mCudaRunner != nullptr, "CudaRunner is not set");
batchedGemmConfig.mCudaRunner->run((void*)&kernelParams, (void*)cudaStream, grid,
/* cluster */ {},
/* instanceId */ batchedGemmConfig.mInstanceIdx);
return 0;
#endif

CUmodule cuModule;
CUfunction cuFunction;

if (moduleCache.has_value()) {
ModuleCache& moduleCacheRef = moduleCache.value().get();

// Modules are associated with a specific context, so the context is included in the key
CUcontext ctx;
unsigned long long ctxId;
cuCtxGetCurrent(&ctx);
cuCtxGetId(ctx, &ctxId);

// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a
// string in decimal representation.
std::string const ctxName =
std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
std::string const funcName = std::string(batchedGemmConfig.mFunctionName);
auto const moduleKey = ctxName + funcName;
auto module = moduleCacheRef.find(moduleKey);

// Use cache if module is found, otherwise load and insert into cache
if (module != moduleCacheRef.end()) {
cuFunction = std::get<1>(module->second);
} else {
gemm::loadCubinData(&cuModule, batchedGemmConfig);
cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName);
moduleCacheRef.insert(std::make_pair(moduleKey, std::make_tuple(cuModule, cuFunction)));
}
} else {
gemm::loadCubinData(&cuModule, batchedGemmConfig);
cuModuleGetFunction(&cuFunction, cuModule, batchedGemmConfig.mFunctionName);
}

// Prepare the grid/block.
dim3 block3{static_cast<uint32_t>(batchedGemmConfig.mNumThreadsPerCTA), static_cast<uint32_t>(1),
static_cast<uint32_t>(1)};
dim3 grid3{(grid.size() > 0 ? static_cast<uint32_t>(grid[0]) : 1u),
(grid.size() > 1 ? static_cast<uint32_t>(grid[1]) : 1u),
(grid.size() > 2 ? static_cast<uint32_t>(grid[2]) : 1u)};
// Prepare the cluster size.
dim3 cluster3{static_cast<uint32_t>(options.mClusterDimX),
static_cast<uint32_t>(options.mClusterDimY),
static_cast<uint32_t>(options.mClusterDimZ)};

// Whether PDL can safely be enabled
const bool pdlSafe = batchedGemmConfig.mOptions.mGridWaitForPrimaryRouting ||
batchedGemmConfig.mOptions.mGridWaitForPrimaryEarlyExit ||
batchedGemmConfig.mOptions.mGridWaitForPrimaryA ||
batchedGemmConfig.mOptions.mGridWaitForPrimaryB;

// Run the kernel.
auto result =
trtllm::gen::launchKernel((void*)&kernelParams, cudaStream, batchedGemmConfig.mSharedMemSize,
cuFunction, block3, grid3, cluster3, usePdl && pdlSafe);
if (result != CUDA_SUCCESS) {
return result;
}
// If a module cache has not been given, unload the module to avoid leaking
if (!moduleCache.has_value()) {
cuModuleUnload(cuModule);
}
return 0;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace batchedGemm
Expand Down
Loading