Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 88 additions & 8 deletions sendnn_inference/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,70 @@ def import_kernels(cls) -> None:
# Workaround torch.accelerator.empty_cache for torch 2.7.1 and vllm v0.18.0 compatibility
setattr(torch.accelerator, "empty_cache", lambda: None) # noqa

@classmethod
def set_device(cls, device: torch.device) -> None:
"""No-op: Spyre does not require explicit device selection."""

@classmethod
def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:
"""
Check if the current platform supports async output.
"""
return False

@classmethod
def get_spyre_scheduler_cls(
cls,
scheduler_config,
is_pooling: bool,
) -> type:
"""Get the appropriate Spyre scheduler class.

This follows the same pattern as vLLM's upstream SchedulerConfig.get_scheduler_cls():
- If scheduler_cls is already set, use it (allows custom schedulers)
- Otherwise, select based on scheduler_config.async_scheduling and model type

The scheduler selection uses factory functions that create classes with the
appropriate base (Scheduler or AsyncScheduler) based on async_scheduling config.

Args:
scheduler_config: The scheduler configuration object
is_pooling: True for pooling/embedding models, False for generative models

Returns:
The scheduler class to use
"""
# If a custom scheduler is explicitly set, use it (str or class both fine)
if scheduler_config.scheduler_cls is not None:
return scheduler_config.scheduler_cls

# Import from appropriate module based on async_scheduling config
# These modules have classes created at module level, so they're importable
if scheduler_config.async_scheduling:
# Use async scheduler variants
if is_pooling:
from sendnn_inference.v1.core.async_scheduler import (
AsyncPoolingSpyreScheduler,
)

return AsyncPoolingSpyreScheduler
else:
from sendnn_inference.v1.core.async_scheduler import (
AsyncChunkedPrefillSpyreScheduler,
)

return AsyncChunkedPrefillSpyreScheduler
else:
# Use sync scheduler variants (default)
if is_pooling:
from sendnn_inference.v1.core.scheduler import PoolingSpyreScheduler

return PoolingSpyreScheduler
else:
from sendnn_inference.v1.core.scheduler import ChunkedPrefillSpyreScheduler

return ChunkedPrefillSpyreScheduler

@classmethod
def get_max_batch_tkv_limit(cls) -> int:
if cls._max_batch_tkv_limit == 0:
Expand Down Expand Up @@ -219,8 +276,26 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
os.environ["FLEX_DEVICE"] = "COMPILE"

if is_decoder:
scheduler_config.scheduler_cls = (
"sendnn_inference.v1.core.scheduler.ChunkedPrefillSpyreScheduler"
# With MultiprocExecutor (TP>1), vLLM's default max_concurrent_batches=2
# enables step_with_batch_queue, which calls schedule() twice before
# update_from_output(). The Spyre mixin's pre-filter pattern is not safe
# under that run-ahead scenario. We install SpyreMultiprocExecutor which
# caps max_concurrent_batches=1, forcing the simpler step() path and
# restoring correct schedule→execute→update sequencing.
# (Spyre's AIU forward pass is synchronous so run-ahead has no throughput
# benefit anyway; AsyncScheduler's _update_after_schedule TTFT benefit is
# preserved through the normal step() path.)
if parallel_config.world_size > 1:
from sendnn_inference.v1.executor.spyre_executor import SpyreMultiprocExecutor

parallel_config.distributed_executor_backend = SpyreMultiprocExecutor

# Select scheduler using get_spyre_scheduler_cls(), following upstream's pattern
# This checks scheduler_cls first, then async_scheduling flag
# SchedulerConfig.scheduler_cls accepts str | type | None directly.
scheduler_config.scheduler_cls = cls.get_spyre_scheduler_cls(
scheduler_config=scheduler_config,
is_pooling=False,
)

if (
Expand Down Expand Up @@ -249,8 +324,11 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# unsetting this config as it was only set to pass vllm scheduler's max_model_len check
vllm_config.scheduler_config.enable_chunked_prefill = False

scheduler_config.scheduler_cls = (
"sendnn_inference.v1.core.scheduler.PoolingSpyreScheduler"
# Select scheduler using get_spyre_scheduler_cls(), following upstream's pattern
# SchedulerConfig.scheduler_cls accepts str | type | None directly.
scheduler_config.scheduler_cls = cls.get_spyre_scheduler_cls(
scheduler_config=scheduler_config,
is_pooling=True,
)

# Apply model-specific configurations using the registry
Expand Down Expand Up @@ -287,8 +365,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
envs_spyre.SENDNN_INFERENCE_DYNAMO_BACKEND,
)

# TODO: try to support async scheduling
scheduler_config.async_scheduling = False
logger.info(
"Spyre async scheduling is %s",
"enabled" if scheduler_config.async_scheduling else "disabled",
)

# To disable any paged attention ops in the base scheduler, we:
# - Set the block size (in tokens) to the maximum sequence length
Expand All @@ -304,7 +384,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
scheduler_config.max_num_batched_tokens = (
model_config.max_model_len * scheduler_config.max_num_seqs
)
cache_config.block_size = model_config.max_model_len # ty: ignore[invalid-assignment]
cache_config.block_size = model_config.max_model_len
vllm_config.cache_config.enable_prefix_caching = False

else:
Expand Down Expand Up @@ -759,7 +839,7 @@ def maybe_ensure_sendnn_configured(cls, model_config: ModelConfig) -> None:
@classmethod
def _set_batch_tkv_limit_from_env(cls) -> None:
try:
cls._max_batch_tkv_limit = int(os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", "-1")) # ty: ignore
cls._max_batch_tkv_limit = int(os.getenv("VLLM_DT_MAX_BATCH_TKV_LIMIT", "-1"))
except ValueError as e:
raise ValueError("VLLM_DT_MAX_BATCH_TKV_LIMIT must be an integer") from e

Expand Down
36 changes: 36 additions & 0 deletions sendnn_inference/v1/core/async_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0

from vllm.v1.core.sched.async_scheduler import AsyncScheduler

from sendnn_inference.v1.core.scheduler import (
ChunkedPrefillSpyreMixin,
PoolingSpyreMixin,
)


class AsyncSpyreScheduler(AsyncScheduler):
"""Base class inheriting from the V1 async scheduler to support static
and continuous batching respecting AIU Spyre constraints."""

def __init__(self, *args, **kwargs) -> None:
# Initialize vLLM async scheduler
super().__init__(*args, **kwargs)
self.model_config = self.vllm_config.model_config


class AsyncPoolingSpyreScheduler(PoolingSpyreMixin, AsyncScheduler):
"""Async scheduler for pooling models with Spyre warmup-shape constraints."""

pass


class AsyncChunkedPrefillSpyreScheduler(ChunkedPrefillSpyreMixin, AsyncScheduler):
"""Async scheduler with Spyre chunked-prefill constraints bypassed in async mode."""

pass


__all__ = [
"AsyncPoolingSpyreScheduler",
"AsyncChunkedPrefillSpyreScheduler",
]
Loading
Loading