Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hold this PR for now since we need to confirm if there is perf regression.

Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
---
AccessModifierOffset: -4
AlignAfterOpenBracket: DontAlign
AlignConsecutiveAssignments: None
AlignConsecutiveDeclarations: None
AlignOperands: false
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: true
AllowShortBlocksOnASingleLine: Empty
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: Yes
BasedOnStyle: None
BinPackArguments: true
BinPackParameters: true
BreakBeforeBinaryOperators: All
BreakBeforeBraces: Allman
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: true
ColumnLimit: 120
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: false
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ExperimentalAutoDetectBinPacking: false
ForEachMacros: [ foreach, Q_FOREACH, BOOST_FOREACH ]
IncludeBlocks: Preserve
IncludeCategories:
- Regex: '^"(llvm|llvm-c|clang|clang-c)/'
Priority: 2
- Regex: '^(<|"(gtest|isl|json)/)'
Priority: 3
- Regex: '.*'
Priority: 1
IndentCaseLabels: false
IndentWidth: 4
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: true
Language: Cpp
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 4
ObjCSpaceAfterProperty: true
ObjCSpaceBeforeProtocolList: true
PenaltyBreakBeforeFirstCallParameter: 19
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 60
PointerAlignment: Left
QualifierAlignment: Right
ReflowComments: true
SeparateDefinitionBlocks: Always
SortIncludes: false
SpaceAfterCStyleCast: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInCStyleCastParentheses: false
SpacesInContainerLiterals: true
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: c++14
TabWidth: 4
UseTab: Never
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
*/
#pragma once

#include <cassert>
#include <string>
#include <cassert>

namespace batchedGemm
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct BatchedGemmData
// Otherwise, shape is [M / 128, K / 128].
// The rightmost dimension is contiguous in memory.
//
// If DeepSeek FP8 recipe is not used, but for MxFp{4,8}, MxInt4 and NvFp4 formats:
// If DeepSeek FP8 recipe is not used, but for MxFp{4,8} and NvFp4 formats:
// The layout of scaling factors for A is always R128c4
// M must be a multiple of 128.
// K must be a multiple of 64.
Expand All @@ -138,8 +138,7 @@ struct BatchedGemmData
// Where paddedM is M if (routeAct == true && batchM), or
// sum(divUpMul(M[bi], tileM) for bi in B) if batchM,
// otherwise divUpMul(M, tileM) * B.
// Dtype is Dtype::Fp32 if DeepSeek FP8 recipe is used, otherwise Dtype is Dtype::E4m3 for
// NvFp4, Dtype::UE8m0 for MxFp{4,8} formats, Dtype::Bfloat16 for MxInt4.
// Dtype is Dtype::Fp32 if DeepSeek FP8 recipe is used, otherwise Dtype::E4m3.
//
// Otherwise should be set to nullptr.
void const* mPtrSfA{nullptr};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
*/
#pragma once

#include "BatchedGemmEnums.h"
#include "GemmGatedActOptions.h"
#include "GemmOptions.h"
#include "GemmGatedActOptions.h"
#include "BatchedGemmEnums.h"

#include <cstdint>
#include <vector>

