Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 136 additions & 158 deletions projects/hipblaslt/tensilelite/Tensile/Common/Parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,43 +22,58 @@
#
################################################################################

import concurrent.futures
import itertools
import multiprocessing
import os
import re
import sys
import time

from joblib import Parallel, delayed
from functools import partial
from typing import Any, Callable

from .Utilities import tqdm


def joblibParallelSupportsGenerator():
import joblib
from packaging.version import Version
def get_inherited_job_limit() -> int:
# 1. Check CMAKE_BUILD_PARALLEL_LEVEL (CMake 3.12+)
if 'CMAKE_BUILD_PARALLEL_LEVEL' in os.environ:
try:
return int(os.environ['CMAKE_BUILD_PARALLEL_LEVEL'])
except ValueError:
pass

joblibVer = joblib.__version__
return Version(joblibVer) >= Version("1.4.0")
# 2. Parse MAKEFLAGS for -jN
makeflags = os.environ.get('MAKEFLAGS', '')
match = re.search(r'-j\s*(\d+)', makeflags)
if match:
return int(match.group(1))
Comment on lines +44 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We aren't gaurunteed to have MAKEFLAGS e.g. what if ninja or another generator is used. Why require that the parallel level is set so you don't have to reach into the environment for this information. We can easily accommodate this in the build system.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My goal here was to inherit from the build system if set, otherwise fall back to the existing logic. That way when I'm building with CMAKE_BUILD_PARALLEL_LEVEL=64 it doesn't try to run with 128t.


return -1

def CPUThreadCount(enable=True):
from .GlobalParameters import globalParameters

def CPUThreadCount(enable=True):
if not enable:
return 1
else:
from .GlobalParameters import globalParameters

# Priority order:
# 1. Inherited from build system (CMAKE_BUILD_PARALLEL_LEVEL or MAKEFLAGS)
# 2. Explicit --jobs flag
# 3. Auto-detect
inherited_limit = get_inherited_job_limit()
cpuThreads = inherited_limit if inherited_limit > 0 else globalParameters["CpuThreads"]

if cpuThreads < 1:
if os.name == "nt":
# Windows supports at most 61 workers because the scheduler uses
# WaitForMultipleObjects directly, which has the limit (the limit
# is actually 64, but some handles are needed for accounting).
cpu_count = min(os.cpu_count(), 61)
cpuThreads = os.cpu_count()
else:
cpu_count = len(os.sched_getaffinity(0))
cpuThreads = globalParameters["CpuThreads"]
if cpuThreads == -1:
return cpu_count
cpuThreads = len(os.sched_getaffinity(0))

return min(cpu_count, cpuThreads)
if os.name == "nt":
# Windows supports at most 61 workers because the scheduler uses
# WaitForMultipleObjects directly, which has the limit (the limit
# is actually 64, but some handles are needed for accounting).
cpuThreads = min(cpuThreads, 61)
return max(1, cpuThreads)


def pcallWithGlobalParamsMultiArg(f, args, newGlobalParameters):
Expand All @@ -71,19 +86,15 @@ def pcallWithGlobalParamsSingleArg(f, arg, newGlobalParameters):
return f(arg)


def apply_print_exception(item, *args):
# print(item, args)
def worker_function(args, function, multiArg):
"""Worker function that executes in the pool process."""
try:
if len(args) > 0:
func = item
args = args[0]
return func(*args)
if multiArg:
return function(*args)
else:
func, item = item
return func(item)
return function(args)
except Exception:
import traceback

traceback.print_exc()
raise
finally:
Expand All @@ -98,154 +109,121 @@ def OverwriteGlobalParameters(newGlobalParameters):
GlobalParameters.globalParameters.update(newGlobalParameters)


def ProcessingPool(enable=True, maxTasksPerChild=None):
import multiprocessing
import multiprocessing.dummy

threadCount = CPUThreadCount()

if (not enable) or threadCount <= 1:
return multiprocessing.dummy.Pool(1)

if multiprocessing.get_start_method() == "spawn":
from . import GlobalParameters

