Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 143 additions & 2 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class KVArgsRegisterInfo:
decode_tp_size: int
decode_tp_rank: int
dst_kv_item_len: int
dst_state_item_lens: list[int] = dataclasses.field(default_factory=list)
dst_state_dim_per_tensor: list[int] = dataclasses.field(default_factory=list)

@classmethod
def from_zmq(cls, msg: List[bytes]):
Expand All @@ -93,6 +95,15 @@ def from_zmq(cls, msg: List[bytes]):
else:
dst_state_data_ptrs = []

dst_state_item_lens = []
dst_state_dim_per_tensor = []
if len(msg) > 12 and len(msg[12]) > 0:
dst_state_item_lens = list(struct.unpack(f"{len(msg[12]) // 4}I", msg[12]))
if len(msg) > 13 and len(msg[13]) > 0:
dst_state_dim_per_tensor = list(
struct.unpack(f"{len(msg[13]) // 4}I", msg[13])
)

return cls(
room=str(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
Expand All @@ -106,6 +117,8 @@ def from_zmq(cls, msg: List[bytes]):
decode_tp_size=int(msg[9].decode("ascii")),
decode_tp_rank=int(msg[10].decode("ascii")),
dst_kv_item_len=int(msg[11].decode("ascii")),
dst_state_item_lens=dst_state_item_lens,
dst_state_dim_per_tensor=dst_state_dim_per_tensor,
)


Expand Down Expand Up @@ -671,6 +684,106 @@ def _send_mamba_state(
raise Exception("Failed to post Mamba state transfer")
return xfer_handle

def _send_mamba_state_slice(
self,
peer_name: str,
prefill_state_indices: List[int],
dst_state_data_ptrs: list[int],
dst_state_indices: List[int],
dst_gpu_id: int,
notif: str,
dst_state_item_lens: list[int],
dst_state_dim_per_tensor: list[int],
decode_tp_size: int,
decode_tp_rank: int,
):
"""Transfer Mamba states with TP slice support via RDMA.

When prefill and decode have different attn_tp_size, we slice the
TP-sharded dimension (3rd dim) of conv_state and temporal_state
accordingly, mirroring Mooncake's _send_mamba_state_slice.
"""
logger.warning_once(
"Using Mamba state slice transfer for different TP sizes. "
f"Prefill attn_tp_size={self.attn_tp_size}, "
f"Decode attn_tp_size={decode_tp_size}."
)
assert len(prefill_state_indices) == 1, "Mamba should have single state index"

prefill_state_data_ptrs = self.kv_args.state_data_ptrs
prefill_state_item_lens = self.kv_args.state_item_lens
src_state_dim_per_tensor = getattr(self.kv_args, "state_dim_per_tensor", [])

if not src_state_dim_per_tensor or not dst_state_dim_per_tensor:
return self._send_mamba_state(
peer_name,
prefill_state_indices,
dst_state_data_ptrs,
dst_state_indices,
dst_gpu_id,
notif,
)

local_tp_rank_in_group = self.kv_args.engine_rank % self.attn_tp_size
dst_tp_rank_in_group = decode_tp_rank % decode_tp_size

src_addrs = []
dst_addrs = []

for i, dst_state_ptr in enumerate(dst_state_data_ptrs):
src_item_len = prefill_state_item_lens[i]
dst_item_len = dst_state_item_lens[i]
src_dim = src_state_dim_per_tensor[i]
dst_dim = dst_state_dim_per_tensor[i]

src_bytes_per_dim = src_item_len // src_dim
dst_bytes_per_dim = dst_item_len // dst_dim

if self.attn_tp_size > decode_tp_size:
src_dim_start = 0
num_dims_to_send = src_dim
writers_per_decode = self.attn_tp_size // decode_tp_size
local_writer_idx = local_tp_rank_in_group % writers_per_decode
dst_dim_start = local_writer_idx * src_dim
else:
src_dim_start = (dst_tp_rank_in_group * dst_dim) % src_dim
num_dims_to_send = dst_dim
dst_dim_start = 0

src_dim_offset = src_dim_start * src_bytes_per_dim
dst_dim_offset = dst_dim_start * dst_bytes_per_dim
bytes_to_send = num_dims_to_send * src_bytes_per_dim

src_addr = (
prefill_state_data_ptrs[i]
+ src_item_len * int(prefill_state_indices[0])
+ src_dim_offset
)
dst_addr = (
dst_state_ptr
+ dst_item_len * int(dst_state_indices[0])
+ dst_dim_offset
)
src_addrs.append((src_addr, bytes_to_send, self.kv_args.gpu_id))
dst_addrs.append((dst_addr, bytes_to_send, dst_gpu_id))

src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")

xfer_handle = self.agent.initialize_xfer(
"WRITE",
src_descs,
dst_descs,
peer_name,
notif.encode("ascii"),
)
if not xfer_handle:
raise Exception("Failed to create Mamba state slice transfer")
state = self.agent.transfer(xfer_handle)
if state == "ERR":
raise Exception("Failed to post Mamba state slice transfer")
return xfer_handle

def maybe_send_extra(
self,
peer_name: str,
Expand All @@ -680,14 +793,26 @@ def maybe_send_extra(
dst_gpu_id: int,
notif: str,
decode_tp_size: int,
decode_tp_rank: int = 0,
dst_state_item_lens: list[int] | None = None,
dst_state_dim_per_tensor: list[int] | None = None,
Comment on lines +796 to +798

@ShangmingCai ShangmingCai Apr 7, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could send KVArgsRegisterInfo in, instead of passing every single parameter, for better readability and future maintenance.

):
"""Send state or extra pool data with type-specific handling."""
state_type = getattr(self.kv_args, "state_type", "none")

if state_type == "mamba":
if self.attn_tp_size != decode_tp_size:
raise RuntimeError(
"PD Disaggregation does NOT support PD different TP sizes for hybrid mamba models yet."
return self._send_mamba_state_slice(
peer_name,
prefill_state_indices,
dst_state_data_ptrs,
dst_state_indices,
dst_gpu_id,
notif,
dst_state_item_lens or [],
dst_state_dim_per_tensor or [],
decode_tp_size,
decode_tp_rank,
)
return self._send_mamba_state(
peer_name,
Expand Down Expand Up @@ -791,6 +916,9 @@ def add_transfer_request(
dst_info.gpu_id,
f"{req.room}_state_{self.kv_args.pp_rank}",
decode_tp_size,
decode_tp_rank=dst_info.decode_tp_rank,
dst_state_item_lens=dst_info.dst_state_item_lens,
dst_state_dim_per_tensor=dst_info.dst_state_dim_per_tensor,
)
if state_xfer_handle is not None:
handles.append(state_xfer_handle)
Expand Down Expand Up @@ -1068,6 +1196,17 @@ def _register_kv_args(self):
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs
)

packed_state_item_lens = b"".join(
struct.pack("I", item_len)
for item_len in self.kv_mgr.kv_args.state_item_lens
)
state_dim_per_tensor = getattr(
self.kv_mgr.kv_args, "state_dim_per_tensor", []
)
packed_state_dim_per_tensor = b"".join(
struct.pack("I", dim) for dim in state_dim_per_tensor
)

with lock:
sock.send_multipart(
[
Expand All @@ -1084,6 +1223,8 @@ def _register_kv_args(self):
str(self.kv_mgr.attn_tp_size).encode("ascii"),
str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"),
packed_state_item_lens,
packed_state_dim_per_tensor,
]
)

Expand Down
Loading