#ifndef TLLM_GEN_EXPORT_INTERFACE
#include "trtllm/gen/CudaRunner.h"
#include "trtllm/gen/GenCtx.h"
#include "trtllm/gen/CudaRunner.h"
#else
#include <iostream>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,15 @@ namespace tg = trtllm::gen;
// Type of the gated activation
enum class ActType
{
// clang-format off
// For ActType == SwiGlu, ideally we would like to have something like
// gatedAct = quantScaleC * (x0 * dequantScaleAb + beta) * ((x1 * scaleGate) * sigmoid(alpha * x1 * scaleGate)).
// But for now, we use the simplified version
// gatedAct = scaleC * (x0 + beta') * ((x1 * scaleGate) * sigmoid(alpha * x1 * scaleGate)),
// where x0 and x1 are the raw numbers from Gemm, while scaleC and scaleGate are input scales,
// beta' = beta / dequantScaleAb, scaleC = quantScaleC * dequantScaleAb.
//
// GatedSilu is a special case of SwiGlu where the alpha is 1.0 and the beta is 0.0.
// clang-format on
// For ActType == SwiGlu, ideally we would like to have something like
// gatedAct = quantScaleC * (x0 * dequantScaleAb + beta) * ((x1 * scaleGate) *
// sigmoid(alpha * x1 * scaleGate)).
// But for now, we use the simplified version
// gatedAct = scaleC * (x0 + beta') * ((x1 * scaleGate) * sigmoid(alpha * x1 * scaleGate)),
// where x0 and x1 are the raw numbers from Gemm, while scaleC and scaleGate are input scales,
// beta' = beta / dequantScaleAb, scaleC = quantScaleC * dequantScaleAb.
//
// GatedSilu is a special case of SwiGlu where the alpha is 1.0 and the beta is 0.0.
SwiGlu,
// For ActType == GeGlu, we use the simplified version
// gatedAct = scaleC' * (x0 + beta') * ((x1 * scaleGate) * phi(alpha * x1 * scaleGate)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
#include "KernelTraits.h"
#include "trtllm/gen/CudaArchDecl.h"
#include "trtllm/gen/DtypeDecl.h"
#include "trtllm/gen/MmaDecl.h"
#include "trtllm/gen/SfLayoutDecl.h"
#include "trtllm/gen/MmaDecl.h"
#ifndef TLLM_GEN_EXPORT_INTERFACE
#include "trtllm/gen/CudaRunner.h"
#include "trtllm/gen/GenCtx.h"
#include "trtllm/gen/CudaRunner.h"
#else
#ifdef TLLM_GEN_EXPORT_FLASHINFER
#include <string>
Expand Down Expand Up @@ -720,10 +720,9 @@ inline bool checkAndUpdateGemmOptions(
#endif // TLLM_PUBLIC_RELEASE

// Check that the A cast is supported.
// Currently, we only support {MxFp4, NvFp4, MxInt4} -> Bf16.
// Currently, we only support {MxFp4, NvFp4} -> Bf16.
TLLM_CHECK_ERROR((options.mDtypeA == options.mDtypeMmaA)
|| ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1
|| options.mDtypeA == tg::Dtype::MxInt4)
|| ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1)
&& options.mDtypeMmaA == tg::Dtype::Bfloat16)
|| (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeMmaA == tg::Dtype::E4m3),
"Unsupported cast for A: ", tg::dtypeToString(options.mDtypeA), " -> ", tg::dtypeToString(options.mDtypeMmaA));
Expand Down Expand Up @@ -1424,19 +1423,6 @@ inline bool checkAndUpdateGemmOptions(
}
}

if (isBlackwell && !options.mUseCustomMmaSchedule && !options.mUseDeepSeekFp8
&& options.mTileScheduler == TileScheduler::Persistent)
{
if (updateOptions)
{
options.mUseCustomMmaSchedule = true;
}
else
{
TLLM_CHECK_ERROR(false, "TileScheduler::Persistent and !UseCustomMmaSchedule is not supported.");
}
}

if (options.mEnablesDelayedEarlyExit && options.mEnablesEarlyExit)
{
TLLM_LOG_WARNING(
Expand Down Expand Up @@ -1637,8 +1623,8 @@ inline CUresult loadCubinData(CUmodule* module, Config const& config)
// Trtllm links the cubin into the executable while Flashinfer loads the cubin from storage.
#ifdef TLLM_GEN_EXPORT_FLASHINFER
#ifdef TLLM_GEN_GEMM_CUBIN_PATH
static const std::string tllm_gen_gemm_cubin_path = std::string(TLLM_GEN_GEMM_CUBIN_PATH);
const std::string sha256 = config.mHash ? config.mHash : "";
static std::string const tllm_gen_gemm_cubin_path = std::string(TLLM_GEN_GEMM_CUBIN_PATH);
std::string const sha256 = config.mHash ? config.mHash : "";
std::string fileName = config.mFunctionName;
if (!fileName.empty())
{
Expand Down
Loading
Loading