Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions tests/e2e/test_speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def _test_correctness_helper(
ref_llm = LLM(model=model_name,
max_model_len=1024,
max_num_seqs=4,
tensor_parallel_size=_get_tensor_parallel_size())
tensor_parallel_size=_get_tensor_parallel_size(),
async_scheduling=0)
ref_outputs = ref_llm.generate(test_prompts, sampling_config)

del ref_llm
Expand All @@ -116,7 +117,8 @@ def _test_correctness_helper(
speculative_config=speculative_config,
max_model_len=1024,
max_num_seqs=4,
tensor_parallel_size=_get_tensor_parallel_size())
tensor_parallel_size=_get_tensor_parallel_size(),
async_scheduling=0)
spec_outputs = spec_llm.generate(test_prompts, sampling_config)

matches = 0
Expand Down Expand Up @@ -198,7 +200,8 @@ def _test_performance_helper(
max_model_len=1024,
max_num_seqs=1,
enable_prefix_caching=False,
tensor_parallel_size=_get_tensor_parallel_size())
tensor_parallel_size=_get_tensor_parallel_size(),
async_scheduling=0)

start_time = time.time()
_ = ref_llm.generate(test_prompts, sampling_config)
Expand All @@ -215,7 +218,8 @@ def _test_performance_helper(
max_model_len=1024,
max_num_seqs=1,
tensor_parallel_size=_get_tensor_parallel_size(),
enable_prefix_caching=False)
enable_prefix_caching=False,
async_scheduling=0)

start_time = time.time()
_ = spec_llm.generate(test_prompts, sampling_config)
Expand Down
2 changes: 1 addition & 1 deletion tests/layers/vllm/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torchax
from jax.sharding import Mesh
from torchax.interop import torch_view
from vllm.attention.backends.abstract import AttentionType
from vllm.v1.attention.backend import AttentionType

from tpu_inference.layers.common.attention_metadata import AttentionMetadata
from tpu_inference.layers.vllm.attention import (PallasAttentionBackend,
Expand Down
1 change: 0 additions & 1 deletion tests/layers/vllm/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@

P = PartitionSpec
MODELS = [
"MiniMaxAI/MiniMax-M2",
"Qwen/Qwen3-0.6B-FP8",
]

Expand Down
2 changes: 1 addition & 1 deletion tests/runner/test_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import numpy as np
import pytest
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.sampling_params import SamplingType
from vllm.v1.attention.backend import AttentionType
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor,
MLAAttentionSpec, SlidingWindowSpec)
Expand Down
8 changes: 4 additions & 4 deletions tpu_inference/layers/vllm/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from jax.sharding import Mesh
from torchax.interop import jax_view, torch_view
from torchax.ops.mappings import t2j
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.registry import (AttentionBackendEnum,
register_backend)
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv, next_power_of_2
from vllm.v1.attention.backend import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.v1.attention.backends.registry import (AttentionBackendEnum,
register_backend)

from tpu_inference import utils
from tpu_inference.layers.common.attention_interface import attention
Expand Down
9 changes: 5 additions & 4 deletions tpu_inference/platforms/tpu_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from tpu_inference.logger import init_logger

if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import BlockSize, ModelConfig, VllmConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import AttentionSelectorConfig
else:
BlockSize = None
ModelConfig = None
Expand Down Expand Up @@ -54,7 +55,7 @@ class TpuPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
attn_selector_config: "AttentionSelectorConfig",
**kwargs) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.registry import AttentionBackendEnum

# Invoke @register_backend in the module.
import tpu_inference.layers.vllm.attention # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/runner/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
import vllm.envs as envs
from jax.sharding import NamedSharding, PartitionSpec
from torchax.ops.mappings import t2j_dtype
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import get_layers_from_vllm_config
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionType
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec, MLAAttentionSpec,
SlidingWindowSpec)
Expand Down