Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/compile/passes/test_rope_kvcache_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def test_rope_kvcache_fusion(
}
q_unfused, k_unfused, v_unfused, dummy = model(qkv_unfused, pos_unfused)
attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_unfused = attn_layer.kv_cache[forward_context.virtual_engine]
kv_cache_unfused = attn_layer.kv_cache[0]
del dummy

torch._dynamo.mark_dynamic(qkv, 0)
Expand All @@ -309,7 +309,7 @@ def test_rope_kvcache_fusion(
}
q_fused, k_fused, v_fused, dummy = model_fused(qkv, pos)
attn_layer = forward_context.no_compile_layers[model.layer_name]
kv_cache_fused = attn_layer.kv_cache[forward_context.virtual_engine]
kv_cache_fused = attn_layer.kv_cache[0]
del dummy

assert fusion_pass.matched_count == 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self, block_size: int, num_gpu_blocks: int):
self._block_hasher = get_request_block_hasher(block_size, sha256)

self._dummy_ctx: ForwardContext = ForwardContext(
no_compile_layers={}, attn_metadata={}, virtual_engine=0, slot_mapping={}
no_compile_layers={}, attn_metadata={}, slot_mapping={}
)

def new_request(self, token_ids: list[int]) -> Request:
Expand Down
1 change: 0 additions & 1 deletion tests/v1/kv_connector/unit/test_lmcache_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def test_forward_context_interface():
from vllm.forward_context import ForwardContext

assumes(ForwardContext, "no_compile_layers", is_instance_of=dict)
assumes(ForwardContext, "virtual_engine")
assumes(ForwardContext, "attn_metadata")


Expand Down
8 changes: 0 additions & 8 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,6 @@ def test_multi_xfer_one_engine(
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
slot_mapping={},
)
_before_load = time.perf_counter()
Expand Down Expand Up @@ -653,7 +652,6 @@ def test_async_load_kv(
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
slot_mapping={},
)
_before_load = time.perf_counter()
Expand Down Expand Up @@ -888,7 +886,6 @@ def test_concurrent_load_kv(
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
slot_mapping={},
)
_before_load = time.perf_counter()
Expand Down Expand Up @@ -1057,7 +1054,6 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
slot_mapping={},
)
connector.start_load_kv(dummy_ctx)
Expand Down Expand Up @@ -1857,7 +1853,6 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
slot_mapping={},
)
connector.start_load_kv(dummy_ctx)
Expand Down Expand Up @@ -2026,7 +2021,6 @@ def test_transfer_failure_logging(
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
slot_mapping={},
)

