From 92c9d64c73b6fd197afc9e0a0f6df61c535c5005 Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Sun, 16 Feb 2025 20:42:10 +0800 Subject: [PATCH 01/29] bind OpenMP threads to CPU cores (#4) --- python/sglang/srt/model_executor/model_runner.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8b9a367f492..2cc9d450ff9 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -211,6 +211,13 @@ def __init__( # CPU offload set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) + # Init OpenMP threads binding + omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all") + if omp_cpuids == "all": + self.local_omp_cpuid = "all" + else: + self.local_omp_cpuid = omp_cpuids.split("|")[tp_rank] + # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -477,6 +484,10 @@ def init_torch_distributed(self): set_mscclpp_all_reduce(self.server_args.enable_mscclpp) if not self.is_draft_worker: + # Bind OpenMP threads to CPU cores + if self.device == "cpu" and self.local_omp_cpuid != "all": + torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) + # Only initialize the distributed environment on the target model worker. init_distributed_environment( backend=backend, From e7a8212f6bccce2ee04cc8fee0b6eb0ddfab7642 Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Tue, 20 May 2025 12:46:29 +0800 Subject: [PATCH 02/29] De-couple vLLM CPU and remove dependency of `sudo apt-get install libnuma-dev` (#74) * port utils.cpp for numa binding from vllm into sglang * fix build * add pybind for init_cpu_threads_env * use conda prefix to find libnuma # how to ensure libnuma is installed via conda * replace init_cpu_threads_env in vllm with that in sgl-kernel * add vllm into srt_cpu * fix format --- .../sglang/srt/model_executor/model_runner.py | 8 +- sgl-kernel/csrc/cpu/numa_utils.cpp | 90 +++++++++++++++++++ sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 7 ++ 3 files changed, 102 insertions(+), 3 deletions(-) create mode 100644 sgl-kernel/csrc/cpu/numa_utils.cpp diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 2cc9d450ff9..91f72f7e41a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -484,9 +484,11 @@ def init_torch_distributed(self): set_mscclpp_all_reduce(self.server_args.enable_mscclpp) if not self.is_draft_worker: - # Bind OpenMP threads to CPU cores - if self.device == "cpu" and self.local_omp_cpuid != "all": - torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) + if self.device == "cpu": + import sgl_kernel.common_ops + # Bind OpenMP threads to CPU cores + if self.local_omp_cpuid != "all": + sgl_kernel.common_ops.init_cpu_threads_env(self.local_omp_cpuid) # Only initialize the distributed environment on the target model worker. init_distributed_environment( diff --git a/sgl-kernel/csrc/cpu/numa_utils.cpp b/sgl-kernel/csrc/cpu/numa_utils.cpp new file mode 100644 index 00000000000..fe2ac083fed --- /dev/null +++ b/sgl-kernel/csrc/cpu/numa_utils.cpp @@ -0,0 +1,90 @@ +#include +#include +#include +#include + +#include "common.h" + +std::string init_cpu_threads_env(const std::string& cpu_ids) { + bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); + TORCH_CHECK(omp_cpu_mask->size > 0); + std::vector omp_cpu_ids; + omp_cpu_ids.reserve(omp_cpu_mask->size); + + constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp); + + for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) { + unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size]; + int i = 0; + while (group_mask) { + if (group_mask & 1) { + omp_cpu_ids.emplace_back(offset + i); + } + ++i; + group_mask >>= 1; + } + } + + // Memory node binding + if (numa_available() != -1) { + int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front()); + bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str()); + bitmask* src_mask = numa_get_membind(); + + int pid = getpid(); + + // move all existing pages to the specified numa node. + *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); + int page_num = numa_migrate_pages(pid, src_mask, mask); + if (page_num == -1) { + TORCH_CHECK(false, + "numa_migrate_pages failed. errno: " + std::to_string(errno)); + } + + // restrict memory allocation node. + numa_set_membind(mask); + numa_set_strict(1); + } + + // OMP threads binding + omp_set_num_threads((int)omp_cpu_ids.size()); + at::set_num_threads((int)omp_cpu_ids.size()); + TORCH_CHECK_EQ(omp_cpu_ids.size(), at::get_num_threads()); + TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads()); + + std::vector> thread_core_mapping; + thread_core_mapping.reserve(omp_cpu_ids.size()); + omp_lock_t writelock; + omp_init_lock(&writelock); + +#pragma omp parallel for schedule(static, 1) + for (size_t i = 0; i < omp_cpu_ids.size(); ++i) { + cpu_set_t mask; + CPU_ZERO(&mask); + CPU_SET(omp_cpu_ids[i], &mask); + int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask); + if (ret == -1) { + TORCH_CHECK(false, + "sched_setaffinity failed. errno: " + std::to_string(errno)); + } + + omp_set_lock(&writelock); + thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]); + omp_unset_lock(&writelock); + } + + omp_destroy_lock(&writelock); + + numa_free_nodemask(omp_cpu_mask); + + std::stringstream ss; + ss << "OMP threads binding of Process " << getpid() << ":\n"; + std::sort(thread_core_mapping.begin(), thread_core_mapping.end(), + [](auto&& a, auto&& b) { return a.second < b.second; }); + for (auto&& item : thread_core_mapping) { + ss << "\t" + << "OMP tid: " << item.first << ", core " << item.second << "\n"; + } + + return ss.str(); +} diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 7c26c354fd6..cfcbd813d25 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -227,6 +227,9 @@ std::tuple rotary_embedding_cpu( at::Tensor& cos_sin_cache, bool is_neox); +// CPU and memory binding +std::string init_cpu_threads_env(const std::string& cpu_ids); + TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // activation m.def("silu_and_mul_cpu(Tensor input) -> Tensor"); @@ -353,6 +356,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, " "bool is_neox) -> (Tensor, Tensor)"); m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu); + + // CPU and memory binding + m.def("init_cpu_threads_env(str cpu_ids) -> str"); + m.impl("init_cpu_threads_env", torch::kCPU, &init_cpu_threads_env); } REGISTER_EXTENSION(common_ops) From 35ac92a81e190c8f70422a2d3f240e8518d331b6 Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Wed, 21 May 2025 14:55:59 +0800 Subject: [PATCH 03/29] Automatically bind threads if `SGLANG_CPU_OMP_THREADS_BIND` is not set (#77) * set local_omp_cpuid automatically * set self.local_omp_cpuid in __init__ * refine warning message * add more comments for example output of util functions * add try except for lscpu * refine warning message --- python/sglang/srt/cpu_utils.py | 80 +++++++++++++++++++ .../sglang/srt/model_executor/model_runner.py | 23 +++++- 2 files changed, 100 insertions(+), 3 deletions(-) create mode 100644 python/sglang/srt/cpu_utils.py diff --git a/python/sglang/srt/cpu_utils.py b/python/sglang/srt/cpu_utils.py new file mode 100644 index 00000000000..e689aee69fc --- /dev/null +++ b/python/sglang/srt/cpu_utils.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import subprocess +from collections import defaultdict + +def parse_lscpu_topology(): + try: + # Get CPU topology: CPU,Core,Socket,Node + output = subprocess.check_output( + ["lscpu", "-p=CPU,Core,Socket,Node"], text=True + ) + except Exception as e: + raise RuntimeError(f"Unexpected error running 'lscpu': {e}") + + # Parse only data lines (skip comments) + cpu_info = [] + for line in output.splitlines(): + if not line.startswith("#"): + cpu, core, socket, node = map(int, line.strip().split(",")) + cpu_info.append((cpu, core, socket, node)) + + # [(0,0,0,0),(1,1,0,0),...,(43,43,0,1),...,(256,0,0,0),...] + return cpu_info + + +def get_physical_cpus_by_numa(): + cpu_info = parse_lscpu_topology() + + # Map NUMA node -> set of (core_id, socket) to avoid duplicates + # 0: {(0,0): 0, (1, 0): 1,...} + # ... + # 5: {(214,1): 214, (215,1): 215} + physical_by_node = defaultdict(dict) # node -> core_id -> cpu_id + + for cpu, core, socket, node in cpu_info: + key = (core, socket) + if key not in physical_by_node[node]: + physical_by_node[node][ + key + ] = cpu # pick first CPU seen for that physical core + + # Convert to list of physical CPUs per node + # 0: [0,1,2,...,42] + # ... + # 2: [86,87,...,127] + # ... + # 5: [214,215,...,255] + node_to_cpus = {} + for node, core_to_cpu in physical_by_node.items(): + cpus = sorted(core_to_cpu.values()) + node_to_cpus[node] = cpus + + return node_to_cpus + + +# Compress sorted list of integers into range strings like 0-2,3,4-6 +def compress_ranges(cpu_list): + if not cpu_list: + return "" + ranges = [] + start = prev = cpu_list[0] + for cpu in cpu_list[1:]: + if cpu == prev + 1: + prev = cpu + else: + ranges.append(f"{start}-{prev}" if start != prev else str(start)) + start = prev = cpu + ranges.append(f"{start}-{prev}" if start != prev else str(start)) + return ",".join(ranges) + + +# Only physical cores are used. Logical cores are excluded. +def get_cpu_ids_by_node(): + node_to_cpus = get_physical_cpus_by_numa() + # Sort by NUMA node index + cpu_ids = [ + compress_ranges(sorted(node_to_cpus[node])) for node in sorted(node_to_cpus) + ] + # ['0-42', '43-85', '86-127', '128-170', '171-213', '214-255'] + return cpu_ids diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 91f72f7e41a..118b48b58a8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -30,6 +30,7 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS +from sglang.srt.cpu_utils import get_cpu_ids_by_node from sglang.srt.distributed import ( get_tp_group, get_world_group, @@ -214,7 +215,24 @@ def __init__( # Init OpenMP threads binding omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all") if omp_cpuids == "all": - self.local_omp_cpuid = "all" + cpu_ids_by_node = get_cpu_ids_by_node() + n_numa_node = len(cpu_ids_by_node) + + assert self.tp_size <= n_numa_node, ( + f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, " + f"tp_size {self.tp_size} should be smaller than number of numa node on the machine {n_numa_node}. " + f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. " + f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, " + f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. " + f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. " + f"If you do need tp_size to be large than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. " + f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2." + ) + if self.tp_size < n_numa_node: + logger.warning( + f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so use only {self.tp_size} numa nodes are used." + ) + self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank] else: self.local_omp_cpuid = omp_cpuids.split("|")[tp_rank] @@ -487,8 +505,7 @@ def init_torch_distributed(self): if self.device == "cpu": import sgl_kernel.common_ops # Bind OpenMP threads to CPU cores - if self.local_omp_cpuid != "all": - sgl_kernel.common_ops.init_cpu_threads_env(self.local_omp_cpuid) + sgl_kernel.common_ops.init_cpu_threads_env(self.local_omp_cpuid) # Only initialize the distributed environment on the target model worker. init_distributed_environment( From 1ac1ee5f0a05bbce41d92981df34861bc9b72d89 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 22 May 2025 14:11:14 +0000 Subject: [PATCH 04/29] use torch.ops --- python/sglang/srt/model_executor/model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 118b48b58a8..c15ff6625b1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -503,9 +503,9 @@ def init_torch_distributed(self): if not self.is_draft_worker: if self.device == "cpu": - import sgl_kernel.common_ops + import sgl_kernel # Bind OpenMP threads to CPU cores - sgl_kernel.common_ops.init_cpu_threads_env(self.local_omp_cpuid) + torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) # Only initialize the distributed environment on the target model worker. init_distributed_environment( From d6ffa4b4d99b750c16dc59a72458ef69b3cfa5a7 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 22 May 2025 14:21:27 +0000 Subject: [PATCH 05/29] add numa related change in CMakeList --- sgl-kernel/csrc/cpu/CMakeLists.txt | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sgl-kernel/csrc/cpu/CMakeLists.txt b/sgl-kernel/csrc/cpu/CMakeLists.txt index 355a6ab4764..bed16b40478 100755 --- a/sgl-kernel/csrc/cpu/CMakeLists.txt +++ b/sgl-kernel/csrc/cpu/CMakeLists.txt @@ -38,6 +38,13 @@ else() endif() link_directories(${PLAT_LIB_DIR}) +# Conda library path support +if(DEFINED ENV{CONDA_PREFIX}) + set(CONDA_LIB_DIR "$ENV{CONDA_PREFIX}/lib") + message(STATUS "Using Conda lib dir: ${CONDA_LIB_DIR}") + link_directories(${CONDA_LIB_DIR}) +endif() + file(GLOB SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp") add_compile_options( @@ -48,7 +55,7 @@ add_compile_options( ) Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) -target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES}) +target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} numa) target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS}) install(TARGETS common_ops From 2b4f8923fe0e255c11b575d063d34349ad089a4a Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 22 May 2025 14:31:47 +0000 Subject: [PATCH 06/29] use CatchAll as dispatch key for init_cpu_threads_env since it does not have tensor input --- sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index cfcbd813d25..5245b468a75 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -359,7 +359,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // CPU and memory binding m.def("init_cpu_threads_env(str cpu_ids) -> str"); - m.impl("init_cpu_threads_env", torch::kCPU, &init_cpu_threads_env); +} + +TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) { + m.impl("init_cpu_threads_env", init_cpu_threads_env); } REGISTER_EXTENSION(common_ops) From 8daac4504f9657d9333724402f782b452f182234 Mon Sep 17 00:00:00 2001 From: y Date: Thu, 22 May 2025 07:58:43 +0000 Subject: [PATCH 07/29] move binding code to init_threads_binding --- .../sglang/srt/model_executor/model_runner.py | 51 ++++++++++--------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c15ff6625b1..81c912e762f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -212,29 +212,9 @@ def __init__( # CPU offload set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3)) - # Init OpenMP threads binding - omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all") - if omp_cpuids == "all": - cpu_ids_by_node = get_cpu_ids_by_node() - n_numa_node = len(cpu_ids_by_node) - - assert self.tp_size <= n_numa_node, ( - f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, " - f"tp_size {self.tp_size} should be smaller than number of numa node on the machine {n_numa_node}. " - f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. " - f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, " - f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. " - f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. " - f"If you do need tp_size to be large than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. " - f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2." - ) - if self.tp_size < n_numa_node: - logger.warning( - f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so use only {self.tp_size} numa nodes are used." - ) - self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank] - else: - self.local_omp_cpuid = omp_cpuids.split("|")[tp_rank] + # Init OpenMP threads binding for CPU + if self.device == "cpu": + self.init_threads_binding() # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() @@ -504,6 +484,7 @@ def init_torch_distributed(self): if not self.is_draft_worker: if self.device == "cpu": import sgl_kernel + # Bind OpenMP threads to CPU cores torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) @@ -1300,6 +1281,30 @@ def init_cuda_graphs(self): f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." ) + def init_threads_binding(self): + omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all") + if omp_cpuids == "all": + cpu_ids_by_node = get_cpu_ids_by_node() + n_numa_node = len(cpu_ids_by_node) + + assert self.tp_size <= n_numa_node, ( + f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, " + f"tp_size {self.tp_size} should be smaller than number of numa node on the machine {n_numa_node}. " + f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. " + f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, " + f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. " + f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. " + f"If you do need tp_size to be large than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. " + f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2." + ) + if self.tp_size < n_numa_node: + logger.warning( + f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so use only {self.tp_size} numa nodes are used." + ) + self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank] + else: + self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank] + def apply_torch_tp(self): logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.") from sglang.srt.model_parallel import tensor_parallel From 04cafa000cbae79f8abcdf24cad7b5ab78d0ff4d Mon Sep 17 00:00:00 2001 From: y Date: Thu, 22 May 2025 07:59:12 +0000 Subject: [PATCH 08/29] fix format --- python/sglang/srt/cpu_utils.py | 1 + sgl-kernel/csrc/cpu/numa_utils.cpp | 13 ++++++------- sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/cpu_utils.py b/python/sglang/srt/cpu_utils.py index e689aee69fc..fd457f91b6b 100644 --- a/python/sglang/srt/cpu_utils.py +++ b/python/sglang/srt/cpu_utils.py @@ -3,6 +3,7 @@ import subprocess from collections import defaultdict + def parse_lscpu_topology(): try: # Get CPU topology: CPU,Core,Socket,Node diff --git a/sgl-kernel/csrc/cpu/numa_utils.cpp b/sgl-kernel/csrc/cpu/numa_utils.cpp index fe2ac083fed..e5eb64bcfbd 100644 --- a/sgl-kernel/csrc/cpu/numa_utils.cpp +++ b/sgl-kernel/csrc/cpu/numa_utils.cpp @@ -1,7 +1,8 @@ #include +#include #include + #include -#include #include "common.h" @@ -37,8 +38,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); int page_num = numa_migrate_pages(pid, src_mask, mask); if (page_num == -1) { - TORCH_CHECK(false, - "numa_migrate_pages failed. errno: " + std::to_string(errno)); + TORCH_CHECK(false, "numa_migrate_pages failed. errno: " + std::to_string(errno)); } // restrict memory allocation node. @@ -64,8 +64,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { CPU_SET(omp_cpu_ids[i], &mask); int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask); if (ret == -1) { - TORCH_CHECK(false, - "sched_setaffinity failed. errno: " + std::to_string(errno)); + TORCH_CHECK(false, "sched_setaffinity failed. errno: " + std::to_string(errno)); } omp_set_lock(&writelock); @@ -79,8 +78,8 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { std::stringstream ss; ss << "OMP threads binding of Process " << getpid() << ":\n"; - std::sort(thread_core_mapping.begin(), thread_core_mapping.end(), - [](auto&& a, auto&& b) { return a.second < b.second; }); + std::sort( + thread_core_mapping.begin(), thread_core_mapping.end(), [](auto&& a, auto&& b) { return a.second < b.second; }); for (auto&& item : thread_core_mapping) { ss << "\t" << "OMP tid: " << item.first << ", core " << item.second << "\n"; diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 5245b468a75..17e2f824c8f 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -362,7 +362,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { } TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) { - m.impl("init_cpu_threads_env", init_cpu_threads_env); + m.impl("init_cpu_threads_env", init_cpu_threads_env); } REGISTER_EXTENSION(common_ops) From e9719ad2cc231d4a9c016584d35b1422e3d29b3a Mon Sep 17 00:00:00 2001 From: Chunyuan WU Date: Fri, 23 May 2025 09:28:48 +0800 Subject: [PATCH 09/29] replace gettid with syscall (#78) --- sgl-kernel/csrc/cpu/numa_utils.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sgl-kernel/csrc/cpu/numa_utils.cpp b/sgl-kernel/csrc/cpu/numa_utils.cpp index e5eb64bcfbd..27c66b689de 100644 --- a/sgl-kernel/csrc/cpu/numa_utils.cpp +++ b/sgl-kernel/csrc/cpu/numa_utils.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include #include @@ -68,7 +70,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { } omp_set_lock(&writelock); - thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]); + thread_core_mapping.emplace_back(syscall(SYS_gettid), omp_cpu_ids[i]); omp_unset_lock(&writelock); } From 6e6c97a27671adcb8ef92dd191a9b05762d063b7 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Fri, 23 May 2025 14:24:02 +0000 Subject: [PATCH 10/29] fix format --- sgl-kernel/csrc/cpu/numa_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgl-kernel/csrc/cpu/numa_utils.cpp b/sgl-kernel/csrc/cpu/numa_utils.cpp index 27c66b689de..fe391534048 100644 --- a/sgl-kernel/csrc/cpu/numa_utils.cpp +++ b/sgl-kernel/csrc/cpu/numa_utils.cpp @@ -1,7 +1,7 @@ #include +#include #include #include -#include #include #include From 40968a0ee361f3e76b064a4066d5b3991072659d Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Fri, 23 May 2025 14:45:43 +0000 Subject: [PATCH 11/29] add check if libnuma is in conda dir --- sgl-kernel/csrc/cpu/CMakeLists.txt | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sgl-kernel/csrc/cpu/CMakeLists.txt b/sgl-kernel/csrc/cpu/CMakeLists.txt index bed16b40478..d83aaeb5d3d 100755 --- a/sgl-kernel/csrc/cpu/CMakeLists.txt +++ b/sgl-kernel/csrc/cpu/CMakeLists.txt @@ -43,6 +43,15 @@ if(DEFINED ENV{CONDA_PREFIX}) set(CONDA_LIB_DIR "$ENV{CONDA_PREFIX}/lib") message(STATUS "Using Conda lib dir: ${CONDA_LIB_DIR}") link_directories(${CONDA_LIB_DIR}) + + # Look for libnuma in Conda's lib directory + find_library(NUMA_LIB numa HINTS "${CONDA_LIB_DIR}") + if(NUMA_LIB) + message(STATUS "Found libnuma: ${NUMA_LIB}") + else() + message(FATAL_ERROR "libnuma not found in Conda environment at ${CONDA_LIB_DIR}\n" + "Please install it using: conda install libnuma numactl\n") + endif() endif() file(GLOB SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp") @@ -55,7 +64,7 @@ add_compile_options( ) Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) -target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} numa) +target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} ${NUMA_LIB}) target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS}) install(TARGETS common_ops From 6319532ce56bcc96d208c5ef4acf014e5d6ccd84 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Fri, 23 May 2025 15:21:12 +0000 Subject: [PATCH 12/29] add UT --- test/srt/cpu/test_binding.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 test/srt/cpu/test_binding.py diff --git a/test/srt/cpu/test_binding.py b/test/srt/cpu/test_binding.py new file mode 100644 index 00000000000..cbd4b37a59e --- /dev/null +++ b/test/srt/cpu/test_binding.py @@ -0,0 +1,26 @@ +import re +import unittest + +import sgl_kernel +import torch +kernel = torch.ops.sgl_kernel + +from sglang.test.test_utils import CustomTestCase + + +class TestGemm(CustomTestCase): + def test_binding(self): + start_id = 0 + n_cpu = 6 + end_id = start_id + n_cpu - 1 + cpu_ids = f"{start_id}-{end_id}" + output = kernel.init_cpu_threads_env(cpu_ids) + + bindings = re.findall(r"OMP tid: \d+, core (\d+)", output) + self.assertEqual(len(bindings), n_cpu) + + expected_cores = list(map(str, range(start_id, n_cpu))) + self.assertEqual(bindings, expected_cores) + +if __name__ == "__main__": + unittest.main() From 9b2f0011c856b543b8080442334e5b9ce26940d9 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Tue, 3 Jun 2025 09:53:35 +0000 Subject: [PATCH 13/29] fix UT format --- test/srt/cpu/test_binding.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/srt/cpu/test_binding.py b/test/srt/cpu/test_binding.py index cbd4b37a59e..75de255400e 100644 --- a/test/srt/cpu/test_binding.py +++ b/test/srt/cpu/test_binding.py @@ -3,6 +3,7 @@ import sgl_kernel import torch + kernel = torch.ops.sgl_kernel from sglang.test.test_utils import CustomTestCase @@ -15,12 +16,13 @@ def test_binding(self): end_id = start_id + n_cpu - 1 cpu_ids = f"{start_id}-{end_id}" output = kernel.init_cpu_threads_env(cpu_ids) - + bindings = re.findall(r"OMP tid: \d+, core (\d+)", output) self.assertEqual(len(bindings), n_cpu) expected_cores = list(map(str, range(start_id, n_cpu))) self.assertEqual(bindings, expected_cores) + if __name__ == "__main__": unittest.main() From e3057c109ae8fe0d0883219dbb0480070d656cda Mon Sep 17 00:00:00 2001 From: y Date: Wed, 28 May 2025 13:37:45 +0000 Subject: [PATCH 14/29] add assert when calling init_cpu_threads_env --- python/sglang/srt/model_executor/model_runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 81c912e762f..9d2507ab9a0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -486,6 +486,9 @@ def init_torch_distributed(self): import sgl_kernel # Bind OpenMP threads to CPU cores + assert ( + cpu_has_amx_support() + ), "init_cpu_threads_env failed since intel amx backend is not available" torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) # Only initialize the distributed environment on the target model worker. From 6bdd01dd4583b2c2f54ad813f3f68e78cf932fa6 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 12 Jun 2025 13:18:19 +0000 Subject: [PATCH 15/29] use TORCH_WARN when checking return of numa_migrate_pages --- sgl-kernel/csrc/cpu/numa_utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgl-kernel/csrc/cpu/numa_utils.cpp b/sgl-kernel/csrc/cpu/numa_utils.cpp index fe391534048..2699d0e236d 100644 --- a/sgl-kernel/csrc/cpu/numa_utils.cpp +++ b/sgl-kernel/csrc/cpu/numa_utils.cpp @@ -40,7 +40,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); int page_num = numa_migrate_pages(pid, src_mask, mask); if (page_num == -1) { - TORCH_CHECK(false, "numa_migrate_pages failed. errno: " + std::to_string(errno)); + TORCH_WARN(false, "numa_migrate_pages failed. errno: " + std::to_string(errno)); } // restrict memory allocation node. From 3701e261282089e685e68c6d2ad0ae9faf11cb63 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 12 Jun 2025 14:43:57 +0000 Subject: [PATCH 16/29] join node_ids to string instead of using start-prev --- python/sglang/srt/cpu_utils.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/cpu_utils.py b/python/sglang/srt/cpu_utils.py index fd457f91b6b..978604af369 100644 --- a/python/sglang/srt/cpu_utils.py +++ b/python/sglang/srt/cpu_utils.py @@ -54,28 +54,13 @@ def get_physical_cpus_by_numa(): return node_to_cpus -# Compress sorted list of integers into range strings like 0-2,3,4-6 -def compress_ranges(cpu_list): - if not cpu_list: - return "" - ranges = [] - start = prev = cpu_list[0] - for cpu in cpu_list[1:]: - if cpu == prev + 1: - prev = cpu - else: - ranges.append(f"{start}-{prev}" if start != prev else str(start)) - start = prev = cpu - ranges.append(f"{start}-{prev}" if start != prev else str(start)) - return ",".join(ranges) - - # Only physical cores are used. Logical cores are excluded. def get_cpu_ids_by_node(): node_to_cpus = get_physical_cpus_by_numa() # Sort by NUMA node index cpu_ids = [ - compress_ranges(sorted(node_to_cpus[node])) for node in sorted(node_to_cpus) + ",".join(map(str, sorted(node_to_cpus[node]))) for node in sorted(node_to_cpus) ] - # ['0-42', '43-85', '86-127', '128-170', '171-213', '214-255'] + + # ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23'] return cpu_ids From ead33942e638047f636283fbb35adf0474f6691f Mon Sep 17 00:00:00 2001 From: y Date: Thu, 12 Jun 2025 06:51:20 +0000 Subject: [PATCH 17/29] add {CONDA_PREFIX}/include into include_directories --- sgl-kernel/csrc/cpu/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sgl-kernel/csrc/cpu/CMakeLists.txt b/sgl-kernel/csrc/cpu/CMakeLists.txt index d83aaeb5d3d..aa77fbad480 100755 --- a/sgl-kernel/csrc/cpu/CMakeLists.txt +++ b/sgl-kernel/csrc/cpu/CMakeLists.txt @@ -43,6 +43,8 @@ if(DEFINED ENV{CONDA_PREFIX}) set(CONDA_LIB_DIR "$ENV{CONDA_PREFIX}/lib") message(STATUS "Using Conda lib dir: ${CONDA_LIB_DIR}") link_directories(${CONDA_LIB_DIR}) + set(CONDA_INCLUDE_DIR "$ENV{CONDA_PREFIX}/include") + include_directories(${CONDA_INCLUDE_DIR}) # Look for libnuma in Conda's lib directory find_library(NUMA_LIB numa HINTS "${CONDA_LIB_DIR}") From 89fbe17f84b6fb0223c05e56a2d029ebe07ae9de Mon Sep 17 00:00:00 2001 From: y Date: Thu, 12 Jun 2025 07:10:43 +0000 Subject: [PATCH 18/29] move functions in cpu_utils.py to utils.py --- python/sglang/srt/cpu_utils.py | 66 ------------------- .../sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/utils.py | 63 ++++++++++++++++++ 3 files changed, 64 insertions(+), 67 deletions(-) delete mode 100644 python/sglang/srt/cpu_utils.py diff --git a/python/sglang/srt/cpu_utils.py b/python/sglang/srt/cpu_utils.py deleted file mode 100644 index 978604af369..00000000000 --- a/python/sglang/srt/cpu_utils.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -import subprocess -from collections import defaultdict - - -def parse_lscpu_topology(): - try: - # Get CPU topology: CPU,Core,Socket,Node - output = subprocess.check_output( - ["lscpu", "-p=CPU,Core,Socket,Node"], text=True - ) - except Exception as e: - raise RuntimeError(f"Unexpected error running 'lscpu': {e}") - - # Parse only data lines (skip comments) - cpu_info = [] - for line in output.splitlines(): - if not line.startswith("#"): - cpu, core, socket, node = map(int, line.strip().split(",")) - cpu_info.append((cpu, core, socket, node)) - - # [(0,0,0,0),(1,1,0,0),...,(43,43,0,1),...,(256,0,0,0),...] - return cpu_info - - -def get_physical_cpus_by_numa(): - cpu_info = parse_lscpu_topology() - - # Map NUMA node -> set of (core_id, socket) to avoid duplicates - # 0: {(0,0): 0, (1, 0): 1,...} - # ... - # 5: {(214,1): 214, (215,1): 215} - physical_by_node = defaultdict(dict) # node -> core_id -> cpu_id - - for cpu, core, socket, node in cpu_info: - key = (core, socket) - if key not in physical_by_node[node]: - physical_by_node[node][ - key - ] = cpu # pick first CPU seen for that physical core - - # Convert to list of physical CPUs per node - # 0: [0,1,2,...,42] - # ... - # 2: [86,87,...,127] - # ... - # 5: [214,215,...,255] - node_to_cpus = {} - for node, core_to_cpu in physical_by_node.items(): - cpus = sorted(core_to_cpu.values()) - node_to_cpus[node] = cpus - - return node_to_cpus - - -# Only physical cores are used. Logical cores are excluded. -def get_cpu_ids_by_node(): - node_to_cpus = get_physical_cpus_by_numa() - # Sort by NUMA node index - cpu_ids = [ - ",".join(map(str, sorted(node_to_cpus[node]))) for node in sorted(node_to_cpus) - ] - - # ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23'] - return cpu_ids diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9d2507ab9a0..15dc419566d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -30,7 +30,6 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS -from sglang.srt.cpu_utils import get_cpu_ids_by_node from sglang.srt.distributed import ( get_tp_group, get_world_group, @@ -103,6 +102,7 @@ enable_show_time_cost, get_available_gpu_memory, get_bool_env_var, + get_cpu_ids_by_node, init_custom_process_group, is_cuda, is_fa3_default_architecture, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 8a91c2fc4d0..7d415e20403 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -40,6 +40,7 @@ import time import traceback import warnings +from collections import defaultdict from contextlib import contextmanager from enum import Enum from functools import lru_cache @@ -2545,3 +2546,65 @@ def align(x: int, y: int) -> int: # COPIED FROM DeepGEMM def ceil_div(x: int, y: int) -> int: return (x + y - 1) // y + + +def parse_lscpu_topology(): + try: + # Get CPU topology: CPU,Core,Socket,Node + output = subprocess.check_output( + ["lscpu", "-p=CPU,Core,Socket,Node"], text=True + ) + except Exception as e: + raise RuntimeError(f"Unexpected error running 'lscpu': {e}") + + # Parse only data lines (skip comments) + cpu_info = [] + for line in output.splitlines(): + if not line.startswith("#"): + cpu, core, socket, node = map(int, line.strip().split(",")) + cpu_info.append((cpu, core, socket, node)) + + # [(0,0,0,0),(1,1,0,0),...,(43,43,0,1),...,(256,0,0,0),...] + return cpu_info + + +def get_physical_cpus_by_numa(): + cpu_info = parse_lscpu_topology() + + # Map NUMA node -> set of (core_id, socket) to avoid duplicates + # 0: {(0,0): 0, (1, 0): 1,...} + # ... + # 5: {(214,1): 214, (215,1): 215} + physical_by_node = defaultdict(dict) # node -> core_id -> cpu_id + + for cpu, core, socket, node in cpu_info: + key = (core, socket) + if key not in physical_by_node[node]: + physical_by_node[node][ + key + ] = cpu # pick first CPU seen for that physical core + + # Convert to list of physical CPUs per node + # 0: [0,1,2,...,42] + # ... + # 2: [86,87,...,127] + # ... + # 5: [214,215,...,255] + node_to_cpus = {} + for node, core_to_cpu in physical_by_node.items(): + cpus = sorted(core_to_cpu.values()) + node_to_cpus[node] = cpus + + return node_to_cpus + + +# Only physical cores are used. Logical cores are excluded. +def get_cpu_ids_by_node(): + node_to_cpus = get_physical_cpus_by_numa() + # Sort by NUMA node index + cpu_ids = [ + ",".join(map(str, sorted(node_to_cpus[node]))) for node in sorted(node_to_cpus) + ] + + # ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23'] + return cpu_ids From 584194dd7ad48de07bb462bfe76e45b6708a09f3 Mon Sep 17 00:00:00 2001 From: y Date: Thu, 12 Jun 2025 07:46:33 +0000 Subject: [PATCH 19/29] set ENV CONDA_PREFIX in Dockerfile --- docker/Dockerfile.xeon | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/Dockerfile.xeon b/docker/Dockerfile.xeon index 2f03b648536..fffc782dcc7 100644 --- a/docker/Dockerfile.xeon +++ b/docker/Dockerfile.xeon @@ -27,6 +27,7 @@ RUN curl -fsSL -v -o miniforge.sh -O https://github.com/conda-forge/miniforge/re ENV PATH=/sgl-workspace/miniforge3/bin:/sgl-workspace/miniforge3/condabin:${PATH} ENV PIP_ROOT_USER_ACTION=ignore +ENV CONDA_PREFIX=/sgl-workspace/miniforge3 RUN pip install intel-openmp From 705a6a68e24bd865ea0aa9a6e6de4aa5bb5319ab Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Thu, 12 Jun 2025 16:03:44 +0000 Subject: [PATCH 20/29] only bind to CPUs that are allowed --- python/sglang/srt/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 7d415e20403..67851c8c3c8 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2584,6 +2584,9 @@ def get_physical_cpus_by_numa(): key ] = cpu # pick first CPU seen for that physical core + # Retrieves CPUs that the current process is allowed to run on + cpus_allowed_list = psutil.Process().cpu_affinity() + # Convert to list of physical CPUs per node # 0: [0,1,2,...,42] # ... @@ -2593,7 +2596,8 @@ def get_physical_cpus_by_numa(): node_to_cpus = {} for node, core_to_cpu in physical_by_node.items(): cpus = sorted(core_to_cpu.values()) - node_to_cpus[node] = cpus + allowed_cpus = set(cpus).intersection(cpus_allowed_list) + node_to_cpus[node] = allowed_cpus return node_to_cpus From 1cb1f6044e3ed1707e5de5ef81a52ca875de2445 Mon Sep 17 00:00:00 2001 From: y Date: Thu, 12 Jun 2025 08:11:25 +0000 Subject: [PATCH 21/29] remove dead code --- python/sglang/srt/model_executor/model_runner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 15dc419566d..f25d766243b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -483,8 +483,6 @@ def init_torch_distributed(self): if not self.is_draft_worker: if self.device == "cpu": - import sgl_kernel - # Bind OpenMP threads to CPU cores assert ( cpu_has_amx_support() From d943a8c204284e5b70de929bdd380313f240124d Mon Sep 17 00:00:00 2001 From: y Date: Thu, 12 Jun 2025 08:20:44 +0000 Subject: [PATCH 22/29] update UT --- test/srt/cpu/test_binding.py | 8 ++++---- test/srt/run_suite.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/srt/cpu/test_binding.py b/test/srt/cpu/test_binding.py index 75de255400e..d3cc329af70 100644 --- a/test/srt/cpu/test_binding.py +++ b/test/srt/cpu/test_binding.py @@ -11,16 +11,16 @@ class TestGemm(CustomTestCase): def test_binding(self): - start_id = 0 + start_id = 1 n_cpu = 6 - end_id = start_id + n_cpu - 1 - cpu_ids = f"{start_id}-{end_id}" + + expected_cores = list(map(str, range(start_id, start_id + n_cpu))) + cpu_ids = ",".join(expected_cores) output = kernel.init_cpu_threads_env(cpu_ids) bindings = re.findall(r"OMP tid: \d+, core (\d+)", output) self.assertEqual(len(bindings), n_cpu) - expected_cores = list(map(str, range(start_id, n_cpu))) self.assertEqual(bindings, expected_cores) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index b28174f6d87..bfc8af1b262 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -183,6 +183,7 @@ class TestFile: ], "per-commit-cpu": [ TestFile("cpu/test_activation.py"), + TestFile("cpu/test_binding.py"), TestFile("cpu/test_decode.py"), TestFile("cpu/test_extend.py"), TestFile("cpu/test_gemm.py"), From 6f4656be90c7e570c2d9c638d0296ae0dc6d629f Mon Sep 17 00:00:00 2001 From: y Date: Thu, 12 Jun 2025 08:25:11 +0000 Subject: [PATCH 23/29] add UTs that can pass back --- test/srt/run_suite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index bfc8af1b262..a9c72b9524c 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -193,6 +193,7 @@ class TestFile: TestFile("cpu/test_qkv_proj_with_rope.py"), TestFile("cpu/test_rope.py"), TestFile("cpu/test_shared_expert.py"), + TestFile("cpu/test_topk.py"), ], "nightly": [ TestFile("test_nightly_gsm8k_eval.py"), From 2a8106458488ee278e230db76ee4a2c4bfa60068 Mon Sep 17 00:00:00 2001 From: srinarayan-srikanthan Date: Fri, 13 Jun 2025 07:51:34 -0700 Subject: [PATCH 24/29] prioritize torch cpu wheel host --- docker/Dockerfile.xeon | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile.xeon b/docker/Dockerfile.xeon index fffc782dcc7..9b84aec496b 100644 --- a/docker/Dockerfile.xeon +++ b/docker/Dockerfile.xeon @@ -4,6 +4,7 @@ SHELL ["/bin/bash", "-c"] ARG VER_SGLANG=main ARG VER_TORCH=2.6.0 ARG VER_TORCHVISION=0.21.0 +ARG VER_TRITON=3.1.0 RUN apt-get update && \ apt-get full-upgrade -y && \ @@ -29,13 +30,15 @@ ENV PATH=/sgl-workspace/miniforge3/bin:/sgl-workspace/miniforge3/condabin:${PATH ENV PIP_ROOT_USER_ACTION=ignore ENV CONDA_PREFIX=/sgl-workspace/miniforge3 -RUN pip install intel-openmp +RUN pip config set global.index-url https://download.pytorch.org/whl/cpu && \ + pip config set global.extra-index-url https://pypi.org/simple && \ + pip install intel-openmp RUN git clone https://github.com/sgl-project/sglang.git && \ cd sglang && \ git checkout ${VER_SGLANG} && \ pip install -e "python[all_cpu]" && \ - pip install torch==${VER_TORCH} torchvision==${VER_TORCHVISION} --index-url https://download.pytorch.org/whl/cpu --force-reinstall && \ + pip install torch==${VER_TORCH} torchvision==${VER_TORCHVISION} triton==${VER_TRITON} --force-reinstall && \ cd sgl-kernel && \ cp pyproject_cpu.toml pyproject.toml && \ pip install -v . From 2278dcc06fea6e6204731833ab35893b3182718b Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Mon, 16 Jun 2025 11:05:51 +0000 Subject: [PATCH 25/29] upgrade torch and dependencies to 2.7.1 stack --- docker/Dockerfile.xeon | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile.xeon b/docker/Dockerfile.xeon index 9b84aec496b..087e12ccaef 100644 --- a/docker/Dockerfile.xeon +++ b/docker/Dockerfile.xeon @@ -2,9 +2,9 @@ FROM ubuntu:24.04 SHELL ["/bin/bash", "-c"] ARG VER_SGLANG=main -ARG VER_TORCH=2.6.0 -ARG VER_TORCHVISION=0.21.0 -ARG VER_TRITON=3.1.0 +ARG VER_TORCH=2.7.1 +ARG VER_TORCHVISION=0.22.1 +ARG VER_TRITON=3.3.1 RUN apt-get update && \ apt-get full-upgrade -y && \ From 1e55a16b8c64cea9e1f1026d65b4ebcc0aabdc29 Mon Sep 17 00:00:00 2001 From: y Date: Wed, 25 Jun 2025 09:54:57 +0000 Subject: [PATCH 26/29] move sgl-kernel changes to #7524 --- docker/Dockerfile.xeon | 12 +-- sgl-kernel/csrc/cpu/CMakeLists.txt | 20 +---- sgl-kernel/csrc/cpu/numa_utils.cpp | 91 --------------------- sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 10 --- test/srt/cpu/test_binding.py | 28 ------- test/srt/run_suite.py | 2 - 6 files changed, 5 insertions(+), 158 deletions(-) delete mode 100644 sgl-kernel/csrc/cpu/numa_utils.cpp delete mode 100644 test/srt/cpu/test_binding.py diff --git a/docker/Dockerfile.xeon b/docker/Dockerfile.xeon index 087e12ccaef..2f03b648536 100644 --- a/docker/Dockerfile.xeon +++ b/docker/Dockerfile.xeon @@ -2,9 +2,8 @@ FROM ubuntu:24.04 SHELL ["/bin/bash", "-c"] ARG VER_SGLANG=main -ARG VER_TORCH=2.7.1 -ARG VER_TORCHVISION=0.22.1 -ARG VER_TRITON=3.3.1 +ARG VER_TORCH=2.6.0 +ARG VER_TORCHVISION=0.21.0 RUN apt-get update && \ apt-get full-upgrade -y && \ @@ -28,17 +27,14 @@ RUN curl -fsSL -v -o miniforge.sh -O https://github.com/conda-forge/miniforge/re ENV PATH=/sgl-workspace/miniforge3/bin:/sgl-workspace/miniforge3/condabin:${PATH} ENV PIP_ROOT_USER_ACTION=ignore -ENV CONDA_PREFIX=/sgl-workspace/miniforge3 -RUN pip config set global.index-url https://download.pytorch.org/whl/cpu && \ - pip config set global.extra-index-url https://pypi.org/simple && \ - pip install intel-openmp +RUN pip install intel-openmp RUN git clone https://github.com/sgl-project/sglang.git && \ cd sglang && \ git checkout ${VER_SGLANG} && \ pip install -e "python[all_cpu]" && \ - pip install torch==${VER_TORCH} torchvision==${VER_TORCHVISION} triton==${VER_TRITON} --force-reinstall && \ + pip install torch==${VER_TORCH} torchvision==${VER_TORCHVISION} --index-url https://download.pytorch.org/whl/cpu --force-reinstall && \ cd sgl-kernel && \ cp pyproject_cpu.toml pyproject.toml && \ pip install -v . diff --git a/sgl-kernel/csrc/cpu/CMakeLists.txt b/sgl-kernel/csrc/cpu/CMakeLists.txt index aa77fbad480..355a6ab4764 100755 --- a/sgl-kernel/csrc/cpu/CMakeLists.txt +++ b/sgl-kernel/csrc/cpu/CMakeLists.txt @@ -38,24 +38,6 @@ else() endif() link_directories(${PLAT_LIB_DIR}) -# Conda library path support -if(DEFINED ENV{CONDA_PREFIX}) - set(CONDA_LIB_DIR "$ENV{CONDA_PREFIX}/lib") - message(STATUS "Using Conda lib dir: ${CONDA_LIB_DIR}") - link_directories(${CONDA_LIB_DIR}) - set(CONDA_INCLUDE_DIR "$ENV{CONDA_PREFIX}/include") - include_directories(${CONDA_INCLUDE_DIR}) - - # Look for libnuma in Conda's lib directory - find_library(NUMA_LIB numa HINTS "${CONDA_LIB_DIR}") - if(NUMA_LIB) - message(STATUS "Found libnuma: ${NUMA_LIB}") - else() - message(FATAL_ERROR "libnuma not found in Conda environment at ${CONDA_LIB_DIR}\n" - "Please install it using: conda install libnuma numactl\n") - endif() -endif() - file(GLOB SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp") add_compile_options( @@ -66,7 +48,7 @@ add_compile_options( ) Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) -target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} ${NUMA_LIB}) +target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES}) target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS}) install(TARGETS common_ops diff --git a/sgl-kernel/csrc/cpu/numa_utils.cpp b/sgl-kernel/csrc/cpu/numa_utils.cpp deleted file mode 100644 index 2699d0e236d..00000000000 --- a/sgl-kernel/csrc/cpu/numa_utils.cpp +++ /dev/null @@ -1,91 +0,0 @@ -#include -#include -#include -#include -#include - -#include - -#include "common.h" - -std::string init_cpu_threads_env(const std::string& cpu_ids) { - bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); - TORCH_CHECK(omp_cpu_mask->size > 0); - std::vector omp_cpu_ids; - omp_cpu_ids.reserve(omp_cpu_mask->size); - - constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp); - - for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) { - unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size]; - int i = 0; - while (group_mask) { - if (group_mask & 1) { - omp_cpu_ids.emplace_back(offset + i); - } - ++i; - group_mask >>= 1; - } - } - - // Memory node binding - if (numa_available() != -1) { - int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front()); - bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str()); - bitmask* src_mask = numa_get_membind(); - - int pid = getpid(); - - // move all existing pages to the specified numa node. - *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); - int page_num = numa_migrate_pages(pid, src_mask, mask); - if (page_num == -1) { - TORCH_WARN(false, "numa_migrate_pages failed. errno: " + std::to_string(errno)); - } - - // restrict memory allocation node. - numa_set_membind(mask); - numa_set_strict(1); - } - - // OMP threads binding - omp_set_num_threads((int)omp_cpu_ids.size()); - at::set_num_threads((int)omp_cpu_ids.size()); - TORCH_CHECK_EQ(omp_cpu_ids.size(), at::get_num_threads()); - TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads()); - - std::vector> thread_core_mapping; - thread_core_mapping.reserve(omp_cpu_ids.size()); - omp_lock_t writelock; - omp_init_lock(&writelock); - -#pragma omp parallel for schedule(static, 1) - for (size_t i = 0; i < omp_cpu_ids.size(); ++i) { - cpu_set_t mask; - CPU_ZERO(&mask); - CPU_SET(omp_cpu_ids[i], &mask); - int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask); - if (ret == -1) { - TORCH_CHECK(false, "sched_setaffinity failed. errno: " + std::to_string(errno)); - } - - omp_set_lock(&writelock); - thread_core_mapping.emplace_back(syscall(SYS_gettid), omp_cpu_ids[i]); - omp_unset_lock(&writelock); - } - - omp_destroy_lock(&writelock); - - numa_free_nodemask(omp_cpu_mask); - - std::stringstream ss; - ss << "OMP threads binding of Process " << getpid() << ":\n"; - std::sort( - thread_core_mapping.begin(), thread_core_mapping.end(), [](auto&& a, auto&& b) { return a.second < b.second; }); - for (auto&& item : thread_core_mapping) { - ss << "\t" - << "OMP tid: " << item.first << ", core " << item.second << "\n"; - } - - return ss.str(); -} diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 17e2f824c8f..7c26c354fd6 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -227,9 +227,6 @@ std::tuple rotary_embedding_cpu( at::Tensor& cos_sin_cache, bool is_neox); -// CPU and memory binding -std::string init_cpu_threads_env(const std::string& cpu_ids); - TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // activation m.def("silu_and_mul_cpu(Tensor input) -> Tensor"); @@ -356,13 +353,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, " "bool is_neox) -> (Tensor, Tensor)"); m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu); - - // CPU and memory binding - m.def("init_cpu_threads_env(str cpu_ids) -> str"); -} - -TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) { - m.impl("init_cpu_threads_env", init_cpu_threads_env); } REGISTER_EXTENSION(common_ops) diff --git a/test/srt/cpu/test_binding.py b/test/srt/cpu/test_binding.py deleted file mode 100644 index d3cc329af70..00000000000 --- a/test/srt/cpu/test_binding.py +++ /dev/null @@ -1,28 +0,0 @@ -import re -import unittest - -import sgl_kernel -import torch - -kernel = torch.ops.sgl_kernel - -from sglang.test.test_utils import CustomTestCase - - -class TestGemm(CustomTestCase): - def test_binding(self): - start_id = 1 - n_cpu = 6 - - expected_cores = list(map(str, range(start_id, start_id + n_cpu))) - cpu_ids = ",".join(expected_cores) - output = kernel.init_cpu_threads_env(cpu_ids) - - bindings = re.findall(r"OMP tid: \d+, core (\d+)", output) - self.assertEqual(len(bindings), n_cpu) - - self.assertEqual(bindings, expected_cores) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index a9c72b9524c..b28174f6d87 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -183,7 +183,6 @@ class TestFile: ], "per-commit-cpu": [ TestFile("cpu/test_activation.py"), - TestFile("cpu/test_binding.py"), TestFile("cpu/test_decode.py"), TestFile("cpu/test_extend.py"), TestFile("cpu/test_gemm.py"), @@ -193,7 +192,6 @@ class TestFile: TestFile("cpu/test_qkv_proj_with_rope.py"), TestFile("cpu/test_rope.py"), TestFile("cpu/test_shared_expert.py"), - TestFile("cpu/test_topk.py"), ], "nightly": [ TestFile("test_nightly_gsm8k_eval.py"), From 7a0b0fb9dd904b9a9020b4d4e6836025f37af7c3 Mon Sep 17 00:00:00 2001 From: "Wu, Chunyuan" Date: Mon, 30 Jun 2025 10:02:49 +0000 Subject: [PATCH 27/29] refine warning message --- python/sglang/srt/model_executor/model_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f25d766243b..ed9198b32a9 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1290,17 +1290,17 @@ def init_threads_binding(self): assert self.tp_size <= n_numa_node, ( f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, " - f"tp_size {self.tp_size} should be smaller than number of numa node on the machine {n_numa_node}. " + f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. " f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. " f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, " f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. " f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. " - f"If you do need tp_size to be large than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. " + f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. " f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2." ) if self.tp_size < n_numa_node: logger.warning( - f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so use only {self.tp_size} numa nodes are used." + f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used." ) self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank] else: From 8d94c8d03a537c824c00be13f755594846982f5a Mon Sep 17 00:00:00 2001 From: y Date: Tue, 1 Jul 2025 02:19:51 +0000 Subject: [PATCH 28/29] change assert cpu_has_amx_support() to be an if condition and add warning --- python/sglang/srt/model_executor/model_runner.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ed9198b32a9..7cc00e05057 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -484,10 +484,12 @@ def init_torch_distributed(self): if not self.is_draft_worker: if self.device == "cpu": # Bind OpenMP threads to CPU cores - assert ( - cpu_has_amx_support() - ), "init_cpu_threads_env failed since intel amx backend is not available" - torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) + if _is_cpu_amx_available: + torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) + else: + logger.warning( + "init_cpu_threads_env is skipped since intel amx backend is not available" + ) # Only initialize the distributed environment on the target model worker. init_distributed_environment( From 79286f794d34f6fdd2067a9917c1d74d6c85fb70 Mon Sep 17 00:00:00 2001 From: y Date: Tue, 1 Jul 2025 02:44:07 +0000 Subject: [PATCH 29/29] refine code --- python/sglang/srt/model_executor/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7cc00e05057..72f9725b919 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -483,8 +483,8 @@ def init_torch_distributed(self): if not self.is_draft_worker: if self.device == "cpu": - # Bind OpenMP threads to CPU cores if _is_cpu_amx_available: + # Bind OpenMP threads to CPU cores torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid) else: logger.warning(