Skip to content

Commit 9b3d7cc

Browse files
authored
[None][feat] Update TRT-LLM Gen MoE kernels (#7970)
Signed-off-by: Nikita Korobov <[email protected]>
1 parent 01423ac commit 9b3d7cc

File tree

217 files changed

+7259
-4229
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

217 files changed

+7259
-4229
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,20 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
110110
&& options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch
111111
&& tileSize == mOptions.tileSize)
112112
{
113+
auto sm = configs[i].mSm;
114+
if (sm != SmVersion::Sm100f)
115+
{
116+
int smVersion = tensorrt_llm::common::getSMVersion();
117+
if (smVersion == 100 && sm != SmVersion::Sm100a)
118+
{
119+
continue;
120+
}
121+
else if (smVersion == 103 && sm != SmVersion::Sm103a)
122+
{
123+
continue;
124+
}
125+
}
126+
113127
// FIXME: Disable split-k for now.
114128
if (options.mClusterDimZ != 1)
115129
{

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ class BatchedGemmInterface
548548
// Aligns the pointer to the alignment
549549
template <typename Dtype>
550550
inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const;
551+
551552
// Returns the size of the workspace buffers in bytes
552553
std::vector<size_t> getWorkspaceSizesInBytes(BatchedGemmConfig const& config, BatchedGemmData const& data) const;
553554

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/KernelMetaInfo.h

Lines changed: 6620 additions & 3740 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)