Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 2 additions & 79 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ jobs:
# Lower bound support
- vllm_version:
name: "vLLM:lowest"
repo: "git+https://github.com/vllm-project/vllm --tag v0.11.0"
repo: "git+https://github.com/vllm-project/vllm --tag v0.17.0"
test_suite:
name: "backward compat"
markers: "compat or (cpu and basic and not quantized and not sb)"
Expand All @@ -94,86 +94,9 @@ jobs:
os: "ubuntu-latest"
python_version: "3.12"
# Intermediate versions of vllm to check basic support for as well
- vllm_version:
name: "vLLM:0.11.1"
repo: "git+https://github.com/vllm-project/vllm --tag v0.11.1"
test_suite:
name: "backward compat"
markers: "cpu and basic and not quantized and not sb"
flags: "--timeout=300"
hf_model_2: "sentence-transformers/all-roberta-large-v1"
hf_model_2_rev: "cf74d8acd4f198de950bf004b262e6accfed5d2c"
os: "ubuntu-latest"
python_version: "3.12"
- vllm_version:
name: "vLLM:0.11.2"
repo: "git+https://github.com/vllm-project/vllm --tag v0.11.2"
test_suite:
name: "backward compat"
markers: "cpu and basic and not quantized and not sb"
flags: "--timeout=300"
hf_model_2: "sentence-transformers/all-roberta-large-v1"
hf_model_2_rev: "cf74d8acd4f198de950bf004b262e6accfed5d2c"
os: "ubuntu-latest"
python_version: "3.12"
- vllm_version:
name: "vLLM:0.12.0"
repo: "git+https://github.com/vllm-project/vllm --tag v0.12.0"
test_suite:
name: "backward compat"
markers: "cpu and basic and not quantized and not sb"
flags: "--timeout=300"
hf_model_2: "sentence-transformers/all-roberta-large-v1"
hf_model_2_rev: "cf74d8acd4f198de950bf004b262e6accfed5d2c"
os: "ubuntu-latest"
python_version: "3.12"
- vllm_version:
name: "vLLM:0.13.0"
repo: "git+https://github.com/vllm-project/vllm --tag v0.13.0"
test_suite:
name: "backward compat"
markers: "cpu and basic and not quantized and not sb"
flags: "--timeout=300"
hf_model_2: "sentence-transformers/all-roberta-large-v1"
hf_model_2_rev: "cf74d8acd4f198de950bf004b262e6accfed5d2c"
os: "ubuntu-latest"
python_version: "3.12"
- vllm_version:
name: "vLLM:0.14.0"
repo: "git+https://github.com/vllm-project/vllm --tag v0.14.0"
test_suite:
name: "backward compat"
markers: "cpu and basic and not quantized and not sb"
flags: "--timeout=300"
hf_model_2: "sentence-transformers/all-roberta-large-v1"
hf_model_2_rev: "cf74d8acd4f198de950bf004b262e6accfed5d2c"
os: "ubuntu-latest"
python_version: "3.12"
- vllm_version:
name: "vLLM:0.14.1"
repo: "git+https://github.com/vllm-project/vllm --tag v0.14.1"
test_suite:
name: "backward compat"
markers: "cpu and basic and not quantized and not sb"
flags: "--timeout=300"
hf_model_2: "sentence-transformers/all-roberta-large-v1"
hf_model_2_rev: "cf74d8acd4f198de950bf004b262e6accfed5d2c"
os: "ubuntu-latest"
python_version: "3.12"
- vllm_version:
name: "vLLM:0.15.0"
repo: "git+https://github.com/vllm-project/vllm --tag v0.15.0"
test_suite:
name: "backward compat"
markers: "cpu and basic and not quantized and not sb"
flags: "--timeout=300"
hf_model_2: "sentence-transformers/all-roberta-large-v1"
hf_model_2_rev: "cf74d8acd4f198de950bf004b262e6accfed5d2c"
os: "ubuntu-latest"
python_version: "3.12"
- vllm_version:
name: "vLLM:0.15.1"
repo: "git+https://github.com/vllm-project/vllm --tag v0.15.1"
repo: "git+https://github.com/vllm-project/vllm --tag v0.17.1"
test_suite:
name: "backward compat"
markers: "cpu and basic and not quantized and not sb"
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ license = {text = "Apache 2"}
dependencies = [
"fms-model-optimizer[fp8]>=0.8.0",
"ibm-fms>=1.7.0,<2.0",
"vllm>=0.11.0,<0.16.1",
"vllm>=0.17.0,<0.18.1",
]
requires-python = ">=3.11"
dynamic = ["version"]
Expand Down Expand Up @@ -70,7 +70,7 @@ environments = [
]

[tool.uv.sources]
vllm = { git = "https://github.com/vllm-project/vllm", rev = "v0.16.0" }
vllm = { git = "https://github.com/vllm-project/vllm", rev = "v0.18.0" }

[tool.ty.rules]
possibly-missing-attribute = "ignore"
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
SpyrePlatform._used_with_cli = False
yield
if should_do_global_cleanup_after_test:
# Workaround torch.accelerator.empty_cache for torch 2.7.1 and vllm v0.18.0 compatibility
setattr(torch.accelerator, "empty_cache", lambda: None) # noqa
cleanup_dist_env_and_memory()


