diff --git a/vllm/attention/backends/rwkv5linear_attn.py b/vllm/attention/backends/rwkv5linear_attn.py new file mode 100644 index 000000000000..446c07c329f2 --- /dev/null +++ b/vllm/attention/backends/rwkv5linear_attn.py @@ -0,0 +1,500 @@ +"""Attention layer with FlashAttention.""" +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +# from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionMetadataBuilder, AttentionState) +from vllm.attention.backends.utils import PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty +from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.worker.model_runner import ModelInputForGPUBuilder +from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase + + +class RWKVAttentionState(AttentionState): + + def __init__(self, runner: "ModelRunnerBase"): + self.runner = runner + self._is_graph_capturing = False + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + yield + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + + def graph_clone(self, batch_size: int) -> "RWKVAttentionState": + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + use_cuda_graph=True, + ) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers backend. + # Assert the same. + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) + + return attn_metadata + + def begin_forward(self, model_input: ModelRunnerInputBase) -> None: + return super().begin_forward(model_input) + + def get_graph_input_buffers(self, attn_metadata: Any, is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + return super().get_graph_input_buffers(attn_metadata, is_encoder_decoder_model) + + def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any], attn_metadata: Any, is_encoder_decoder_model: bool = False) -> None: + return super().prepare_graph_input_buffers(input_buffers, attn_metadata, is_encoder_decoder_model) + +class RWKVFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32,40, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "linear-flash-attn" + + @staticmethod + def get_state_cls() -> type[AttentionState]: + return RWKVAttentionState + + @staticmethod + def get_impl_cls() -> Type["LinearFlashAttentionImpl"]: + return LinearFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return LinearFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashLinearAttentionMetadataBuilder"]: + return FlashLinearAttentionMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (min(num_blocks,16),block_size, head_size+2, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + # src_value_cache = src_kv_cache[1] + # dst_value_cache = dst_kv_cache[1] + # ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class LinearFlashAttentionMetadata(AttentionMetadata): + """Metadata for RWKVFlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + _cached_prefill_metadata: Optional["LinearFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["LinearFlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["LinearFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = LinearFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefills, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefills], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["LinearFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = LinearFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping, + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class LinearFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + ) -> None: + + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + + support_head_sizes = RWKVFlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: LinearFlashAttentionMetadata, + kv_scale: float = 1.0, + ) -> torch.Tensor: + + num_tokens, hidden_size = query.shape + output = query + + return output.view(num_tokens, hidden_size) + + +class FlashLinearAttentionMetadataBuilder( + AttentionMetadataBuilder[LinearFlashAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id] + + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx( + False, 1, 1, False, + self.use_v2_block_manager) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + 1, 0, start_idx, + self.block_size, inter_data.block_tables) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.from_numpy(input_block_tables).to( + device=device, non_blocking=True) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + return LinearFlashAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 30aa7cb311af..78504dc25832 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -9,6 +9,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger +from vllm.transformers_utils.configs.RWKV5 import useLinear from vllm.platforms import current_platform from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu @@ -24,6 +25,7 @@ class _Backend(enum.Enum): FLASHINFER = enum.auto() PALLAS = enum.auto() IPEX = enum.auto() + LINEAR = enum.auto() def backend_name_to_enum(backend_name: str) -> _Backend: @@ -146,8 +148,13 @@ def get_attn_backend( logger.info("Using Pallas backend.") from vllm.attention.backends.pallas import PallasAttentionBackend return PallasAttentionBackend + elif backend == _Backend.LINEAR: + logger.info("Using Pallas backend.") + from vllm.attention.backends.rwkv5linear_attn import RWKVFlashAttentionBackend + return RWKVFlashAttentionBackend else: raise ValueError("Invalid attention backend.") + def which_attn_to_use( @@ -163,6 +170,10 @@ def which_attn_to_use( # Default case. selected_backend = _Backend.FLASH_ATTN + if useLinear: + print("Using Linear Attention") + return _Backend.LINEAR + # Check whether a particular choice of backend was # previously forced. # diff --git a/vllm/config.py b/vllm/config.py index 3139c5a08bfb..3520ecc14e08 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -618,6 +618,7 @@ def __init__( self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * GiB_bytes self.num_gpu_blocks_override = num_gpu_blocks_override + self.num_cpu_blocks_override = None self.cache_dtype = cache_dtype self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5b7587d15084..8e54d61019d5 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -891,7 +891,7 @@ def _schedule_prefills( assert num_new_tokens == num_prompt_tokens prompt_limit = self._get_prompt_limit(seq_group) - if num_new_tokens > prompt_limit: + if num_new_tokens > prompt_limit and prompt_limit > 0: logger.warning( "Input prompt (%d tokens) is too long" " and exceeds limit of %d", num_new_tokens, prompt_limit) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3550759f85dd..eaf342b2c3e8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -494,6 +494,14 @@ def _initialize_kv_caches(self) -> None: num_gpu_blocks_override) num_gpu_blocks = num_gpu_blocks_override + if self.cache_config.num_cpu_blocks_override is not None: + num_cpu_blocks_override = self.cache_config.num_cpu_blocks_override + logger.info( + "Overriding num_cpu_blocks=%d with " + "num_cpu_blocks_override=%d", num_cpu_blocks, + num_cpu_blocks_override) + num_cpu_blocks = num_cpu_blocks + self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 74277cae7c8e..a44e8b37a6a8 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -297,6 +297,8 @@ def build_1_2_5_buckets(max_value: int) -> List[int]: [1, 2, 5, 10, 20, 50, 100] """ mantissa_lst = [1, 2, 5] + if max_value <= 0: + max_value = 100 # for infinite context models exponent = 0 buckets: List[int] = [] while True: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 1a0669d8d12c..011036dd3cc6 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -247,7 +247,7 @@ def _validate_input( # Note: EmbeddingRequest doesn't have max_tokens if isinstance(request, EmbeddingRequest): - if token_num > self.max_model_len: + if token_num > self.max_model_len and self.max_model_len > 0: raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " @@ -264,13 +264,13 @@ def _validate_input( prompt_token_ids=input_ids) if request.max_tokens is None: - if token_num >= self.max_model_len: + if token_num >= self.max_model_len and self.max_model_len > 0: raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the messages, " f"Please reduce the length of the messages.") - elif token_num + request.max_tokens > self.max_model_len: + elif token_num + request.max_tokens > self.max_model_len and self.max_model_len > 0: raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " diff --git a/vllm/model_executor/models/Rwkv5ForCausalLM.py b/vllm/model_executor/models/Rwkv5ForCausalLM.py new file mode 100644 index 000000000000..e831d560488c --- /dev/null +++ b/vllm/model_executor/models/Rwkv5ForCausalLM.py @@ -0,0 +1,402 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/RWKV/modeling_RWKV.py +# Copyright 2023 The vLLM team. +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-2 model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import RwkvConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.attention.backends.rwkv5linear_attn import LinearFlashAttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, tensor_model_parallel_all_gather +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SampleLogprobs + + +class RWKVAttention(nn.Module): + + def __init__( + self, + config: RwkvConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim**-0.5 + + self.receptance = ColumnParallelLinear(self.hidden_size, self.hidden_size, bias=False, quant_config=quant_config) + self.key = ColumnParallelLinear(self.hidden_size, self.hidden_size, bias=False, quant_config=quant_config) + + self.value = ColumnParallelLinear(self.hidden_size, self.hidden_size, bias=False, quant_config=quant_config) + self.gate = ColumnParallelLinear(self.hidden_size, self.hidden_size, bias=False, quant_config=quant_config) + self.output = RowParallelLinear(self.hidden_size, self.hidden_size, bias=False, quant_config=quant_config) + self.time_mix_key = nn.Parameter(torch.zeros(1,1,self.hidden_size)) + self.time_mix_value = nn.Parameter(torch.zeros(1,1,self.hidden_size)) + self.time_mix_receptance = nn.Parameter(torch.zeros(1,1,self.hidden_size)) + self.time_mix_gate = nn.Parameter(torch.zeros(1,1,self.hidden_size)) + self.time_decay = nn.Parameter(torch.zeros(self.num_heads, self.head_dim)) + self.time_faaaa = nn.Parameter(torch.zeros(self.num_heads, self.head_dim)) + self.ln_x = nn.GroupNorm(self.num_heads,self.hidden_size//tensor_model_parallel_world_size) + self.head_size_divisor = 8 + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + position_ids:torch.Tensor + ) -> torch.Tensor: + x = hidden_states + + + blocknum = attn_metadata.slot_mapping // 16 + blockidx = attn_metadata.slot_mapping % 16 + blocknum = blocknum.to(x.device) + blockidx = blockidx.to(x.device) + + if(attn_metadata.num_decode_tokens > 0): + if kv_cache != None: + state = kv_cache[blocknum,blockidx,0,:] + kv_cache[blocknum,blockidx,0,:] = x.chunk(get_tensor_model_parallel_world_size(),-1)[get_tensor_model_parallel_rank()].reshape_as(state) + state = tensor_model_parallel_all_gather(state.reshape(-1,x.shape[-1]//get_tensor_model_parallel_world_size())) + else: + print(x.shape) + ott = torch.arange(x.shape[0]).to(x.device) + ott = ott-1 + state = x[ott] + state[position_ids==0]*=0 # for start of sequence + if(kv_cache != None and kv_cache.shape[0] > 0): + mm = kv_cache[blocknum[ott[position_ids==0]],blockidx[ott[position_ids==0]],0,:] + mm[:] = x[ott[position_ids==0]-1].chunk(get_tensor_model_parallel_world_size(),-1)[get_tensor_model_parallel_rank()].reshape_as(mm) + state = state.reshape_as(x) + + # state[position_ids.query_start_loc] = cache + # cache = state[position_ids.query_start_loc + position_ids.seq_lens] + + xx = state - x + xk = x + xx * (1-self.time_mix_key[0]) + xv = x + xx * (1-self.time_mix_value[0]) + xr = x + xx * (1-self.time_mix_receptance[0]) + xg = x + xx * (1-self.time_mix_gate[0]) + + k,_ = self.key(xk) + v,_ = self.value(xv) + r,_ = self.receptance(xr) + g,_ = self.gate(xg) + g = torch.nn.functional.silu(g) + + k = k.view(-1,self.num_heads, self.head_dim,1) + v = v.view(-1,self.num_heads, 1, self.head_dim) + r = r.view(-1,self.num_heads, self.head_dim, 1).transpose(-1,-2) + + at = (k*v) + u = self.time_faaaa.reshape(1,self.num_heads,self.head_dim,1).transpose(-1,-2) + w = self.time_decay.reshape(self.num_heads,self.head_dim,1).exp().neg().exp() + # print(at.shape, r.shape, u.shape) + ur = (u*r) + # print(ur.shape) + out = ur@at + + # if(kv_cache != None): + # print(attn_metadata.num_prefill_tokens) + + T = attn_metadata.num_prefill_tokens + # print(attn_metadata) + + if (T == 0): T = attn_metadata.num_decode_tokens + + if(attn_metadata.num_prefill_tokens != 0): + # print(kv_cache.shape if kv_cache != None else None) + s = kv_cache[blocknum,blockidx,2:,:].transpose(-3,-2) if kv_cache != None and kv_cache.shape[0] > 0 else torch.zeros(1,self.num_heads, self.head_dim, self.head_dim, device=at.device, dtype=at.dtype) + # print(kv_cache.shape if kv_cache != None else None) + for t in range(T): + print(out[t].shape, r[t].shape, s.shape) + out[t] += r[t] @ s[0] + s[0] *= w + s[0] += at[t] + + if(kv_cache != None and kv_cache.shape[0] > 0): + kv_cache[blocknum,blockidx,2:,:,] = s.transpose(-3,-2) + + else: + # print(kv_cache.shape if kv_cache != None else None) + + + for t in range(T): + s = kv_cache[blocknum[t],blockidx[t],2:,:,].transpose(-3,-2) if kv_cache != None else torch.zeros(self.num_heads, self.head_dim, self.head_dim, device=at.device, dtype=at.dtype) + + # print(out[t].shape, r[t].shape, s.shape) + out[t] += r[t] @ s + s *= w + s += at[t] + + if(kv_cache != None): + kv_cache[blocknum[t],blockidx[t],2:,:,] = s.transpose(-3,-2) + out = out.view(-1, self.num_heads * self.head_dim) + out = self.ln_x(out/self.head_size_divisor) + hidden_states, _ = self.output(out*g) + + + return kv_cache, hidden_states + + +class RWKVMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: RwkvConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + + + self.time_mix_key = nn.Parameter(torch.zeros(1,1,hidden_size)) + self.time_mix_receptance = nn.Parameter(torch.zeros(1,1,hidden_size)) + + self.key = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config + ) + self.value = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config + ) + self.receptance = ColumnParallelLinear( + hidden_size, + hidden_size, + bias=True, + quant_config=quant_config + ) + + def forward(self, x, kv_cache, attn_metadata:LinearFlashAttentionMetadata, position_ids:torch.Tensor): + + blocknum = attn_metadata.slot_mapping // 16 + blockidx = attn_metadata.slot_mapping % 16 + + blocknum = blocknum.to(x.device) + blockidx = blockidx.to(x.device) + if(attn_metadata.num_decode_tokens > 0): + if kv_cache != None: + state = kv_cache[blocknum,blockidx,1,:,:] + kv_cache[blocknum,blockidx,1,:] = x.chunk(get_tensor_model_parallel_world_size(),-1)[get_tensor_model_parallel_rank()].reshape_as(state) + state = tensor_model_parallel_all_gather(state.reshape(-1,x.shape[-1]//get_tensor_model_parallel_world_size())) + else: + ott = torch.arange(x.shape[0]).to(x.device) + ott = ott-1 + state = x[ott] + state[position_ids==0]*=0 # for start of sequence + if kv_cache != None and kv_cache.shape[0] > 0: + kv_cache[blocknum[ott[position_ids==0]],blockidx[ott[position_ids==0]],1,:] = x[ott[position_ids==0]-1].chunk(get_tensor_model_parallel_world_size(),-1)[get_tensor_model_parallel_rank()].reshape_as(kv_cache[blocknum[ott[position_ids==0]],blockidx[ott[position_ids==0]],1,:]) + state = state.reshape_as(x) + + xx = state - x + xk = x + xx * (1-self.time_mix_key[0]) + xr = x + xx * (1-self.time_mix_receptance[0]) + + k,_ = self.key(xk) + k = torch.relu(k) ** 2 + kv,_ = self.value(k) + rr,_ = self.receptance(xr) + + return tensor_model_parallel_all_gather(torch.sigmoid(rr)) * kv + + +class RWKVBlock(nn.Module): + + def __init__( + self, + config: RwkvConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + position = 0 + ): + super().__init__() + hidden_size = config.hidden_size + + if position == 0: + self.pre_ln = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + + # self.ln0 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.ln1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attention = RWKVAttention(config, cache_config, quant_config) + self.ln2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.feed_forward = RWKVMLP(int(config.hidden_size*3.5), config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + position_ids: torch.Tensor + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln1(hidden_states) + kv_cache,attn_output = self.attention( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + position_ids=position_ids + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln2(hidden_states) + feed_forward_hidden_states = self.feed_forward(hidden_states, kv_cache, attn_metadata, position_ids) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class RWKV5Model(nn.Module): + + def __init__( + self, + config: RwkvConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.embeddings = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + self.blocks = nn.ModuleList([ + RWKVBlock(config, cache_config, quant_config, _) + for _ in range(config.num_hidden_layers) + ]) + self.ln_out = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.head = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: LinearFlashAttentionMetadata, + ) -> torch.Tensor: + # print(position_ids.size(),position_ids) + # print(attn_metadata) + inputs_embeds = self.embeddings(input_ids) + + + hidden_states = self.blocks[0].pre_ln(inputs_embeds) + + for i in range(len(self.blocks)): + layer = self.blocks[i] + hidden_states = layer(hidden_states, kv_caches[i], attn_metadata, position_ids) + + hidden_states = self.ln_out(hidden_states) + + return hidden_states + + +class Rwkv5ForCausalLM(nn.Module): + + def __init__( + self, + config: RwkvConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.quant_config = quant_config + print(config) + print(cache_config) + print(quant_config) + cache_config.num_gpu_blocks_override = 16 + cache_config.num_cpu_blocks_override = 16 + self.rwkv = RWKV5Model(config, cache_config, quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + *args, + **kwargs + ) -> torch.Tensor: + hidden_states = self.rwkv(input_ids, positions, kv_caches, + attn_metadata) + print(hidden_states) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.rwkv.head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SampleLogprobs]: + next_tokens = self.sampler(logits, sampling_metadata) + print(next_tokens) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + print(params_dict.keys()) + + for name, loaded_weight in weights: + if not name.startswith("rwkv."): + name = "rwkv." + name + + if("time_decay" in name or "time_faaaa" in name or "ln_x" in name): + print("Splitting:" + name) + loaded_weight = loaded_weight.chunk(get_tensor_model_parallel_world_size(),0)[get_tensor_model_parallel_rank()] + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + print(name, param.size(), loaded_weight.size()) + weight_loader(param, loaded_weight) \ No newline at end of file diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index ad6cf659c3e6..bb4d0041e488 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -10,6 +10,8 @@ logger = init_logger(__name__) _GENERATION_MODELS = { + + "Rwkv5ForCausalLM": ("Rwkv5ForCausalLM","Rwkv5ForCausalLM"), "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 0f20e8d0c821..c7715e84ca49 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -19,6 +19,8 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, + MLPSpeculatorConfig, + RWKV5Config, EAGLEConfig, ExaoneConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MllamaConfig, @@ -50,6 +52,7 @@ "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) "jais": JAISConfig, "mlp_speculator": MLPSpeculatorConfig, + "rwkv5": RWKV5Config, "medusa": MedusaConfig, "eagle": EAGLEConfig, "exaone": ExaoneConfig, diff --git a/vllm/transformers_utils/configs/RWKV5.py b/vllm/transformers_utils/configs/RWKV5.py new file mode 100644 index 000000000000..d59b6d600613 --- /dev/null +++ b/vllm/transformers_utils/configs/RWKV5.py @@ -0,0 +1,52 @@ +# Adapted from +# https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py +# Copyright 2023 The vLLM team. +# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Falcon configuration""" +from transformers.configuration_utils import PretrainedConfig +useLinear = False +import os + +class RWKV5Config(PretrainedConfig): + model_type = "rwkv5" + + + def __init__( + self, + **kwargs, + ) -> None: + global useLinear + useLinear = True + + print("RWKV5Config", kwargs) + # exit() + if(kwargs.get("num_attention_heads", False)): + print(kwargs) + kwargs["num_attention_heads"] = kwargs["attention_hidden_size"] // kwargs["head_size"] + kwargs["num_kv_heads"] = kwargs["num_attention_heads"] + kwargs["max_seq_len"] = -1 + + super().__init__(**kwargs) + + @property + def head_dim(self): + return self.hidden_size // self.n_head + + @property + def rotary(self): + return False + + diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 462cd964325d..b4215925e94a 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -12,6 +12,7 @@ from vllm.transformers_utils.configs.mllama import MllamaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig +from vllm.transformers_utils.configs.RWKV5 import RWKV5Config from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig, Qwen2VLVisionConfig) @@ -30,6 +31,7 @@ "ExaoneConfig", "MllamaConfig", "MLPSpeculatorConfig", + "RWKV5Config", "NemotronConfig", "SolarConfig", "UltravoxConfig", diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 252440c7b7e0..49b972db5dfe 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -81,6 +81,7 @@ def _allocate_kv_cache( # null block in CpuGpuBlockAllocator requires at least that # block to be zeroed-out. # We zero-out everything for simplicity. + print("kv_cache_shape", kv_cache_shape) kv_cache.append( torch.zeros(kv_cache_shape, dtype=self.dtype,