diff --git a/tests/utils/test_cli_args.py b/tests/utils/test_cli_args.py index 1a663690b..aa7b6027a 100644 --- a/tests/utils/test_cli_args.py +++ b/tests/utils/test_cli_args.py @@ -78,7 +78,7 @@ def test_prefix_caching_is_on_by_default(monkeypatch: pytest.MonkeyPatch) -> Non *common_args, ] ) - assert engine_args.enable_prefix_caching + assert engine_args.enable_prefix_caching is None vllm_config = engine_args.create_engine_config() assert engine_args.enable_prefix_caching assert vllm_config.cache_config.enable_prefix_caching diff --git a/vllm_spyre/compat_utils.py b/vllm_spyre/compat_utils.py index edc2f5540..391a439b0 100644 --- a/vllm_spyre/compat_utils.py +++ b/vllm_spyre/compat_utils.py @@ -2,6 +2,12 @@ from dataclasses import fields from functools import lru_cache from typing import Callable +from packaging.version import Version +import torch +import transformers +import functools +from typing import Any +from types import ModuleType def dataclass_fields(cls: type) -> list[str]: @@ -22,3 +28,66 @@ def has_argument(func: Callable, param_name: str) -> bool: ): return True return False + + +def is_pytorch_lt_2_8() -> bool: + return Version(torch.__version__) < Version("2.8.0") + + +def maybe_patch_torch_2_7(): + # Workaround issue with torch 2.7.1 https://github.com/pytorch/pytorch/issues/160886 + # For now, we just disable the replacement of the linear layers + if is_pytorch_lt_2_8(): + import vllm.model_executor.models.transformers.base as transformer_utils + + @functools.wraps(transformer_utils.replace_linear_class) + def replace_linear_class( + linear: Any, + style: Any = "replicate", + quant_config: Any = None, + *, + prefix: str = "", + ) -> Any: + return linear + + transformer_utils.replace_linear_class = replace_linear_class # ty: ignore + + +def is_transformers_lt_5() -> bool: + return Version(transformers.__version__) < Version("5.0.0") + + +def maybe_patch_transformers_4_57(patch_backend: bool = False): + if is_transformers_lt_5(): + if patch_backend: + from vllm.model_executor.models.transformers.base import Base + + @functools.wraps(Base.check_version) + def check_version(cls, min_version: str, feature: str): + pass + + Base.check_version = check_version # ty: ignore + + def patch_model(model_module: ModuleType, attn_classes_attr: str, class_names: list[str]): + attn_classes = getattr(model_module, attn_classes_attr) + attn_classes["vllm"] = attn_classes["sdpa"] + + for class_name in class_names: + model_class = getattr(model_module, class_name) + model_class.is_causal = False + model_class._supports_attention_backend = True + + patch_model( + transformers.models.bert.modeling_bert, "BERT_SELF_ATTENTION_CLASSES", ["BertModel"] + ) + patch_model( + transformers.models.roberta.modeling_roberta, + "ROBERTA_SELF_ATTENTION_CLASSES", + ["RobertaModel", "RobertaForMaskedLM", "RobertaForSequenceClassification"], + ) + + patch_model( + transformers.models.xlm_roberta.modeling_xlm_roberta, + "XLM_ROBERTA_SELF_ATTENTION_CLASSES", + ["XLMRobertaModel", "XLMRobertaForMaskedLM", "XLMRobertaForSequenceClassification"], + ) diff --git a/vllm_spyre/platform.py b/vllm_spyre/platform.py index b40af9335..1cd696fb9 100644 --- a/vllm_spyre/platform.py +++ b/vllm_spyre/platform.py @@ -26,15 +26,21 @@ from vllm.config import ModelConfig, VllmConfig from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams + from vllm.v1.attention.backends.registry import AttentionBackendEnum + from vllm.v1.attention.selector import AttentionSelectorConfig else: ModelConfig = None VllmConfig = None SamplingParams = None PoolingParams = None -from vllm.platforms import Platform, PlatformEnum + AttentionBackendEnum = None + AttentionSelectorConfig = None +from vllm.v1.attention.backend import AttentionType +from vllm.platforms import Platform, PlatformEnum import vllm_spyre.envs as envs_spyre from vllm_spyre.compilation_utils import handle_disable_compilation +from vllm_spyre.compat_utils import maybe_patch_transformers_4_57 logger = init_logger(__name__) @@ -83,8 +89,19 @@ def get_device_name(cls, device_id: int = 0) -> str: @classmethod def import_kernels(cls) -> None: + maybe_patch_transformers_4_57() pass # suppress warning + @classmethod + def get_attn_backend_cls( + cls, + selected_backend: "AttentionBackendEnum", + attn_selector_config: "AttentionSelectorConfig", + ) -> str: + assert attn_selector_config.attn_type == AttentionType.ENCODER_ONLY + logger.info("Using Torch SDPA backend.") + return "vllm_spyre.v1.attention.backends.spyre_sdpa.SpyreSDPABackend" + @classmethod def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: """ @@ -133,6 +150,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if is_pooling: os.environ["FLEX_OVERWRITE_NMB_FRAME"] = "false" os.environ["COMPILATION_MODE"] = "offline" + if vllm_config.model_config.model_impl == "auto": + vllm_config.model_config.model_impl = "transformers" + + archs = vllm_config.model_config.hf_config.architectures + if archs is not None and archs[0] in ("XLMRobertaForMaskedLM", "RobertaForMaskedLM"): + archs[0] = "TransformersEmbeddingModel" if is_decoder: scheduler_config.scheduler_cls = ( @@ -199,6 +222,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND, ) + # avoid circular imports + from vllm.config.compilation import CompilationMode + from vllm_spyre.model_executor.model_loader.spyre import BACKEND_LIST + + # verify compilation config + if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "eager": + vllm_config.compilation_config.mode = CompilationMode.NONE + else: + assert envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST + vllm_config.compilation_config.mode = CompilationMode.STOCK_TORCH_COMPILE + vllm_config.compilation_config.backend = envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND + # TODO: try to support async scheduling scheduler_config.async_scheduling = False @@ -408,8 +443,8 @@ def _get_matching_warmup_shapes( @classmethod def pre_register_and_update(cls, parser: FlexibleArgumentParser | None = None) -> None: if parser is not None: - parser.set_defaults(enable_prefix_caching=True) parser.set_defaults(max_num_batched_tokens=cls.DEFAULT_CHUNK_SIZE) + parser.set_defaults(model_impl="transformers") @classmethod def _check_threading_config(cls, worker_count: int): diff --git a/vllm_spyre/v1/attention/__init__.py b/vllm_spyre/v1/attention/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm_spyre/v1/attention/backends/__init__.py b/vllm_spyre/v1/attention/backends/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm_spyre/v1/attention/backends/spyre_sdpa.py b/vllm_spyre/v1/attention/backends/spyre_sdpa.py new file mode 100644 index 000000000..63a319d77 --- /dev/null +++ b/vllm_spyre/v1/attention/backends/spyre_sdpa.py @@ -0,0 +1,213 @@ +"""Attention layer with simple KV caching without paging. +Uses SDPA for attention +""" + +from dataclasses import dataclass + +import torch +from torch.nn.functional import scaled_dot_product_attention + +from vllm.config import VllmConfig +from vllm.v1.attention.backend import ( + AttentionMetadata, + AttentionBackend, + AttentionImpl, + AttentionLayer, + AttentionMetadataBuilder, + AttentionType, + CommonAttentionMetadata, + MultipleOf, +) +from vllm.v1.attention.backends.registry import register_backend, AttentionBackendEnum +from typing import ClassVar +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@register_backend(AttentionBackendEnum.CUSTOM) +class SpyreSDPABackend(AttentionBackend): + supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16] + + @staticmethod + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: + return [64] + + @staticmethod + def get_name() -> str: + return "CUSTOM" + + @staticmethod + def get_impl_cls() -> type["SpyreSDPABackendImpl"]: + return SpyreSDPABackendImpl + + @staticmethod + def get_builder_cls() -> type["SpyreSDPAMetadataBuilder"]: + return SpyreSDPAMetadataBuilder + + @classmethod + def supports_head_size(cls, head_size: int) -> bool: + # TODO: copied from flash attn, need to verify + return head_size % 8 == 0 and head_size <= 256 + + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + return attn_type == AttentionType.ENCODER_ONLY + + +@dataclass +class SpyreSDPAMetadata(AttentionMetadata): + prompt_padding: torch.Tensor + padded_num_seqs: int + padded_seq_len: int + + +class SpyreSDPAMetadataBuilder(AttentionMetadataBuilder[SpyreSDPAMetadata]): + def __init__( + self, + kv_cache_spec: "AttentionSpec", + layer_names: list[str], + vllm_config: "VllmConfig", + device: torch.device, + ): + self.kv_cache_spec = kv_cache_spec + self.layer_names = layer_names + self.vllm_config = vllm_config + self.device = device + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> SpyreSDPAMetadata: + assert not common_attn_metadata.causal, "Causal attention is not supported" + + padded_num_seqs = common_attn_metadata.query_start_loc.shape[0] + + ret = SpyreSDPAMetadata( + prompt_padding=common_attn_metadata.query_start_loc, + padded_num_seqs=padded_num_seqs, + padded_seq_len=common_attn_metadata.num_actual_tokens // padded_num_seqs, + ) + return ret + + +class SpyreSDPABackendImpl(AttentionImpl[SpyreSDPAMetadata]): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int | None = None, + alibi_slopes: list[float] | None = None, + sliding_window: int | None = None, + kv_cache_dtype: str = "auto", + logits_soft_cap: float | None = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: str | None = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + assert num_kv_heads is not None + self.num_kv_heads = num_kv_heads + + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + # Check for supported head sizes + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if logits_soft_cap is not None: + raise NotImplementedError("Logits soft cap is not supported.") + if kv_cache_dtype != "auto": + raise NotImplementedError("FP8 KV cache dtype is not supported.") + self.attn_type = attn_type + if attn_type != AttentionType.ENCODER_ONLY: + raise NotImplementedError( + "Only Encoder self-attention is implemented for SpyreSDPABackendImpl" + ) + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: SpyreSDPAMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass with torch SDPA and PagedAttention. + + Args: + query: shape = [batch_size * num_tokens, num_heads * head_size] + key: shape = [batch_size * num_tokens, num_kv_heads, head_size] + value: shape = [batch_size * num_tokens, num_kv_heads, head_size] + kv_cache = [] # disabled + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size * num_tokens, num_heads * head_size] + """ + assert output is None + assert kv_cache.numel() == 0, "Only encoder attention is supported" + assert key is not None and value is not None + bsize = attn_metadata.padded_num_seqs + seq_len = attn_metadata.padded_seq_len + + # Reshape the query, key, and value tensors. + query = query.view(bsize, seq_len, self.num_heads, self.head_size) + key = key.view(bsize, seq_len, self.num_kv_heads, self.head_size) + value = value.view(bsize, seq_len, self.num_kv_heads, self.head_size) + + query = query.transpose(2, 1) + key = key.transpose(2, 1) + value = value.transpose(2, 1) + + attn_output = self._sdpa_forward(query, key, value, attn_metadata) + + return attn_output.view(bsize * seq_len, self.num_heads * self.head_size) + + def _sdpa_forward( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata + ) -> torch.Tensor: + _, nheads, qlen, _ = query.shape + kvlen = key.shape[2] + assert self.num_kv_heads == key.shape[1] + + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + mask_list = [] + + idx = torch.arange(kvlen, device=key.device) + for prompt_padding in attn_metadata.prompt_padding: + mask = idx >= prompt_padding + mask = mask.unsqueeze(0).expand(qlen, kvlen) + mask_list.append(mask) + + masks = torch.stack(mask_list) + masks = masks.unsqueeze(1) + masks = masks.expand(-1, nheads, -1, -1) + + out = scaled_dot_product_attention( + query, + key, + value, + is_causal=False, + scale=self.scale, + dropout_p=0.0, + attn_mask=masks, + ) + + out = out.transpose(2, 1).contiguous() + return out diff --git a/vllm_spyre/v1/worker/spyre_model_runner.py b/vllm_spyre/v1/worker/spyre_model_runner.py index 606311b68..2017a9b66 100644 --- a/vllm_spyre/v1/worker/spyre_model_runner.py +++ b/vllm_spyre/v1/worker/spyre_model_runner.py @@ -3,37 +3,53 @@ import time from abc import ABC, abstractmethod from dataclasses import dataclass, field +from collections import defaultdict from logging import DEBUG -from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union +from typing import TYPE_CHECKING, cast, Any, Generic, TypeVar, NamedTuple, TypeAlias, Protocol +from copy import deepcopy, copy +import numpy import torch -from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer -from vllm.config import DeviceConfig, VllmConfig, set_current_vllm_config +from transformers import AutoTokenizer +from vllm.config import DeviceConfig, VllmConfig, get_layers_from_vllm_config +from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.pooler.activations import ( - get_cross_encoder_act_fn, -) -from vllm.model_executor.layers.pooler.seqwise.poolers import ( - pooler_for_classify, - pooler_for_embed, -) from vllm.sampling_params import SamplingType from vllm.tasks import SupportedTask from vllm.utils.hashing import get_hash_fn_by_name from vllm.utils.platform_utils import is_pin_memory_available +from vllm.attention.layer import Attention from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import KVCacheBlock, get_request_block_hasher, init_none_hash from vllm.v1.core.sched.output import CachedRequestData from vllm.v1.core.single_type_kv_cache_manager import FullAttentionManager -from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + EncoderOnlyAttentionSpec, + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + KVCacheSpec, +) +from vllm.v1.attention.backend import ( + AttentionBackend, + AttentionMetadata, + AttentionType, + CommonAttentionMetadata, +) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput, SamplerOutput -from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.request import Request +from vllm.v1.worker.utils import AttentionGroup +from vllm.model_executor.models.interfaces_base import VllmModelForPooling +from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.models.transformers.legacy import LegacyMixin +from vllm.model_executor.models.transformers.base import Base as TransformersBase -import vllm_spyre.envs as envs_spyre import vllm_spyre.utils as utils_spyre +import vllm_spyre.envs as envs_spyre from vllm_spyre.model_executor.model_loader.spyre import ( BACKEND_LIST, SpyreAttentionMetadata, @@ -42,6 +58,11 @@ from vllm_spyre.platform import SpyrePlatform from vllm_spyre.utils import exact_div from vllm_spyre.v1.sample.spyre_logits_processor import build_logitsprocs_for_cb +from vllm_spyre.compat_utils import ( + maybe_patch_transformers_4_57, + is_transformers_lt_5, + maybe_patch_torch_2_7, +) # yapf conflicts with ruff for this block # yapf: disable @@ -63,6 +84,10 @@ NewRequestData = None SamplingMetadata = None + +PerLayerAttnMetadata: TypeAlias = dict[str, AttentionMetadata] +# list when ubatching is enabled + logger = init_logger(__name__) @@ -155,9 +180,6 @@ def __init__( self.device = torch.device(self.device_config.device) self.pin_memory = is_pin_memory_available() - # Lazy initialization: after load_model. - self._model: SpyreCausalLM | None = None - # Flag to be turned off after warmup is complete self.warmup_mode = True @@ -171,10 +193,12 @@ def __init__( def build_input_batch(self) -> InputBatchT: raise NotImplementedError + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + pass + @property - def model(self) -> SpyreCausalLM: - assert self._model is not None, "model accessed before loading" - return self._model + def model(self) -> torch.nn.Module: + raise NotImplementedError @property def is_multimodal(self) -> bool: @@ -229,12 +253,6 @@ def complete_warmup(self): """Turn off warmup mode once the warmup is complete""" self.warmup_mode = False - def build_attn_metadata(self, model_input: ModelInputsT) -> SpyreAttentionMetadata: - # TODO: probably sooner we will need a more sophisticated way to switch - # build attention metadata based on model/attention. But for now, a - # simple method override is good enough. - return None # ty: ignore - @abstractmethod def update_states(self, scheduler_output: SchedulerOutput): raise NotImplementedError @@ -248,28 +266,33 @@ def execute_model( ) -> ModelRunnerOutput: raise NotImplementedError + def _make_compatible_sampled_token_ids( + self, sampled_token_ids: torch.Tensor + ) -> list[list[int]] | list[numpy.ndarray]: + """Some versions of vllm required a list of numpy arrays as output. + This was ultimately rejected, see: + https://github.com/vllm-project/vllm/pull/29121 -class PoolerAdapter(torch.nn.Module): - def __init__(self, pooler: torch.nn.Module): - super().__init__() - self.pooler = pooler + This can be removed once the *lower bound* of the vllm dependency is + >= 0.12.0 + """ + # ty ignore comments here because the typing is all dependent on the specific version of + # vllm installed + if ModelRunnerOutput.__dataclass_fields__["sampled_token_ids"].type == list[numpy.ndarray]: + sampled_token_ids = [x for x in sampled_token_ids.numpy()] # ty: ignore + else: + sampled_token_ids = sampled_token_ids.tolist() # ty: ignore + return sampled_token_ids # ty: ignore - def forward( - self, - hidden_states: Union[torch.Tensor, tuple[torch.Tensor, ...]], - pooling_metadata: PoolingMetadata, - ) -> Union[torch.Tensor, list[torch.Tensor]]: - # Because we're using transformers to load the pooler - # and classifier layers and the assumption there is that - # we have a right padded batch, we need to split - # and at the batch dimension. - if isinstance(hidden_states, torch.Tensor): - hidden_states = torch.split(hidden_states, pooling_metadata.prompt_lens.tolist()) - return [self.pooler(h.unsqueeze(dim=0)) for h in hidden_states] +class PoolingModel(VllmModelForPooling, Protocol): + def __call__(self, *args, **kwargs) -> torch.Tensor: + pass -def _cls(input: torch.Tensor) -> torch.Tensor: - return input[:, 0] + def eval( + self, + ) -> None: + pass class SpyrePoolingModelRunner( @@ -289,6 +312,12 @@ def __init__( self._position_ids: torch.Tensor = None self.use_token_type_ids = False + # Attention layers that are only in the KVCacheConfig of the runner + # (e.g., KV sharing, encoder-only attention), but not in the + # KVCacheConfig of the scheduler. + self.runner_only_attn_layers: set[str] = set() + self.attn_groups: list[list[AttentionGroup]] = [] + @property def model(self) -> torch.nn.Module: return self._model # ty: ignore[invalid-return-type] @@ -302,43 +331,135 @@ def build_input_batch(self) -> PoolingInputBatch: vocab_size=self.model_config.get_vocab_size(), ) - def load_model(self) -> None: - assert len(self.model_config.architectures) == 1 - task = ( - "classify" - if self.model_config.architectures[0].endswith("ForSequenceClassification") - else "embed" - ) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: + # No KV cache for encoder-only attention + return {} + + # I've kept the organization of methods the same as in the + # GPU model runner, so that it's easier to recognize the + # pattern + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + kv_cache_config = deepcopy(kv_cache_config) + self.kv_cache_config = kv_cache_config + self.may_add_encoder_only_layers_to_kv_cache_config() + self.initialize_attn_backend(kv_cache_config) + self.initialize_metadata_builders(kv_cache_config) + + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: + """ + Add encoder-only layers to the KV cache config. + """ + block_size = self.vllm_config.cache_config.block_size + encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + assert attn_module.attn_type == AttentionType.ENCODER_ONLY + attn_spec: AttentionSpec = EncoderOnlyAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=torch.float16, + ) + encoder_only_attn_specs[attn_spec].append(layer_name) + self.runner_only_attn_layers.add(layer_name) + if len(encoder_only_attn_specs) > 0: + assert len(encoder_only_attn_specs) == 1, ( + "Only support one encoder-only attention spec now" + ) + spec, layer_names = encoder_only_attn_specs.popitem() + self.kv_cache_config.kv_cache_groups.append( + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec) + ) - if task == "embed": - self._model = AutoModel.from_pretrained(self.model_config.model) - elif task == "classify": - class_model = AutoModelForSequenceClassification.from_pretrained( - self.model_config.model + def initialize_metadata_builders(self, kv_cache_config: KVCacheConfig) -> None: + for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)): + for attn_group in self.attn_groups[kv_cache_group_id]: + attn_group.create_metadata_builders( + self.vllm_config, + self.device, + None, + ) + + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: + """ + Initialize the attention backends and attention metadata builders. + """ + assert len(self.attn_groups) == 0, "Attention backends are already initialized" + + class AttentionGroupKey(NamedTuple): + attn_backend: type[AttentionBackend] + kv_cache_spec: KVCacheSpec + + def get_attn_backends_for_group( + kv_cache_group_spec: KVCacheGroupSpec, + ) -> tuple[dict[AttentionGroupKey, list[str]], set[type[AttentionBackend]]]: + layer_type = cast(type[Any], AttentionLayerBase) + layers = get_layers_from_vllm_config( + self.vllm_config, layer_type, kv_cache_group_spec.layer_names ) - if hasattr(class_model, "bert"): - self._model = class_model.bert - self._pooler = PoolerAdapter(self.model.pooler) # ty:ignore[invalid-argument-type] - elif hasattr(class_model, "roberta"): - self._model = class_model.roberta - self._pooler = PoolerAdapter(_cls) # ty:ignore[invalid-argument-type] - else: - raise ValueError( - f"Unsupported model {self.model_config.model}: Expected " - "Bert or Roberta for sequence classification" + attn_backends = {} + attn_backend_layers = defaultdict(list) + + for layer_name in kv_cache_group_spec.layer_names: + attn_backend = layers[layer_name].get_attn_backend() + + full_cls_name = attn_backend.full_cls_name() + layer_kv_cache_spec = kv_cache_group_spec.kv_cache_spec + key = (full_cls_name, layer_kv_cache_spec) + attn_backends[key] = AttentionGroupKey(attn_backend, layer_kv_cache_spec) + attn_backend_layers[key].append(layer_name) + return ( + {attn_backends[k]: v for k, v in attn_backend_layers.items()}, + set(group_key.attn_backend for group_key in attn_backends.values()), + ) + + def create_attn_groups( + attn_backends_map: dict[AttentionGroupKey, list[str]], + kv_cache_group_id: int, + ) -> list[AttentionGroup]: + attn_groups: list[AttentionGroup] = [] + for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): + attn_group = AttentionGroup( + attn_backend, + layer_names, + kv_cache_spec, + kv_cache_group_id, ) - self.classifier = class_model.classifier - # Disable pooler because in transformers it's - # always run even tough we don't use the outputs - # directly. - self._model.pooler = None + attn_groups.append(attn_group) + return attn_groups - model_class_name = type(self.model).__name__ - self.is_roberta = "roberta" in model_class_name.lower() + for i, kv_cache_group_spec in enumerate(kv_cache_config.kv_cache_groups): + attn_backends = get_attn_backends_for_group(kv_cache_group_spec) + self.attn_groups.append(create_attn_groups(attn_backends[0], i)) + + def load_model(self) -> None: + maybe_patch_transformers_4_57(patch_backend=True) + maybe_patch_torch_2_7() - self.model.eval() + model_loader = get_model_loader(self.load_config) + self.vllm_model: PoolingModel = model_loader.load_model( + vllm_config=self.vllm_config, model_config=self.model_config + ) + self.vllm_model.eval() torch.set_grad_enabled(False) + + def _find_compilable(module: torch.nn.Module) -> torch.nn.Module | None: + if isinstance(module, TorchCompileWithNoGuardsWrapper): + return module + for child_module in module.children(): + if (mod := _find_compilable(child_module)) is not None: + return mod + return None + + if is_transformers_lt_5(): + assert isinstance(self.vllm_model, TransformersBase) + self._model = self.vllm_model.model + self._compilable = self._model + else: + self._model = self.vllm_model + self._compilable = _find_compilable(self.model) + if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST: # Lazy import to avoid load torch_sendnn runtime before it is really # necessary. This solve issues of running forked tests that share @@ -351,40 +472,26 @@ def load_model(self) -> None: with utils_spyre.stagger_region( envs_spyre.VLLM_SPYRE_MAX_LOAD_PROCESSES, self.parallel_config.world_size, self.rank ): - # Not clear how to make the type checking happy with the torch.compile return - self._model = torch.compile( # ty: ignore[invalid-assignment] - self.model, + assert self._compilable is not None + self._compilable.compile( mode="default", dynamic=False, backend=envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND, + fullgraph=True, ) - if task == "classify": - tokenizer = AutoTokenizer.from_pretrained(self.model_config.model) + self.use_token_type_ids = False + if "score" in self.vllm_model.pooler.get_supported_tasks() and ( + tokenizer := AutoTokenizer.from_pretrained(self.model_config.model) + ): output = tokenizer(text="foo", text_pair="bar") self.use_token_type_ids = "token_type_ids" in output if self.use_token_type_ids: self.sep_token_id = tokenizer.sep_token_id - pooler_config = self.model_config.pooler_config - assert pooler_config is not None, "Pooler config is require for pooling models" - - if task == "embed": - with set_current_vllm_config(self.vllm_config): - self.pooler = pooler_for_embed(pooler_config=pooler_config) - elif task == "classify": - with set_current_vllm_config(self.vllm_config): - self.pooler = pooler_for_classify( - pooler_config=pooler_config, - pooling=self._pooler, - classifier=self.classifier, - act_fn=get_cross_encoder_act_fn(self.model_config.hf_config), - ) - @property def vocab_size(self) -> int: - # self.model here is probably a transformers model class - return self.model.config.vocab_size # ty: ignore[invalid-return-type] + return self.model_config.get_vocab_size() def _prepare_pad_input_ids( self, @@ -559,7 +666,11 @@ def _prepare_prompt( if self.use_token_type_ids: token_type_ids = self._token_types(input_tokens) - if self.is_roberta: + if ( + is_transformers_lt_5() + and isinstance(self.vllm_model, LegacyMixin) + and self.vllm_model.is_roberta + ): position_ids += self.pad_token_id + 1 position_ids *= mask @@ -608,7 +719,57 @@ def _mark_input_tensors(self, model_input: PoolingForwardInputs) -> None: torch._dynamo.mark_static(model_input.token_type_ids, 1) def get_supported_tasks(self) -> tuple[SupportedTask, ...]: - return tuple(self.pooler.get_supported_tasks()) + return tuple(self.vllm_model.pooler.get_supported_tasks()) + + def build_attn_metadata( + self, + input: PoolingForwardInputs, + ) -> PerLayerAttnMetadata: + num_tokens_padded = input.input_tokens.numel() + num_reqs_padded = self.input_batch.padded_batch_size + + prompt_tokens = torch.from_numpy(self.input_batch._get_num_prompt_tokens()) + num_reqs = len(prompt_tokens) + max_seq_len = max(prompt_tokens) + + padded_prompt_len = input.input_tokens.shape[1] + prompt_paddings = padded_prompt_len - prompt_tokens + + query_start_locs = torch.zeros(num_reqs_padded, dtype=torch.int32) + query_start_locs[:num_reqs] = prompt_paddings + + cm_base = CommonAttentionMetadata( + query_start_loc=query_start_locs, + query_start_loc_cpu=query_start_locs, + seq_lens=prompt_tokens, + num_reqs=self.input_batch.num_reqs, + num_actual_tokens=num_tokens_padded, + max_query_len=max_seq_len, + max_seq_len=max_seq_len, + block_table_tensor=torch.zeros(0), + slot_mapping=torch.zeros(0), + causal=False, + ) + + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + attn_metadata: PerLayerAttnMetadata = {} + for kv_cache_gid, _ in enumerate(self.kv_cache_config.kv_cache_groups): + cm = copy(cm_base) # shallow copy + + for attn_gid in range(len(self.attn_groups[kv_cache_gid])): + attn_group = self.attn_groups[kv_cache_gid][attn_gid] + builder = attn_group.get_metadata_builder() + + attn_metadata_i = builder.build( + common_prefix_len=0, + common_attn_metadata=cm, + ) + + for layer_name in attn_group.layer_names: + attn_metadata[layer_name] = attn_metadata_i + + return attn_metadata @SpyrePlatform.inference_mode() def execute_model( @@ -632,16 +793,37 @@ def execute_model( if self.use_token_type_ids: model_kwargs["token_type_ids"] = model_input.token_type_ids - # Execute the model - with set_forward_context(attn_metadata, self.vllm_config): + def call_model_transformers_4_57() -> torch.Tensor: outputs = self.model( input_ids=model_input.input_tokens, position_ids=model_input.input_positions, attention_mask=model_input.input_masks, **model_kwargs, ) + return outputs["last_hidden_state"] + + def call_model_transformers_5() -> torch.Tensor: + assert model_input.input_tokens is not None + assert model_input.input_positions is not None + batch, seqlen = model_input.input_tokens.shape + # We need to flatten the batch dimension here to be + # compatible with the transformers backend. The vllm + # code doesn't care. + if (token_typed_ids := model_kwargs.get("token_typed_ids")) is not None: + model_kwargs["token_typed_ids"] = token_typed_ids.view(batch * seqlen) + + hidden_states = self.model( + input_ids=model_input.input_tokens.view(batch * seqlen), + positions=model_input.input_positions.view(batch * seqlen), + **model_kwargs, + ) + return hidden_states.view(batch, seqlen, -1) - hidden_states = outputs["last_hidden_state"] + with set_forward_context(attn_metadata, self.vllm_config): + if is_transformers_lt_5(): + hidden_states = call_model_transformers_4_57() + else: + hidden_states = call_model_transformers_5() # Only perform sampling in the driver worker. if not self.is_driver_worker: @@ -665,7 +847,7 @@ def execute_model( # we're left padding hidden_state_list.append(hidden_state[-prompt_len:]) - raw_pooler_output = self.pooler( + raw_pooler_output = self.vllm_model.pooler( hidden_states=torch.cat(hidden_state_list), pooling_metadata=pooling_metadata ) @@ -740,6 +922,11 @@ def __init__( self.prefix_cache_stats = None + @property + def model(self) -> SpyreCausalLM: + assert self._model is not None, "model accessed before loading" + return self._model + def load_model(self) -> None: self._model = SpyreCausalLM( vllm_config=self.vllm_config, @@ -751,7 +938,7 @@ def vocab_size(self) -> int: model_cfg = self.model.fms_model.config if self.model.is_multimodal: return self.model.mm_model_utils.resolve_multimodal_vocab_size() - return model_cfg.src_vocab_size # ty: ignore[invalid-return-type] + return model_cfg.src_vocab_size @property def enable_prefix_caching(self): @@ -784,7 +971,7 @@ def pre_warmup(self) -> None: self._set_blocks(num_blocks=n_blocks_warmup) # TODO: fixup the typing here. Things are getting tripped up by having all of our "model" # classes inherit from `nn.Module` when maybe they don't need to - self.model.set_past_key_value_states(num_blocks=n_blocks_warmup) # ty: ignore[call-non-callable] + self.model.set_past_key_value_states(num_blocks=n_blocks_warmup) # Future code: @@ -806,7 +993,7 @@ def complete_warmup(self) -> None: self._set_blocks(num_blocks=n_blocks_avail) # TODO: fixup the typing here. Things are getting tripped up by having all of our "model" # classes inherit from `nn.Module` when maybe they don't need to - self.model.set_past_key_value_states(num_blocks=n_blocks_avail) # ty: ignore[call-non-callable] + self.model.set_past_key_value_states(num_blocks=n_blocks_avail) def get_total_spyre_blocks(self) -> int: """Returns the total number of KV cache blocks available for spyre. diff --git a/vllm_spyre/v1/worker/spyre_worker.py b/vllm_spyre/v1/worker/spyre_worker.py index 7effb0550..30952c3d7 100644 --- a/vllm_spyre/v1/worker/spyre_worker.py +++ b/vllm_spyre/v1/worker/spyre_worker.py @@ -215,10 +215,10 @@ def determine_available_memory(self) -> int: # This can probably be fixed in a nicer way. return 2 * accurate_fake_kv_cache_size - def initialize_from_config(self, kv_cache_configs: list[KVCacheConfig]) -> None: + def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Construct the KV cache from the provided configs. Currently, we do not support paged attention or kv caching""" - pass + self.model_runner.initialize_kv_cache(kv_cache_config) def __init__( self, @@ -570,7 +570,9 @@ def _warmup_spyre_fixed_size(self, prompt_len, special_token_ids, batch_size): sampling_params, pooling_params = None, None - pooling_params = PoolingParams(task="embed") # for warmup any task will do + supported_tasks = self.model_runner.get_supported_tasks() + + pooling_params = PoolingParams(task=supported_tasks[0]) # ty: ignore # Set up dummy requests for prefill steps dummy_requests = [