Skip to content

Commit 757579a

Browse files
committed
limit num of tactics to tune for
Signed-off-by: Anthony Chang <[email protected]>
1 parent b8b10cb commit 757579a

File tree

5 files changed

+61
-1
lines changed

5 files changed

+61
-1
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,9 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void*
251251
if (mDtypeWeights == btg::Dtype::MxE2m1 && mDtypeAct == btg::Dtype::MxE4m3)
252252
{
253253
// The multiple is no less than 128 as TMA requires it for CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B types
254-
// FIXME: enforce valid hidden dim to be multiple of 512 due to unhandled OOB read in routeAct
254+
// FIXME: enforce valid hidden dim to be multiple of 512 due to unhandled OOB read in routeAct. Please keep this
255+
// in sync with
256+
// tensorrt_llm/_torch/modules/fused_moe/quantization.py:MXFP4WeightTRTLLMGenFusedMoEMethod.input_hidden_alignment
255257
validHiddenSize = tensorrt_llm::common::roundUp(validHiddenSize, 512);
256258
}
257259
auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim);

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include "tensorrt_llm/common/cudaUtils.h"
2323
#include "tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h"
2424
#include "tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
25+
#include "tensorrt_llm/thop/thUtils.h"
26+
#include <set>
2527
#include <string>
2628

2729
namespace tensorrt_llm
@@ -31,6 +33,33 @@ namespace kernels
3133
namespace trtllmGenFp8BlockScaleMoe
3234
{
3335

36+
inline std::set<int32_t> computeSelectedTileN(std::vector<int32_t> const& supported_tile_nums, int64_t const num_tokens,
37+
int64_t const top_k, int64_t const num_local_experts)
38+
{
39+
float const avg_tokens_per_expert = static_cast<float>(num_tokens * top_k) / num_local_experts;
40+
// assume supported_tile_nums is sorted
41+
int32_t tile_tokens_dim = std::clamp(
42+
torch_ext::nextPowerOfTwo(avg_tokens_per_expert), supported_tile_nums.front(), supported_tile_nums.back());
43+
auto it = std::find(supported_tile_nums.begin(), supported_tile_nums.end(), tile_tokens_dim);
44+
45+
std::set<int32_t> selected_tile_nums;
46+
selected_tile_nums.insert(tile_tokens_dim);
47+
if (std::next(it) != supported_tile_nums.end())
48+
{
49+
selected_tile_nums.insert(*std::next(it));
50+
if (std::next(std::next(it)) != supported_tile_nums.end())
51+
{
52+
selected_tile_nums.insert(*std::next(std::next(it)));
53+
}
54+
}
55+
if (it != supported_tile_nums.begin())
56+
{
57+
selected_tile_nums.insert(*std::prev(it));
58+
}
59+
60+
return selected_tile_nums;
61+
}
62+
3463
namespace Routing
3564
{
3665

cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace torch_ext
2727
namespace btg = batchedGemm::trtllm::gen;
2828
using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::RoutingMethodType;
2929
using MoeRunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner;
30+
using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::computeSelectedTileN;
3031

3132
std::vector<torch::Tensor> run_fp4_block_scale_moe_runner(torch::optional<torch::Tensor> const& routing_logits,
3233
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
@@ -419,6 +420,11 @@ class FP4BlockScaleMoeRunner : public torch::CustomClassHolder
419420
std::vector<std::vector<int64_t>> tactics;
420421
for (auto& [tileN, runner] : mRunners)
421422
{
423+
auto chosen = computeSelectedTileN(mSupportedTileN, numTokens, topK, numLocalExperts);
424+
if (chosen.find(tileN) == chosen.end())
425+
{
426+
continue;
427+
}
422428
auto config_indices_per_runner
423429
= runner->getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
424430
for (auto cfg : config_indices_per_runner)
@@ -500,6 +506,11 @@ class FP8FP4BlockScaleMoeRunner : public torch::CustomClassHolder
500506
std::vector<std::vector<int64_t>> tactics;
501507
for (auto& [tileN, runner] : mRunners)
502508
{
509+
auto chosen = computeSelectedTileN(mSupportedTileN, numTokens, topK, numLocalExperts);
510+
if (chosen.find(tileN) == chosen.end())
511+
{
512+
continue;
513+
}
503514
auto config_indices_per_runner
504515
= runner->getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
505516
for (auto cfg : config_indices_per_runner)

cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ namespace torch_ext
3232
namespace btg = batchedGemm::trtllm::gen;
3333
using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::RoutingMethodType;
3434
using MoeRunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner;
35+
using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::computeSelectedTileN;
3536

3637
at::Tensor run_fp8_block_scale_moe(at::optional<at::Tensor> const& routing_logits,
3738
std::optional<at::Tensor> const& routing_bias, at::Tensor const& hidden_states,
@@ -335,6 +336,11 @@ class FP8BlockScaleMoeRunner : public torch::CustomClassHolder
335336
std::vector<std::vector<int64_t>> tactics;
336337
for (auto& [tileN, runner] : mRunners)
337338
{
339+
auto chosen = computeSelectedTileN(mSupportedTileN, numTokens, topK, numLocalExperts);
340+
if (chosen.find(tileN) == chosen.end())
341+
{
342+
continue;
343+
}
338344
auto config_indices_per_runner
339345
= runner->getValidConfigIndices(topK, hiddenSize, intermediateSize, numLocalExperts, numTokens);
340346
for (auto cfg : config_indices_per_runner)

cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace torch_ext
3030
namespace btg = batchedGemm::trtllm::gen;
3131
using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::RoutingMethodType;
3232
using MoeRunnerType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::Runner;
33+
using tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::computeSelectedTileN;
3334

3435
torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional<torch::Tensor> const& routing_logits,
3536
torch::optional<torch::Tensor> const& routing_bias, torch::Tensor const& hidden_states,
@@ -500,6 +501,11 @@ class Bf16MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder
500501
std::vector<std::vector<int64_t>> tactics;
501502
for (auto& [tileN, runner] : mRunners)
502503
{
504+
auto chosen = computeSelectedTileN(mSupportedTileN, numTokens, topK, numLocalExperts);
505+
if (chosen.find(tileN) == chosen.end())
506+
{
507+
continue;
508+
}
503509
auto config_indices_per_runner = runner->getValidConfigIndices(
504510
topK, hiddenSize, intermediateSize, numLocalExperts, numTokens, validHiddenSize, validIntermediateSize);
505511
for (auto cfg : config_indices_per_runner)
@@ -587,6 +593,12 @@ class MxE4m3MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder
587593
std::vector<std::vector<int64_t>> tactics;
588594
for (auto& [tileN, runner] : mRunners)
589595
{
596+
auto chosen = computeSelectedTileN(mSupportedTileN, numTokens, topK, numLocalExperts);
597+
if (chosen.find(tileN) == chosen.end())
598+
{
599+
continue;
600+
}
601+
590602
auto config_indices_per_runner = runner->getValidConfigIndices(
591603
topK, hiddenSize, intermediateSize, numLocalExperts, numTokens, validHiddenSize, validIntermediateSize);
592604
for (auto cfg : config_indices_per_runner)

0 commit comments

Comments
 (0)