Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/utils/test_cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_prefix_caching_is_on_by_default(monkeypatch: pytest.MonkeyPatch) -> Non
*common_args,
]
)
assert engine_args.enable_prefix_caching
assert engine_args.enable_prefix_caching is None
vllm_config = engine_args.create_engine_config()
assert engine_args.enable_prefix_caching
assert vllm_config.cache_config.enable_prefix_caching
Expand Down
69 changes: 69 additions & 0 deletions vllm_spyre/compat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
from dataclasses import fields
from functools import lru_cache
from typing import Callable
from packaging.version import Version
import torch
import transformers
import functools
from typing import Any
from types import ModuleType


def dataclass_fields(cls: type) -> list[str]:
Expand All @@ -22,3 +28,66 @@ def has_argument(func: Callable, param_name: str) -> bool:
):
return True
return False


def is_pytorch_lt_2_8() -> bool:
return Version(torch.__version__) < Version("2.8.0")


def maybe_patch_torch_2_7():
# Workaround issue with torch 2.7.1 https://github.com/pytorch/pytorch/issues/160886
# For now, we just disable the replacement of the linear layers
if is_pytorch_lt_2_8():
import vllm.model_executor.models.transformers.base as transformer_utils

@functools.wraps(transformer_utils.replace_linear_class)
def replace_linear_class(
linear: Any,
style: Any = "replicate",
quant_config: Any = None,
*,
prefix: str = "",
) -> Any:
return linear

transformer_utils.replace_linear_class = replace_linear_class # ty: ignore


def is_transformers_lt_5() -> bool:
return Version(transformers.__version__) < Version("5.0.0")


def maybe_patch_transformers_4_57(patch_backend: bool = False):
if is_transformers_lt_5():
if patch_backend:
from vllm.model_executor.models.transformers.base import Base

@functools.wraps(Base.check_version)
def check_version(cls, min_version: str, feature: str):
pass

Base.check_version = check_version # ty: ignore

def patch_model(model_module: ModuleType, attn_classes_attr: str, class_names: list[str]):
attn_classes = getattr(model_module, attn_classes_attr)
attn_classes["vllm"] = attn_classes["sdpa"]

for class_name in class_names:
model_class = getattr(model_module, class_name)
model_class.is_causal = False
model_class._supports_attention_backend = True

patch_model(
transformers.models.bert.modeling_bert, "BERT_SELF_ATTENTION_CLASSES", ["BertModel"]
)
patch_model(
transformers.models.roberta.modeling_roberta,
"ROBERTA_SELF_ATTENTION_CLASSES",
["RobertaModel", "RobertaForMaskedLM", "RobertaForSequenceClassification"],
)

patch_model(
transformers.models.xlm_roberta.modeling_xlm_roberta,
"XLM_ROBERTA_SELF_ATTENTION_CLASSES",
["XLMRobertaModel", "XLMRobertaForMaskedLM", "XLMRobertaForSequenceClassification"],
)
39 changes: 37 additions & 2 deletions vllm_spyre/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,21 @@
from vllm.config import ModelConfig, VllmConfig
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import AttentionSelectorConfig
else:
ModelConfig = None
VllmConfig = None
SamplingParams = None
PoolingParams = None
from vllm.platforms import Platform, PlatformEnum
AttentionBackendEnum = None
AttentionSelectorConfig = None

from vllm.v1.attention.backend import AttentionType
from vllm.platforms import Platform, PlatformEnum
import vllm_spyre.envs as envs_spyre
from vllm_spyre.compilation_utils import handle_disable_compilation
from vllm_spyre.compat_utils import maybe_patch_transformers_4_57

logger = init_logger(__name__)

Expand Down Expand Up @@ -83,8 +89,19 @@ def get_device_name(cls, device_id: int = 0) -> str:

@classmethod
def import_kernels(cls) -> None:
maybe_patch_transformers_4_57()
pass # suppress warning

@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
attn_selector_config: "AttentionSelectorConfig",
) -> str:
assert attn_selector_config.attn_type == AttentionType.ENCODER_ONLY
logger.info("Using Torch SDPA backend.")
return "vllm_spyre.v1.attention.backends.spyre_sdpa.SpyreSDPABackend"

@classmethod
def is_async_output_supported(cls, enforce_eager: bool | None) -> bool:
"""
Expand Down Expand Up @@ -133,6 +150,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if is_pooling:
os.environ["FLEX_OVERWRITE_NMB_FRAME"] = "false"
os.environ["COMPILATION_MODE"] = "offline"
if vllm_config.model_config.model_impl == "auto":
vllm_config.model_config.model_impl = "transformers"

archs = vllm_config.model_config.hf_config.architectures
if archs is not None and archs[0] in ("XLMRobertaForMaskedLM", "RobertaForMaskedLM"):
archs[0] = "TransformersEmbeddingModel"

if is_decoder:
scheduler_config.scheduler_cls = (
Expand Down Expand Up @@ -199,6 +222,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND,
)

# avoid circular imports
from vllm.config.compilation import CompilationMode
from vllm_spyre.model_executor.model_loader.spyre import BACKEND_LIST

# verify compilation config
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "eager":
vllm_config.compilation_config.mode = CompilationMode.NONE
else:
assert envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST
vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE
vllm_config.compilation_config.backend = envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND

# TODO: try to support async scheduling
scheduler_config.async_scheduling = False

Expand Down Expand Up @@ -408,8 +443,8 @@ def _get_matching_warmup_shapes(
@classmethod
def pre_register_and_update(cls, parser: FlexibleArgumentParser | None = None) -> None:
if parser is not None:
parser.set_defaults(enable_prefix_caching=True)
parser.set_defaults(max_num_batched_tokens=cls.DEFAULT_CHUNK_SIZE)
parser.set_defaults(model_impl="transformers")

@classmethod
def _check_threading_config(cls, worker_count: int):
Expand Down
Empty file.
Empty file.
213 changes: 213 additions & 0 deletions vllm_spyre/v1/attention/backends/spyre_sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Attention layer with simple KV caching without paging.
Uses SDPA for attention
"""

