Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 6 additions & 44 deletions tests/v1/test_oracle.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os

import pytest

import vllm.envs as envs
from vllm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs

MODEL = "meta-llama/Llama-3.2-1B-Instruct"


def test_reject_bad_config(monkeypatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")


def test_unsupported_configs(monkeypatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

with pytest.raises(NotImplementedError):
AsyncEngineArgs(
model=MODEL,
speculative_config={
"model": MODEL,
},
).create_engine_config()


def test_enable_by_default_fallback(monkeypatch):
with monkeypatch.context() as m:
if os.getenv("VLLM_USE_V1", None):
m.delenv("VLLM_USE_V1")

# Should default to V1 for supported config.
_ = AsyncEngineArgs(
def test_unsupported_configs():
with pytest.raises(NotImplementedError):
AsyncEngineArgs(
model=MODEL,
enforce_eager=True,
speculative_config={
"model": MODEL,
},
).create_engine_config()
assert envs.VLLM_USE_V1
m.delenv("VLLM_USE_V1")


def test_v1_llm_by_default(monkeypatch):
with monkeypatch.context() as m:
if os.getenv("VLLM_USE_V1", None):
m.delenv("VLLM_USE_V1")

# Should default to V1 for supported config.
llm = LLM(MODEL, enforce_eager=True, enable_lora=True)
print(llm.generate("Hello my name is"))
assert hasattr(llm.llm_engine, "engine_core")
m.delenv("VLLM_USE_V1")
105 changes: 15 additions & 90 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,15 +1290,7 @@ def create_engine_config(
"""
Create the VllmConfig.

NOTE: for autoselection of V0 vs V1 engine, we need to
create the ModelConfig first, since ModelConfig's attrs
(e.g. the model arch) are needed to make the decision.

This function set VLLM_USE_V1=X if VLLM_USE_V1 is
unspecified by the user.

If VLLM_USE_V1 is specified by the user but the VllmConfig
is incompatible, we raise an error.
NOTE: If VllmConfig is incompatible, we raise an error.
"""
current_platform.pre_register_and_update()

Expand All @@ -1324,22 +1316,7 @@ def create_engine_config(
self.model = model_config.model
self.tokenizer = model_config.tokenizer

# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"
# and fall back to V0 for experimental or unsupported features.
# * If VLLM_USE_V1=1, we enable V1 for supported + experimental
# features and raise error for unsupported features.
# * If VLLM_USE_V1=0, we disable V1.
use_v1 = False
try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1")
if try_v1 and self._is_v1_supported_oracle(model_config):
use_v1 = True

# If user explicitly set VLLM_USE_V1, sanity check we respect it.
if envs.is_set("VLLM_USE_V1"):
assert use_v1 == envs.VLLM_USE_V1
# Otherwise, set the VLLM_USE_V1 variable globally.
else:
envs.set_vllm_use_v1(use_v1)
self._check_feature_supported(model_config)

# Set default arguments for V1 Engine.
self._set_default_args(usage_context, model_config)
Expand Down Expand Up @@ -1708,30 +1685,20 @@ def create_engine_config(

return config

def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"""Oracle for whether to use V0 or V1 Engine by default."""

#############################################################
# Unsupported Feature Flags on V1.

def _check_feature_supported(self, model_config: ModelConfig):
"""Raise an error if the feature is not supported."""
if self.logits_processor_pattern != EngineArgs.logits_processor_pattern:
_raise_or_fallback(
feature_name="--logits-processor-pattern", recommend_to_remove=False
)
return False
_raise_unsupported_error(feature_name="--logits-processor-pattern")

# No Concurrent Partial Prefills so far.
if (
self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
or self.max_long_partial_prefills
!= SchedulerConfig.max_long_partial_prefills
):
_raise_or_fallback(
feature_name="Concurrent Partial Prefill", recommend_to_remove=False
)
return False
_raise_unsupported_error(feature_name="Concurrent Partial Prefill")

# V1 supports N-gram, Medusa, and Eagle speculative decoding.
# N-gram, Medusa, and Eagle are supported for speculative decoding.
if self.speculative_config is not None:
# speculative_config could still be a dict at this point
if isinstance(self.speculative_config, dict):
Expand All @@ -1746,35 +1713,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"such as ngram, medusa, eagle, or mtp."
)

V1_BACKENDS = [
"FLASH_ATTN",
"PALLAS",
"TRITON_ATTN",
"TRITON_MLA",
"CUTLASS_MLA",
"FLASHMLA",
"FLASH_ATTN_MLA",
"FLASHINFER",
"FLASHINFER_MLA",
"ROCM_AITER_MLA",
"TORCH_SDPA",
"FLEX_ATTENTION",
"TREE_ATTN",
"XFORMERS",
"ROCM_ATTN",
"ROCM_AITER_UNIFIED_ATTN",
]
if (
envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS
):
name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}"
_raise_or_fallback(feature_name=name, recommend_to_remove=True)
return False

#############################################################
# Experimental Features - allow users to opt in.

if self.pipeline_parallel_size > 1:
supports_pp = getattr(
self.distributed_executor_backend, "supports_pp", False
Expand All @@ -1790,18 +1728,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"executor or multiprocessing executor or external "
"launcher"
)
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
return False
_raise_unsupported_error(feature_name=name)

if current_platform.is_cpu() and model_config.get_sliding_window() is not None:
_raise_or_fallback(
feature_name="sliding window (CPU backend)", recommend_to_remove=False
)
return False

#############################################################

return True
_raise_unsupported_error(feature_name="sliding window (CPU backend)")

def _set_default_args(
self, usage_context: UsageContext, model_config: ModelConfig
Expand Down Expand Up @@ -2000,17 +1930,12 @@ def add_cli_args(
return parser


def _raise_or_fallback(feature_name: str, recommend_to_remove: bool):
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
raise NotImplementedError(
f"VLLM_USE_V1=1 is not supported with {feature_name}."
)
msg = f"{feature_name} is not supported by the V1 Engine. "
msg += "Falling back to V0. "
if recommend_to_remove:
msg += f"We recommend to remove {feature_name} from your config "
msg += "in favor of the V1 Engine."
logger.warning(msg)
def _raise_unsupported_error(feature_name: str):
msg = (
f"{feature_name} is not supported. We recommend to "
f"remove {feature_name} from your config."
)
raise NotImplementedError(msg)


def human_readable_int(value):
Expand Down