Skip to content

Commit 1f9ac3f

Browse files
committed
Revert "[None][feat] Update TRTLLM MoE cubins; reduce mxfp4 weight padding requirement; tighten TMA bound (#9025)"
This reverts commit 86cfb3e.
1 parent 24f5cd7 commit 1f9ac3f

File tree

1,434 files changed

+9892
-22312
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,434 files changed

+9892
-22312
lines changed

cpp/include/tensorrt_llm/common/cudaUtils.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
2020
#include "tensorrt_llm/common/cudaDriverWrapper.h"
2121
#include "tensorrt_llm/common/cudaFp8Utils.h"
22-
#if ENABLE_FP4
23-
#include <cuda_fp4.h>
24-
#endif
2522
#include "tensorrt_llm/common/logger.h"
2623
#include "tensorrt_llm/common/tllmException.h"
2724
#include <algorithm>
@@ -548,9 +545,6 @@ template void printArrayInfo(__nv_bfloat16 const* ptr, uint64_t nElement, std::s
548545
#ifdef ENABLE_FP8
549546
template void printArrayInfo(__nv_fp8_e4m3 const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
550547
#endif
551-
#ifdef ENABLE_FP4
552-
template void printArrayInfo(__nv_fp4_e2m1 const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
553-
#endif
554548
template void printArrayInfo(uint32_t const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
555549
template void printArrayInfo(uint64_t const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);
556550
template void printArrayInfo(int const* ptr, uint64_t nElement, std::string name, bool const bPrintElement);

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 58 additions & 252 deletions
Large diffs are not rendered by default.

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -68,54 +68,50 @@ class TrtllmGenBatchedGemmRunner
6868
int32_t configIndex) const;
6969

7070
// Generic GEMM interface
71-
void run(int32_t m, int32_t n, int32_t k, int32_t validM, int32_t validN, int32_t validK,
72-
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
73-
void const* a, void const* sfA, void const* b, void const* sfB, void const* perTokensSfA,
74-
void const* perTokensSfB, float const* scaleC, float const* scaleGateC, float const* bias,
75-
float const* swiGluAlpha, float const* swiGluBeta, float const* clampLimit, void* c, void* outSfC,
76-
int32_t const* routeMap, int32_t const* totalNumPaddedTokens, int32_t const* ctaIdxXyToBatchIdx,
77-
int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas, void* workspace, CUstream stream,
78-
int device, int32_t configIndex);
71+
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
72+
int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b,
73+
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC,
74+
float const* scaleGateC, float const* bias, float const* swiGluAlpha, float const* swiGluBeta,
75+
float const* clampLimit, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
76+
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
77+
void* workspace, CUstream stream, int device, int32_t configIndex);
7978

8079
// Block-scaling GEMM
8180
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
8281
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device,
83-
int32_t configIndex, int32_t validM = -1, int32_t validN = -1, int32_t validK = -1);
82+
int32_t configIndex);
8483

8584
// Block-scaling GEMM with SwiGLU activation
8685
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
8786
void const* b, void const* sfB, float const* bias, float const* swiGluAlpha, float const* swiGluBeta,
8887
float const* clampLimit, void* c, void* outSfC, void* workspace, CUstream stream, int device,
89-
int32_t configIndex, int32_t validM = -1, int32_t validN = -1, int32_t validK = -1);
88+
int32_t configIndex);
9089

9190
// FP8 per-tensor scaling GEMM
9291
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* b,
9392
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device,
94-
int32_t configIndex, int32_t validM = -1, int32_t validN = -1, int32_t validK = -1);
93+
int32_t configIndex);
9594

9695
// Get the list of configs that passed the validation based on the constructor options
9796
[[nodiscard]] std::vector<int64_t> getPassingConfigIndices() const
9897
{
9998
return mPassingConfigIndices;
10099
}
101100

102-
// Get the kernel name from the config index
103-
[[nodiscard]] std::string getKernelNameFromConfigIndex(int32_t configIndex) const;
104-
105101
// Get the list of config indices that are valid for the given problem shape
106102
[[nodiscard]] std::vector<int64_t> getValidConfigIndices(int32_t m, int32_t n, int32_t k,
107-
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
108-
int32_t validM = -1, int32_t validN = -1, int32_t validK = -1) const;
103+
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
104+
int32_t maxNumCtasInBatchDim) const;
109105

110106
// Get a default config index that is valid for the given problem shape
111107
// This will be used as the fallback config if using auto-tuning
112108
[[nodiscard]] int64_t getDefaultValidConfigIndex(int32_t m, int32_t n, int32_t k,
113-
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
114-
int32_t validM = -1, int32_t validN = -1, int32_t validK = -1) const;
109+
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
110+
int32_t maxNumCtasInBatchDim) const;
115111

116112
[[nodiscard]] bool isValidConfigIndex(int32_t configIndex, int32_t m, int32_t n, int32_t k,
117-
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
118-
int32_t validM = -1, int32_t validN = -1, int32_t validK = -1) const;
113+
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
114+
int32_t maxNumCtasInBatchDim) const;
119115

120116
private:
121117
void selectGemmConfig(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/.clang-format

Lines changed: 0 additions & 78 deletions
This file was deleted.

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmEnums.h

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
*/
1717
#pragma once
1818

19-
#include <string>
2019
#include <cassert>
20+
#include <string>
2121

2222
namespace batchedGemm
2323
{
@@ -34,9 +34,7 @@ enum class RouteImpl
3434
// Use LDGSTS to do the routing
3535
Ldgsts = 1,
3636
// Use UTMALDG.GATHER4 to do the routing
37-
Tma = 2,
38-
// Use LDG+STS to do the routing
39-
LdgPlusSts = 3
37+
Tma = 2
4038
};
4139

4240
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -62,13 +60,6 @@ inline bool doesRouteImplUseTma(RouteImpl mode)
6260

6361
////////////////////////////////////////////////////////////////////////////////////////////////////
6462

65-
inline bool doesRouteImplUseLdgPlusSts(RouteImpl mode)
66-
{
67-
return (mode == RouteImpl::LdgPlusSts);
68-
}
69-
70-
////////////////////////////////////////////////////////////////////////////////////////////////////
71-
7263
} // namespace batchedGemm
7364

7465
////////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)