diff --git a/launcher/launch.py b/launcher/launch.py index 33415ae04..c99ece38f 100644 --- a/launcher/launch.py +++ b/launcher/launch.py @@ -2,10 +2,12 @@ from __future__ import print_function import os +import re import subprocess import threading import sys import time +from functools import reduce class PropagatingThread(threading.Thread): @@ -36,6 +38,88 @@ def join(self): COMMON_REQUIRED_ENVS = ["DMLC_ROLE", "DMLC_NUM_WORKER", "DMLC_NUM_SERVER", "DMLC_PS_ROOT_URI", "DMLC_PS_ROOT_PORT"] WORKER_REQUIRED_ENVS = ["DMLC_WORKER_ID"] +NUMA_PATH = "/sys/devices/system/node" + + +def get_numa_info(): + ret = [] + if os.path.exists(NUMA_PATH): + items = os.listdir(NUMA_PATH) + nodes = list(filter(lambda str: str.startswith("node"), items)) + if nodes: + for node in nodes: + items = os.listdir(os.path.join(NUMA_PATH, node)) + cpus = [re.findall("cpu\d+", cpu) for cpu in items] + cpus = list(filter(lambda x: x, cpus)) + cpu_ids = [int(cpu[0].split('cpu')[1]) for cpu in cpus] + cpu_ids = sorted(cpu_ids) + ret.append(cpu_ids) + else: + print("NUMA PATH %s NOT FOUND" % NUMA_PATH) + return ret + + +def allocate_cpu(local_size): + def _get_allocation(nodes, quota): + if quota < 1: + raise ValueError("quota should be no less than 1") + ret = [] + for node in nodes: + if len(node) < quota: + continue + split_index = [] + for i in range(1, quota): + if node[i] != node[i-1] + 1: + split_index.append(i) + quota_bck = quota + last_idx = 0 + for idx in split_index: + ret.append(node[last_idx:idx]) + quota -= idx - last_idx + last_idx = idx + ret.append(node[last_idx:last_idx+quota]) + for idx in sorted(range(quota_bck), reverse=True): + del node[idx] + return ret + return ret + + def _get_quota(nodes, local_size): + if len(nodes) > 1: + cpu_nums = reduce(lambda x, y: (len(x) + len(y)), nodes) + else: + cpu_nums = len(nodes[0]) + + # default quota is the number of cpus for non-root processess + default_quota = int(os.getenv("BYTEPS_NUMA_DEFAULT_QUOTA", 6)) + while default_quota >= 1 and default_quota * local_size > cpu_nums: + default_quota -= 2 + + # root quota is the number of cpus for root processess + # root does more work, thus using more cpus + root_quota = cpu_nums - default_quota * (local_size - 1) + if int(os.getenv("BYTEPS_NUMA_ROOT_QUOTA", 0)): + root_quota = int(os.getenv("BYTEPS_NUMA_ROOT_QUOTA", 0)) + + node_size = len(nodes[0]) + while root_quota >= 1 and root_quota > node_size: + root_quota -= 2 + return [default_quota] * (local_size - 1) + [root_quota] + + nodes = get_numa_info() + if not nodes: + return None + quota_list = _get_quota(nodes, local_size) + ret = [] + for quota in quota_list: + while quota > 0: + allocation = _get_allocation(nodes, quota) + if allocation: + ret.append(allocation) + break + else: + quota -= 2 + + return ret def check_env(): @@ -55,7 +139,7 @@ def check_env(): os._exit(0) -def worker(local_rank, local_size, command): +def worker(local_rank, local_size, command, allocation=None): my_env = os.environ.copy() my_env["BYTEPS_LOCAL_RANK"] = str(local_rank) my_env["BYTEPS_LOCAL_SIZE"] = str(local_size) @@ -64,6 +148,20 @@ def worker(local_rank, local_size, command): command = "python " + command command = "gdb -ex 'run' -ex 'bt' -batch --args " + command + if allocation: + print("enable NUMA finetune...") + retval = subprocess.call( + ["dpkg", "-s", "numactl"], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT) + if retval == 0: + numa = "numactl --physcpubind " + for cpu_set in allocation: + numa += "{}-{},".format(cpu_set[0], cpu_set[-1]) + numa = numa.strip(',') + ' ' + command = numa + command + print("Command: %s\n" % command) + else: + print("Warning: numactl not found. try `sudo apt-get install numactl`.") + if os.environ.get("BYTEPS_TRACE_ON", "") == "1": print("\n!!!Enable profiling for WORKER_ID: %s and local_rank: %d!!!" % (os.environ.get("DMLC_WORKER_ID"), local_rank)) @@ -89,10 +187,18 @@ def launch_bps(): else: local_size = 1 t = [None] * local_size + + if os.environ.get("BYTEPS_NUMA_ON", "") == "1": + allocations = allocate_cpu(local_size) + for i in range(local_size): command = ' '.join(sys.argv[1:]) - t[i] = PropagatingThread(target=worker, args=[ - i, local_size, command]) + if os.environ.get("BYTEPS_NUMA_ON", "") == "1": + t[i] = PropagatingThread(target=worker, args=[ + i, local_size, command, allocations[i]]) + else: + t[i] = PropagatingThread(target=worker, args=[ + i, local_size, command]) t[i].daemon = True t[i].start()