Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 8 additions & 81 deletions python/sglang/srt/disaggregation/conn.py
Original file line number Diff line number Diff line change
@@ -1,81 +1,8 @@
from __future__ import annotations

import logging
from enum import Enum
from typing import Optional

import numpy as np
import numpy.typing as npt

logger = logging.getLogger(__name__)


class KVArgs:
engine_rank: int
kv_data_ptrs: list[int]
kv_data_lens: list[int]
kv_item_lens: list[int]
aux_data_ptrs: list[int]
aux_data_lens: list[int]
aux_item_lens: list[int]
ib_device: str


class KVManager:
def __init__(self, args: KVArgs): ...


class KVPoll:
Failed = 0
Bootstrapping = 1
WaitingForInput = 2
Transferring = 3
Success = 4


class KVSender:
def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int):
self.has_sent = False

def init(self, num_kv_indices: int, aux_index: Optional[int] = None): ...

def send(self, kv_indices: npt.NDArray[np.int32]):
self.has_sent = True

def poll(self) -> KVPoll:
if self.has_sent is False:
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
return KVPoll.Success

def failure_exception(self):
raise Exception("Fake KVSender Exception")


class KVReceiver:
def __init__(
self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None
):
self.has_init = False

def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
self.has_init = True

def poll(self) -> KVPoll:
if self.has_init is False:
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
return KVPoll.Success

def failure_exception(self):
raise Exception("Fake KVReceiver Exception")


class KVBootstrapServer:
def __init__(self, port: int): ...

def poll(self) -> KVPoll: ...
MOONCAKE = "mooncake"
PYVERBS = "pyverbs"
mode = PYVERBS

if mode == MOONCAKE:
from sglang.srt.disaggregation.mooncake_conn import *
elif mode == PYVERBS:
from sglang.srt.disaggregation.verbs_conn import *
4 changes: 3 additions & 1 deletion python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch
from torch.distributed import ProcessGroup

from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver
from sglang.srt.disaggregation.verbs_conn import KVArgs, KVManager, KVPoll, KVReceiver
from sglang.srt.disaggregation.utils import (
ReqToMetadataIdxAllocator,
poll_and_all_reduce,
Expand Down Expand Up @@ -115,6 +115,8 @@ def _init_kv_manager(self) -> KVManager:
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
]
kv_args.ib_device = "mock-ib-device"
kv_args.gpu_id = self.scheduler.gpu_id

kv_manager = KVManager(kv_args)
return kv_manager

Expand Down
38 changes: 38 additions & 0 deletions python/sglang/srt/disaggregation/design.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Simple PD based on pyverbs

## Overall Architecture

### Sequence Diagram

The overall sequence is as follows ![img](./pd_1.png)

Detailed interpretation

![img](./sequence.png)

### RDMA Connection Establishment Process

* Prefill Server

1. Start BootstrapServer (global startup)

2. When a request comes in, create a Sender object
2.1 Sender initialization -> Enter Bootstrapping phase
2.2 Each worker (tp) of P communicates with BootstrapServer to query the peer D's port and IP based on room_id and engine rank
2.3 After obtaining the peer's rdma socket port, enter WaitingForInput phase
2.4 Sender-init method: Initialize RdmaClient -> RdmaServer + socketport to start rdma connection establishment, exchange metadata buffer information, obtain peer D's metadata buffer and memory address array and rkey array of the D segment to be operated. Enter Transfering phase
2.5 Forward then send, calculate the base address + layer cache length for each layer's cache based on the computed kv_indices, create local MR objects, bind with the exchanged addresses and rkeys, use SendWR for remote GPU memory writing (using IBV_WR_RDMA_WRITE mode without requiring server-side recv)
2.5 Poll local Send_CQ, once all Kvcache MRs are written successfully, write a Metadata buffer (using IBV_WR_RDMA_WRITE_WITH_IMM mode, server needs recv)
2.6 After all SendWRs are sent, TransferComplete

* Decode Server

0. When request comes in, pre-allocate kv space

1. Register its rank and a random port (for RdmaServer port) with the Bootstraper server, bind sock. If successful, enter WaitingForInput

2. Decode calls init method, passing in kvindices, aux_index, at this time exchanges with Rdma Client, mainly sending its metadata addr, rkey, pre-allocated address, rkey, len to P node through sock communication. Enter Transfering phase

3. Submit recv_metadata_mr, RDMA waits for P phase first word address sending

4. Poll waiting for metadata write success, once successful TransferComplete
64 changes: 64 additions & 0 deletions python/sglang/srt/disaggregation/group_indics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env python
# coding:utf-8
"""
@author: nivic ybyang7
@license: Apache Licence
@file: group_indics.py
@time: 2025/04/02
@contact: [email protected]
@site:
@software: PyCharm

# code is far away from bugs with the god animal protecting
I love animals. They taste delicious.
┏┓ ┏┓
┏┛┻━━━┛┻┓
┃ ☃ ┃
┃ ┳┛ ┗┳ ┃
┃ ┻ ┃
┗━┓ ┏━┛
┃ ┗━━━┓
┃ 神兽保佑 ┣┓
┃ 永无BUG! ┏┛
┗┓┓┏━┳┓┏┛
┃┫┫ ┃┫┫
┗┻┛ ┗┻┛
"""

# Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit.
# Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan.
# Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna.
# Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus.
# Vestibulum commodo. Ut rhoncus gravida arcu.
import numpy as np

def group_by_continuity_numpy(arr1, arr2):
arr1 = np.array(arr1)
arr2 = np.array(arr2)

# 找到不连续的索引点(相邻元素差值不等于1)
split_indices = np.where(np.diff(arr1) != 1)[0] + 1

# 用 split_indices 切分数组
grouped_arr1 = np.split(arr1, split_indices)
grouped_arr2 = np.split(arr2, split_indices)

return [list(g) for g in grouped_arr1], [list(g) for g in grouped_arr2]


def groups_by_continuity_numpy(arr1):
arr1 = np.array(arr1)

# 找到不连续的索引点(相邻元素差值不等于1)
split_indices = np.where(np.diff(arr1) != 1)[0] + 1

# 用 split_indices 切分数组
grouped_arr1 = np.split(arr1, split_indices)

return [list(g) for g in grouped_arr1]

if __name__ == '__main__':
a = [1,9,3,4,5,6,7]
b = [2, 2, 3, 4, 5, 6, 7]
print(group_by_continuity_numpy(a,b))
print(groups_by_continuity_numpy(a))
99 changes: 99 additions & 0 deletions python/sglang/srt/disaggregation/ib_devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/usr/bin/env python
# coding:utf-8
"""
@author: nivic ybyang7
@license: Apache Licence
@file: ib_devices
@time: 2025/04/03
@contact: [email protected]
@site:
@software: PyCharm

# code is far away from bugs with the god animal protecting
I love animals. They taste delicious.
┏┓ ┏┓
┏┛┻━━━┛┻┓
┃ ☃ ┃
┃ ┳┛ ┗┳ ┃
┃ ┻ ┃
┗━┓ ┏━┛
┃ ┗━━━┓
┃ 神兽保佑 ┣┓
┃ 永无BUG! ┏┛
┗┓┓┏━┳┓┏┛
┃┫┫ ┃┫┫
┗┻┛ ┗┻┛
"""
import os

# Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit.
# Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan.
# Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna.
# Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus.
# Vestibulum commodo. Ut rhoncus gravida arcu.
import pyverbs.device as d
import pynvml
def get_device_list(prefix, gpu_no=0, roce_version=2, port_num=1):
lst = d.get_device_list()
if len(lst) == 0:
print("No IB devices")
return []
device_list = {}
for dev in lst:
if dev.name.decode().startswith(prefix):
with d.Context(name=dev.name.decode()) as ctx:
gid_tbl_len = ctx.query_port(port_num).gid_tbl_len
if gid_tbl_len > 0:
ctx.query_gid(port_num=port_num, index=roce_version)
# 从 sysfs 获取 PCI 地址
dev_path = f"/sys/class/infiniband/{dev.name.decode()}/device"
if os.path.exists(dev_path):
pci_addr = os.readlink(dev_path).split("/")[-1] # 形如 "0000:19:00.0"
device_list[dev.name.decode()] = pci_addr

return device_list

def get_gpu_pci_address(gpu_no):
""" 获取指定 GPU 设备的 PCI 地址 """
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_no)
pci_info = pynvml.nvmlDeviceGetPciInfo(handle)
pynvml.nvmlShutdown()
return pci_info.busId #

def get_net_device_from_rdma(rdma_dev):
""" 获取 RoCE 设备对应的网卡名 """
net_path = f"/sys/class/infiniband/{rdma_dev}/device/net"
if os.path.exists(net_path):
return os.listdir(net_path)[0] # 读取网卡名
return None
def normalize_pci_addr(pci_addr):
""" 统一 PCI 地址格式,例如 00000000:08:00.0 -> 0000:08:00.0 """
parts = pci_addr.split(":")
if len(parts) == 3: # 形如 "00000000:08:00.0"
return f"{int(parts[0], 16):04x}:{parts[1]}:{parts[2]}" # 转换为 "0000:08:00.0"
return pci_addr # 返回原始格式

def find_best_roce_for_gpu(gpu_no, prefix="mlx"):
""" 根据 GPU 设备号找到最亲和的 RoCE 网卡 """
gpu_pci = normalize_pci_addr(get_gpu_pci_address(gpu_no))
roce_devices = {k: normalize_pci_addr(v) for k, v in get_device_list(prefix).items()}

best_rdma_dev = None
min_distance = float("inf")

for rdma_dev, rdma_pci in roce_devices.items():
if rdma_pci[:5] == gpu_pci[:5]: # **确保同一 NUMA 节点**
distance = abs(int(rdma_pci.split(":")[1], 16) - int(gpu_pci.split(":")[1], 16))
if distance < min_distance:
min_distance = distance
best_rdma_dev = rdma_dev

if best_rdma_dev:
net_dev = get_net_device_from_rdma(best_rdma_dev)
return best_rdma_dev, net_dev

if __name__ == '__main__':
gpu_no = 0 # 需要查询的 GPU 设备号
rdma_dev, net_dev = find_best_roce_for_gpu(gpu_no)
print(f"GPU {gpu_no} 最亲和的 RDMA 设备: {rdma_dev}, 对应网卡: {net_dev}")
Loading