diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 4f0db88fe702..2b80937e8580 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -610,6 +610,8 @@ steps: --ignore=lora/test_qwen3moe_tp.py parallelism: 4 +##### .buildkite/test_areas/pytorch.yaml ##### +# corresponds to .buildkite/test_areas/pytorch.yaml - label: PyTorch Compilation Unit Tests # 15min timeout_in_minutes: 30 mirror_hardwares: [amdexperimental, amdproduction] @@ -627,6 +629,20 @@ steps: # they do not suffer from https://github.com/vllm-project/vllm/issues/28965 - "find compile/ -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;" +# corresponds to .buildkite/test_areas/pytorch.yaml +- label: PyTorch Compilation Passes Unit Tests + timeout_in_minutes: 20 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + source_file_dependencies: + - vllm/ + - tests/compile/passes + commands: + # TODO: clean up this comment if not needed. It is used to + # keep track of the tests changes during vLLM IR Ops refactoring. + # Use `find` to launch multiple instances of pytest. + - "find compile/passes -maxdepth 1 -name 'test_*.py' -exec pytest -s -v {} \\\\;" + - label: PyTorch Fullgraph Smoke Test # 15min timeout_in_minutes: 30 mirror_hardwares: [amdexperimental, amdproduction] @@ -1211,41 +1227,6 @@ steps: - pytest -v -s tests/kernels/moe/test_flashinfer.py - pytest -v -s tests/kernels/moe/test_cutedsl_moe.py -- label: Blackwell Fusion and Compile Tests # 30 min - timeout_in_minutes: 40 - working_dir: "/vllm-workspace/" - gpu: b200 - source_file_dependencies: - - csrc/quantization/fp4/ - - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py - - vllm/v1/attention/backends/flashinfer.py - - vllm/v1/worker/ - - vllm/v1/cudagraph_dispatcher.py - - vllm/compilation/ - # can affect pattern matching - - vllm/model_executor/layers/layernorm.py - - vllm/model_executor/layers/activation.py - - vllm/model_executor/layers/quantization/input_quant_fp8.py - - tests/compile/passes/test_fusion_attn.py - - tests/compile/passes/test_silu_mul_quant_fusion.py - - tests/compile/passes/distributed/test_fusion_all_reduce.py - - tests/compile/fullgraph/test_full_graph.py - commands: - - nvidia-smi - - pytest -v -s tests/compile/passes/test_fusion_attn.py - - pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py - # this runner has 2 GPUs available even though num_gpus=2 is not set - - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py - - # # Limit to Inductor partition, no custom ops, and allreduce & attn fusion to reduce running time - # # Wrap with quotes to escape yaml - # - "pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm -k 'True and not +quant_fp8 and not +rms_norm'" - # Old E2E tests were removed in https://github.com/vllm-project/vllm/pull/33293 - # in favor of new tests in fusions_e2e. We avoid replicating the new jobs in this file as it's deprecated. - - # test_fp8_kv_scale_compile requires FlashAttention (not supported on default L4/L40) - - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile - - label: Blackwell GPT-OSS Eval timeout_in_minutes: 60 working_dir: "/vllm-workspace/" @@ -1371,7 +1352,6 @@ steps: - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - VLLM_TEST_SAME_HOST=1 VLLM_TEST_WITH_DEFAULT_DEVICE_SET=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - - pytest -v -s compile/correctness_e2e/test_sequence_parallel.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - pytest -v -s v1/worker/test_worker_memory_snapshot.py @@ -1601,16 +1581,16 @@ steps: commands: - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/passes/distributed/test_async_tp.py - pytest -v -s tests/compile/passes/distributed/test_sequence_parallelism.py - - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py + # TODO: this test is not supported on ROCm, there are aiter kernels for this. + # - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py #- pytest -v -s tests/compile/distributed/test_fusions_e2e.py::test_tp2_attn_quant_allreduce_rmsnorm # - "VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/distributed/test_fusions_e2e.py -k 'not Llama-4'" # Old E2E tests were removed in https://github.com/vllm-project/vllm/pull/33293 # in favor of new tests in fusions_e2e. We avoid replicating the new jobs in this file as it's deprecated. - - - VLLM_TEST_CLEAN_GPU_MEMORY=1 pytest -v -s tests/compile/correctness_e2e/test_sequence_parallel.py - pytest -v -s tests/distributed/test_context_parallel.py - HIP_VISIBLE_DEVICES=0,1 VLLM_LOGGING_LEVEL=DEBUG python3 examples/offline_inference/data_parallel.py --model=Qwen/Qwen1.5-MoE-A2.7B -tp=1 -dp=2 --max-model-len=2048 --all2all-backend=allgather_reducescatter --disable-nccl-for-dp-synchronization - - pytest -v -s tests/v1/distributed/test_dbo.py + # this test is not supported on ROCm + # - pytest -v -s tests/v1/distributed/test_dbo.py ##### B200 test ##### - label: Distributed Tests (B200) # optional @@ -1721,6 +1701,93 @@ steps: commands: - bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040 +##### .buildkite/test_areas/compile.yaml ##### +# Slowly setting up the tests so that it is also easier for the +# CI team to review and upstream to the pipelinev2. +# The following tests are important for vLLM IR Ops refactoring, +# which affects fusion passes on ROCm. So we have to +# enable them as as soon as possible. + +## TODO: Enable the test in this group +# # corresponds to .buildkite/test_areas/compile.yaml +# - label: Fusion and Compile Unit Tests (2xMI325 GPUs) +# timeout_in_minutes: 20 +# working_dir: "/vllm-workspace/" +# mirror_hardwares: [amdexperimental, amdproduction, tj] +# agent_pool: mi325_1 # changed to 1 GPU until the fusion all reduce is enabled then only revert back to 2 GPUs +# source_file_dependencies: +# - csrc/quantization/fp4/ +# - vllm/model_executor/layers/quantization/ +# - vllm/model_executor/layers/layernorm.py +# - vllm/model_executor/layers/activation.py +# - vllm/model_executor/layers/attention/attention.py +# - vllm/v1/attention/backends/flashinfer.py +# - vllm/compilation/ # TODO(luka) limit to vllm/compilation/passes +# - tests/compile/test_fusion_attn.py +# - tests/compile/test_silu_mul_quant_fusion.py +# - tests/compile/distributed/test_fusion_all_reduce.py +# - tests/compile/fullgraph/test_full_graph.py +# commands: +# - rocm-smi +# # we run all backend tests on ROCm +# # These two tests are covered in "PyTorch Compilation Passes Unit Tests" +# # - "pytest -v -s tests/compile/passes/test_fusion_attn.py" +# # - "pytest -v -s tests/compile/passes/test_silu_mul_quant_fusion.py" +# # TODO: this test is not supported on ROCm, there are aiter kernels for this. +# # - pytest -v -s tests/compile/passes/distributed/test_fusion_all_reduce.py +# # TODO: find out more details +# # - pytest -v -s tests/compile/fullgraph/test_full_graph.py::test_fp8_kv_scale_compile + +# corresponds to .buildkite/test_areas/compile.yaml +- label: Fusion E2E Quick (MI325) + timeout_in_minutes: 15 + working_dir: "/vllm-workspace/" + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + num_devices: 1 + source_file_dependencies: + - csrc/quantization/ + - vllm/model_executor/ + - vllm/v1/attention/ + - vllm/compilation/ + - tests/compile/fusions_e2e/ + commands: + - rocm-smi + # Run all models and attn backends but only Inductor partition and native custom ops + - "pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k 'inductor_partition and not +rms_norm and not +quant_fp8'" + # Different from CUDA, Qwen requires +rms_norm and +quant_fp8 as rms+quant fusion is only supported on AITER + - "pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k 'inductor_partition and +rms_norm and +quant_fp8 and qwen3'" + +# corresponds to .buildkite/test_areas/compile.yaml +- label: Fusion E2E Config Sweep (MI325) + timeout_in_minutes: 30 + working_dir: "/vllm-workspace/" + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + num_devices: 1 + source_file_dependencies: + - csrc/quantization/ + - vllm/compilation/ + # can affect pattern matching + - vllm/model_executor/layers/layernorm.py + - vllm/model_executor/layers/activation.py + - vllm/model_executor/layers/attention/attention.py + - vllm/model_executor/layers/quantization/input_quant_fp8.py + - tests/compile/fusions_e2e/ + commands: + - rocm-smi + # Run just llama3 (fp8) for all config combinations + - pytest -v -s tests/compile/fusions_e2e/test_tp1_quant.py -k "llama-3" + +## There are no ops on ROCm for these tests. +## The test still passes but the logs are not useful. +## fused ops just call torch.ops.symm_mem which +## exists in ROCm even though they don't work +# - label: AsyncTP Correctness Tests (2xMI325 GPUs) +# - label: Fusion E2E TP2 Quick (MI325) +# - label: Fusion E2E TP2 AsyncTP Config Sweep (MI325) +# - label: Fusion E2E TP2 (MI325) +# - label: Sequence Parallel Correctness Tests (2xMI325 GPUs) ##################################################################################################################################### diff --git a/tests/compile/fusions_e2e/common.py b/tests/compile/fusions_e2e/common.py index 284a9d66b957..2c6dc2b3ebbc 100644 --- a/tests/compile/fusions_e2e/common.py +++ b/tests/compile/fusions_e2e/common.py @@ -13,6 +13,7 @@ class Matches(NamedTuple): # simple pointwise + aiter_rms_quant_fusion: int = 0 rms_quant_fusion: int = 0 act_quant_fusion: int = 0 norm_rope_fusion: int = 0 @@ -82,6 +83,9 @@ def has_cuda_graph_wrapper_metadata() -> bool: ] FUSION_LOG_PATTERNS: dict[str, re.Pattern] = { + "aiter_rms_quant_fusion": re.compile( + r"RocmAiterRMSNormQuantFusionPass Replaced (\d+) patterns" + ), "rms_quant_fusion": re.compile(r"rms_quant_fusion.py:\d+] Replaced (\d+) patterns"), "act_quant_fusion": re.compile(r"act_quant_fusion.py:\d+] Replaced (\d+) patterns"), "norm_rope_fusion": re.compile( diff --git a/tests/compile/fusions_e2e/conftest.py b/tests/compile/fusions_e2e/conftest.py index 40b4de57f66f..d083b6f14e4b 100644 --- a/tests/compile/fusions_e2e/conftest.py +++ b/tests/compile/fusions_e2e/conftest.py @@ -63,9 +63,14 @@ def run( compilation_config: dict, matches_check: list[str], use_deepgemm: bool = False, + use_aiter: bool = False, tp_size: int = 1, ): monkeypatch.setenv("VLLM_USE_DEEP_GEMM", "1" if use_deepgemm else "0") + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1" if use_aiter else "0") + from vllm._aiter_ops import rocm_aiter_ops + + rocm_aiter_ops.refresh_env_variables() # Disable, compile cache to make sure custom passes run. # Otherwise, we can't verify fusion happened through the logs. diff --git a/tests/compile/fusions_e2e/models.py b/tests/compile/fusions_e2e/models.py index f54f617c64d4..e18bc1ee5652 100644 --- a/tests/compile/fusions_e2e/models.py +++ b/tests/compile/fusions_e2e/models.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +from vllm._aiter_ops import is_aiter_found_and_supported +from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer from vllm.v1.attention.backends.registry import AttentionBackendEnum @@ -24,6 +26,24 @@ AttentionBackendCase(backend=AttentionBackendEnum.TRITON_ATTN), id="TRITON_ATTN" ) +ROCM_ATTN = pytest.param( + AttentionBackendCase(backend=AttentionBackendEnum.ROCM_ATTN), + id="ROCM_ATTN", + marks=pytest.mark.skipif( + not current_platform.is_rocm(), + reason="ROCm attention only for AMD", + ), +) + +ROCM_AITER_UNIFIED_ATTN = pytest.param( + AttentionBackendCase(backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN), + id="ROCM_AITER_UNIFIED_ATTN", + marks=pytest.mark.skipif( + not is_aiter_found_and_supported(), + reason="ROCM_AITER_UNIFIED_ATTN only for AMD when AITER is installed", + ), +) + # Models llama3_8b = ModelFusionInfo( model_name="meta-llama/Llama-3.1-8B-Instruct", @@ -49,7 +69,6 @@ llama3_8b_fp4 = ModelFusionInfo( model_name="nvidia/Llama-3.1-8B-Instruct-FP4", matches=lambda n_layers: Matches( - rms_quant_fusion=0, act_quant_fusion=n_layers, attn_quant_fusion=n_layers, ar_rms_fusion=n_layers * 2 + 1, @@ -79,7 +98,6 @@ model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-NVFP4", hf_overrides=lambda n_layers: {"text_config": {"num_hidden_layers": n_layers}}, matches=lambda n_layers: Matches( - rms_quant_fusion=0, attn_quant_fusion=n_layers, ar_rms_fusion=n_layers * 2, sequence_parallel=n_layers * 2, diff --git a/tests/compile/fusions_e2e/test_tp1_quant.py b/tests/compile/fusions_e2e/test_tp1_quant.py index f98400c2e26d..917116515f89 100644 --- a/tests/compile/fusions_e2e/test_tp1_quant.py +++ b/tests/compile/fusions_e2e/test_tp1_quant.py @@ -5,6 +5,7 @@ import pytest from vllm.config import PassConfig +from vllm.platforms import current_platform from vllm.utils.flashinfer import is_flashinfer_fp8_blockscale_gemm_supported from .common import ( @@ -16,6 +17,8 @@ ) from .models import ( FLASHINFER_ATTN, + ROCM_AITER_UNIFIED_ATTN, + ROCM_ATTN, TRITON_ATTN, llama3_8b_fp4, llama3_8b_fp8, @@ -29,12 +32,33 @@ "model_name, matches_fn, model_kwargs, hf_overrides, use_deepgemm", [ (*llama3_8b_fp8, False), - (*llama4_scout_fp8, False), (*qwen3_a3b_fp8, False), - (*qwen3_a3b_fp8, True), + pytest.param( + *llama4_scout_fp8, + False, + marks=pytest.mark.skipif( + not current_platform.is_cuda(), + reason="Llama4 Scout FP8 only supported on CUDA", + ), + ), + pytest.param( + *qwen3_a3b_fp8, + True, + marks=pytest.mark.skipif( + not current_platform.is_cuda(), reason="DeepGemm only supported on CUDA" + ), + ), + ], +) +@pytest.mark.parametrize( + "attn_backend", + [ + TRITON_ATTN, + FLASHINFER_ATTN, + ROCM_ATTN, + ROCM_AITER_UNIFIED_ATTN, ], ) -@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN]) @pytest.mark.parametrize("n_layers", [6]) @pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm")) @pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION) @@ -81,6 +105,8 @@ def test_tp1_fp8_fusions( ), ) + use_aiter = current_platform.is_rocm() and ("qwen" in model_name.lower()) + matches_check = [ "rms_quant_fusion", "act_quant_fusion", @@ -88,6 +114,15 @@ def test_tp1_fp8_fusions( "attn_quant_fusion", ] + if use_aiter: + matches_check[0] = "aiter_rms_quant_fusion" + + matches = matches._replace(aiter_rms_quant_fusion=matches.rms_quant_fusion) + # TODO: enable the `norm_rope_fusion` test, + # On ROCm norm_rope_fusion is only supported without + # enabling AITER. + matches_check.remove("norm_rope_fusion") + run_e2e_fusion_test( model_name, matches, @@ -96,6 +131,7 @@ def test_tp1_fp8_fusions( compilation_config, matches_check, use_deepgemm=use_deepgemm, + use_aiter=use_aiter, ) diff --git a/tests/compile/fusions_e2e/test_tp2_ar_rms.py b/tests/compile/fusions_e2e/test_tp2_ar_rms.py index 18b19565c1fc..ab4aefcaf79a 100644 --- a/tests/compile/fusions_e2e/test_tp2_ar_rms.py +++ b/tests/compile/fusions_e2e/test_tp2_ar_rms.py @@ -5,6 +5,7 @@ import pytest from vllm.config import PassConfig +from vllm.platforms import current_platform from ...utils import multi_gpu_test from .common import ( @@ -26,6 +27,8 @@ qwen3_a3b_fp8, ) +pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( diff --git a/tests/compile/fusions_e2e/test_tp2_async_tp.py b/tests/compile/fusions_e2e/test_tp2_async_tp.py index 921839ea0692..9657d64b88f7 100644 --- a/tests/compile/fusions_e2e/test_tp2_async_tp.py +++ b/tests/compile/fusions_e2e/test_tp2_async_tp.py @@ -5,6 +5,7 @@ import pytest from vllm.config import PassConfig +from vllm.platforms import current_platform from ...utils import multi_gpu_test from .common import ( @@ -23,6 +24,8 @@ qwen3_a3b, ) +pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize( diff --git a/tests/compile/passes/distributed/test_sequence_parallelism.py b/tests/compile/passes/distributed/test_sequence_parallelism.py index 78c3cf92a067..a0fe717ba026 100644 --- a/tests/compile/passes/distributed/test_sequence_parallelism.py +++ b/tests/compile/passes/distributed/test_sequence_parallelism.py @@ -36,6 +36,8 @@ from vllm.utils.system_utils import update_environment_variables from vllm.utils.torch_utils import set_random_seed +pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") + FP8_DTYPE = current_platform.fp8_dtype() prompts = [ "Hello, my name is", diff --git a/tests/compile/passes/test_silu_mul_quant_fusion.py b/tests/compile/passes/test_silu_mul_quant_fusion.py index cc06208ea758..a77b4e6de7bd 100644 --- a/tests/compile/passes/test_silu_mul_quant_fusion.py +++ b/tests/compile/passes/test_silu_mul_quant_fusion.py @@ -182,8 +182,24 @@ def ops_in_model_after(self): "model_class, enable_quant_fp8_custom_op, force_kernel", list(itertools.product([TestSiluMulFp8QuantModel], [True, False], TEST_KERNELS)) + [ - (TestSiluMulNvfp4QuantModel, False, None), - (TestSiluMulGroupFp8QuantModel, False, None), + pytest.param( + TestSiluMulNvfp4QuantModel, + False, + None, + marks=pytest.mark.skipif( + not current_platform.is_cuda(), reason="CUDA only" + ), + ), + # GroupFP8Quant fusion only works with AITER on ROCm. + # and the enable_quant_fp8_custom_op must be True. + pytest.param( + TestSiluMulGroupFp8QuantModel, + True, + None, + marks=pytest.mark.skipif( + not current_platform.is_rocm(), reason="ROCm only" + ), + ), ], ) @pytest.mark.skipif( @@ -201,6 +217,7 @@ def test_fusion_silu_and_mul_quant( enable_silu_mul_custom_op: bool, enable_quant_fp8_custom_op: bool, force_kernel: FP8ScaledMMLinearKernel | None, + monkeypatch: pytest.MonkeyPatch, ): if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): pytest.skip("NVFP4 is not supported on this GPU.") @@ -227,13 +244,16 @@ def test_fusion_silu_and_mul_quant( ), ) - with set_current_vllm_config(config): + with set_current_vllm_config(config), monkeypatch.context() as m: fusion_passes = [ActivationQuantFusionPass(config)] - if IS_AITER_FOUND: + if IS_AITER_FOUND and model_class is TestSiluMulGroupFp8QuantModel: + from vllm._aiter_ops import rocm_aiter_ops from vllm.compilation.passes.fusion.rocm_aiter_fusion import ( RocmAiterSiluMulFp8GroupQuantFusionPass, ) + m.setenv("VLLM_ROCM_USE_AITER", "1") + rocm_aiter_ops.refresh_env_variables() fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)] passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)] diff --git a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py index d8131ce952d2..59c94db5e812 100644 --- a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py +++ b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py @@ -5,7 +5,6 @@ import torch._inductor.pattern_matcher as pm from torch import fx from torch._inductor.pattern_matcher import PatternMatcherPass -from torch._ops import OpOverload import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 from vllm._aiter_ops import rocm_aiter_ops @@ -15,6 +14,7 @@ GroupShape, QuantKey, ScaleDesc, + kFp8Dynamic128Sym, ) from vllm.platforms import current_platform @@ -312,7 +312,9 @@ def __init__(self, config: VllmConfig) -> None: @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph) -> None: self.matched_count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns", self.matched_count) + logger.debug( + "%s Replaced %s patterns", self.__class__.__name__, self.matched_count + ) def uuid(self) -> str: fusion_patterns = [ @@ -332,9 +334,11 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op() - def __init__(self, quant_op: OpOverload) -> None: + def __init__(self) -> None: self.silu_and_mul_matcher = MatcherSiluAndMul() - self.quant_op = quant_op + self.quant_matcher = MatcherQuantFP8( + quant_key=kFp8Dynamic128Sym, match_rocm_aiter=True + ) def get_inputs(self) -> list[torch.Tensor]: return [ @@ -346,7 +350,7 @@ def pattern( input: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: at1 = self.silu_and_mul_matcher(input) - at2 = self.quant_op(at1, 128) + at2 = self.quant_matcher(at1) return at2[0], at2[1] def replacement( @@ -370,11 +374,6 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 """ - AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op() - TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default - - QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP] - @enable_fake_mode def __init__(self, config: VllmConfig) -> None: super().__init__(config) @@ -383,8 +382,7 @@ def __init__(self, config: VllmConfig) -> None: pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass" ) - for quant_op in self.QUANT_OPS: - AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns) + AiterSiluMulFp8GroupQuantPattern().register(self.patterns) self.dump_patterns(config, self.patterns)