diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9a26dc611515..1f4e683596e2 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -1422,3 +1422,10 @@ steps: num_gpus: 2 commands: - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor/config-b200.txt + +- label: MoE Refactor Integration Test (B200 DP - TEMPORARY) # optional + gpu: b200 + optional: true + num_gpus: 2 + commands: + - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt diff --git a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py index 626b3b160044..9c6edee7b264 100644 --- a/benchmarks/kernels/benchmark_cutlass_moe_fp8.py +++ b/benchmarks/kernels/benchmark_cutlass_moe_fp8.py @@ -6,13 +6,16 @@ but use different quantization strategies and backends. """ -import nvtx import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.platforms import current_platform from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.v1.worker.workspace import init_workspace_manager @@ -59,6 +62,7 @@ def bench_run( per_out_ch: bool, mkn: tuple[int, int, int], ): + init_workspace_manager(torch.cuda.current_device()) (m, k, n) = mkn dtype = torch.half @@ -121,85 +125,6 @@ def bench_run( # Force per-tensor quantization for all cases per_act_token = False - # Create stride tensors for CUTLASS - ab_strides1 = torch.full((num_experts,), k, dtype=torch.int64, device=device) - ab_strides2 = torch.full((num_experts,), n, dtype=torch.int64, device=device) - c_strides1 = torch.full((num_experts,), 2 * n, dtype=torch.int64, device=device) - c_strides2 = torch.full((num_experts,), k, dtype=torch.int64, device=device) - - def run_triton_moe( - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - a1_scale: torch.Tensor, - a2_scale: torch.Tensor, - num_repeats: int, - ): - quant_config = fp8_w8a8_moe_quant_config( - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - per_act_token_quant=per_act_token, - per_out_ch_quant=per_out_ch, - ) - - for _ in range(num_repeats): - fused_experts( - a, - w1, - w2, - topk_weights, - topk_ids, - quant_config=quant_config, - ) - - def run_cutlass_moe_fp8( - a: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides1: torch.Tensor, - c_strides2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - a1_scale: torch.Tensor, - a2_scale: torch.Tensor, - num_repeats: int, - ): - quant_config = fp8_w8a8_moe_quant_config( - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - per_act_token_quant=per_act_token, - per_out_ch_quant=per_out_ch, - ) - - for _ in range(num_repeats): - with nvtx.annotate("cutlass_moe_fp8", color="blue"): - cutlass_moe_fp8( - a=a, - w1_q=w1, - w2_q=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - ab_strides1=ab_strides1, - ab_strides2=ab_strides2, - c_strides1=c_strides1, - c_strides2=c_strides2, - quant_config=quant_config, - activation="silu", - global_num_experts=num_experts, - ) - # Pre-create quantization config to avoid creating it inside CUDA graph quant_config = fp8_w8a8_moe_quant_config( w1_scale=w1_scale, @@ -210,23 +135,30 @@ def run_cutlass_moe_fp8( per_out_ch_quant=per_out_ch, ) + fn = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + CutlassExpertsFp8( + out_dtype=a.dtype, + e=num_experts, + n=n, + k=k, + quant_config=quant_config, + device=w1.device, + ), + ) + # Create CUDA graphs for CUTLASS (match benchmark_moe.py pattern exactly) cutlass_stream = torch.cuda.Stream() cutlass_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): # Capture 10 invocations like benchmark_moe.py for _ in range(10): - cutlass_moe_fp8( - a=a, - w1_q=w1_fp8q_cutlass, - w2_q=w2_fp8q_cutlass, - topk_weights=topk_weights, - topk_ids=topk_ids, - ab_strides1=ab_strides1, - ab_strides2=ab_strides2, - c_strides1=c_strides1, - c_strides2=c_strides2, - quant_config=quant_config, + fn( + a, + w1_fp8q_cutlass, + w2_fp8q_cutlass, + topk_weights, + topk_ids, activation="silu", global_num_experts=num_experts, ) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 4390be8770c1..b30a1263878b 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -5,14 +5,18 @@ import torch.utils.benchmark as benchmark from benchmark_shapes import WEIGHT_SHAPES_MOE +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_topk, ) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.v1.worker.workspace import init_workspace_manager @@ -45,6 +49,7 @@ def bench_run( per_out_ch: bool, mkn: tuple[int, int, int], ): + init_workspace_manager(torch.cuda.current_device()) label = "Quant Matmul" sub_label = ( @@ -82,11 +87,6 @@ def bench_run( a, score, topk, renormalize=False ) - ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) - def run_triton_moe( a: torch.Tensor, w1: torch.Tensor, @@ -120,10 +120,6 @@ def run_cutlass_moe( w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - ab_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides1: torch.Tensor, - c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, per_act_token: bool, @@ -135,31 +131,29 @@ def run_cutlass_moe( per_act_token_quant=per_act_token, ) - for _ in range(num_repeats): - cutlass_moe_fp8( - a, - w1, - w2, - topk_weights, - topk_ids, - ab_strides1, - ab_strides2, - c_strides1, - c_strides2, + fn = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + CutlassExpertsFp8( + out_dtype=a.dtype, + # NOTE(rob): w2 is shaped as [E, hidden, intermediate] + e=w2.shape[0], + n=w2.shape[2], + k=w2.shape[1], quant_config=quant_config, - ) + device=w1.device, + ), + ) + + for _ in range(num_repeats): + fn(a, w1, w2, topk_weights, topk_ids) def run_cutlass_from_graph( a: torch.Tensor, a_scale: torch.Tensor, - w1_q: torch.Tensor, - w2_q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - ab_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides1: torch.Tensor, - c_strides2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, ): @@ -169,21 +163,23 @@ def run_cutlass_from_graph( per_act_token_quant=per_act_token, ) + fn = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + CutlassExpertsFp8( + out_dtype=a.dtype, + # NOTE(rob): w2 is shaped as [E, hidden, intermediate] + e=w2.shape[0], + n=w2.shape[2], + k=w2.shape[1], + quant_config=quant_config, + device=w1.device, + ), + ) + with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) ): - return cutlass_moe_fp8( - a, - w1_q, - w2_q, - topk_weights, - topk_ids, - ab_strides1, - ab_strides2, - c_strides1, - c_strides2, - quant_config=quant_config, - ) + return fn(a, w1, w2, topk_weights, topk_ids) def run_triton_from_graph( a: torch.Tensor, @@ -227,10 +223,6 @@ def replay_graph(graph, num_repeats): w2_q, w1_scale, w2_scale, - ab_strides1, - ab_strides2, - c_strides1, - c_strides2, topk_weights, topk_ids, ) @@ -268,10 +260,6 @@ def replay_graph(graph, num_repeats): "w1_scale": w1_scale, "w2_scale": w2_scale, "per_act_token": per_act_token, - "ab_strides1": ab_strides1, - "ab_strides2": ab_strides2, - "c_strides1": c_strides1, - "c_strides2": c_strides2, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, @@ -330,10 +318,6 @@ def replay_graph(graph, num_repeats): w2_q, w1_scale, w2_scale, - ab_strides1, - ab_strides2, - c_strides1, - c_strides2, topk_weights, topk_ids, per_act_token, @@ -342,7 +326,7 @@ def replay_graph(graph, num_repeats): results.append( benchmark.Timer( - stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 26a281f4e4fb..b31cfd61161f 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -48,8 +48,6 @@ def clear_triton_cache(): # Try to clear Triton's runtime cache try: - import triton - if ( hasattr(triton, "runtime") and hasattr(triton.runtime, "cache") diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 11c6e488f958..d683b538c415 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -87,7 +87,7 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | triton (batched) | batched | all1 | G,A,T | silu, gelu | 6 | Y | [`BatchedTritonExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.BatchedTritonExperts] | | deep gemm | standard,
batched | fp8 | G(128),A,T | silu, gelu | 6 | Y | [`deep_gemm_moe_fp8`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.deep_gemm_moe_fp8],
[`DeepGemmExperts`][vllm.model_executor.layers.fused_moe.deep_gemm_moe.DeepGemmExperts],
[`BatchedDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe.BatchedDeepGemmExperts] | | cutlass_fp4 | standard,
batched | nvfp4 | A,T | silu | Y | Y | [`cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp4],
[`CutlassExpertsFp4`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp4] | -| cutlass_fp8 | standard,
batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],
[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],
[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] | +| cutlass_fp8 | standard,
batched | fp8 | A,T | silu, gelu | Y | Y | [`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],
[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] | | flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],
[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | marlin | standard,
batched | 3 / N/A | 3 / N/A | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml new file mode 100644 index 000000000000..9d62c542a085 --- /dev/null +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Llama-4-Scout-Fp8-ModelOpt-triton.yaml @@ -0,0 +1,5 @@ +model_name: "nvidia/Llama-4-Scout-17B-16E-Instruct-FP8" +accuracy_threshold: 0.92 +num_questions: 1319 +num_fewshot: 5 +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ht.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ht.yaml new file mode 100644 index 000000000000..276d63f4ee10 --- /dev/null +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ht.yaml @@ -0,0 +1,8 @@ +model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" +accuracy_threshold: 0.88 +num_questions: 1319 +num_fewshot: 5 +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_high_throughput" +env: + VLLM_USE_DEEP_GEMM: "1" + VLLM_USE_DEEP_GEMM_MOE: "1" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml new file mode 100644 index 000000000000..54e6ab7b35f6 --- /dev/null +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml @@ -0,0 +1,9 @@ +model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" +accuracy_threshold: 0.88 +num_questions: 1319 +num_fewshot: 5 +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_low_latency --disable-uvicorn-access-log" +env: + VLLM_USE_DEEP_GEMM: "1" + VLLM_USE_DEEP_GEMM_MOE: "1" + VLLM_USE_DEEP_GEMM_E8M0: "0" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm.yaml new file mode 100644 index 000000000000..eee58539c4a8 --- /dev/null +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm.yaml @@ -0,0 +1,8 @@ +model_name: "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" +accuracy_threshold: 0.88 +num_questions: 1319 +num_fewshot: 5 +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel" +env: + VLLM_USE_DEEP_GEMM: "1" + VLLM_USE_DEEP_GEMM_MOE: "1" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ht.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ht.yaml new file mode 100644 index 000000000000..2083df585f4d --- /dev/null +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ht.yaml @@ -0,0 +1,8 @@ +model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block" +accuracy_threshold: 0.85 +num_questions: 1319 +num_fewshot: 5 +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_high_throughput" +env: + VLLM_USE_DEEP_GEMM: "1" + VLLM_USE_DEEP_GEMM_MOE: "1" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml new file mode 100644 index 000000000000..1d4cbfe96b68 --- /dev/null +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml @@ -0,0 +1,9 @@ +model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block" +accuracy_threshold: 0.85 +num_questions: 1319 +num_fewshot: 5 +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel --all2all-backend deepep_low_latency --disable-uvicorn-access-log" +env: + VLLM_USE_DEEP_GEMM: "1" + VLLM_USE_DEEP_GEMM_MOE: "1" + VLLM_USE_DEEP_GEMM_E8M0: "0" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml new file mode 100644 index 000000000000..246549d62961 --- /dev/null +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml @@ -0,0 +1,8 @@ +model_name: "RedHatAI/Qwen3-30B-A3B-FP8-block" +accuracy_threshold: 0.85 +num_questions: 1319 +num_fewshot: 5 +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel" +env: + VLLM_USE_DEEP_GEMM: "1" + VLLM_USE_DEEP_GEMM_MOE: "1" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml new file mode 100644 index 000000000000..53fd62bac839 --- /dev/null +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml @@ -0,0 +1,8 @@ +model_name: "RedHatAI/Qwen3-30B-A3B-NVFP4" +accuracy_threshold: 0.88 +num_questions: 1319 +num_fewshot: 5 +server_args: "--enforce-eager --max-model-len 8192 --data-parallel-size 2 --enable-expert-parallel" +env: + VLLM_USE_FLASHINFER_MOE_FP4: "1" + VLLM_FLASHINFER_MOE_BACKEND: "throughput" diff --git a/tests/evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt new file mode 100644 index 000000000000..c1b405fd1d00 --- /dev/null +++ b/tests/evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt @@ -0,0 +1,8 @@ +Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ht.yaml +Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm-deepep-ll.yaml +Qwen3-30B-A3B-Fp8-AutoFp8-deepgemm.yaml +Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ht.yaml +Qwen3-30B-A3B-Fp8-CT-Block-deepgemm-deepep-ll.yaml +Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml +Qwen3-30B-A3B-NvFp4-CT-fi-cutlass.yaml +Qwen3-30B-A3B-NvFp4-ModelOpt-fi-cutlass-fi-a2av.yaml diff --git a/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-CT-vllm-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-CT-vllm-cutlass.yaml new file mode 100644 index 000000000000..bf8c93921f41 --- /dev/null +++ b/tests/evals/gsm8k/configs/moe-refactor/Llama-4-Scout-Fp8-CT-vllm-cutlass.yaml @@ -0,0 +1,5 @@ +model_name: "RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic" +accuracy_threshold: 0.92 +num_questions: 1319 +num_fewshot: 5 +server_args: "--enforce-eager --max-model-len 8192 --tensor-parallel-size 2" diff --git a/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-vllm-cutlass.yaml b/tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml similarity index 100% rename from tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-vllm-cutlass.yaml rename to tests/evals/gsm8k/configs/moe-refactor/Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml diff --git a/tests/evals/gsm8k/configs/moe-refactor/config-b200.txt b/tests/evals/gsm8k/configs/moe-refactor/config-b200.txt index bf02f1363be3..9d86e432e84f 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/config-b200.txt +++ b/tests/evals/gsm8k/configs/moe-refactor/config-b200.txt @@ -1,3 +1,4 @@ +Llama-4-Scout-Fp8-CT-vllm-cutlass.yaml Llama-4-Scout-Fp8-ModelOpt-fi-trtllm.yaml Qwen3-30B-A3B-Fp8-AutoFp8-fi-trtllm.yaml Qwen3-30B-A3B-NvFp4-CT-vllm-cutlass.yaml diff --git a/tests/evals/gsm8k/configs/moe-refactor/config-h100.txt b/tests/evals/gsm8k/configs/moe-refactor/config-h100.txt index 9725db7c8be2..2c25ea2c2caa 100644 --- a/tests/evals/gsm8k/configs/moe-refactor/config-h100.txt +++ b/tests/evals/gsm8k/configs/moe-refactor/config-h100.txt @@ -5,7 +5,7 @@ Qwen3-30B-A3B-Fp8-AutoFp8-marlin.yaml Qwen3-30B-A3B-Fp8-AutoFp8-triton.yaml Qwen3-30B-A3B-Fp8-CT-Block-deepgemm.yaml Qwen3-30B-A3B-Fp8-CT-Block-marlin.yaml -Qwen3-30B-A3B-Fp8-CT-Block-vllm-cutlass.yaml +Qwen3-30B-A3B-Fp8-CT-Block-triton.yaml Qwen3-30B-A3B-Fp8-CT-Channel-marlin.yaml Qwen3-30B-A3B-Fp8-CT-Channel-vllm-cutlass.yaml Llama-4-Scout-Fp8-ModelOpt-fi-cutlass.yaml diff --git a/tests/evals/gsm8k/test_gsm8k_correctness.py b/tests/evals/gsm8k/test_gsm8k_correctness.py index 991b905211ff..6b2cb02e9401 100644 --- a/tests/evals/gsm8k/test_gsm8k_correctness.py +++ b/tests/evals/gsm8k/test_gsm8k_correctness.py @@ -61,6 +61,7 @@ def test_gsm8k_correctness(config_filename): server_args.extend( [ "--trust-remote-code", + "--disable-uvicorn-access-log", ] ) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 4a57affdfbf4..cd5bf47d69e5 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -7,17 +7,22 @@ import pytest import torch +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8, + CutlassExpertsFp8, run_cutlass_moe_fp8, ) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed @@ -150,16 +155,15 @@ def make_moe_tensors_8bit( def run_with_expert_maps( - num_experts: int, num_local_experts: int, **cutlass_moe_kwargs + num_experts: int, + num_local_experts: int, + quant_config: FusedMoEQuantConfig, + **cutlass_moe_kwargs, ): def slice_experts(): slice_params = [ - "w1_q", - "w2_q", - "ab_strides1", - "ab_strides2", - "c_strides1", - "c_strides2", + "w1", + "w2", ] full_tensors = { k: v @@ -167,8 +171,6 @@ def slice_experts(): if k in slice_params and k in cutlass_moe_kwargs } - quant_config = cutlass_moe_kwargs["quant_config"] - for i in range(0, num_experts, num_local_experts): s, e = i, i + num_local_experts @@ -187,13 +189,23 @@ def slice_experts(): new_quant_config._w1.scale = quant_config.w1_scale[s:e] new_quant_config._w2.scale = quant_config.w2_scale[s:e] - cutlass_moe_kwargs["quant_config"] = new_quant_config - - yield cutlass_moe_kwargs - - out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"]) - for kwargs in slice_experts(): - out_tensor = out_tensor + cutlass_moe_fp8(**kwargs) + yield cutlass_moe_kwargs, new_quant_config + + out_tensor = torch.zeros_like(cutlass_moe_kwargs["hidden_states"]) + for kwargs, new_quant_config in slice_experts(): + kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + CutlassExpertsFp8( + out_dtype=kwargs["hidden_states"].dtype, + # NOTE(rob): w2 is shaped as [E, hidden, intermediate] + e=kwargs["w2"].shape[0], # type: ignore[union-attr] + n=kwargs["w2"].shape[2], # type: ignore[union-attr] + k=kwargs["w2"].shape[1], # type: ignore[union-attr] + quant_config=new_quant_config, + device="cuda", + ), + ) + out_tensor = out_tensor + kernel(**kwargs) return out_tensor @@ -230,27 +242,35 @@ def run_8_bit( ) kwargs = { - "a": moe_tensors.a, - "w1_q": moe_tensors.w1_q, # type: ignore[union-attr] - "w2_q": moe_tensors.w2_q, # type: ignore[union-attr] + "hidden_states": moe_tensors.a, + "w1": moe_tensors.w1_q, # type: ignore[union-attr] + "w2": moe_tensors.w2_q, # type: ignore[union-attr] "topk_weights": topk_weights, "topk_ids": topk_ids, - "ab_strides1": moe_tensors.ab_strides1, - "ab_strides2": moe_tensors.ab_strides2, - "c_strides1": moe_tensors.c_strides1, - "c_strides2": moe_tensors.c_strides2, - "quant_config": quant_config, } num_experts = moe_tensors.w1.size(0) with_ep = num_local_experts is not None or num_local_experts == num_experts if not with_ep: - return cutlass_moe_fp8(**kwargs) + kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + CutlassExpertsFp8( + out_dtype=moe_tensors.a.dtype, + # NOTE(rob): w2 is shaped as [E, hidden, intermediate] + e=moe_tensors.w2_q.shape[0], # type: ignore[union-attr] + n=moe_tensors.w2_q.shape[2], # type: ignore[union-attr] + k=moe_tensors.w2_q.shape[1], # type: ignore[union-attr] + quant_config=quant_config, + device="cuda", + ), + ) + return kernel(**kwargs) assert num_local_experts is not None return run_with_expert_maps( num_experts, num_local_experts, # type: ignore[arg-type] + quant_config, **kwargs, ) diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index c23107965340..bb2f6b873941 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -11,12 +11,17 @@ FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, ) +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, +) from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, +) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_flashinfer_per_tensor_scale_fp8, - flashinfer_cutlass_moe_fp8, + apply_fi_trtllm_fp8_per_tensor_moe, register_scales_for_trtllm_fp8_per_tensor_moe, - rotate_flashinfer_fp8_moe_weights, + rotate_weights_for_fi_trtllm_fp8_per_tensor_moe, swap_w13_to_w31, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import input_to_float8 @@ -103,6 +108,7 @@ def make_moe_tensors_8bit( w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2) layer = torch.nn.Module() + layer.orig_dtype = torch.bfloat16 layer.w13_weight = w13_quantized.clone() layer.w2_weight = w2_quantized.clone() layer.w13_input_scale = a1_scale @@ -115,10 +121,10 @@ def make_moe_tensors_8bit( pcp_size=1, dp_size=1, ep_size=1, - tp_rank=1, - pcp_rank=1, - dp_rank=1, - ep_rank=1, + tp_rank=0, + pcp_rank=0, + dp_rank=0, + ep_rank=0, use_ep=False, all2all_backend="naive", ) @@ -126,7 +132,9 @@ def make_moe_tensors_8bit( # flashinfer expects swapped rows for w13 layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) if is_trtllm: - rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) + rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( + layer.w13_weight, layer.w2_weight + ) register_scales_for_trtllm_fp8_per_tensor_moe( layer, layer.w13_weight_scale, @@ -199,7 +207,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( quant_config=quant_config, ) - flashinfer_output = apply_flashinfer_per_tensor_scale_fp8( + flashinfer_output = apply_fi_trtllm_fp8_per_tensor_moe( layer=td.layer, hidden_states=td.hidden_states, router_logits=score, @@ -277,17 +285,34 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig: td.layer.get_fused_moe_quant_config = get_fused_moe_quant_config td.layer.quant_method = td.layer - flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8( + kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP( + defer_input_quant=quant_config.is_block_quantized + ), + FlashInferExperts( + out_dtype=td.layer.orig_dtype, + quant_config=quant_config, + ep_rank=td.layer.moe_parallel_config.ep_rank, + ep_size=td.layer.moe_parallel_config.ep_size, + tp_rank=td.layer.moe_parallel_config.tp_rank, + tp_size=td.layer.moe_parallel_config.tp_size, + use_dp=False, + use_deepseek_fp8_block_scale=False, + ), + ) + + flashinfer_cutlass_output = kernel( td.hidden_states, - td.layer, + td.layer.w13_weight, + td.layer.w2_weight, topk_weights, topk_ids, + inplace=False, activation=activation, global_num_experts=e, expert_map=None, apply_router_weight_on_input=True, ) - torch.testing.assert_close( output, flashinfer_cutlass_output, atol=5.5e-2, rtol=1e-2 ) diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index a4b6d35987e1..d823d2b7a35d 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -15,7 +15,6 @@ Fp8Config, Fp8KVCacheMethod, Fp8LinearMethod, - Fp8MoeBackend, Fp8MoEMethod, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -278,8 +277,18 @@ def per_tensor_dequantize(tensor, inv_scale, dtype): # this is the case for marlin as well as per-tensor Fp8MoEMethod @pytest.mark.parametrize("use_marlin", [False]) # skip True def test_fp8_reloading( - method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init + method_cls, + is_checkpoint_fp8_serialized, + weight_block_size, + use_marlin, + dist_init, + monkeypatch, ): + # NOTE(rob): this test fails when using DeepGEMM because the + # shapes are invalid. Previously the test was passing because + # we set fp8_backend to None, which sidestepped the issue. + monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "0") + if is_checkpoint_fp8_serialized is False: pytest.skip("FP8 weight reloading does not support online quantization") @@ -307,6 +316,7 @@ def test_fp8_reloading( params_dtype=torch.bfloat16, weight_loader=default_weight_loader, ) + method.use_marlin = use_marlin else: layer = FusedMoE( @@ -325,11 +335,6 @@ def test_fp8_reloading( weight_loader=default_weight_loader, ) - # Fp8LinearMethod uses use_marlin - # Fp8MoEMethod uses fp8_backend - method.use_marlin = use_marlin - method.fp8_backend = Fp8MoeBackend.MARLIN if use_marlin else None - # capture weights format during loading original_metadata = [ (name, param.shape, getattr(param, "weight_loader", default_weight_loader)) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 3d248e7fb994..e63404086ed9 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -73,7 +73,6 @@ def get_config() -> dict[str, Any] | None: CutlassExpertsFp8, CutlassExpertsW4A8Fp8, cutlass_moe_fp4, - cutlass_moe_fp8, cutlass_moe_w4a8_fp8, ) from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts @@ -96,7 +95,6 @@ def get_config() -> dict[str, Any] | None: "fused_experts", "get_config_file_name", "GroupedTopk", - "cutlass_moe_fp8", "cutlass_moe_fp4", "cutlass_moe_w4a8_fp8", "CutlassExpertsFp8", diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index c585cbc1ab5d..6e397f1e76a1 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -249,20 +249,28 @@ def run_cutlass_moe_fp8( class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): def __init__( self, + e: int, + n: int, + k: int, out_dtype: torch.dtype | None, - ab_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides1: torch.Tensor, - c_strides2: torch.Tensor, quant_config: FusedMoEQuantConfig, + device: torch.dtype, ): assert quant_config.use_fp8_w8a8 super().__init__(quant_config) + + # E: num_experts + # N: intermediate size per partition + # K: hidden dim + ab_strides1_c_strides2 = torch.full((e,), k, device=device, dtype=torch.int64) + ab_strides2 = torch.full((e,), n, device=device, dtype=torch.int64) + c_strides1 = torch.full((e,), 2 * n, device=device, dtype=torch.int64) + self.out_dtype = out_dtype - self.ab_strides1 = ab_strides1 + self.ab_strides1 = ab_strides1_c_strides2 self.ab_strides2 = ab_strides2 self.c_strides1 = c_strides1 - self.c_strides2 = c_strides2 + self.c_strides2 = ab_strides1_c_strides2 def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. @@ -329,24 +337,6 @@ def apply( class CutlassExpertsFp8(CutlassExpertsFp8Base): - def __init__( - self, - out_dtype: torch.dtype | None, - ab_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides1: torch.Tensor, - c_strides2: torch.Tensor, - quant_config: FusedMoEQuantConfig, - ): - super().__init__( - out_dtype, - ab_strides1, - ab_strides2, - c_strides1, - c_strides2, - quant_config, - ) - @property def activation_formats( self, @@ -390,21 +380,10 @@ def __init__( self, max_experts_per_worker: int, num_dispatchers: int, - out_dtype: torch.dtype | None, - ab_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides1: torch.Tensor, - c_strides2: torch.Tensor, - quant_config: FusedMoEQuantConfig, + *args, + **kwargs, ): - super().__init__( - out_dtype, - ab_strides1, - ab_strides2, - c_strides1, - c_strides2, - quant_config, - ) + super().__init__(*args, **kwargs) assert max_experts_per_worker > 0 self.max_experts_per_worker = max_experts_per_worker self.num_dispatchers = num_dispatchers @@ -445,113 +424,6 @@ def workspace_shapes( return (workspace1, workspace2, output) -def cutlass_moe_fp8( - a: torch.Tensor, - w1_q: torch.Tensor, - w2_q: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides1: torch.Tensor, - c_strides2: torch.Tensor, - quant_config: FusedMoEQuantConfig, - activation: str = "silu", - expert_map: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - global_num_experts: int = -1, -) -> torch.Tensor: - """ - This function computes a a8w8-quantized Mixture of Experts (MoE) layer - using two sets of quantized weights, w1_q and w2_q, and top-k gating - mechanism. The matrix multiplications are implemented with CUTLASS - grouped gemm. - - Parameters: - - a (torch.Tensor): The input tensor to the MoE layer. - Shape: [M, K] - - w1_q (torch.Tensor): The first set of fp8-quantized expert weights. - Shape: [num_experts, K, 2N] (the weights are passed transposed) - - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. - Shape: [num_experts, N, K] (the weights are passed transposed) - - topk_weights (torch.Tensor): The weights of each token->expert mapping. - - topk_ids (torch.Tensor): The token->expert mappings. - - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. - Shape: [num_experts] or [num_experts, 2N] - - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. - Shape: [num_experts] or [num_experts, K] - - ab_strides1 (torch.Tensor): The input/weight strides for the first gemm. - Shape: [num_experts] - - ab_strides2 (torch.Tensor): The input/weight strides for the second gemm. - Shape: [num_experts] - - c_strides1 (torch.Tensor): The output strides for the first gemm. - Shape: [num_experts] - - c_strides2 (torch.Tensor): The output strides for the second gemm. - Shape: [num_experts] - - per_act_token (Optional[bool]): Whether the scale is per-token or - per-tensor. - - activation (str): The activation function to use. - - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. - Shape: scalar or [M] - - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to - quantize the intermediate result between the gemms. - Shape: scalar or [M] - - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, - every Rank is responsible for a subset of experts. expert_map is a - mapping from global expert-id to local expert-id. When expert_map[i] - is -1, it means that this Rank is not responsible for global - expert-id i. - - apply_router_weight_on_input (bool): When true, the topk weights are - applied directly on the inputs. This is only applicable when topk is 1. - - global_num_experts (int): The total number of experts. - - Returns: - - torch.Tensor: The fp16 output tensor after applying the MoE layer. - """ - assert quant_config is not None - - if quant_config.a1_scale is not None: - assert quant_config.per_act_token_quant == (quant_config.a1_scale.numel() != 1) - if quant_config.a2_scale is not None: - assert quant_config.per_act_token_quant == (quant_config.a2_scale.numel() != 1) - - if quant_config.w1_scale is not None: - if quant_config.per_out_ch_quant: - assert quant_config.w1_scale.dim() > 1 and quant_config.w1_scale.size( - 1 - ) == w1_q.size(1) - else: - assert ( - quant_config.w1_scale.dim() == 1 or quant_config.w1_scale.size(1) == 1 - ) - - num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) - - fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - CutlassExpertsFp8( - out_dtype=a.dtype, - ab_strides1=ab_strides1, - ab_strides2=ab_strides2, - c_strides1=c_strides1, - c_strides2=c_strides2, - quant_config=quant_config, - ), - ) - - return fn( - a, - w1_q, - w2_q, - topk_weights, - topk_ids, - activation=activation, - global_num_experts=num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py new file mode 100644 index 000000000000..14ef6b9aaa5e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk + + +class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): + """Base class for runtime dispatching of expert implementations.""" + + def __init__( + self, + experts: mk.FusedMoEPermuteExpertsUnpermute, + fallback_experts: mk.FusedMoEPermuteExpertsUnpermute, + ): + super().__init__(experts.quant_config) + self.fallback_experts = fallback_experts + self.experts = experts + + @property + def activation_formats( + self, + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + assert ( + self.fallback_experts.activation_formats == self.experts.activation_formats + ) + return self.fallback_experts.activation_formats + + def supports_chunking(self) -> bool: + assert ( + self.experts.supports_chunking() + == self.fallback_experts.supports_chunking() + ) + return ( + self.experts.supports_chunking() + and self.fallback_experts.supports_chunking() + ) + + def supports_expert_map(self) -> bool: + assert ( + self.experts.supports_expert_map() + == self.fallback_experts.supports_expert_map() + ) + return ( + self.experts.supports_expert_map() + and self.fallback_experts.supports_expert_map() + ) + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + e_war = self.experts.finalize_weight_and_reduce_impl() + fbe_war = self.fallback_experts.finalize_weight_and_reduce_impl() + is_dge_war = e_war is not None + is_fbe_war = fbe_war is not None + + if is_dge_war and is_fbe_war: + assert e_war == fbe_war, ( + "Both implementations should agree on WeightAndReduce impls. " + f"Got e_war: {e_war}, and fbe_war: {fbe_war}" + ) + + if e_war is not None: + return e_war + assert fbe_war is not None + return fbe_war + + @abstractmethod + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + raise NotImplementedError + + @abstractmethod + def _select_experts_impl( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + raise NotImplementedError + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + experts = self._select_experts_impl(hidden_states, w1, w2) + experts.apply( + output, + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + activation, + global_num_experts, + expert_map, + a1q_scale, + a2_scale, + workspace13, + workspace2, + expert_tokens_meta, + apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 51e06ac54f49..3bb5a23abb7b 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -100,7 +100,7 @@ def flashinfer_fused_moe_blockscale_fp8_fake( ) -def flashinfer_fused_moe_per_tensor_scale_fp8( +def fi_trtllm_fp8_per_tensor_moe( routing_logits: torch.Tensor, routing_bias: torch.Tensor | None, hidden_states: torch.Tensor, @@ -158,7 +158,7 @@ def flashinfer_fused_moe_per_tensor_scale_fp8( ) -def flashinfer_fused_moe_per_tensor_scale_fp8_fake( +def fi_trtllm_fp8_per_tensor_moe_fake( routing_logits: torch.Tensor, routing_bias: torch.Tensor | None, hidden_states: torch.Tensor, @@ -184,9 +184,9 @@ def flashinfer_fused_moe_per_tensor_scale_fp8_fake( # TODO(bnell): Does this really need to be a torch.op? direct_register_custom_op( - op_name="flashinfer_fused_moe_per_tensor_scale_fp8", - op_func=flashinfer_fused_moe_per_tensor_scale_fp8, + op_name="fi_trtllm_fp8_per_tensor_moe", + op_func=fi_trtllm_fp8_per_tensor_moe, mutates_args=["hidden_states"], - fake_impl=flashinfer_fused_moe_per_tensor_scale_fp8_fake, + fake_impl=fi_trtllm_fp8_per_tensor_moe_fake, tags=(torch.Tag.needs_fixed_stride_order,), ) diff --git a/vllm/model_executor/layers/fused_moe/oracle/__init__.py b/vllm/model_executor/layers/fused_moe/oracle/__init__.py new file mode 100644 index 000000000000..208f01a7cb5e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/oracle/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py new file mode 100644 index 000000000000..f5c3b9af611f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -0,0 +1,358 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, + fp8_w8a8_moe_quant_config, + fp8_w8a16_moe_quant_config, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + FlashinferMoeBackend, + get_flashinfer_moe_backend, + make_fp8_moe_alpha_scales_for_fi, + prepare_fp8_moe_layer_for_fi, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + prepare_fp8_moe_layer_for_deepgemm, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + prepare_fp8_moe_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + cutlass_group_gemm_supported, +) +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import is_deep_gemm_supported +from vllm.utils.flashinfer import has_flashinfer_moe +from vllm.utils.import_utils import has_deep_gemm + +logger = init_logger(__name__) + + +class Fp8MoeBackend(Enum): + NONE = 0 + FLASHINFER_TRTLLM = 1 + FLASHINFER_CUTLASS = 2 + DEEPGEMM = 3 + MARLIN = 4 + TRITON = 5 + AITER = 6 + VLLM_CUTLASS = 7 + + +def select_fp8_moe_backend( + block_quant: bool, + tp_size: int, + with_lora_support: bool, + is_act_and_mul: bool = True, + allow_vllm_cutlass: bool = False, +) -> Fp8MoeBackend: + """ + Select the primary FP8 MoE backend + Note: Shape-specific fallbacks may still occur at runtime. + """ + # TODO(rob): in a future PR, we will query each mk for + # supported features and return the mk directly, just like + # we do for the Attention Backend. + + if with_lora_support: + return Fp8MoeBackend.TRITON + + def _make_log_backend(backend_name: str): + return f"Using {backend_name} backend for FP8 MoE" + + # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100. + if ( + current_platform.is_cuda() + and ( + current_platform.is_device_capability_family(100) + or current_platform.is_device_capability(90) + ) + and envs.VLLM_USE_FLASHINFER_MOE_FP8 + and has_flashinfer_moe() + ): + backend = get_flashinfer_moe_backend() + if backend == FlashinferMoeBackend.TENSORRT_LLM: + logger.info_once(_make_log_backend("FlashInfer TRTLLM")) + if not is_act_and_mul: + raise ValueError( + "FlashInfer TRTLLM FP8 MoE backend only supports " + "act_and_mul gate_up_project fusion. Please set " + "VLLM_USE_FLASHINFER_MOE_FP8=throughput to use the " + "FlashInfer CUTLASS backend instead." + ) + return Fp8MoeBackend.FLASHINFER_TRTLLM + else: + if block_quant and current_platform.is_device_capability_family(100): + raise ValueError( + "FlashInfer FP8 MoE throughput backend does not " + "support block quantization on SM100. Please use " + "VLLM_FLASHINFER_MOE_BACKEND=latency to use the " + "FlashInfer TRTLLM backend instead." + ) + logger.info_once(_make_log_backend("FlashInfer CUTLASS")) + return Fp8MoeBackend.FLASHINFER_CUTLASS + + # weight-only path for older GPUs without native FP8 + if ( + current_platform.is_cuda() and not current_platform.has_device_capability(89) + ) or envs.VLLM_TEST_FORCE_FP8_MARLIN: + logger.info_once(_make_log_backend("Marlin"), scope="local") + return Fp8MoeBackend.MARLIN + + # Determine if we should use DeepGEMM with block-quantized weights: + # - If explicitly set by user, respect their choice + # - If not explicitly set (default), disable when TP size is >= 8 + moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM + if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and tp_size >= 8: + moe_use_deep_gemm = False + logger.info_once( + "DeepGEMM MoE is disabled by default when TP size is >= 8. " + "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.", + scope="local", + ) + + use_deep_gemm = envs.VLLM_USE_DEEP_GEMM + if not is_deep_gemm_supported(): + use_deep_gemm = False + logger.info_once( + "DeepGEMM is disabled because the platform does not support it.", + scope="local", + ) + + if use_deep_gemm and moe_use_deep_gemm and block_quant: + if not has_deep_gemm(): + logger.warning_once( + "DeepGEMM backend requested but not available.", scope="local" + ) + elif is_deep_gemm_supported(): + logger.info_once(_make_log_backend("DeepGEMM"), scope="local") + return Fp8MoeBackend.DEEPGEMM + + if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE: + logger.info_once(_make_log_backend("ROCm AITER"), scope="local") + return Fp8MoeBackend.AITER + + if allow_vllm_cutlass and not block_quant and cutlass_group_gemm_supported(): + logger.info_once(_make_log_backend("vLLM CUTLASS"), scope="local") + return Fp8MoeBackend.VLLM_CUTLASS + + # default to Triton + logger.info_once(_make_log_backend("Triton"), scope="local") + return Fp8MoeBackend.TRITON + + +def convert_to_fp8_moe_kernel_format( + fp8_backend: Fp8MoeBackend, + layer: torch.nn.Module, + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + w13_input_scale: torch.Tensor | None, + w2_input_scale: torch.Tensor | None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + block_quant = hasattr(layer, "weight_block_size") + if fp8_backend == Fp8MoeBackend.DEEPGEMM: + assert block_quant + w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_deepgemm( + w13, + w2, + w13_scale, + w2_scale, + tuple(layer.weight_block_size), + ) + elif fp8_backend == Fp8MoeBackend.AITER: + w13, w2 = rocm_aiter_ops.shuffle_weights(w13, w2) + elif fp8_backend == Fp8MoeBackend.MARLIN: + w13, w2, w13_scale, w2_scale = prepare_fp8_moe_layer_for_marlin( + layer, + w13, + w2, + w13_scale, + w2_scale, + ) + elif fp8_backend in [ + Fp8MoeBackend.FLASHINFER_CUTLASS, + Fp8MoeBackend.FLASHINFER_TRTLLM, + ]: + w13, w2, w13_scale = prepare_fp8_moe_layer_for_fi( + layer=layer, + w13=w13, + w2=w2, + w13_scale=w13_scale, + w13_input_scale=w13_input_scale, + w2_scale=w2_scale, + w2_input_scale=w2_input_scale, + is_trtllm=(fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM), + ) + + return w13, w2, w13_scale, w2_scale + + +def make_fp8_moe_quant_config( + fp8_backend: Fp8MoeBackend, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + block_shape: list[int] | None = None, +) -> FusedMoEQuantConfig | None: + """ + Create FusedMoEQuantConfig for the specifed FP8 Backend. + The FusedMoEQuantConfig holds the scales that are used + at runtime by the Modular Kernel abstraction. + + Note that certain kernels (e.g. Flashinfer CUTLASS) need + special Quant configs to handle non-standard inputs to + their kernel interfaces. + + In a future PR, we will have this function should be + a method of the modular kernel itself. + """ + # TRTLLM does not use Modular Kernel abstraction yet. + if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + return None + + # MARLIN is mixed precision W8A16 config. + if fp8_backend == Fp8MoeBackend.MARLIN: + return fp8_w8a16_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + block_shape=block_shape, + ) + + # Flashinfer CUTLASS per-tensor uses single dq scale + # (alpha = w_scale * a_scale) and inverse a2 scale. + if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None: + assert a1_scale is not None and a2_scale is not None + g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( + w1_scale, + a1_scale, + w2_scale, + a2_scale, + ) + return fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + a1_gscale=(1.0 / a1_scale), + a2_gscale=(1.0 / a2_scale), + g1_alphas=g1_alphas, + g2_alphas=g2_alphas, + ) + # All other backends use normal config. + return fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + ) + + +def make_fp8_moe_kernel( + layer: torch.nn.Module, + moe_quant_config: FusedMoEQuantConfig, + moe_config: FusedMoEConfig, + fp8_backend: Fp8MoeBackend, +) -> tuple[mk.FusedMoEModularKernel, bool]: + # Delayed import is required since the oracle is imported + # by CPU backends which cannot import all of these experts. + # TODO: update the experts to make this not happen. + from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP, + ) + + # NOTE(rob): this is a WIP refactor. We are first migrating + # all of the kernels in the TP case to use mk. Once this is + # done, then we will initialzie the TP case and DP/EP case + # via the same code path (i.e. via maybe_init_modular_kernel). + # NOTE(rob): in progress migrating all into this format. + use_inplace = True + if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( + FlashInferExperts, + ) + + kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP( + defer_input_quant=moe_quant_config.is_block_quantized + ), + FlashInferExperts( + out_dtype=layer.orig_dtype, + quant_config=moe_quant_config, + ep_rank=moe_config.ep_rank, + ep_size=moe_config.ep_size, + tp_rank=moe_config.tp_rank, + tp_size=moe_config.tp_size, + use_dp=(moe_config.dp_size > 1), + use_deepseek_fp8_block_scale=moe_quant_config.is_block_quantized, + ), + ) + use_inplace = False + + elif fp8_backend == Fp8MoeBackend.AITER: + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + AiterExperts, + ) + + kernel = mk.FusedMoEModularKernel( + # TODO: make defer_input_quant an attr of the AiterExperts + MoEPrepareAndFinalizeNoEP(defer_input_quant=True), + AiterExperts(quant_config=moe_quant_config), + ) + elif fp8_backend == Fp8MoeBackend.MARLIN: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, + ) + + kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + MarlinExperts(quant_config=moe_quant_config), + ) + elif fp8_backend == Fp8MoeBackend.VLLM_CUTLASS: + from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import ( + TritonOrCutlassExperts, + ) + + kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + TritonOrCutlassExperts( + out_dtype=moe_config.in_dtype, + e=layer.local_num_experts, + n=layer.intermediate_size_per_partition, + k=layer.hidden_size, + device=layer.w13_weight.device, + quant_config=moe_quant_config, + ), + ) + elif fp8_backend == Fp8MoeBackend.DEEPGEMM: + from vllm.model_executor.layers.fused_moe import ( + TritonOrDeepGemmExperts, + ) + + kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + TritonOrDeepGemmExperts(quant_config=moe_quant_config), + ) + else: + from vllm.model_executor.layers.fused_moe.fused_moe import ( + TritonExperts, + ) + + assert fp8_backend == Fp8MoeBackend.TRITON + kernel = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + TritonExperts(quant_config=moe_quant_config), + ) + return kernel, use_inplace diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py new file mode 100644 index 000000000000..e874ba609be0 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts +from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts +from vllm.platforms import current_platform + + +class TritonOrCutlassExperts(FallbackExperts): + """Cutlass with fallback to Triton for low latency shapes on SM100.""" + + def __init__( + self, + e: int, + n: int, + k: int, + out_dtype: torch.dtype | None, + quant_config: FusedMoEQuantConfig, + device: torch.dtype, + ): + self.is_sm100 = current_platform.has_device_capability(100) + super().__init__( + experts=CutlassExpertsFp8(e, n, k, out_dtype, quant_config, device), + fallback_experts=TritonExperts(quant_config), + ) + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + # Small batch fallback for sm100. + if self.is_sm100 and M <= 8: + return self.fallback_experts.workspace_shapes( + M, + N, + K, + topk, + global_num_experts, + local_num_experts, + expert_tokens_meta, + ) + else: + return self.experts.workspace_shapes( + M, + N, + K, + topk, + global_num_experts, + local_num_experts, + expert_tokens_meta, + ) + + def _select_experts_impl( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + ) -> mk.FusedMoEPermuteExpertsUnpermute: + # Small batch fallback for sm100. + if self.is_sm100 and hidden_states.shape[0] <= 8: + return self.fallback_experts + else: + return self.experts diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index b8e0837162ef..4fcc1a7c1fc0 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -10,78 +10,22 @@ _valid_deep_gemm, _valid_deep_gemm_shape, ) +from vllm.model_executor.layers.fused_moe.fallback import FallbackExperts from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts from vllm.utils.deep_gemm import ( - get_mk_alignment_for_contiguous_layout, is_deep_gemm_e8m0_used, ) -class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( - self, - quant_config: FusedMoEQuantConfig, - allow_deep_gemm: bool = False, - ): - super().__init__(quant_config) - - self.triton_expert = TritonExperts(quant_config) - - self.allow_deep_gemm = ( - allow_deep_gemm - and self.quant_config.use_fp8_w8a8 - and self.block_shape == get_mk_alignment_for_contiguous_layout() - ) - - self.deep_gemm_expert = ( - DeepGemmExperts(self.quant_config) if self.allow_deep_gemm else None - ) +class TritonOrDeepGemmExperts(FallbackExperts): + """DeepGemm with fallback to Triton for low latency shapes.""" - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - assert ( - self.deep_gemm_expert is None - or self.triton_expert.activation_formats - == self.deep_gemm_expert.activation_formats - ) - return self.triton_expert.activation_formats - - def supports_chunking(self) -> bool: - dge = self.deep_gemm_expert - te = self.triton_expert - return (dge is None or dge.supports_chunking()) and ( - te is None or te.supports_chunking() + def __init__(self, quant_config: FusedMoEQuantConfig): + super().__init__( + experts=DeepGemmExperts(quant_config), + fallback_experts=TritonExperts(quant_config), ) - def supports_expert_map(self) -> bool: - dge = self.deep_gemm_expert - te = self.triton_expert - return (dge is None or dge.supports_expert_map()) and ( - te is None or te.supports_expert_map() - ) - - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - dge = self.deep_gemm_expert - te = self.triton_expert - dge_war = dge.finalize_weight_and_reduce_impl() if dge else None - te_war = te.finalize_weight_and_reduce_impl() if te else None - is_dge_war = dge_war is not None - is_te_war = te_war is not None - - if is_dge_war and is_te_war: - assert dge_war == te_war, ( - "Both implementations should agree on WeightAndReduce impls. " - f"Got dge_war: {dge_war}, and te_war: {te_war}" - ) - - if dge_war is not None: - return dge_war - - assert te_war is not None - return te_war - def workspace_shapes( self, M: int, @@ -95,11 +39,8 @@ def workspace_shapes( # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and ( - is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K) - ): - assert self.deep_gemm_expert is not None - return self.deep_gemm_expert.workspace_shapes( + if is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K): + return self.experts.workspace_shapes( M, N, K, @@ -109,7 +50,7 @@ def workspace_shapes( expert_tokens_meta, ) else: - return self.triton_expert.workspace_shapes( + return self.fallback_experts.workspace_shapes( M, N, K, @@ -119,45 +60,13 @@ def workspace_shapes( expert_tokens_meta, ) - def apply( + def _select_experts_impl( self, - output: torch.Tensor, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: torch.Tensor | None, - a1q_scale: torch.Tensor | None, - a2_scale: torch.Tensor | None, - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - apply_router_weight_on_input: bool, - ): - use_deep_gemm = self.allow_deep_gemm and ( - is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2) - ) - - experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert - assert experts is not None - - experts.apply( - output, - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - activation, - global_num_experts, - expert_map, - a1q_scale, - a2_scale, - workspace13, - workspace2, - expert_tokens_meta, - apply_router_weight_on_input, - ) + ) -> mk.FusedMoEPermuteExpertsUnpermute: + if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2): + return self.experts + else: + return self.fallback_experts diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 1094d9d55a1b..a2b3aec4457e 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -13,10 +13,8 @@ ) from torch.nn.parameter import Parameter -import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops -from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( @@ -31,6 +29,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, + fp8_w8a16_moe_quant_config, int4_w4a16_moe_quant_config, int4_w4afp8_moe_quant_config, int8_w8a8_moe_quant_config, @@ -46,11 +45,16 @@ MarlinExperts, fused_marlin_moe, ) +from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( + Fp8MoeBackend, + convert_to_fp8_moe_kernel_format, + make_fp8_moe_kernel, + select_fp8_moe_backend, +) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP, ) -from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, flashinfer_trtllm_fp4_moe, @@ -63,8 +67,8 @@ get_flashinfer_moe_backend, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - expert_weight_is_col_major, - requant_weight_ue8m0_inplace, + process_fp8_input_tensor_strategy_moe, + process_fp8_weight_tensor_strategy_moe, ) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_moe_marlin_supports_layer, @@ -76,29 +80,17 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( prepare_moe_fp4_layer_for_marlin, ) -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - prepare_moe_fp8_layer_for_marlin, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( convert_bf16_scales_to_fp8, convert_packed_uint4b8_to_signed_int4_inplace, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, normalize_e4m3fn_to_e4m3fnuz, - per_tensor_dequantize, ) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import CpuArchEnum, current_platform from vllm.scalar_type import scalar_types -from vllm.utils.deep_gemm import ( - get_col_major_tma_aligned_tensor, - get_mk_alignment_for_contiguous_layout, - is_deep_gemm_e8m0_used, - is_deep_gemm_supported, -) -from vllm.utils.import_utils import has_deep_gemm logger = init_logger(__name__) @@ -657,10 +649,6 @@ def __init__( moe: FusedMoEConfig, layer_name: str | None = None, ): - from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsConfig, - ) - super().__init__(moe) self.weight_quant = weight_quant self.input_quant = input_quant @@ -687,42 +675,31 @@ def __init__( "For FP8 Fused MoE layer, we require either per tensor or " "channelwise, dynamic per token quantization." ) - - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization - self.use_marlin = ( - not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN - and not self.block_quant - ) - # Disable marlin for rocm - if current_platform.is_rocm(): - self.use_marlin = False - - self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() - - # cutlass path - self.is_fp8_w8a8_sm100 = CompressedTensorsConfig._is_fp8_w8a8_sm100( - self.weight_quant, self.input_quant + self.fp8_backend = select_fp8_moe_backend( + block_quant=self.block_quant, + tp_size=moe.tp_size, + with_lora_support=moe.is_lora_enabled, + # TODO(rob): enable selecting this externally. + allow_vllm_cutlass=True, ) - self.use_cutlass = not self.block_quant and ( - CompressedTensorsConfig._is_fp8_w8a8_sm90( - self.weight_quant, self.input_quant + if self.fp8_backend != Fp8MoeBackend.MARLIN: + per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN + per_channel_quant = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + ) + if per_act_token != per_channel_quant: + raise NotImplementedError( + "For FP8 Fused MoE layers, per-token and per-channel must be " + "used together." + ) + # TODO(rob): hook this up in a follow up PR. + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: + raise NotImplementedError( + "FlashInfer TRTLLM backend not supported for compressed-tensors yet." ) - or self.is_fp8_w8a8_sm100 - ) self.disable_expert_map = False - self.layer_name = layer_name - self.marlin_input_dtype = ( - get_marlin_input_dtype(layer_name) if self.use_marlin else None - ) - self.allow_deep_gemm = ( - self.block_quant - and envs.VLLM_MOE_USE_DEEP_GEMM - and is_deep_gemm_supported() - and list(self.weight_block_size) == get_mk_alignment_for_contiguous_layout() - ) + self.kernel: mk.FusedMoEModularKernel | None = None def create_weights( self, @@ -880,163 +857,75 @@ def create_weights( layer.w2_input_scale = None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # Fp8 moe kernels require a single activation scale. - # We take the max of all the scales in case they differ. - if self.static_input_scales: - assert self.input_quant.strategy == QuantizationStrategy.TENSOR - if layer.w13_input_scale is None or layer.w2_input_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - if not all_close_1d(layer.w13_input_scale) or not all_close_1d( - layer.w2_input_scale - ): - logger.warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer." - ) - layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False - ) - layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False - ) - + # Allow for accessing weights and scales in standard way. + w13 = layer.w13_weight + w2 = layer.w2_weight + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + w13_input_scale = layer.w13_input_scale + w2_input_scale = layer.w2_input_scale + + # MI300x and MI325x use FNUZ format for FP8. Convert if needed. if current_platform.is_fp8_fnuz(): - # Normalize the weights and scales - w13_weight, w13_weight_scale, w13_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale - ) + w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz( + w13, w13_scale, w13_input_scale ) - w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale - ) - # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter( - w13_weight_scale, requires_grad=False + w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( + w2, w2_scale, w2_input_scale ) - if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter( - w13_input_scale, requires_grad=False - ) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter( - w2_weight_scale, requires_grad=False - ) - if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter( - w2_input_scale, requires_grad=False - ) - # For Per-TENSOR case, Fp8 moe kernel needs single weight scale - # for w13 per expert. Use max then dequant and requant each expert. - if self.weight_quant.strategy == QuantizationStrategy.TENSOR: - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start : start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id], - ) - layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( - ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) - ) - start += shard_size - layer.w13_weight_scale = torch.nn.Parameter( - max_w13_scales, requires_grad=False - ) - - # Property to determine if AITER is used - if self.rocm_aiter_moe_enabled: - # reshaping weights is required for aiter moe kernel. - shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( - layer.w13_weight.data, layer.w2_weight.data + # Per tensor kernels require single activation scale. Use the max. + if self.static_input_scales: + assert self.input_quant.strategy == QuantizationStrategy.TENSOR + assert w13_input_scale is not None and w2_input_scale is not None + w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe( + w13_input_scale, w2_input_scale ) + replace_parameter(layer, "w13_input_scale", w13_input_scale) + replace_parameter(layer, "w2_input_scale", w2_input_scale) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - - elif self.use_marlin: - ( - workspace, - w13_weight, - w2_weight, - w13_weight_scale, - w2_weight_scale, - ) = prepare_moe_fp8_layer_for_marlin( - layer, - layer.w13_weight, - layer.w2_weight, - layer.w13_weight_scale, - layer.w2_weight_scale, - input_dtype=self.marlin_input_dtype, + # Per-tensor kernels use a single scale, for W13, but on disk there + # is a separate scale for W1 and W3. Requantize with the max scale. + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + process_fp8_weight_tensor_strategy_moe( + w13, + w13_scale, + shard_size=layer.intermediate_size_per_partition, + num_experts=layer.num_local_experts, + ) + + w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format( + fp8_backend=self.fp8_backend, + layer=layer, + w13=w13, + w2=w2, + w13_scale=w13_scale, + w2_scale=w2_scale, + w13_input_scale=w13_input_scale, + w2_input_scale=w2_input_scale, + ) + + # Replace parameters with updated versions. Note that this helper + # function ensures the replacement is compatible with RL weight reloads. + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w2_weight_scale", w2_scale) + + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + if self.moe_quant_config: + self.kernel, self.use_inplace = make_fp8_moe_kernel( + layer=layer, + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + fp8_backend=self.fp8_backend, ) - layer.workspace = workspace - replace_parameter(layer, "w13_weight", w13_weight) - replace_parameter(layer, "w2_weight", w2_weight) - replace_parameter(layer, "w13_weight_scale", w13_weight_scale) - replace_parameter(layer, "w2_weight_scale", w2_weight_scale) - - if self.use_cutlass: - assert self.weight_quant.strategy != QuantizationStrategy.BLOCK - device = layer.w13_weight.device - # ab_strides1 and c_strides2 are the same - self.ab_strides1_c_strides2 = torch.full( - (layer.local_num_experts,), - layer.hidden_size, - device=device, - dtype=torch.int64, - ) - self.ab_strides2 = torch.full( - (layer.local_num_experts,), - layer.intermediate_size_per_partition, - device=device, - dtype=torch.int64, - ) - self.c_strides1 = torch.full( - (layer.local_num_experts,), - 2 * layer.intermediate_size_per_partition, - device=device, - dtype=torch.int64, - ) - - if is_deep_gemm_e8m0_used() and self.block_quant: - assert layer.weight_block_size is not None - # Re-quantise the expert weights so their scales are UE8M0. - block_sz = tuple(layer.weight_block_size) - requant_weight_ue8m0_inplace( - layer.w13_weight.data, - layer.w13_weight_scale.data, - block_sz, - ) - requant_weight_ue8m0_inplace( - layer.w2_weight.data, - layer.w2_weight_scale.data, - block_sz, - ) - - # Ensure column-major TMA alignment expected by DeepGEMM. - if expert_weight_is_col_major(layer.w13_weight_scale): - layer.w13_weight_scale = get_col_major_tma_aligned_tensor( - layer.w13_weight_scale - ) - if expert_weight_is_col_major(layer.w2_weight_scale): - layer.w2_weight_scale = get_col_major_tma_aligned_tensor( - layer.w2_weight_scale - ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if self.use_marlin or self.rocm_aiter_moe_enabled: + if self.fp8_backend in [Fp8MoeBackend.MARLIN, Fp8MoeBackend.AITER]: return None else: return super().maybe_make_prepare_finalize(routing_tables) @@ -1048,7 +937,7 @@ def select_gemm_impl( ) -> FusedMoEPermuteExpertsUnpermute: # cutlass path assert self.moe_quant_config is not None - if self.use_cutlass: + if self.fp8_backend == Fp8MoeBackend.VLLM_CUTLASS: from vllm.model_executor.layers.fused_moe import ( CutlassBatchedExpertsFp8, CutlassExpertsFp8, @@ -1064,26 +953,27 @@ def select_gemm_impl( ): logger.debug("CutlassBatchedExpertsFp8(%s)", self.__class__.__name__) experts = CutlassBatchedExpertsFp8( - self.moe.num_local_experts, - num_dispatchers, - self.moe.in_dtype, - ab_strides1=self.ab_strides1_c_strides2, - ab_strides2=self.ab_strides2, - c_strides1=self.c_strides1, - c_strides2=self.ab_strides1_c_strides2, + max_experts_per_worker=self.moe.num_local_experts, + num_dispatchers=num_dispatchers, + out_dtype=self.moe.in_dtype, + e=layer.local_num_experts, + n=layer.intermediate_size_per_partition, + k=layer.hidden_size, + device=layer.w13_weight.device, quant_config=self.moe_quant_config, ) else: logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__) experts = CutlassExpertsFp8( - self.moe.in_dtype, - ab_strides1=self.ab_strides1_c_strides2, - ab_strides2=self.ab_strides2, - c_strides1=self.c_strides1, - c_strides2=self.ab_strides1_c_strides2, + out_dtype=self.moe.in_dtype, + e=layer.local_num_experts, + n=layer.intermediate_size_per_partition, + k=layer.hidden_size, + device=layer.w13_weight.device, quant_config=self.moe_quant_config, ) + # TODO(rob): investigate disable_expert_map self.disable_expert_map = ( num_dispatchers > 1 or not experts.supports_expert_map() ) @@ -1096,13 +986,14 @@ def select_gemm_impl( from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts, ) + from vllm.model_executor.layers.fused_moe.fused_moe import ( + TritonExperts, + ) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, ) - assert not self.rocm_aiter_moe_enabled and not self.use_marlin - - use_deep_gemm = envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM + assert self.fp8_backend not in [Fp8MoeBackend.AITER, Fp8MoeBackend.MARLIN] if ( prepare_finalize.activation_format @@ -1111,28 +1002,7 @@ def select_gemm_impl( max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None - if use_deep_gemm and not has_deep_gemm(): - raise RuntimeError( - "DeepGEMM requested for MoE layer but not installed." - ) - - compatible_with_deep_gemm = ( - self.moe_quant_config.use_fp8_w8a8 - and self.moe_quant_config.block_shape - == get_mk_alignment_for_contiguous_layout() - ) - - # If this MoE layer is compatible with DeepGEMM, the proper env - # vars are set and DeepGEMM is not installed, throw an error. - if use_deep_gemm and compatible_with_deep_gemm and not has_deep_gemm(): - raise RuntimeError( - f"MoE layer incompatible with DeepGEMM, expected " - f"fp8==True, got {self.moe_quant_config.use_fp8_w8a8}" - f"or block_shape {self.moe_quant_config.block_shape}" - f"=={get_mk_alignment_for_contiguous_layout()}." - ) - - if use_deep_gemm and compatible_with_deep_gemm and has_deep_gemm(): + if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__) return BatchedDeepGemmExperts( max_num_tokens=max_num_tokens_per_rank, @@ -1148,17 +1018,22 @@ def select_gemm_impl( ) else: - logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) - return TritonOrDeepGemmExperts( - self.moe_quant_config, - allow_deep_gemm=use_deep_gemm, - ) + if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: + logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) + return TritonOrDeepGemmExperts(self.moe_quant_config) + else: + logger.debug("TritonExperts(%s)", self.__class__.__name__) + return TritonExperts(self.moe_quant_config) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - if self.use_marlin: - return None + if self.fp8_backend == Fp8MoeBackend.MARLIN: + return fp8_w8a16_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + block_shape=self.weight_block_size, + ) per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL @@ -1184,118 +1059,23 @@ def apply( router_logits=router_logits, ) - per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN - per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL - - if self.use_marlin: - assert layer.activation == "silu", ( - f"{layer.activation} not supported for Marlin MoE." - ) - return fused_marlin_moe( - x, - layer.w13_weight, - layer.w2_weight, - None, - None, - layer.w13_weight_scale, - layer.w2_weight_scale, - router_logits, - topk_weights, - topk_ids, - quant_type_id=scalar_types.float8_e4m3fn.id, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - input_dtype=self.marlin_input_dtype, - workspace=layer.workspace, - ) - - elif self.rocm_aiter_moe_enabled: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - rocm_aiter_fused_experts, - ) - - assert per_act_token == per_channel_quant - assert self.moe_quant_config is not None - return rocm_aiter_fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=layer.activation, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - expert_map=layer.expert_map, - quant_config=self.moe_quant_config, - ) - - # cutlass path - elif self.use_cutlass: - assert self.moe_quant_config is not None - - # small-batch fallback on SM100 - if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: - from vllm.model_executor.layers.fused_moe import fused_experts - - assert per_act_token == per_channel_quant - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=layer.activation, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=layer.global_num_experts, - expert_map=None - if self.disable_expert_map - else layer.expert_map, # ??? - quant_config=self.moe_quant_config, - allow_deep_gemm=self.allow_deep_gemm, - ) - else: - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8, - ) - - assert per_act_token == per_channel_quant - assert self.moe_quant_config is not None - return cutlass_moe_fp8( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - quant_config=self.moe_quant_config, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=None if self.disable_expert_map else layer.expert_map, - ab_strides1=self.ab_strides1_c_strides2, - ab_strides2=self.ab_strides2, - c_strides1=self.c_strides1, - c_strides2=self.ab_strides1_c_strides2, - ) - - else: - from vllm.model_executor.layers.fused_moe import fused_experts + assert self.kernel is not None + result = self.kernel( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=self.use_inplace, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + # TODO(rob): investigate the disable_expert_map introduced by: + # https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501 + expert_map=None if self.disable_expert_map else layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + ) - assert per_act_token == per_channel_quant - assert self.moe_quant_config is not None - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=layer.activation, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - quant_config=self.moe_quant_config, - allow_deep_gemm=self.allow_deep_gemm, - ) + return result @property def supports_eplb(self) -> bool: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1223c6902e5f..2879315a6886 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from enum import Enum from typing import TYPE_CHECKING, Any, Optional import torch @@ -27,13 +26,17 @@ FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEParallelConfig, FusedMoEQuantConfig, RoutingMethodType, - fp8_w8a8_moe_quant_config, - fp8_w8a16_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod +from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( + Fp8MoeBackend, + convert_to_fp8_moe_kernel_format, + make_fp8_moe_kernel, + make_fp8_moe_quant_config, + select_fp8_moe_backend, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -46,25 +49,20 @@ ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - FlashinferMoeBackend, - apply_flashinfer_per_tensor_scale_fp8, + apply_fi_trtllm_fp8_per_tensor_moe, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - get_flashinfer_moe_backend, - make_fp8_moe_alpha_scales_for_fi, - register_scales_for_trtllm_fp8_per_tensor_moe, - rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, - swap_w13_to_w31, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, - deepgemm_post_process_fp8_weight_block, maybe_post_process_fp8_weight_block, + process_fp8_input_tensor_strategy_moe, process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy, + process_fp8_weight_tensor_strategy_moe, validate_fp8_block_shape, ) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( @@ -73,7 +71,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, - prepare_moe_fp8_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -81,12 +78,10 @@ ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, - all_close_1d, cutlass_block_fp8_supported, cutlass_fp8_supported, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, - per_tensor_dequantize, ) from vllm.model_executor.parameter import ( BlockQuantScaleParameter, @@ -96,11 +91,8 @@ from vllm.model_executor.utils import replace_parameter, set_weight_attrs from vllm.platforms import current_platform from vllm.utils.deep_gemm import ( - is_deep_gemm_e8m0_used, is_deep_gemm_supported, ) -from vllm.utils.flashinfer import has_flashinfer_moe -from vllm.utils.import_utils import has_deep_gemm if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -110,107 +102,6 @@ logger = init_logger(__name__) -class Fp8MoeBackend(Enum): - NONE = 0 - FLASHINFER_TRTLLM = 1 - FLASHINFER_CUTLASS = 2 - DEEPGEMM = 3 - MARLIN = 4 - TRITON = 5 - AITER = 6 - - -def get_fp8_moe_backend( - block_quant: bool, - moe_parallel_config: FusedMoEParallelConfig, - with_lora_support: bool, -) -> Fp8MoeBackend | None: - """ - Select the primary FP8 MoE backend - Note: Shape-specific fallbacks may still occur at runtime. - """ - if current_platform.is_xpu(): - return None - if with_lora_support: - return Fp8MoeBackend.TRITON - # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100. - if ( - current_platform.is_cuda() - and ( - current_platform.is_device_capability_family(100) - or current_platform.is_device_capability(90) - ) - and envs.VLLM_USE_FLASHINFER_MOE_FP8 - and has_flashinfer_moe() - ): - backend = get_flashinfer_moe_backend() - if backend == FlashinferMoeBackend.TENSORRT_LLM: - logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100") - return Fp8MoeBackend.FLASHINFER_TRTLLM - else: - if block_quant and current_platform.is_device_capability_family(100): - raise ValueError( - "FlashInfer FP8 MoE throughput backend does not " - "support block quantization on SM100. Please use " - "VLLM_FLASHINFER_MOE_BACKEND=latency " - "instead." - ) - logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100") - return Fp8MoeBackend.FLASHINFER_CUTLASS - - # weight-only path for older GPUs without native FP8 - use_marlin = ( - not current_platform.has_device_capability(89) - or envs.VLLM_TEST_FORCE_FP8_MARLIN - ) - if current_platform.is_rocm(): - use_marlin = False - if use_marlin: - logger.info_once("Using Marlin backend for FP8 MoE") - return Fp8MoeBackend.MARLIN - - # Determine if we should use DeepGEMM with block-quantized weights: - # - If explicitly set by user, respect their choice - # - If not explicitly set (default), disable when TP size is >= 8 - moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM - if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8: - moe_use_deep_gemm = False - logger.info_once( - "DeepGEMM MoE is disabled by default when TP size is >= 8. " - "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.", - scope="local", - ) - - # Determine if we should use DeepGEMM (top-level enable switch) - # - If explicitly set by user, respect their choice - # - If not platform supports DeepGEMM, disable it - # This helps avoid warning messages on unsupported platforms. - use_deep_gemm = envs.VLLM_USE_DEEP_GEMM - if not is_deep_gemm_supported(): - use_deep_gemm = False - logger.info_once( - "DeepGEMM is disabled because the platform does not support it.", - scope="local", - ) - - if use_deep_gemm and moe_use_deep_gemm and block_quant: - if not has_deep_gemm(): - logger.warning_once( - "DeepGEMM backend requested but not available.", scope="local" - ) - elif is_deep_gemm_supported(): - logger.info_once("Using DeepGEMM backend for FP8 MoE", scope="local") - return Fp8MoeBackend.DEEPGEMM - - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MOE: - logger.info_once("Using ROCm AITER backend for FP8 MoE", scope="local") - return Fp8MoeBackend.AITER - - # default to Triton - logger.info_once("Using Triton backend for FP8 MoE") - return Fp8MoeBackend.TRITON - - class Fp8Config(QuantizationConfig): """Config class for FP8.""" @@ -348,7 +239,6 @@ def get_quant_method( moe_quant_method = Fp8MoEMethod(self, layer) else: moe_quant_method = Fp8OnlineMoEMethod(self, layer) - moe_quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return moe_quant_method elif isinstance(layer, Attention): return Fp8KVCacheMethod(self) @@ -736,40 +626,24 @@ class Fp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): super().__init__(layer.moe_config) - self.layer = layer self.quant_config = quant_config self.weight_block_size = self.quant_config.weight_block_size self.block_quant: bool = self.weight_block_size is not None self.weight_scale_name = ( "weight_scale_inv" if self.block_quant else "weight_scale" ) - self.fp8_backend = get_fp8_moe_backend( - self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled + self.fp8_backend = select_fp8_moe_backend( + block_quant=self.block_quant, + tp_size=layer.moe_parallel_config.tp_size, + with_lora_support=self.moe.is_lora_enabled, ) - self.marlin_input_dtype = None - self.flashinfer_moe_backend: FlashinferMoeBackend | None = None - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM - elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: if self.block_quant and self.weight_block_size != [128, 128]: raise NotImplementedError( "FlashInfer CUTLASS FP8 MoE backend only supports block " "size [128, 128]." ) - if not self.block_quant: - if layer.renormalize or layer.custom_routing_function is not None: - raise NotImplementedError( - "FlashInfer CUTLASS FP8 MoE backend does custom routing " - f"function or renormalization, but got {layer.renormalize} and " - f"{layer.custom_routing_function}." - ) - if layer.scoring_func != "sigmoid": - raise NotImplementedError( - "FlashInfer CUTLASS FP8 MoE backend only supports " - f"'sigmoid' scoring function, but got {layer.scoring_func}." - ) if layer.activation != "silu": raise NotImplementedError( "FlashInfer CUTLASS FP8 MoE backend only supports SiLU " @@ -778,12 +652,17 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): dynamic_per_token = ( not self.block_quant and self.quant_config.activation_scheme != "static" ) - if self.flashinfer_moe_backend is not None and dynamic_per_token: + if dynamic_per_token and self.fp8_backend in [ + Fp8MoeBackend.FLASHINFER_TRTLLM, + Fp8MoeBackend.FLASHINFER_CUTLASS, + ]: raise NotImplementedError( "FlashInfer FP8 MoE backend does not support dynamic per token " "activation quantization." ) + self.kernel: mk.FusedMoEModularKernel | None = None + def create_weights( self, layer: Module, @@ -907,148 +786,43 @@ def create_weights( layer.w13_input_scale = None layer.w2_input_scale = None - def _convert_weights_to_kernel_format( + def _setup_kernel( self, layer: Module, - w13_weight: torch.Tensor, - w2_weight: torch.Tensor, - w13_weight_scale: torch.Tensor, - w2_weight_scale: torch.Tensor, + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, w13_input_scale: torch.Tensor | None, w2_input_scale: torch.Tensor | None, ) -> None: - if self.fp8_backend == Fp8MoeBackend.DEEPGEMM: - assert self.block_quant - w13_weight, w13_weight_scale = deepgemm_post_process_fp8_weight_block( - wq=w13_weight, - ws=w13_weight_scale, - quant_block_shape=tuple(layer.weight_block_size), - use_e8m0=is_deep_gemm_e8m0_used(), - ) - w2_weight, w2_weight_scale = deepgemm_post_process_fp8_weight_block( - wq=w2_weight, - ws=w2_weight_scale, - quant_block_shape=tuple(layer.weight_block_size), - use_e8m0=is_deep_gemm_e8m0_used(), - ) - elif self.fp8_backend == Fp8MoeBackend.AITER: - w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights( - w13_weight, w2_weight - ) - elif self.fp8_backend == Fp8MoeBackend.MARLIN: - ( - workspace, - w13_weight, - w2_weight, - w13_weight_scale, - w2_weight_scale, - ) = prepare_moe_fp8_layer_for_marlin( - layer, - w13_weight, - w2_weight, - w13_weight_scale, - w2_weight_scale, - input_dtype=self.marlin_input_dtype, - ) - layer.workspace = workspace - - elif self.fp8_backend in [ - Fp8MoeBackend.FLASHINFER_CUTLASS, - Fp8MoeBackend.FLASHINFER_TRTLLM, - ]: - w13_weight = swap_w13_to_w31(w13_weight) - if self.block_quant: - w13_weight_scale = swap_w13_to_w31(w13_weight_scale) - else: - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - rotate_flashinfer_fp8_moe_weights(w13_weight, w2_weight) - register_scales_for_trtllm_fp8_per_tensor_moe( - layer=layer, - w13_weight_scale=w13_weight, - w13_input_scale=w13_input_scale, - w2_weight_scale=w2_weight, - w2_input_scale=w2_input_scale, - ) - - elif self.fp8_backend == Fp8MoeBackend.AITER: - w13_weight, w2_weight = rocm_aiter_ops.shuffle_weights( - w13_weight, w2_weight - ) + # Shuffle weights to runtime format. + w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format( + fp8_backend=self.fp8_backend, + layer=layer, + w13=w13, + w2=w2, + w13_scale=w13_scale, + w2_scale=w2_scale, + w13_input_scale=w13_input_scale, + w2_input_scale=w2_input_scale, + ) # Replace parameters with updated versions. Note that this helper # function ensures the replacement is compatible with RL weight reloads. - replace_parameter(layer, "w13_weight", w13_weight) - replace_parameter(layer, "w2_weight", w2_weight) - replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_weight_scale) - replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_weight_scale) - - def _setup_kernel(self, layer: Module) -> None: - """Setup Modular Kernel for TP Case""" - # NOTE(rob): this is a WIP refactor. We are first migrating - # all of the kernels in the TP case to use mk. Once this is - # done, then we will initialzie the TP case and DP/EP case - # via the same code path (i.e. via maybe_init_modular_kernel). - # NOTE(rob): in progress migrating all into this format. - - from vllm.model_executor.layers.fused_moe import ( - TritonOrDeepGemmExperts, - ) - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - FlashInferExperts, - ) - from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( - MarlinExperts, - ) - from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, - ) - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - AiterExperts, - ) - - # Flashinfer TRTLLM does not use the modular kernel abstraction. - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - return + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale) + replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale) + # Setup modular kernel for TP case. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - assert self.moe_quant_config is not None - self.use_inplace = True - - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - self.kernel = mk.FusedMoEModularKernel( - # TODO: make defer_input_quant an attr of the FlashInferExperts - MoEPrepareAndFinalizeNoEP(defer_input_quant=self.block_quant), - FlashInferExperts( - out_dtype=layer.orig_dtype, - quant_config=self.moe_quant_config, - ep_rank=self.moe.ep_rank, - ep_size=self.moe.ep_size, - tp_rank=self.moe.tp_rank, - tp_size=self.moe.tp_size, - use_dp=(self.moe.dp_size > 1), - use_deepseek_fp8_block_scale=self.block_quant, - ), - ) - self.use_inplace = False - - elif self.fp8_backend == Fp8MoeBackend.AITER: - self.kernel = mk.FusedMoEModularKernel( - # TODO: make defer_input_quant an attr of the AiterExperts - MoEPrepareAndFinalizeNoEP(defer_input_quant=True), - AiterExperts(quant_config=self.moe_quant_config), - ) - elif self.fp8_backend == Fp8MoeBackend.MARLIN: - self.kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - MarlinExperts(quant_config=self.moe_quant_config), - ) - else: - self.kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonOrDeepGemmExperts( - quant_config=self.moe_quant_config, - allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM), - ), + if self.moe_quant_config: + self.kernel, self.use_inplace = make_fp8_moe_kernel( + layer=layer, + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + fp8_backend=self.fp8_backend, ) def process_weights_after_loading(self, layer: Module) -> None: @@ -1056,78 +830,58 @@ def process_weights_after_loading(self, layer: Module) -> None: return # Allow for accessing weights and scales in standard way. - w13_weight = layer.w13_weight - w2_weight = layer.w2_weight - w13_weight_scale = getattr(layer, f"w13_{self.weight_scale_name}") - w2_weight_scale = getattr(layer, f"w2_{self.weight_scale_name}") + w13 = layer.w13_weight + w2 = layer.w2_weight + w13_scale = getattr(layer, f"w13_{self.weight_scale_name}") + w2_scale = getattr(layer, f"w2_{self.weight_scale_name}") w13_input_scale = layer.w13_input_scale w2_input_scale = layer.w2_input_scale # MI300x and MI325x use FNUZ format for FP8. Convert if needed. if current_platform.is_fp8_fnuz(): - w13_weight, w13_weight_scale, w13_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - w13_weight, w13_weight_scale, w13_input_scale - ) + w13, w13_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz( + w13, + w13_scale, + w13_input_scale, ) - w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( - w2_weight, w2_weight_scale, w2_input_scale + w2, w2_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( + w2, + w2_scale, + w2_input_scale, ) # Per tensor kernels require single activation scale. Use the max. if self.quant_config.activation_scheme == "static": assert not self.block_quant assert w13_input_scale is not None and w2_input_scale is not None - if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale): - logger.warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer." - ) - replace_parameter(layer, "w13_input_scale", w13_input_scale.max()) - replace_parameter(layer, "w2_input_scale", w2_input_scale.max()) + w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe( + w13_input_scale, w2_input_scale + ) + replace_parameter(layer, "w13_input_scale", w13_input_scale) + replace_parameter(layer, "w2_input_scale", w2_input_scale) # Per tensor kernels require single weight scale for w13 per expert, but # on disk there is a scale for w1 and w3. Use the max to requantize. if not self.block_quant: shard_size = layer.intermediate_size_per_partition - max_w13_scales = w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - w13_weight[expert_id][start : start + shard_size, :], - w13_weight_scale[expert_id][shard_id], - ) - w13_weight[expert_id][start : start + shard_size, :], _ = ( - ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) - ) - start += shard_size - w13_weight_scale = max_w13_scales + w13, w13_scale = process_fp8_weight_tensor_strategy_moe( + w13, w13_scale, shard_size, layer.local_num_experts + ) - # Shuffle weights into the runtime format. - self._convert_weights_to_kernel_format( - layer=layer, - w13_weight=w13_weight, - w2_weight=w2_weight, - w13_weight_scale=w13_weight_scale, - w2_weight_scale=w2_weight_scale, - w13_input_scale=w13_input_scale, - w2_input_scale=w2_input_scale, + # Shuffle weights to runtime format and setup kernel. + self._setup_kernel( + layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale ) - # Setup modular kernel for TP case. - self._setup_kernel(layer) - def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: - if ( - self.fp8_backend == Fp8MoeBackend.AITER - or self.fp8_backend == Fp8MoeBackend.MARLIN - or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - ): + if self.fp8_backend in [ + Fp8MoeBackend.AITER, + Fp8MoeBackend.MARLIN, + Fp8MoeBackend.FLASHINFER_TRTLLM, + ]: return None elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( @@ -1184,7 +938,7 @@ def select_gemm_impl( ) elif self.moe.is_lora_enabled: return TritonExperts(quant_config=self.moe_quant_config) - elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: # Select GEMM experts with block-scale when weights are block-quantized experts = select_cutlass_fp8_gemm_impl( self.moe, @@ -1193,17 +947,23 @@ def select_gemm_impl( ) logger.debug_once("Using %s", experts.__class__.__name__) return experts - else: + elif self.fp8_backend == Fp8MoeBackend.DEEPGEMM: logger.debug( "TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s", self.__class__.__name__, self.weight_block_size, False, ) - return TritonOrDeepGemmExperts( - quant_config=self.moe_quant_config, - allow_deep_gemm=(self.fp8_backend == Fp8MoeBackend.DEEPGEMM), + return TritonOrDeepGemmExperts(self.moe_quant_config) + else: + assert self.fp8_backend == Fp8MoeBackend.TRITON + logger.debug( + "TritonExperts(%s): block_size=%s, per_act_token=%s", + self.__class__.__name__, + self.weight_block_size, + False, ) + return TritonExperts(self.moe_quant_config) def get_fused_moe_quant_config( self, layer: torch.nn.Module @@ -1212,42 +972,13 @@ def get_fused_moe_quant_config( if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: return None - # MARLIN uses mixed precision W8A16 config. - if self.fp8_backend == Fp8MoeBackend.MARLIN: - return fp8_w8a16_moe_quant_config( - w1_scale=getattr(layer, f"w13_{self.weight_scale_name}"), - w2_scale=getattr(layer, f"w2_{self.weight_scale_name}"), - block_shape=self.weight_block_size, - ) - w1_scale = getattr(layer, f"w13_{self.weight_scale_name}") w2_scale = getattr(layer, f"w2_{self.weight_scale_name}") a1_scale = layer.w13_input_scale a2_scale = layer.w2_input_scale - # Flashinfer CUTLASS per-tensor uses single dq scale - # (alpha = w_scale * a_scale) and inverse a2 scale. - if ( - self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS - and not self.block_quant - ): - g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( - w1_scale, - a1_scale, - w2_scale, - a2_scale, - ) - return fp8_w8a8_moe_quant_config( - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=(1.0 / a2_scale), - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, - ) - - # All other backends use normal config. - return fp8_w8a8_moe_quant_config( + return make_fp8_moe_quant_config( + fp8_backend=self.fp8_backend, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, @@ -1269,7 +1000,7 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: # TODO(rob): convert this to MK. if layer.enable_eplb: raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.") @@ -1308,10 +1039,7 @@ def apply( routed_scaling=layer.routed_scaling_factor, ) else: - assert ( - not layer.renormalize and layer.custom_routing_function is not None - ) - result = apply_flashinfer_per_tensor_scale_fp8( + result = apply_fi_trtllm_fp8_per_tensor_moe( layer=layer, hidden_states=x, router_logits=router_logits, @@ -1327,6 +1055,8 @@ def apply( hidden_states=x, router_logits=router_logits, ) + + assert self.kernel is not None result = self.kernel( x, layer.w13_weight, @@ -1358,7 +1088,6 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): assert not quant_config.is_checkpoint_fp8_serialized assert quant_config.activation_scheme == "dynamic" assert quant_config.weight_block_size is None - assert self.flashinfer_moe_backend is None def create_weights( self, @@ -1447,6 +1176,8 @@ def patched_weight_loader(param, loaded_weight, *args, **kwargs): ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) layer.w13_input_scale = None layer.w2_input_scale = None @@ -1457,33 +1188,30 @@ def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint is fp16, quantize in place. fp8_dtype = current_platform.fp8_dtype() - w13_weight = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) - w2_weight = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) + w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype) + w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype) + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale for expert in range(layer.local_num_experts): - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w13_weight[expert, :, :]) + w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant( + layer.w13_weight[expert, :, :] ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w2_weight[expert, :, :]) + w2[expert, :, :], w2_scale[expert] = ops.scaled_fp8_quant( + layer.w2_weight[expert, :, :] ) - replace_parameter(layer, "w13_weight", w13_weight) - replace_parameter(layer, "w2_weight", w2_weight) - # Shuffle weights into the runtime format. - self._convert_weights_to_kernel_format( - layer=layer, - w13_weight=w13_weight, - w2_weight=w2_weight, - w13_weight_scale=layer.w13_weight_scale, - w2_weight_scale=layer.w2_weight_scale, - w13_input_scale=None, - w2_input_scale=None, + # Shuffle weights to runtime format and setup kernel. + self._setup_kernel( + layer, + w13, + w2, + w13_scale, + w2_scale, + layer.w13_input_scale, + layer.w2_input_scale, ) - # Setup modular kernel for TP case. - self._setup_kernel(layer) - class Fp8KVCacheMethod(BaseKVCacheMethod): """ diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index b6752d7f9913..bd7a90a80af1 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -15,7 +15,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, - fp8_w8a8_moe_quant_config, nvfp4_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe @@ -24,6 +23,13 @@ FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) +from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( + Fp8MoeBackend, + convert_to_fp8_moe_kernel_format, + make_fp8_moe_kernel, + make_fp8_moe_quant_config, + select_fp8_moe_backend, +) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, @@ -45,19 +51,16 @@ ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, - apply_flashinfer_per_tensor_scale_fp8, + apply_fi_trtllm_fp8_per_tensor_moe, build_flashinfer_fp8_cutlass_moe_prepare_finalize, - flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, is_flashinfer_supporting_global_sf, - make_fp8_moe_alpha_scales_for_fi, - register_scales_for_trtllm_fp8_per_tensor_moe, - rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, - swap_w13_to_w31, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, + process_fp8_input_tensor_strategy_moe, + process_fp8_weight_tensor_strategy_moe, ) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( get_marlin_input_dtype, @@ -85,13 +88,12 @@ ModelWeightParameter, PerTensorScaleParameter, ) +from vllm.model_executor.utils import replace_parameter from vllm.scalar_type import scalar_types from vllm.utils.flashinfer import ( flashinfer_scaled_fp4_mm, has_flashinfer, - has_flashinfer_moe, ) -from vllm.utils.math_utils import round_up if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -721,38 +723,23 @@ def __init__( layer: FusedMoE, ) -> None: super().__init__(layer.moe_config) - self.layer = layer self.quant_config = quant_config - from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - cutlass_fp8_supported, + assert self.quant_config.is_checkpoint_fp8_serialized + self.fp8_backend = select_fp8_moe_backend( + block_quant=False, + tp_size=layer.moe_parallel_config.tp_size, + with_lora_support=self.moe.is_lora_enabled, ) - - self.cutlass_fp8_supported = cutlass_fp8_supported() - self.flashinfer_moe_backend: FlashinferMoeBackend | None = None - if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): - self.flashinfer_moe_backend = get_flashinfer_moe_backend() - if ( - self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - and not self.moe.is_act_and_mul - ): - logger.info_once( - "Non-gated MoE is not supported for min-latency mode," - "falling back to high-throughput mode" - ) - self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS - - logger.info_once( - f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" - ) + self.kernel: mk.FusedMoEModularKernel | None = None def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: # TRT LLM not supported with all2all yet. - if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: return None - elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: # TP case: avoid convert to ModularKernelMethod - to be refactored. if self.moe.dp_size == 1: return None @@ -787,6 +774,9 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + layer.orig_dtype = params_dtype + layer.num_experts = num_experts + # Use FP8 dtype if checkpoint is serialized weight_dtype = ( torch.float8_e4m3fn @@ -826,217 +816,121 @@ def create_weights( ) layer.register_parameter("w2_weight", w2_weight) - if self.quant_config.is_checkpoint_fp8_serialized: - # WEIGHT SCALES - Per-tensor scaling for ModelOpts - # For gated MoE, allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - # For non-gated MoE, allocate 1 scale for w13. - if self.moe.is_act_and_mul: - w13_weight_scale_shape = (num_experts, 2) - else: - w13_weight_scale_shape = (num_experts, 1) - w13_weight_scale = PerTensorScaleParameter( - data=torch.full( - w13_weight_scale_shape, - 1.0, - dtype=torch.float32, - ), - weight_loader=weight_loader, - ) - w2_weight_scale = PerTensorScaleParameter( - data=torch.full((num_experts,), 1.0, dtype=torch.float32), - weight_loader=weight_loader, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - - # Set weight loader attributes for scales - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - - # INPUT SCALES - Per-tensor scaling for ModelOpt - w13_input_scale = PerTensorScaleParameter( - data=torch.full((num_experts,), 1.0, dtype=torch.float32), - weight_loader=weight_loader, - ) - w2_input_scale = PerTensorScaleParameter( - data=torch.full((num_experts,), 1.0, dtype=torch.float32), - weight_loader=weight_loader, - ) - layer.register_parameter("w13_input_scale", w13_input_scale) - layer.register_parameter("w2_input_scale", w2_input_scale) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - """Process FP8 MoE weights after loading from serialized checkpoint. - Only supports pre-quantized checkpoints with FP8 weights and scales. - """ - - if self.flashinfer_moe_backend is not None: - self._maybe_pad_intermediate_for_flashinfer(layer) - - layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) - layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + # WEIGHT SCALES - Per-tensor scaling for ModelOpts + # For gated MoE, allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + # For non-gated MoE, allocate 1 scale for w13. + w13_weight_scale = PerTensorScaleParameter( + data=torch.full( + (num_experts, 2 if self.moe.is_act_and_mul else 1), + 1.0, + dtype=torch.float32, + ), + weight_loader=weight_loader, + ) + w2_weight_scale = PerTensorScaleParameter( + data=torch.full((num_experts,), 1.0, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) - from vllm._custom_ops import scaled_fp8_quant - from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - per_tensor_dequantize, + # INPUT SCALES - Per-tensor scaling for ModelOpt + w13_input_scale = PerTensorScaleParameter( + data=torch.full((num_experts,), 1.0, dtype=torch.float32), + weight_loader=weight_loader, + ) + w2_input_scale = PerTensorScaleParameter( + data=torch.full((num_experts,), 1.0, dtype=torch.float32), + weight_loader=weight_loader, ) + layer.register_parameter("w13_input_scale", w13_input_scale) + layer.register_parameter("w2_input_scale", w2_input_scale) - # Handle scale parameters - if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None: - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max of the w1 and w3 scales - # then dequant and requant each expert. - if ( - layer.w13_weight_scale.dim() == 2 - and layer.w13_weight_scale.shape[1] == 2 - ): - assert self.moe.is_act_and_mul, ( - "w13_weight_scale should have 2 elements per expert " - "only for gated MoE" - ) - # Get the maximum scale across w1 and w3 for each expert - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - - # Requantize each expert's weights using the combined scale - # w13_weight (num_experts, 2 * intermediate_size, hidden_size) - # where the first intermediate_size rows are w1, the next are w3 - intermediate_size = layer.w13_weight.shape[1] // 2 - for expert_id in range(layer.w13_weight.shape[0]): - start = 0 - for shard_id in range(2): # w1 and w3 - # Dequantize using the original scale for this shard - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][ - start : start + intermediate_size, : - ], - layer.w13_weight_scale[expert_id][shard_id], - ) - # Requantize using the combined max scale - - ( - layer.w13_weight[expert_id][ - start : start + intermediate_size, : - ], - _, - ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) - - start += intermediate_size - - # Update the scale parameter to be per-expert - layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False) - else: - layer.w13_weight_scale = Parameter( - layer.w13_weight_scale.data, requires_grad=False - ) + def _setup_kernel( + self, + layer: torch.nn.Module, + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + w13_input_scale: torch.Tensor, + w2_input_scale: torch.Tensor, + ): + w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format( + fp8_backend=self.fp8_backend, + layer=layer, + w13=w13, + w2=w2, + w13_scale=w13_scale, + w2_scale=w2_scale, + w13_input_scale=w13_input_scale, + w2_input_scale=w2_input_scale, + ) - if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None: - layer.w2_weight_scale = Parameter( - layer.w2_weight_scale.data, requires_grad=False - ) - # Input scales must be equal for each expert in fp8 MoE layers. - if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None: - layer.w13_input_scale = Parameter( - layer.w13_input_scale.max(), requires_grad=False - ) - if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None: - layer.w2_input_scale = Parameter( - layer.w2_input_scale.max(), requires_grad=False + # Replace parameters with updated versions. Note that this helper + # function ensures the replacement is compatible with RL weight reloads. + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w2_weight_scale", w2_scale) + + # Setup modular kernel for TP case. + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + if self.moe_quant_config: + self.kernel, self.use_inplace = make_fp8_moe_kernel( + layer=layer, + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + fp8_backend=self.fp8_backend, ) - if self.flashinfer_moe_backend is not None: - if self.moe.is_act_and_mul: - layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) - - # NOTE: this adds some attributes used by the trtllm kernel, - # which does not conform to the modular kernels abstraction (yet). - if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: - rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) - register_scales_for_trtllm_fp8_per_tensor_moe( - layer=layer, - w13_weight_scale=layer.w13_weight_scale, - w13_input_scale=layer.w13_input_scale, - w2_weight_scale=layer.w2_weight_scale, - w2_input_scale=layer.w2_input_scale, - ) - - def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None: - """Pad intermediate size so FlashInfer kernels' alignment constraints hold. - - Some FlashInfer FP8 MoE kernels require the (gated) intermediate size - used for GEMM to be divisible by a small alignment value. When this is - not satisfied (e.g. with certain tensor-parallel sizes), we pad the - gate/up and down projection weights along the intermediate dim. - """ - if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"): - return - - # Current local intermediate size (per partition) is the K dimension of - # the down projection. - num_experts, hidden_size, intermediate = layer.w2_weight.shape - - min_alignment = 16 - padded_intermediate = round_up(intermediate, min_alignment) - - if padded_intermediate == intermediate: - return - - logger.info( - "Padding intermediate size from %d to %d for up/down projection weights.", - intermediate, - padded_intermediate, + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + w13 = layer.w13_weight + w2 = layer.w2_weight + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + w13_input_scale = layer.w13_input_scale + w2_input_scale = layer.w2_input_scale + + # Per tensor kernels require single activation scale. Use the max. + w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe( + w13_input_scale, w2_input_scale + ) + replace_parameter(layer, "w13_input_scale", w13_input_scale) + replace_parameter(layer, "w2_input_scale", w2_input_scale) + + # Per tensor kernels require single weight scale for w13 per expert, but + # on disk there is a scale for w1 and w3. Use the max to requantize. + shard_size = layer.intermediate_size_per_partition + w13, w13_scale = process_fp8_weight_tensor_strategy_moe( + w13, + w13_scale, + shard_size, + num_experts=layer.w13_weight.shape[0], + is_act_and_mul=self.moe.is_act_and_mul, ) - up_mult = 2 if self.moe.is_act_and_mul else 1 - padded_gate_up_dim = up_mult * padded_intermediate - - # Pad w13 and w12 along its intermediate dimension. - w13 = layer.w13_weight.data - padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size)) - padded_w13[:, : w13.shape[1], :] = w13 - layer.w13_weight.data = padded_w13 - - w2 = layer.w2_weight.data - padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate)) - padded_w2[:, :, :intermediate] = w2 - layer.w2_weight.data = padded_w2 - - if hasattr(layer, "intermediate_size_per_partition"): - layer.intermediate_size_per_partition = padded_intermediate + # Shuffle weights to runtime format and setup kernel. + self._setup_kernel( + layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale + ) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: - # TRTLLM does not use modular kernels - return None - - elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( - layer.w13_weight_scale, - layer.w13_input_scale, - layer.w2_weight_scale, - layer.w2_input_scale, - ) - return fp8_w8a8_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - a1_gscale=(1.0 / layer.w13_input_scale), - a2_gscale=(1.0 / layer.w2_input_scale), - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, - ) - else: - assert self.flashinfer_moe_backend is None - return fp8_w8a8_moe_quant_config( - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) + w1_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + a1_scale = layer.w13_input_scale + a2_scale = layer.w2_input_scale + + return make_fp8_moe_quant_config( + fp8_backend=self.fp8_backend, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) def apply( self, @@ -1044,17 +938,18 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: if layer.enable_eplb: raise NotImplementedError( - "EPLB not supported for `ModelOptFp8MoEMethod` yet." + "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend." ) + # TODO(rob): this validation should happen at kernel selection + # time in the oracle rather than here. assert layer.activation == "silu", ( f"Expected 'silu' activation but got {layer.activation}" ) - assert not layer.renormalize - return apply_flashinfer_per_tensor_scale_fp8( + return apply_fi_trtllm_fp8_per_tensor_moe( layer=layer, hidden_states=x, router_logits=router_logits, @@ -1066,46 +961,34 @@ def apply( apply_router_weight_on_input=layer.apply_router_weight_on_input, ) - # Expert selection topk_weights, topk_ids = layer.select_experts( hidden_states=x, router_logits=router_logits, ) - if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + # TODO(rob): this validation should happen at kernel selection + # time in the oracle rather than here. + if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: assert layer.activation in ("silu", "relu2_no_mul"), ( "Expected activation to be in ('silu', 'relu2_no_mul')," f"but got {layer.activation}" ) - return flashinfer_cutlass_moe_fp8( - x, - layer, - topk_weights, - topk_ids, - inplace=False, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - ) - else: - from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts - assert self.moe_quant_config is not None + assert self.kernel is not None + result = self.kernel( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=self.use_inplace, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + ) - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=layer.activation, - quant_config=self.moe_quant_config, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - ) + return result ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 15d37f7d366b..4ab618dc44ef 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -22,7 +22,7 @@ ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - prepare_moe_fp8_layer_for_marlin, + prepare_fp8_moe_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_BLOCK_SIZE, @@ -315,8 +315,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) elif self.use_marlin: - (workspace, w13_weight, w2_weight, w13_weight_scale, w2_weight_scale) = ( - prepare_moe_fp8_layer_for_marlin( + w13_weight, w2_weight, w13_weight_scale, w2_weight_scale = ( + prepare_fp8_moe_layer_for_marlin( layer, layer.w13_weight, layer.w2_weight, @@ -324,7 +324,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_weight_scale, ) ) - layer.workspace = workspace # TODO(rob): once we apply refactor to Quark, switch to using # replace_parameter for compatibility with reloading in RL. layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index b73c44b3130d..a9b30b780587 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -18,6 +18,7 @@ create_flashinfer_prepare_finalize, ) from vllm.platforms import current_platform +from vllm.utils.math_utils import round_up logger = init_logger(__name__) @@ -58,9 +59,10 @@ def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: ) -def rotate_flashinfer_fp8_moe_weights( +def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( gemm1_weights: torch.Tensor, gemm2_weights: torch.Tensor ): + """Shuffle weights for for FI TRT-LLM Format""" from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_a epilogue_tile_m = 128 @@ -105,16 +107,16 @@ def rotate_flashinfer_fp8_moe_weights( def register_scales_for_trtllm_fp8_per_tensor_moe( layer: torch.nn.Module, - w13_weight_scale: torch.Tensor, + w13_scale: torch.Tensor, w13_input_scale: torch.Tensor, - w2_weight_scale: torch.Tensor, + w2_scale: torch.Tensor, w2_input_scale: torch.Tensor, ) -> None: """Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel""" g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( - w13_scale=w13_weight_scale, + w13_scale=w13_scale, w13_input_scale=w13_input_scale, - w2_scale=w2_weight_scale, + w2_scale=w2_scale, w2_input_scale=w2_input_scale, ) layer.w2_input_scale_inv = 1.0 / w2_input_scale @@ -123,7 +125,7 @@ def register_scales_for_trtllm_fp8_per_tensor_moe( layer.output2_scales_scalar = g2_alphas -def apply_flashinfer_per_tensor_scale_fp8( +def apply_fi_trtllm_fp8_per_tensor_moe( layer: torch.nn.Module, hidden_states: torch.Tensor, router_logits: torch.Tensor, @@ -139,16 +141,23 @@ def apply_flashinfer_per_tensor_scale_fp8( import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 from vllm.model_executor.models.llama4 import Llama4MoE + # Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe assert ( hasattr(layer, "output1_scales_scalar") and hasattr(layer, "output1_scales_gate_scalar") and hasattr(layer, "output2_scales_scalar") ) - assert layer.custom_routing_function == Llama4MoE.custom_routing_function, ( - "FusedMoE flashinfer kernels are only supported for Llama4" + # Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe + assert ( + hasattr(layer, "output1_scales_scalar") + and hasattr(layer, "output1_scales_gate_scalar") + and hasattr(layer, "output2_scales_scalar") ) - return torch.ops.vllm.flashinfer_fused_moe_per_tensor_scale_fp8( + + is_llama4 = layer.custom_routing_function == Llama4MoE.custom_routing_function + assert is_llama4, "FusedMoE flashinfer kernels are only supported for Llama4" + return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe( routing_logits=router_logits, routing_bias=routing_bias, hidden_states=hidden_states, @@ -221,50 +230,6 @@ def select_cutlass_fp8_gemm_impl( ) -def flashinfer_cutlass_moe_fp8( - hidden_states: torch.Tensor, - layer: torch.nn.Module, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool = False, - activation: str = "silu", - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - use_deepseek_fp8_block_scale: bool = False, - moe: FusedMoEConfig | None = None, -) -> torch.Tensor: - quant_config = layer.quant_method.get_fused_moe_quant_config(layer) - assert quant_config is not None - - # Construct modular kernel with block-scale support when requested. - fused_experts = mk.FusedMoEModularKernel( - build_flashinfer_fp8_cutlass_moe_prepare_finalize( - moe=moe, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale - ), - select_cutlass_fp8_gemm_impl( - moe=moe, - quant_config=quant_config, - out_dtype=hidden_states.dtype, - use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, - ), - moe_parallel_config=layer.moe_parallel_config, - ) - - return fused_experts( - hidden_states, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - inplace=inplace, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - def get_flashinfer_moe_backend() -> FlashinferMoeBackend: backend_map = { "throughput": FlashinferMoeBackend.CUTLASS, @@ -301,3 +266,104 @@ def is_flashinfer_supporting_global_sf(backend: FlashinferMoeBackend | None) -> FlashinferMoeBackend.TENSORRT_LLM, ) return backend in backends_supporting_global_sf + + +def align_fp8_moe_weights_for_fi( + w13: torch.Tensor, w2: torch.Tensor, is_act_and_mul: bool +) -> tuple[torch.Tensor, torch.Tensor, int]: + """Pad intermediate size so FlashInfer kernels' alignment constraints hold. + + Some FlashInfer FP8 MoE kernels require the (gated) intermediate size + used for GEMM to be divisible by a small alignment value. When this is + not satisfied (e.g. with certain tensor-parallel sizes), we pad the + gate/up and down projection weights along the intermediate dim. + """ + + # Current local intermediate size (per partition) is the K dimension of + # the down projection. + num_experts, hidden_size, intermediate = w2.shape + + min_alignment = 16 + padded_intermediate = round_up(intermediate, min_alignment) + + if padded_intermediate == intermediate: + return w13, w2, intermediate + + logger.info_once( + "Padding intermediate size from %d to %d for up/down projection weights.", + intermediate, + padded_intermediate, + scope="local", + ) + + up_mult = 2 if is_act_and_mul else 1 + padded_gate_up_dim = up_mult * padded_intermediate + + # Pad w13 and w2 along its intermediate dimension. + padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size)) + padded_w13[:, : w13.shape[1], :] = w13 + + padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate)) + padded_w2[:, :, :intermediate] = w2 + + return padded_w13, padded_w2, padded_intermediate + + +def prepare_fp8_moe_layer_for_fi( + layer: torch.nn.Module, + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w13_input_scale: torch.Tensor | None, + w2_scale: torch.Tensor, + w2_input_scale: torch.Tensor | None, + is_trtllm: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Convert Fp8 MoE weights to flashinfer kernel format + + Note that for trtllm we update the model state dict + with the scale format needed for these kernels. + + Note that for per-tensor, we update the layer's + intermediate size if the weights needed padding. + """ + + assert hasattr(layer.moe_config, "is_act_and_mul") + block_quant = ( + hasattr(layer, "weight_block_size") and layer.weight_block_size is not None + ) + + # Some FI MoE kernels require internal alignment of 16 + # for the gate-up proj. Pad the weights to respect this. + if not block_quant: + w13, w2, new_intermediate = align_fp8_moe_weights_for_fi( + w13, + w2, + layer.moe_config.is_act_and_mul, + ) + layer.intermediate_size_per_partition = new_intermediate + + # FI kernels require W31 layout rather than W13. + if layer.moe_config.is_act_and_mul: + w13 = swap_w13_to_w31(w13) + if block_quant: + w13_scale = swap_w13_to_w31(w13_scale) + + # FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle + # and registration of alpha scales. Note that we do not register + # as nn.Parameters since they are not needed for weight-reloading. + if is_trtllm and not block_quant: + assert w13_input_scale is not None + assert w2_input_scale is not None + + rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2) + register_scales_for_trtllm_fp8_per_tensor_moe( + layer, + w13_scale=w13_scale, + w13_input_scale=w13_input_scale, + w2_scale=w2_scale, + w2_input_scale=w2_input_scale, + ) + + return w13, w2, w13_scale diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 1c9e36e02248..9d74becd588d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -21,6 +21,8 @@ ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED, + all_close_1d, + per_tensor_dequantize, ) from vllm.model_executor.parameter import ( BlockQuantScaleParameter, @@ -1350,6 +1352,29 @@ def deepgemm_post_process_fp8_weight_block( return wq, dg_ws +def prepare_fp8_moe_layer_for_deepgemm( + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + block_shape: tuple[int], +): + w13, w13_scale = deepgemm_post_process_fp8_weight_block( + wq=w13, + ws=w13_scale, + quant_block_shape=block_shape, + use_e8m0=is_deep_gemm_e8m0_used(), + ) + w2, w2_scale = deepgemm_post_process_fp8_weight_block( + wq=w2, + ws=w2_scale, + quant_block_shape=block_shape, + use_e8m0=is_deep_gemm_e8m0_used(), + ) + + return w13, w2, w13_scale, w2_scale + + def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor: """Pad the weight tensor. This is an optimization on ROCm platform, which can benefit from tensors located far enough from one another in memory""" @@ -1584,7 +1609,49 @@ def maybe_post_process_fp8_weight_block(layer: torch.nn.Module): replace_parameter(layer, scale_attr, dg_weight_scale) -def expert_weight_is_col_major(x: torch.Tensor) -> bool: - assert x.dim() == 3 - b, m, n = x.shape - return x.stride(0) == m * n and x.stride(1) == 1 and x.stride(2) == m +def process_fp8_weight_tensor_strategy_moe( + weight: torch.Tensor, + weight_scales: torch.Tensor, + shard_size: int, + num_experts: int, + is_act_and_mul: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + """Process moe weights for tensor-wise quantization strategy.""" + max_scales = weight_scales.max(dim=1).values + + # For w1 case (i.e. not w13): just collapse the last dim since + # there is already just one scale per expert in this case. + if not is_act_and_mul: + assert weight_scales.shape[1] == 1 + return weight, weight_scales.max() + + # For w13 case (common): require single scale for w13 per expert, but + # on disk there is a scale for w1 and w3. Use the max to requantize. + for expert_id in range(num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + weight[expert_id][start : start + shard_size, :], + weight_scales[expert_id][shard_id], + ) + weight[expert_id][start : start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_scales[expert_id] + ) + start += shard_size + return weight, max_scales + + +def process_fp8_input_tensor_strategy_moe( + w13_input_scale: torch.Tensor, + w2_input_scale: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Process moe input scales for tensor-wise quantization strategy.""" + + if not all_close_1d(w13_input_scale) or not all_close_1d(w2_input_scale): + logger.info_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer." + ) + + return w13_input_scale.max(), w2_input_scale.max() diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 0c2f984f9f0f..0e21c81f70f8 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -496,7 +496,7 @@ def get__quant_fp8_method() -> QuantFP8: return _quant_fp8_method -def get_marlin_input_dtype(prefix): +def get_marlin_input_dtype(prefix: str | None = None): if envs.VLLM_MARLIN_INPUT_DTYPE is None: return elif envs.VLLM_MARLIN_INPUT_DTYPE.lower() == "int8": diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py index a8d9db224860..91b93c76cb32 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.marlin_utils import ( USE_FP32_REDUCE_DEFAULT, + get_marlin_input_dtype, marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, @@ -197,26 +198,28 @@ def prepare_fp8_layer_for_marlin( replace_parameter(layer, "bias", bias) -def prepare_moe_fp8_layer_for_marlin( +def prepare_fp8_moe_layer_for_marlin( layer: torch.nn.Module, w13_weight: torch.Tensor, w2_weight: torch.Tensor, w13_weight_scale: torch.Tensor, w2_weight_scale: torch.Tensor, - input_dtype: torch.dtype | None = None, -) -> tuple[ - torch.Tensor, # workspace - torch.Tensor, # w13_weight - torch.Tensor, # w2_weight - torch.Tensor, # w13_weight_scale - torch.Tensor, # w2_weight_scale -]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Shuffle weights and scales into marlin format. + + Note that this function has the side effect of adding a `workspace` + attribute to the layer. This `workspace` does not need to be + registered as a Parameter as it is not used during weight reloading. + """ + logger.warning_once( "Your GPU does not have native support for FP8 computation but " "FP8 quantization is being used. Weight-only FP8 compression will " "be used leveraging the Marlin kernel. This may degrade " "performance for compute-heavy workloads." ) + input_dtype = get_marlin_input_dtype() if input_dtype is not None and input_dtype.itemsize == 1: raise NotImplementedError("Marlin W8A8 is not supported.") @@ -227,7 +230,9 @@ def prepare_moe_fp8_layer_for_marlin( # WORKSPACE device = layer.w13_weight.device - workspace = marlin_make_workspace_new(device, 4) + # NOTE(rob): we do not need to register the workspace as a param + # because it is not used as part of the weight reloading process. + layer.workspace = marlin_make_workspace_new(device, 4) perm = torch.empty(0, dtype=torch.int, device=device) # WEIGHT @@ -310,13 +315,7 @@ def permute_scales(scales: torch.Tensor, name: str) -> torch.Tensor: w13_weight_scale = permute_scales(w13_weight_scale, "w13") w2_weight_scale = permute_scales(w2_weight_scale, "w2") - return ( - workspace, - w13_weight, - w2_weight, - w13_weight_scale, - w2_weight_scale, - ) + return w13_weight, w2_weight, w13_weight_scale, w2_weight_scale def pack_fp8_to_int32(