@@ -68,54 +68,50 @@ class TrtllmGenBatchedGemmRunner
6868 int32_t configIndex) const ;
6969
7070 // Generic GEMM interface
71- void run (int32_t m, int32_t n, int32_t k, int32_t validM, int32_t validN, int32_t validK,
72- std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
73- void const * a, void const * sfA, void const * b, void const * sfB, void const * perTokensSfA,
74- void const * perTokensSfB, float const * scaleC, float const * scaleGateC, float const * bias,
75- float const * swiGluAlpha, float const * swiGluBeta, float const * clampLimit, void * c, void * outSfC,
76- int32_t const * routeMap, int32_t const * totalNumPaddedTokens, int32_t const * ctaIdxXyToBatchIdx,
77- int32_t const * ctaIdxXyToMnLimit, int32_t const * numNonExitingCtas, void * workspace, CUstream stream,
78- int device, int32_t configIndex);
71+ void run (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens, int32_t numTokens,
72+ int32_t numBatches, int32_t maxNumCtasInBatchDim, void const * a, void const * sfA, void const * b,
73+ void const * sfB, void const * perTokensSfA, void const * perTokensSfB, float const * scaleC,
74+ float const * scaleGateC, float const * bias, float const * swiGluAlpha, float const * swiGluBeta,
75+ float const * clampLimit, void * c, void * outSfC, int32_t const * routeMap, int32_t const * totalNumPaddedTokens,
76+ int32_t const * ctaIdxXyToBatchIdx, int32_t const * ctaIdxXyToMnLimit, int32_t const * numNonExitingCtas,
77+ void * workspace, CUstream stream, int device, int32_t configIndex);
7978
8079 // Block-scaling GEMM
8180 void run (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens, void const * a, void const * sfA,
8281 void const * b, void const * sfB, void * c, void * outSfC, void * workspace, CUstream stream, int device,
83- int32_t configIndex, int32_t validM = - 1 , int32_t validN = - 1 , int32_t validK = - 1 );
82+ int32_t configIndex);
8483
8584 // Block-scaling GEMM with SwiGLU activation
8685 void run (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens, void const * a, void const * sfA,
8786 void const * b, void const * sfB, float const * bias, float const * swiGluAlpha, float const * swiGluBeta,
8887 float const * clampLimit, void * c, void * outSfC, void * workspace, CUstream stream, int device,
89- int32_t configIndex, int32_t validM = - 1 , int32_t validN = - 1 , int32_t validK = - 1 );
88+ int32_t configIndex);
9089
9190 // FP8 per-tensor scaling GEMM
9291 void run (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens, void const * a, void const * b,
9392 float const * scaleC, float const * scaleGateC, void * c, void * workspace, CUstream stream, int device,
94- int32_t configIndex, int32_t validM = - 1 , int32_t validN = - 1 , int32_t validK = - 1 );
93+ int32_t configIndex);
9594
9695 // Get the list of configs that passed the validation based on the constructor options
9796 [[nodiscard]] std::vector<int64_t > getPassingConfigIndices () const
9897 {
9998 return mPassingConfigIndices ;
10099 }
101100
102- // Get the kernel name from the config index
103- [[nodiscard]] std::string getKernelNameFromConfigIndex (int32_t configIndex) const ;
104-
105101 // Get the list of config indices that are valid for the given problem shape
106102 [[nodiscard]] std::vector<int64_t > getValidConfigIndices (int32_t m, int32_t n, int32_t k,
107- std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
108- int32_t validM = - 1 , int32_t validN = - 1 , int32_t validK = - 1 ) const ;
103+ std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches,
104+ int32_t maxNumCtasInBatchDim ) const ;
109105
110106 // Get a default config index that is valid for the given problem shape
111107 // This will be used as the fallback config if using auto-tuning
112108 [[nodiscard]] int64_t getDefaultValidConfigIndex (int32_t m, int32_t n, int32_t k,
113- std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
114- int32_t validM = - 1 , int32_t validN = - 1 , int32_t validK = - 1 ) const ;
109+ std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches,
110+ int32_t maxNumCtasInBatchDim ) const ;
115111
116112 [[nodiscard]] bool isValidConfigIndex (int32_t configIndex, int32_t m, int32_t n, int32_t k,
117- std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
118- int32_t validM = - 1 , int32_t validN = - 1 , int32_t validK = - 1 ) const ;
113+ std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches,
114+ int32_t maxNumCtasInBatchDim ) const ;
119115
120116private:
121117 void selectGemmConfig (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens, int32_t numTokens,
0 commit comments