Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8de4315
Add support for openpangu_pro_moe_v2, which characterized by its diff…
yt0428 Nov 15, 2025
b0e8806
Bugfix for param_sink_key initialization, block_table for cascade and…
yt0428 Nov 15, 2025
e38739a
Add FLASH_SINK_ATTN to AttentionBackendEnum
yt0428 Nov 15, 2025
b4dce68
Merge branch 'main' into Add_support_for_openpangu_promoe_v2
yt0428 Nov 19, 2025
e03575a
Merge branch 'main' into Add_support_for_openpangu_promoe_v2
yt0428 Nov 25, 2025
dff2694
Merge branch 'main' into Add_support_for_openpangu_promoe_v2
yt0428 Nov 25, 2025
315e3f6
Refactor code, make attn backend focus on diffkv and move sink logic …
yt0428 Nov 26, 2025
28169a6
Merge branch 'main' into Add_support_for_openpangu_promoe_v2
yt0428 Nov 26, 2025
ca7eaa0
fix pre-commit
yt0428 Nov 26, 2025
0abdaad
fix typo
yt0428 Nov 26, 2025
de538d3
Merge branch 'main' into Add_support_for_openpangu_promoe_v2
yt0428 Dec 13, 2025
b565203
Refacotr code. Extent FLASH_ATTN to support different KV size and cre…
yt0428 Dec 13, 2025
2470548
Fix typos
yt0428 Dec 13, 2025
a0563e7
Add assert in init of SinkFullAttentionManager
yt0428 Dec 13, 2025
a7430ab
Fix typos
yt0428 Dec 13, 2025
93a7afc
Exteng SinkFullAttentionManager to handle sink blocks management, avo…
yt0428 Dec 17, 2025
378e208
Fix pre-commit
yt0428 Dec 17, 2025
4ad3f75
Refactor code, add FLASH_ATTN_DIFFKV backend
yt0428 Dec 22, 2025
cf58a62
Refactor code, move static sink logics to builder
yt0428 Dec 22, 2025
049d2aa
Fix minor difference
yt0428 Dec 22, 2025
94a55eb
Merge branch 'main' into Add_support_for_openpangu_promoe_v2
yt0428 Dec 23, 2025
1aff6ff
Fix pre-commit
yt0428 Dec 23, 2025
6304606
Merge branch 'main' into Add_support_for_openpangu_promoe_v2
LucasWilkinson Dec 24, 2025
e3e949a
Merge branch 'main' into Add_support_for_openpangu_promoe_v2
DarkLight1337 Dec 26, 2025
ffa6276
Merge branch 'main' into Add_support_for_openpangu_promoe_v2
yt0428 Dec 30, 2025
5bafa03
Update init of SinkFullAttentionManager
yt0428 Dec 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ th {
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ |
| `OuroForCausalLM` | ouro | `ByteDance/Ouro-1.4B`, `ByteDance/Ouro-2.6B`, etc. | ✅︎ | |
| `PanguEmbeddedForCausalLM` |openPangu-Embedded-7B | `FreedomIntelligence/openPangu-Embedded-7B-V1.1` | ✅︎ | ✅︎ |
| `PanguProMoEV2ForCausalLM` |openpangu-pro-moe-v2 | | ✅︎ | ✅︎ |
| `PanguUltraMoEForCausalLM` |openpangu-ultra-moe-718b-model | `FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1` | ✅︎ | ✅︎ |
| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ |
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ |
Expand Down
5 changes: 5 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,11 @@ def check_available_online(
"PanguEmbeddedForCausalLM": _HfExamplesInfo(
"FreedomIntelligence/openPangu-Embedded-7B-V1.1", trust_remote_code=True
),
"PanguProMoEV2ForCausalLM": _HfExamplesInfo(
"",
trust_remote_code=True,
is_available_online=False,
),
"PanguUltraMoEForCausalLM": _HfExamplesInfo(
"FreedomIntelligence/openPangu-Ultra-MoE-718B-V1.1",
trust_remote_code=True,
Expand Down
3 changes: 3 additions & 0 deletions vllm/attention/backends/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"""

FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
FLASH_ATTN_DIFFKV = (
"vllm.v1.attention.backends.flash_attn_diffkv.FlashAttentionDiffKVBackend"
)
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
ROCM_ATTN = "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
ROCM_AITER_MLA = "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend"
Expand Down
15 changes: 11 additions & 4 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def __init__(
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None,
head_size_v: int | None = None,
**extra_impl_args,
) -> None:
"""
Expand Down Expand Up @@ -217,6 +218,7 @@ def __init__(

self.num_heads = num_heads
self.head_size = head_size
self.head_size_v = self.head_size if head_size_v is None else head_size_v
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
Expand Down Expand Up @@ -274,8 +276,7 @@ def __init__(
kv_sharing_target_layer_name,
**extra_impl_args,
)
backend_name = self.attn_backend.get_name()
self.backend = AttentionBackendEnum.__members__.get(backend_name)
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
self.dtype = dtype

# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
Expand Down Expand Up @@ -355,18 +356,22 @@ def forward(
query, _ = self.query_quant(query, self._q_scale)

if self.use_output:
if output_shape is None:
output_shape = torch.Size(
(*query.shape[:-1], self.num_heads * self.head_size_v)
)
output_shape = output_shape if output_shape is not None else query.shape
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size_v)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size_v)
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
Expand Down Expand Up @@ -452,6 +457,7 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size_v,
dtype=self.kv_cache_torch_dtype,
)

Expand Down Expand Up @@ -794,6 +800,7 @@ def unified_attention_with_output(
output_block_scale: torch.Tensor | None = None,
) -> None:
attn_metadata, self, kv_cache = get_attention_context(layer_name)

self.impl.forward(
self,
query,
Expand Down
254 changes: 254 additions & 0 deletions vllm/attention/layers/static_sink_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools

import torch

from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.attention.layer import Attention
from vllm.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv,
)
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheSpec,
SinkFullAttentionSpec,
)

logger = init_logger(__name__)


@functools.lru_cache
def create_static_sink_attention_backend(
underlying_attn_backend: type[AttentionBackend],
sink_len: int = 0,
) -> type[AttentionBackend]:
prefix = "StaticSink_"
underlying_builder = underlying_attn_backend.get_builder_cls()

class StaticSinkAttentionBuilder(underlying_builder): # type: ignore
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
model_config = vllm_config.model_config
scheduler_config = vllm_config.scheduler_config
self.sink_len = sink_len
self.block_size = vllm_config.cache_config.block_size
self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size
self.max_num_blocks = cdiv(
model_config.max_model_len, vllm_config.cache_config.block_size
)
self.block_table_with_sink = torch.zeros(
(
scheduler_config.max_num_seqs,
self.max_num_blocks + self.num_sink_blocks,
),
device=device,
dtype=torch.int32,
)
self.block_table_with_sink[:, : self.num_sink_blocks] = torch.arange(
1,
self.num_sink_blocks + 1,
device=device,
dtype=torch.int32,
)

def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
common_attn_metadata.seq_lens[:] = (
common_attn_metadata.seq_lens + self.sink_len
)
common_attn_metadata.seq_lens[
common_attn_metadata.seq_lens == self.sink_len
] = 0
common_attn_metadata.max_seq_len = (
common_attn_metadata.max_seq_len + self.sink_len
)
max_num_blocks = cdiv(common_attn_metadata.max_seq_len, self.block_size)
num_reqs = common_attn_metadata.num_reqs
self.block_table_with_sink[
:num_reqs, self.num_sink_blocks : self.num_sink_blocks + max_num_blocks
] = common_attn_metadata.block_table_tensor[:, :max_num_blocks]
common_attn_metadata.block_table_tensor = self.block_table_with_sink[
:num_reqs
]

return super().build(common_prefix_len, common_attn_metadata, fast_build)

attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=StaticSinkAttentionBuilder,
)

