Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from tests.conftest import cleanup
from vllm import LLM
from vllm.model_executor.utils import set_random_seed


@pytest.fixture
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, seed):
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, seed)


@pytest.fixture
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed):
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed)


def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
distinct_llm_kwargs, seed):
kwargs = {
**common_llm_kwargs,
**per_test_common_llm_kwargs,
**distinct_llm_kwargs,
}

def generator_inner():
llm = LLM(**kwargs)

set_random_seed(seed)

yield llm
del llm
cleanup()

for llm in generator_inner():
yield llm
del llm
50 changes: 50 additions & 0 deletions tests/spec_decode/e2e/test_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest

from vllm import SamplingParams


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
"speculative_model": "facebook/opt-125m",
"num_speculative_tokens": 5,

# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_config(test_llm_generator):
output_len = 1024
temperature = 0.0

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)

with pytest.raises(
AssertionError,
match="Speculative decoding not yet supported for GPU backend"):
get_token_ids_from_llm_generator(test_llm_generator, prompts,
sampling_params)


def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
for llm in llm_generator:
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
del llm

return token_ids
18 changes: 8 additions & 10 deletions tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,16 @@ def create_worker(cls: type,
block_size=block_size,
enforce_eager=enforce_eager,
)

(model_config, cache_config, parallel_config, scheduler_config,
device_config, _, _) = engine_args.create_engine_configs()
engine_config = engine_args.create_engine_config()

distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())

