diff --git a/tests/compile/correctness_e2e/test_async_tp.py b/tests/compile/correctness_e2e/test_async_tp.py index 3539e4d5abb4..932e513258d0 100644 --- a/tests/compile/correctness_e2e/test_async_tp.py +++ b/tests/compile/correctness_e2e/test_async_tp.py @@ -13,6 +13,17 @@ from vllm.config import ( CompilationMode, ) +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer + +NVFP4_MODEL_ID = "nvidia/Llama-3.1-8B-Instruct-NVFP4" +NVFP4_HF_OVERRIDES = { + "num_hidden_layers": 4, + "hidden_size": 512, + "intermediate_size": 800, + "num_attention_heads": 4, + "num_key_value_heads": 1, +} @create_new_process_for_each_test() @@ -82,3 +93,65 @@ def test_async_tp_pass_correctness( ] compare_two_settings(model_id, async_tp_args, tp_args, method="generate") + + +@create_new_process_for_each_test() +def test_async_tp_pass_nvfp4_correctness(num_gpus_available: int, monkeypatch): + if ( + not current_platform.is_cuda() + or not current_platform.is_device_capability_family(100) + ): + pytest.skip("NVFP4 requires Blackwell") + if not has_flashinfer(): + pytest.skip("FlashInfer is required for the NVFP4 AsyncTP path") + + monkeypatch.setenv("VLLM_NVFP4_GEMM_BACKEND", "flashinfer-cutlass") + + tp_size = 2 + if num_gpus_available < tp_size: + pytest.skip(f"Need at least {tp_size} GPUs") + + common_args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + "--load-format", + "dummy", + "--hf-overrides", + json.dumps(NVFP4_HF_OVERRIDES), + ] + + compilation_config = { + "mode": CompilationMode.VLLM_COMPILE, + "compile_sizes": [2, 4, 8], + "splitting_ops": [], + "pass_config": { + "enable_sp": True, + "fuse_gemm_comms": True, + "fuse_allreduce_rms": False, + "sp_min_token_num": 1, + }, + } + + async_tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + "mp", + "--compilation_config", + json.dumps(compilation_config), + ] + + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + "mp", + ] + + compare_two_settings(NVFP4_MODEL_ID, async_tp_args, tp_args, method="generate") diff --git a/tests/compile/correctness_e2e/test_sequence_parallel.py b/tests/compile/correctness_e2e/test_sequence_parallel.py index 4b7cb814e74a..56d18b5d0e22 100644 --- a/tests/compile/correctness_e2e/test_sequence_parallel.py +++ b/tests/compile/correctness_e2e/test_sequence_parallel.py @@ -21,12 +21,14 @@ from vllm.platforms import current_platform from vllm.utils.torch_utils import is_torch_equal_or_newer -from ...models.registry import HF_EXAMPLE_MODELS +from ...models.registry import HF_EXAMPLE_MODELS, _HfExamplesInfo from ...utils import compare_two_settings, create_new_process_for_each_test logger = init_logger("test_sequence_parallel") VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" +NVFP4_MODEL_ID = "nvidia/Llama-3.1-8B-Instruct-NVFP4" +NVFP4_MODEL_INFO = _HfExamplesInfo(NVFP4_MODEL_ID) class ParallelSetup(NamedTuple): @@ -41,6 +43,7 @@ class ParallelSetup(NamedTuple): class SPTestOptions(NamedTuple): multi_node_only: bool load_format: str | None = None + model_info: _HfExamplesInfo | None = None @dataclass @@ -170,6 +173,7 @@ def _compare_sp( *, method: Literal["generate", "encode"], is_multimodal: bool, + dtype: str = "float16", ): ( tp_size, @@ -180,14 +184,15 @@ def _compare_sp( chunked_prefill, ) = parallel_setup - multi_node_only, load_format = test_options + multi_node_only = test_options.multi_node_only + load_format = test_options.load_format - model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info = test_options.model_info or HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info.check_transformers_version(on_fail="skip") trust_remote_code = model_info.trust_remote_code tokenizer_mode = model_info.tokenizer_mode - hf_overrides = model_info.hf_overrides + hf_overrides = dict(model_info.hf_overrides) require_embed_inputs = model_info.require_embed_inputs if load_format == "dummy": @@ -220,7 +225,7 @@ def _compare_sp( common_args = [ # use half precision for speed and memory savings in CI environment "--dtype", - "float16", + dtype, "--max-model-len", "2048", "--max-num-seqs", @@ -352,3 +357,37 @@ def test_tp_sp_generation( method="generate", is_multimodal=False, ) + + +@create_new_process_for_each_test() +def test_tp_sp_nvfp4_generation(num_gpus_available: int): + if ( + not current_platform.is_cuda() + or not current_platform.is_device_capability_family(100) + ): + pytest.skip("NVFP4 requires Blackwell") + + _compare_sp( + NVFP4_MODEL_ID, + ParallelSetup( + tp_size=2, + pp_size=1, + fuse_norm_quant=True, + fuse_act_quant=True, + eager_mode=True, + chunked_prefill=False, + ), + "mp", + "auto", + SPTestOptions( + multi_node_only=False, + load_format="dummy", + model_info=NVFP4_MODEL_INFO, + ), + num_gpus_available, + use_inductor_graph_partition=False, + fuse_gemm_comms=False, + method="generate", + is_multimodal=False, + dtype="bfloat16", + ) diff --git a/tests/compile/fullgraph/test_toy_llama.py b/tests/compile/fullgraph/test_toy_llama.py index 915fbc6ce7f3..69c758702e8a 100644 --- a/tests/compile/fullgraph/test_toy_llama.py +++ b/tests/compile/fullgraph/test_toy_llama.py @@ -17,7 +17,6 @@ import torch from torch import nn -from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import ( CompilationConfig, @@ -340,6 +339,8 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor: def test_toy_llama( backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path ): + from vllm.compilation.counter import compilation_counter + # We disable the vLLM compile cache into a new tmp dir for 1 reason: # 1. To make sure we can properly track the number of Inductor compilations. monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") diff --git a/tests/compile/fusions_e2e/test_tp2_async_tp.py b/tests/compile/fusions_e2e/test_tp2_async_tp.py index baa7bdef0a7d..a22c68f4bf92 100644 --- a/tests/compile/fusions_e2e/test_tp2_async_tp.py +++ b/tests/compile/fusions_e2e/test_tp2_async_tp.py @@ -13,11 +13,13 @@ AttentionBackendCase, Matches, custom_ops_combos, + is_blackwell, ) from .models import ( FLASHINFER_ATTN, TRITON_ATTN, llama3_8b, + llama3_8b_fp4, llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b, @@ -90,6 +92,69 @@ def test_tp2_async_tp_fp8_fusions( ) +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "model_name, matches_fn, model_kwargs, hf_overrides", + [llama3_8b_fp4], +) +@pytest.mark.parametrize("attn_backend", [FLASHINFER_ATTN]) +@pytest.mark.parametrize("n_layers", [4]) +@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm")) +@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) +@pytest.mark.skipif(not is_blackwell(), reason="Blackwell required for fp4") +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") +def test_tp2_async_tp_nvfp4_fusions( + model_name: str, + matches_fn: Callable[[int], Matches], + model_kwargs: dict, + hf_overrides: Callable[[int], dict], + attn_backend: AttentionBackendCase, + n_layers: int, + custom_ops: str, + inductor_graph_partition: bool, + run_e2e_fusion_test, +): + # NVFP4 currently wires the all-gather + GEMM path only. + matches = matches_fn(n_layers)._replace(async_tp=n_layers * 2) + + # Reduce size of model and skip weight loading time + model_kwargs["hf_overrides"] = hf_overrides(n_layers) + model_kwargs["load_format"] = "dummy" + model_kwargs["max_model_len"] = 1024 + model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False} + + compilation_config = dict( + use_inductor_graph_partition=inductor_graph_partition, + custom_ops=custom_ops.split(","), + pass_config=PassConfig( + fuse_act_quant=True, + fuse_attn_quant=True, + enable_sp=True, + fuse_gemm_comms=True, + fuse_allreduce_rms=False, + # Override threshold for testing (models have small hidden_size) + sp_min_token_num=512, + ), + ) + + matches_check = [ + "act_quant_fusion", + "attn_quant_fusion", + "sequence_parallel", + "async_tp", + ] + + run_e2e_fusion_test( + model_name, + matches, + model_kwargs, + attn_backend, + compilation_config, + matches_check, + tp_size=2, + ) + + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( "model_name, matches_fn, model_kwargs, hf_overrides", diff --git a/vllm/compilation/passes/fusion/collective_fusion.py b/vllm/compilation/passes/fusion/collective_fusion.py index 2b74eae8dd32..29d79c9b92ce 100644 --- a/vllm/compilation/passes/fusion/collective_fusion.py +++ b/vllm/compilation/passes/fusion/collective_fusion.py @@ -74,6 +74,36 @@ def _flashinfer_scaled_mm_out( ) +def _flashinfer_fp4_mm_out( + A: torch.Tensor, + B: torch.Tensor, + *, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out: torch.Tensor, + alpha: torch.Tensor, + out_dtype: torch.dtype | None = None, + use_8x4_sf_layout: bool = False, + backend: str = "cutlass", +) -> None: + from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm_out + + assert A.ndim == 2 and B.ndim == 2 and out.ndim == 2, ( + "FlashInfer FP4 symm_mem adapter expects 2D inputs and output" + ) + flashinfer_scaled_fp4_mm_out( + A, + B, + scale_a, + scale_b, + alpha, + out=out, + out_dtype=out_dtype or out.dtype, + use_8x4_sf_layout=use_8x4_sf_layout, + backend=backend, + ) + + def fused_flashinfer_scaled_matmul_reduce_scatter_fake( A: torch.Tensor, B: torch.Tensor, @@ -197,6 +227,90 @@ def fused_all_gather_flashinfer_scaled_matmul( return outputs[0] +def fused_all_gather_flashinfer_fp4_matmul_fake( + A_shard: torch.Tensor, + B: torch.Tensor, + A_scale_shard: torch.Tensor, + B_scale: torch.Tensor, + alpha: torch.Tensor, + gather_dim: int, + group_name: str, + out_dtype: torch.dtype | None = None, + view_a_scale_as_fp8: bool = False, + use_8x4_sf_layout: bool = False, + backend: str = "cutlass", +) -> torch.Tensor: + world_size = c10d._resolve_process_group(group_name).size() + output_shape = list(A_shard.shape) + output_shape[gather_dim] *= world_size + output_shape[-1] = B.shape[1] + return torch.empty( + output_shape, + dtype=out_dtype or torch.bfloat16, + device=A_shard.device, + ) + + +def fused_all_gather_flashinfer_fp4_matmul( + A_shard: torch.Tensor, + B: torch.Tensor, + A_scale_shard: torch.Tensor, + B_scale: torch.Tensor, + alpha: torch.Tensor, + gather_dim: int, + group_name: str, + out_dtype: torch.dtype | None = None, + view_a_scale_as_fp8: bool = False, + use_8x4_sf_layout: bool = False, + backend: str = "cutlass", +) -> torch.Tensor: + assert gather_dim == 0, ( + "FlashInfer FP4 symm_mem adapter currently only supports gather_dim=0" + ) + assert A_shard.ndim == 2 and A_scale_shard.ndim == 2 and B.ndim == 2, ( + "FlashInfer FP4 symm_mem adapter expects 2D inputs" + ) + if view_a_scale_as_fp8: + A_scale_shard = A_scale_shard.view(torch.float8_e4m3fn) + + group = c10d._resolve_process_group(group_name) + world_size = group.size() + output = A_shard.new_empty( + A_shard.shape[0] * world_size, + B.shape[1], + dtype=out_dtype or torch.bfloat16, + ) + output_shards = output.chunk(world_size) + + A = A_shard.new_empty(A_shard.shape[0] * world_size, A_shard.shape[1]) + A_scale = A_scale_shard.new_empty( + A_scale_shard.shape[0] * world_size, + A_scale_shard.shape[1], + ) + + def fp4_shard_consumer(shards: list[torch.Tensor], rank: int) -> None: + _flashinfer_fp4_mm_out( + shards[0], + B, + scale_a=shards[1], + scale_b=B_scale, + alpha=alpha, + out=output_shards[rank], + out_dtype=out_dtype, + use_8x4_sf_layout=use_8x4_sf_layout, + backend=backend, + ) + + torch.distributed._symmetric_memory._pipelined_multi_all_gather_and_consume( + [A_shard, A_scale_shard], + fp4_shard_consumer, + [A, A_scale], + group_name, + False, + ) + return output + + direct_register_custom_op( op_name="fused_flashinfer_scaled_matmul_reduce_scatter", op_func=fused_flashinfer_scaled_matmul_reduce_scatter, @@ -209,6 +323,12 @@ def fused_all_gather_flashinfer_scaled_matmul( fake_impl=fused_all_gather_flashinfer_scaled_matmul_fake, ) +direct_register_custom_op( + op_name="fused_all_gather_flashinfer_fp4_matmul", + op_func=fused_all_gather_flashinfer_fp4_matmul, + fake_impl=fused_all_gather_flashinfer_fp4_matmul_fake, +) + class BasePattern: def __init__(self, dtype: torch.dtype, device: str | None) -> None: @@ -682,6 +802,101 @@ def _replacement( return _replacement +class FlashInferAllGatherFP4Pattern( + BasePattern, VllmPatternReplacement[..., torch.Tensor] +): + def __init__( + self, + dtype: torch.dtype, + device: str | None, + backend: str, + use_8x4_sf_layout: bool, + a_scale_view: str, + ) -> None: + super().__init__(dtype, device) + self.backend = backend + self.use_8x4_sf_layout = use_8x4_sf_layout + self.a_scale_view = a_scale_view + + def get_inputs(self) -> list[torch.Tensor]: + a_shard_2d = torch.empty([8, 8], device=self.device, dtype=torch.uint8) + b_2d = torch.empty([8, 16], device=self.device, dtype=torch.uint8) + a_scale_shard = torch.empty([128, 4], device=self.device, dtype=torch.int32) + b_scale = torch.empty([4, 128], device=self.device, dtype=torch.uint8) + alpha = torch.empty([], device=self.device, dtype=torch.float32) + return [ + a_shard_2d, + b_2d, + a_scale_shard, + b_scale, + alpha, + ] + + @property + def pattern(self) -> Callable[..., torch.Tensor]: + def _pattern( + a_shard_2d: torch.Tensor, + b_2d: torch.Tensor, + a_scale_shard: torch.Tensor, + b_scale: torch.Tensor, + alpha: torch.Tensor, + ) -> torch.Tensor: + all_gather_a = torch.ops.vllm.all_gather.default( + a_shard_2d, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name, + ) + all_gather_a_scale = torch.ops.vllm.all_gather.default( + a_scale_shard, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name, + ) + a_scale = all_gather_a_scale + if self.a_scale_view in ("float8", "float8_uint8"): + a_scale = torch.ops.aten.view.dtype(a_scale, torch.float8_e4m3fn) + if self.a_scale_view in ("uint8", "float8_uint8"): + a_scale = torch.ops.aten.view.dtype(a_scale, torch.uint8) + return torch.ops.vllm.flashinfer_mm_fp4.default( + all_gather_a, + b_2d, + a_scale, + b_scale, + alpha, + self.dtype, + self.use_8x4_sf_layout, + self.backend, + ) + + return _pattern + + @property + def replacement(self) -> Callable[..., torch.Tensor]: + def _replacement( + a_shard_2d: torch.Tensor, + b_2d: torch.Tensor, + a_scale_shard: torch.Tensor, + b_scale: torch.Tensor, + alpha: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.vllm.fused_all_gather_flashinfer_fp4_matmul.default( + a_shard_2d, + b_2d, + a_scale_shard, + b_scale, + alpha, + 0, + self.tp.device_group.group_name, + self.dtype, + self.a_scale_view in ("float8", "float8_uint8"), + self.use_8x4_sf_layout, + self.backend, + ) + + return _replacement + + class AsyncTPPass(VllmFusionPatternMatcherPass): @enable_fake_mode def __init__(self, config: VllmConfig) -> None: @@ -718,6 +933,34 @@ def __init__(self, config: VllmConfig) -> None: self.register( FlashInferBMMFP8ReduceScatterPattern(self.model_dtype, self.device) ) + if hasattr(torch.ops.vllm, "flashinfer_mm_fp4"): + for backend in ("cutlass", "cudnn"): + for a_scale_view in ("float8_uint8", "uint8"): + self.register( + FlashInferAllGatherFP4Pattern( + self.model_dtype, + self.device, + backend, + use_8x4_sf_layout=False, + a_scale_view=a_scale_view, + ) + ) + for use_8x4_sf_layout in (False, True): + for a_scale_view in ("float8",): + self.register( + FlashInferAllGatherFP4Pattern( + self.model_dtype, + self.device, + "trtllm", + use_8x4_sf_layout=use_8x4_sf_layout, + a_scale_view=a_scale_view, + ) + ) + # NVFP4 reduce-scatter does not need scale communication: FP4 + # scales are consumed by the local GEMM and only BF16 partial + # outputs are reduced. Keep this PR scoped to the all-gather + # path; reduce-scatter needs a dedicated FP4 producer rather + # than the existing FP8-style helper. self.dump_patterns(config, self.pm_pass) diff --git a/vllm/compilation/passes/fusion/sequence_parallelism.py b/vllm/compilation/passes/fusion/sequence_parallelism.py index 2c7a1390bdb8..f5255526c7d1 100644 --- a/vllm/compilation/passes/fusion/sequence_parallelism.py +++ b/vllm/compilation/passes/fusion/sequence_parallelism.py @@ -8,6 +8,7 @@ import torch import torch._inductor.pattern_matcher as pm import torch.fx as fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass import vllm.ir.ops @@ -27,6 +28,10 @@ logger = init_logger(__name__) +if hasattr(torch.ops._C, "scaled_fp4_quant"): + SCALED_FP4_QUANT_OUT_OVERLOAD = torch.ops._C.scaled_fp4_quant.out + SCALED_FP4_QUANT_DEFAULT_OVERLOAD = torch.ops._C.scaled_fp4_quant.default + # Min hidden size per device capability for sequence parallelism # Only apply sequence parallelism for models with hidden_size >= threshold SP_MIN_HIDDEN_SIZE: dict[int, int] = { @@ -332,6 +337,129 @@ def replacement( ) +class FirstAllReduceRMSNormStaticNVFP4Pattern(_SequenceParallelPatternHelper): + def get_inputs(self) -> list[torch.Tensor]: + input = self.empty([8, 16]) + weight = self.empty([16]) + input_global_scale = self.empty_f32([1, 1]) + quant_output = torch.empty([8, 8], device=self.device, dtype=torch.uint8) + output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) + return [input, weight, input_global_scale, quant_output, output_scale] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + input: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor, + quant_output: torch.Tensor, + output_scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + all_reduce = self._all_reduce(input) + rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon) + quant = auto_functionalized( + SCALED_FP4_QUANT_OUT_OVERLOAD, + input=rms, + input_scale=input_global_scale, + is_sf_swizzled_layout=True, + output=quant_output, + output_scale=output_scale, + ) + return quant[1], all_reduce, quant[2] + + def replacement( + input: torch.Tensor, + weight: torch.Tensor, + input_global_scale: torch.Tensor, + quant_output: torch.Tensor, + output_scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + reduce_scatter = self._reduce_scatter(input) + rms = vllm.ir.ops.rms_norm(reduce_scatter, weight, self.epsilon) + rms = torch.ops.aten.view.default(rms, [-1, rms.shape[-1]]) + quant = SCALED_FP4_QUANT_DEFAULT_OVERLOAD( + rms, + input_global_scale, + True, + ) + return ( + self._all_gather(quant[0]), + reduce_scatter, + self._all_gather(quant[1]), + ) + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + +class MiddleAllReduceRMSNormStaticNVFP4Pattern(_SequenceParallelPatternHelper): + def get_inputs(self) -> list[torch.Tensor]: + mm_1 = self.empty([8, 16]) + residual = self.empty([8, 16]) + rms_norm_weights = self.empty([16]) + input_global_scale = self.empty_f32([1, 1]) + quant_output = torch.empty([8, 8], device=self.device, dtype=torch.uint8) + output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32) + return [ + residual, + mm_1, + rms_norm_weights, + input_global_scale, + quant_output, + output_scale, + ] + + def register(self, pm_pass: PatternMatcherPass) -> None: + def pattern( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + input_global_scale: torch.Tensor, + quant_output: torch.Tensor, + output_scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + all_reduce = self._all_reduce(mm_1) + rms, residual_out = vllm.ir.ops.fused_add_rms_norm( + all_reduce, residual, rms_norm_weights, self.epsilon + ) + quant = auto_functionalized( + SCALED_FP4_QUANT_OUT_OVERLOAD, + input=rms, + input_scale=input_global_scale, + is_sf_swizzled_layout=True, + output=quant_output, + output_scale=output_scale, + ) + return quant[1], residual_out, quant[2] + + def replacement( + residual: torch.Tensor, + mm_1: torch.Tensor, + rms_norm_weights: torch.Tensor, + input_global_scale: torch.Tensor, + quant_output: torch.Tensor, + output_scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Keep this slice in sync with the non-quantized SP replacement: + # once the previous SP pattern fires, it becomes a no-op. + reduce_scatter = self._reduce_scatter(mm_1) + residual = residual[0 : reduce_scatter.size(0), ...] + rms, residual_out = vllm.ir.ops.fused_add_rms_norm( + reduce_scatter, residual, rms_norm_weights, self.epsilon + ) + rms = torch.ops.aten.view.default(rms, [-1, rms.shape[-1]]) + quant = SCALED_FP4_QUANT_DEFAULT_OVERLOAD( + rms, + input_global_scale, + True, + ) + return self._all_gather(quant[0]), residual_out, self._all_gather(quant[1]) + + pm.register_replacement( + pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass + ) + + class SequenceParallelismPass(VllmPatternMatcherPass): """ This pass enables sequence parallelism for models. @@ -404,6 +532,14 @@ def __init__(self, config: VllmConfig) -> None: epsilon, self.model_dtype, self.device ).register(self.patterns) + if "SCALED_FP4_QUANT_OUT_OVERLOAD" in globals(): + FirstAllReduceRMSNormStaticNVFP4Pattern( + epsilon, self.model_dtype, self.device + ).register(self.patterns) + MiddleAllReduceRMSNormStaticNVFP4Pattern( + epsilon, self.model_dtype, self.device + ).register(self.patterns) + # Normal RMSNorm patterns FirstAllReduceRMSNormPattern( epsilon, self.model_dtype, self.device diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 828ff08a067d..44fcc19c2d2b 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -685,6 +685,47 @@ def flashinfer_scaled_fp4_mm( ) +def flashinfer_scaled_fp4_mm_out( + a: torch.Tensor, + b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, + alpha: torch.Tensor, + out: torch.Tensor, + out_dtype: torch.dtype | None, + use_8x4_sf_layout: bool, + backend: str, +) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 and out.ndim == 2 + assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 + assert a.stride(-1) == 1 + assert a.shape[1] == b.shape[0] + assert out.shape == (a.shape[0], b.shape[1]) + assert out.device.type == "cuda" + + if backend in ("cutlass", "cudnn"): + if block_scale_a.dtype != torch.uint8: + block_scale_a = block_scale_a.view(torch.uint8) + if block_scale_b.dtype != torch.uint8: + block_scale_b = block_scale_b.view(torch.uint8) + + from flashinfer import mm_fp4 as flashinfer_mm_fp4_ + + flashinfer_mm_fp4_( + a, + b, + block_scale_a, + block_scale_b, + alpha, + out_dtype or out.dtype, + out=out, + block_size=16, + use_8x4_sf_layout=use_8x4_sf_layout, + backend=backend, + ) + return out + + def flashinfer_scaled_fp8_mm( a: torch.Tensor, b: torch.Tensor, @@ -864,6 +905,7 @@ def is_flashinfer_cudnn_fp8_prefill_attn_supported() -> bool: "can_use_trtllm_attention", "use_trtllm_attention", "flashinfer_scaled_fp4_mm", + "flashinfer_scaled_fp4_mm_out", "flashinfer_scaled_fp8_mm", "flashinfer_scaled_fp8_mm_out", "flashinfer_quant_nvfp4_8x4_sf_layout",