Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4ea0189
workaround
vadiklyutiy Feb 24, 2026
152ccc2
zeroing kv-cache block after allocation
vadiklyutiy Feb 25, 2026
27ab3fa
Merge branch 'main' into vadim/issue35138
vadiklyutiy Feb 25, 2026
af817b7
optimize
vadiklyutiy Feb 25, 2026
053d305
make a bit better code
vadiklyutiy Feb 26, 2026
3fc05cd
Merge branch 'main' into vadim/issue35138
vadiklyutiy Feb 27, 2026
13271d9
fixes
vadiklyutiy Mar 2, 2026
3654a39
fixes
vadiklyutiy Mar 2, 2026
a507809
Merge branch 'main' into vadim/issue35138
vadiklyutiy Mar 2, 2026
59fcf43
zeroing only full attn
vadiklyutiy Mar 3, 2026
0dd399a
Merge branch 'main' into vadim/issue35138
vadiklyutiy Mar 3, 2026
888e947
fix
vadiklyutiy Mar 3, 2026
387cf77
fix
vadiklyutiy Mar 3, 2026
72d4aaf
fix bugs
vadiklyutiy Mar 4, 2026
3115077
fix pin memory
vadiklyutiy Mar 4, 2026
61e6b3c
fix cumem bug
vadiklyutiy Mar 4, 2026
56ae39a
code style
vadiklyutiy Mar 4, 2026
6e91a13
limit to hybrid model only
vadiklyutiy Mar 4, 2026
624fdea
Merge branch 'main' into vadim/issue35138
vadiklyutiy Mar 4, 2026
e26d486
pre-commit fix
vadiklyutiy Mar 4, 2026
a000a63
Merge branch 'main' into vadim/issue35138
vadiklyutiy Mar 6, 2026
2fea792
Merge branch 'main' into vadim/issue35138
vadiklyutiy Mar 9, 2026
267646a
revert workaround for pre-commit in hermes_tool_parser.py
vadiklyutiy Mar 9, 2026
61feb0c
move finding of block dim to attn backend
vadiklyutiy Mar 9, 2026
9c92ec2
resolve PR comment
vadiklyutiy Mar 9, 2026
145c7fa
move zeroing from vllm/utils/math_utils.py to vllm/v1/worker/utils.py
vadiklyutiy Mar 10, 2026
67fcd34
call _init_kv_zero_meta for model with mamba
vadiklyutiy Mar 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/utils/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ def round_up(x: int, y: int) -> int:
def round_down(x: int, y: int) -> int:
"""Round down x to the nearest multiple of y."""
return (x // y) * y


def largest_power_of_2_divisor(n: int) -> int:
"""Return the largest power-of-2 that divides *n* (isolate lowest set bit)."""
return n & (-n)
20 changes: 20 additions & 0 deletions vllm/v1/attention/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,26 @@ def get_kv_cache_shape(
) -> tuple[int, ...]:
raise NotImplementedError

@classmethod
def get_kv_cache_block_dim(
cls,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> int:
"""Discover which tensor dim is the block index, since different
backends lay out dims differently."""
_S = 1234567
shape = cls.get_kv_cache_shape(
_S,
block_size,
num_kv_heads,
head_size,
cache_dtype_str=cache_dtype_str,
)
return shape.index(_S)

@staticmethod
def get_kv_cache_stride_order(
include_num_layers_dimension: bool = False,
Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,13 @@ def create_kv_cache_blocks(
# Only create new KVCacheBlocks for non-empty blocks
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks

def take_new_block_ids(self) -> list[int]:
"""Drain and return new attention block IDs for zeroing."""
ids: list[int] = []
for mgr in self.coordinator.single_type_managers:
ids.extend(mgr.take_new_block_ids())
return ids

def new_step_starts(self) -> None:
"""Called when a new step is started."""
self.coordinator.new_step_starts()
5 changes: 5 additions & 0 deletions vllm/v1/core/sched/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ class SchedulerOutput:
# EC Cache Connector metadata
ec_connector_metadata: ECConnectorMetadata | None = None

# Block IDs freshly allocated from the pool during this scheduling step.
# The worker zeros the corresponding GPU memory before the blocks are used,
# preventing stale NaN/data from corrupting attention or SSM computation.
new_block_ids_to_zero: list[int] | None = None

@classmethod
def make_empty(cls) -> "SchedulerOutput":
return cls(
Expand Down
18 changes: 10 additions & 8 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
from vllm.v1.metrics.stats import PrefixCacheStats, SchedulerStats
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
Expand Down Expand Up @@ -233,13 +233,8 @@ def __init__(
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER

def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
return any(
isinstance(group_spec.kv_cache_spec, MambaSpec)
for group_spec in kv_cache_config.kv_cache_groups
)

self.has_mamba_layers = has_mamba_layers(kv_cache_config)
self.has_mamba_layers = kv_cache_config.has_mamba_layers
self.needs_kv_cache_zeroing = kv_cache_config.needs_kv_cache_zeroing
self.need_mamba_block_aligned_split = (
self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align"
)
Expand Down Expand Up @@ -890,6 +885,12 @@ def schedule(self) -> SchedulerOutput:
self.prev_step_scheduled_req_ids.clear()
self.prev_step_scheduled_req_ids.update(num_scheduled_tokens.keys())

new_block_ids_to_zero = (
(self.kv_cache_manager.take_new_block_ids() or None)
if self.needs_kv_cache_zeroing
else None
)

scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
Expand All @@ -905,6 +906,7 @@ def schedule(self) -> SchedulerOutput:
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
new_block_ids_to_zero=new_block_ids_to_zero,
)

# NOTE(Kuntai): this function is designed for multiple purposes:
Expand Down
11 changes: 11 additions & 0 deletions vllm/v1/core/single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
self.enable_caching = enable_caching
self.new_block_ids: list[int] = []

# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
Expand Down Expand Up @@ -208,6 +209,8 @@ def allocate_new_computed_blocks(
cdiv(num_total_computed_tokens, self.block_size) - len(req_blocks)
)
req_blocks.extend(allocated_blocks)
if type(self.kv_cache_spec) is FullAttentionSpec:
self.new_block_ids.extend(b.block_id for b in allocated_blocks)

def allocate_new_blocks(
self, request_id: str, num_tokens: int, num_tokens_main_model: int
Expand All @@ -234,8 +237,16 @@ def allocate_new_blocks(
else:
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
if type(self.kv_cache_spec) is FullAttentionSpec:
self.new_block_ids.extend(b.block_id for b in new_blocks)
return new_blocks

def take_new_block_ids(self) -> list[int]:
"""Drain and return block IDs allocated since the last call."""
ids = self.new_block_ids
self.new_block_ids = []
return ids

def cache_blocks(self, request: Request, num_tokens: int) -> None:
"""
Cache the blocks for the request.
Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/kv_cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,11 @@ class KVCacheConfig:
For models with multiple types of attention, there will be multiple groups,
see `_get_kv_cache_config_uniform_page_size` for more details.
"""

@property
def has_mamba_layers(self) -> bool:
return any(isinstance(g.kv_cache_spec, MambaSpec) for g in self.kv_cache_groups)

@property
def needs_kv_cache_zeroing(self) -> bool:
return self.has_mamba_layers
27 changes: 27 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@

from .utils import (
AttentionGroup,
KVBlockZeroer,
add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache,
prepare_kernel_block_sizes,
Expand Down Expand Up @@ -978,6 +979,26 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
decode_threshold=self.reorder_batch_threshold,
)

def _init_kv_zero_meta(self) -> None:
"""One-time precomputation for _zero_block_ids.

Delegates to KVBlockZeroer.init_meta with the runner's state.
Called from gpu_worker.py outside the CuMem pool context.
"""
self._kv_block_zeroer = KVBlockZeroer(self.device, self.pin_memory)
self._kv_block_zeroer.init_meta(
attn_groups_iter=self._kv_cache_spec_attn_group_iterator(),
kernel_block_sizes=self._kernel_block_sizes,
cache_dtype=self.cache_config.cache_dtype,
runner_only_attn_layers=self.runner_only_attn_layers,
static_forward_context=(self.compilation_config.static_forward_context),
)

def _zero_block_ids(self, block_ids: list[int]) -> None:
"""Zero the KV cache memory for the given block IDs."""
if hasattr(self, "_kv_block_zeroer"):
self._kv_block_zeroer.zero_block_ids(block_ids)

# Note: used for model runner override.
def _init_device_properties(self) -> None:
"""Initialize attributes from torch.cuda.get_device_properties"""
Expand Down Expand Up @@ -1011,6 +1032,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
for req_id in scheduler_output.finished_req_ids:
self.input_batch.remove_request(req_id)

# Zero GPU memory for freshly allocated cache blocks to prevent
# stale NaN/data from corrupting attention or SSM computation.
if scheduler_output.new_block_ids_to_zero:
self._zero_block_ids(scheduler_output.new_block_ids_to_zero)

# Free the cached encoder outputs.
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
Expand Down Expand Up @@ -6461,6 +6487,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
kernel_block_sizes = prepare_kernel_block_sizes(
kv_cache_config, self.attn_groups
)
self._kernel_block_sizes = kernel_block_sizes

# create metadata builders
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)
Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,14 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
else:
self.model_runner.initialize_kv_cache(kv_cache_config)

# Build KV-zero metadata outside the CuMem pool so the bookkeeping
# GPU tensors (seg_addrs, block-id buffers) use the standard PyTorch
# allocator and are not discarded during sleep/wake cycles.
if kv_cache_config.needs_kv_cache_zeroing and hasattr(
self.model_runner, "_init_kv_zero_meta"
):
self.model_runner._init_kv_zero_meta()

@instrument(span_name="Warmup (GPU)")
def compile_or_warm_up_model(self) -> float:
warmup_sizes: list[int] = []
Expand Down
Loading