Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions tests/aftu/graph_compare_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from subprocess import PIPE, STDOUT, CalledProcessError, TimeoutExpired, run
from typing import Optional

from spyre_util import ModelInfo
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf)

Expand Down Expand Up @@ -157,14 +158,13 @@ def run_inference_py_and_get_graphs(
return aftu_graphs


def get_model_path(model_name_or_path):
is_local = os.path.isdir(model_name_or_path)
model_path = model_name_or_path
# Get location of model from HF cache.
if not is_local:
model_path = download_weights_from_hf(
model_name_or_path=model_path,
cache_dir=None,
allow_patterns=["*.safetensors", "*.bin", "*.pt"])
def get_model_path(model: ModelInfo):
if os.path.isdir(model.name):
return model.name

return model_path
# Get location of model from HF cache.
return download_weights_from_hf(
model_name_or_path=model.name,
cache_dir=None,
allow_patterns=["*.safetensors", "*.bin", "*.pt"],
revision=model.revision)
9 changes: 5 additions & 4 deletions tests/aftu/test_compare_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
get_aftu_script_dir, get_model_path,
run_inference_py_and_get_graphs)
from output_util import generate_spyre_vllm_output
from spyre_util import DecodeWarmupShapes, get_chicken_soup_prompts
from spyre_util import DecodeWarmupShapes, ModelInfo, get_chicken_soup_prompts
from vllm import SamplingParams


