diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py index f5e2d9ddb752..996f72fc2025 100644 --- a/tests/compile/piecewise/test_multiple_graphs.py +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -6,7 +6,6 @@ """ import torch from torch import nn -from torch.library import Library from vllm.compilation.backends import set_model_tag from vllm.compilation.counter import compilation_counter @@ -14,11 +13,12 @@ support_torch_compile) from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) +from vllm.envs import VLLM_USE_V1 from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa +# This import automatically registers torch ops for testing (like silly.attention) +import tests.compile.testing_ops BATCH_SIZE = 32 MLP_SIZE = 128 @@ -26,27 +26,6 @@ RANDOM_SEED = 0 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @support_torch_compile class ParentModel(nn.Module): @@ -277,9 +256,5 @@ def test_multi_graph_piecewise_compile_outputs_equal(): outputs.append( run_model(vllm_config, model, inputs, cudagraph_runtime_mode)) - # Generally don't expect outputs with and without inductor - # to be bitwise equivalent - assert torch.allclose(outputs[0], outputs[1]) - # Expect bitwise equivalence using inductor w/ and w/o cudagraph assert torch.equal(outputs[0], outputs[2]) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 2d1a72d44ec7..4df709b1ba86 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -7,7 +7,6 @@ import pytest import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile @@ -15,34 +14,10 @@ VllmConfig, set_current_vllm_config) from vllm.envs import VLLM_USE_V1 from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -global_counter = 0 - -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa - - -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - global global_counter - global_counter += 1 - print(f"{global_counter=}") - out.copy_(q) - out[0] += 1 - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, +# This import also automatically registers torch ops for testing (like silly.attention) +from tests.compile.testing_ops import ( + get_global_counter, reset_global_counter ) @@ -58,9 +33,8 @@ def __init__(self, def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Overall effect: - x += 1 - x[0] += 2 + Overall effect with unified attention implementation: + input [0., 0.] -> final output [19., 19.] global_counter += 2 """ x = x + 1 @@ -121,13 +95,12 @@ def test_simple_piecewise_compile(use_inductor): model(torch.randn(1).cuda()) input = torch.zeros(2).cuda() - global global_counter - global_counter = 0 + reset_global_counter() with set_forward_context( None, vllm_config=vllm_config, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, batch_descriptor=BatchDescriptor(num_tokens=2, )): output = model(input) - assert global_counter == 2 - assert torch.allclose(output.cpu(), torch.tensor([3., 1.])) + assert get_global_counter() == 2 + assert torch.allclose(output.cpu(), torch.tensor([19., 19.])) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index bcfd0d834c5d..5088322ffab2 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -14,38 +14,15 @@ import pytest import torch from torch import nn -from torch.library import Library from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa - - -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) +# This import automatically registers torch ops for testing (like silly.attention) +import tests.compile.testing_ops @dataclass diff --git a/tests/compile/test_decorator.py b/tests/compile/test_decorator.py index 51f8ddd566d5..e6f6a6733017 100644 --- a/tests/compile/test_decorator.py +++ b/tests/compile/test_decorator.py @@ -2,44 +2,20 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from torch import nn -from torch.library import Library +# This import automatically registers torch ops for testing (like silly.attention) +import tests.compile.testing_ops from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import (ignore_torch_compile, support_torch_compile) from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, CUDAGraphMode, VllmConfig, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, set_forward_context -from vllm.utils import direct_register_custom_op - -# create a library to hold the custom op -silly_lib = Library("silly", "FRAGMENT") # noqa BATCH_SIZE = 32 MLP_SIZE = 128 -def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - out.copy_(q) - out += k - out += v - - -def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: - return - - -direct_register_custom_op( - op_name="attention", - op_func=silly_attention, - mutates_args=["out"], - fake_impl=silly_attention_fake, - target_lib=silly_lib, -) - - @torch.inference_mode def run_model(vllm_config: VllmConfig, model: nn.Module, cudagraph_runtime_mode: CUDAGraphMode): diff --git a/tests/compile/testing_ops.py b/tests/compile/testing_ops.py new file mode 100644 index 000000000000..5b25cf58ab58 --- /dev/null +++ b/tests/compile/testing_ops.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Shared PyTorch custom operations for compilation tests. + +Centralizes custom operation definitions to avoid duplicate registrations. +""" + +import torch +from torch.library import Library + +from vllm.utils import direct_register_custom_op + +# Shared library for all compilation test operations +# Using "silly" namespace to match existing test expectations +silly_lib = Library("silly", "FRAGMENT") + + +# Global counter that counts the number of times attention is invoked +_global_counter = 0 + + +def get_global_counter(): + """Get the current global counter value""" + return _global_counter + + +def reset_global_counter(): + """Reset the global counter to 0""" + global _global_counter + _global_counter = 0 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + """ + Unified attention implementation that depends on all inputs and affects the output. + Always increments a global counter that tests can use or ignore. + """ + global _global_counter + + # Always increment the global counter + _global_counter += 1 + + # Unified implementation that depends on all inputs + out.copy_(q + k + v) + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + """Fake implementation for testing""" + return + + +# Register the unified attention operation +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) \ No newline at end of file