Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ class CompilationConfig:
"vllm::kda_attention",
"vllm::sparse_attn_indexer",
"vllm::rocm_aiter_sparse_attn_indexer",
"vllm::rocm_sparse_attn_indexer_no_insert",
"vllm::deepseek_v4_attention",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,17 +408,26 @@ def update_state_after_alloc(
)

else:
# WRITE mode: prefill scheduler notifies the decode side that
# blocks are ready. Parse the decode's host/notify_port from
# the request_id
# WRITE mode: this branch only fires on the decode-side
# scheduler (the toy proxy sets do_remote_prefill=True only on
# decode-bound requests). The decode tells the prefill which
# blocks to RDMA-write into, so we need the *prefill's*
# host/notify_port from the request_id.
# get_peer_zmq_from_request_id() takes the *caller's* role and
# returns the peer's address; passing self.is_producer=False
# on the decode side resolves to the prefill address.
# Hardcoding True here used to make the decode send the
# block-notify message to its own notify port, where the
# consumer-role assertion in
# MoRIIOWrapper._handle_structured_message would fail.
assert request.kv_transfer_params is not None, (
"kv_transfer_params should not be None"
)

remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0)

peer_zmq = get_peer_zmq_from_request_id(
request.request_id, is_producer=True
request.request_id, is_producer=self.is_producer
)
remote_host, _, remote_notify_port = parse_moriio_zmq_address(peer_zmq)

Expand Down Expand Up @@ -770,17 +779,83 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.use_mla = self.model_config.use_mla
self.built_session = False
self.built_write_session: defaultdict[str, list] = defaultdict(list)
backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.cache_config.cache_dtype,
use_mla=self.use_mla,
)
# DeepSeek V4 sparse attention switches the cache layout to
# "fp8_ds_mla" inside the attention layer. The platform's attention
# selector only exposes the sparse-MLA backend (which understands that
# cache dtype) when use_sparse=True is requested. Detect that case from
# the configured cache dtype so the connector's backend probe matches
# what the model actually instantiated.
self.use_sparse = self.cache_config.cache_dtype == "fp8_ds_mla"
self.transfer_id_to_request_id: dict[TransferId, ReqId] = {}

# The platform selector cannot describe every backend a model may pick
# for itself. DeepSeek V4 in particular returns its own
# DeepseekV4FlashMLASparseBackend from the attention layer's
# get_attn_backend() and never goes through get_attn_backend() at the
# platform level. On ROCm there is no platform candidate registered for
# (use_mla=True, use_sparse=True, kv_cache_dtype=fp8_ds_mla), so the
# selector raises here even though the model is running fine via the
# ROCm FlashMLA fallbacks. self.backend_name is only used as
# informational metadata in the MoRIIO handshake (no dispatch keys off
# it), so probe optimistically and fall back to the actual backend the
# model instantiated.
# TODO: consider the integration of flashinfer or other backends.
self.backend_name = backend.get_name()
logger.debug("Detected attention backend %s", self.backend_name)
try:
backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.cache_config.cache_dtype,
use_mla=self.use_mla,
use_sparse=self.use_sparse,
)
self.backend_name = backend.get_name()
except ValueError as exc:
self.backend_name = self._infer_backend_name_from_model(
vllm_config
)
logger.warning(
"Platform attention selector has no entry for "
"(use_mla=%s, use_sparse=%s, kv_cache_dtype=%s); using '%s' "
"for MoRIIO handshake metadata. Underlying selector error: %s",
self.use_mla,
self.use_sparse,
self.cache_config.cache_dtype,
self.backend_name,
exc,
)
else:
logger.debug("Detected attention backend %s", self.backend_name)

@staticmethod
def _infer_backend_name_from_model(vllm_config: VllmConfig) -> str:
"""Recover the attention backend name from a model that bypasses the
platform selector (e.g., DeepSeek V4 returning its own backend class
from the attention layer's get_attn_backend()).

Walks vllm_config.compilation_config.static_forward_context, which
holds the instantiated attention layers, and returns the name of the
first backend class advertised by any of them. Falls back to a
sentinel string if introspection fails.
"""
sentinel = "UNREGISTERED"
try:
forward_context = (
vllm_config.compilation_config.static_forward_context
)
except AttributeError:
return sentinel
for layer in forward_context.values():
getter = getattr(layer, "get_attn_backend", None)
if getter is None:
continue
try:
backend_cls = getter()
name = backend_cls.get_name()
except Exception:
continue
if isinstance(name, str) and name:
return name
return sentinel

