diff --git a/python/sglang/srt/disaggregation/conn.py b/python/sglang/srt/disaggregation/conn.py index f61add327d1..a2054a4dfbd 100644 --- a/python/sglang/srt/disaggregation/conn.py +++ b/python/sglang/srt/disaggregation/conn.py @@ -74,7 +74,7 @@ class KVPoll: class KVManager: # TODO: make it general and support multiple transfer backend before merging def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode): - self.engine = MooncakeTransferEngine() + self.engine = MooncakeTransferEngine(args.gpu_id) self.kv_args = args self.disaggregation_mode = disaggregation_mode self.request_pool: RequestPoolType = {} diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index e0918a083e0..a34ffea0136 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -99,6 +99,7 @@ def __init__( def _init_kv_manager(self) -> KVManager: kv_args = KVArgs() kv_args.engine_rank = self.tp_rank + kv_args.gpu_id = self.scheduler.gpu_id kv_data_ptrs, kv_data_lens, kv_item_lens = ( self.token_to_kv_pool.get_contiguous_buf_infos() ) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 25ab54bb818..1e586bac36f 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -56,6 +56,7 @@ def __init__( tp_size: int, bootstrap_port: int, gloo_group: ProcessGroup, + scheduler: Scheduler, ): self.token_to_kv_pool = token_to_kv_pool self.aux_dtype = aux_dtype @@ -68,6 +69,7 @@ def __init__( self.queue: List[Req] = [] self.gloo_group = gloo_group self.bootstrap_port = bootstrap_port + self.scheduler = scheduler def allocate_token_id(self, idx: int, token_id: int): assert token_id >= 0, f"token_id: {token_id} is negative" @@ -84,6 +86,7 @@ def _init_kv_manager(self) -> KVManager: kv_args.kv_data_ptrs = kv_data_ptrs kv_args.kv_data_lens = kv_data_lens kv_args.kv_item_lens = kv_item_lens + kv_args.gpu_id = self.scheduler.gpu_id # Define req -> input ids buffer kv_args.aux_data_ptrs = [ diff --git a/python/sglang/srt/disaggregation/rdma_device_utils.py b/python/sglang/srt/disaggregation/rdma_device_utils.py new file mode 100644 index 00000000000..bfee3a087b2 --- /dev/null +++ b/python/sglang/srt/disaggregation/rdma_device_utils.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# coding:utf-8 +""" +@author: nivic ybyang7 +@license: Apache Licence +@file: ib_devices +@time: 2025/04/03 +@contact: ybyang7@iflytek.com +@site: +@software: PyCharm + +# Code is far away from bugs with the god animal protecting + I love animals. They taste delicious. + ┏┓ ┏┓ + ┏┛┻━━━┛┻┓ + ┃ ☃ ┃ + ┃ ┳┛ ┗┳ ┃ + ┃ ┻ ┃ + ┗━┓ ┏━┛ + ┃ ┗━━━┓ + ┃ God Bless ┣┓ + ┃ No BUG! ┏┛ + ┗┓┓┏━┳┓┏┛ + ┃┫┫ ┃┫┫ + ┗┻┛ ┗┻┛ +""" +import os + +# Copyright (c) 2022. Lorem ipsum dolor sit amet, consectetur adipiscing elit. +# Morbi non lorem porttitor neque feugiat blandit. Ut vitae ipsum eget quam lacinia accumsan. +# Etiam sed turpis ac ipsum condimentum fringilla. Maecenas magna. +# Proin dapibus sapien vel ante. Aliquam erat volutpat. Pellentesque sagittis ligula eget metus. +# Vestibulum commodo. Ut rhoncus gravida arcu. +import pyverbs.device as d +import pynvml + + +def get_device_list(prefix, gpu_no=0, roce_version=2, port_num=1): + """ + Get a list of RDMA devices matching the specified prefix. + + Args: + prefix (str): Device name prefix to filter (e.g., 'mlx') + gpu_no (int): GPU device number (default: 0) + roce_version (int): RoCE version to use (default: 2) + port_num (int): Port number to query (default: 1) + + Returns: + dict: Dictionary mapping RDMA device names to their PCI addresses + """ + lst = d.get_device_list() + if len(lst) == 0: + print("No IB devices") + return [] + device_list = {} + for dev in lst: + if dev.name.decode().startswith(prefix): + with d.Context(name=dev.name.decode()) as ctx: + gid_tbl_len = ctx.query_port(port_num).gid_tbl_len + if gid_tbl_len > 0: + ctx.query_gid(port_num=port_num, index=roce_version) + # Get PCI address from sysfs + dev_path = f"/sys/class/infiniband/{dev.name.decode()}/device" + if os.path.exists(dev_path): + pci_addr = os.readlink(dev_path).split("/")[-1] # Format like "0000:19:00.0" + device_list[dev.name.decode()] = pci_addr + + return device_list + + +def get_gpu_pci_address(gpu_no): + """ + Get the PCI address of a specified GPU device. + + Args: + gpu_no (int): GPU device number + + Returns: + str: PCI address of the GPU device + """ + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_no) + pci_info = pynvml.nvmlDeviceGetPciInfo(handle) + pynvml.nvmlShutdown() + return pci_info.busId + + +def get_net_device_from_rdma(rdma_dev): + """ + Get the network interface name corresponding to a RoCE device. + + Args: + rdma_dev (str): RDMA device name + + Returns: + str: Network interface name or None if not found + """ + net_path = f"/sys/class/infiniband/{rdma_dev}/device/net" + if os.path.exists(net_path): + return os.listdir(net_path)[0] # Read network interface name + return None + + +def normalize_pci_addr(pci_addr): + """ + Standardize PCI address format. + + Args: + pci_addr (str): PCI address to normalize + + Returns: + str: Normalized PCI address in format "0000:08:00.0" + """ + parts = pci_addr.split(":") + if len(parts) == 3: # Format like "00000000:08:00.0" + return f"{int(parts[0], 16):04x}:{parts[1]}:{parts[2]}" # Convert to "0000:08:00.0" + return pci_addr # Return original format + + +def find_best_rdma_device_for_gpu(gpu_no, prefix="mlx"): + """ + Find the most affinity RoCE network card for a given GPU. + + Args: + gpu_no (int): GPU device number + prefix (str): RDMA device name prefix (default: "mlx") + + Returns: + tuple: (best_rdma_dev, net_dev) containing the best RDMA device and its network interface + """ + gpu_pci = normalize_pci_addr(get_gpu_pci_address(gpu_no)) + roce_devices = {k: normalize_pci_addr(v) for k, v in get_device_list(prefix).items()} + + best_rdma_dev = None + min_distance = float("inf") + + for rdma_dev, rdma_pci in roce_devices.items(): + if rdma_pci[:5] == gpu_pci[:5]: # Ensure same NUMA node + distance = abs(int(rdma_pci.split(":")[1], 16) - int(gpu_pci.split(":")[1], 16)) + if distance < min_distance: + min_distance = distance + best_rdma_dev = rdma_dev + + if best_rdma_dev: + net_dev = get_net_device_from_rdma(best_rdma_dev) + return best_rdma_dev, net_dev + + +if __name__ == '__main__': + gpu_no = 0 # GPU device number to query + rdma_dev, net_dev = find_best_roce_for_gpu(gpu_no) + print(f"GPU {gpu_no} most affinity RDMA device: {rdma_dev}, corresponding network interface: {net_dev}") diff --git a/python/sglang/srt/disaggregation/transfer_engine/mooncake.py b/python/sglang/srt/disaggregation/transfer_engine/mooncake.py index bdba72579a4..b0ecfb67243 100644 --- a/python/sglang/srt/disaggregation/transfer_engine/mooncake.py +++ b/python/sglang/srt/disaggregation/transfer_engine/mooncake.py @@ -4,6 +4,9 @@ import uuid from dataclasses import dataclass +from sglang.srt.utils import get_local_ip_by_remote +from sglang.srt.disaggregation.rdma_device_utils import find_best_rdma_device_for_gpu + logger = logging.getLogger(__name__) @@ -27,19 +30,36 @@ def from_file(file_path: str) -> "MooncakeTransferEngineConfig": ) @staticmethod - def load_from_env() -> "MooncakeTransferEngineConfig": + def load_config(gpu_id=None) -> "MooncakeTransferEngineConfig": """Load config from a file specified in the environment variable.""" config_file_path = os.getenv("MOONCAKE_CONFIG_PATH") if config_file_path is None: - raise ValueError( - "The environment variable 'MOONCAKE_CONFIG_PATH' is not set." - ) + logger.info("No config set for 'MOONCAKE_CONFIG_PATH', specified env is preferred") + return MooncakeTransferEngineConfig.auto_config(gpu_id) return MooncakeTransferEngineConfig.from_file(config_file_path) + @staticmethod + def load_auto_config(gpu_id) -> "MooncakeTransferEngineConfig": + """Load config from a file specified in the environment variable.""" + metadata_server = os.getenv("MOONCAKE_METADATA_SERVER", None) + if metadata_server is None: + raise ValueError( + "The environment variable 'MOONCAKE_METADATA_SERVER' is not set." + ) + local_hostname = os.getenv("MOONCAKE_LOCAL_HOSTNAME", default=get_local_ip_by_remote()) + protocol = os.getenv("MOONCAKE_PROTOCOL", default="rdma") + default_ib_device, _ = find_best_rdma_device_for_gpu(gpu_id) + device_name = os.getenv("MOONCAKE_RDMA_DEVICE_NAME", default=default_ib_device) + return MooncakeTransferEngineConfig( + local_hostname=local_hostname, + metadata_server=metadata_server, + protocol=protocol, + device_name=device_name, + ) class MooncakeTransferEngine: - def __init__(self): + def __init__(self, gpu_id=0): try: from mooncake.engine import TransferEngine except ImportError as e: @@ -52,7 +72,7 @@ def __init__(self): self.engine = TransferEngine() try: - self.config = MooncakeTransferEngineConfig.load_from_env() + self.config = MooncakeTransferEngineConfig.load_auto_config(gpu_id) logger.info("Mooncake Configuration loaded successfully.") except ValueError as e: logger.error(e) @@ -61,8 +81,6 @@ def __init__(self): logger.error("An error occurred while loading the configuration: %s", exc) raise - self.config = MooncakeTransferEngineConfig.load_from_env() - session_suffix = "_" + str(uuid.uuid4()) self.session_id = self.config.local_hostname + session_suffix self.initialize( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 359573fc692..87a6b161e59 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -590,6 +590,7 @@ def init_disaggregation(self): tp_size=self.tp_size, bootstrap_port=self.server_args.disaggregation_bootstrap_port, gloo_group=self.tp_worker.get_attention_tp_cpu_group(), + scheduler=self, ) # The prefill requests that are in the middle of kv sending self.disagg_prefill_infight_queue: List[Req] = [] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d68fa489bee..9501792ca13 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1828,3 +1828,13 @@ def fast_topk(values, topk, dim): else: # Use topk for efficiency with larger k values return torch.topk(values, topk, dim=dim) + +def get_local_ip_by_remote(addr="8.8.8.8:8888"): + """ + Get Local IP Connecting Remote Addr + """ + + host, port = addr.split(":") + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.connect((host, int(port.strip()))) # connecting fake server to get ip host + return s.getsockname()[0] \ No newline at end of file