Expand Down
4 changes: 0 additions & 4 deletions tests/e2e/test_chunked_prefill_tkv_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,6 @@ def make_scheduler_output(
scheduled_cached_reqs=scheduled_cached_reqs,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_tokens,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=finished_req_ids,
kv_connector_metadata=None,
**extra_args,
Expand All @@ -134,7 +131,6 @@ def make_new_request_data(req_id, prompt_len):
prompt_token_ids=[42] * prompt_len,
sampling_params=SamplingParams(),
pooling_params=None,
eos_token_id=None,
)
return NewRequestData.from_request(req, block_ids=[])

Expand Down
1 change: 0 additions & 1 deletion tests/scheduling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def create_request_for_scheduler_test(
request_id=str(request_id),
sampling_params=sampling_params,
prompt_token_ids=prompt,
eos_token_id=None,
arrival_time=0,
lora_request=None,
pooling_params=None,
Expand Down
10 changes: 2 additions & 8 deletions tests/spyre_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,8 @@
from vllm_spyre.platform import SpyrePlatform
from vllm_spyre import envs

try:
# old
from vllm.utils import FlexibleArgumentParser, get_open_port
except ImportError:
# new
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import get_open_port
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import get_open_port

from vllm.v1.request import Request

Expand Down Expand Up @@ -448,7 +443,6 @@ def create_random_request(
request_id=str(request_id),
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
eos_token_id=None,
arrival_time=0,
lora_request=None,
pooling_params=None,
Expand Down
9 changes: 1 addition & 8 deletions tests/utils/test_cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@
from vllm_spyre.config.model_registry import get_model_registry
from spyre_util import environ_checkpoint, REFERENCE_MODELS

try:
# old
from vllm.utils import FlexibleArgumentParser
except ImportError:
# new
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.argparse_utils import FlexibleArgumentParser

global_default = 192

Expand Down Expand Up @@ -69,8 +64,6 @@ def sendnn_configured() -> bool:
"32",
"-tp",
"4",
"--swap-space", # to prevent a validation error in the 16GB memory test env.
"1",
]

if model_name == "ibm-granite/granite-3.3-8b-instruct":
Expand Down
29 changes: 24 additions & 5 deletions tests/utils/test_platform_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,29 @@
from SamplingParams during request validation.
"""

from unittest.mock import MagicMock
import pytest

from vllm import SamplingParams
from vllm.inputs.data import token_inputs
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import StructuredOutputsParams
from vllm_spyre.platform import SpyrePlatform

pytestmark = pytest.mark.skip_global_cleanup


@pytest.fixture(autouse=True)
def mock_spyre_config():
"""Mock SpyrePlatform._config for all tests."""
original_config = SpyrePlatform._config
mock_config = MagicMock()
mock_config.model_config.max_model_len = 512
SpyrePlatform._config = mock_config
yield mock_config
SpyrePlatform._config = original_config


class TestStructuredOutputValidation:
"""Test that platform validation strips structured outputs from requests."""

Expand All @@ -24,7 +38,8 @@ def test_strips_structured_outputs(self):

assert params.structured_outputs is not None

SpyrePlatform.validate_request("Test prompt", params)
processed_inputs = token_inputs(prompt_token_ids=[1, 2, 3])
SpyrePlatform.validate_request(processed_inputs, params)

assert params.structured_outputs is None

Expand All @@ -34,7 +49,8 @@ def test_logs_warning_when_stripping(self, caplog_vllm_spyre):
max_tokens=20, structured_outputs=StructuredOutputsParams(json_object=True)
)

SpyrePlatform.validate_request("Test prompt", params)
processed_inputs = token_inputs(prompt_token_ids=[1, 2, 3])
SpyrePlatform.validate_request(processed_inputs, params)

assert len(caplog_vllm_spyre.records) > 0
warning_record = caplog_vllm_spyre.records[0]
Expand All @@ -55,7 +71,8 @@ def test_strips_different_structured_output_types(self, structured_output):

assert params.structured_outputs is not None

SpyrePlatform.validate_request("Test prompt", params)
processed_inputs = token_inputs(prompt_token_ids=[1, 2, 3])
SpyrePlatform.validate_request(processed_inputs, params)

assert params.structured_outputs is None

Expand All @@ -77,7 +94,8 @@ def test_preserves_other_sampling_params(self):
"top_k": params.top_k,
}

SpyrePlatform.validate_request("Test prompt", params)
processed_inputs = token_inputs(prompt_token_ids=[1, 2, 3])
SpyrePlatform.validate_request(processed_inputs, params)

# Verify other params are unchanged
assert params.max_tokens == original_values["max_tokens"]
Expand All @@ -92,7 +110,8 @@ def test_does_not_affect_pooling_params(self):
pooling_params = PoolingParams()

# Should not raise any errors and should return early
SpyrePlatform.validate_request("Test prompt", pooling_params)
processed_inputs = token_inputs(prompt_token_ids=[1, 2, 3])
SpyrePlatform.validate_request(processed_inputs, pooling_params)

# PoolingParams don't have structured_outputs, so just verify no exception
assert True # If we got here, the early return worked
Expand Down
Loading
Loading