@pytest.mark.spyre
@pytest.mark.cb
def test_compare_graphs_cb(model: str, max_num_seqs: int,
def test_compare_graphs_cb(model: ModelInfo, max_num_seqs: int,
monkeypatch: pytest.MonkeyPatch, use_llm_cache):
"""Test that the spyre worker correctly outputs
continuous batches of requests by comparing to HF"""
Expand All @@ -43,7 +43,8 @@ def test_compare_graphs_cb(model: str, max_num_seqs: int,

extra_env = {
"VLLM_DT_MAX_CONTEXT_LEN": str(max_model_len),
"VLLM_DT_MAX_BATCH_SIZE": str(max_num_seqs)
"VLLM_DT_MAX_BATCH_SIZE": str(max_num_seqs),
"VLLM_DT_MAX_BATCH_TKV_LIMIT": str(1024 * 128)
}
aftu_graphs = run_inference_py_and_get_graphs(inference_py_args,
script_dir, extra_env)
Expand Down Expand Up @@ -91,7 +92,7 @@ def test_compare_graphs_cb(model: str, max_num_seqs: int,
@pytest.mark.parametrize(
"warmup_shapes", [[(64, 4, 4)]]) # (prompt_length/new_tokens/batch_size)
def test_compare_graphs_static_batching(
model: str, warmup_shapes: DecodeWarmupShapes,
model: ModelInfo, warmup_shapes: DecodeWarmupShapes,
monkeypatch: pytest.MonkeyPatch) -> None:

# AFTU
Expand Down
15 changes: 7 additions & 8 deletions tests/e2e/test_spyre_async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from contextlib import ExitStack

import pytest
from spyre_util import DecodeWarmupShapes, get_chicken_soup_prompts
from spyre_util import DecodeWarmupShapes, ModelInfo, get_chicken_soup_prompts
from vllm import PromptType, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand Down Expand Up @@ -47,7 +47,7 @@ async def generate(
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_abort(model: str, backend: str, cb: int,
async def test_abort(model: ModelInfo, backend: str, cb: int,
warmup_shapes: DecodeWarmupShapes,
output_kind: RequestOutputKind,
monkeypatch: pytest.MonkeyPatch):
Expand All @@ -70,12 +70,11 @@ async def test_abort(model: str, backend: str, cb: int,

# Async LLM API is a little different between v0 and V1
engine = AsyncLLM.from_engine_args(
AsyncEngineArgs(
model=model,
tokenizer=model,
max_model_len=128,
max_num_seqs=4,
))
AsyncEngineArgs(model=model.name,
tokenizer=model.name,
max_model_len=256,
max_num_seqs=4,
revision=model.revision))
has_unfinished_requests = \
engine.output_processor.has_unfinished_requests
after.callback(engine.shutdown)
Expand Down
26 changes: 14 additions & 12 deletions tests/e2e/test_spyre_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest
from output_util import check_output_against_hf, generate_spyre_vllm_output
from spyre_util import (DecodeWarmupShapes, create_random_request,
from spyre_util import (DecodeWarmupShapes, ModelInfo, create_random_request,
get_chicken_soup_prompts, patch_environment,
skip_unsupported_tp_size)
from vllm import EngineArgs, SamplingParams
Expand All @@ -16,7 +16,7 @@


@pytest.mark.full_model
def test_output(model: str, tp_size: int, backend: str, cb: int,
def test_output(model: ModelInfo, tp_size: int, backend: str, cb: int,
max_num_seqs: int, max_model_len: int,
warmup_shapes: DecodeWarmupShapes,
monkeypatch: pytest.MonkeyPatch, use_llm_cache) -> None:
Expand Down Expand Up @@ -70,8 +70,9 @@ def test_output(model: str, tp_size: int, backend: str, cb: int,
pytest.param(
"sendnn_decoder", marks=pytest.mark.spyre, id="sendnn_decoder")
])
def test_output_sendnn_decoder(model: str, warmup_shapes: DecodeWarmupShapes,
backend: str, monkeypatch: pytest.MonkeyPatch,
def test_output_sendnn_decoder(model: ModelInfo,
warmup_shapes: DecodeWarmupShapes, backend: str,
monkeypatch: pytest.MonkeyPatch,
use_llm_cache) -> None:
'''
Tests the deprecated sendnn_decoder backend, which should fall-back to
Expand Down Expand Up @@ -101,7 +102,7 @@ def test_output_sendnn_decoder(model: str, warmup_shapes: DecodeWarmupShapes,
prompts)


def test_batch_handling(model: str, backend: str, cb: int, warmup_shapes,
def test_batch_handling(model: ModelInfo, backend: str, cb: int, warmup_shapes,
max_num_seqs: int, max_model_len: int,
monkeypatch: pytest.MonkeyPatch, use_llm_cache):
"""Test that the spyre worker correctly handles
Expand Down Expand Up @@ -147,7 +148,7 @@ def test_batch_handling(model: str, backend: str, cb: int, warmup_shapes,
prompts)


def test_full_batch_scheduling(model: str, backend: str, monkeypatch):
def test_full_batch_scheduling(model: ModelInfo, backend: str, monkeypatch):
"""Test that we can schedule a full batch of prompts."""

# We need to ensure here that the max number of tokens in a full batch
Expand All @@ -171,10 +172,11 @@ def test_full_batch_scheduling(model: str, backend: str, monkeypatch):
monkeypatch.setenv("VLLM_SPYRE_DYNAMO_BACKEND", backend)

# Setup the engine
engine_args = EngineArgs(model=model,
tokenizer=model,
engine_args = EngineArgs(model=model.name,
tokenizer=model.name,
max_num_batched_tokens=max_batched_tokens,
max_num_seqs=batch_size)
max_num_seqs=batch_size,
revision=model.revision)
vllm_config = engine_args.create_engine_config()
executor_class = Executor.get_class(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config,
Expand All @@ -191,14 +193,14 @@ def test_full_batch_scheduling(model: str, backend: str, monkeypatch):
request_id=i,
num_tokens=max_batched_tokens,
sampling_params=vllm_sampling_params,
model=model,
model=model.name,
))
schedule = scheduler.schedule()

assert len(schedule.scheduled_new_reqs) == batch_size


def test_max_model_len_override(model: str, backend, warmup_shapes, cb,
def test_max_model_len_override(model: ModelInfo, backend, warmup_shapes, cb,
monkeypatch):
"""Test that makes sure that --max-model-len
doesn't affect SB, instead it is picked up from
Expand All @@ -215,7 +217,7 @@ def test_max_model_len_override(model: str, backend, warmup_shapes, cb,

patch_environment(**kwargs, backend=backend, monkeypatch=monkeypatch)
vllm_config = EngineArgs(
model=model, max_model_len=max_model_len).create_engine_config()
model=model.name, max_model_len=max_model_len).create_engine_config()
model_config = vllm_config.model_config

if not cb:
Expand Down
29 changes: 14 additions & 15 deletions tests/e2e/test_spyre_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
from openai import BadRequestError
from output_util import (check_output_against_hf, compare_results,
extract_output, generate_spyre_vllm_output)
from spyre_util import (RemoteOpenAIServer, create_seq_prompt,
from spyre_util import (ModelInfo, RemoteOpenAIServer, create_seq_prompt,
get_chicken_soup_prompts, skip_unsupported_tp_size)
from vllm import LLM, SamplingParams


@pytest.mark.cb
@pytest.mark.parametrize(
"backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")])
def test_cb_max_tokens(model: str, backend: str, max_model_len: int,
def test_cb_max_tokens(model: ModelInfo, backend: str, max_model_len: int,
max_num_seqs: int, monkeypatch: pytest.MonkeyPatch,
use_llm_cache):
"""Test that continuous batches of requests that
Expand Down Expand Up @@ -52,7 +52,7 @@ def test_cb_max_tokens(model: str, backend: str, max_model_len: int,
"backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")])
def test_api_cb_rejects_oversized_request(
remote_openai_server: RemoteOpenAIServer,
model: str,
model: ModelInfo,
backend: str,
cb: bool,
max_model_len: int,
Expand All @@ -67,7 +67,7 @@ def test_api_cb_rejects_oversized_request(
with pytest.raises(BadRequestError,
match="This model's maximum context length is"):
client.completions.create(
model=model,
model=model.name,
prompt=overflow_prompt,
max_tokens=max_tokens,
)
Expand All @@ -79,7 +79,7 @@ def test_api_cb_rejects_oversized_request(
"backend", [pytest.param("eager", marks=pytest.mark.cpu, id="eager")])
def test_api_cb_generates_correct_max_tokens(
remote_openai_server: RemoteOpenAIServer,
model: str,
model: ModelInfo,
backend: str,
cb: bool,
max_model_len: int,
Expand All @@ -90,7 +90,7 @@ def test_api_cb_generates_correct_max_tokens(
client = remote_openai_server.get_client()
max_tokens = 10

response = client.completions.create(model=model,
response = client.completions.create(model=model.name,
prompt=get_chicken_soup_prompts(1),
max_tokens=max_tokens,
temperature=0)
Expand All @@ -110,7 +110,7 @@ def test_api_cb_generates_correct_max_tokens(
ids=lambda val: f"TP({val})",
)
def test_long_context_batches(
model: str,
model: ModelInfo,
backend: str,
tp_size: int,
monkeypatch: pytest.MonkeyPatch,
Expand All @@ -136,13 +136,12 @@ def test_long_context_batches(
(2, 9000),
]

vllm_model = LLM(
model=model,
tokenizer=model,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tp_size,
)
vllm_model = LLM(model=model.name,
tokenizer=model.name,
max_model_len=max_model_len,
max_num_seqs=max_num_seqs,
tensor_parallel_size=tp_size,
revision=model.revision)

sampling_params = SamplingParams(
max_tokens=max_tokens,
Expand All @@ -152,7 +151,7 @@ def test_long_context_batches(
)

for batch_size, token_len in batch_token_pairs:
prompt = create_seq_prompt(model, token_length=token_len)
prompt = create_seq_prompt(model.name, token_length=token_len)
prompts = [prompt] * batch_size

vllm_outputs = vllm_model.generate(prompts, sampling_params)
Expand Down
Loading