diff --git a/tests/e2e/test_speculative_decoding.py b/tests/e2e/test_speculative_decoding.py index af31adbf36..eaf1a21290 100644 --- a/tests/e2e/test_speculative_decoding.py +++ b/tests/e2e/test_speculative_decoding.py @@ -104,7 +104,8 @@ def _test_correctness_helper( ref_llm = LLM(model=model_name, max_model_len=1024, max_num_seqs=4, - tensor_parallel_size=_get_tensor_parallel_size()) + tensor_parallel_size=_get_tensor_parallel_size(), + async_scheduling=0) ref_outputs = ref_llm.generate(test_prompts, sampling_config) del ref_llm @@ -116,7 +117,8 @@ def _test_correctness_helper( speculative_config=speculative_config, max_model_len=1024, max_num_seqs=4, - tensor_parallel_size=_get_tensor_parallel_size()) + tensor_parallel_size=_get_tensor_parallel_size(), + async_scheduling=0) spec_outputs = spec_llm.generate(test_prompts, sampling_config) matches = 0 @@ -198,7 +200,8 @@ def _test_performance_helper( max_model_len=1024, max_num_seqs=1, enable_prefix_caching=False, - tensor_parallel_size=_get_tensor_parallel_size()) + tensor_parallel_size=_get_tensor_parallel_size(), + async_scheduling=0) start_time = time.time() _ = ref_llm.generate(test_prompts, sampling_config) @@ -215,7 +218,8 @@ def _test_performance_helper( max_model_len=1024, max_num_seqs=1, tensor_parallel_size=_get_tensor_parallel_size(), - enable_prefix_caching=False) + enable_prefix_caching=False, + async_scheduling=0) start_time = time.time() _ = spec_llm.generate(test_prompts, sampling_config) diff --git a/tests/layers/vllm/test_attention.py b/tests/layers/vllm/test_attention.py index 9e4b14288e..15d9fe38f1 100644 --- a/tests/layers/vllm/test_attention.py +++ b/tests/layers/vllm/test_attention.py @@ -22,7 +22,7 @@ import torchax from jax.sharding import Mesh from torchax.interop import torch_view -from vllm.attention.backends.abstract import AttentionType +from vllm.v1.attention.backend import AttentionType from tpu_inference.layers.common.attention_metadata import AttentionMetadata from tpu_inference.layers.vllm.attention import (PallasAttentionBackend, diff --git a/tests/layers/vllm/test_fp8.py b/tests/layers/vllm/test_fp8.py index d87780883c..7151a47f35 100644 --- a/tests/layers/vllm/test_fp8.py +++ b/tests/layers/vllm/test_fp8.py @@ -43,7 +43,6 @@ P = PartitionSpec MODELS = [ - "MiniMaxAI/MiniMax-M2", "Qwen/Qwen3-0.6B-FP8", ] diff --git a/tests/runner/test_kv_cache_manager.py b/tests/runner/test_kv_cache_manager.py index 7a770d99fe..70d61122fb 100644 --- a/tests/runner/test_kv_cache_manager.py +++ b/tests/runner/test_kv_cache_manager.py @@ -19,11 +19,11 @@ import numpy as np import pytest import torch -from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig) from vllm.sampling_params import SamplingType +from vllm.v1.attention.backend import AttentionType from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor, MLAAttentionSpec, SlidingWindowSpec) diff --git a/tpu_inference/layers/vllm/attention.py b/tpu_inference/layers/vllm/attention.py index 707bb81180..b5e25541fa 100644 --- a/tpu_inference/layers/vllm/attention.py +++ b/tpu_inference/layers/vllm/attention.py @@ -9,12 +9,12 @@ from jax.sharding import Mesh from torchax.interop import jax_view, torch_view from torchax.ops.mappings import t2j -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, AttentionType) -from vllm.attention.backends.registry import (AttentionBackendEnum, - register_backend) from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv, next_power_of_2 +from vllm.v1.attention.backend import (AttentionBackend, AttentionImpl, + AttentionLayer, AttentionType) +from vllm.v1.attention.backends.registry import (AttentionBackendEnum, + register_backend) from tpu_inference import utils from tpu_inference.layers.common.attention_interface import attention diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index 50ea86a78e..52f6d10740 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -14,11 +14,12 @@ from tpu_inference.logger import init_logger if TYPE_CHECKING: - from vllm.attention.backends.registry import AttentionBackendEnum - from vllm.attention.selector import AttentionSelectorConfig - from vllm.config import BlockSize, ModelConfig, VllmConfig + from vllm.config import ModelConfig, VllmConfig + from vllm.config.cache import BlockSize from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType + from vllm.v1.attention.backends.registry import AttentionBackendEnum + from vllm.v1.attention.selector import AttentionSelectorConfig else: BlockSize = None ModelConfig = None @@ -54,7 +55,7 @@ class TpuPlatform(Platform): def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum", attn_selector_config: "AttentionSelectorConfig", **kwargs) -> str: - from vllm.attention.backends.registry import AttentionBackendEnum + from vllm.v1.attention.backends.registry import AttentionBackendEnum # Invoke @register_backend in the module. import tpu_inference.layers.vllm.attention # noqa: F401 diff --git a/tpu_inference/runner/kv_cache_manager.py b/tpu_inference/runner/kv_cache_manager.py index e5f0ddf410..173e44bcbf 100644 --- a/tpu_inference/runner/kv_cache_manager.py +++ b/tpu_inference/runner/kv_cache_manager.py @@ -21,10 +21,10 @@ import vllm.envs as envs from jax.sharding import NamedSharding, PartitionSpec from torchax.ops.mappings import t2j_dtype -from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import get_layers_from_vllm_config from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backend import AttentionType from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec, MLAAttentionSpec, SlidingWindowSpec)