Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
64890fb
refactor trtllm-gen enums and gemm runner
IwakuraRein Feb 26, 2026
a730b0c
update gemm headers
IwakuraRein Feb 27, 2026
7ff5f21
add mxfp8 gemm (WIP)
IwakuraRein Feb 27, 2026
9e03b40
wip
IwakuraRein Feb 28, 2026
30c40fe
temporary fix for coexistence of two majorness
IwakuraRein Feb 28, 2026
6eda259
update unit test
IwakuraRein Feb 28, 2026
616f493
add layout A to options
IwakuraRein Mar 2, 2026
ef9726c
update test_mm_mxfp8.py
IwakuraRein Mar 2, 2026
e5009cf
update benchmark
IwakuraRein Mar 2, 2026
66a93be
update cubins
IwakuraRein Mar 3, 2026
e50ad06
update mxfp8 test
IwakuraRein Mar 3, 2026
ca102c4
Merge remote-tracking branch 'upstream/main' into mxfp8-gemm
IwakuraRein Mar 5, 2026
5944694
update artifact
IwakuraRein Mar 5, 2026
aa63322
move sflayout to tllm_enums; expose sf swizzle layout in the mxfp8 qu…
IwakuraRein Mar 9, 2026
9392c1e
update test_mm_mxfp8; update comments in mm_mxfp8
IwakuraRein Mar 9, 2026
ed9620d
Merge remote-tracking branch 'upstream/main' into mxfp8-gemm
IwakuraRein Mar 9, 2026
39335cf
fix merge
IwakuraRein Mar 9, 2026
b0eb943
fix mxfp8 benchmark
IwakuraRein Mar 9, 2026
4750dab
fix benchmark 128x4 layout
IwakuraRein Mar 9, 2026
a59e22e
address comments
IwakuraRein Mar 9, 2026
e3293f6
update comments
IwakuraRein Mar 9, 2026
80356a0
default to 128x4
IwakuraRein Mar 10, 2026
8316e83
Merge remote-tracking branch 'upstream/main' into mxfp8-gemm
IwakuraRein Mar 10, 2026
2a38fde
fix typo in artifacts.py
IwakuraRein Mar 10, 2026
8e5d4e9
update checksum hash
IwakuraRein Mar 11, 2026
305ae95
fix csrc/trtllm_low_latency_gemm_runner.cu
IwakuraRein Mar 13, 2026
f338e78
Merge remote-tracking branch 'upstream/main' into mxfp8-gemm
IwakuraRein Mar 13, 2026
612b98a
fix ci error and typo
IwakuraRein Mar 16, 2026
4ed50cf
add block-major-k
IwakuraRein Mar 16, 2026
e38922c
fix low latency gemm
IwakuraRein Mar 17, 2026
db85eb6
remove silent exit
IwakuraRein Mar 17, 2026
ea91fb1
update benchmark
IwakuraRein Mar 17, 2026
60c828e
Merge remote-tracking branch 'upstream/main' into mxfp8-gemm
IwakuraRein Mar 17, 2026
21b097c
Merge remote-tracking branch 'upstream/main' into mxfp8-gemm
IwakuraRein Mar 19, 2026
eb08363
fix merge
IwakuraRein Mar 19, 2026
0be4849
fix merge
IwakuraRein Mar 19, 2026
5b2f490
fix merge
IwakuraRein Mar 19, 2026
7ca7788
fix typo
IwakuraRein Mar 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ def dtype_str_to_torch_dtype(dtype_str):
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cutlass", "cute-dsl"],
"10.3": ["cutlass", "cute-dsl"],
"10.0": ["cutlass", "cute-dsl", "trtllm"],
"10.3": ["cutlass", "cute-dsl", "trtllm"],
"11.0": ["cutlass"],
"12.0": [],
},
Expand Down
109 changes: 71 additions & 38 deletions benchmarks/routines/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,7 +1308,12 @@ def testMmMxfp8(args):
res_dtype = args.out_dtype
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
autotune_supported_backends = ["cutlass", "cute-dsl", "auto"]
autotune_supported_backends = [
"cutlass",
"cute-dsl",
"trtllm",
"auto",
]
res = []

