diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ccb1cf08fd0..917518b0492 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -102,6 +102,7 @@ enable_show_time_cost, get_available_gpu_memory, get_bool_env_var, + get_cpu_ids_by_node, init_custom_process_group, is_cuda, is_fa3_default_architecture, @@ -211,6 +212,10 @@ def __init__( # CPU offload set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + # Init OpenMP threads binding for CPU + if self.device == "cpu": + self.init_threads_binding() + # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -477,6 +482,15 @@ def init_torch_distributed(self): set_mscclpp_all_reduce(self.server_args.enable_mscclpp) if not self.is_draft_worker: + if self.device == "cpu": + if _is_cpu_amx_available: + # Bind OpenMP threads to CPU cores + torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) + else: + logger.warning( + "init_cpu_threads_env is skipped since intel amx backend is not available" + ) + # Only initialize the distributed environment on the target model worker. init_distributed_environment( backend=backend, @@ -1273,6 +1287,30 @@ def init_cuda_graphs(self): f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB." ) + def init_threads_binding(self): + omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all") + if omp_cpuids == "all": + cpu_ids_by_node = get_cpu_ids_by_node() + n_numa_node = len(cpu_ids_by_node) + + assert self.tp_size <= n_numa_node, ( + f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, " + f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. " + f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. " + f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, " + f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. " + f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. " + f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. " + f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2." + ) + if self.tp_size < n_numa_node: + logger.warning( + f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used." + ) + self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank] + else: + self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank] + def apply_torch_tp(self): logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") from sglang.srt.model_parallel import tensor_parallel diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 8a91c2fc4d0..67851c8c3c8 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -40,6 +40,7 @@ import time import traceback import warnings +from collections import defaultdict from contextlib import contextmanager from enum import Enum from functools import lru_cache @@ -2545,3 +2546,69 @@ def align(x: int, y: int) -> int: # COPIED FROM DeepGEMM def ceil_div(x: int, y: int) -> int: return (x + y - 1) // y + + +def parse_lscpu_topology(): + try: + # Get CPU topology: CPU,Core,Socket,Node + output = subprocess.check_output( + ["lscpu", "-p=CPU,Core,Socket,Node"], text=True + ) + except Exception as e: + raise RuntimeError(f"Unexpected error running 'lscpu': {e}") + + # Parse only data lines (skip comments) + cpu_info = [] + for line in output.splitlines(): + if not line.startswith("#"): + cpu, core, socket, node = map(int, line.strip().split(",")) + cpu_info.append((cpu, core, socket, node)) + + # [(0,0,0,0),(1,1,0,0),...,(43,43,0,1),...,(256,0,0,0),...] + return cpu_info + + +def get_physical_cpus_by_numa(): + cpu_info = parse_lscpu_topology() + + # Map NUMA node -> set of (core_id, socket) to avoid duplicates + # 0: {(0,0): 0, (1, 0): 1,...} + # ... + # 5: {(214,1): 214, (215,1): 215} + physical_by_node = defaultdict(dict) # node -> core_id -> cpu_id + + for cpu, core, socket, node in cpu_info: + key = (core, socket) + if key not in physical_by_node[node]: + physical_by_node[node][ + key + ] = cpu # pick first CPU seen for that physical core + + # Retrieves CPUs that the current process is allowed to run on + cpus_allowed_list = psutil.Process().cpu_affinity() + + # Convert to list of physical CPUs per node + # 0: [0,1,2,...,42] + # ... + # 2: [86,87,...,127] + # ... + # 5: [214,215,...,255] + node_to_cpus = {} + for node, core_to_cpu in physical_by_node.items(): + cpus = sorted(core_to_cpu.values()) + allowed_cpus = set(cpus).intersection(cpus_allowed_list) + node_to_cpus[node] = allowed_cpus + + return node_to_cpus + + +# Only physical cores are used. Logical cores are excluded. +def get_cpu_ids_by_node(): + node_to_cpus = get_physical_cpus_by_numa() + # Sort by NUMA node index + cpu_ids = [ + ",".join(map(str, sorted(node_to_cpus[node]))) for node in sorted(node_to_cpus) + ] + + # ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23'] + return cpu_ids