worker = cls(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
Expand All @@ -128,9 +126,9 @@ def create_worker(cls: type,
worker.init_device()
worker.load_model()

cache_config.num_gpu_blocks = num_gpu_blocks
cache_config.num_cpu_blocks = 0
worker.init_cache_engine(cache_config)
engine_config.cache_config.num_gpu_blocks = num_gpu_blocks
engine_config.cache_config.num_cpu_blocks = 0
worker.init_cache_engine(engine_config.cache_config)
worker.warm_up_model()

return worker
Expand Down
17 changes: 8 additions & 9 deletions tests/worker/test_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@ def test_swap() -> None:
engine_args = EngineArgs(model="facebook/opt-125m",
dtype="half",
load_format="dummy")
(model_config, cache_config, parallel_config, scheduler_config,
device_config, _, _) = engine_args.create_engine_configs()
cache_config.num_gpu_blocks = 100
cache_config.num_cpu_blocks = 100
engine_config = engine_args.create_engine_config()
engine_config.cache_config.num_gpu_blocks = 100
engine_config.cache_config.num_cpu_blocks = 100

# Create the worker.
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
worker = Worker(
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
Expand All @@ -32,7 +31,7 @@ def test_swap() -> None:
# Initialize the worker.
worker.init_device()
worker.load_model()
worker.init_cache_engine(cache_config)
worker.init_cache_engine(engine_config.cache_config)
worker.warm_up_model()

# Randomly initialize the cache.
Expand Down
188 changes: 187 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import enum
import json
import os
from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, ClassVar, Optional, Union

import torch
Expand Down Expand Up @@ -617,6 +617,159 @@ def __init__(self, device: str = "auto") -> None:
self.device = torch.device(self.device_type)


class SpeculativeConfig:
"""Configuration for speculative decoding.

The configuration is currently specialized to draft-model speculative
decoding with top-1 proposals.
"""

@staticmethod
def maybe_create_spec_config(
target_model_config: ModelConfig,
target_parallel_config: ParallelConfig,
target_dtype: str,
speculative_model: Optional[str],
num_speculative_tokens: Optional[int],
) -> Optional["SpeculativeConfig"]:
"""Create a SpeculativeConfig if possible, else return None.

This function attempts to create a SpeculativeConfig object based on the
provided parameters. If the necessary conditions are met, it returns an
instance of SpeculativeConfig. Otherwise, it returns None.

Args:
target_model_config (ModelConfig): The configuration of the target
model.
target_parallel_config (ParallelConfig): The parallel configuration
for the target model.
target_dtype (str): The data type used for the target model.
speculative_model (Optional[str]): The name of the speculative
model, if provided.
num_speculative_tokens (Optional[int]): The number of speculative
tokens, if provided.

Returns:
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
the necessary conditions are met, else None.
"""

if (speculative_model is None and num_speculative_tokens is None):
return None

if speculative_model is not None and num_speculative_tokens is None:
raise ValueError(
"Expected both speculative_model and "
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")

# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
draft_revision = None
draft_code_revision = None
draft_quantization = None
draft_max_model_len = None

draft_model_config = ModelConfig(
model=speculative_model,
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,
download_dir=target_model_config.download_dir,
load_format=target_model_config.load_format,
dtype=target_model_config.dtype,
seed=target_model_config.seed,
revision=draft_revision,
code_revision=draft_code_revision,
tokenizer_revision=target_model_config.tokenizer_revision,
max_model_len=draft_max_model_len,
quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager,
max_context_len_to_capture=target_model_config.
max_context_len_to_capture,
max_logprobs=target_model_config.max_logprobs,
)

draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config))

return SpeculativeConfig(
draft_model_config,
draft_parallel_config,
num_speculative_tokens,
)

@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.

This is mostly a copy of the target parallel config. In the future the
draft worker can have a different parallel strategy, e.g. TP=1.
"""
draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size,
tensor_parallel_size=target_parallel_config.tensor_parallel_size,
worker_use_ray=target_parallel_config.worker_use_ray,
max_parallel_loading_workers=target_parallel_config.
max_parallel_loading_workers,
disable_custom_all_reduce=target_parallel_config.
disable_custom_all_reduce,
tokenizer_pool_config=target_parallel_config.tokenizer_pool_config,
ray_workers_use_nsight=target_parallel_config.
ray_workers_use_nsight,
placement_group=target_parallel_config.placement_group,
)

return draft_parallel_config

def __init__(
self,
draft_model_config: ModelConfig,
draft_parallel_config: ParallelConfig,
num_speculative_tokens: int,
):
"""Create a SpeculativeConfig object.

Args:
draft_model_config: ModelConfig for the draft model.
draft_parallel_config: ParallelConfig for the draft model.
num_speculative_tokens: The number of tokens to sample from the
draft model before scoring with the target model.
"""
self.draft_model_config = draft_model_config
self.draft_parallel_config = draft_parallel_config
self.num_speculative_tokens = num_speculative_tokens

self._verify_args()

def _verify_args(self) -> None:
if self.num_speculative_tokens <= 0:
raise ValueError("Expected num_speculative_tokens to be greater "
f"than zero ({self.num_speculative_tokens}).")

if self.draft_model_config:
self.draft_model_config.verify_with_parallel_config(
self.draft_parallel_config)

@property
def num_lookahead_slots(self) -> int:
"""The number of additional slots the scheduler should allocate per
step, in addition to the slots allocated for each known token.

This is equal to the number of speculative tokens, as each speculative
token must be scored.
"""
return self.num_speculative_tokens

def __repr__(self) -> str:
draft_model = self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})"


@dataclass
class LoRAConfig:
max_lora_rank: int
Expand Down Expand Up @@ -838,3 +991,36 @@ def _get_and_verify_max_len(
"to incorrect model outputs or CUDA errors. Make sure the "
"value is correct and within the model context size.")
return int(max_model_len)


@dataclass(frozen=True)
class EngineConfig:
"""Dataclass which contains all engine-related configuration. This
simplifies passing around the distinct configurations in the codebase.
"""

model_config: ModelConfig
cache_config: CacheConfig
parallel_config: ParallelConfig
scheduler_config: SchedulerConfig
device_config: DeviceConfig
lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig]
speculative_config: Optional[SpeculativeConfig]

def __post_init__(self):
"""Verify configs are valid & consistent with each other.
"""
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)

if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)

def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs.
"""
return dict(
(field.name, getattr(self, field.name)) for field in fields(self))
Loading