return attn_backend


@CustomOp.register("static_sink_attention")
class StaticSinkAttention(Attention, CustomOp):
"""
Attention with static sink tokens
"""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
sink_len: int,
attn_backend: type[AttentionBackend] | None = None,
cache_config: CacheConfig | None = None,
**kwargs,
):
dtype = torch.get_default_dtype()

if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16

if attn_backend is not None:
underlying_attn_backend = attn_backend
else:
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_static_sink_attention_backend(
underlying_attn_backend,
sink_len=sink_len,
)
Attention.__init__(
self=self,
num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
**kwargs,
)
CustomOp.__init__(self)

self.sink_len = sink_len
self.block_size = block_size
self.sink_populated = False
self.sink_key = None
self.sink_value = None

def update_sink_kv(self, sink_key, sink_value) -> None:
self.sink_key = sink_key
self.sink_value = sink_value

def forward_native(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
assert self.sink_key is not None and self.sink_value is not None, (
"sink_key and sink_value have not been prepared"
)
if not self.sink_populated:
forward_context: ForwardContext = get_forward_context()
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)

return super().forward(query, key, value, output_shape)

def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
return self.forward_native(query, key, value, output_shape)

def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)

def populate_sink_kv(self, self_kv_cache):
sink_kv_slot_mapping = torch.arange(
self.block_size,
self.sink_len + self.block_size,
device=torch.cuda.current_device(),
dtype=torch.long,
)
triton_reshape_and_cache_flash_diffkv(
self.sink_key,
self.sink_value,
self_kv_cache,
sink_kv_slot_mapping,
self.kv_cache_dtype,
self._k_scale,
self._v_scale,
)
# We only populate the sink_key and sink_value once
self.sink_populated = True

def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER

return SinkFullAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size_v,
sink_len=self.sink_len,
dtype=self.kv_cache_torch_dtype,
)


def maybe_populate_sink(
self_kv_cache: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if self.sink_populated or self_kv_cache.numel() == 0:
return
self.populate_sink_kv(self_kv_cache)


def maybe_populate_sink_fake(
self_kv_cache: torch.Tensor,
layer_name: str,
) -> None:
return


direct_register_custom_op(
op_name="maybe_populate_sink",
op_func=maybe_populate_sink,
mutates_args=["self_kv_cache"],
fake_impl=maybe_populate_sink_fake,
)
Loading