Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions benchmark/kernels/quantization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# W8A8 Block-wise Quantization Kernel Tuning

Auto-tune Triton FP8/INT8 block-wise quantization kernels for optimal performance.
Copy link
Collaborator

@ispobock ispobock Nov 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When do we need to use Triton FP8 block-wise quantization kernel instead of DeepGEMM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replying to here .


## When to Use Triton FP8 Block-wise Quantization Kernel vs DeepGEMM

**Use Triton FP8 Block-wise Quantization Kernel when:**
- Output dtype is NOT `bfloat16` (e.g., `float16`, `float32`)
- DeepGEMM is disabled (environment variable `SGLANG_ENABLE_JIT_DEEPGEMM=0`)
- Running on GPUs with compute capability < SM90 (DeepGEMM requires SM90+)
- You need cross-platform compatibility (Triton works on both NVIDIA and AMD GPUs)

**Use DeepGEMM when:**
- Output dtype is `bfloat16` AND DeepGEMM is enabled
- Running on NVIDIA GPUs with compute capability >= SM90 (e.g., H100, H200)
- Need maximum performance for production workloads (DeepGEMM is highly optimized for Hopper architecture)

**Note:** DeepGEMM requires CUDA compute capability >= 9.0 (SM90+). It is specifically optimized for NVIDIA Hopper GPUs (H100/H200).

The kernel selection logic in SGLang automatically chooses DeepGEMM when conditions are met (see `w8a8_block_fp8_matmul` function in `fp8_kernel.py`), otherwise falls back to Triton implementation.

## Quick Start

**Default (DeepSeek-V3):**
```bash
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --tp-size 8
```

**Custom Model (specify N and K):**
```bash
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 25600
```

## Parameters

- `--N`, `--K`: Weight matrix dimensions (N=output_dim, K=input_dim). If not specified, uses `--tp-size` for DeepSeek-V3
- `--tp-size`: Tensor parallelism size for DeepSeek-V3 (default: 8)
- `--input-type`: `fp8` or `int8` (default: fp8)
- `--block-n`, `--block-k`: Block quantization granularity (default: 128)
- `--batch-size`: Test single batch size (optional)

## How to Calculate N and K

For a linear layer `y = xW^T` where `x` is (M, K) and `W` is (N, K):
- **N**: Output features (weight matrix output dimension)
- **K**: Input features (weight matrix input dimension)

**Example: Qwen3-VL-32B** (hidden_size=5120, intermediate_size=25600, num_heads=64, num_kv_heads=8, head_dim=128) and TP=1
```bash
# QKV projection: Q(8192) + K(1024) + V(1024) = 10240
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 10240 --K 5120

# MLP gate+up (SwiGLU): 2 * intermediate_size = 51200
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 51200 --K 5120

# MLP down projection
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 25600

# O projection (if separate from QKV)
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 8192
```

If TP=8:

```bash
# QKV projection: Q(8192) + K(1024) + V(1024) = 10240 / TP=8
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 1280 --K 5120

# MLP gate+up (SwiGLU): 2 * intermediate_size = 51200 / TP=8
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 6400 --K 5120

# MLP down projection
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 3200

# O projection (if separate from QKV)
python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 1024
```

## Output

Generates JSON config files saved to `python/sglang/srt/layers/quantization/configs/`:
```
N={N},K={K},device_name={DEVICE},dtype=fp8_w8a8,block_shape=[128,128].json
```

Config maps batch size to optimal kernel parameters:
```json
{
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, ...},
"2048": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, ...}
}
```
35 changes: 33 additions & 2 deletions benchmark/kernels/quantization/tuning_block_wise_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def w8a8_block_matmul(
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)

needs_masking = bool(K % config["BLOCK_SIZE_K"] != 0)

def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
Expand Down Expand Up @@ -127,6 +129,7 @@ def grid(META):
Bs.stride(1),
Bs.stride(0),
**config,
needs_masking=needs_masking,
)

return C
Expand Down Expand Up @@ -428,7 +431,13 @@ def main(args):
batch_sizes = [args.batch_size]
num_gpus = 1 # If only one batch size, use only one GPU

weight_shapes = get_weight_shapes(args.tp_size)
# Support manual N and K specification
if args.N is not None and args.K is not None:
weight_shapes = [(args.N, args.K)]
print(f"Using manually specified weight shape: N={args.N}, K={args.K}")
else:
weight_shapes = get_weight_shapes(args.tp_size)
print(f"Using predefined weight shapes for TP size {args.tp_size}")

batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus)

Expand All @@ -453,7 +462,25 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--tp-size", "-tp", type=int, default=8)
parser.add_argument(
"--tp-size",
"-tp",
type=int,
default=8,
help="Tensor parallelism size (ignored if --N and --K are specified)",
)
parser.add_argument(
"--N",
type=int,
default=None,
help="Output dimension of weight matrix (number of columns)",
)
parser.add_argument(
"--K",
type=int,
default=None,
help="Input dimension of weight matrix (number of rows)",
)
parser.add_argument(
"--input-type", type=str, choices=["fp8", "int8"], default="fp8"
)
Expand All @@ -471,4 +498,8 @@ def main(args):
)
args = parser.parse_args()

# Validate arguments
if (args.N is None) != (args.K is None):
parser.error("--N and --K must be specified together or not at all")

main(args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
}
16 changes: 16 additions & 0 deletions python/sglang/srt/layers/quantization/configs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# W8A8 Block FP8 Kernel Configurations

This directory contains optimized kernel configurations for the W8A8 block FP8 matrix multiplication kernel.

## Configuration File Format

Configuration files are named using the following pattern:
```
N={N},K={K},device_name={DEVICE_NAME},dtype=fp8_w8a8,block_shape=[{BLOCK_N},{BLOCK_K}].json
```

Where:
- `N`: Output dimension (number of columns in weight matrix)
- `K`: Input dimension (number of columns in activation matrix)
- `DEVICE_NAME`: GPU device name with spaces replaced by underscores (e.g., `NVIDIA_H100_80GB_HBM3`)
- `BLOCK_N`, `BLOCK_K`: Block quantization granularity (typically `[128,128]`)
Loading
Loading