return multiprocessing.Pool(
threadCount,
initializer=OverwriteGlobalParameters,
maxtasksperchild=maxTasksPerChild,
initargs=(GlobalParameters.globalParameters,),
)
else:
return multiprocessing.Pool(threadCount, maxtasksperchild=maxTasksPerChild)
def progress_logger(iterable, total, message, min_log_interval=5.0):
"""
Generator that wraps an iterable and logs progress with time-based throttling.

Only logs progress if at least min_log_interval seconds have passed since last log.
Only prints completion message if task took >= min_log_interval seconds.

def ParallelMap(function, objects, message="", enable=True, method=None, maxTasksPerChild=None):
Yields (index, item) tuples.
"""
Generally equivalent to list(map(function, objects)), possibly executing in parallel.

message: A message describing the operation to be performed.
enable: May be set to false to disable parallelism.
method: A function which can fetch the mapping function from a processing pool object.
Leave blank to use .map(), other possiblities:
- `lambda x: x.starmap` - useful if `function` takes multiple parameters.
- `lambda x: x.imap` - lazy evaluation
- `lambda x: x.imap_unordered` - lazy evaluation, does not preserve order of return value.
"""
from .GlobalParameters import globalParameters
start_time = time.time()
last_log_time = start_time
log_interval = 1 + (total // 100)

threadCount = CPUThreadCount(enable)
pool = ProcessingPool(enable, maxTasksPerChild)

if threadCount <= 1 and globalParameters["ShowProgressBar"]:
# Provide a progress bar for single-threaded operation.
# This works for method=None, and for starmap.
mapFunc = map
if method is not None:
# itertools provides starmap which can fill in for pool.starmap. It provides imap on Python 2.7.
# If this works, we will use it, otherwise we will fallback to the "dummy" pool for single threaded
# operation.
try:
mapFunc = method(itertools)
except NameError:
mapFunc = None

if mapFunc is not None:
return list(mapFunc(function, tqdm(objects, message)))

mapFunc = pool.map
if method:
mapFunc = method(pool)

objects = zip(itertools.repeat(function), objects)
function = apply_print_exception

countMessage = ""
try:
countMessage = " for {} tasks".format(len(objects))
except TypeError:
pass
for idx, item in enumerate(iterable):
if idx % log_interval == 0:
current_time = time.time()
if (current_time - last_log_time) >= min_log_interval:
print(f"{message}\t{idx+1: 5d}/{total: 5d}")
last_log_time = current_time
yield idx, item

if message != "":
message += ": "
elapsed = time.time() - start_time
final_idx = idx + 1 if 'idx' in locals() else 0

print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage))
sys.stdout.flush()
currentTime = time.time()
rv = mapFunc(function, objects)
totalTime = time.time() - currentTime
print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime))
sys.stdout.flush()
pool.close()
return rv
if elapsed >= min_log_interval or last_log_time > start_time:
print(f"{message} done in {elapsed:.1f}s!\t{final_idx: 5d}/{total: 5d}")


def ParallelMapReturnAsGenerator(function, objects, message="", enable=True, multiArg=True):
from .GlobalParameters import globalParameters
def imap_with_progress(pool, func, iterable, total, message, chunksize):
results = []
for _, result in progress_logger(pool.imap(func, iterable, chunksize=chunksize), total, message):
results.append(result)
return results

threadCount = CPUThreadCount(enable)
print("{0}Launching {1} threads...".format(message, threadCount))

if threadCount <= 1 and globalParameters["ShowProgressBar"]:
# Provide a progress bar for single-threaded operation.
callFunc = lambda args: function(*args) if multiArg else lambda args: function(args)
return [callFunc(args) for args in tqdm(objects, message)]
def _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild):
# separate fn because yield makes the entire fn a generator even if unreachable
ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn')

with concurrent.futures.ProcessPoolExecutor(max_workers=threadCount) as executor:
resultFutures = (executor.submit(function, *arg if multiArg else arg) for arg in objects)
for result in concurrent.futures.as_completed(resultFutures):
yield result.result()
with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild,
initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool:
for _, result in progress_logger(pool.imap_unordered(worker, objects, chunksize=chunksize), objLen, message):
yield result


def ParallelMap2(
function, objects, message="", enable=True, multiArg=True, return_as="list", procs=None
function: Callable,
objects: Any,
message: str = "",
enable: bool = True,
multiArg: bool = True,
minChunkSize: int = 1,
maxWorkers: int = -1,
maxtasksperchild: int = 1024,
return_as: str = "list"
):
"""Executes a function over a list of objects in parallel or sequentially.

This function is generally equivalent to ``list(map(function, objects))``. However, it provides
additional functionality to run in parallel, depending on the 'enable' flag and available CPU
threads.