backends = filter_backends_by_compute_capability(backends, args.routine, device)
Expand Down Expand Up @@ -1336,42 +1341,73 @@ def testMmMxfp8(args):
print("[ERROR] No backends to test. Exiting.")
return res

## Prepare input tensors
# Use swizzled layout for optimal performance
is_sf_swizzled_layout = True

inputs = {}
input = torch.randn([m, k], device=device, dtype=torch.bfloat16)
input_mxfp8, input_scale = mxfp8_quantize(
input, is_sf_swizzled_layout=is_sf_swizzled_layout
)

mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16)
mat2_mxfp8, mat2_scale = mxfp8_quantize(
mat2, is_sf_swizzled_layout=is_sf_swizzled_layout
)
for backend in backends:
## Prepare input tensors
# Use swizzled layout for optimal performance
is_sf_swizzled_layout = backend in ["cutlass", "trtllm"]

if not is_sf_swizzled_layout:
sf_layout_input = flashinfer.SfLayout.layout_linear
elif backend == "cutlass" or args.use_128x4_sf_layout:
sf_layout_input = flashinfer.SfLayout.layout_128x4
elif backend == "trtllm":
if not args.use_128x4_sf_layout:
sf_layout_input = flashinfer.SfLayout.layout_8x4
else:
sf_layout_input = flashinfer.SfLayout.layout_128x4
input_mxfp8, input_scale = mxfp8_quantize(
input, sf_swizzle_layout=sf_layout_input
)
Comment on lines +1350 to +1363
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

"cute-dsl" and "auto" backends receive incorrect input quantization.

Per the mm_mxfp8 documentation (context snippet 1), the "cute-dsl" backend "currently requires swizzled 1D scales" and "auto" "selects the CUTLASS backend." However, is_sf_swizzled_layout is False for both since they're not in ["cutlass", "trtllm"].

This causes mxfp8_quantize to produce non-swizzled scales for backends that expect swizzled layouts, likely resulting in incorrect results or failures.

