Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion tensorrt_llm/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import datetime
import enum
import json
import os
import weakref
from pathlib import Path
from queue import Queue
from typing import Dict, List, Optional, Tuple, Union

import psutil
import torch

from tensorrt_llm.logger import logger
Expand All @@ -19,7 +21,8 @@
from ..llmapi.llm_args import BaseLlmArgs, KvCacheConnectorConfig, PybindMirror
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.tracer import global_tracer
from ..llmapi.utils import _SyncQueue, print_colored_debug
from ..llmapi.utils import (_SyncQueue, get_numa_aware_cpu_affinity,
print_colored_debug)
from ..lora_helper import LoraConfig
from ..lora_manager import LoraManager
from ..metrics import RequestEventTiming
Expand Down Expand Up @@ -100,6 +103,48 @@ def __init__(
if global_mpi_size() > 1:
logger.set_rank(self.global_rank)

def _configure_affinity(self, device_id):
'''Probe and configure the CPU affinity of the worker based on NUMA topology.

Args:
device_id: The CUDA device ID to determine optimal CPU affinity.

Note:
If the process already has constrained affinity, a warning is logged.
Configuration is handled as follows:
TLLM_NUMA_WORKER_AFFINITY = <unset>
-> affinity is auto-configured only if it is unconstrained
TLLM_NUMA_WORKER_AFFINITY = 1
-> affinity is unconditionally auto-configured
TLLM_NUMA_WORKER_AFFINITY = 0 or any other value
-> affinity is unconditionally _not_ auto-configured
'''

# Get the current affinity setting
pid = os.getpid()
process = psutil.Process(pid)
cpu_affinity = process.cpu_affinity()

all_cpus = list(range(psutil.cpu_count()))

constrained_affinity = (cpu_affinity != all_cpus)

# If the process is affined to a constrained set of CPUs, warn the user
# so as to ensure that this is what is intended
if constrained_affinity:
logger.warning(
f"Worker process {pid} is affined to run on the following CPUs: "
f"{cpu_affinity} (subset of all logical CPUs). This may harm "
f"performance if set incorrectly.")

# If affinity is unconstrained and the user hasn't explicitly
# prohibited it or the user has explicitly requested it, choose the
# optimal affinity based upon the NUMA topology
numa_aware_affinity = os.environ.get("TLLM_NUMA_AWARE_WORKER_AFFINITY")
if ((numa_aware_affinity is None and not constrained_affinity)
or (numa_aware_affinity == "1")):
process.cpu_affinity(get_numa_aware_cpu_affinity(device_id))

def setup_engine(self):
"""
Setup the engine for the worker.
Expand All @@ -115,6 +160,9 @@ def _get_comm_ranks_device_id():
global_rank = global_mpi_rank()
comm_ranks = mpi_comm().allgather(global_rank)
device_ids = mpi_comm().allgather(device_id)

self._configure_affinity(device_id)

return comm_ranks, device_ids

def _create_py_executor():
Expand Down
13 changes: 1 addition & 12 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import os
import time
import traceback
from concurrent.futures import ProcessPoolExecutor
Expand All @@ -19,8 +18,7 @@
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.tracer import VizTracer, set_global_tracer
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
clear_sched_affinity, print_colored_debug,
print_traceback_on_error)
print_colored_debug, print_traceback_on_error)
from ..lora_helper import LoraConfig
from ..sampling_params import BatchedLogitsProcessor
from .base_worker import BaseWorker
Expand Down Expand Up @@ -276,15 +274,6 @@ def worker_main(
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",
"green")

pid = os.getpid()
cpus = os.sched_getaffinity(pid)
if cpus:
logger.warning(
f"Found worker process {pid} was bound to {cpus}, this may harm "
"performance.", )
logger.warning(f"Will clear the cpu affinity")
clear_sched_affinity(pid)

result_queue: Optional[IpcQueue] = None
result_queues: Optional[List[IpcQueue]] = None

Expand Down
63 changes: 49 additions & 14 deletions tensorrt_llm/llmapi/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import collections
import ctypes
import hashlib
import io
import math
import os
import re
import sys
Expand Down Expand Up @@ -454,24 +456,57 @@ def get(self, timeout=None):
time.sleep(0.01)


def set_sched_setaffinity(required_cores: int):
''' Set the CPU affinity of the current process to the required number of
cores.
def get_numa_aware_cpu_affinity(device_id):
'''Query NVML for NUMA-aware CPU affinity for the specified CUDA device.

Known issue: This may race with other processes that also set the affinity.
'''
cpu_percentages = psutil.cpu_percent(percpu=True)
# sort the cores by usage
free_cores = sorted(range(len(cpu_percentages)),
key=lambda i: cpu_percentages[i])
Args:
device_id: The CUDA device ID to query for optimal CPU affinity.

pid = os.getpid()
os.sched_setaffinity(pid, set(free_cores[:required_cores]))
Returns:
List of CPU IDs representing the optimal CPU affinity mask for the device.

Raises:
pynvml.NVMLError: If NVML operations fail or device_id is invalid.
'''
cpu_count = psutil.cpu_count()

# If this is not a NUMA system, or we hit an exception, default to
# unconstrained CPU affinity
cpu_affinity = list(range(cpu_count))

if not os.path.isdir("/sys/devices/system/node/node1"):
return cpu_affinity

try:
# initialize NVML
import pynvml
pynvml.nvmlInit()

# Get the number of bits per ulong
c_ulong_bits = ctypes.sizeof(ctypes.c_ulong) * 8

# Determine how large our cpu set array from NVML needs to be
cpu_set_size = math.ceil(cpu_count / c_ulong_bits)

# Get the optimal CPU affinity for this device according to the NUMA
# topology
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
affinity_masks = pynvml.nvmlDeviceGetCpuAffinity(handle, cpu_set_size)

# Convert CPU masks to python list
cpu_affinity = []
for cpu_id in range(cpu_count):
mask_array_index = cpu_id // c_ulong_bits
mask_bit_index = cpu_id % c_ulong_bits
if affinity_masks[mask_array_index] & (1 << mask_bit_index):
cpu_affinity.append(cpu_id)
finally:
try:
pynvml.nvmlShutdown()
except:
pass # Ignore shutdown errors
Comment on lines +480 to +507
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard NUMA affinity when NVML isn’t available

If pynvml is missing or NVML init fails, the finally block still tries to call pynvml.nvmlShutdown(), but pynvml is undefined in that scenario. The resulting UnboundLocalError bubbles up into BaseWorker._configure_affinity and prevents workers from starting on any host without NVML installed—exactly the environment where we need to fall back to the unconstrained affinity declared at the top of this helper. Please short-circuit when the import or NVML calls fail so we return the default affinity safely.

Apply this diff:

-    try:
-        # initialize NVML
-        import pynvml
-        pynvml.nvmlInit()
+    try:
+        import pynvml
+    except ModuleNotFoundError:
+        return cpu_affinity
+
+    try:
+        pynvml.nvmlInit()
@@
-    finally:
-        try:
-            pynvml.nvmlShutdown()
-        except:
-            pass  # Ignore shutdown errors
+    except pynvml.NVMLError:
+        return cpu_affinity
+    finally:
+        try:
+            pynvml.nvmlShutdown()
+        except pynvml.NVMLError:
+            pass  # Ignore shutdown errors
🧰 Tools
🪛 Ruff (0.14.4)

506-506: Do not use bare except

(E722)


506-507: try-except-pass detected, consider logging the exception

(S110)


def clear_sched_affinity(pid: int):
''' Clear the CPU affinity of the current process. '''
os.sched_setaffinity(pid, set(range(psutil.cpu_count())))
return cpu_affinity


def generate_api_docs_as_docstring(model: Type[BaseModel],
Expand Down