Skip to content
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ High-Performance GPU Kernels for Inference
- **POD-Attention**: Fused prefill+decode for mixed batching

### GEMM & Linear Operations
- **BF16 GEMM**: BF16 matrix multiplication for SM10.0+ GPUs.
- **FP8 GEMM**: Per-tensor and groupwise scaling
- **FP4 GEMM**: NVFP4 and MXFP4 matrix multiplication for Blackwell GPUs
- **Grouped GEMM**: Efficient batched matrix operations for LoRA and multi-expert routing
Expand Down
8 changes: 7 additions & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Currently supports testing attention, gemm, fused MOE, normalization, quantizati
- `group_gemm_fp8_nt_groupwise` - Group GEMM with FP8 data types using groupwise scaling.
- `bmm_fp8` - Batched matrix multiplication with FP8 inputs.
- `mm_fp4` - Matrix multiplication with NVFP4 inputs.
- `mm_bf16` - Matrix multiplication with BF16 inputs (Blackwell SM10.0+).
- `bmm_bf16` - Batched matrix multiplication with BF16 inputs (Blackwell SM10.0+).
- MOE:
- `trtllm_fp4_block_scale_moe` - MOE with FP4 quantized weights and block-wise scaling.
- `trtllm_fp8_block_scale_moe` - MOE with FP8 quantized weights and block-wise scaling.
Expand Down Expand Up @@ -219,7 +221,8 @@ The output CSV will contain detailed metrics including:
| `--mat2_dtype` | Data type for second matrix (for FP8 GEMM, e.g. `fp8_e4m3`) |
| `--use_128x4_sf_layout` | Use 128x4 scale/format layout for FP4 GEMM (for `mm_fp4` routine) |
| `--use_nvfp4` | Whether to use nvfp4 quantization or mxfp4 quantization, defaults to False.(for `mm_fp4` routine) |
| `--autotune` | Enable autotune for supported operation (`trtllm` and `cutlass` backends for `mm_fp4` and `bmm_fp8` routines)|
| `--autotune` | Enable autotune for supported operation (`mm_fp4`, `bmm_fp8`, `mm_bf16`, `bmm_bf16` routines) |
| `--bias` | Use bias for `mm_bf16` (Enabled for TGV backend) |

### MOE Flags
| Flag | Description |
Expand Down Expand Up @@ -406,6 +409,8 @@ Legend:
| **group_gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | |
| **bmm_fp8** | | | | cudnn, cublas | cudnn, cublas | cudnn, cublas, cutlass | cudnn, cublas, cutlass | cudnn, cublas |
| **mm_fp4** | | | | | | cudnn, trtllm, cutlass | cudnn, trtllm, cutlass | cudnn |
| **mm_bf16** | | | | | | cudnn, cutlass, tgv | cudnn, cutlass, tgv | |
| **bmm_bf16** | | | | | | cudnn, cutlass | cudnn, cutlass | |
| **trtllm_fp4_block_scale_moe** | | | | | | trtllm | trtllm | |
| **trtllm_fp8_block_scale_moe** | | | | | | trtllm | trtllm | |
| **trtllm_fp8_per_tensor_scale_moe** | | | | | | trtllm | trtllm | |
Expand Down Expand Up @@ -452,6 +457,7 @@ Backend Legend:
- cudnn: cuDNN (via wrapper API)
- cudnn-native: cuDNN (direct API call)
- cutlass: CUTLASS
- tgv: TGV
- trtllm: TensorRT-LLM
- trtllm-gen: TensorRT-LLM
- trtllm-native: TensorRT-LLM (out-of-wrapper)
Expand Down
5 changes: 4 additions & 1 deletion benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"mma_sm",
"use_128x4_sf_layout",
"use_nvfp4",
"bias",
],
"moe": [
"num_tokens",
Expand Down Expand Up @@ -153,6 +154,8 @@
"bmm_mxfp8",
"mm_fp4",
"mm_mxfp8",
"mm_bf16",
"bmm_bf16",
],
"moe": [
"trtllm_fp4_block_scale_moe",
Expand Down Expand Up @@ -353,7 +356,7 @@ def dtype_str_to_torch_dtype(dtype_str):
"11.0": ["cutlass"],
"12.0": [],
},
# Note: mm_fp4 uses support checkers to filter backends, so it is not listed here
# Note: mm_fp4, mm_bf16, and bmm_bf16 use support checkers to filter backends, so they are not listed here
# MOE
"trtllm_fp4_block_scale_moe": {
"7.5": [],
Expand Down
Loading
Loading