Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
92c9d64
bind OpenMP threads to CPU cores (#4)
chunyuan-w Feb 16, 2025
e7a8212
De-couple vLLM CPU and remove dependency of `sudo apt-get install lib…
chunyuan-w May 20, 2025
35ac92a
Automatically bind threads if `SGLANG_CPU_OMP_THREADS_BIND` is not se…
chunyuan-w May 21, 2025
1ac1ee5
use torch.ops
chunyuan-w May 22, 2025
d6ffa4b
add numa related change in CMakeList
chunyuan-w May 22, 2025
2b4f892
use CatchAll as dispatch key for init_cpu_threads_env since it does n…
chunyuan-w May 22, 2025
8daac45
move binding code to init_threads_binding
chunyuan-w May 22, 2025
04cafa0
fix format
chunyuan-w May 22, 2025
e9719ad
replace gettid with syscall (#78)
chunyuan-w May 23, 2025
6e6c97a
fix format
chunyuan-w May 23, 2025
40968a0
add check if libnuma is in conda dir
chunyuan-w May 23, 2025
6319532
add UT
chunyuan-w May 23, 2025
9b2f001
fix UT format
chunyuan-w Jun 3, 2025
e3057c1
add assert when calling init_cpu_threads_env
chunyuan-w May 28, 2025
6bdd01d
use TORCH_WARN when checking return of numa_migrate_pages
chunyuan-w Jun 12, 2025
3701e26
join node_ids to string instead of using start-prev
chunyuan-w Jun 12, 2025
ead3394
add {CONDA_PREFIX}/include into include_directories
chunyuan-w Jun 12, 2025
89fbe17
move functions in cpu_utils.py to utils.py
chunyuan-w Jun 12, 2025
584194d
set ENV CONDA_PREFIX in Dockerfile
chunyuan-w Jun 12, 2025
705a6a6
only bind to CPUs that are allowed
chunyuan-w Jun 12, 2025
1cb1f60
remove dead code
chunyuan-w Jun 12, 2025
d943a8c
update UT
chunyuan-w Jun 12, 2025
6f4656b
add UTs that can pass back
chunyuan-w Jun 12, 2025
2a81064
prioritize torch cpu wheel host
srinarayan-srikanthan Jun 13, 2025
2278dcc
upgrade torch and dependencies to 2.7.1 stack
chunyuan-w Jun 16, 2025
1e55a16
move sgl-kernel changes to #7524
chunyuan-w Jun 25, 2025
7a0b0fb
refine warning message
chunyuan-w Jun 30, 2025
8cb323c
Merge branch 'main' into chunyuan/pr_core_bind
zhyncs Jun 30, 2025
8d94c8d
change assert cpu_has_amx_support() to be an if condition and add war…
chunyuan-w Jul 1, 2025
79286f7
refine code
chunyuan-w Jul 1, 2025
70ac1d1
Merge branch 'main' into chunyuan/pr_core_bind
zhyncs Jul 1, 2025
a2c15a7
Merge branch 'main' into chunyuan/pr_core_bind
zhyncs Jul 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
enable_show_time_cost,
get_available_gpu_memory,
get_bool_env_var,
get_cpu_ids_by_node,
init_custom_process_group,
is_cuda,
is_fa3_default_architecture,
Expand Down Expand Up @@ -211,6 +212,10 @@ def __init__(
# CPU offload
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))

# Init OpenMP threads binding for CPU
if self.device == "cpu":
self.init_threads_binding()

# Get memory before model loading
min_per_gpu_memory = self.init_torch_distributed()

Expand Down Expand Up @@ -477,6 +482,15 @@ def init_torch_distributed(self):
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)

if not self.is_draft_worker:
if self.device == "cpu":
if _is_cpu_amx_available:
# Bind OpenMP threads to CPU cores
torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
else:
logger.warning(
"init_cpu_threads_env is skipped since intel amx backend is not available"
)

# Only initialize the distributed environment on the target model worker.
init_distributed_environment(
backend=backend,
Expand Down Expand Up @@ -1273,6 +1287,30 @@ def init_cuda_graphs(self):
f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
)

def init_threads_binding(self):
omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
if omp_cpuids == "all":
cpu_ids_by_node = get_cpu_ids_by_node()
n_numa_node = len(cpu_ids_by_node)

assert self.tp_size <= n_numa_node, (
f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. "
f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, "
f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. "
f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. "
f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. "
f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2."
)
if self.tp_size < n_numa_node:
logger.warning(
f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used."
)
self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
else:
self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank]

def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
from sglang.srt.model_parallel import tensor_parallel
Expand Down
67 changes: 67 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import time
import traceback
import warnings
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum
from functools import lru_cache
Expand Down Expand Up @@ -2545,3 +2546,69 @@ def align(x: int, y: int) -> int:
# COPIED FROM DeepGEMM
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y


def parse_lscpu_topology():
try:
# Get CPU topology: CPU,Core,Socket,Node
output = subprocess.check_output(
["lscpu", "-p=CPU,Core,Socket,Node"], text=True
)
except Exception as e:
raise RuntimeError(f"Unexpected error running 'lscpu': {e}")

# Parse only data lines (skip comments)
cpu_info = []
for line in output.splitlines():
if not line.startswith("#"):
cpu, core, socket, node = map(int, line.strip().split(","))
cpu_info.append((cpu, core, socket, node))

# [(0,0,0,0),(1,1,0,0),...,(43,43,0,1),...,(256,0,0,0),...]
return cpu_info


def get_physical_cpus_by_numa():
cpu_info = parse_lscpu_topology()

# Map NUMA node -> set of (core_id, socket) to avoid duplicates
# 0: {(0,0): 0, (1, 0): 1,...}
# ...
# 5: {(214,1): 214, (215,1): 215}
physical_by_node = defaultdict(dict) # node -> core_id -> cpu_id

for cpu, core, socket, node in cpu_info:
key = (core, socket)
if key not in physical_by_node[node]:
physical_by_node[node][
key
] = cpu # pick first CPU seen for that physical core

# Retrieves CPUs that the current process is allowed to run on
cpus_allowed_list = psutil.Process().cpu_affinity()

# Convert to list of physical CPUs per node
# 0: [0,1,2,...,42]
# ...
# 2: [86,87,...,127]
# ...
# 5: [214,215,...,255]
node_to_cpus = {}
for node, core_to_cpu in physical_by_node.items():
cpus = sorted(core_to_cpu.values())
allowed_cpus = set(cpus).intersection(cpus_allowed_list)
node_to_cpus[node] = allowed_cpus

return node_to_cpus


# Only physical cores are used. Logical cores are excluded.
def get_cpu_ids_by_node():
node_to_cpus = get_physical_cpus_by_numa()
# Sort by NUMA node index
cpu_ids = [
",".join(map(str, sorted(node_to_cpus[node]))) for node in sorted(node_to_cpus)
]

# ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23']
return cpu_ids
Loading