Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions tests/compile/backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import weakref
from collections.abc import Sequence
from copy import deepcopy
from typing import Callable, Union
Expand All @@ -10,7 +11,25 @@

from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.inductor_pass import InductorPass
from vllm.config import get_current_vllm_config
from vllm.compilation.pass_manager import with_pattern_match_debug
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import VllmConfig, get_current_vllm_config


class LazyInitPass(InductorPass):
"""
If there's a pass that we want to initialize lazily in a test,
we can wrap it in LazyInitPass, which will initialize the pass when invoked
and then immediately invoke it.
"""

def __init__(self, pass_cls: type[VllmInductorPass],
vllm_config: VllmConfig):
self.pass_cls = pass_cls
self.vllm_config = weakref.proxy(vllm_config) # avoid cycle

def __call__(self, graph: fx.Graph) -> None:
self.pass_cls(self.vllm_config)(graph)


class TestBackend:
Expand Down Expand Up @@ -40,6 +59,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs):
example_inputs,
config_patches=self.inductor_config)

@with_pattern_match_debug
def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)
for pass_ in self.custom_passes:
Expand All @@ -64,4 +84,4 @@ def check_after_ops(self, ops: Sequence[OpOverload]):
num_pre = len(list(find_op_nodes(op, self.graph_pre_pass)))
num_post = len(list(find_op_nodes(op, self.graph_post_pass)))
assert num_pre == 0, f"Unexpected op {op.name()} in pre-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
assert num_post > 0, f"Op {op.name()} not found in post-pass graph"
10 changes: 9 additions & 1 deletion tests/compile/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import vllm
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.config import CompilationConfig, VllmConfig
from vllm.utils import _is_torch_equal_or_newer


Expand All @@ -26,6 +26,14 @@ def test_use_cudagraphs_dynamic(monkeypatch):
assert not vllm_config.compilation_config.use_cudagraph


def test_custom_op():
# proper syntax
_ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

with pytest.raises(ValueError, match="Invalid syntax '"):
_ = CompilationConfig(custom_ops=["quant_fp8"])


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
Expand Down
10 changes: 6 additions & 4 deletions tests/compile/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FUSED_OPS, FusionPass
from vllm.compilation.fusion import FUSED_OPS, RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
Expand Down Expand Up @@ -58,11 +59,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)

passes = [noop_pass, fusion_pass, act_quant_fusion_pass
] if do_fusion else [noop_pass]
passes = [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass
] if do_fusion else [noop_pass, cleanup_pass]
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
Expand Down
15 changes: 8 additions & 7 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import pytest
import torch

import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass)
RMSNormQuantFusionPass)
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -79,15 +79,15 @@ def ops_in_model_after(self):


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.parametrize("num_tokens", [257])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize("cuda_force_torch",
[True, False] if cutlass_fp8_supported() else [True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test on CUDA and ROCm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cuda_force_torch):
Expand All @@ -104,9 +104,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

backend = TestBackend(noop_pass, fusion_pass)
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
model = TestModel(hidden_size, eps, static, cuda_force_torch)

# First dimension dynamic
Expand Down
5 changes: 4 additions & 1 deletion tests/compile/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
ModelConfig, PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
Expand Down Expand Up @@ -215,8 +216,10 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass,
cleanup_pass)

token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
Expand Down
27 changes: 14 additions & 13 deletions tests/compile/test_fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@
import pytest
import torch._dynamo

from tests.compile.backend import TestBackend
from tests.compile.backend import LazyInitPass, TestBackend
from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata)
from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention
from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
ModelConfig, PassConfig, SchedulerConfig, VllmConfig,
set_current_vllm_config)
Expand All @@ -40,13 +41,12 @@
@pytest.mark.parametrize(
"model, quant_key",
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
@pytest.mark.parametrize(
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
@pytest.mark.parametrize("use_triton_fa", [True, False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test CUDA and ROCm")
def test_attention_fusion(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="V0 attention fusion only supported on ROCm.")
def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool):
# Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough)
torch._dynamo.reset()
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,

# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw)
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(model,
enforce_eager=True,
Expand Down Expand Up @@ -188,7 +188,7 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
device=self.device,
)

def build_attn_metadata(self, batch_size: int):
def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
"""Initialize attention metadata."""

# Create common attn metadata
Expand Down Expand Up @@ -393,9 +393,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,

# Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config)
attn_pass = lambda *args, **kw: AttnFusionPass(vllm_config)(*args, **kw
)
test_backend = TestBackend(noop_pass, attn_pass)
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

test_backend = TestBackend(noop_pass, attn_pass, cleanup_pass)

# Compile model with fusion enabled
model_compiled = torch.compile(model_fused,
Expand Down
19 changes: 12 additions & 7 deletions tests/compile/test_sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

import vllm.envs as envs
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import FusionPass
from vllm.compilation.fusion import RMSNormQuantFusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.compilation.vllm_inductor_pass import VllmInductorPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
Expand Down Expand Up @@ -104,7 +106,7 @@ def __init__(self,
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)

self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
self.fp8_linear = Fp8LinearOp(act_quant_static=True)

self.scale = torch.rand(1, dtype=torch.float32)
# Create a weight that is compatible with torch._scaled_mm,
Expand Down Expand Up @@ -137,8 +139,7 @@ def forward(self, hidden_states, residual):
# layer normalization
norm_output, residual_output = self.norm(all_reduce, residual)

# for static input quantization
# self.fp8_linear is initialized with use_per_token_if_dynamic=False
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(norm_output,
self.w,
self.wscale,
Expand Down Expand Up @@ -253,16 +254,20 @@ def sequence_parallelism_pass_on_test_model(
dtype=dtype,
seed=42)

sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
sequence_parallelism_pass = SequenceParallelismPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
cleanup_pass = PostCleanupPass(vllm_config)

passes_for_backend = [noop_pass, sequence_parallelism_pass]
passes_for_backend: list[VllmInductorPass] = \
[noop_pass, sequence_parallelism_pass]

if enable_fusion:
fusion_pass = FusionPass.instance(vllm_config)
fusion_pass = RMSNormQuantFusionPass(vllm_config)
passes_for_backend.append(fusion_pass)

passes_for_backend.append(cleanup_pass)

backend_no_func = TestBackend(*passes_for_backend)
backend_func = TestBackend(*passes_for_backend, func_pass)

Expand Down
11 changes: 10 additions & 1 deletion tests/compile/test_silu_mul_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# yapf: enable
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand Down Expand Up @@ -66,6 +67,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):

def __init__(self, hidden_size: int, **kwargs):
super().__init__()
from vllm.compilation.activation_quant_fusion import (
silu_and_mul_nvfp4_quant_supported)
assert silu_and_mul_nvfp4_quant_supported

self.silu_and_mul = SiluAndMul()
self.w = torch.randint(256, (hidden_size, hidden_size // 2),
dtype=FP4_DTYPE)
Expand Down Expand Up @@ -117,7 +122,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
fusion_pass = ActivationQuantFusionPass(config)

backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
passes = [
NoOpEliminationPass(config), fusion_pass,
PostCleanupPass(config)
]
backend = TestBackend(*passes)
model = model_class(hidden_size=hidden_size,
cuda_force_torch=cuda_force_torch)

Expand Down
Loading