def schedule_write_blocks(
self,
Expand Down Expand Up @@ -1174,7 +1249,25 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
if layer_name not in self.layer_name_to_local_kv_cache_metadata:
self.layer_name_to_local_kv_cache_metadata[layer_name] = []

moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(kv_cache)
# MoRIIO's register_torch_tensor requires a contiguous tensor for
# RDMA pinning. Some KV cache layouts (e.g. DeepSeek V4's MLA spec
# with alignment=576) wrap the underlying contiguous int8 storage
# in a torch.as_strided() view that is non-contiguous by design
# (see vllm/v1/worker/gpu_model_runner.py::_reshape_kv_cache_tensors
# and MLAAttentionSpec.page_size_padded). For those, register the
# underlying storage as a flat 1D uint8 alias instead -- same
# memory, same data_ptr(), so all downstream addressing
# (kv_caches_base_addr, schedule_write_blocks, etc.) is unchanged.
if kv_cache.is_contiguous():
register_target = kv_cache
else:
register_target = torch.empty(
0, dtype=torch.uint8, device=kv_cache.device
).set_(kv_cache.untyped_storage())

moriio_mem_metadata = self.moriio_wrapper.register_local_tensor(
register_target
)
self.layer_name_to_local_kv_cache_metadata[layer_name].append(
moriio_mem_metadata
)
Expand Down
32 changes: 32 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT: bool = False
VLLM_ROCM_USE_V4_TRITON_FALLBACK: bool = True
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
VLLM_DISABLE_COMPILE_CACHE: bool = False
Expand Down Expand Up @@ -1076,6 +1077,37 @@ def _get_or_set_default() -> str:
"VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: (
os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1")
),
# Master switch for the ROCm-native code paths used by
# DeepSeek-V4 (DSv4-Flash-FP8). When True (default on ROCm) the model
# selects the triton/torch fallbacks at three call sites:
#
# 1. SWA K-cache writer: torch reference
# (``_deepseek_v4_qnorm_rope_kv_insert_reference``) instead of
# upstream's HIPified ``fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert``
# C++ kernel, whose FP8 dtype is selected at compile time
# (``HIP_FP8_TYPE_OCP``) and silently corrupts every K byte on
# MI300X (FNUZ-only). This is the regression fix.
# 2. Sparse indexer: ``rocm_sparse_attn_indexer_no_insert``
# orchestration instead of upstream's
# ``rocm_aiter_sparse_attn_indexer_native``.
# 3. MLA sparse backend dispatch: route through the unified
# ``DeepseekV4FlashMLASparseBackend`` (whose ROCm kernels are
# supplied by ``flash_mla_with_kvcache_rocm`` /
# ``flash_mla_sparse_fwd_rocm`` via ``flashmla.py``) instead of
# ``DeepseekV4ROCMAiterMLASparseBackend`` /
# ``Impl`` (whose ``_sparse_attn_decode_ragged_kernel`` Triton
# kernel currently hard-codes the SM89 ``tl.float8e4b15`` dtype
# in the ``IS_FNUZ`` branch and crashes JIT-compile on
# gfx942 — see logs/0512/server_log2.txt).
#
# Set to "0" to opt back into the upstream AITER + native paths for
# bisection (note: the SWA-writer C++ kernel still produces
# deterministic garbage on MI300X, and the AITER Triton kernel has the
# ``fp8e4b15`` bug above, so env=0 is only useful for kernel debugging
# at present).
"VLLM_ROCM_USE_V4_TRITON_FALLBACK": lambda: (
os.getenv("VLLM_ROCM_USE_V4_TRITON_FALLBACK", "True").lower() in ("true", "1")
),
# Custom quick allreduce kernel for MI3* cards
# Choice of quantization level: FP, INT8, INT6, INT4 or NONE
# Recommended for large models to get allreduce
Expand Down
Loading
Loading