diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 9a26dc611515..1f4e683596e2 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -1422,3 +1422,10 @@ steps:
num_gpus: 2
commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor/config-b200.txt
+
+- label: MoE Refactor Integration Test (B200 DP - TEMPORARY) # optional
+ gpu: b200
+ optional: true
+ num_gpus: 2
+ commands:
+ - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt
diff --git a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py
index 626b3b160044..9c6edee7b264 100644
--- a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py
+++ b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py
@@ -6,13 +6,16 @@
but use different quantization strategies and backends.
"""
-import nvtx
import torch
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
-from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
+from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNoEP,
+)
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager
@@ -59,6 +62,7 @@ def bench_run(
per_out_ch: bool,
mkn: tuple[int, int, int],
):
+ init_workspace_manager(torch.cuda.current_device())
(m, k, n) = mkn
dtype = torch.half
@@ -121,85 +125,6 @@ def bench_run(
# Force per-tensor quantization for all cases
per_act_token = False
- # Create stride tensors for CUTLASS
- ab_strides1 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
- ab_strides2 = torch.full((num_experts,), n, dtype=torch.int64, device=device)
- c_strides1 = torch.full((num_experts,), 2 * n, dtype=torch.int64, device=device)
- c_strides2 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
-
- def run_triton_moe(
- a: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- w1_scale: torch.Tensor,
- w2_scale: torch.Tensor,
- a1_scale: torch.Tensor,
- a2_scale: torch.Tensor,
- num_repeats: int,
- ):
- quant_config = fp8_w8a8_moe_quant_config(
- w1_scale=w1_scale,
- w2_scale=w2_scale,
- a1_scale=a1_scale,
- a2_scale=a2_scale,
- per_act_token_quant=per_act_token,
- per_out_ch_quant=per_out_ch,
- )
-
- for _ in range(num_repeats):
- fused_experts(
- a,
- w1,
- w2,
- topk_weights,
- topk_ids,
- quant_config=quant_config,
- )
-
- def run_cutlass_moe_fp8(
- a: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- ab_strides1: torch.Tensor,
- ab_strides2: torch.Tensor,
- c_strides1: torch.Tensor,
- c_strides2: torch.Tensor,
- w1_scale: torch.Tensor,
- w2_scale: torch.Tensor,
- a1_scale: torch.Tensor,
- a2_scale: torch.Tensor,
- num_repeats: int,
- ):
- quant_config = fp8_w8a8_moe_quant_config(
- w1_scale=w1_scale,
- w2_scale=w2_scale,
- a1_scale=a1_scale,
- a2_scale=a2_scale,
- per_act_token_quant=per_act_token,
- per_out_ch_quant=per_out_ch,
- )
-
- for _ in range(num_repeats):
- with nvtx.annotate("cutlass_moe_fp8", color="blue"):
- cutlass_moe_fp8(
- a=a,
- w1_q=w1,
- w2_q=w2,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- ab_strides1=ab_strides1,
- ab_strides2=ab_strides2,
- c_strides1=c_strides1,
- c_strides2=c_strides2,
- quant_config=quant_config,
- activation="silu",
- global_num_experts=num_experts,
- )
-
# Pre-create quantization config to avoid creating it inside CUDA graph
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
@@ -210,23 +135,30 @@ def run_cutlass_moe_fp8(
per_out_ch_quant=per_out_ch,
)
+ fn = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(),
+ CutlassExpertsFp8(
+ out_dtype=a.dtype,
+ e=num_experts,
+ n=n,
+ k=k,
+ quant_config=quant_config,
+ device=w1.device,
+ ),
+ )
+
# Create CUDA graphs for CUTLASS (match benchmark_moe.py pattern exactly)
cutlass_stream = torch.cuda.Stream()
cutlass_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
# Capture 10 invocations like benchmark_moe.py
for _ in range(10):
- cutlass_moe_fp8(
- a=a,
- w1_q=w1_fp8q_cutlass,
- w2_q=w2_fp8q_cutlass,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- ab_strides1=ab_strides1,
- ab_strides2=ab_strides2,
- c_strides1=c_strides1,
- c_strides2=c_strides2,
- quant_config=quant_config,
+ fn(
+ a,
+ w1_fp8q_cutlass,
+ w2_fp8q_cutlass,
+ topk_weights,
+ topk_ids,
activation="silu",
global_num_experts=num_experts,
)
diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
index 4390be8770c1..b30a1263878b 100644
--- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
@@ -5,14 +5,18 @@
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES_MOE
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
-from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
+from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts,
fused_topk,
)
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNoEP,
+)
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager
@@ -45,6 +49,7 @@ def bench_run(
per_out_ch: bool,
mkn: tuple[int, int, int],
):
+ init_workspace_manager(torch.cuda.current_device())
label = "Quant Matmul"
sub_label = (
@@ -82,11 +87,6 @@ def bench_run(
a, score, topk, renormalize=False
)
- ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
- ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
- c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
- c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
-
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
@@ -120,10 +120,6 @@ def run_cutlass_moe(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
- ab_strides1: torch.Tensor,
- ab_strides2: torch.Tensor,
- c_strides1: torch.Tensor,
- c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
@@ -135,31 +131,29 @@ def run_cutlass_moe(
per_act_token_quant=per_act_token,
)
- for _ in range(num_repeats):
- cutlass_moe_fp8(
- a,
- w1,
- w2,
- topk_weights,
- topk_ids,
- ab_strides1,
- ab_strides2,
- c_strides1,
- c_strides2,
+ fn = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(),
+ CutlassExpertsFp8(
+ out_dtype=a.dtype,
+ # NOTE(rob): w2 is shaped as [E, hidden, intermediate]
+ e=w2.shape[0],
+ n=w2.shape[2],
+ k=w2.shape[1],
quant_config=quant_config,
- )
+ device=w1.device,
+ ),
+ )
+
+ for _ in range(num_repeats):
+ fn(a, w1, w2, topk_weights, topk_ids)
def run_cutlass_from_graph(
a: torch.Tensor,
a_scale: torch.Tensor,
- w1_q: torch.Tensor,
- w2_q: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
- ab_strides1: torch.Tensor,
- ab_strides2: torch.Tensor,
- c_strides1: torch.Tensor,
- c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
@@ -169,21 +163,23 @@ def run_cutlass_from_graph(
per_act_token_quant=per_act_token,
)
+ fn = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(),
+ CutlassExpertsFp8(
+ out_dtype=a.dtype,
+ # NOTE(rob): w2 is shaped as [E, hidden, intermediate]
+ e=w2.shape[0],
+ n=w2.shape[2],
+ k=w2.shape[1],
+ quant_config=quant_config,
+ device=w1.device,
+ ),
+ )
+
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
- return cutlass_moe_fp8(
- a,
- w1_q,
- w2_q,
- topk_weights,
- topk_ids,
- ab_strides1,
- ab_strides2,
- c_strides1,
- c_strides2,
- quant_config=quant_config,
- )
+ return fn(a, w1, w2, topk_weights, topk_ids)
def run_triton_from_graph(
a: torch.Tensor,
@@ -227,10 +223,6 @@ def replay_graph(graph, num_repeats):
w2_q,
w1_scale,
w2_scale,
- ab_strides1,
- ab_strides2,
- c_strides1,
- c_strides2,
topk_weights,
topk_ids,
)
@@ -268,10 +260,6 @@ def replay_graph(graph, num_repeats):
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
- "ab_strides1": ab_strides1,
- "ab_strides2": ab_strides2,
- "c_strides1": c_strides1,
- "c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
@@ -330,10 +318,6 @@ def replay_graph(graph, num_repeats):
w2_q,
w1_scale,
w2_scale,
- ab_strides1,
- ab_strides2,
- c_strides1,
- c_strides2,
topk_weights,
topk_ids,
per_act_token,
@@ -342,7 +326,7 @@ def replay_graph(graph, num_repeats):
results.append(
benchmark.Timer(
- stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
+ stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py
index 26a281f4e4fb..b31cfd61161f 100644
--- a/benchmarks/kernels/benchmark_moe.py
+++ b/benchmarks/kernels/benchmark_moe.py
@@ -48,8 +48,6 @@ def clear_triton_cache():
# Try to clear Triton's runtime cache
try:
- import triton
-
if (
hasattr(triton, "runtime")
and hasattr(triton.runtime, "cache")
diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index 11c6e488f958..d683b538c415 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -87,7 +87,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| triton (batched) | batched | all1 | G,A,T | silu, gelu | 6 | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] |
| deep gemm | standard,batched | fp8 | G(128),A,T | silu, gelu | 6 | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] |
| cutlass_fp4 | standard,batched | nvfp4 | A,T | silu | Y | Y | [`cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp4],[`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] |
-| cutlass_fp8 | standard,batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
+| cutlass_fp8 | standard,batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
| flashinfer | standard | nvfp4,fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| marlin | standard,batched | 3 / N/A | 3 / N/A | silu,swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml
new file mode 100644
index 000000000000..9d62c542a085
--- /dev/null
+++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml
@@ -0,0 +1,5 @@
+model_name: "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8"
+accuracy_threshold: 0.92
+num_questions: 1319
+num_fewshot: 5
+server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ht.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ht.yaml
new file mode 100644
index 000000000000..276d63f4ee10
--- /dev/null
+++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ht.yaml
@@ -0,0 +1,8 @@
+model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
+accuracy_threshold: 0.88
+num_questions: 1319
+num_fewshot: 5
+server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_high_throughput"
+env:
+ VLLM_USE_DEEP_GEMM: "1"
+ VLLM_USE_DEEP_GEMM_MOE: "1"
diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml
new file mode 100644
index 000000000000..54e6ab7b35f6
--- /dev/null
+++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml
@@ -0,0 +1,9 @@
+model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
+accuracy_threshold: 0.88
+num_questions: 1319
+num_fewshot: 5
+server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_low_latency --disable-uvicorn-access-log"
+env:
+ VLLM_USE_DEEP_GEMM: "1"
+ VLLM_USE_DEEP_GEMM_MOE: "1"
+ VLLM_USE_DEEP_GEMM_E8M0: "0"
diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm.yaml
new file mode 100644
index 000000000000..eee58539c4a8
--- /dev/null
+++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm.yaml
@@ -0,0 +1,8 @@
+model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8"
+accuracy_threshold: 0.88
+num_questions: 1319
+num_fewshot: 5
+server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
+env:
+ VLLM_USE_DEEP_GEMM: "1"
+ VLLM_USE_DEEP_GEMM_MOE: "1"
diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ht.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ht.yaml
new file mode 100644
index 000000000000..2083df585f4d
--- /dev/null
+++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ht.yaml
@@ -0,0 +1,8 @@
+model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block"
+accuracy_threshold: 0.85
+num_questions: 1319
+num_fewshot: 5
+server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_high_throughput"
+env:
+ VLLM_USE_DEEP_GEMM: "1"
+ VLLM_USE_DEEP_GEMM_MOE: "1"
diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml
new file mode 100644
index 000000000000..1d4cbfe96b68
--- /dev/null
+++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml
@@ -0,0 +1,9 @@
+model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block"
+accuracy_threshold: 0.85
+num_questions: 1319
+num_fewshot: 5
+server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_low_latency --disable-uvicorn-access-log"
+env:
+ VLLM_USE_DEEP_GEMM: "1"
+ VLLM_USE_DEEP_GEMM_MOE: "1"
+ VLLM_USE_DEEP_GEMM_E8M0: "0"
diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml
new file mode 100644
index 000000000000..246549d62961
--- /dev/null
+++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml
@@ -0,0 +1,8 @@
+model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block"
+accuracy_threshold: 0.85
+num_questions: 1319
+num_fewshot: 5
+server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
+env:
+ VLLM_USE_DEEP_GEMM: "1"
+ VLLM_USE_DEEP_GEMM_MOE: "1"
diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml
new file mode 100644
index 000000000000..53fd62bac839
--- /dev/null
+++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml
@@ -0,0 +1,8 @@
+model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4"
+accuracy_threshold: 0.88
+num_questions: 1319
+num_fewshot: 5
+server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel"
+env:
+ VLLM_USE_FLASHINFER_MOE_FP4: "1"
+ VLLM_FLASHINFER_MOE_BACKEND: "throughput"
diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt
new file mode 100644
index 000000000000..c1b405fd1d00
--- /dev/null
+++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt
@@ -0,0 +1,8 @@
+Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ht.yaml
+Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml
+Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm.yaml
+Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ht.yaml
+Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml
+Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml
+Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml
+Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass-fi-a2av.yaml
diff --git a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-CT-vllm-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-CT-vllm-cutlass.yaml
new file mode 100644
index 000000000000..bf8c93921f41
--- /dev/null
+++ b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-CT-vllm-cutlass.yaml
@@ -0,0 +1,5 @@
+model_name: "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic"
+accuracy_threshold: 0.92
+num_questions: 1319
+num_fewshot: 5
+server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2"
diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-vllm-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml
similarity index 100%
rename from tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-vllm-cutlass.yaml
rename to tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml
diff --git a/tests/evals/gsm8k/configs/moe-refactor/config-b200.txt b/tests/evals/gsm8k/configs/moe-refactor/config-b200.txt
index bf02f1363be3..9d86e432e84f 100644
--- a/tests/evals/gsm8k/configs/moe-refactor/config-b200.txt
+++ b/tests/evals/gsm8k/configs/moe-refactor/config-b200.txt
@@ -1,3 +1,4 @@
+Llama-4-Scout-Fp8-CT-vllm-cutlass.yaml
Llama-4-Scout-Fp8-ModelOpt-fi-trtllm.yaml
Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml
Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml
diff --git a/tests/evals/gsm8k/configs/moe-refactor/config-h100.txt b/tests/evals/gsm8k/configs/moe-refactor/config-h100.txt
index 9725db7c8be2..2c25ea2c2caa 100644
--- a/tests/evals/gsm8k/configs/moe-refactor/config-h100.txt
+++ b/tests/evals/gsm8k/configs/moe-refactor/config-h100.txt
@@ -5,7 +5,7 @@ Qwen3-30B-A3B-Fp8-AutoFp8-marlin.yaml
Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml
Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml
Qwen3-30B-A3B-Fp8-CT-Block-marlin.yaml
-Qwen3-30B-A3B-Fp8-CT-Block-vllm-cutlass.yaml
+Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml
Qwen3-30B-A3B-Fp8-CT-Channel-marlin.yaml
Qwen3-30B-A3B-Fp8-CT-Channel-vllm-cutlass.yaml
Llama-4-Scout-Fp8-ModelOpt-fi-cutlass.yaml
diff --git a/tests/evals/gsm8k/test_gsm8k_correctness.py b/tests/evals/gsm8k/test_gsm8k_correctness.py
index 991b905211ff..6b2cb02e9401 100644
--- a/tests/evals/gsm8k/test_gsm8k_correctness.py
+++ b/tests/evals/gsm8k/test_gsm8k_correctness.py
@@ -61,6 +61,7 @@ def test_gsm8k_correctness(config_filename):
server_args.extend(
[
"--trust-remote-code",
+ "--disable-uvicorn-access-log",
]
)
diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py
index 4a57affdfbf4..cd5bf47d69e5 100644
--- a/tests/kernels/moe/test_cutlass_moe.py
+++ b/tests/kernels/moe/test_cutlass_moe.py
@@ -7,17 +7,22 @@
import pytest
import torch
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
+ FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
- cutlass_moe_fp8,
+ CutlassExpertsFp8,
run_cutlass_moe_fp8,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNoEP,
+)
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
@@ -150,16 +155,15 @@ def make_moe_tensors_8bit(
def run_with_expert_maps(
- num_experts: int, num_local_experts: int, **cutlass_moe_kwargs
+ num_experts: int,
+ num_local_experts: int,
+ quant_config: FusedMoEQuantConfig,
+ **cutlass_moe_kwargs,
):
def slice_experts():
slice_params = [
- "w1_q",
- "w2_q",
- "ab_strides1",
- "ab_strides2",
- "c_strides1",
- "c_strides2",
+ "w1",
+ "w2",
]
full_tensors = {
k: v
@@ -167,8 +171,6 @@ def slice_experts():
if k in slice_params and k in cutlass_moe_kwargs
}
- quant_config = cutlass_moe_kwargs["quant_config"]
-
for i in range(0, num_experts, num_local_experts):
s, e = i, i + num_local_experts
@@ -187,13 +189,23 @@ def slice_experts():
new_quant_config._w1.scale = quant_config.w1_scale[s:e]
new_quant_config._w2.scale = quant_config.w2_scale[s:e]
- cutlass_moe_kwargs["quant_config"] = new_quant_config
-
- yield cutlass_moe_kwargs
-
- out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
- for kwargs in slice_experts():
- out_tensor = out_tensor + cutlass_moe_fp8(**kwargs)
+ yield cutlass_moe_kwargs, new_quant_config
+
+ out_tensor = torch.zeros_like(cutlass_moe_kwargs["hidden_states"])
+ for kwargs, new_quant_config in slice_experts():
+ kernel = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(),
+ CutlassExpertsFp8(
+ out_dtype=kwargs["hidden_states"].dtype,
+ # NOTE(rob): w2 is shaped as [E, hidden, intermediate]
+ e=kwargs["w2"].shape[0], # type: ignore[union-attr]
+ n=kwargs["w2"].shape[2], # type: ignore[union-attr]
+ k=kwargs["w2"].shape[1], # type: ignore[union-attr]
+ quant_config=new_quant_config,
+ device="cuda",
+ ),
+ )
+ out_tensor = out_tensor + kernel(**kwargs)
return out_tensor
@@ -230,27 +242,35 @@ def run_8_bit(
)
kwargs = {
- "a": moe_tensors.a,
- "w1_q": moe_tensors.w1_q, # type: ignore[union-attr]
- "w2_q": moe_tensors.w2_q, # type: ignore[union-attr]
+ "hidden_states": moe_tensors.a,
+ "w1": moe_tensors.w1_q, # type: ignore[union-attr]
+ "w2": moe_tensors.w2_q, # type: ignore[union-attr]
"topk_weights": topk_weights,
"topk_ids": topk_ids,
- "ab_strides1": moe_tensors.ab_strides1,
- "ab_strides2": moe_tensors.ab_strides2,
- "c_strides1": moe_tensors.c_strides1,
- "c_strides2": moe_tensors.c_strides2,
- "quant_config": quant_config,
}
num_experts = moe_tensors.w1.size(0)
with_ep = num_local_experts is not None or num_local_experts == num_experts
if not with_ep:
- return cutlass_moe_fp8(**kwargs)
+ kernel = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(),
+ CutlassExpertsFp8(
+ out_dtype=moe_tensors.a.dtype,
+ # NOTE(rob): w2 is shaped as [E, hidden, intermediate]
+ e=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
+ n=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
+ k=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
+ quant_config=quant_config,
+ device="cuda",
+ ),
+ )
+ return kernel(**kwargs)
assert num_local_experts is not None
return run_with_expert_maps(
num_experts,
num_local_experts, # type: ignore[arg-type]
+ quant_config,
**kwargs,
)
diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py
index c23107965340..bb2f6b873941 100644
--- a/tests/kernels/moe/test_flashinfer.py
+++ b/tests/kernels/moe/test_flashinfer.py
@@ -11,12 +11,17 @@
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
)
+from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
+ FlashInferExperts,
+)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
+from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNoEP,
+)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
- apply_flashinfer_per_tensor_scale_fp8,
- flashinfer_cutlass_moe_fp8,
+ apply_fi_trtllm_fp8_per_tensor_moe,
register_scales_for_trtllm_fp8_per_tensor_moe,
- rotate_flashinfer_fp8_moe_weights,
+ rotate_weights_for_fi_trtllm_fp8_per_tensor_moe,
swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8
@@ -103,6 +108,7 @@ def make_moe_tensors_8bit(
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
layer = torch.nn.Module()
+ layer.orig_dtype = torch.bfloat16
layer.w13_weight = w13_quantized.clone()
layer.w2_weight = w2_quantized.clone()
layer.w13_input_scale = a1_scale
@@ -115,10 +121,10 @@ def make_moe_tensors_8bit(
pcp_size=1,
dp_size=1,
ep_size=1,
- tp_rank=1,
- pcp_rank=1,
- dp_rank=1,
- ep_rank=1,
+ tp_rank=0,
+ pcp_rank=0,
+ dp_rank=0,
+ ep_rank=0,
use_ep=False,
all2all_backend="naive",
)
@@ -126,7 +132,9 @@ def make_moe_tensors_8bit(
# flashinfer expects swapped rows for w13
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
if is_trtllm:
- rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
+ rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
+ layer.w13_weight, layer.w2_weight
+ )
register_scales_for_trtllm_fp8_per_tensor_moe(
layer,
layer.w13_weight_scale,
@@ -199,7 +207,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
quant_config=quant_config,
)
- flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
+ flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe(
layer=td.layer,
hidden_states=td.hidden_states,
router_logits=score,
@@ -277,17 +285,34 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config
td.layer.quant_method = td.layer
- flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
+ kernel = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(
+ defer_input_quant=quant_config.is_block_quantized
+ ),
+ FlashInferExperts(
+ out_dtype=td.layer.orig_dtype,
+ quant_config=quant_config,
+ ep_rank=td.layer.moe_parallel_config.ep_rank,
+ ep_size=td.layer.moe_parallel_config.ep_size,
+ tp_rank=td.layer.moe_parallel_config.tp_rank,
+ tp_size=td.layer.moe_parallel_config.tp_size,
+ use_dp=False,
+ use_deepseek_fp8_block_scale=False,
+ ),
+ )
+
+ flashinfer_cutlass_output = kernel(
td.hidden_states,
- td.layer,
+ td.layer.w13_weight,
+ td.layer.w2_weight,
topk_weights,
topk_ids,
+ inplace=False,
activation=activation,
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,
)
-
torch.testing.assert_close(
output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2
)
diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py
index a4b6d35987e1..d823d2b7a35d 100644
--- a/tests/quantization/test_fp8.py
+++ b/tests/quantization/test_fp8.py
@@ -15,7 +15,6 @@
Fp8Config,
Fp8KVCacheMethod,
Fp8LinearMethod,
- Fp8MoeBackend,
Fp8MoEMethod,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -278,8 +277,18 @@ def per_tensor_dequantize(tensor, inv_scale, dtype):
# this is the case for marlin as well as per-tensor Fp8MoEMethod
@pytest.mark.parametrize("use_marlin", [False]) # skip True
def test_fp8_reloading(
- method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init
+ method_cls,
+ is_checkpoint_fp8_serialized,
+ weight_block_size,
+ use_marlin,
+ dist_init,
+ monkeypatch,
):
+ # NOTE(rob): this test fails when using DeepGEMM because the
+ # shapes are invalid. Previously the test was passing because
+ # we set fp8_backend to None, which sidestepped the issue.
+ monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "0")
+
if is_checkpoint_fp8_serialized is False:
pytest.skip("FP8 weight reloading does not support online quantization")
@@ -307,6 +316,7 @@ def test_fp8_reloading(
params_dtype=torch.bfloat16,
weight_loader=default_weight_loader,
)
+ method.use_marlin = use_marlin
else:
layer = FusedMoE(
@@ -325,11 +335,6 @@ def test_fp8_reloading(
weight_loader=default_weight_loader,
)
- # Fp8LinearMethod uses use_marlin
- # Fp8MoEMethod uses fp8_backend
- method.use_marlin = use_marlin
- method.fp8_backend = Fp8MoeBackend.MARLIN if use_marlin else None
-
# capture weights format during loading
original_metadata = [
(name, param.shape, getattr(param, "weight_loader", default_weight_loader))
diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py
index 3d248e7fb994..e63404086ed9 100644
--- a/vllm/model_executor/layers/fused_moe/__init__.py
+++ b/vllm/model_executor/layers/fused_moe/__init__.py
@@ -73,7 +73,6 @@ def get_config() -> dict[str, Any] | None:
CutlassExpertsFp8,
CutlassExpertsW4A8Fp8,
cutlass_moe_fp4,
- cutlass_moe_fp8,
cutlass_moe_w4a8_fp8,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
@@ -96,7 +95,6 @@ def get_config() -> dict[str, Any] | None:
"fused_experts",
"get_config_file_name",
"GroupedTopk",
- "cutlass_moe_fp8",
"cutlass_moe_fp4",
"cutlass_moe_w4a8_fp8",
"CutlassExpertsFp8",
diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
index c585cbc1ab5d..6e397f1e76a1 100644
--- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
@@ -249,20 +249,28 @@ def run_cutlass_moe_fp8(
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
+ e: int,
+ n: int,
+ k: int,
out_dtype: torch.dtype | None,
- ab_strides1: torch.Tensor,
- ab_strides2: torch.Tensor,
- c_strides1: torch.Tensor,
- c_strides2: torch.Tensor,
quant_config: FusedMoEQuantConfig,
+ device: torch.dtype,
):
assert quant_config.use_fp8_w8a8
super().__init__(quant_config)
+
+ # E: num_experts
+ # N: intermediate size per partition
+ # K: hidden dim
+ ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64)
+ ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64)
+ c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64)
+
self.out_dtype = out_dtype
- self.ab_strides1 = ab_strides1
+ self.ab_strides1 = ab_strides1_c_strides2
self.ab_strides2 = ab_strides2
self.c_strides1 = c_strides1
- self.c_strides2 = c_strides2
+ self.c_strides2 = ab_strides1_c_strides2
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
@@ -329,24 +337,6 @@ def apply(
class CutlassExpertsFp8(CutlassExpertsFp8Base):
- def __init__(
- self,
- out_dtype: torch.dtype | None,
- ab_strides1: torch.Tensor,
- ab_strides2: torch.Tensor,
- c_strides1: torch.Tensor,
- c_strides2: torch.Tensor,
- quant_config: FusedMoEQuantConfig,
- ):
- super().__init__(
- out_dtype,
- ab_strides1,
- ab_strides2,
- c_strides1,
- c_strides2,
- quant_config,
- )
-
@property
def activation_formats(
self,
@@ -390,21 +380,10 @@ def __init__(
self,
max_experts_per_worker: int,
num_dispatchers: int,
- out_dtype: torch.dtype | None,
- ab_strides1: torch.Tensor,
- ab_strides2: torch.Tensor,
- c_strides1: torch.Tensor,
- c_strides2: torch.Tensor,
- quant_config: FusedMoEQuantConfig,
+ *args,
+ **kwargs,
):
- super().__init__(
- out_dtype,
- ab_strides1,
- ab_strides2,
- c_strides1,
- c_strides2,
- quant_config,
- )
+ super().__init__(*args, **kwargs)
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
self.num_dispatchers = num_dispatchers
@@ -445,113 +424,6 @@ def workspace_shapes(
return (workspace1, workspace2, output)
-def cutlass_moe_fp8(
- a: torch.Tensor,
- w1_q: torch.Tensor,
- w2_q: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- ab_strides1: torch.Tensor,
- ab_strides2: torch.Tensor,
- c_strides1: torch.Tensor,
- c_strides2: torch.Tensor,
- quant_config: FusedMoEQuantConfig,
- activation: str = "silu",
- expert_map: torch.Tensor | None = None,
- apply_router_weight_on_input: bool = False,
- global_num_experts: int = -1,
-) -> torch.Tensor:
- """
- This function computes a a8w8-quantized Mixture of Experts (MoE) layer
- using two sets of quantized weights, w1_q and w2_q, and top-k gating
- mechanism. The matrix multiplications are implemented with CUTLASS
- grouped gemm.
-
- Parameters:
- - a (torch.Tensor): The input tensor to the MoE layer.
- Shape: [M, K]
- - w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
- Shape: [num_experts, K, 2N] (the weights are passed transposed)
- - w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
- Shape: [num_experts, N, K] (the weights are passed transposed)
- - topk_weights (torch.Tensor): The weights of each token->expert mapping.
- - topk_ids (torch.Tensor): The token->expert mappings.
- - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
- Shape: [num_experts] or [num_experts, 2N]
- - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
- Shape: [num_experts] or [num_experts, K]
- - ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
- Shape: [num_experts]
- - ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
- Shape: [num_experts]
- - c_strides1 (torch.Tensor): The output strides for the first gemm.
- Shape: [num_experts]
- - c_strides2 (torch.Tensor): The output strides for the second gemm.
- Shape: [num_experts]
- - per_act_token (Optional[bool]): Whether the scale is per-token or
- per-tensor.
- - activation (str): The activation function to use.
- - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
- Shape: scalar or [M]
- - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
- quantize the intermediate result between the gemms.
- Shape: scalar or [M]
- - expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
- every Rank is responsible for a subset of experts. expert_map is a
- mapping from global expert-id to local expert-id. When expert_map[i]
- is -1, it means that this Rank is not responsible for global
- expert-id i.
- - apply_router_weight_on_input (bool): When true, the topk weights are
- applied directly on the inputs. This is only applicable when topk is 1.
- - global_num_experts (int): The total number of experts.
-
- Returns:
- - torch.Tensor: The fp16 output tensor after applying the MoE layer.
- """
- assert quant_config is not None
-
- if quant_config.a1_scale is not None:
- assert quant_config.per_act_token_quant == (quant_config.a1_scale.numel() != 1)
- if quant_config.a2_scale is not None:
- assert quant_config.per_act_token_quant == (quant_config.a2_scale.numel() != 1)
-
- if quant_config.w1_scale is not None:
- if quant_config.per_out_ch_quant:
- assert quant_config.w1_scale.dim() > 1 and quant_config.w1_scale.size(
- 1
- ) == w1_q.size(1)
- else:
- assert (
- quant_config.w1_scale.dim() == 1 or quant_config.w1_scale.size(1) == 1
- )
-
- num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
-
- fn = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- CutlassExpertsFp8(
- out_dtype=a.dtype,
- ab_strides1=ab_strides1,
- ab_strides2=ab_strides2,
- c_strides1=c_strides1,
- c_strides2=c_strides2,
- quant_config=quant_config,
- ),
- )
-
- return fn(
- a,
- w1_q,
- w2_q,
- topk_weights,
- topk_ids,
- activation=activation,
- global_num_experts=num_experts,
- expert_map=expert_map,
- apply_router_weight_on_input=apply_router_weight_on_input,
- )
-
-
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py
new file mode 100644
index 000000000000..14ef6b9aaa5e
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/fallback.py
@@ -0,0 +1,126 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from abc import ABC, abstractmethod
+
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+
+
+class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
+ """Base class for runtime dispatching of expert implementations."""
+
+ def __init__(
+ self,
+ experts: mk.FusedMoEPermuteExpertsUnpermute,
+ fallback_experts: mk.FusedMoEPermuteExpertsUnpermute,
+ ):
+ super().__init__(experts.quant_config)
+ self.fallback_experts = fallback_experts
+ self.experts = experts
+
+ @property
+ def activation_formats(
+ self,
+ ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
+ assert (
+ self.fallback_experts.activation_formats == self.experts.activation_formats
+ )
+ return self.fallback_experts.activation_formats
+
+ def supports_chunking(self) -> bool:
+ assert (
+ self.experts.supports_chunking()
+ == self.fallback_experts.supports_chunking()
+ )
+ return (
+ self.experts.supports_chunking()
+ and self.fallback_experts.supports_chunking()
+ )
+
+ def supports_expert_map(self) -> bool:
+ assert (
+ self.experts.supports_expert_map()
+ == self.fallback_experts.supports_expert_map()
+ )
+ return (
+ self.experts.supports_expert_map()
+ and self.fallback_experts.supports_expert_map()
+ )
+
+ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
+ e_war = self.experts.finalize_weight_and_reduce_impl()
+ fbe_war = self.fallback_experts.finalize_weight_and_reduce_impl()
+ is_dge_war = e_war is not None
+ is_fbe_war = fbe_war is not None
+
+ if is_dge_war and is_fbe_war:
+ assert e_war == fbe_war, (
+ "Both implementations should agree on WeightAndReduce impls. "
+ f"Got e_war: {e_war}, and fbe_war: {fbe_war}"
+ )
+
+ if e_war is not None:
+ return e_war
+ assert fbe_war is not None
+ return fbe_war
+
+ @abstractmethod
+ def workspace_shapes(
+ self,
+ M: int,
+ N: int,
+ K: int,
+ topk: int,
+ global_num_experts: int,
+ local_num_experts: int,
+ expert_tokens_meta: mk.ExpertTokensMetadata | None,
+ ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
+ raise NotImplementedError
+
+ @abstractmethod
+ def _select_experts_impl(
+ self,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ ) -> mk.FusedMoEPermuteExpertsUnpermute:
+ raise NotImplementedError
+
+ def apply(
+ self,
+ output: torch.Tensor,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ activation: str,
+ global_num_experts: int,
+ expert_map: torch.Tensor | None,
+ a1q_scale: torch.Tensor | None,
+ a2_scale: torch.Tensor | None,
+ workspace13: torch.Tensor,
+ workspace2: torch.Tensor,
+ expert_tokens_meta: mk.ExpertTokensMetadata | None,
+ apply_router_weight_on_input: bool,
+ ):
+ experts = self._select_experts_impl(hidden_states, w1, w2)
+ experts.apply(
+ output,
+ hidden_states,
+ w1,
+ w2,
+ topk_weights,
+ topk_ids,
+ activation,
+ global_num_experts,
+ expert_map,
+ a1q_scale,
+ a2_scale,
+ workspace13,
+ workspace2,
+ expert_tokens_meta,
+ apply_router_weight_on_input,
+ )
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
index 51e06ac54f49..3bb5a23abb7b 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
@@ -100,7 +100,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake(
)
-def flashinfer_fused_moe_per_tensor_scale_fp8(
+def fi_trtllm_fp8_per_tensor_moe(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
@@ -158,7 +158,7 @@ def flashinfer_fused_moe_per_tensor_scale_fp8(
)
-def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
+def fi_trtllm_fp8_per_tensor_moe_fake(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
hidden_states: torch.Tensor,
@@ -184,9 +184,9 @@ def flashinfer_fused_moe_per_tensor_scale_fp8_fake(
# TODO(bnell): Does this really need to be a torch.op?
direct_register_custom_op(
- op_name="flashinfer_fused_moe_per_tensor_scale_fp8",
- op_func=flashinfer_fused_moe_per_tensor_scale_fp8,
+ op_name="fi_trtllm_fp8_per_tensor_moe",
+ op_func=fi_trtllm_fp8_per_tensor_moe,
mutates_args=["hidden_states"],
- fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake,
+ fake_impl=fi_trtllm_fp8_per_tensor_moe_fake,
tags=(torch.Tag.needs_fixed_stride_order,),
)
diff --git a/vllm/model_executor/layers/fused_moe/oracle/__init__.py b/vllm/model_executor/layers/fused_moe/oracle/__init__.py
new file mode 100644
index 000000000000..208f01a7cb5e
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/oracle/__init__.py
@@ -0,0 +1,2 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
new file mode 100644
index 000000000000..f5c3b9af611f
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
@@ -0,0 +1,358 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from enum import Enum
+
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm import envs
+from vllm._aiter_ops import rocm_aiter_ops
+from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEQuantConfig,
+ fp8_w8a8_moe_quant_config,
+ fp8_w8a16_moe_quant_config,
+)
+from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
+ FlashinferMoeBackend,
+ get_flashinfer_moe_backend,
+ make_fp8_moe_alpha_scales_for_fi,
+ prepare_fp8_moe_layer_for_fi,
+)
+from vllm.model_executor.layers.quantization.utils.fp8_utils import (
+ prepare_fp8_moe_layer_for_deepgemm,
+)
+from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
+ prepare_fp8_moe_layer_for_marlin,
+)
+from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
+ cutlass_group_gemm_supported,
+)
+from vllm.platforms import current_platform
+from vllm.utils.deep_gemm import is_deep_gemm_supported
+from vllm.utils.flashinfer import has_flashinfer_moe
+from vllm.utils.import_utils import has_deep_gemm
+
+logger = init_logger(__name__)
+
+
+class Fp8MoeBackend(Enum):
+ NONE = 0
+ FLASHINFER_TRTLLM = 1
+ FLASHINFER_CUTLASS = 2
+ DEEPGEMM = 3
+ MARLIN = 4
+ TRITON = 5
+ AITER = 6
+ VLLM_CUTLASS = 7
+
+
+def select_fp8_moe_backend(
+ block_quant: bool,
+ tp_size: int,
+ with_lora_support: bool,
+ is_act_and_mul: bool = True,
+ allow_vllm_cutlass: bool = False,
+) -> Fp8MoeBackend:
+ """
+ Select the primary FP8 MoE backend
+ Note: Shape-specific fallbacks may still occur at runtime.
+ """
+ # TODO(rob): in a future PR, we will query each mk for
+ # supported features and return the mk directly, just like
+ # we do for the Attention Backend.
+
+ if with_lora_support:
+ return Fp8MoeBackend.TRITON
+
+ def _make_log_backend(backend_name: str):
+ return f"Using {backend_name} backend for FP8 MoE"
+
+ # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
+ if (
+ current_platform.is_cuda()
+ and (
+ current_platform.is_device_capability_family(100)
+ or current_platform.is_device_capability(90)
+ )
+ and envs.VLLM_USE_FLASHINFER_MOE_FP8
+ and has_flashinfer_moe()
+ ):
+ backend = get_flashinfer_moe_backend()
+ if backend == FlashinferMoeBackend.TENSORRT_LLM:
+ logger.info_once(_make_log_backend("FlashInfer TRTLLM"))
+ if not is_act_and_mul:
+ raise ValueError(
+ "FlashInfer TRTLLM FP8 MoE backend only supports "
+ "act_and_mul gate_up_project fusion. Please set "
+ "VLLM_USE_FLASHINFER_MOE_FP8=throughput to use the "
+ "FlashInfer CUTLASS backend instead."
+ )
+ return Fp8MoeBackend.FLASHINFER_TRTLLM
+ else:
+ if block_quant and current_platform.is_device_capability_family(100):
+ raise ValueError(
+ "FlashInfer FP8 MoE throughput backend does not "
+ "support block quantization on SM100. Please use "
+ "VLLM_FLASHINFER_MOE_BACKEND=latency to use the "
+ "FlashInfer TRTLLM backend instead."
+ )
+ logger.info_once(_make_log_backend("FlashInfer CUTLASS"))
+ return Fp8MoeBackend.FLASHINFER_CUTLASS
+
+ # weight-only path for older GPUs without native FP8
+ if (
+ current_platform.is_cuda() and not current_platform.has_device_capability(89)
+ ) or envs.VLLM_TEST_FORCE_FP8_MARLIN:
+ logger.info_once(_make_log_backend("Marlin"), scope="local")
+ return Fp8MoeBackend.MARLIN
+
+ # Determine if we should use DeepGEMM with block-quantized weights:
+ # - If explicitly set by user, respect their choice
+ # - If not explicitly set (default), disable when TP size is >= 8
+ moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
+ if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and tp_size >= 8:
+ moe_use_deep_gemm = False
+ logger.info_once(
+ "DeepGEMM MoE is disabled by default when TP size is >= 8. "
+ "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
+ scope="local",
+ )
+
+ use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
+ if not is_deep_gemm_supported():
+ use_deep_gemm = False
+ logger.info_once(
+ "DeepGEMM is disabled because the platform does not support it.",
+ scope="local",
+ )
+
+ if use_deep_gemm and moe_use_deep_gemm and block_quant:
+ if not has_deep_gemm():
+ logger.warning_once(
+ "DeepGEMM backend requested but not available.", scope="local"
+ )
+ elif is_deep_gemm_supported():
+ logger.info_once(_make_log_backend("DeepGEMM"), scope="local")
+ return Fp8MoeBackend.DEEPGEMM
+
+ if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
+ logger.info_once(_make_log_backend("ROCm AITER"), scope="local")
+ return Fp8MoeBackend.AITER
+
+ if allow_vllm_cutlass and not block_quant and cutlass_group_gemm_supported():
+ logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local")
+ return Fp8MoeBackend.VLLM_CUTLASS
+
+ # default to Triton
+ logger.info_once(_make_log_backend("Triton"), scope="local")
+ return Fp8MoeBackend.TRITON
+
+
+def convert_to_fp8_moe_kernel_format(
+ fp8_backend: Fp8MoeBackend,
+ layer: torch.nn.Module,
+ w13: torch.Tensor,
+ w2: torch.Tensor,
+ w13_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+ w13_input_scale: torch.Tensor | None,
+ w2_input_scale: torch.Tensor | None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ block_quant = hasattr(layer, "weight_block_size")
+ if fp8_backend == Fp8MoeBackend.DEEPGEMM:
+ assert block_quant
+ w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_deepgemm(
+ w13,
+ w2,
+ w13_scale,
+ w2_scale,
+ tuple(layer.weight_block_size),
+ )
+ elif fp8_backend == Fp8MoeBackend.AITER:
+ w13, w2 = rocm_aiter_ops.shuffle_weights(w13, w2)
+ elif fp8_backend == Fp8MoeBackend.MARLIN:
+ w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_marlin(
+ layer,
+ w13,
+ w2,
+ w13_scale,
+ w2_scale,
+ )
+ elif fp8_backend in [
+ Fp8MoeBackend.FLASHINFER_CUTLASS,
+ Fp8MoeBackend.FLASHINFER_TRTLLM,
+ ]:
+ w13, w2, w13_scale = prepare_fp8_moe_layer_for_fi(
+ layer=layer,
+ w13=w13,
+ w2=w2,
+ w13_scale=w13_scale,
+ w13_input_scale=w13_input_scale,
+ w2_scale=w2_scale,
+ w2_input_scale=w2_input_scale,
+ is_trtllm=(fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM),
+ )
+
+ return w13, w2, w13_scale, w2_scale
+
+
+def make_fp8_moe_quant_config(
+ fp8_backend: Fp8MoeBackend,
+ w1_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+ a1_scale: torch.Tensor | None,
+ a2_scale: torch.Tensor | None,
+ block_shape: list[int] | None = None,
+) -> FusedMoEQuantConfig | None:
+ """
+ Create FusedMoEQuantConfig for the specifed FP8 Backend.
+ The FusedMoEQuantConfig holds the scales that are used
+ at runtime by the Modular Kernel abstraction.
+
+ Note that certain kernels (e.g. Flashinfer CUTLASS) need
+ special Quant configs to handle non-standard inputs to
+ their kernel interfaces.
+
+ In a future PR, we will have this function should be
+ a method of the modular kernel itself.
+ """
+ # TRTLLM does not use Modular Kernel abstraction yet.
+ if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
+ return None
+
+ # MARLIN is mixed precision W8A16 config.
+ if fp8_backend == Fp8MoeBackend.MARLIN:
+ return fp8_w8a16_moe_quant_config(
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ block_shape=block_shape,
+ )
+
+ # Flashinfer CUTLASS per-tensor uses single dq scale
+ # (alpha = w_scale * a_scale) and inverse a2 scale.
+ if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None:
+ assert a1_scale is not None and a2_scale is not None
+ g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
+ w1_scale,
+ a1_scale,
+ w2_scale,
+ a2_scale,
+ )
+ return fp8_w8a8_moe_quant_config(
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ a1_scale=a1_scale,
+ a2_scale=a2_scale,
+ a1_gscale=(1.0 / a1_scale),
+ a2_gscale=(1.0 / a2_scale),
+ g1_alphas=g1_alphas,
+ g2_alphas=g2_alphas,
+ )
+ # All other backends use normal config.
+ return fp8_w8a8_moe_quant_config(
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ a1_scale=a1_scale,
+ a2_scale=a2_scale,
+ block_shape=block_shape,
+ )
+
+
+def make_fp8_moe_kernel(
+ layer: torch.nn.Module,
+ moe_quant_config: FusedMoEQuantConfig,
+ moe_config: FusedMoEConfig,
+ fp8_backend: Fp8MoeBackend,
+) -> tuple[mk.FusedMoEModularKernel, bool]:
+ # Delayed import is required since the oracle is imported
+ # by CPU backends which cannot import all of these experts.
+ # TODO: update the experts to make this not happen.
+ from vllm.model_executor.layers.fused_moe.prepare_finalize import (
+ MoEPrepareAndFinalizeNoEP,
+ )
+
+ # NOTE(rob): this is a WIP refactor. We are first migrating
+ # all of the kernels in the TP case to use mk. Once this is
+ # done, then we will initialzie the TP case and DP/EP case
+ # via the same code path (i.e. via maybe_init_modular_kernel).
+ # NOTE(rob): in progress migrating all into this format.
+ use_inplace = True
+ if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
+ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
+ FlashInferExperts,
+ )
+
+ kernel = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(
+ defer_input_quant=moe_quant_config.is_block_quantized
+ ),
+ FlashInferExperts(
+ out_dtype=layer.orig_dtype,
+ quant_config=moe_quant_config,
+ ep_rank=moe_config.ep_rank,
+ ep_size=moe_config.ep_size,
+ tp_rank=moe_config.tp_rank,
+ tp_size=moe_config.tp_size,
+ use_dp=(moe_config.dp_size > 1),
+ use_deepseek_fp8_block_scale=moe_quant_config.is_block_quantized,
+ ),
+ )
+ use_inplace = False
+
+ elif fp8_backend == Fp8MoeBackend.AITER:
+ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
+ AiterExperts,
+ )
+
+ kernel = mk.FusedMoEModularKernel(
+ # TODO: make defer_input_quant an attr of the AiterExperts
+ MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
+ AiterExperts(quant_config=moe_quant_config),
+ )
+ elif fp8_backend == Fp8MoeBackend.MARLIN:
+ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
+ MarlinExperts,
+ )
+
+ kernel = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(),
+ MarlinExperts(quant_config=moe_quant_config),
+ )
+ elif fp8_backend == Fp8MoeBackend.VLLM_CUTLASS:
+ from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import (
+ TritonOrCutlassExperts,
+ )
+
+ kernel = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(),
+ TritonOrCutlassExperts(
+ out_dtype=moe_config.in_dtype,
+ e=layer.local_num_experts,
+ n=layer.intermediate_size_per_partition,
+ k=layer.hidden_size,
+ device=layer.w13_weight.device,
+ quant_config=moe_quant_config,
+ ),
+ )
+ elif fp8_backend == Fp8MoeBackend.DEEPGEMM:
+ from vllm.model_executor.layers.fused_moe import (
+ TritonOrDeepGemmExperts,
+ )
+
+ kernel = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(),
+ TritonOrDeepGemmExperts(quant_config=moe_quant_config),
+ )
+ else:
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
+ TritonExperts,
+ )
+
+ assert fp8_backend == Fp8MoeBackend.TRITON
+ kernel = mk.FusedMoEModularKernel(
+ MoEPrepareAndFinalizeNoEP(),
+ TritonExperts(quant_config=moe_quant_config),
+ )
+ return kernel, use_inplace
diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
new file mode 100644
index 000000000000..e874ba609be0
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
@@ -0,0 +1,75 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
+from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts
+from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
+from vllm.platforms import current_platform
+
+
+class TritonOrCutlassExperts(FallbackExperts):
+ """Cutlass with fallback to Triton for low latency shapes on SM100."""
+
+ def __init__(
+ self,
+ e: int,
+ n: int,
+ k: int,
+ out_dtype: torch.dtype | None,
+ quant_config: FusedMoEQuantConfig,
+ device: torch.dtype,
+ ):
+ self.is_sm100 = current_platform.has_device_capability(100)
+ super().__init__(
+ experts=CutlassExpertsFp8(e, n, k, out_dtype, quant_config, device),
+ fallback_experts=TritonExperts(quant_config),
+ )
+
+ def workspace_shapes(
+ self,
+ M: int,
+ N: int,
+ K: int,
+ topk: int,
+ global_num_experts: int,
+ local_num_experts: int,
+ expert_tokens_meta: mk.ExpertTokensMetadata | None,
+ ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
+ # Small batch fallback for sm100.
+ if self.is_sm100 and M <= 8:
+ return self.fallback_experts.workspace_shapes(
+ M,
+ N,
+ K,
+ topk,
+ global_num_experts,
+ local_num_experts,
+ expert_tokens_meta,
+ )
+ else:
+ return self.experts.workspace_shapes(
+ M,
+ N,
+ K,
+ topk,
+ global_num_experts,
+ local_num_experts,
+ expert_tokens_meta,
+ )
+
+ def _select_experts_impl(
+ self,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ ) -> mk.FusedMoEPermuteExpertsUnpermute:
+ # Small batch fallback for sm100.
+ if self.is_sm100 and hidden_states.shape[0] <= 8:
+ return self.fallback_experts
+ else:
+ return self.experts
diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
index b8e0837162ef..4fcc1a7c1fc0 100644
--- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
@@ -10,78 +10,22 @@
_valid_deep_gemm,
_valid_deep_gemm_shape,
)
+from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.utils.deep_gemm import (
- get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
)
-class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
- def __init__(
- self,
- quant_config: FusedMoEQuantConfig,
- allow_deep_gemm: bool = False,
- ):
- super().__init__(quant_config)
-
- self.triton_expert = TritonExperts(quant_config)
-
- self.allow_deep_gemm = (
- allow_deep_gemm
- and self.quant_config.use_fp8_w8a8
- and self.block_shape == get_mk_alignment_for_contiguous_layout()
- )
-
- self.deep_gemm_expert = (
- DeepGemmExperts(self.quant_config) if self.allow_deep_gemm else None
- )
+class TritonOrDeepGemmExperts(FallbackExperts):
+ """DeepGemm with fallback to Triton for low latency shapes."""
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- assert (
- self.deep_gemm_expert is None
- or self.triton_expert.activation_formats
- == self.deep_gemm_expert.activation_formats
- )
- return self.triton_expert.activation_formats
-
- def supports_chunking(self) -> bool:
- dge = self.deep_gemm_expert
- te = self.triton_expert
- return (dge is None or dge.supports_chunking()) and (
- te is None or te.supports_chunking()
+ def __init__(self, quant_config: FusedMoEQuantConfig):
+ super().__init__(
+ experts=DeepGemmExperts(quant_config),
+ fallback_experts=TritonExperts(quant_config),
)
- def supports_expert_map(self) -> bool:
- dge = self.deep_gemm_expert
- te = self.triton_expert
- return (dge is None or dge.supports_expert_map()) and (
- te is None or te.supports_expert_map()
- )
-
- def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
- dge = self.deep_gemm_expert
- te = self.triton_expert
- dge_war = dge.finalize_weight_and_reduce_impl() if dge else None
- te_war = te.finalize_weight_and_reduce_impl() if te else None
- is_dge_war = dge_war is not None
- is_te_war = te_war is not None
-
- if is_dge_war and is_te_war:
- assert dge_war == te_war, (
- "Both implementations should agree on WeightAndReduce impls. "
- f"Got dge_war: {dge_war}, and te_war: {te_war}"
- )
-
- if dge_war is not None:
- return dge_war
-
- assert te_war is not None
- return te_war
-
def workspace_shapes(
self,
M: int,
@@ -95,11 +39,8 @@ def workspace_shapes(
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
- if self.allow_deep_gemm and (
- is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K)
- ):
- assert self.deep_gemm_expert is not None
- return self.deep_gemm_expert.workspace_shapes(
+ if is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K):
+ return self.experts.workspace_shapes(
M,
N,
K,
@@ -109,7 +50,7 @@ def workspace_shapes(
expert_tokens_meta,
)
else:
- return self.triton_expert.workspace_shapes(
+ return self.fallback_experts.workspace_shapes(
M,
N,
K,
@@ -119,45 +60,13 @@ def workspace_shapes(
expert_tokens_meta,
)
- def apply(
+ def _select_experts_impl(
self,
- output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- activation: str,
- global_num_experts: int,
- expert_map: torch.Tensor | None,
- a1q_scale: torch.Tensor | None,
- a2_scale: torch.Tensor | None,
- workspace13: torch.Tensor,
- workspace2: torch.Tensor,
- expert_tokens_meta: mk.ExpertTokensMetadata | None,
- apply_router_weight_on_input: bool,
- ):
- use_deep_gemm = self.allow_deep_gemm and (
- is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2)
- )
-
- experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
- assert experts is not None
-
- experts.apply(
- output,
- hidden_states,
- w1,
- w2,
- topk_weights,
- topk_ids,
- activation,
- global_num_experts,
- expert_map,
- a1q_scale,
- a2_scale,
- workspace13,
- workspace2,
- expert_tokens_meta,
- apply_router_weight_on_input,
- )
+ ) -> mk.FusedMoEPermuteExpertsUnpermute:
+ if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
+ return self.experts
+ else:
+ return self.fallback_experts
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index 1094d9d55a1b..a2b3aec4457e 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -13,10 +13,8 @@
)
from torch.nn.parameter import Parameter
-import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
-from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
@@ -31,6 +29,7 @@
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
+ fp8_w8a16_moe_quant_config,
int4_w4a16_moe_quant_config,
int4_w4afp8_moe_quant_config,
int8_w8a8_moe_quant_config,
@@ -46,11 +45,16 @@
MarlinExperts,
fused_marlin_moe,
)
+from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
+ Fp8MoeBackend,
+ convert_to_fp8_moe_kernel_format,
+ make_fp8_moe_kernel,
+ select_fp8_moe_backend,
+)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP,
)
-from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize,
flashinfer_trtllm_fp4_moe,
@@ -63,8 +67,8 @@
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
- expert_weight_is_col_major,
- requant_weight_ue8m0_inplace,
+ process_fp8_input_tensor_strategy_moe,
+ process_fp8_weight_tensor_strategy_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_moe_marlin_supports_layer,
@@ -76,29 +80,17 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_moe_fp4_layer_for_marlin,
)
-from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
- prepare_moe_fp8_layer_for_marlin,
-)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
swizzle_blockscale,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
- all_close_1d,
normalize_e4m3fn_to_e4m3fnuz,
- per_tensor_dequantize,
)
-from vllm.model_executor.utils import set_weight_attrs
+from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import CpuArchEnum, current_platform
from vllm.scalar_type import scalar_types
-from vllm.utils.deep_gemm import (
- get_col_major_tma_aligned_tensor,
- get_mk_alignment_for_contiguous_layout,
- is_deep_gemm_e8m0_used,
- is_deep_gemm_supported,
-)
-from vllm.utils.import_utils import has_deep_gemm
logger = init_logger(__name__)
@@ -657,10 +649,6 @@ def __init__(
moe: FusedMoEConfig,
layer_name: str | None = None,
):
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
- CompressedTensorsConfig,
- )
-
super().__init__(moe)
self.weight_quant = weight_quant
self.input_quant = input_quant
@@ -687,42 +675,31 @@ def __init__(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization."
)
-
- # For GPUs that lack FP8 hardware support, we can leverage the Marlin
- # kernel for fast weight-only FP8 quantization
- self.use_marlin = (
- not current_platform.has_device_capability(89)
- or envs.VLLM_TEST_FORCE_FP8_MARLIN
- and not self.block_quant
- )
- # Disable marlin for rocm
- if current_platform.is_rocm():
- self.use_marlin = False
-
- self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
-
- # cutlass path
- self.is_fp8_w8a8_sm100 = CompressedTensorsConfig._is_fp8_w8a8_sm100(
- self.weight_quant, self.input_quant
+ self.fp8_backend = select_fp8_moe_backend(
+ block_quant=self.block_quant,
+ tp_size=moe.tp_size,
+ with_lora_support=moe.is_lora_enabled,
+ # TODO(rob): enable selecting this externally.
+ allow_vllm_cutlass=True,
)
- self.use_cutlass = not self.block_quant and (
- CompressedTensorsConfig._is_fp8_w8a8_sm90(
- self.weight_quant, self.input_quant
+ if self.fp8_backend != Fp8MoeBackend.MARLIN:
+ per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
+ per_channel_quant = (
+ self.weight_quant.strategy == QuantizationStrategy.CHANNEL
+ )
+ if per_act_token != per_channel_quant:
+ raise NotImplementedError(
+ "For FP8 Fused MoE layers, per-token and per-channel must be "
+ "used together."
+ )
+ # TODO(rob): hook this up in a follow up PR.
+ if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
+ raise NotImplementedError(
+ "FlashInfer TRTLLM backend not supported for compressed-tensors yet."
)
- or self.is_fp8_w8a8_sm100
- )
self.disable_expert_map = False
- self.layer_name = layer_name
- self.marlin_input_dtype = (
- get_marlin_input_dtype(layer_name) if self.use_marlin else None
- )
- self.allow_deep_gemm = (
- self.block_quant
- and envs.VLLM_MOE_USE_DEEP_GEMM
- and is_deep_gemm_supported()
- and list(self.weight_block_size) == get_mk_alignment_for_contiguous_layout()
- )
+ self.kernel: mk.FusedMoEModularKernel | None = None
def create_weights(
self,
@@ -880,163 +857,75 @@ def create_weights(
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- # Fp8 moe kernels require a single activation scale.
- # We take the max of all the scales in case they differ.
- if self.static_input_scales:
- assert self.input_quant.strategy == QuantizationStrategy.TENSOR
- if layer.w13_input_scale is None or layer.w2_input_scale is None:
- raise ValueError(
- "QuantConfig has static quantization, but found "
- "activation scales are None."
- )
- if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
- layer.w2_input_scale
- ):
- logger.warning_once(
- "Found input_scales that are not equal for "
- "fp8 MoE layer. Using the maximum across experts "
- "for each layer."
- )
- layer.w13_input_scale = torch.nn.Parameter(
- layer.w13_input_scale.max(), requires_grad=False
- )
- layer.w2_input_scale = torch.nn.Parameter(
- layer.w2_input_scale.max(), requires_grad=False
- )
-
+ # Allow for accessing weights and scales in standard way.
+ w13 = layer.w13_weight
+ w2 = layer.w2_weight
+ w13_scale = layer.w13_weight_scale
+ w2_scale = layer.w2_weight_scale
+ w13_input_scale = layer.w13_input_scale
+ w2_input_scale = layer.w2_input_scale
+
+ # MI300x and MI325x use FNUZ format for FP8. Convert if needed.
if current_platform.is_fp8_fnuz():
- # Normalize the weights and scales
- w13_weight, w13_weight_scale, w13_input_scale = (
- normalize_e4m3fn_to_e4m3fnuz(
- layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
- )
+ w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
+ w13, w13_scale, w13_input_scale
)
- w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
- layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
- )
- # Reset the parameter
- layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
- layer.w13_weight_scale = torch.nn.Parameter(
- w13_weight_scale, requires_grad=False
+ w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
+ w2, w2_scale, w2_input_scale
)
- if w13_input_scale is not None:
- layer.w13_input_scale = torch.nn.Parameter(
- w13_input_scale, requires_grad=False
- )
- layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
- layer.w2_weight_scale = torch.nn.Parameter(
- w2_weight_scale, requires_grad=False
- )
- if w2_input_scale is not None:
- layer.w2_input_scale = torch.nn.Parameter(
- w2_input_scale, requires_grad=False
- )
- # For Per-TENSOR case, Fp8 moe kernel needs single weight scale
- # for w13 per expert. Use max then dequant and requant each expert.
- if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
- assert layer.w13_weight_scale is not None
- shard_size = layer.intermediate_size_per_partition
- max_w13_scales = layer.w13_weight_scale.max(dim=1).values
- for expert_id in range(layer.local_num_experts):
- start = 0
- for shard_id in range(2):
- dq_weight = per_tensor_dequantize(
- layer.w13_weight[expert_id][start : start + shard_size, :],
- layer.w13_weight_scale[expert_id][shard_id],
- )
- layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
- ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
- )
- start += shard_size
- layer.w13_weight_scale = torch.nn.Parameter(
- max_w13_scales, requires_grad=False
- )
-
- # Property to determine if AITER is used
- if self.rocm_aiter_moe_enabled:
- # reshaping weights is required for aiter moe kernel.
- shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
- layer.w13_weight.data, layer.w2_weight.data
+ # Per tensor kernels require single activation scale. Use the max.
+ if self.static_input_scales:
+ assert self.input_quant.strategy == QuantizationStrategy.TENSOR
+ assert w13_input_scale is not None and w2_input_scale is not None
+ w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
+ w13_input_scale, w2_input_scale
)
+ replace_parameter(layer, "w13_input_scale", w13_input_scale)
+ replace_parameter(layer, "w2_input_scale", w2_input_scale)
- layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
- layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
-
- elif self.use_marlin:
- (
- workspace,
- w13_weight,
- w2_weight,
- w13_weight_scale,
- w2_weight_scale,
- ) = prepare_moe_fp8_layer_for_marlin(
- layer,
- layer.w13_weight,
- layer.w2_weight,
- layer.w13_weight_scale,
- layer.w2_weight_scale,
- input_dtype=self.marlin_input_dtype,
+ # Per-tensor kernels use a single scale, for W13, but on disk there
+ # is a separate scale for W1 and W3. Requantize with the max scale.
+ if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
+ process_fp8_weight_tensor_strategy_moe(
+ w13,
+ w13_scale,
+ shard_size=layer.intermediate_size_per_partition,
+ num_experts=layer.num_local_experts,
+ )
+
+ w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
+ fp8_backend=self.fp8_backend,
+ layer=layer,
+ w13=w13,
+ w2=w2,
+ w13_scale=w13_scale,
+ w2_scale=w2_scale,
+ w13_input_scale=w13_input_scale,
+ w2_input_scale=w2_input_scale,
+ )
+
+ # Replace parameters with updated versions. Note that this helper
+ # function ensures the replacement is compatible with RL weight reloads.
+ replace_parameter(layer, "w13_weight", w13)
+ replace_parameter(layer, "w2_weight", w2)
+ replace_parameter(layer, "w13_weight_scale", w13_scale)
+ replace_parameter(layer, "w2_weight_scale", w2_scale)
+
+ self.moe_quant_config = self.get_fused_moe_quant_config(layer)
+ if self.moe_quant_config:
+ self.kernel, self.use_inplace = make_fp8_moe_kernel(
+ layer=layer,
+ moe_quant_config=self.moe_quant_config,
+ moe_config=self.moe,
+ fp8_backend=self.fp8_backend,
)
- layer.workspace = workspace
- replace_parameter(layer, "w13_weight", w13_weight)
- replace_parameter(layer, "w2_weight", w2_weight)
- replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
- replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
-
- if self.use_cutlass:
- assert self.weight_quant.strategy != QuantizationStrategy.BLOCK
- device = layer.w13_weight.device
- # ab_strides1 and c_strides2 are the same
- self.ab_strides1_c_strides2 = torch.full(
- (layer.local_num_experts,),
- layer.hidden_size,
- device=device,
- dtype=torch.int64,
- )
- self.ab_strides2 = torch.full(
- (layer.local_num_experts,),
- layer.intermediate_size_per_partition,
- device=device,
- dtype=torch.int64,
- )
- self.c_strides1 = torch.full(
- (layer.local_num_experts,),
- 2 * layer.intermediate_size_per_partition,
- device=device,
- dtype=torch.int64,
- )
-
- if is_deep_gemm_e8m0_used() and self.block_quant:
- assert layer.weight_block_size is not None
- # Re-quantise the expert weights so their scales are UE8M0.
- block_sz = tuple(layer.weight_block_size)
- requant_weight_ue8m0_inplace(
- layer.w13_weight.data,
- layer.w13_weight_scale.data,
- block_sz,
- )
- requant_weight_ue8m0_inplace(
- layer.w2_weight.data,
- layer.w2_weight_scale.data,
- block_sz,
- )
-
- # Ensure column-major TMA alignment expected by DeepGEMM.
- if expert_weight_is_col_major(layer.w13_weight_scale):
- layer.w13_weight_scale = get_col_major_tma_aligned_tensor(
- layer.w13_weight_scale
- )
- if expert_weight_is_col_major(layer.w2_weight_scale):
- layer.w2_weight_scale = get_col_major_tma_aligned_tensor(
- layer.w2_weight_scale
- )
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- if self.use_marlin or self.rocm_aiter_moe_enabled:
+ if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]:
return None
else:
return super().maybe_make_prepare_finalize(routing_tables)
@@ -1048,7 +937,7 @@ def select_gemm_impl(
) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path
assert self.moe_quant_config is not None
- if self.use_cutlass:
+ if self.fp8_backend == Fp8MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe import (
CutlassBatchedExpertsFp8,
CutlassExpertsFp8,
@@ -1064,26 +953,27 @@ def select_gemm_impl(
):
logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassBatchedExpertsFp8(
- self.moe.num_local_experts,
- num_dispatchers,
- self.moe.in_dtype,
- ab_strides1=self.ab_strides1_c_strides2,
- ab_strides2=self.ab_strides2,
- c_strides1=self.c_strides1,
- c_strides2=self.ab_strides1_c_strides2,
+ max_experts_per_worker=self.moe.num_local_experts,
+ num_dispatchers=num_dispatchers,
+ out_dtype=self.moe.in_dtype,
+ e=layer.local_num_experts,
+ n=layer.intermediate_size_per_partition,
+ k=layer.hidden_size,
+ device=layer.w13_weight.device,
quant_config=self.moe_quant_config,
)
else:
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
experts = CutlassExpertsFp8(
- self.moe.in_dtype,
- ab_strides1=self.ab_strides1_c_strides2,
- ab_strides2=self.ab_strides2,
- c_strides1=self.c_strides1,
- c_strides2=self.ab_strides1_c_strides2,
+ out_dtype=self.moe.in_dtype,
+ e=layer.local_num_experts,
+ n=layer.intermediate_size_per_partition,
+ k=layer.hidden_size,
+ device=layer.w13_weight.device,
quant_config=self.moe_quant_config,
)
+ # TODO(rob): investigate disable_expert_map
self.disable_expert_map = (
num_dispatchers > 1 or not experts.supports_expert_map()
)
@@ -1096,13 +986,14 @@ def select_gemm_impl(
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts,
)
+ from vllm.model_executor.layers.fused_moe.fused_moe import (
+ TritonExperts,
+ )
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
- assert not self.rocm_aiter_moe_enabled and not self.use_marlin
-
- use_deep_gemm = envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
+ assert self.fp8_backend not in [Fp8MoeBackend.AITER, Fp8MoeBackend.MARLIN]
if (
prepare_finalize.activation_format
@@ -1111,28 +1002,7 @@ def select_gemm_impl(
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
- if use_deep_gemm and not has_deep_gemm():
- raise RuntimeError(
- "DeepGEMM requested for MoE layer but not installed."
- )
-
- compatible_with_deep_gemm = (
- self.moe_quant_config.use_fp8_w8a8
- and self.moe_quant_config.block_shape
- == get_mk_alignment_for_contiguous_layout()
- )
-
- # If this MoE layer is compatible with DeepGEMM, the proper env
- # vars are set and DeepGEMM is not installed, throw an error.
- if use_deep_gemm and compatible_with_deep_gemm and not has_deep_gemm():
- raise RuntimeError(
- f"MoE layer incompatible with DeepGEMM, expected "
- f"fp8==True, got {self.moe_quant_config.use_fp8_w8a8}"
- f"or block_shape {self.moe_quant_config.block_shape}"
- f"=={get_mk_alignment_for_contiguous_layout()}."
- )
-
- if use_deep_gemm and compatible_with_deep_gemm and has_deep_gemm():
+ if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
return BatchedDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank,
@@ -1148,17 +1018,22 @@ def select_gemm_impl(
)
else:
- logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
- return TritonOrDeepGemmExperts(
- self.moe_quant_config,
- allow_deep_gemm=use_deep_gemm,
- )
+ if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
+ logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
+ return TritonOrDeepGemmExperts(self.moe_quant_config)
+ else:
+ logger.debug("TritonExperts(%s)", self.__class__.__name__)
+ return TritonExperts(self.moe_quant_config)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
- if self.use_marlin:
- return None
+ if self.fp8_backend == Fp8MoeBackend.MARLIN:
+ return fp8_w8a16_moe_quant_config(
+ w1_scale=layer.w13_weight_scale,
+ w2_scale=layer.w2_weight_scale,
+ block_shape=self.weight_block_size,
+ )
per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
@@ -1184,118 +1059,23 @@ def apply(
router_logits=router_logits,
)
- per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
- per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL
-
- if self.use_marlin:
- assert layer.activation == "silu", (
- f"{layer.activation} not supported for Marlin MoE."
- )
- return fused_marlin_moe(
- x,
- layer.w13_weight,
- layer.w2_weight,
- None,
- None,
- layer.w13_weight_scale,
- layer.w2_weight_scale,
- router_logits,
- topk_weights,
- topk_ids,
- quant_type_id=scalar_types.float8_e4m3fn.id,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- global_num_experts=layer.global_num_experts,
- expert_map=layer.expert_map,
- input_dtype=self.marlin_input_dtype,
- workspace=layer.workspace,
- )
-
- elif self.rocm_aiter_moe_enabled:
- from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
- rocm_aiter_fused_experts,
- )
-
- assert per_act_token == per_channel_quant
- assert self.moe_quant_config is not None
- return rocm_aiter_fused_experts(
- hidden_states=x,
- w1=layer.w13_weight,
- w2=layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- activation=layer.activation,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- expert_map=layer.expert_map,
- quant_config=self.moe_quant_config,
- )
-
- # cutlass path
- elif self.use_cutlass:
- assert self.moe_quant_config is not None
-
- # small-batch fallback on SM100
- if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8:
- from vllm.model_executor.layers.fused_moe import fused_experts
-
- assert per_act_token == per_channel_quant
- return fused_experts(
- hidden_states=x,
- w1=layer.w13_weight,
- w2=layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True,
- activation=layer.activation,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- global_num_experts=layer.global_num_experts,
- expert_map=None
- if self.disable_expert_map
- else layer.expert_map, # ???
- quant_config=self.moe_quant_config,
- allow_deep_gemm=self.allow_deep_gemm,
- )
- else:
- from vllm.model_executor.layers.fused_moe.cutlass_moe import (
- cutlass_moe_fp8,
- )
-
- assert per_act_token == per_channel_quant
- assert self.moe_quant_config is not None
- return cutlass_moe_fp8(
- x,
- layer.w13_weight,
- layer.w2_weight,
- topk_weights,
- topk_ids,
- quant_config=self.moe_quant_config,
- activation=layer.activation,
- global_num_experts=layer.global_num_experts,
- expert_map=None if self.disable_expert_map else layer.expert_map,
- ab_strides1=self.ab_strides1_c_strides2,
- ab_strides2=self.ab_strides2,
- c_strides1=self.c_strides1,
- c_strides2=self.ab_strides1_c_strides2,
- )
-
- else:
- from vllm.model_executor.layers.fused_moe import fused_experts
+ assert self.kernel is not None
+ result = self.kernel(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ topk_weights,
+ topk_ids,
+ inplace=self.use_inplace,
+ activation=layer.activation,
+ global_num_experts=layer.global_num_experts,
+ # TODO(rob): investigate the disable_expert_map introduced by:
+ # https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
+ expert_map=None if self.disable_expert_map else layer.expert_map,
+ apply_router_weight_on_input=layer.apply_router_weight_on_input,
+ )
- assert per_act_token == per_channel_quant
- assert self.moe_quant_config is not None
- return fused_experts(
- hidden_states=x,
- w1=layer.w13_weight,
- w2=layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True,
- activation=layer.activation,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- global_num_experts=layer.global_num_experts,
- expert_map=layer.expert_map,
- quant_config=self.moe_quant_config,
- allow_deep_gemm=self.allow_deep_gemm,
- )
+ return result
@property
def supports_eplb(self) -> bool:
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index 1223c6902e5f..2879315a6886 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from enum import Enum
from typing import TYPE_CHECKING, Any, Optional
import torch
@@ -27,13 +26,17 @@
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.config import (
- FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
- fp8_w8a8_moe_quant_config,
- fp8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
+from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
+ Fp8MoeBackend,
+ convert_to_fp8_moe_kernel_format,
+ make_fp8_moe_kernel,
+ make_fp8_moe_quant_config,
+ select_fp8_moe_backend,
+)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
@@ -46,25 +49,20 @@
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
- FlashinferMoeBackend,
- apply_flashinfer_per_tensor_scale_fp8,
+ apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
- get_flashinfer_moe_backend,
- make_fp8_moe_alpha_scales_for_fi,
- register_scales_for_trtllm_fp8_per_tensor_moe,
- rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl,
- swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
- deepgemm_post_process_fp8_weight_block,
maybe_post_process_fp8_weight_block,
+ process_fp8_input_tensor_strategy_moe,
process_fp8_weight_block_strategy,
process_fp8_weight_tensor_strategy,
+ process_fp8_weight_tensor_strategy_moe,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@@ -73,7 +71,6 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
- prepare_moe_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
@@ -81,12 +78,10 @@
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
- all_close_1d,
cutlass_block_fp8_supported,
cutlass_fp8_supported,
maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz,
- per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
@@ -96,11 +91,8 @@
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
- is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
-from vllm.utils.flashinfer import has_flashinfer_moe
-from vllm.utils.import_utils import has_deep_gemm
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@@ -110,107 +102,6 @@
logger = init_logger(__name__)
-class Fp8MoeBackend(Enum):
- NONE = 0
- FLASHINFER_TRTLLM = 1
- FLASHINFER_CUTLASS = 2
- DEEPGEMM = 3
- MARLIN = 4
- TRITON = 5
- AITER = 6
-
-
-def get_fp8_moe_backend(
- block_quant: bool,
- moe_parallel_config: FusedMoEParallelConfig,
- with_lora_support: bool,
-) -> Fp8MoeBackend | None:
- """
- Select the primary FP8 MoE backend
- Note: Shape-specific fallbacks may still occur at runtime.
- """
- if current_platform.is_xpu():
- return None
- if with_lora_support:
- return Fp8MoeBackend.TRITON
- # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
- if (
- current_platform.is_cuda()
- and (
- current_platform.is_device_capability_family(100)
- or current_platform.is_device_capability(90)
- )
- and envs.VLLM_USE_FLASHINFER_MOE_FP8
- and has_flashinfer_moe()
- ):
- backend = get_flashinfer_moe_backend()
- if backend == FlashinferMoeBackend.TENSORRT_LLM:
- logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
- return Fp8MoeBackend.FLASHINFER_TRTLLM
- else:
- if block_quant and current_platform.is_device_capability_family(100):
- raise ValueError(
- "FlashInfer FP8 MoE throughput backend does not "
- "support block quantization on SM100. Please use "
- "VLLM_FLASHINFER_MOE_BACKEND=latency "
- "instead."
- )
- logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
- return Fp8MoeBackend.FLASHINFER_CUTLASS
-
- # weight-only path for older GPUs without native FP8
- use_marlin = (
- not current_platform.has_device_capability(89)
- or envs.VLLM_TEST_FORCE_FP8_MARLIN
- )
- if current_platform.is_rocm():
- use_marlin = False
- if use_marlin:
- logger.info_once("Using Marlin backend for FP8 MoE")
- return Fp8MoeBackend.MARLIN
-
- # Determine if we should use DeepGEMM with block-quantized weights:
- # - If explicitly set by user, respect their choice
- # - If not explicitly set (default), disable when TP size is >= 8
- moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM
- if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8:
- moe_use_deep_gemm = False
- logger.info_once(
- "DeepGEMM MoE is disabled by default when TP size is >= 8. "
- "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.",
- scope="local",
- )
-
- # Determine if we should use DeepGEMM (top-level enable switch)
- # - If explicitly set by user, respect their choice
- # - If not platform supports DeepGEMM, disable it
- # This helps avoid warning messages on unsupported platforms.
- use_deep_gemm = envs.VLLM_USE_DEEP_GEMM
- if not is_deep_gemm_supported():
- use_deep_gemm = False
- logger.info_once(
- "DeepGEMM is disabled because the platform does not support it.",
- scope="local",
- )
-
- if use_deep_gemm and moe_use_deep_gemm and block_quant:
- if not has_deep_gemm():
- logger.warning_once(
- "DeepGEMM backend requested but not available.", scope="local"
- )
- elif is_deep_gemm_supported():
- logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local")
- return Fp8MoeBackend.DEEPGEMM
-
- if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE:
- logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local")
- return Fp8MoeBackend.AITER
-
- # default to Triton
- logger.info_once("Using Triton backend for FP8 MoE")
- return Fp8MoeBackend.TRITON
-
-
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
@@ -348,7 +239,6 @@ def get_quant_method(
moe_quant_method = Fp8MoEMethod(self, layer)
else:
moe_quant_method = Fp8OnlineMoEMethod(self, layer)
- moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
@@ -736,40 +626,24 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
super().__init__(layer.moe_config)
- self.layer = layer
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.weight_scale_name = (
"weight_scale_inv" if self.block_quant else "weight_scale"
)
- self.fp8_backend = get_fp8_moe_backend(
- self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
+ self.fp8_backend = select_fp8_moe_backend(
+ block_quant=self.block_quant,
+ tp_size=layer.moe_parallel_config.tp_size,
+ with_lora_support=self.moe.is_lora_enabled,
)
- self.marlin_input_dtype = None
- self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
- elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
+ if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
if self.block_quant and self.weight_block_size != [128, 128]:
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend only supports block "
"size [128, 128]."
)
- if not self.block_quant:
- if layer.renormalize or layer.custom_routing_function is not None:
- raise NotImplementedError(
- "FlashInfer CUTLASS FP8 MoE backend does custom routing "
- f"function or renormalization, but got {layer.renormalize} and "
- f"{layer.custom_routing_function}."
- )
- if layer.scoring_func != "sigmoid":
- raise NotImplementedError(
- "FlashInfer CUTLASS FP8 MoE backend only supports "
- f"'sigmoid' scoring function, but got {layer.scoring_func}."
- )
if layer.activation != "silu":
raise NotImplementedError(
"FlashInfer CUTLASS FP8 MoE backend only supports SiLU "
@@ -778,12 +652,17 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
dynamic_per_token = (
not self.block_quant and self.quant_config.activation_scheme != "static"
)
- if self.flashinfer_moe_backend is not None and dynamic_per_token:
+ if dynamic_per_token and self.fp8_backend in [
+ Fp8MoeBackend.FLASHINFER_TRTLLM,
+ Fp8MoeBackend.FLASHINFER_CUTLASS,
+ ]:
raise NotImplementedError(
"FlashInfer FP8 MoE backend does not support dynamic per token "
"activation quantization."
)
+ self.kernel: mk.FusedMoEModularKernel | None = None
+
def create_weights(
self,
layer: Module,
@@ -907,148 +786,43 @@ def create_weights(
layer.w13_input_scale = None
layer.w2_input_scale = None
- def _convert_weights_to_kernel_format(
+ def _setup_kernel(
self,
layer: Module,
- w13_weight: torch.Tensor,
- w2_weight: torch.Tensor,
- w13_weight_scale: torch.Tensor,
- w2_weight_scale: torch.Tensor,
+ w13: torch.Tensor,
+ w2: torch.Tensor,
+ w13_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
w13_input_scale: torch.Tensor | None,
w2_input_scale: torch.Tensor | None,
) -> None:
- if self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
- assert self.block_quant
- w13_weight, w13_weight_scale = deepgemm_post_process_fp8_weight_block(
- wq=w13_weight,
- ws=w13_weight_scale,
- quant_block_shape=tuple(layer.weight_block_size),
- use_e8m0=is_deep_gemm_e8m0_used(),
- )
- w2_weight, w2_weight_scale = deepgemm_post_process_fp8_weight_block(
- wq=w2_weight,
- ws=w2_weight_scale,
- quant_block_shape=tuple(layer.weight_block_size),
- use_e8m0=is_deep_gemm_e8m0_used(),
- )
- elif self.fp8_backend == Fp8MoeBackend.AITER:
- w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
- w13_weight, w2_weight
- )
- elif self.fp8_backend == Fp8MoeBackend.MARLIN:
- (
- workspace,
- w13_weight,
- w2_weight,
- w13_weight_scale,
- w2_weight_scale,
- ) = prepare_moe_fp8_layer_for_marlin(
- layer,
- w13_weight,
- w2_weight,
- w13_weight_scale,
- w2_weight_scale,
- input_dtype=self.marlin_input_dtype,
- )
- layer.workspace = workspace
-
- elif self.fp8_backend in [
- Fp8MoeBackend.FLASHINFER_CUTLASS,
- Fp8MoeBackend.FLASHINFER_TRTLLM,
- ]:
- w13_weight = swap_w13_to_w31(w13_weight)
- if self.block_quant:
- w13_weight_scale = swap_w13_to_w31(w13_weight_scale)
- else:
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight)
- register_scales_for_trtllm_fp8_per_tensor_moe(
- layer=layer,
- w13_weight_scale=w13_weight,
- w13_input_scale=w13_input_scale,
- w2_weight_scale=w2_weight,
- w2_input_scale=w2_input_scale,
- )
-
- elif self.fp8_backend == Fp8MoeBackend.AITER:
- w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights(
- w13_weight, w2_weight
- )
+ # Shuffle weights to runtime format.
+ w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
+ fp8_backend=self.fp8_backend,
+ layer=layer,
+ w13=w13,
+ w2=w2,
+ w13_scale=w13_scale,
+ w2_scale=w2_scale,
+ w13_input_scale=w13_input_scale,
+ w2_input_scale=w2_input_scale,
+ )
# Replace parameters with updated versions. Note that this helper
# function ensures the replacement is compatible with RL weight reloads.
- replace_parameter(layer, "w13_weight", w13_weight)
- replace_parameter(layer, "w2_weight", w2_weight)
- replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale)
- replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale)
-
- def _setup_kernel(self, layer: Module) -> None:
- """Setup Modular Kernel for TP Case"""
- # NOTE(rob): this is a WIP refactor. We are first migrating
- # all of the kernels in the TP case to use mk. Once this is
- # done, then we will initialzie the TP case and DP/EP case
- # via the same code path (i.e. via maybe_init_modular_kernel).
- # NOTE(rob): in progress migrating all into this format.
-
- from vllm.model_executor.layers.fused_moe import (
- TritonOrDeepGemmExperts,
- )
- from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
- FlashInferExperts,
- )
- from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
- MarlinExperts,
- )
- from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNoEP,
- )
- from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
- AiterExperts,
- )
-
- # Flashinfer TRTLLM does not use the modular kernel abstraction.
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- return
+ replace_parameter(layer, "w13_weight", w13)
+ replace_parameter(layer, "w2_weight", w2)
+ replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
+ replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
+ # Setup modular kernel for TP case.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- assert self.moe_quant_config is not None
- self.use_inplace = True
-
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- self.kernel = mk.FusedMoEModularKernel(
- # TODO: make defer_input_quant an attr of the FlashInferExperts
- MoEPrepareAndFinalizeNoEP(defer_input_quant=self.block_quant),
- FlashInferExperts(
- out_dtype=layer.orig_dtype,
- quant_config=self.moe_quant_config,
- ep_rank=self.moe.ep_rank,
- ep_size=self.moe.ep_size,
- tp_rank=self.moe.tp_rank,
- tp_size=self.moe.tp_size,
- use_dp=(self.moe.dp_size > 1),
- use_deepseek_fp8_block_scale=self.block_quant,
- ),
- )
- self.use_inplace = False
-
- elif self.fp8_backend == Fp8MoeBackend.AITER:
- self.kernel = mk.FusedMoEModularKernel(
- # TODO: make defer_input_quant an attr of the AiterExperts
- MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
- AiterExperts(quant_config=self.moe_quant_config),
- )
- elif self.fp8_backend == Fp8MoeBackend.MARLIN:
- self.kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- MarlinExperts(quant_config=self.moe_quant_config),
- )
- else:
- self.kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- TritonOrDeepGemmExperts(
- quant_config=self.moe_quant_config,
- allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
- ),
+ if self.moe_quant_config:
+ self.kernel, self.use_inplace = make_fp8_moe_kernel(
+ layer=layer,
+ moe_quant_config=self.moe_quant_config,
+ moe_config=self.moe,
+ fp8_backend=self.fp8_backend,
)
def process_weights_after_loading(self, layer: Module) -> None:
@@ -1056,78 +830,58 @@ def process_weights_after_loading(self, layer: Module) -> None:
return
# Allow for accessing weights and scales in standard way.
- w13_weight = layer.w13_weight
- w2_weight = layer.w2_weight
- w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}")
- w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}")
+ w13 = layer.w13_weight
+ w2 = layer.w2_weight
+ w13_scale = getattr(layer, f"w13_{self.weight_scale_name}")
+ w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
w13_input_scale = layer.w13_input_scale
w2_input_scale = layer.w2_input_scale
# MI300x and MI325x use FNUZ format for FP8. Convert if needed.
if current_platform.is_fp8_fnuz():
- w13_weight, w13_weight_scale, w13_input_scale = (
- normalize_e4m3fn_to_e4m3fnuz(
- w13_weight, w13_weight_scale, w13_input_scale
- )
+ w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
+ w13,
+ w13_scale,
+ w13_input_scale,
)
- w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
- w2_weight, w2_weight_scale, w2_input_scale
+ w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
+ w2,
+ w2_scale,
+ w2_input_scale,
)
# Per tensor kernels require single activation scale. Use the max.
if self.quant_config.activation_scheme == "static":
assert not self.block_quant
assert w13_input_scale is not None and w2_input_scale is not None
- if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
- logger.warning_once(
- "Found input_scales that are not equal for "
- "fp8 MoE layer. Using the maximum across experts "
- "for each layer."
- )
- replace_parameter(layer, "w13_input_scale", w13_input_scale.max())
- replace_parameter(layer, "w2_input_scale", w2_input_scale.max())
+ w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
+ w13_input_scale, w2_input_scale
+ )
+ replace_parameter(layer, "w13_input_scale", w13_input_scale)
+ replace_parameter(layer, "w2_input_scale", w2_input_scale)
# Per tensor kernels require single weight scale for w13 per expert, but
# on disk there is a scale for w1 and w3. Use the max to requantize.
if not self.block_quant:
shard_size = layer.intermediate_size_per_partition
- max_w13_scales = w13_weight_scale.max(dim=1).values
- for expert_id in range(layer.local_num_experts):
- start = 0
- for shard_id in range(2):
- dq_weight = per_tensor_dequantize(
- w13_weight[expert_id][start : start + shard_size, :],
- w13_weight_scale[expert_id][shard_id],
- )
- w13_weight[expert_id][start : start + shard_size, :], _ = (
- ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
- )
- start += shard_size
- w13_weight_scale = max_w13_scales
+ w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
+ w13, w13_scale, shard_size, layer.local_num_experts
+ )
- # Shuffle weights into the runtime format.
- self._convert_weights_to_kernel_format(
- layer=layer,
- w13_weight=w13_weight,
- w2_weight=w2_weight,
- w13_weight_scale=w13_weight_scale,
- w2_weight_scale=w2_weight_scale,
- w13_input_scale=w13_input_scale,
- w2_input_scale=w2_input_scale,
+ # Shuffle weights to runtime format and setup kernel.
+ self._setup_kernel(
+ layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
)
- # Setup modular kernel for TP case.
- self._setup_kernel(layer)
-
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
- if (
- self.fp8_backend == Fp8MoeBackend.AITER
- or self.fp8_backend == Fp8MoeBackend.MARLIN
- or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
- ):
+ if self.fp8_backend in [
+ Fp8MoeBackend.AITER,
+ Fp8MoeBackend.MARLIN,
+ Fp8MoeBackend.FLASHINFER_TRTLLM,
+ ]:
return None
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize(
@@ -1184,7 +938,7 @@ def select_gemm_impl(
)
elif self.moe.is_lora_enabled:
return TritonExperts(quant_config=self.moe_quant_config)
- elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
+ elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# Select GEMM experts with block-scale when weights are block-quantized
experts = select_cutlass_fp8_gemm_impl(
self.moe,
@@ -1193,17 +947,23 @@ def select_gemm_impl(
)
logger.debug_once("Using %s", experts.__class__.__name__)
return experts
- else:
+ elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM:
logger.debug(
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__,
self.weight_block_size,
False,
)
- return TritonOrDeepGemmExperts(
- quant_config=self.moe_quant_config,
- allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM),
+ return TritonOrDeepGemmExperts(self.moe_quant_config)
+ else:
+ assert self.fp8_backend == Fp8MoeBackend.TRITON
+ logger.debug(
+ "TritonExperts(%s): block_size=%s, per_act_token=%s",
+ self.__class__.__name__,
+ self.weight_block_size,
+ False,
)
+ return TritonExperts(self.moe_quant_config)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@@ -1212,42 +972,13 @@ def get_fused_moe_quant_config(
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
- # MARLIN uses mixed precision W8A16 config.
- if self.fp8_backend == Fp8MoeBackend.MARLIN:
- return fp8_w8a16_moe_quant_config(
- w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"),
- w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"),
- block_shape=self.weight_block_size,
- )
-
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
a2_scale = layer.w2_input_scale
- # Flashinfer CUTLASS per-tensor uses single dq scale
- # (alpha = w_scale * a_scale) and inverse a2 scale.
- if (
- self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS
- and not self.block_quant
- ):
- g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
- w1_scale,
- a1_scale,
- w2_scale,
- a2_scale,
- )
- return fp8_w8a8_moe_quant_config(
- w1_scale=w1_scale,
- w2_scale=w2_scale,
- a1_scale=a1_scale,
- a2_scale=(1.0 / a2_scale),
- g1_alphas=g1_alphas,
- g2_alphas=g2_alphas,
- )
-
- # All other backends use normal config.
- return fp8_w8a8_moe_quant_config(
+ return make_fp8_moe_quant_config(
+ fp8_backend=self.fp8_backend,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
@@ -1269,7 +1000,7 @@ def apply(
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
+ if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
# TODO(rob): convert this to MK.
if layer.enable_eplb:
raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
@@ -1308,10 +1039,7 @@ def apply(
routed_scaling=layer.routed_scaling_factor,
)
else:
- assert (
- not layer.renormalize and layer.custom_routing_function is not None
- )
- result = apply_flashinfer_per_tensor_scale_fp8(
+ result = apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
@@ -1327,6 +1055,8 @@ def apply(
hidden_states=x,
router_logits=router_logits,
)
+
+ assert self.kernel is not None
result = self.kernel(
x,
layer.w13_weight,
@@ -1358,7 +1088,6 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module):
assert not quant_config.is_checkpoint_fp8_serialized
assert quant_config.activation_scheme == "dynamic"
assert quant_config.weight_block_size is None
- assert self.flashinfer_moe_backend is None
def create_weights(
self,
@@ -1447,6 +1176,8 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs):
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
layer.w13_input_scale = None
layer.w2_input_scale = None
@@ -1457,33 +1188,30 @@ def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
- w13_weight = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
- w2_weight = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
+ w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
+ w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
+ w13_scale = layer.w13_weight_scale
+ w2_scale = layer.w2_weight_scale
for expert in range(layer.local_num_experts):
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
- ops.scaled_fp8_quant(layer.w13_weight[expert, :, :])
+ w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
+ layer.w13_weight[expert, :, :]
)
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
- ops.scaled_fp8_quant(layer.w2_weight[expert, :, :])
+ w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant(
+ layer.w2_weight[expert, :, :]
)
- replace_parameter(layer, "w13_weight", w13_weight)
- replace_parameter(layer, "w2_weight", w2_weight)
- # Shuffle weights into the runtime format.
- self._convert_weights_to_kernel_format(
- layer=layer,
- w13_weight=w13_weight,
- w2_weight=w2_weight,
- w13_weight_scale=layer.w13_weight_scale,
- w2_weight_scale=layer.w2_weight_scale,
- w13_input_scale=None,
- w2_input_scale=None,
+ # Shuffle weights to runtime format and setup kernel.
+ self._setup_kernel(
+ layer,
+ w13,
+ w2,
+ w13_scale,
+ w2_scale,
+ layer.w13_input_scale,
+ layer.w2_input_scale,
)
- # Setup modular kernel for TP case.
- self._setup_kernel(layer)
-
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py
index b6752d7f9913..bd7a90a80af1 100644
--- a/vllm/model_executor/layers/quantization/modelopt.py
+++ b/vllm/model_executor/layers/quantization/modelopt.py
@@ -15,7 +15,6 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
- fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
@@ -24,6 +23,13 @@
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
)
+from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
+ Fp8MoeBackend,
+ convert_to_fp8_moe_kernel_format,
+ make_fp8_moe_kernel,
+ make_fp8_moe_quant_config,
+ select_fp8_moe_backend,
+)
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
@@ -45,19 +51,16 @@
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
- apply_flashinfer_per_tensor_scale_fp8,
+ apply_fi_trtllm_fp8_per_tensor_moe,
build_flashinfer_fp8_cutlass_moe_prepare_finalize,
- flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend,
is_flashinfer_supporting_global_sf,
- make_fp8_moe_alpha_scales_for_fi,
- register_scales_for_trtllm_fp8_per_tensor_moe,
- rotate_flashinfer_fp8_moe_weights,
select_cutlass_fp8_gemm_impl,
- swap_w13_to_w31,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
+ process_fp8_input_tensor_strategy_moe,
+ process_fp8_weight_tensor_strategy_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
get_marlin_input_dtype,
@@ -85,13 +88,12 @@
ModelWeightParameter,
PerTensorScaleParameter,
)
+from vllm.model_executor.utils import replace_parameter
from vllm.scalar_type import scalar_types
from vllm.utils.flashinfer import (
flashinfer_scaled_fp4_mm,
has_flashinfer,
- has_flashinfer_moe,
)
-from vllm.utils.math_utils import round_up
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@@ -721,38 +723,23 @@ def __init__(
layer: FusedMoE,
) -> None:
super().__init__(layer.moe_config)
- self.layer = layer
self.quant_config = quant_config
- from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
- cutlass_fp8_supported,
+ assert self.quant_config.is_checkpoint_fp8_serialized
+ self.fp8_backend = select_fp8_moe_backend(
+ block_quant=False,
+ tp_size=layer.moe_parallel_config.tp_size,
+ with_lora_support=self.moe.is_lora_enabled,
)
-
- self.cutlass_fp8_supported = cutlass_fp8_supported()
- self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
- if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
- self.flashinfer_moe_backend = get_flashinfer_moe_backend()
- if (
- self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
- and not self.moe.is_act_and_mul
- ):
- logger.info_once(
- "Non-gated MoE is not supported for min-latency mode,"
- "falling back to high-throughput mode"
- )
- self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
-
- logger.info_once(
- f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
- )
+ self.kernel: mk.FusedMoEModularKernel | None = None
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
# TRT LLM not supported with all2all yet.
- if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
+ if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
return None
- elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
+ elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
# TP case: avoid convert to ModularKernelMethod - to be refactored.
if self.moe.dp_size == 1:
return None
@@ -787,6 +774,9 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
):
+ layer.orig_dtype = params_dtype
+ layer.num_experts = num_experts
+
# Use FP8 dtype if checkpoint is serialized
weight_dtype = (
torch.float8_e4m3fn
@@ -826,217 +816,121 @@ def create_weights(
)
layer.register_parameter("w2_weight", w2_weight)
- if self.quant_config.is_checkpoint_fp8_serialized:
- # WEIGHT SCALES - Per-tensor scaling for ModelOpts
- # For gated MoE, allocate 2 scales for w1 and w3 respectively.
- # They will be combined to a single scale after weight loading.
- # For non-gated MoE, allocate 1 scale for w13.
- if self.moe.is_act_and_mul:
- w13_weight_scale_shape = (num_experts, 2)
- else:
- w13_weight_scale_shape = (num_experts, 1)
- w13_weight_scale = PerTensorScaleParameter(
- data=torch.full(
- w13_weight_scale_shape,
- 1.0,
- dtype=torch.float32,
- ),
- weight_loader=weight_loader,
- )
- w2_weight_scale = PerTensorScaleParameter(
- data=torch.full((num_experts,), 1.0, dtype=torch.float32),
- weight_loader=weight_loader,
- )
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
-
- # Set weight loader attributes for scales
- extra_weight_attrs.update(
- {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
- )
-
- # INPUT SCALES - Per-tensor scaling for ModelOpt
- w13_input_scale = PerTensorScaleParameter(
- data=torch.full((num_experts,), 1.0, dtype=torch.float32),
- weight_loader=weight_loader,
- )
- w2_input_scale = PerTensorScaleParameter(
- data=torch.full((num_experts,), 1.0, dtype=torch.float32),
- weight_loader=weight_loader,
- )
- layer.register_parameter("w13_input_scale", w13_input_scale)
- layer.register_parameter("w2_input_scale", w2_input_scale)
-
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- """Process FP8 MoE weights after loading from serialized checkpoint.
- Only supports pre-quantized checkpoints with FP8 weights and scales.
- """
-
- if self.flashinfer_moe_backend is not None:
- self._maybe_pad_intermediate_for_flashinfer(layer)
-
- layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
- layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
+ # WEIGHT SCALES - Per-tensor scaling for ModelOpts
+ # For gated MoE, allocate 2 scales for w1 and w3 respectively.
+ # They will be combined to a single scale after weight loading.
+ # For non-gated MoE, allocate 1 scale for w13.
+ w13_weight_scale = PerTensorScaleParameter(
+ data=torch.full(
+ (num_experts, 2 if self.moe.is_act_and_mul else 1),
+ 1.0,
+ dtype=torch.float32,
+ ),
+ weight_loader=weight_loader,
+ )
+ w2_weight_scale = PerTensorScaleParameter(
+ data=torch.full((num_experts,), 1.0, dtype=torch.float32),
+ weight_loader=weight_loader,
+ )
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
- from vllm._custom_ops import scaled_fp8_quant
- from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
- per_tensor_dequantize,
+ # INPUT SCALES - Per-tensor scaling for ModelOpt
+ w13_input_scale = PerTensorScaleParameter(
+ data=torch.full((num_experts,), 1.0, dtype=torch.float32),
+ weight_loader=weight_loader,
+ )
+ w2_input_scale = PerTensorScaleParameter(
+ data=torch.full((num_experts,), 1.0, dtype=torch.float32),
+ weight_loader=weight_loader,
)
+ layer.register_parameter("w13_input_scale", w13_input_scale)
+ layer.register_parameter("w2_input_scale", w2_input_scale)
- # Handle scale parameters
- if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None:
- # Fp8 moe kernel needs single weight scale for w13 per expert.
- # We take the max of the w1 and w3 scales
- # then dequant and requant each expert.
- if (
- layer.w13_weight_scale.dim() == 2
- and layer.w13_weight_scale.shape[1] == 2
- ):
- assert self.moe.is_act_and_mul, (
- "w13_weight_scale should have 2 elements per expert "
- "only for gated MoE"
- )
- # Get the maximum scale across w1 and w3 for each expert
- max_w13_scales = layer.w13_weight_scale.max(dim=1).values
-
- # Requantize each expert's weights using the combined scale
- # w13_weight (num_experts, 2 * intermediate_size, hidden_size)
- # where the first intermediate_size rows are w1, the next are w3
- intermediate_size = layer.w13_weight.shape[1] // 2
- for expert_id in range(layer.w13_weight.shape[0]):
- start = 0
- for shard_id in range(2): # w1 and w3
- # Dequantize using the original scale for this shard
- dq_weight = per_tensor_dequantize(
- layer.w13_weight[expert_id][
- start : start + intermediate_size, :
- ],
- layer.w13_weight_scale[expert_id][shard_id],
- )
- # Requantize using the combined max scale
-
- (
- layer.w13_weight[expert_id][
- start : start + intermediate_size, :
- ],
- _,
- ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
-
- start += intermediate_size
-
- # Update the scale parameter to be per-expert
- layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False)
- else:
- layer.w13_weight_scale = Parameter(
- layer.w13_weight_scale.data, requires_grad=False
- )
+ def _setup_kernel(
+ self,
+ layer: torch.nn.Module,
+ w13: torch.Tensor,
+ w2: torch.Tensor,
+ w13_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+ w13_input_scale: torch.Tensor,
+ w2_input_scale: torch.Tensor,
+ ):
+ w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format(
+ fp8_backend=self.fp8_backend,
+ layer=layer,
+ w13=w13,
+ w2=w2,
+ w13_scale=w13_scale,
+ w2_scale=w2_scale,
+ w13_input_scale=w13_input_scale,
+ w2_input_scale=w2_input_scale,
+ )
- if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None:
- layer.w2_weight_scale = Parameter(
- layer.w2_weight_scale.data, requires_grad=False
- )
- # Input scales must be equal for each expert in fp8 MoE layers.
- if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None:
- layer.w13_input_scale = Parameter(
- layer.w13_input_scale.max(), requires_grad=False
- )
- if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None:
- layer.w2_input_scale = Parameter(
- layer.w2_input_scale.max(), requires_grad=False
+ # Replace parameters with updated versions. Note that this helper
+ # function ensures the replacement is compatible with RL weight reloads.
+ replace_parameter(layer, "w13_weight", w13)
+ replace_parameter(layer, "w2_weight", w2)
+ replace_parameter(layer, "w13_weight_scale", w13_scale)
+ replace_parameter(layer, "w2_weight_scale", w2_scale)
+
+ # Setup modular kernel for TP case.
+ self.moe_quant_config = self.get_fused_moe_quant_config(layer)
+ if self.moe_quant_config:
+ self.kernel, self.use_inplace = make_fp8_moe_kernel(
+ layer=layer,
+ moe_quant_config=self.moe_quant_config,
+ moe_config=self.moe,
+ fp8_backend=self.fp8_backend,
)
- if self.flashinfer_moe_backend is not None:
- if self.moe.is_act_and_mul:
- layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
-
- # NOTE: this adds some attributes used by the trtllm kernel,
- # which does not conform to the modular kernels abstraction (yet).
- if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
- rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight)
- register_scales_for_trtllm_fp8_per_tensor_moe(
- layer=layer,
- w13_weight_scale=layer.w13_weight_scale,
- w13_input_scale=layer.w13_input_scale,
- w2_weight_scale=layer.w2_weight_scale,
- w2_input_scale=layer.w2_input_scale,
- )
-
- def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None:
- """Pad intermediate size so FlashInfer kernels' alignment constraints hold.
-
- Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
- used for GEMM to be divisible by a small alignment value. When this is
- not satisfied (e.g. with certain tensor-parallel sizes), we pad the
- gate/up and down projection weights along the intermediate dim.
- """
- if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"):
- return
-
- # Current local intermediate size (per partition) is the K dimension of
- # the down projection.
- num_experts, hidden_size, intermediate = layer.w2_weight.shape
-
- min_alignment = 16
- padded_intermediate = round_up(intermediate, min_alignment)
-
- if padded_intermediate == intermediate:
- return
-
- logger.info(
- "Padding intermediate size from %d to %d for up/down projection weights.",
- intermediate,
- padded_intermediate,
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ w13 = layer.w13_weight
+ w2 = layer.w2_weight
+ w13_scale = layer.w13_weight_scale
+ w2_scale = layer.w2_weight_scale
+ w13_input_scale = layer.w13_input_scale
+ w2_input_scale = layer.w2_input_scale
+
+ # Per tensor kernels require single activation scale. Use the max.
+ w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe(
+ w13_input_scale, w2_input_scale
+ )
+ replace_parameter(layer, "w13_input_scale", w13_input_scale)
+ replace_parameter(layer, "w2_input_scale", w2_input_scale)
+
+ # Per tensor kernels require single weight scale for w13 per expert, but
+ # on disk there is a scale for w1 and w3. Use the max to requantize.
+ shard_size = layer.intermediate_size_per_partition
+ w13, w13_scale = process_fp8_weight_tensor_strategy_moe(
+ w13,
+ w13_scale,
+ shard_size,
+ num_experts=layer.w13_weight.shape[0],
+ is_act_and_mul=self.moe.is_act_and_mul,
)
- up_mult = 2 if self.moe.is_act_and_mul else 1
- padded_gate_up_dim = up_mult * padded_intermediate
-
- # Pad w13 and w12 along its intermediate dimension.
- w13 = layer.w13_weight.data
- padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
- padded_w13[:, : w13.shape[1], :] = w13
- layer.w13_weight.data = padded_w13
-
- w2 = layer.w2_weight.data
- padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
- padded_w2[:, :, :intermediate] = w2
- layer.w2_weight.data = padded_w2
-
- if hasattr(layer, "intermediate_size_per_partition"):
- layer.intermediate_size_per_partition = padded_intermediate
+ # Shuffle weights to runtime format and setup kernel.
+ self._setup_kernel(
+ layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
+ )
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
- if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
- # TRTLLM does not use modular kernels
- return None
-
- elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
- g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
- layer.w13_weight_scale,
- layer.w13_input_scale,
- layer.w2_weight_scale,
- layer.w2_input_scale,
- )
- return fp8_w8a8_moe_quant_config(
- w1_scale=layer.w13_weight_scale,
- w2_scale=layer.w2_weight_scale,
- a1_scale=layer.w13_input_scale,
- a2_scale=layer.w2_input_scale,
- a1_gscale=(1.0 / layer.w13_input_scale),
- a2_gscale=(1.0 / layer.w2_input_scale),
- g1_alphas=g1_alphas,
- g2_alphas=g2_alphas,
- )
- else:
- assert self.flashinfer_moe_backend is None
- return fp8_w8a8_moe_quant_config(
- w1_scale=layer.w13_weight_scale,
- w2_scale=layer.w2_weight_scale,
- a1_scale=layer.w13_input_scale,
- a2_scale=layer.w2_input_scale,
- )
+ w1_scale = layer.w13_weight_scale
+ w2_scale = layer.w2_weight_scale
+ a1_scale = layer.w13_input_scale
+ a2_scale = layer.w2_input_scale
+
+ return make_fp8_moe_quant_config(
+ fp8_backend=self.fp8_backend,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ a1_scale=a1_scale,
+ a2_scale=a2_scale,
+ )
def apply(
self,
@@ -1044,17 +938,18 @@ def apply(
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
+ if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
if layer.enable_eplb:
raise NotImplementedError(
- "EPLB not supported for `ModelOptFp8MoEMethod` yet."
+ "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
)
+ # TODO(rob): this validation should happen at kernel selection
+ # time in the oracle rather than here.
assert layer.activation == "silu", (
f"Expected 'silu' activation but got {layer.activation}"
)
-
assert not layer.renormalize
- return apply_flashinfer_per_tensor_scale_fp8(
+ return apply_fi_trtllm_fp8_per_tensor_moe(
layer=layer,
hidden_states=x,
router_logits=router_logits,
@@ -1066,46 +961,34 @@ def apply(
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
- # Expert selection
topk_weights, topk_ids = layer.select_experts(
hidden_states=x,
router_logits=router_logits,
)
- if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
+ # TODO(rob): this validation should happen at kernel selection
+ # time in the oracle rather than here.
+ if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
assert layer.activation in ("silu", "relu2_no_mul"), (
"Expected activation to be in ('silu', 'relu2_no_mul'),"
f"but got {layer.activation}"
)
- return flashinfer_cutlass_moe_fp8(
- x,
- layer,
- topk_weights,
- topk_ids,
- inplace=False,
- activation=layer.activation,
- global_num_experts=layer.global_num_experts,
- expert_map=layer.expert_map,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- )
- else:
- from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
- assert self.moe_quant_config is not None
+ assert self.kernel is not None
+ result = self.kernel(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ topk_weights,
+ topk_ids,
+ inplace=self.use_inplace,
+ activation=layer.activation,
+ global_num_experts=layer.global_num_experts,
+ expert_map=layer.expert_map,
+ apply_router_weight_on_input=layer.apply_router_weight_on_input,
+ )
- return fused_experts(
- x,
- layer.w13_weight,
- layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True,
- activation=layer.activation,
- quant_config=self.moe_quant_config,
- global_num_experts=layer.global_num_experts,
- expert_map=layer.expert_map,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- )
+ return result
ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod
diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py
index 15d37f7d366b..4ab618dc44ef 100644
--- a/vllm/model_executor/layers/quantization/quark/quark_moe.py
+++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py
@@ -22,7 +22,7 @@
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
- prepare_moe_fp8_layer_for_marlin,
+ prepare_fp8_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
@@ -315,8 +315,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
elif self.use_marlin:
- (workspace, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale) = (
- prepare_moe_fp8_layer_for_marlin(
+ w13_weight, w2_weight, w13_weight_scale, w2_weight_scale = (
+ prepare_fp8_moe_layer_for_marlin(
layer,
layer.w13_weight,
layer.w2_weight,
@@ -324,7 +324,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w2_weight_scale,
)
)
- layer.workspace = workspace
# TODO(rob): once we apply refactor to Quark, switch to using
# replace_parameter for compatibility with reloading in RL.
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
index b73c44b3130d..a9b30b780587 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
@@ -18,6 +18,7 @@
create_flashinfer_prepare_finalize,
)
from vllm.platforms import current_platform
+from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
@@ -58,9 +59,10 @@ def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
)
-def rotate_flashinfer_fp8_moe_weights(
+def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor
):
+ """Shuffle weights for for FI TRT-LLM Format"""
from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a
epilogue_tile_m = 128
@@ -105,16 +107,16 @@ def rotate_flashinfer_fp8_moe_weights(
def register_scales_for_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
- w13_weight_scale: torch.Tensor,
+ w13_scale: torch.Tensor,
w13_input_scale: torch.Tensor,
- w2_weight_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
w2_input_scale: torch.Tensor,
) -> None:
"""Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
- w13_scale=w13_weight_scale,
+ w13_scale=w13_scale,
w13_input_scale=w13_input_scale,
- w2_scale=w2_weight_scale,
+ w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
)
layer.w2_input_scale_inv = 1.0 / w2_input_scale
@@ -123,7 +125,7 @@ def register_scales_for_trtllm_fp8_per_tensor_moe(
layer.output2_scales_scalar = g2_alphas
-def apply_flashinfer_per_tensor_scale_fp8(
+def apply_fi_trtllm_fp8_per_tensor_moe(
layer: torch.nn.Module,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
@@ -139,16 +141,23 @@ def apply_flashinfer_per_tensor_scale_fp8(
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
from vllm.model_executor.models.llama4 import Llama4MoE
+ # Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
assert (
hasattr(layer, "output1_scales_scalar")
and hasattr(layer, "output1_scales_gate_scalar")
and hasattr(layer, "output2_scales_scalar")
)
- assert layer.custom_routing_function == Llama4MoE.custom_routing_function, (
- "FusedMoE flashinfer kernels are only supported for Llama4"
+ # Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
+ assert (
+ hasattr(layer, "output1_scales_scalar")
+ and hasattr(layer, "output1_scales_gate_scalar")
+ and hasattr(layer, "output2_scales_scalar")
)
- return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8(
+
+ is_llama4 = layer.custom_routing_function == Llama4MoE.custom_routing_function
+ assert is_llama4, "FusedMoE flashinfer kernels are only supported for Llama4"
+ return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
routing_logits=router_logits,
routing_bias=routing_bias,
hidden_states=hidden_states,
@@ -221,50 +230,6 @@ def select_cutlass_fp8_gemm_impl(
)
-def flashinfer_cutlass_moe_fp8(
- hidden_states: torch.Tensor,
- layer: torch.nn.Module,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- inplace: bool = False,
- activation: str = "silu",
- global_num_experts: int = -1,
- expert_map: torch.Tensor | None = None,
- apply_router_weight_on_input: bool = False,
- use_deepseek_fp8_block_scale: bool = False,
- moe: FusedMoEConfig | None = None,
-) -> torch.Tensor:
- quant_config = layer.quant_method.get_fused_moe_quant_config(layer)
- assert quant_config is not None
-
- # Construct modular kernel with block-scale support when requested.
- fused_experts = mk.FusedMoEModularKernel(
- build_flashinfer_fp8_cutlass_moe_prepare_finalize(
- moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale
- ),
- select_cutlass_fp8_gemm_impl(
- moe=moe,
- quant_config=quant_config,
- out_dtype=hidden_states.dtype,
- use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
- ),
- moe_parallel_config=layer.moe_parallel_config,
- )
-
- return fused_experts(
- hidden_states,
- layer.w13_weight,
- layer.w2_weight,
- topk_weights,
- topk_ids,
- inplace=inplace,
- activation=activation,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- apply_router_weight_on_input=apply_router_weight_on_input,
- )
-
-
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS,
@@ -301,3 +266,104 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) ->
FlashinferMoeBackend.TENSORRT_LLM,
)
return backend in backends_supporting_global_sf
+
+
+def align_fp8_moe_weights_for_fi(
+ w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool
+) -> tuple[torch.Tensor, torch.Tensor, int]:
+ """Pad intermediate size so FlashInfer kernels' alignment constraints hold.
+
+ Some FlashInfer FP8 MoE kernels require the (gated) intermediate size
+ used for GEMM to be divisible by a small alignment value. When this is
+ not satisfied (e.g. with certain tensor-parallel sizes), we pad the
+ gate/up and down projection weights along the intermediate dim.
+ """
+
+ # Current local intermediate size (per partition) is the K dimension of
+ # the down projection.
+ num_experts, hidden_size, intermediate = w2.shape
+
+ min_alignment = 16
+ padded_intermediate = round_up(intermediate, min_alignment)
+
+ if padded_intermediate == intermediate:
+ return w13, w2, intermediate
+
+ logger.info_once(
+ "Padding intermediate size from %d to %d for up/down projection weights.",
+ intermediate,
+ padded_intermediate,
+ scope="local",
+ )
+
+ up_mult = 2 if is_act_and_mul else 1
+ padded_gate_up_dim = up_mult * padded_intermediate
+
+ # Pad w13 and w2 along its intermediate dimension.
+ padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size))
+ padded_w13[:, : w13.shape[1], :] = w13
+
+ padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate))
+ padded_w2[:, :, :intermediate] = w2
+
+ return padded_w13, padded_w2, padded_intermediate
+
+
+def prepare_fp8_moe_layer_for_fi(
+ layer: torch.nn.Module,
+ w13: torch.Tensor,
+ w2: torch.Tensor,
+ w13_scale: torch.Tensor,
+ w13_input_scale: torch.Tensor | None,
+ w2_scale: torch.Tensor,
+ w2_input_scale: torch.Tensor | None,
+ is_trtllm: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Convert Fp8 MoE weights to flashinfer kernel format
+
+ Note that for trtllm we update the model state dict
+ with the scale format needed for these kernels.
+
+ Note that for per-tensor, we update the layer's
+ intermediate size if the weights needed padding.
+ """
+
+ assert hasattr(layer.moe_config, "is_act_and_mul")
+ block_quant = (
+ hasattr(layer, "weight_block_size") and layer.weight_block_size is not None
+ )
+
+ # Some FI MoE kernels require internal alignment of 16
+ # for the gate-up proj. Pad the weights to respect this.
+ if not block_quant:
+ w13, w2, new_intermediate = align_fp8_moe_weights_for_fi(
+ w13,
+ w2,
+ layer.moe_config.is_act_and_mul,
+ )
+ layer.intermediate_size_per_partition = new_intermediate
+
+ # FI kernels require W31 layout rather than W13.
+ if layer.moe_config.is_act_and_mul:
+ w13 = swap_w13_to_w31(w13)
+ if block_quant:
+ w13_scale = swap_w13_to_w31(w13_scale)
+
+ # FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
+ # and registration of alpha scales. Note that we do not register
+ # as nn.Parameters since they are not needed for weight-reloading.
+ if is_trtllm and not block_quant:
+ assert w13_input_scale is not None
+ assert w2_input_scale is not None
+
+ rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2)
+ register_scales_for_trtllm_fp8_per_tensor_moe(
+ layer,
+ w13_scale=w13_scale,
+ w13_input_scale=w13_input_scale,
+ w2_scale=w2_scale,
+ w2_input_scale=w2_input_scale,
+ )
+
+ return w13, w2, w13_scale
diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py
index 1c9e36e02248..9d74becd588d 100644
--- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py
@@ -21,6 +21,8 @@
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
+ all_close_1d,
+ per_tensor_dequantize,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
@@ -1350,6 +1352,29 @@ def deepgemm_post_process_fp8_weight_block(
return wq, dg_ws
+def prepare_fp8_moe_layer_for_deepgemm(
+ w13: torch.Tensor,
+ w2: torch.Tensor,
+ w13_scale: torch.Tensor,
+ w2_scale: torch.Tensor,
+ block_shape: tuple[int],
+):
+ w13, w13_scale = deepgemm_post_process_fp8_weight_block(
+ wq=w13,
+ ws=w13_scale,
+ quant_block_shape=block_shape,
+ use_e8m0=is_deep_gemm_e8m0_used(),
+ )
+ w2, w2_scale = deepgemm_post_process_fp8_weight_block(
+ wq=w2,
+ ws=w2_scale,
+ quant_block_shape=block_shape,
+ use_e8m0=is_deep_gemm_e8m0_used(),
+ )
+
+ return w13, w2, w13_scale, w2_scale
+
+
def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
"""Pad the weight tensor. This is an optimization on ROCm platform, which
can benefit from tensors located far enough from one another in memory"""
@@ -1584,7 +1609,49 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
replace_parameter(layer, scale_attr, dg_weight_scale)
-def expert_weight_is_col_major(x: torch.Tensor) -> bool:
- assert x.dim() == 3
- b, m, n = x.shape
- return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m
+def process_fp8_weight_tensor_strategy_moe(
+ weight: torch.Tensor,
+ weight_scales: torch.Tensor,
+ shard_size: int,
+ num_experts: int,
+ is_act_and_mul: bool = True,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Process moe weights for tensor-wise quantization strategy."""
+ max_scales = weight_scales.max(dim=1).values
+
+ # For w1 case (i.e. not w13): just collapse the last dim since
+ # there is already just one scale per expert in this case.
+ if not is_act_and_mul:
+ assert weight_scales.shape[1] == 1
+ return weight, weight_scales.max()
+
+ # For w13 case (common): require single scale for w13 per expert, but
+ # on disk there is a scale for w1 and w3. Use the max to requantize.
+ for expert_id in range(num_experts):
+ start = 0
+ for shard_id in range(2):
+ dq_weight = per_tensor_dequantize(
+ weight[expert_id][start : start + shard_size, :],
+ weight_scales[expert_id][shard_id],
+ )
+ weight[expert_id][start : start + shard_size, :], _ = ops.scaled_fp8_quant(
+ dq_weight, max_scales[expert_id]
+ )
+ start += shard_size
+ return weight, max_scales
+
+
+def process_fp8_input_tensor_strategy_moe(
+ w13_input_scale: torch.Tensor,
+ w2_input_scale: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Process moe input scales for tensor-wise quantization strategy."""
+
+ if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale):
+ logger.info_once(
+ "Found input_scales that are not equal for "
+ "fp8 MoE layer. Using the maximum across experts "
+ "for each layer."
+ )
+
+ return w13_input_scale.max(), w2_input_scale.max()
diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py
index 0c2f984f9f0f..0e21c81f70f8 100644
--- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py
@@ -496,7 +496,7 @@ def get__quant_fp8_method() -> QuantFP8:
return _quant_fp8_method
-def get_marlin_input_dtype(prefix):
+def get_marlin_input_dtype(prefix: str | None = None):
if envs.VLLM_MARLIN_INPUT_DTYPE is None:
return
elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8":
diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
index a8d9db224860..91b93c76cb32 100644
--- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
+++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
@@ -8,6 +8,7 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
USE_FP32_REDUCE_DEFAULT,
+ get_marlin_input_dtype,
marlin_make_workspace_new,
marlin_permute_bias,
marlin_permute_scales,
@@ -197,26 +198,28 @@ def prepare_fp8_layer_for_marlin(
replace_parameter(layer, "bias", bias)
-def prepare_moe_fp8_layer_for_marlin(
+def prepare_fp8_moe_layer_for_marlin(
layer: torch.nn.Module,
w13_weight: torch.Tensor,
w2_weight: torch.Tensor,
w13_weight_scale: torch.Tensor,
w2_weight_scale: torch.Tensor,
- input_dtype: torch.dtype | None = None,
-) -> tuple[
- torch.Tensor, # workspace
- torch.Tensor, # w13_weight
- torch.Tensor, # w2_weight
- torch.Tensor, # w13_weight_scale
- torch.Tensor, # w2_weight_scale
-]:
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Shuffle weights and scales into marlin format.
+
+ Note that this function has the side effect of adding a `workspace`
+ attribute to the layer. This `workspace` does not need to be
+ registered as a Parameter as it is not used during weight reloading.
+ """
+
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
+ input_dtype = get_marlin_input_dtype()
if input_dtype is not None and input_dtype.itemsize == 1:
raise NotImplementedError("Marlin W8A8 is not supported.")
@@ -227,7 +230,9 @@ def prepare_moe_fp8_layer_for_marlin(
# WORKSPACE
device = layer.w13_weight.device
- workspace = marlin_make_workspace_new(device, 4)
+ # NOTE(rob): we do not need to register the workspace as a param
+ # because it is not used as part of the weight reloading process.
+ layer.workspace = marlin_make_workspace_new(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT
@@ -310,13 +315,7 @@ def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor:
w13_weight_scale = permute_scales(w13_weight_scale, "w13")
w2_weight_scale = permute_scales(w2_weight_scale, "w2")
- return (
- workspace,
- w13_weight,
- w2_weight,
- w13_weight_scale,
- w2_weight_scale,
- )
+ return w13_weight, w2_weight, w13_weight_scale, w2_weight_scale
def pack_fp8_to_int32(