From 487151de954b26f8f2d3263b88d2a182f2a1b0cd Mon Sep 17 00:00:00 2001 From: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Date: Mon, 26 Jan 2026 05:01:36 -0800 Subject: [PATCH 1/8] Add support for Mistral Large 3 inference with Flashinfer MoE Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> --- requirements/cuda.txt | 2 +- ...evice_name=NVIDIA_B200,dtype=fp8_w8a8.json | 147 ++++++++++++++++++ .../E=128,N=512,device_name=NVIDIA_B200.json | 147 ++++++++++++++++++ ...vice_name=NVIDIA_GB200,dtype=fp8_w8a8.json | 147 ++++++++++++++++++ .../E=128,N=512,device_name=NVIDIA_H200.json | 147 ++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ++++++++++++++++++ .../layers/fused_moe/flashinfer_trtllm_moe.py | 7 +- .../quantization/utils/flashinfer_utils.py | 18 +-- vllm/model_executor/models/deepseek_v2.py | 14 +- vllm/utils/flashinfer.py | 2 +- 10 files changed, 762 insertions(+), 16 deletions(-) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_B200,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_B200.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=4096,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/requirements/cuda.txt b/requirements/cuda.txt index b8f69713c267..8e9a9063877f 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -10,4 +10,4 @@ torchaudio==2.9.1 # These must be updated alongside torch torchvision==0.24.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version # FlashInfer should be updated together with the Dockerfile -flashinfer-python==0.6.1 +flashinfer-python==0.6.2 diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_B200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_B200,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..959589a361d7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_B200,dtype=fp8_w8a8.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.4.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_B200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_B200.json new file mode 100644 index 000000000000..58201bd9b027 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_B200.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.4.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..959589a361d7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.4.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200.json new file mode 100644 index 000000000000..040d8fb94a20 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.4.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=4096,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=4096,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..a15963d2fe72 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=4096,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.1", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 647108cc44fd..7829b068716f 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -71,8 +71,11 @@ def _supports_routing_method( ] elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): # NOTE(rob): kernel requires Llama4. - return routing_method == RoutingMethodType.Llama4 - + return routing_method in [ + RoutingMethodType.Llama4, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] else: raise ValueError("Unsupported quantization scheme.") diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index cd82b54322ff..e61d79e97899 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -113,15 +113,15 @@ def apply_fi_trtllm_fp8_per_tensor_moe( and hasattr(layer, "output2_scales_scalar") ) - # Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe - assert ( - hasattr(layer, "output1_scales_scalar") - and hasattr(layer, "output1_scales_gate_scalar") - and hasattr(layer, "output2_scales_scalar") - ) + if layer.routing_method_type == RoutingMethodType.Llama4: + assert ( + not layer.renormalize + and layer.custom_routing_function == Llama4MoE.custom_routing_function + ), ( + "FusedMoE flashinfer kernels with Llama4 routing method are only " + "supported for Llama4" + ) - is_llama4 = layer.custom_routing_function == Llama4MoE.custom_routing_function - assert is_llama4, "FusedMoE flashinfer kernels are only supported for Llama4" return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe( routing_logits=router_logits, routing_bias=routing_bias, @@ -140,7 +140,7 @@ def apply_fi_trtllm_fp8_per_tensor_moe( local_expert_offset=layer.ep_rank * layer.local_num_experts, local_num_experts=layer.local_num_experts, use_routing_scales_on_input=apply_router_weight_on_input, - routing_method_type=RoutingMethodType.Llama4, + routing_method_type=layer.routing_method_type, ) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index f8907ed86efa..4e465f0fe166 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -295,6 +295,14 @@ def __init__( prefix=f"{prefix}.shared_experts", ) + n_group = getattr(config, "n_group", 1) + topk_group = getattr(config, "topk_group", 1) + use_grouped_topk = True + if (n_group, topk_group) == (1, 1): + n_group = None + topk_group = None + use_grouped_topk = False + self.experts = SharedFusedMoE( shared_experts=self.shared_experts, gate=self.gate, @@ -305,9 +313,9 @@ def __init__( reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=getattr(config, "n_group", 1), - topk_group=getattr(config, "topk_group", 1), + use_grouped_topk=use_grouped_topk, + num_expert_group=n_group, + topk_group=topk_group, prefix=f"{prefix}.experts", scoring_func=getattr(config, "scoring_func", "softmax"), # we do scaling outside, set factor to 1.0 to avoid double mul diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 067c6fb3e785..155530258ee9 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -409,7 +409,7 @@ def flashinfer_mm_fp4( use_8x4_sf_layout: bool, backend: str, ) -> torch.Tensor: - from flashinfer import mm_fp4 as flashinfer_mm_fp4_ + from flashinfer.gemm import mm_fp4 as flashinfer_mm_fp4_ return flashinfer_mm_fp4_( A, From 389f2ea6c564c6d5d8d903bfea5b8041de197eef Mon Sep 17 00:00:00 2001 From: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Date: Thu, 29 Jan 2026 00:52:19 -0800 Subject: [PATCH 2/8] Bump flashinfer version in docker Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> --- docker/Dockerfile | 2 +- docker/Dockerfile.nightly_torch | 4 ++-- docker/versions.json | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 5f9649144a0f..804dfd85892e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -586,7 +586,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # This is ~1.1GB and only changes when FlashInfer version bumps # https://docs.flashinfer.ai/installation.html # From versions.json: .flashinfer.version -ARG FLASHINFER_VERSION=0.6.1 +ARG FLASHINFER_VERSION=0.6.2 RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system flashinfer-cubin==${FLASHINFER_VERSION} \ && uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \ diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch index 7731c0477f5f..87a9144952e2 100644 --- a/docker/Dockerfile.nightly_torch +++ b/docker/Dockerfile.nightly_torch @@ -221,13 +221,13 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2. # build flashinfer for torch nightly from source around 10 mins -# release version: v0.6.1 +# release version: v0.6.2 # todo(elainewy): cache flashinfer build result for faster build ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/uv \ echo "git clone flashinfer..." \ - && git clone --depth 1 --branch v0.6.1 --recursive https://github.com/flashinfer-ai/flashinfer.git \ + && git clone --depth 1 --branch v0.6.2 --recursive https://github.com/flashinfer-ai/flashinfer.git \ && cd flashinfer \ && git submodule update --init --recursive \ && echo "finish git clone flashinfer..." \ diff --git a/docker/versions.json b/docker/versions.json index 630ab8607fca..24cf7f06cc41 100644 --- a/docker/versions.json +++ b/docker/versions.json @@ -68,7 +68,7 @@ "default": "true" }, "FLASHINFER_VERSION": { - "default": "0.6.1" + "default": "0.6.2" }, "GDRCOPY_CUDA_VERSION": { "default": "12.8" From 186aecc5cfb5daa8765b154cc309b4283432a96a Mon Sep 17 00:00:00 2001 From: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Date: Thu, 29 Jan 2026 05:17:24 -0800 Subject: [PATCH 3/8] Update comment Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> --- vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 7829b068716f..d010556ed018 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -70,7 +70,7 @@ def _supports_routing_method( RoutingMethodType.RenormalizeNaive, ] elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): - # NOTE(rob): kernel requires Llama4. + # NOTE(dbari): as above, potentially allow others here. return routing_method in [ RoutingMethodType.Llama4, RoutingMethodType.Renormalize, From 4e5484d25703d83487d0f545ed66f1e16e7074ed Mon Sep 17 00:00:00 2001 From: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Date: Thu, 29 Jan 2026 05:19:31 -0800 Subject: [PATCH 4/8] Extend assert Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> --- .../layers/quantization/utils/flashinfer_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index e61d79e97899..8932e6e06fbb 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -121,6 +121,10 @@ def apply_fi_trtllm_fp8_per_tensor_moe( "FusedMoE flashinfer kernels with Llama4 routing method are only " "supported for Llama4" ) + else: + assert layer.custom_routing_function is None, ( + "Custom routing function is only supported for Llama4" + ) return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe( routing_logits=router_logits, From 026b4a67121fc4cde9c78ab5a06b70cf58e0d49c Mon Sep 17 00:00:00 2001 From: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Date: Thu, 29 Jan 2026 05:21:13 -0800 Subject: [PATCH 5/8] Revert unneeded change Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> --- vllm/utils/flashinfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 155530258ee9..067c6fb3e785 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -409,7 +409,7 @@ def flashinfer_mm_fp4( use_8x4_sf_layout: bool, backend: str, ) -> torch.Tensor: - from flashinfer.gemm import mm_fp4 as flashinfer_mm_fp4_ + from flashinfer import mm_fp4 as flashinfer_mm_fp4_ return flashinfer_mm_fp4_( A, From 463b08d6216ec1151d727d3d9f3e97a527cb95b3 Mon Sep 17 00:00:00 2001 From: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Date: Thu, 29 Jan 2026 05:23:53 -0800 Subject: [PATCH 6/8] Add Triton configs for H200 TP8 with/without EP Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> --- ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ++++++++++++++++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 147 ++++++++++++++++++ 2 files changed, 294 insertions(+) create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=16,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..3357dc223f76 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..9f2a4baa2d1d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} From eb1dd0f60d6c94d537fd123b2835b56227e99932 Mon Sep 17 00:00:00 2001 From: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Date: Thu, 29 Jan 2026 07:23:20 -0800 Subject: [PATCH 7/8] Extend benchmark_moe.py for Mistral Large 3 Additionally: - Fixed a serialization bug when calling Ray - Fixed the selection of LLM architecture for some other models Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> --- benchmarks/kernels/benchmark_moe.py | 49 ++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index db1468f82073..773926bfffe1 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, ) -from vllm.platforms import current_platform from vllm.transformers_utils.config import get_config from vllm.triton_utils import triton from vllm.utils.argparse_utils import FlexibleArgumentParser @@ -482,6 +481,8 @@ def benchmark( block_quant_shape: list[int] = None, use_deep_gemm: bool = False, ) -> tuple[dict[str, int], float]: + # local import to allow serialization by ray + set_random_seed(self.seed) dtype_str = _get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 @@ -535,6 +536,9 @@ def tune( block_quant_shape: list[int], use_deep_gemm: bool, ) -> dict[str, int]: + # local import to allow serialization by ray + from vllm.platforms import current_platform + best_config = None best_time = float("inf") if current_platform.is_rocm(): @@ -646,20 +650,28 @@ def save_configs( f.write("\n") +def get_compressed_tensors_block_structure(config, default_value=None): + config_groups = config.get("config_groups", {}) + if len(config_groups) != 1: + return default_value + group = next(iter(config_groups.values())) + weights = group.get("weights", {}) + block_structure = weights.get("block_structure", default_value) + return block_structure + + def get_weight_block_size_safety(config, default_value=None): quantization_config = getattr(config, "quantization_config", {}) if isinstance(quantization_config, dict): - return quantization_config.get("weight_block_size", default_value) + if "weight_block_size" in quantization_config: + return quantization_config["weight_block_size"] + return get_compressed_tensors_block_structure( + quantization_config, default_value + ) return default_value -def main(args: argparse.Namespace): - print(args) - - config = get_config(model=args.model, trust_remote_code=args.trust_remote_code) - if args.model_prefix: - config = getattr(config, args.model_prefix) - +def get_model_params(config): if config.architectures[0] == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k @@ -677,6 +689,7 @@ def main(args: argparse.Namespace): "Glm4MoeForCausalLM", "Glm4MoeLiteForCausalLM", "NemotronHForCausalLM", + "MistralLarge3ForCausalLM", ): E = config.n_routed_experts topk = config.num_experts_per_tok @@ -697,16 +710,20 @@ def main(args: argparse.Namespace): topk = text_config.num_experts_per_tok intermediate_size = text_config.moe_intermediate_size hidden_size = text_config.hidden_size - elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"): + elif config.architectures[0] == "HunYuanMoEV1ForCausalLM": E = config.num_experts topk = config.moe_topk[0] intermediate_size = config.moe_intermediate_size[0] hidden_size = config.hidden_size - elif config.architectures[0] in ["Qwen3OmniMoeForConditionalGeneration"]: + elif config.architectures[0] == "Qwen3OmniMoeForConditionalGeneration": E = config.thinker_config.text_config.num_experts topk = config.thinker_config.text_config.num_experts_per_tok intermediate_size = config.thinker_config.text_config.moe_intermediate_size hidden_size = config.thinker_config.text_config.hidden_size + elif config.architectures[0] == "PixtralForConditionalGeneration": + # Pixtral can contain different LLM architectures, + # recurse to get their parameters + return get_model_params(config.get_text_config()) else: # Support for llama4 config = config.get_text_config() @@ -715,6 +732,16 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.intermediate_size hidden_size = config.hidden_size + return E, topk, intermediate_size, hidden_size + + +def main(args: argparse.Namespace): + print(args) + + config = get_config(model=args.model, trust_remote_code=args.trust_remote_code) + if args.model_prefix: + config = getattr(config, args.model_prefix) + E, topk, intermediate_size, hidden_size = get_model_params(config) enable_ep = bool(args.enable_expert_parallel) if enable_ep: ensure_divisibility(E, args.tp_size, "Number of experts") From a1359593c6c0408ab5d27f9579e52fe8d564e751 Mon Sep 17 00:00:00 2001 From: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Date: Fri, 30 Jan 2026 00:54:01 -0800 Subject: [PATCH 8/8] Fix test_flashinfer_per_tensor_moe_fp8_no_graph Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> --- tests/kernels/moe/test_flashinfer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index a484a3f210e3..1c512b5b1688 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -135,6 +135,8 @@ def make_moe_tensors_8bit( layer.w2_input_scale, ) layer.custom_routing_function = Llama4MoE.custom_routing_function + layer.routing_method_type = RoutingMethodType.Llama4 + layer.renormalize = False layer.intermediate_size_per_partition = n layer.ep_rank = 0 layer.local_num_experts = e