Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
500 changes: 500 additions & 0 deletions vllm/attention/backends/rwkv5linear_attn.py

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.transformers_utils.configs.RWKV5 import useLinear
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu

Expand All @@ -24,6 +25,7 @@ class _Backend(enum.Enum):
FLASHINFER = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
LINEAR = enum.auto()


def backend_name_to_enum(backend_name: str) -> _Backend:
Expand Down Expand Up @@ -146,8 +148,13 @@ def get_attn_backend(
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend
elif backend == _Backend.LINEAR:
logger.info("Using Pallas backend.")
from vllm.attention.backends.rwkv5linear_attn import RWKVFlashAttentionBackend
return RWKVFlashAttentionBackend
else:
raise ValueError("Invalid attention backend.")



def which_attn_to_use(
Expand All @@ -163,6 +170,10 @@ def which_attn_to_use(
# Default case.
selected_backend = _Backend.FLASH_ATTN

if useLinear:
print("Using Linear Attention")
return _Backend.LINEAR

# Check whether a particular choice of backend was
# previously forced.
#
Expand Down
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ def __init__(
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * GiB_bytes
self.num_gpu_blocks_override = num_gpu_blocks_override
self.num_cpu_blocks_override = None
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
Expand Down
2 changes: 1 addition & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def _schedule_prefills(
assert num_new_tokens == num_prompt_tokens

prompt_limit = self._get_prompt_limit(seq_group)
if num_new_tokens > prompt_limit:
if num_new_tokens > prompt_limit and prompt_limit > 0:
logger.warning(
"Input prompt (%d tokens) is too long"
" and exceeds limit of %d", num_new_tokens, prompt_limit)
Expand Down
8 changes: 8 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,14 @@ def _initialize_kv_caches(self) -> None:
num_gpu_blocks_override)
num_gpu_blocks = num_gpu_blocks_override

if self.cache_config.num_cpu_blocks_override is not None:
num_cpu_blocks_override = self.cache_config.num_cpu_blocks_override
logger.info(
"Overriding num_cpu_blocks=%d with "
"num_cpu_blocks_override=%d", num_cpu_blocks,
num_cpu_blocks_override)
num_cpu_blocks = num_cpu_blocks

self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
[1, 2, 5, 10, 20, 50, 100]
"""
mantissa_lst = [1, 2, 5]
if max_value <= 0:
max_value = 100 # for infinite context models
exponent = 0
buckets: List[int] = []
while True:
Expand Down
6 changes: 3 additions & 3 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _validate_input(

# Note: EmbeddingRequest doesn't have max_tokens
if isinstance(request, EmbeddingRequest):
if token_num > self.max_model_len:
if token_num > self.max_model_len and self.max_model_len > 0:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
Expand All @@ -264,13 +264,13 @@ def _validate_input(
prompt_token_ids=input_ids)

if request.max_tokens is None:
if token_num >= self.max_model_len:
if token_num >= self.max_model_len and self.max_model_len > 0:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.")
elif token_num + request.max_tokens > self.max_model_len:
elif token_num + request.max_tokens > self.max_model_len and self.max_model_len > 0:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
Expand Down
Loading