from dataclasses import dataclass

import torch
from torch.nn.functional import scaled_dot_product_attention

from vllm.config import VllmConfig
from vllm.v1.attention.backend import (
AttentionMetadata,
AttentionBackend,
AttentionImpl,
AttentionLayer,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.registry import register_backend, AttentionBackendEnum
from typing import ClassVar
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.logger import init_logger

logger = init_logger(__name__)


@register_backend(AttentionBackendEnum.CUSTOM)
class SpyreSDPABackend(AttentionBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [64]

@staticmethod
def get_name() -> str:
return "CUSTOM"

@staticmethod
def get_impl_cls() -> type["SpyreSDPABackendImpl"]:
return SpyreSDPABackendImpl

@staticmethod
def get_builder_cls() -> type["SpyreSDPAMetadataBuilder"]:
return SpyreSDPAMetadataBuilder

@classmethod
def supports_head_size(cls, head_size: int) -> bool:
# TODO: copied from flash attn, need to verify
return head_size % 8 == 0 and head_size <= 256

@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
return attn_type == AttentionType.ENCODER_ONLY


@dataclass
class SpyreSDPAMetadata(AttentionMetadata):
prompt_padding: torch.Tensor
padded_num_seqs: int
padded_seq_len: int


class SpyreSDPAMetadataBuilder(AttentionMetadataBuilder[SpyreSDPAMetadata]):
def __init__(
self,
kv_cache_spec: "AttentionSpec",
layer_names: list[str],
vllm_config: "VllmConfig",
device: torch.device,
):
self.kv_cache_spec = kv_cache_spec
self.layer_names = layer_names
self.vllm_config = vllm_config
self.device = device

def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> SpyreSDPAMetadata:
assert not common_attn_metadata.causal, "Causal attention is not supported"

padded_num_seqs = common_attn_metadata.query_start_loc.shape[0]

ret = SpyreSDPAMetadata(
prompt_padding=common_attn_metadata.query_start_loc,
padded_num_seqs=padded_num_seqs,
padded_seq_len=common_attn_metadata.num_actual_tokens // padded_num_seqs,
)
return ret


class SpyreSDPABackendImpl(AttentionImpl[SpyreSDPAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
sliding_window: int | None = None,
kv_cache_dtype: str = "auto",
logits_soft_cap: float | None = None,
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
assert num_kv_heads is not None
self.num_kv_heads = num_kv_heads

self.kv_cache_dtype = kv_cache_dtype

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

# Check for supported head sizes
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:
raise NotImplementedError("Sliding window is not supported.")
if logits_soft_cap is not None:
raise NotImplementedError("Logits soft cap is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
self.attn_type = attn_type
if attn_type != AttentionType.ENCODER_ONLY:
raise NotImplementedError(
"Only Encoder self-attention is implemented for SpyreSDPABackendImpl"
)

def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: SpyreSDPAMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.

Args:
query: shape = [batch_size * num_tokens, num_heads * head_size]
key: shape = [batch_size * num_tokens, num_kv_heads, head_size]
value: shape = [batch_size * num_tokens, num_kv_heads, head_size]
kv_cache = [] # disabled
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size * num_tokens, num_heads * head_size]
"""
assert output is None
assert kv_cache.numel() == 0, "Only encoder attention is supported"
assert key is not None and value is not None
bsize = attn_metadata.padded_num_seqs
seq_len = attn_metadata.padded_seq_len

# Reshape the query, key, and value tensors.
query = query.view(bsize, seq_len, self.num_heads, self.head_size)
key = key.view(bsize, seq_len, self.num_kv_heads, self.head_size)
value = value.view(bsize, seq_len, self.num_kv_heads, self.head_size)

query = query.transpose(2, 1)
key = key.transpose(2, 1)
value = value.transpose(2, 1)

attn_output = self._sdpa_forward(query, key, value, attn_metadata)

return attn_output.view(bsize * seq_len, self.num_heads * self.head_size)

def _sdpa_forward(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata
) -> torch.Tensor:
_, nheads, qlen, _ = query.shape
kvlen = key.shape[2]
assert self.num_kv_heads == key.shape[1]

if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)

mask_list = []

idx = torch.arange(kvlen, device=key.device)
for prompt_padding in attn_metadata.prompt_padding:
mask = idx >= prompt_padding
mask = mask.unsqueeze(0).expand(qlen, kvlen)
mask_list.append(mask)

masks = torch.stack(mask_list)
masks = masks.unsqueeze(1)
masks = masks.expand(-1, nheads, -1, -1)

out = scaled_dot_product_attention(
query,
key,
value,
is_causal=False,
scale=self.scale,
dropout_p=0.0,
attn_mask=masks,
)

out = out.transpose(2, 1).contiguous()
return out
Loading