nixl refactor: new transfer design#40731
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the NIXL connector to use a plan-based transfer design, which unifies Dense and Mamba implementations into a model-agnostic execution path. By pre-generating an EngineTransferPlan during the handshake, the hot path is simplified and model-specific branching is removed. The review feedback highlights several opportunities to improve robustness, specifically by replacing assertions with explicit error handling for block count mismatches and adding divisibility checks when calculating logical blocks and chunk sizes to prevent potential data corruption or crashes in heterogeneous tensor parallel configurations.
| num_local_blocks = len(local_block_ids[i]) | ||
| if not self._is_mamba_group[i]: | ||
| assert num_local_blocks <= num_remote_blocks | ||
| # Partial prefix cache hit: just read uncomputed blocks. | ||
| # Skip mamba groups — their blocks represent full state (conv+ssm), | ||
| # not per-token data, so trimming would corrupt the transfer. | ||
| if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]: | ||
| assert num_local_blocks <= len(remote_group) | ||
| if num_local_blocks < len(remote_group): |
There was a problem hiding this comment.
The assertion assert num_local_blocks <= len(remote_group) can cause a hard crash of the worker if the producer sends fewer blocks than the consumer expects (e.g., due to a race condition or misconfiguration). It is safer to handle this as a transfer failure for the specific request, allowing the engine to continue processing other requests. Since _read_blocks is called within a try-except block in _read_blocks_for_req, raising a ValueError will be caught and handled gracefully.
| num_local_blocks = len(local_block_ids[i]) | |
| if not self._is_mamba_group[i]: | |
| assert num_local_blocks <= num_remote_blocks | |
| # Partial prefix cache hit: just read uncomputed blocks. | |
| # Skip mamba groups — their blocks represent full state (conv+ssm), | |
| # not per-token data, so trimming would corrupt the transfer. | |
| if num_local_blocks < num_remote_blocks and not self._is_mamba_group[i]: | |
| assert num_local_blocks <= len(remote_group) | |
| if num_local_blocks < len(remote_group): | |
| num_local_blocks = len(local_block_ids[i]) | |
| if num_local_blocks > len(remote_group): | |
| raise ValueError( | |
| f"Group {i}: local block count ({num_local_blocks}) " | |
| f"exceeds remote block count ({len(remote_group)})") | |
| if num_local_blocks < len(remote_group): |
| ratio = physical_blocks_per_logical | ||
| logical_blocks = num_blocks // ratio |
There was a problem hiding this comment.
The calculation of logical_blocks = num_blocks // ratio assumes that the total number of kernel blocks is a perfect multiple of the physical-to-logical ratio. If this invariant is violated, the descriptor IDs for SSM regions will be incorrectly computed, leading to memory corruption or incorrect data transfer. An explicit check should be added to verify this invariant.
| ratio = physical_blocks_per_logical | |
| logical_blocks = num_blocks // ratio | |
| ratio = physical_blocks_per_logical | |
| if num_blocks % ratio != 0: | |
| raise ValueError(f"num_blocks {num_blocks} is not a multiple of " | |
| f"physical_blocks_per_logical {ratio}") | |
| logical_blocks = num_blocks // ratio |
| if j < num_fa_descs: | ||
| chunk = local_len // fa_num_splits | ||
| handle.append((addr + fa_slot * chunk, chunk, dev)) | ||
| else: | ||
| chunk = local_len // ssm_num_splits | ||
| handle.append((addr + p_idx * chunk, chunk, dev)) |
There was a problem hiding this comment.
Using integer division to compute chunk sizes for split handles can lead to silent data loss if the total length is not perfectly divisible by the number of splits. While standard vLLM configurations usually satisfy this, heterogeneous TP scenarios with non-standard world sizes could trigger this issue. It is safer to explicitly check for divisibility.
| if j < num_fa_descs: | |
| chunk = local_len // fa_num_splits | |
| handle.append((addr + fa_slot * chunk, chunk, dev)) | |
| else: | |
| chunk = local_len // ssm_num_splits | |
| handle.append((addr + p_idx * chunk, chunk, dev)) | |
| if j < num_fa_descs: | |
| if local_len % fa_num_splits != 0: | |
| raise ValueError(f"FA descriptor length {local_len} is not " | |
| f"divisible by split count {fa_num_splits}") | |
| chunk = local_len // fa_num_splits | |
| handle.append((addr + fa_slot * chunk, chunk, dev)) | |
| else: | |
| if local_len % ssm_num_splits != 0: | |
| raise ValueError(f"SSM descriptor length {local_len} is not " | |
| f"divisible by split count {ssm_num_splits}") | |
| chunk = local_len // ssm_num_splits | |
| handle.append((addr + p_idx * chunk, chunk, dev)) |
4852347 to
fcf7418
Compare
fcf7418 to
a6e5266
Compare
NickLucche
left a comment
There was a problem hiding this comment.
Left some comments, thanks @ZhanqiuHu !
| # ------------------------------------------------------------------ | ||
| # Plan executors (static — no self access) | ||
| # ------------------------------------------------------------------ |
There was a problem hiding this comment.
this feels a little bit too "claudy"
| @dataclass(frozen=True) | ||
| class RegionPlan: | ||
| """Geometry for one descriptor region. | ||
|
|
||
| Everything needed to build NIXL descriptors and compute descriptor | ||
| IDs is baked in. The caller plugs in ``base_addr`` and | ||
| ``device_id`` when constructing the final descriptor tuples. |
There was a problem hiding this comment.
I am not fully convinced about this abstraction.
I am afraid this may actually be harder to work with rather than a basic region described by (base_addr, len).
Like do we care about keeping track of things like
- page_stride
- offset_in_page
once the starting address of the region has been computed?
'Cause if we don't, we might as well just store the address and len directly
There was a problem hiding this comment.
This is purely my personal preference, but I was thinking of moving the
geometry computation (stride, offset, descriptor size) from worker.py
into transfer_plan.py to reduce the worker code. RegionPlan packs
the output of that. Originally the functions included base_addr and
device_id, but then I wanted to reduce the arguments to
_build_fa_regions, build_fa_local_regions, and
build_mamba_local_regions, and base_addr and device_id are not used for block geometry computation.
I was also thinking of adding parameters like descs_per_block and desc_stride_bytes to RegionPlan,
so we can handle different cache groups with different block size ratios
(e.g, Gemma4 HeteroTP where SWA and FA have different tokens-per-block).
| handle.append((addr + p_idx * chunk, chunk, dev)) | ||
| result.append(handle) | ||
|
|
||
| return result |
There was a problem hiding this comment.
not that it matters much in terms of speed, but this whole method could yield handle here an be a generator
There was a problem hiding this comment.
Sounds good, will update _build_local_splits_from_plan to yield handle.
| all_descs: list[np.ndarray] = [] | ||
| for i, group in enumerate(block_ids): | ||
| group_arr = np.asarray(group) | ||
| spec_type = plan.group_spec_types[i] | ||
| if _is_attention_spec(spec_type): | ||
| fa_region_ids = np.arange(num_fa_regions)[:, None] | ||
| all_descs.append( | ||
| (fa_region_ids * num_blocks + group_arr[None, :]).flatten() | ||
| ) |
There was a problem hiding this comment.
I think we lost a nice optimization for non mamba models that used vectorized np ops only here
if not self._has_mamba:
block_ids = np.concatenate(block_ids)[None, :]
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()
else:
# NOTE (NickLucche) SSM and Attention blocks regions can be exchanged
@ZhanqiuHu could you double check
There was a problem hiding this comment.
Yes, that's truth. I was thinking of removing the branch. I can add it back like this:
@staticmethod
def _compute_desc_ids_from_plan(
plan: EngineTransferPlan,
block_ids: BlockIds,
dst_num_blocks: int,
block_size_ratio: float | None,
physical_blocks_per_logical: int,
) -> np.ndarray:
"""Compute NIXL descriptor IDs for given block IDs."""
num_fa_regions = len(plan.fa_regions)
num_ssm_regions = len(plan.ssm_regions)
num_blocks = dst_num_blocks
if block_size_ratio is not None:
num_blocks = int(num_blocks * block_size_ratio)
ratio = physical_blocks_per_logical
logical_blocks = num_blocks // ratio
num_fa_descs = num_fa_regions * num_blocks
# All-attention fast path: single vectorized broadcast.
if num_ssm_regions == 0:
block_arr = np.concatenate(block_ids)[None, :]
region_ids = np.arange(num_fa_regions)[:, None]
return (region_ids * num_blocks + block_arr).flatten()
# NOTE (NickLucche) With HMA, every kv group has the same number
# of layers and layers from different groups share the same kv
# tensor. Therefore we compute desc IDs per group using the
# right stride:
# FA descs have num_blocks entries per region (kernel granularity),
# SSM descs have logical_blocks entries per region (no kernel
# splitting).
all_descs: list[np.ndarray] = []
for i, group in enumerate(block_ids):
group_arr = np.asarray(group)
if _is_attention_spec(plan.group_spec_types[i]):
fa_region_ids = np.arange(num_fa_regions)[:, None]
all_descs.append(
(fa_region_ids * num_blocks + group_arr[None, :]).flatten()
)
elif _is_ssm_spec(plan.group_spec_types[i]):
# NOTE (NickLucche) SSM and Attention block regions can
# be exchanged arbitrarily by manager. Therefore, descs
# are laid out as:
# [descs_fa (all regions) | descs_ssm (all regions)].
# num_fa_descs offset must be computed per-engine since
# P and D can have different num_blocks (and thus
# different FA desc counts).
ssm_region_ids = np.arange(num_ssm_regions)[:, None]
all_descs.append(
(
ssm_region_ids * logical_blocks
+ group_arr[None, :]
+ num_fa_descs
).flatten()
)
else:
raise ValueError(
f"Unknown spec type {plan.group_spec_types[i]} at index {i}"
)
return np.concatenate(all_descs)
| # Per-group ordered source ranks. Position = local piece index. | ||
| source_ranks_per_group: tuple[tuple[int, ...], ...] | ||
|
|
||
| # Superset of all source ranks (union of all groups). | ||
| all_source_ranks: tuple[int, ...] |
There was a problem hiding this comment.
I dont think it's very clear what "source_rank" is here..
There was a problem hiding this comment.
source_rank is suppose to mean the remote TP ranks that this local rank reads from.
source_ranks_per_group is the source_rank for each kv cache group (e.g., FA source ranks will be < mamba source ranks if FA is replicated and Mamba is sharded).
Should we renamed it to something else?
There was a problem hiding this comment.
let's just add this as comment above here
| group_spec_types: tuple[type[KVCacheSpec], ...], | ||
| local_physical_blocks_per_logical: int, | ||
| ) -> EngineTransferPlan: | ||
| """Generate transfer plan for dense (attention-only) models.""" |
There was a problem hiding this comment.
nit: very minor, but we may want to choose some other name or clarify that for dense we still encompass all non-mamba, including SW and DSA
There was a problem hiding this comment.
How about generate_pure_attention_plan() and
generate_ssm_attention_hybrid_plan()?
There was a problem hiding this comment.
let's ask claude for some more options here
| def _build_local_descs( | ||
| self, | ||
| base_addresses: list[int], | ||
| block_size_ratio: int, | ||
| ) -> list[tuple[int, int, int]]: |
There was a problem hiding this comment.
this method is only used once in register_local_xfer_handler.
I don't think there's a lot of value in added clarity in separating this snippet as a staticmethod
There was a problem hiding this comment.
Agree, will inline it.
| def _compute_read_specs_from_plan( | ||
| plan: EngineTransferPlan, | ||
| local_block_ids: BlockIds, | ||
| remote_block_ids: BlockIds, | ||
| ) -> list[ReadSpec]: | ||
| """Compute read specs from plan. |
There was a problem hiding this comment.
I am also not sure whether this should be a function, or an inline for
There was a problem hiding this comment.
Agree, will inline it.
| # ..but we still need to notify the other remote ranks that we | ||
| # have the blocks we need so they can update the request state. |
There was a problem hiding this comment.
nit: this continuation comment has lost its first part now. It was meant to be
# MLA opt: when P TP > D TP, only a single read is executed for
# the first remote rank (cache is duplicated)..
# ..but we still need to notify the other remote ranks that we
# have the blocks we need so they can update the request state.
but we can just rephrase
There was a problem hiding this comment.
The first part is actually before
if self.use_mla and tp_ratio < 0:
read_specs = read_specs[:1]
The change is already in main (main ref), probably introduced by a previous PR.
| self._read_blocks( | ||
| request_id=req_id, | ||
| dst_engine_id=meta.remote.engine_id, | ||
| remote_request_id=meta.remote.request_id, | ||
| local_block_ids=local_ids, | ||
| remote_block_ids=remote_ids, | ||
| local_block_ids=local_block_ids, | ||
| remote_block_ids=remote_block_ids, |
There was a problem hiding this comment.
not fully sure, but should we change _read_blocks interface to just be
def _read_blocks(
self,
dst_engine_id: str,
request_id: str,
read_spec: ReadSpec,
remote_request_id: str,
local_xfer_side_handle: int,
remote_xfer_side_handle: int,
):
There was a problem hiding this comment.
Yes, I think we can do
def _read_blocks(
self,
read_spec: ReadSpec,
dst_engine_id: str,
remote_request_id: str,
local_xfer_side_handle: int,
remote_xfer_side_handle: int,
):
local_block_ids = read_spec.local_block_ids
remote_block_ids = read_spec.remote_block_ids
remote_rank = read_spec.remote_rank
ZhanqiuHu
left a comment
There was a problem hiding this comment.
Addressed comments
| # Per-group ordered source ranks. Position = local piece index. | ||
| source_ranks_per_group: tuple[tuple[int, ...], ...] | ||
|
|
||
| # Superset of all source ranks (union of all groups). | ||
| all_source_ranks: tuple[int, ...] |
There was a problem hiding this comment.
let's just add this as comment above here
| group_spec_types: tuple[type[KVCacheSpec], ...], | ||
| local_physical_blocks_per_logical: int, | ||
| ) -> EngineTransferPlan: | ||
| """Generate transfer plan for dense (attention-only) models.""" |
There was a problem hiding this comment.
let's ask claude for some more options here
| all_descs: list[np.ndarray] = [] | ||
| for i, group in enumerate(block_ids): | ||
| group_arr = np.asarray(group) | ||
| spec_type = plan.group_spec_types[i] | ||
| if _is_attention_spec(spec_type): | ||
| fa_region_ids = np.arange(num_fa_regions)[:, None] | ||
| all_descs.append( | ||
| (fa_region_ids * num_blocks + group_arr[None, :]).flatten() | ||
| ) |
| Get the block length for one K/V element (K and V have the same size). | ||
|
|
||
| For FA and other backends, this is equal to the length of the whole | ||
| block, as K and V are in separate regions. | ||
| For FlashInfer, this is half the length of the whole block, as K and V | ||
| share the same region. | ||
| Similarly, for SSM-based models, state and conv are interleaved, but crucially | ||
| the their size differs. | ||
| Reference diagram: | ||
| KVCacheTensor (Shared) | ||
| / \\ | ||
| / \\ | ||
| / \\ | ||
| Attention (FlashInfer) View Mamba View | ||
| | | | ||
| | | | ||
| +-------------------+ +-------------------+ | ||
| | KVCacheTensor | | KVCacheTensor | | ||
| | | | | | ||
| |<----- page ------>| |<----- page ------->| | ||
| | size | | size | | ||
| | Key 0 | Val 0 | |Conv 0 | SSM 0 | | ||
| | Key 1 | Val 1 | |Conv 1 | SSM 1 | | ||
| | ... | ... | | ... | ... | | ||
| | Key N-2 | Val N-2 | |Conv N-2| SSM N-2 | | ||
| | Key N-1 | Val N-1 | |Conv N-1| SSM N-1 | | ||
| +-------------------+ +--------------------+ | ||
| |1st_split-2nd_split| |1st_split-2nd_split | | ||
| """ |
| @dataclass(frozen=True) | ||
| class MambaEngineTransferInfo(EngineTransferInfo): | ||
| """Extends ``EngineTransferInfo`` with Mamba-hybrid transfer geometry. |
a6e5266 to
51f297a
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
51f297a to
7881c46
Compare
NickLucche
left a comment
There was a problem hiding this comment.
I think this is a particularly complicated part of the codebase, thanks for the great work here @ZhanqiuHu trying to improve clarity and maintainability of it!
Happy to finally approve this PR. Left some comments, mostly around preserving some of the context we had in comments. Will push something to help the work around the clock
| self._physical_blocks_per_logical_kv_block = ( | ||
| self.block_size // kernel_block_size | ||
| ) |
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """TP mapping computation for NIXL KV cache transfers.""" |
There was a problem hiding this comment.
I think this file is tp_mapping now :)
| # With homogeneous TP, D pulls the whole kv cache from corresponding | ||
| # rank. With heterogeneous TP, prepare the descriptors by splitting the | ||
| # P KV cache along kv_head dim, of D worker's kv_head size (D>P). | ||
| # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. | ||
|
|
||
| # Register all remote blocks, but only the corresponding kv heads. |
There was a problem hiding this comment.
this might still be good context
| ) | ||
|
|
||
| if transfer_topo.is_kv_layout_blocks_first: | ||
| # With FlashInfer index V separately to allow head splitting. |
| # Separate and interleave K/V regions to maintain the same | ||
| # descs ordering. This is needed for selecting contiguous heads | ||
| # when split across TP ranks. |
| # Mamba conv state is always TP-sharded, even when attention KV | ||
| # is replicated (num_kv_heads < tp_size). |
| # ---- Mamba-HMA per-engine state (only used when self._has_mamba) ---- | ||
| # NOTE (ZhanqiuHu): _physical_blocks_per_logical MUST be per-engine. | ||
| # physical_blocks_per_logical = ceil((conv_bytes + ssm_bytes) / block_len) | ||
| # where conv/ssm bytes are per-TP-rank (dimension-sharded). With | ||
| # heterogeneous TP the per-rank sizes differ, so the ratio differs: | ||
| # e.g. Nemotron 30B: P(TP=4) → 131, D(TP=1) → 261. |
There was a problem hiding this comment.
I think this is also good context
f23d68c to
98c3207
Compare
Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: Zhanqiu Hu <zhu@redhat.com> Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com>
Head branch was pushed to by a user without write access
98c3207 to
1232865
Compare
|
It seems that this PR breaks DSv4 PD with dynamo Have we verified the implementation with an end-to-end integration load test? We're investigating on parallel. |
I tested this end-to-end with Qwen and Nemotron and Deepseek-V2, haven't test with Deepseek v4 yet. Will look into this. By the way do you have the setup and full traces? Thanks! |
|
might be a compat version bump, thanks for reporting @ywang96 will check this out |
|
@ZhanqiuHu Here is an MRE using Slurm, launching 1 prefill and 1 decode worker Script"""Launch one DSv4-Flash prefill worker and one decode worker via Slurm.
Please change the hard-coded nodes accordingly.
Install NIXL
uv pip install 'nixl[cu13]==0.10.1' 'nixl-cu12==0.10.1' 'nixl-cu13==0.10.1' 'cupy-cuda12x==14.0.1'
"""
import contextlib
from dataclasses import dataclass
import json
import os
import signal
import shlex
import subprocess
import sys
import time
import uuid
from pathlib import Path
import requests
MODEL = "deepseek-ai/DeepSeek-V4-Flash"
SERVER_HOST = "0.0.0.0"
LOG_DIR = Path(__file__).parent
GPUS_PER_WORKER = 4
CPUS_PER_WORKER = 32
READY_TIMEOUT_S = 900
REQUEST_TIMEOUT_S = 300
@dataclass(frozen=True)
class Worker:
name: str
node: str
port: int
nixl_port: int
@property
def host(self) -> str:
return self.node
@property
def base_url(self) -> str:
return f"http://{self.host}:{self.port}"
@property
def log_path(self) -> Path:
return LOG_DIR / f"{self.name}.log"
### PLEASE UPDATE THIS ###
PREFILL = Worker(name="prefill", node="node-01", port=8100, nixl_port=5610)
DECODE = Worker(name="decode", node="node-02", port=8200, nixl_port=5620)
PROMPT = (
"Write a detailed explanation of paged attention for transformer inference, "
"including how KV cache blocks are allocated, transferred, and reused across "
"prefill and decode workers in a disaggregated serving system. Include enough "
"detail that this prompt exceeds the KV cache block threshold."
)
VLLM_CMD = f"""
{shlex.quote(sys.executable)}
-m vllm.entrypoints.openai.api_server
--model {shlex.quote(MODEL)}
--host {SERVER_HOST}
--trust-remote-code
--load-format dummy
--data-parallel-size {GPUS_PER_WORKER}
--enable-expert-parallel
--max-model-len 512
--max-num-batched-tokens 512
--block-size 256
--gpu-memory-utilization 0.80
--kv-cache-dtype fp8
--tokenizer-mode deepseek_v4
--moe-backend deep_gemm_mega_moe
--enforce-eager
--compilation-config '{{"mode":0}}'
--attention-config '{{"use_fp4_indexer_cache":true}}'
--kv-transfer-config '{{"kv_connector":"NixlConnector","kv_role":"kv_both"}}'
"""
def launch(worker: Worker) -> subprocess.Popen:
cmd = [
"srun",
"--nodes=1",
"--ntasks=1",
f"--nodelist={worker.node}",
f"--gpus-per-task={GPUS_PER_WORKER}",
f"--cpus-per-task={CPUS_PER_WORKER}",
"--export=ALL",
*shlex.split(VLLM_CMD),
"--port",
str(worker.port),
]
env = os.environ | {
"VLLM_NIXL_SIDE_CHANNEL_HOST": worker.host,
"VLLM_NIXL_SIDE_CHANNEL_PORT": str(worker.nixl_port),
"VLLM_WORKER_MULTIPROC_METHOD": os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"),
"UCX_TLS": os.getenv("UCX_TLS", "cuda_copy,cuda_ipc,tcp"),
"UCX_MEMTYPE_CACHE": os.getenv("UCX_MEMTYPE_CACHE", "n"),
"UCX_MEMTYPE_REG_WHOLE": os.getenv("UCX_MEMTYPE_REG_WHOLE", "n"),
}
print(f"\n{worker.name}: {shlex.join(cmd)}", flush=True)
print(
f"{worker.name} srun: node={worker.node} gpus_per_task={GPUS_PER_WORKER} "
f"cpus_per_task={CPUS_PER_WORKER} "
f"{worker.name} env: VLLM_NIXL_SIDE_CHANNEL_HOST={env['VLLM_NIXL_SIDE_CHANNEL_HOST']} "
f"VLLM_NIXL_SIDE_CHANNEL_PORT={env['VLLM_NIXL_SIDE_CHANNEL_PORT']}",
flush=True,
)
return subprocess.Popen(
cmd,
env=env,
stdout=worker.log_path.open("w", encoding="utf-8"),
stderr=subprocess.STDOUT,
text=True,
start_new_session=True,
)
def wait_ready(worker: Worker, proc: subprocess.Popen) -> None:
url = f"{worker.base_url}/v1/models"
deadline = time.monotonic() + READY_TIMEOUT_S
while time.monotonic() < deadline:
if proc.poll() is not None:
raise RuntimeError(f"{worker.name} exited early; see {worker.log_path}")
try:
requests.get(url, timeout=5).raise_for_status()
print(f"{worker.name} ready: {url}", flush=True)
return
except requests.RequestException:
time.sleep(2)
raise TimeoutError(f"{worker.name} did not become ready; see {worker.log_path}")
def raise_for_status(response: requests.Response) -> None:
try:
response.raise_for_status()
except requests.HTTPError as exc:
body = response.text[:4000]
raise RuntimeError(f"{response.request.method} {response.url} failed: {exc}\n{body}") from exc
def stop(proc: subprocess.Popen | None) -> None:
if proc is None or proc.poll() is not None:
return
with contextlib.suppress(ProcessLookupError):
os.killpg(proc.pid, signal.SIGTERM)
try:
proc.wait(timeout=20)
except subprocess.TimeoutExpired:
with contextlib.suppress(ProcessLookupError):
os.killpg(proc.pid, signal.SIGKILL)
def main() -> None:
LOG_DIR.mkdir(parents=True, exist_ok=True)
print(f"logs: {LOG_DIR}", flush=True)
print(
f"prefill: node={PREFILL.node} host={PREFILL.host} port={PREFILL.port} nixl_port={PREFILL.nixl_port}",
flush=True,
)
print(
f"decode: node={DECODE.node} host={DECODE.host} port={DECODE.port} nixl_port={DECODE.nixl_port}",
flush=True,
)
prefill = decode = None
try:
prefill = launch(PREFILL)
decode = launch(DECODE)
wait_ready(PREFILL, prefill)
wait_ready(DECODE, decode)
request_id = str(uuid.uuid4())
prefill_response = requests.post(
f"{PREFILL.base_url}/v1/completions",
headers={"X-Request-Id": request_id},
timeout=REQUEST_TIMEOUT_S,
json={
"model": MODEL,
"prompt": PROMPT,
"max_tokens": 1,
"temperature": 0,
"kv_transfer_params": {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
},
},
)
raise_for_status(prefill_response)
prefill_json = prefill_response.json()
print("\nprefill response:", json.dumps(prefill_json, indent=2)[:4000])
kv_transfer_params = prefill_json.get("kv_transfer_params")
if not kv_transfer_params:
raise RuntimeError(f"prefill response did not include kv_transfer_params: {prefill_json}")
decode_response = requests.post(
f"{DECODE.base_url}/v1/completions",
headers={"X-Request-Id": request_id},
timeout=REQUEST_TIMEOUT_S,
json={
"model": MODEL,
"prompt": "decode side placeholder prompt",
"max_tokens": 8,
"temperature": 0,
"kv_transfer_params": kv_transfer_params,
},
)
print("\ndecode status:", decode_response.status_code)
print(decode_response.text[:4000])
raise_for_status(decode_response)
finally:
print(f"\nlogs: {LOG_DIR}", flush=True)
stop(prefill)
stop(decode)
if __name__ == "__main__":
main()Here is another MRE generated by Codex, looking into the internals. I'm not too familiar with this code so I'm not entirely sure if it makes sense. But maybe you will understand """Minimal repro for the suspected NIXL PR #40731 MLA classification bug.
"""
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.nixl.tp_mapping import (
_is_attention_spec,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheSpec,
MLAAttentionSpec,
MambaSpec,
UniformTypeKVCacheSpecs,
)
def classify_like_pr40731(spec: KVCacheSpec) -> type[KVCacheSpec]:
return type(spec)
def classify_with_uniform_unwrap(spec: KVCacheSpec) -> type[KVCacheSpec]:
if isinstance(spec, UniformTypeKVCacheSpecs):
inner_specs = tuple(spec.kv_cache_specs.values())
if inner_specs and all(isinstance(s, AttentionSpec) for s in inner_specs):
return AttentionSpec
if inner_specs and all(isinstance(s, MambaSpec) for s in inner_specs):
return MambaSpec
return type(spec)
def find_fa_group(group_spec_types: tuple[type[KVCacheSpec], ...]) -> int:
return next(
i for i, spec_type in enumerate(group_spec_types) if _is_attention_spec(spec_type)
)
def main() -> None:
mla_spec = MLAAttentionSpec(
block_size=64,
num_kv_heads=1,
head_size=512,
dtype=torch.bfloat16,
cache_dtype_str="fp8_ds_mla",
compress_ratio=1,
model_version="deepseek_v4",
)
uniform_mla_spec = UniformTypeKVCacheSpecs(
block_size=64,
kv_cache_specs={
"model.layers.0.self_attn": mla_spec,
"model.layers.1.self_attn": mla_spec,
},
)
old_group_spec_types = (classify_like_pr40731(uniform_mla_spec),)
print(
"old_group_spec_types =",
[spec_type.__name__ for spec_type in old_group_spec_types],
)
try:
old_idx = find_fa_group(old_group_spec_types)
except StopIteration:
print("old path reproduced StopIteration")
else:
raise AssertionError(f"old path unexpectedly found FA group {old_idx}")
fixed_group_spec_types = (classify_with_uniform_unwrap(uniform_mla_spec),)
print(
"fixed_group_spec_types =",
[spec_type.__name__ for spec_type in fixed_group_spec_types],
)
fixed_idx = find_fa_group(fixed_group_spec_types)
print("fixed path found FA group", fixed_idx)
if __name__ == "__main__":
main() |
|
@gau-nernst Thanks! I will take a look. IIRC |
yep when registering kv tensors |
…roject#40731 PR vllm-project#40449 fixed the MLA broadcast notification to use `meta.remote.request_id` (the prefill-side id) instead of `req_id` (the decode-side id). The refactor in vllm-project#40731 inadvertently reverted this back to `req_id`, causing "Potentially invalid KV blocks for unrecognized request" errors on non-reading P-ranks when P_TP > D_TP. Fixes vllm-project#41974 Signed-off-by: Zhanqiu Hu <zhanqiuhu@gmail.com>
[Bugfix] restore MLA notif request_id accidentally reverted in vllm-project#40731
Signed-off-by: ZhanqiuHu <zhu@redhat.com> Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: NickLucche <nlucches@redhat.com> Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
…roject#40731 PR vllm-project#40449 fixed the MLA broadcast notification to use `meta.remote.request_id` (the prefill-side id) instead of `req_id` (the decode-side id). The refactor in vllm-project#40731 inadvertently reverted this back to `req_id`, causing "Potentially invalid KV blocks for unrecognized request" errors on non-reading P-ranks when P_TP > D_TP. Fixes vllm-project#41974 Signed-off-by: Zhanqiu Hu <zhanqiuhu@gmail.com> Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com> Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: NickLucche <nlucches@redhat.com> Signed-off-by: Libin Tang <libin.tang@intel.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com> Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: NickLucche <nlucches@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com> Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: NickLucche <nlucches@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com> Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: NickLucche <nlucches@redhat.com>
Signed-off-by: ZhanqiuHu <zhu@redhat.com> Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: NickLucche <nlucches@redhat.com> Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Refactor 3/N