🔧 Suggested fix
-        is_sf_swizzled_layout = backend in ["cutlass", "trtllm"]
+        is_sf_swizzled_layout = backend in ["cutlass", "trtllm", "cute-dsl", "auto"]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
is_sf_swizzled_layout = backend in ["cutlass", "trtllm"]
input = torch.randn([m, k], device=device, dtype=torch.bfloat16)
input_mxfp8, input_scale = mxfp8_quantize(
input, is_sf_swizzled_layout=is_sf_swizzled_layout
)
input = torch.randn([m, k], device=device, dtype=torch.bfloat16)
input_mxfp8, input_scale = mxfp8_quantize(
input, is_sf_swizzled_layout=is_sf_swizzled_layout
)
is_sf_swizzled_layout = backend in ["cutlass", "trtllm", "cute-dsl", "auto"]
input = torch.randn([m, k], device=device, dtype=torch.bfloat16)
input_mxfp8, input_scale = mxfp8_quantize(
input, is_sf_swizzled_layout=is_sf_swizzled_layout
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/routines/gemm.py` around lines 1335 - 1340, The backend selection
sets is_sf_swizzled_layout only for ["cutlass","trtllm"], but "cute-dsl"
requires swizzled 1D scales and "auto" maps to CUTLASS, so mxfp8_quantize is
given the wrong layout; update the logic that computes is_sf_swizzled_layout
(used when calling mxfp8_quantize) to also treat backend == "cute-dsl" and
backend == "auto" as swizzled (i.e., set is_sf_swizzled_layout = True for those
values) so the scales passed to mxfp8_quantize match the backend expectations.

# when using trtllm, the shuffle_matrix_sf_a will swizzle the layout.
mat2_mxfp8, mat2_scale = mxfp8_quantize(
mat2,
is_sf_swizzled_layout=False
if backend == "trtllm"
else is_sf_swizzled_layout,
)

if args.verbose >= 2:
print(f"[VVERBOSE] {input_mxfp8.shape = }")
print(f"[VVERBOSE] {input_mxfp8.dtype = }")
print(f"[VVERBOSE] {mat2_mxfp8.shape = }")
print(f"[VVERBOSE] {mat2_mxfp8.dtype = }")
print(f"[VVERBOSE] {input_scale.shape = }")
print(f"[VVERBOSE] {input_scale.dtype = }")
print(f"[VVERBOSE] {mat2_scale.shape = }")
print(f"[VVERBOSE] {mat2_scale.dtype = }")
if backend == "trtllm":
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
if backend in ["cutlass", "cute-dsl", "auto"]:
return flashinfer.gemm.mm_mxfp8(
a=input_mxfp8,
b=mat2_mxfp8.t(), # mm_mxfp8 expects b.t()
a_descale=input_scale,
b_descale=mat2_scale, # mm_mxfp8 handles swizzled 1D internally
out_dtype=res_dtype,
backend=backend,
mat2_mxfp8 = shuffle_matrix_a(mat2_mxfp8, 128).reshape(n, k)
mat2_scale = shuffle_matrix_sf_a(
mat2_scale.reshape(n, k // 32),
128,
num_elts_per_sf=32,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
mat2_scale = mat2_scale.t()

if args.verbose >= 2:
print(f"[VERBOSE] {backend}: {input_mxfp8.shape = }")
print(f"[VERBOSE] {backend}: {input_mxfp8.dtype = }")
print(f"[VERBOSE] {backend}: {mat2_mxfp8.shape = }")
print(f"[VERBOSE] {backend}: {mat2_mxfp8.dtype = }")
print(f"[VERBOSE] {backend}: {input_scale.shape = }")
print(f"[VERBOSE] {backend}: {input_scale.dtype = }")
print(f"[VERBOSE] {backend}: {mat2_scale.shape = }")
print(f"[VERBOSE] {backend}: {mat2_scale.dtype = }")
inputs[backend] = (input_mxfp8, mat2_mxfp8, input_scale, mat2_scale)

def run_backend(
backend: str,
inputs: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
) -> torch.Tensor:
assert backend in ["cutlass", "trtllm", "cute-dsl", "auto"], (
f"Unsupported backend: {backend}"
)
input_mxfp8, mat2_mxfp8, input_scale, mat2_scale = inputs
return flashinfer.gemm.mm_mxfp8(
a=input_mxfp8,
b=mat2_mxfp8.t(), # mm_mxfp8 expects b.t()
a_descale=input_scale,
b_descale=mat2_scale,
out_dtype=res_dtype,
backend=backend,
use_8x4_sf_layout=backend == "trtllm" and not args.use_128x4_sf_layout,
)

has_reference_output = False
if run_refcheck:
Expand All @@ -1391,10 +1427,7 @@ def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
for _ in range(warmup_iters):
run_backend(
cur_backend,
input_mxfp8,
mat2_mxfp8,
input_scale,
mat2_scale,
inputs[cur_backend],
)
elif cache_path:
with autotune(False, cache=cache_path):
Expand All @@ -1406,7 +1439,7 @@ def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend(
cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale
cur_backend, inputs[cur_backend]
).detach()
backend_times[cur_backend] = bench_gpu_time(
fn=run_backend,
Expand All @@ -1416,7 +1449,7 @@ def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
enable_cupti=args.use_cupti,
use_cuda_graph=is_cuda_graph_compatible,
cold_l2_cache=True,
input_args=(cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale),
input_args=(cur_backend, inputs[cur_backend]),
)

# Minimum cosine similarity for swizzled layout
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/routines/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def testNvfp4Quantize(args):
Returns:
dict: List of dictionaries containing performance results
"""
from flashinfer.fp4_quantization import SfLayout
from flashinfer import SfLayout

if args.verbose >= 1:
print("[INFO] Running testNvfp4Quantize")
Expand Down
17 changes: 6 additions & 11 deletions csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
// linear layout. See QuantizationSFLayout enum for more details about the two layouts.
// returns
void mxfp8_quantize(TensorView input, TensorView valMxFP8, TensorView scaleFP8SF,
bool isSfSwizzledLayout, int64_t alignment, bool enable_pdl) {
int64_t sfSwizzleLayout, int64_t alignment, bool enable_pdl) {
CHECK_CUDA(input);
CHECK_CONTIGUOUS(input);

Expand All @@ -50,8 +50,7 @@ void mxfp8_quantize(TensorView input, TensorView valMxFP8, TensorView scaleFP8SF

const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();

auto const layout = isSfSwizzledLayout ? tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4
: tensorrt_llm::QuantizationSFLayout::LINEAR;
auto const layout = static_cast<tensorrt_llm::QuantizationSFLayout>(sfSwizzleLayout);

#define LAUNCH_MXFP8_QUANTIZE_KERNEL(T) \
tensorrt_llm::kernels::invokeMxFP8Quantization( \
Expand Down Expand Up @@ -94,7 +93,7 @@ inline uint8_t float_to_ue8m0(float value) {

// Used in tests to quantize mxe4m3 tensors on host.
void mxfp8_quantize_host(TensorView x_fp32, TensorView fp8_tensor, TensorView scale_tensor,
bool is_sf_swizzled_layout) {
int64_t sfSwizzleLayout) {
int32_t const sf_vec_size = 32;
auto fp32_dtype = DLDataType{kDLFloat, 32, 1};
CHECK_INPUT_TYPE(x_fp32, fp32_dtype);
Expand All @@ -104,9 +103,7 @@ void mxfp8_quantize_host(TensorView x_fp32, TensorView fp8_tensor, TensorView sc
int hidden_dim = data_shape[1];
int groups_per_hidden_dim = hidden_dim / sf_vec_size;

tensorrt_llm::QuantizationSFLayout layout =
is_sf_swizzled_layout ? tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4
: tensorrt_llm::QuantizationSFLayout::LINEAR;
auto const layout = static_cast<tensorrt_llm::QuantizationSFLayout>(sfSwizzleLayout);

for (size_t ti = 0; ti < static_cast<size_t>(data_shape[0]); ++ti) {
for (int group = 0; group < groups_per_hidden_dim; ++group) {
Expand Down Expand Up @@ -141,7 +138,7 @@ void mxfp8_quantize_host(TensorView x_fp32, TensorView fp8_tensor, TensorView sc

// Used in tests to dequantize mxe4m3 tensors on host.
void mxfp8_dequantize_host(TensorView value_e4m3, TensorView scale_ue8m08sf,
TensorView float_tensor, bool is_sf_swizzled_layout) {
TensorView float_tensor, int64_t sfSwizzleLayout) {
int32_t const sf_vec_size = 32;
CHECK_INPUT_TYPE(value_e4m3, dl_uint8);
CHECK_INPUT_TYPE(scale_ue8m08sf, dl_uint8);
Expand All @@ -153,9 +150,7 @@ void mxfp8_dequantize_host(TensorView value_e4m3, TensorView scale_ue8m08sf,
int hidden_dim = data_shape[1];
int groups_per_hidden_dim = hidden_dim / sf_vec_size;

tensorrt_llm::QuantizationSFLayout layout =
is_sf_swizzled_layout ? tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4
: tensorrt_llm::QuantizationSFLayout::LINEAR;
auto const layout = static_cast<tensorrt_llm::QuantizationSFLayout>(sfSwizzleLayout);
for (size_t ti = 0; ti < static_cast<size_t>(data_shape[0]); ++ti) {
for (int group = 0; group < groups_per_hidden_dim; ++group) {
float* float_ptr =
Expand Down
6 changes: 3 additions & 3 deletions csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ inline int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn,
// alignment: sfVecSize
// returns fp8_quantized and block_scale_factors.
void mxfp8_quantize(TensorView input, TensorView valMxFP8, TensorView scaleFP8SF,
bool is_sf_swizzled_layout, int64_t alignment, bool enable_pdl);
int64_t sfSwizzleLayout, int64_t alignment, bool enable_pdl);

// x_fp32: [M, K], fp32_quantized (on the host)
// isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in
// linear layout. See QuantizationSFLayout enum for more details about the two layouts.
// returns fp8_quantized and block_scale_factors (on the host).
void mxfp8_quantize_host(TensorView x_fp32, TensorView fp8_tensor, TensorView scale_tensor,
bool is_sf_swizzled_layout = true);
int64_t sfSwizzleLayout = 2);

void mxfp8_dequantize_host(TensorView value_e4m3, TensorView scale_ue8m08sf,
TensorView float_tensor, bool is_sf_swizzled_layout = true);
TensorView float_tensor, int64_t sfSwizzleLayout = 2);
Loading
Loading