Skip to content
Merged
49 changes: 48 additions & 1 deletion tensorrt_llm/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import datetime
import enum
import json
import os
import weakref
from pathlib import Path
from queue import Queue
from typing import Dict, List, Optional, Tuple, Union

import psutil
import torch

from tensorrt_llm.logger import logger
Expand All @@ -19,7 +21,7 @@
from ..llmapi.llm_args import BaseLlmArgs, PybindMirror
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.tracer import global_tracer
from ..llmapi.utils import _SyncQueue, logger_debug
from ..llmapi.utils import _SyncQueue, get_numa_aware_cpu_affinity, logger_debug
from ..lora_manager import LoraManager
from ..metrics import RequestEventTiming
from ..prompt_adapter_manager import PromptAdapterManager
Expand Down Expand Up @@ -91,13 +93,58 @@ def __init__(
if global_mpi_size() > 1:
logger.set_rank(self.global_rank)

def _configure_affinity(self, device_id):
'''Probe and configure the CPU affinity of the worker based on NUMA topology.

Args:
device_id: The CUDA device ID to determine optimal CPU affinity.

Note:
If the process already has constrained affinity, a warning is logged.
Configuration is handled as follows:
TLLM_NUMA_WORKER_AFFINITY = <unset>
-> affinity is auto-configured only if it is unconstrained
TLLM_NUMA_WORKER_AFFINITY = 1
-> affinity is unconditionally auto-configured
TLLM_NUMA_WORKER_AFFINITY = 0 or any other value
-> affinity is unconditionally _not_ auto-configured
'''

# Get the current affinity setting
pid = os.getpid()
process = psutil.Process(pid)
cpu_affinity = process.cpu_affinity()

all_cpus = list(range(psutil.cpu_count()))

constrained_affinity = (cpu_affinity != all_cpus)

# If the process is affined to a constrained set of CPUs, warn the user
# so as to ensure that this is what is intended
if constrained_affinity:
logger.warning(
f"Worker process {pid} is affined to run on the following CPUs: "
f"{cpu_affinity} (subset of all logical CPUs). This may harm "
f"performance if set incorrectly.")

# If affinity is unconstrained and the user hasn't explicitly
# prohibited it or the user has explicitly requested it, choose the
# optimal affinity based upon the NUMA topology
numa_aware_affinity = os.environ.get("TLLM_NUMA_AWARE_WORKER_AFFINITY")
if ((numa_aware_affinity is None and not constrained_affinity)
or (numa_aware_affinity == "1")):
process.cpu_affinity(get_numa_aware_cpu_affinity(device_id))

def _get_comm_ranks_device_id(self):
device_id = self.global_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
# Make sure C++ executor would use same devices/ranks as py_executor
global_rank = global_mpi_rank()
comm_ranks = mpi_comm().allgather(global_rank)
device_ids = mpi_comm().allgather(device_id)

self._configure_affinity(device_id)

return comm_ranks, device_ids

def setup_engine(self):
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/executor/ray_gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def _get_comm_ranks_device_id(self):

torch.distributed.all_gather_object(comm_ranks, global_rank)
torch.distributed.all_gather_object(device_ids, self.device_id)

self._configure_affinity(self.device_id)

return comm_ranks, device_ids

def enqueue_request(self,
Expand Down
12 changes: 1 addition & 11 deletions tensorrt_llm/executor/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from ..llmapi.mpi_session import set_mpi_session_cpp
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.tracer import VizTracer, set_global_tracer
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue,
clear_sched_affinity, logger_debug,
from ..llmapi.utils import (AsyncQueue, ManagedThread, _SyncQueue, logger_debug,
print_traceback_on_error)
from ..sampling_params import BatchedLogitsProcessor
from .base_worker import BaseWorker
Expand Down Expand Up @@ -245,15 +244,6 @@ def worker_main(
mpi_comm().barrier()
logger_debug(f"Worker {mpi_rank()} entering worker_main...\n", "green")

pid = os.getpid()
cpus = os.sched_getaffinity(pid)
if cpus:
logger.warning(
f"Found worker process {pid} was bound to {cpus}, this may harm "
"performance.", )
logger.warning(f"Will clear the cpu affinity")
clear_sched_affinity(pid)

result_queue: Optional[IpcQueue] = None
result_queues: Optional[List[IpcQueue]] = None

Expand Down
63 changes: 49 additions & 14 deletions tensorrt_llm/llmapi/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import collections
import ctypes
import datetime
import hashlib
import inspect
import io
import math
import os
import re
import sys
Expand Down Expand Up @@ -513,24 +515,57 @@ def get(self, timeout=None):
time.sleep(0.01)


def set_sched_setaffinity(required_cores: int):
''' Set the CPU affinity of the current process to the required number of
cores.
def get_numa_aware_cpu_affinity(device_id):
'''Query NVML for NUMA-aware CPU affinity for the specified CUDA device.

Known issue: This may race with other processes that also set the affinity.
'''
cpu_percentages = psutil.cpu_percent(percpu=True)
# sort the cores by usage
free_cores = sorted(range(len(cpu_percentages)),
key=lambda i: cpu_percentages[i])
Args:
device_id: The CUDA device ID to query for optimal CPU affinity.

pid = os.getpid()
os.sched_setaffinity(pid, set(free_cores[:required_cores]))
Returns:
List of CPU IDs representing the optimal CPU affinity mask for the device.

Raises:
pynvml.NVMLError: If NVML operations fail or device_id is invalid.
'''
cpu_count = psutil.cpu_count()

# If this is not a NUMA system, or we hit an exception, default to
# unconstrained CPU affinity
cpu_affinity = list(range(cpu_count))

if not os.path.isdir("/sys/devices/system/node/node1"):
return cpu_affinity

try:
# initialize NVML
import pynvml
pynvml.nvmlInit()

# Get the number of bits per ulong
c_ulong_bits = ctypes.sizeof(ctypes.c_ulong) * 8

# Determine how large our cpu set array from NVML needs to be
cpu_set_size = math.ceil(cpu_count / c_ulong_bits)

# Get the optimal CPU affinity for this device according to the NUMA
# topology
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
affinity_masks = pynvml.nvmlDeviceGetCpuAffinity(handle, cpu_set_size)

# Convert CPU masks to python list
cpu_affinity = []
for cpu_id in range(cpu_count):
mask_array_index = cpu_id // c_ulong_bits
mask_bit_index = cpu_id % c_ulong_bits
if affinity_masks[mask_array_index] & (1 << mask_bit_index):
cpu_affinity.append(cpu_id)
finally:
try:
pynvml.nvmlShutdown()
except:
pass # Ignore shutdown errors

def clear_sched_affinity(pid: int):
''' Clear the CPU affinity of the current process. '''
os.sched_setaffinity(pid, set(range(psutil.cpu_count())))
return cpu_affinity


def generate_api_docs_as_docstring(model: Type[BaseModel],
Expand Down