Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
557db0a
wip: not compiles yet
nekorobov Feb 5, 2026
45cdb86
fix: compiles, but hangs in autotuning
nekorobov Feb 5, 2026
d8c15b4
banned splitK and tileN 256, unit test works
nekorobov Feb 5, 2026
8a7a269
Merge remote-tracking branch 'origin/main' into nkorobov/mxfp8-trtllm…
IwakuraRein Feb 5, 2026
77c49a7
upd
IwakuraRein Feb 5, 2026
3e1a29f
add mxfp8 bench
IwakuraRein Feb 5, 2026
b12c461
fix test
IwakuraRein Feb 6, 2026
46eddfa
upd comments
IwakuraRein Feb 6, 2026
b046320
drop tile==8 and use unroll loop 2x
IwakuraRein Feb 6, 2026
acf0c39
fix test
IwakuraRein Feb 6, 2026
2702ee2
WAR: drop all UnrollLoop2xForMma kernels
IwakuraRein Feb 6, 2026
1dc688d
Merge remote-tracking branch 'origin/main' into siyuanf/mxfp8-trtllm-…
IwakuraRein Feb 7, 2026
4e83b82
address comment
IwakuraRein Feb 9, 2026
aae1719
fix unit test
IwakuraRein Feb 9, 2026
73d7594
fix hang and segfault
nekorobov Feb 10, 2026
4354ec4
use permute cache in unit test (WIP)
IwakuraRein Feb 10, 2026
0944312
use permute cache in unit test (WIP)
IwakuraRein Feb 10, 2026
aa85e94
Revert "use permute cache in unit test (WIP)"
IwakuraRein Feb 11, 2026
a7ebf1e
Merge remote-tracking branch 'origin/main' into siyuanf/mxfp8-trtllm-…
IwakuraRein Feb 12, 2026
4815a0c
address comments
IwakuraRein Feb 13, 2026
e18d73c
intermediate_size_factor
IwakuraRein Feb 13, 2026
b9f198d
Merge remote-tracking branch 'origin/main' into siyuanf/mxfp8-trtllm-…
IwakuraRein Feb 13, 2026
c310276
address comments
IwakuraRein Feb 13, 2026
33acaa2
quick fix
IwakuraRein Feb 13, 2026
03cac02
fix intermediate_size_factor initialization
IwakuraRein Feb 14, 2026
19417d1
allow split k
IwakuraRein Feb 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 53 additions & 23 deletions benchmarks/bench_trtllm_gen_fused_moe_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
mxfp8_quantize,
)
from flashinfer.fused_moe import (
Fp8QuantizationType,
trtllm_fp4_block_scale_moe,
trtllm_mxint4_block_scale_moe,
trtllm_fp8_per_tensor_scale_moe,
Expand Down Expand Up @@ -53,7 +54,7 @@ def mxint4_quantize(

def bench_trtllm_gen_fused_moe_autotuner_fp8(
tune_max_num_tokens: Optional[int],
quant_mode: Literal["Fp8-Per-Tensor", "Fp8-Block"],
quant_mode: Literal["Fp8-Per-Tensor", "Fp8-Block", "MxFP8xMxFP8"],
num_tokens: int,
num_experts: int,
hidden_size: int,
Expand All @@ -79,29 +80,54 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
torch.bfloat16
)

is_block_scale = quant_mode == "Fp8-Block"
if not is_block_scale:
is_block_scale = quant_mode != "Fp8-Per-Tensor"
if quant_mode == "Fp8-Per-Tensor":
hidden_states, hidden_states_scale = fp8_quantize(hidden_states)
w13, w13_scale = fp8_quantize(w13)
w2, w2_scale = fp8_quantize(w2)
else:
# block scale quantization is too slow, so we use per-tensor quantization for now
hidden_states, hidden_states_scale = fp8_quantize(hidden_states)
w13, w13_scale = fp8_quantize(w13)
w2, w2_scale = fp8_quantize(w2)
hidden_states_scale = torch.full(
(hidden_size // 128, num_tokens), hidden_states_scale.item(), device=device
)
w13_scale = torch.full(
(num_experts, intermediate_size * 2 // 128, hidden_size // 128),
w13_scale.item(),
device=device,
)
w2_scale = torch.full(
(num_experts, hidden_size // 128, intermediate_size // 128),
w2_scale.item(),
device=device,
)
scale_vec_size = 128 if quant_mode == "Fp8-Block" else 32
if quant_mode == "Fp8-Block":
# block scale quantization is too slow, so we use per-tensor quantization for now
hidden_states, hidden_states_scale = fp8_quantize(
hidden_states
) # scalar quantization
w13, w13_scale = fp8_quantize(w13) # scalar quantization
w2, w2_scale = fp8_quantize(w2) # scalar quantization
hidden_states_scale = torch.full(
(hidden_size // scale_vec_size, num_tokens),
hidden_states_scale.item(),
device=device,
)
w13_scale = torch.full(
(
num_experts,
intermediate_size * 2 // scale_vec_size,
hidden_size // scale_vec_size,
),
w13_scale.item(),
device=device,
)
w2_scale = torch.full(
(
num_experts,
hidden_size // scale_vec_size,
intermediate_size // scale_vec_size,
),
w2_scale.item(),
device=device,
)
else: # MxFP8xMxFP8
hidden_states, hidden_states_scale = mxfp8_quantize(hidden_states, False)
w13, w13_scale = mxfp8_quantize(w13, True)
w2, w2_scale = mxfp8_quantize(w2, True)
hidden_states_scale = hidden_states_scale.view(torch.uint8).reshape(
num_tokens, -1
)
w13_scale = w13_scale.view(torch.uint8).reshape(
num_experts, intermediate_size * 2, -1
)
w2_scale = w2_scale.view(torch.uint8).reshape(num_experts, hidden_size, -1)

output1_scale_scalar = (
torch.tensor([hidden_states_scale * w13_scale] * num_experts, device=device)
Expand Down Expand Up @@ -136,12 +162,15 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8(
local_num_experts=num_experts,
routed_scaling_factor=2.5,
routing_method_type=RoutingMethodType.DeepSeekV3.value,
use_shuffled_weight=False,
weight_layout=WeightLayout.MajorK.value, # weight_layout
use_shuffled_weight=quant_mode == "MxFP8xMxFP8",
weight_layout=WeightLayout.MajorK.value,
enable_pdl=enable_pdl,
tune_max_num_tokens=num_tokens
if tune_max_num_tokens is None
else tune_max_num_tokens,
fp8_quantization_type=Fp8QuantizationType.DeepSeekFp8
if quant_mode == "Fp8-Block"
else Fp8QuantizationType.MxFp8,
)
else:
fn = partial(
Expand Down Expand Up @@ -468,6 +497,7 @@ def bench(do_autotune):
"MxFP4xMxFP8",
"MxFP4xBf16",
"MxInt4xBf16",
"MxFP8xMxFP8",
"Fp8-Per-Tensor",
"Fp8-Block",
],
Expand Down Expand Up @@ -505,7 +535,7 @@ def bench(do_autotune):
args = parser.parse_args()
fn = (
bench_trtllm_gen_fused_moe_autotuner_fp8
if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"]
if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block", "MxFP8xMxFP8"]
else bench_trtllm_gen_fused_moe_autotuner_mxint4
if args.quant_mode == "MxInt4xBf16"
else bench_trtllm_gen_fused_moe_autotuner_fp4
Expand Down
16 changes: 9 additions & 7 deletions csrc/trtllm_batched_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(
size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(
int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
int32_t numBatches, int32_t maxNumCtasInBatchDim, int32_t configIndex) const {
BatchedGemmData gemmData;
BatchedGemmData gemmData{};
gemmData.mProblemDimensions.mNumBatches = numBatches;
gemmData.mProblemDimensions.mNumTokens = numTokens;
gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput;
Expand Down Expand Up @@ -174,11 +174,12 @@ void TrtllmGenBatchedGemmRunner::run(
CUstream stream, int device, int32_t configIndex, bool enable_pdl) {
auto bmm = BatchedGemmInterface();

BatchedGemmData gemmData;
BatchedGemmData gemmData{};

auto const configs = bmm.getBatchedGemmConfigs();

auto const& config = configs[configIndex];
// printf("running config %d: %s\n", configIndex, config.mFunctionName);

FLASHINFER_CHECK(numBatches > 0, "Batched GEMM requires numBatches > 0");
if (!mOptions.staticBatch) {
Expand Down Expand Up @@ -327,7 +328,7 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(

int32_t multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();

BatchedGemmData gemmData;
BatchedGemmData gemmData{};
// Dims
gemmData.mProblemDimensions.mNumBatches = numBatches;
gemmData.mProblemDimensions.mNumTokens = numTokens;
Expand Down Expand Up @@ -436,7 +437,7 @@ bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t
auto const bmm = BatchedGemmInterface();
auto const configs = bmm.getBatchedGemmConfigs();

BatchedGemmData gemmData;
BatchedGemmData gemmData{};
// Dims
gemmData.mProblemDimensions.mNumBatches = numBatches;
gemmData.mProblemDimensions.mNumTokens = numTokens;
Expand All @@ -451,12 +452,13 @@ bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t
gemmData.mProblemDimensions.mRank = 0;
gemmData.mProblemDimensions.mWorldSize = 1;
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;

auto const& config = configs[configIndex];

// FIXME: temporarily disable split-k as renormalize routing plus expert number 256 failed in
// trtllm-gen ac83afb
return bmm.isValidConfig(config, gemmData) && config.mOptions.mClusterDimZ == 1;
return bmm.isValidConfig(config, gemmData);
}

} // namespace kernels
Expand Down
Loading
Loading