diff --git a/vllm_spyre_next/vllm_spyre_next/testing/models.py b/vllm_spyre_next/vllm_spyre_next/testing/models.py index 07581cb23..8c154cd7f 100644 --- a/vllm_spyre_next/vllm_spyre_next/testing/models.py +++ b/vllm_spyre_next/vllm_spyre_next/testing/models.py @@ -51,6 +51,8 @@ class AllowEntry: param_allows: Parameter combinations to allow (whitelist). If specified, only these parameter values will run. param_overrides: Parameter values to replace upstream defaults with. + fixture_names: Fixture names to inject for this test (e.g. "foo" for a + custom fixture that prints "hello world"). """ test: str @@ -59,6 +61,7 @@ class AllowEntry: param_skips: tuple[ParamSkip, ...] = () param_allows: tuple[ParamAllow, ...] = () param_overrides: tuple[ParamOverride, ...] = () + fixture_names: tuple[str, ...] = () @dataclass(frozen=True) diff --git a/vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py b/vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py index 52c44bbf3..8cb0f8140 100644 --- a/vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py +++ b/vllm_spyre_next/vllm_spyre_next/testing/pytest_plugin.py @@ -43,6 +43,7 @@ from pathlib import Path import pytest +from vllm.v1.attention.backends.registry import AttentionBackendEnum import yaml from vllm_spyre_next.testing.models import ( @@ -119,6 +120,7 @@ def _parse_config(raw_tests: dict) -> UpstreamTestConfig: param_skips=tuple(param_skips), param_allows=tuple(param_allows), param_overrides=tuple(param_overrides), + fixture_names=tuple(allow.get("fixture_names", ())), ) ) block_list = [BlockEntry(test=b["test"]) for b in file_entry.get("block_list", [])] @@ -325,6 +327,26 @@ def _prepare_upstream_tests_dir() -> Path: return tests_dir +def _temp_upstream_code_edits(upstream_tests_dir: Path): + """Apply small code edits to the upstream tests directory before importing. + + These should be _temporary_ edits to source code for vllm tests while we work to make them more + portable. This should only be used where mocking is not possible or too cumbersome. + """ + + # Mocking out torch.device seems impossible to do (at least multiple rounds of Bob and Claude + # were unsuccessful). So we patch the source code to change the hardcoded + # `torch.device("cuda:0")` to `torch.device("cpu")`. + hardcoded_cuda_test_path = ( + upstream_tests_dir / "v1" / "attention" / "test_attention_backends.py" + ) + with open(hardcoded_cuda_test_path) as f: + content = f.read() + content = content.replace('torch.device("cuda:0")', 'torch.device("cpu")') + with open(hardcoded_cuda_test_path, "w") as f: + f.write(content) + + # --------------------------------------------------------------------------- # Pytest Hooks # --------------------------------------------------------------------------- @@ -365,6 +387,7 @@ def pytest_configure(config): try: # Clone vLLM to cache upstream_tests_base = _prepare_upstream_tests_dir() + _temp_upstream_code_edits(upstream_tests_base) config._upstream_tests_base = upstream_tests_base # Determine which test paths to inject @@ -498,6 +521,10 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item elif allow_entry.mode == "xfail_strict": item.add_marker(pytest.mark.xfail(strict=True)) + # Inject fixtures for tests that have fixture_names defined + for fixture_name in allow_entry.fixture_names: + item.fixturenames.append(fixture_name) + # Reorder tests so that tests with "model" in the name run first _reorder_tests_by_name(items) @@ -597,6 +624,33 @@ def should_do_global_cleanup_after_test(): return False +@pytest.fixture() +def patch_backend_list(request, monkeypatch): + """This fixture patches things for tests/v1/attention/test_attention_backends.py""" + + # The BACKENDS_TO_TEST list has to be patched with only our backend + our_backend_list = [ + AttentionBackendEnum.CUSTOM, + ] + test_module = request.node.module + monkeypatch.setattr(test_module, "BACKENDS_TO_TEST", our_backend_list) + + # _test_backend_correctness may be called with a hardcoded AttentionBackendEnum.FLASH_ATTN, + # which we want to ignore + orig_tbc = test_module._test_backend_correctness + + def tbc_wrapper( + batch_spec, model, backend_to_test: list[AttentionBackendEnum | str], *args, **kwargs + ): + if "AttentionBackendEnum.FLEX_ATTENTION" in str(backend_to_test): + return + return orig_tbc(batch_spec, model, backend_to_test, *args, **kwargs) + + monkeypatch.setattr(test_module, "_test_backend_correctness", tbc_wrapper) + + yield + + @pytest.hookimpl(tryfirst=True) def pytest_fixture_setup(fixturedef, request): """Override fixtures when running upstream vLLM tests.""" diff --git a/vllm_spyre_next/vllm_spyre_next/testing/upstream_tests.yaml b/vllm_spyre_next/vllm_spyre_next/testing/upstream_tests.yaml index d273502c7..3ee1f7737 100644 --- a/vllm_spyre_next/vllm_spyre_next/testing/upstream_tests.yaml +++ b/vllm_spyre_next/vllm_spyre_next/testing/upstream_tests.yaml @@ -1,7 +1,7 @@ # Upstream test filter configuration for vllm-spyre-next. # # Only tests listed here will run from upstream vLLM. All other upstream -# tests are skipped by default (opt-in / whitelist model). +# tests are skipped by default (opt-in / allowlist model). # # block_list entries take precedence over allow_list entries. # @@ -14,6 +14,9 @@ # Parameter name -> list of values to skip # allow_list[].params.override # Parameter name -> replacement values (replaces upstream defaults) +# allow_list[].fixture_names Fixture names to inject for this test. +# These fixtures are automatically added to the test's +# fixturenames during collection. # block_list[].test fnmatch glob matched against test function name tests: @@ -42,8 +45,18 @@ tests: tags: [facebook, upstream, uses_subprocess] params: allow: # skip every model except facebook/opt-125m - model: + model: - facebook/opt-125m - block_list: - - test: "test_fused_rms_norm_quant" + - rel_path: tests/v1/attention/test_attention_backends.py + allow_list: + - test: "test_causal_backend_correctness" + mode: mandatory_pass + tags: [attention, upstream] + params: + allow: # skip TP cases that we don't support + tensor_parallel_size: + - 1 + fixture_names: + - "patch_backend_list" +