Expand Down Expand Up @@ -2129,7 +2123,6 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
slot_mapping={},
)
connector.start_load_kv(dummy_ctx)
Expand Down Expand Up @@ -2182,7 +2175,6 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
slot_mapping={},
)
connector.start_load_kv(dummy_ctx)
Expand Down
1 change: 0 additions & 1 deletion tests/v1/kv_connector/unit/test_offloading_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ def __init__(
self._dummy_ctx: ForwardContext = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
slot_mapping={},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def inject_kv_into_layer(
if kv_cache_attr is None:
continue

kv_cache_layer = kv_cache_attr[forward_context.virtual_engine]
kv_cache_layer = kv_cache_attr[0]

filename = self._generate_filename_debug(
layer_name, request.token_ids, request.mm_hashes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -778,9 +778,7 @@ def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"
continue

if layer_name not in self.kv_caches:
self.kv_caches[layer_name] = attn_layer.kv_cache[
forward_context.virtual_engine
]
self.kv_caches[layer_name] = attn_layer.kv_cache[0]

####################
# Worker side APIs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def inject_kv_into_layer(
if kv_cache is None:
continue

layer = kv_cache[forward_context.virtual_engine]
layer = kv_cache[0]

kv_cache = self.p2p_nccl_engine.recv_tensor(
request.request_id + "#" + layer_name, remote_address
Expand Down
7 changes: 0 additions & 7 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,6 @@ class ForwardContext:
for each microbatch.
Set dynamically for each forward pass
"""
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine: int # set dynamically for each forward pass
# set dynamically for each forward pass
dp_metadata: DPMetadata | None = None
# determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
Expand Down Expand Up @@ -265,7 +263,6 @@ def is_forward_context_available() -> bool:
def create_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
dp_metadata: DPMetadata | None = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
Expand All @@ -282,7 +279,6 @@ def create_forward_context(
return ForwardContext(
no_compile_layers=vllm_config.compilation_config.static_forward_context,
all_moe_layers=all_moe_layers,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
slot_mapping=slot_mapping or {},
dp_metadata=dp_metadata,
Expand Down Expand Up @@ -313,7 +309,6 @@ def override_forward_context(forward_context: ForwardContext | None):
def set_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,
virtual_engine: int = 0,
num_tokens: int | None = None,
num_tokens_across_dp: torch.Tensor | None = None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
Expand Down Expand Up @@ -362,7 +357,6 @@ def set_forward_context(
additional_kwargs = current_platform.set_additional_forward_context(
attn_metadata=attn_metadata,
vllm_config=vllm_config,
virtual_engine=virtual_engine,
dp_metadata=dp_metadata,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
Expand All @@ -374,7 +368,6 @@ def set_forward_context(
forward_context = create_forward_context(
attn_metadata,
vllm_config,
virtual_engine,
dp_metadata,
cudagraph_runtime_mode,
batch_descriptor,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def get_attention_context(
- attn_metadata: Attention metadata for this specific layer, or None if
no metadata available
- attn_layer: The attention layer instance (Attention or MLAAttention)
- kv_cache: The KV cache tensor for current virtual engine
- kv_cache: The KV cache tensor for current forward pass
- slot_mapping: The slot mapping for this specific layer

Note: attn_metadata may be None, but attn_layer and kv_cache are always
Expand All @@ -600,7 +600,7 @@ def get_attention_context(
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
kv_cache = attn_layer.kv_cache[0]
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/attention/mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def forward(
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self_kv_cache = self.kv_cache[0]
slot_mapping = forward_context.slot_mapping

assert isinstance(slot_mapping, dict), (
Expand Down Expand Up @@ -940,7 +940,7 @@ def unified_mla_kv_cache_update(
return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)

attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
kv_cache = attn_layer.kv_cache[0]

slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ def forward_native(
"sink_key and sink_value have not been prepared"
)
if not self.sink_populated:
forward_context: ForwardContext = get_forward_context()
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self_kv_cache = self.kv_cache[0]
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)

return super().forward(query, key, value, output_shape)
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/kda.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _forward(
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
num_actual_tokens = attn_metadata.num_actual_tokens
constant_caches = self.kv_cache[forward_context.virtual_engine]
constant_caches = self.kv_cache[0]

q_proj_states = q_proj_states[:num_actual_tokens]
k_proj_states = k_proj_states[:num_actual_tokens]
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/mamba/linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def _forward(
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
kv_cache = self.kv_cache[0][0]
state_indices_tensor = attn_metadata.state_indices_tensor
clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor):
query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def conv_ssm_forward(
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self_kv_cache = self.kv_cache[0]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/mamba/short_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def forward_cuda(
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor_p = attn_metadata.state_indices_tensor_p
state_indices_tensor_d = attn_metadata.state_indices_tensor_d
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/bailing_moe_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def _forward(

# Get KV cache and state indices
if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
kv_cache = self.kv_cache[0][0]
state_indices_tensor = attn_metadata.state_indices_tensor
clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/extract_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def unified_kv_cache_update(
"""
forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
kv_cache = attn_layer.kv_cache[0]

slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/olmo_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def _forward_core(
non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/plamo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def forward_impl(
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self_kv_cache = self.kv_cache[0]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,6 @@ def _forward_core(
a=a,
core_attn_out=core_attn_out,
attn_metadata=attn_metadata,
virtual_engine=forward_context.virtual_engine,
)

has_initial_state = attn_metadata.has_initial_state
Expand All @@ -826,7 +825,7 @@ def _forward_core(
non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
Expand Down Expand Up @@ -1009,13 +1008,12 @@ def _forward_core_decode_non_spec(
a: torch.Tensor,
core_attn_out: torch.Tensor,
attn_metadata: GDNAttentionMetadata,
virtual_engine: int,
):
"""
Core attention computation with a packed non-spec decode fast path.
"""
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[virtual_engine]
self_kv_cache = self.kv_cache[0]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def bind_kv_cache(

# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
# NOTE: Keep list wrapper for layers that index kv_cache by engine slot.
forward_context[layer_name].kv_cache = [kv_cache]


Expand Down
Loading