Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
*/
#pragma once

#include <string>
#include <cassert>
#include <string>

namespace batchedGemm
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct BatchedGemmData
// Otherwise, shape is [M / 128, K / 128].
// The rightmost dimension is contiguous in memory.
//
// If DeepSeek FP8 recipe is not used, but for MxFp{4,8} and NvFp4 formats:
// If DeepSeek FP8 recipe is not used, but for MxFp{4,8}, MxInt4 and NvFp4 formats:
// The layout of scaling factors for A is always R128c4
// M must be a multiple of 128.
// K must be a multiple of 64.
Expand All @@ -138,7 +138,8 @@ struct BatchedGemmData
// Where paddedM is M if (routeAct == true && batchM), or
// sum(divUpMul(M[bi], tileM) for bi in B) if batchM,
// otherwise divUpMul(M, tileM) * B.
// Dtype is Dtype::Fp32 if DeepSeek FP8 recipe is used, otherwise Dtype::E4m3.
// Dtype is Dtype::Fp32 if DeepSeek FP8 recipe is used, otherwise Dtype is Dtype::E4m3 for
// NvFp4, Dtype::UE8m0 for MxFp{4,8} formats, Dtype::Bfloat16 for MxInt4.
//
// Otherwise should be set to nullptr.
void const* mPtrSfA{nullptr};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
*/
#pragma once

#include "GemmOptions.h"
#include "GemmGatedActOptions.h"
#include "BatchedGemmEnums.h"
#include "GemmGatedActOptions.h"
#include "GemmOptions.h"

#include <cstdint>
#include <vector>

#ifndef TLLM_GEN_EXPORT_INTERFACE
#include "trtllm/gen/GenCtx.h"
#include "trtllm/gen/CudaRunner.h"
#include "trtllm/gen/GenCtx.h"
#else
#include <iostream>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,16 @@ namespace tg = trtllm::gen;
// Type of the gated activation
enum class ActType
{
// For ActType == SwiGlu, ideally we would like to have something like
// gatedAct = quantScaleC * (x0 * dequantScaleAb + beta) * ((x1 * scaleGate) *
// sigmoid(alpha * x1 * scaleGate)).
// But for now, we use the simplified version
// gatedAct = scaleC * (x0 + beta') * ((x1 * scaleGate) * sigmoid(alpha * x1 * scaleGate)),
// where x0 and x1 are the raw numbers from Gemm, while scaleC and scaleGate are input scales,
// beta' = beta / dequantScaleAb, scaleC = quantScaleC * dequantScaleAb.
//
// GatedSilu is a special case of SwiGlu where the alpha is 1.0 and the beta is 0.0.
// clang-format off
// For ActType == SwiGlu, ideally we would like to have something like
// gatedAct = quantScaleC * (x0 * dequantScaleAb + beta) * ((x1 * scaleGate) * sigmoid(alpha * x1 * scaleGate)).
// But for now, we use the simplified version
// gatedAct = scaleC * (x0 + beta') * ((x1 * scaleGate) * sigmoid(alpha * x1 * scaleGate)),
// where x0 and x1 are the raw numbers from Gemm, while scaleC and scaleGate are input scales,
// beta' = beta / dequantScaleAb, scaleC = quantScaleC * dequantScaleAb.
//
// GatedSilu is a special case of SwiGlu where the alpha is 1.0 and the beta is 0.0.
// clang-format on
SwiGlu,
// For ActType == GeGlu, we use the simplified version
// gatedAct = scaleC' * (x0 + beta') * ((x1 * scaleGate) * phi(alpha * x1 * scaleGate)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
#include "KernelTraits.h"
#include "trtllm/gen/CudaArchDecl.h"
#include "trtllm/gen/DtypeDecl.h"
#include "trtllm/gen/SfLayoutDecl.h"
#include "trtllm/gen/MmaDecl.h"
#include "trtllm/gen/SfLayoutDecl.h"
#ifndef TLLM_GEN_EXPORT_INTERFACE
#include "trtllm/gen/GenCtx.h"
#include "trtllm/gen/CudaRunner.h"
#include "trtllm/gen/GenCtx.h"
#else
#ifdef TLLM_GEN_EXPORT_FLASHINFER
#include <string>
Expand Down Expand Up @@ -720,9 +720,10 @@ inline bool checkAndUpdateGemmOptions(
#endif // TLLM_PUBLIC_RELEASE

// Check that the A cast is supported.
// Currently, we only support {MxFp4, NvFp4} -> Bf16.
// Currently, we only support {MxFp4, NvFp4, MxInt4} -> Bf16.
TLLM_CHECK_ERROR((options.mDtypeA == options.mDtypeMmaA)
|| ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1)
|| ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1
|| options.mDtypeA == tg::Dtype::MxInt4)
&& options.mDtypeMmaA == tg::Dtype::Bfloat16)
|| (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeMmaA == tg::Dtype::E4m3),
"Unsupported cast for A: ", tg::dtypeToString(options.mDtypeA), " -> ", tg::dtypeToString(options.mDtypeMmaA));
Expand Down Expand Up @@ -1423,6 +1424,19 @@ inline bool checkAndUpdateGemmOptions(
}
}

if (isBlackwell && !options.mUseCustomMmaSchedule && !options.mUseDeepSeekFp8
&& options.mTileScheduler == TileScheduler::Persistent)
{
if (updateOptions)
{
options.mUseCustomMmaSchedule = true;
}
else
{
TLLM_CHECK_ERROR(false, "TileScheduler::Persistent and !UseCustomMmaSchedule is not supported.");
}
}

if (options.mEnablesDelayedEarlyExit && options.mEnablesEarlyExit)
{
TLLM_LOG_WARNING(
Expand Down Expand Up @@ -1623,8 +1637,8 @@ inline CUresult loadCubinData(CUmodule* module, Config const& config)
// Trtllm links the cubin into the executable while Flashinfer loads the cubin from storage.
#ifdef TLLM_GEN_EXPORT_FLASHINFER
#ifdef TLLM_GEN_GEMM_CUBIN_PATH
static std::string const tllm_gen_gemm_cubin_path = std::string(TLLM_GEN_GEMM_CUBIN_PATH);
std::string const sha256 = config.mHash ? config.mHash : "";
static const std::string tllm_gen_gemm_cubin_path = std::string(TLLM_GEN_GEMM_CUBIN_PATH);
const std::string sha256 = config.mHash ? config.mHash : "";
std::string fileName = config.mFunctionName;
if (!fileName.empty())
{
Expand Down
Loading
Loading