Args:
function: The function to apply to each item in 'objects'. If 'multiArg' is True, 'function'
should accept multiple arguments.
objects: An iterable of objects to be processed by 'function'. If 'multiArg' is True, each
item in 'objects' should be an iterable of arguments for 'function'.
message: Optional; a message describing the operation. Default is an empty string.
enable: Optional; if False, disables parallel execution and runs sequentially. Default is True.
multiArg: Optional; if True, treats each item in 'objects' as multiple arguments for
'function'. Default is True.
return_as: Optional; "list" (default) or "generator_unordered" for streaming results

Returns:
A list or generator containing the results of applying **function** to each item in **objects**.
"""
Generally equivalent to list(map(function, objects)), possibly executing in parallel.
from .GlobalParameters import globalParameters

message: A message describing the operation to be performed.
enable: May be set to false to disable parallelism.
multiArg: True if objects represent multiple arguments
(differentiates multi args vs single collection arg)
"""
if return_as in ("generator", "generator_unordered") and not joblibParallelSupportsGenerator():
return ParallelMapReturnAsGenerator(function, objects, message, enable, multiArg)
threadCount = CPUThreadCount(enable)

from .GlobalParameters import globalParameters
if not hasattr(objects, "__len__"):
objects = list(objects)

threadCount = procs if procs else CPUThreadCount(enable)
objLen = len(objects)
if objLen == 0:
return [] if return_as == "list" else iter([])

threadCount = CPUThreadCount(enable)
f = (lambda x: function(*x)) if multiArg else function
if objLen == 1:
print(f"{message}: (1 task)")
result = [f(x) for x in objects]
return result if return_as == "list" else iter(result)

if threadCount <= 1 and globalParameters["ShowProgressBar"]:
# Provide a progress bar for single-threaded operation.
return [function(*args) if multiArg else function(args) for args in tqdm(objects, message)]
extra_message = (
f": {threadCount} thread(s)" + f", {objLen} tasks"
if objLen
else ""
)

countMessage = ""
try:
countMessage = " for {} tasks".format(len(objects))
except TypeError:
pass

if message != "":
message += ": "
print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage))
sys.stdout.flush()
currentTime = time.time()

pcall = pcallWithGlobalParamsMultiArg if multiArg else pcallWithGlobalParamsSingleArg
pargs = zip(objects, itertools.repeat(globalParameters))

if joblibParallelSupportsGenerator():
rv = Parallel(n_jobs=threadCount, timeout=99999, return_as=return_as)(
delayed(pcall)(function, a, params) for a, params in pargs
)
print(f"ParallelMap {message}{extra_message}")

if threadCount <= 1:
result = [f(x) for x in objects]
return result if return_as == "list" else iter(result)

if maxWorkers > 0:
threadCount = min(maxWorkers, threadCount)

chunksize = max(minChunkSize, objLen // 2000)
worker = partial(worker_function, function=function, multiArg=multiArg)
if return_as == "generator_unordered":
# yield results as they complete without buffering
return _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild)
else:
rv = Parallel(n_jobs=threadCount, timeout=99999)(
delayed(pcall)(function, a, params) for a, params in pargs
)

totalTime = time.time() - currentTime
print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime))
sys.stdout.flush()
return rv
ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn')
with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild,
initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool:
return list(imap_with_progress(pool, worker, objects, objLen, message, chunksize))
17 changes: 15 additions & 2 deletions projects/hipblaslt/tensilelite/Tensile/Common/Utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import functools
import math
import operator
import os
import sys
import time
Expand Down Expand Up @@ -269,8 +270,20 @@ def state(obj):


def state_key_ordering(cls):
def tup(obj):
return tuple([getattr(obj, k) for k in cls.StateKeys])
# Use operator.attrgetter for efficiency if __slots__ is defined
if hasattr(cls, '__slots__'):
# attrgetter is faster for slotted classes
getter = operator.attrgetter(*cls.StateKeys)
if len(cls.StateKeys) == 1:
# attrgetter returns scalar for single key, we need tuple
def tup(obj):
return (getter(obj),)
else:
tup = getter
else:
# Fallback for regular classes
def tup(obj):
return tuple([getattr(obj, k) for k in cls.StateKeys])

def lt(a, b):
return tup(a) < tup(b)
Expand Down
Loading