diff --git a/projects/hipblaslt/tensilelite/Tensile/Common/Parallel.py b/projects/hipblaslt/tensilelite/Tensile/Common/Parallel.py index 1a2bf9e1190a..f2cfed554e46 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Common/Parallel.py +++ b/projects/hipblaslt/tensilelite/Tensile/Common/Parallel.py @@ -22,43 +22,58 @@ # ################################################################################ -import concurrent.futures -import itertools +import multiprocessing import os +import re import sys import time - -from joblib import Parallel, delayed +from functools import partial +from typing import Any, Callable from .Utilities import tqdm -def joblibParallelSupportsGenerator(): - import joblib - from packaging.version import Version +def get_inherited_job_limit() -> int: + # 1. Check CMAKE_BUILD_PARALLEL_LEVEL (CMake 3.12+) + if 'CMAKE_BUILD_PARALLEL_LEVEL' in os.environ: + try: + return int(os.environ['CMAKE_BUILD_PARALLEL_LEVEL']) + except ValueError: + pass - joblibVer = joblib.__version__ - return Version(joblibVer) >= Version("1.4.0") + # 2. Parse MAKEFLAGS for -jN + makeflags = os.environ.get('MAKEFLAGS', '') + match = re.search(r'-j\s*(\d+)', makeflags) + if match: + return int(match.group(1)) + return -1 -def CPUThreadCount(enable=True): - from .GlobalParameters import globalParameters +def CPUThreadCount(enable=True): if not enable: return 1 - else: + from .GlobalParameters import globalParameters + + # Priority order: + # 1. Inherited from build system (CMAKE_BUILD_PARALLEL_LEVEL or MAKEFLAGS) + # 2. Explicit --jobs flag + # 3. Auto-detect + inherited_limit = get_inherited_job_limit() + cpuThreads = inherited_limit if inherited_limit > 0 else globalParameters["CpuThreads"] + + if cpuThreads < 1: if os.name == "nt": - # Windows supports at most 61 workers because the scheduler uses - # WaitForMultipleObjects directly, which has the limit (the limit - # is actually 64, but some handles are needed for accounting). - cpu_count = min(os.cpu_count(), 61) + cpuThreads = os.cpu_count() else: - cpu_count = len(os.sched_getaffinity(0)) - cpuThreads = globalParameters["CpuThreads"] - if cpuThreads == -1: - return cpu_count + cpuThreads = len(os.sched_getaffinity(0)) - return min(cpu_count, cpuThreads) + if os.name == "nt": + # Windows supports at most 61 workers because the scheduler uses + # WaitForMultipleObjects directly, which has the limit (the limit + # is actually 64, but some handles are needed for accounting). + cpuThreads = min(cpuThreads, 61) + return max(1, cpuThreads) def pcallWithGlobalParamsMultiArg(f, args, newGlobalParameters): @@ -71,19 +86,15 @@ def pcallWithGlobalParamsSingleArg(f, arg, newGlobalParameters): return f(arg) -def apply_print_exception(item, *args): - # print(item, args) +def worker_function(args, function, multiArg): + """Worker function that executes in the pool process.""" try: - if len(args) > 0: - func = item - args = args[0] - return func(*args) + if multiArg: + return function(*args) else: - func, item = item - return func(item) + return function(args) except Exception: import traceback - traceback.print_exc() raise finally: @@ -98,154 +109,121 @@ def OverwriteGlobalParameters(newGlobalParameters): GlobalParameters.globalParameters.update(newGlobalParameters) -def ProcessingPool(enable=True, maxTasksPerChild=None): - import multiprocessing - import multiprocessing.dummy - - threadCount = CPUThreadCount() - - if (not enable) or threadCount <= 1: - return multiprocessing.dummy.Pool(1) - - if multiprocessing.get_start_method() == "spawn": - from . import GlobalParameters - - return multiprocessing.Pool( - threadCount, - initializer=OverwriteGlobalParameters, - maxtasksperchild=maxTasksPerChild, - initargs=(GlobalParameters.globalParameters,), - ) - else: - return multiprocessing.Pool(threadCount, maxtasksperchild=maxTasksPerChild) +def progress_logger(iterable, total, message, min_log_interval=5.0): + """ + Generator that wraps an iterable and logs progress with time-based throttling. + Only logs progress if at least min_log_interval seconds have passed since last log. + Only prints completion message if task took >= min_log_interval seconds. -def ParallelMap(function, objects, message="", enable=True, method=None, maxTasksPerChild=None): + Yields (index, item) tuples. """ - Generally equivalent to list(map(function, objects)), possibly executing in parallel. - - message: A message describing the operation to be performed. - enable: May be set to false to disable parallelism. - method: A function which can fetch the mapping function from a processing pool object. - Leave blank to use .map(), other possiblities: - - `lambda x: x.starmap` - useful if `function` takes multiple parameters. - - `lambda x: x.imap` - lazy evaluation - - `lambda x: x.imap_unordered` - lazy evaluation, does not preserve order of return value. - """ - from .GlobalParameters import globalParameters + start_time = time.time() + last_log_time = start_time + log_interval = 1 + (total // 100) - threadCount = CPUThreadCount(enable) - pool = ProcessingPool(enable, maxTasksPerChild) - - if threadCount <= 1 and globalParameters["ShowProgressBar"]: - # Provide a progress bar for single-threaded operation. - # This works for method=None, and for starmap. - mapFunc = map - if method is not None: - # itertools provides starmap which can fill in for pool.starmap. It provides imap on Python 2.7. - # If this works, we will use it, otherwise we will fallback to the "dummy" pool for single threaded - # operation. - try: - mapFunc = method(itertools) - except NameError: - mapFunc = None - - if mapFunc is not None: - return list(mapFunc(function, tqdm(objects, message))) - - mapFunc = pool.map - if method: - mapFunc = method(pool) - - objects = zip(itertools.repeat(function), objects) - function = apply_print_exception - - countMessage = "" - try: - countMessage = " for {} tasks".format(len(objects)) - except TypeError: - pass + for idx, item in enumerate(iterable): + if idx % log_interval == 0: + current_time = time.time() + if (current_time - last_log_time) >= min_log_interval: + print(f"{message}\t{idx+1: 5d}/{total: 5d}") + last_log_time = current_time + yield idx, item - if message != "": - message += ": " + elapsed = time.time() - start_time + final_idx = idx + 1 if 'idx' in locals() else 0 - print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage)) - sys.stdout.flush() - currentTime = time.time() - rv = mapFunc(function, objects) - totalTime = time.time() - currentTime - print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime)) - sys.stdout.flush() - pool.close() - return rv + if elapsed >= min_log_interval or last_log_time > start_time: + print(f"{message} done in {elapsed:.1f}s!\t{final_idx: 5d}/{total: 5d}") -def ParallelMapReturnAsGenerator(function, objects, message="", enable=True, multiArg=True): - from .GlobalParameters import globalParameters +def imap_with_progress(pool, func, iterable, total, message, chunksize): + results = [] + for _, result in progress_logger(pool.imap(func, iterable, chunksize=chunksize), total, message): + results.append(result) + return results - threadCount = CPUThreadCount(enable) - print("{0}Launching {1} threads...".format(message, threadCount)) - if threadCount <= 1 and globalParameters["ShowProgressBar"]: - # Provide a progress bar for single-threaded operation. - callFunc = lambda args: function(*args) if multiArg else lambda args: function(args) - return [callFunc(args) for args in tqdm(objects, message)] +def _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild): + # separate fn because yield makes the entire fn a generator even if unreachable + ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn') - with concurrent.futures.ProcessPoolExecutor(max_workers=threadCount) as executor: - resultFutures = (executor.submit(function, *arg if multiArg else arg) for arg in objects) - for result in concurrent.futures.as_completed(resultFutures): - yield result.result() + with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild, + initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool: + for _, result in progress_logger(pool.imap_unordered(worker, objects, chunksize=chunksize), objLen, message): + yield result def ParallelMap2( - function, objects, message="", enable=True, multiArg=True, return_as="list", procs=None + function: Callable, + objects: Any, + message: str = "", + enable: bool = True, + multiArg: bool = True, + minChunkSize: int = 1, + maxWorkers: int = -1, + maxtasksperchild: int = 1024, + return_as: str = "list" ): + """Executes a function over a list of objects in parallel or sequentially. + + This function is generally equivalent to ``list(map(function, objects))``. However, it provides + additional functionality to run in parallel, depending on the 'enable' flag and available CPU + threads. + + Args: + function: The function to apply to each item in 'objects'. If 'multiArg' is True, 'function' + should accept multiple arguments. + objects: An iterable of objects to be processed by 'function'. If 'multiArg' is True, each + item in 'objects' should be an iterable of arguments for 'function'. + message: Optional; a message describing the operation. Default is an empty string. + enable: Optional; if False, disables parallel execution and runs sequentially. Default is True. + multiArg: Optional; if True, treats each item in 'objects' as multiple arguments for + 'function'. Default is True. + return_as: Optional; "list" (default) or "generator_unordered" for streaming results + + Returns: + A list or generator containing the results of applying **function** to each item in **objects**. """ - Generally equivalent to list(map(function, objects)), possibly executing in parallel. + from .GlobalParameters import globalParameters - message: A message describing the operation to be performed. - enable: May be set to false to disable parallelism. - multiArg: True if objects represent multiple arguments - (differentiates multi args vs single collection arg) - """ - if return_as in ("generator", "generator_unordered") and not joblibParallelSupportsGenerator(): - return ParallelMapReturnAsGenerator(function, objects, message, enable, multiArg) + threadCount = CPUThreadCount(enable) - from .GlobalParameters import globalParameters + if not hasattr(objects, "__len__"): + objects = list(objects) - threadCount = procs if procs else CPUThreadCount(enable) + objLen = len(objects) + if objLen == 0: + return [] if return_as == "list" else iter([]) - threadCount = CPUThreadCount(enable) + f = (lambda x: function(*x)) if multiArg else function + if objLen == 1: + print(f"{message}: (1 task)") + result = [f(x) for x in objects] + return result if return_as == "list" else iter(result) - if threadCount <= 1 and globalParameters["ShowProgressBar"]: - # Provide a progress bar for single-threaded operation. - return [function(*args) if multiArg else function(args) for args in tqdm(objects, message)] + extra_message = ( + f": {threadCount} thread(s)" + f", {objLen} tasks" + if objLen + else "" + ) - countMessage = "" - try: - countMessage = " for {} tasks".format(len(objects)) - except TypeError: - pass - - if message != "": - message += ": " - print("{0}Launching {1} threads{2}...".format(message, threadCount, countMessage)) - sys.stdout.flush() - currentTime = time.time() - - pcall = pcallWithGlobalParamsMultiArg if multiArg else pcallWithGlobalParamsSingleArg - pargs = zip(objects, itertools.repeat(globalParameters)) - - if joblibParallelSupportsGenerator(): - rv = Parallel(n_jobs=threadCount, timeout=99999, return_as=return_as)( - delayed(pcall)(function, a, params) for a, params in pargs - ) + print(f"ParallelMap {message}{extra_message}") + + if threadCount <= 1: + result = [f(x) for x in objects] + return result if return_as == "list" else iter(result) + + if maxWorkers > 0: + threadCount = min(maxWorkers, threadCount) + + chunksize = max(minChunkSize, objLen // 2000) + worker = partial(worker_function, function=function, multiArg=multiArg) + if return_as == "generator_unordered": + # yield results as they complete without buffering + return _ParallelMap_generator(worker, objects, objLen, message, chunksize, threadCount, globalParameters, maxtasksperchild) else: - rv = Parallel(n_jobs=threadCount, timeout=99999)( - delayed(pcall)(function, a, params) for a, params in pargs - ) - - totalTime = time.time() - currentTime - print("{0}Done. ({1:.1f} secs elapsed)".format(message, totalTime)) - sys.stdout.flush() - return rv + ctx = multiprocessing.get_context('forkserver' if os.name != 'nt' else 'spawn') + with ctx.Pool(processes=threadCount, maxtasksperchild=maxtasksperchild, + initializer=OverwriteGlobalParameters, initargs=(globalParameters,)) as pool: + return list(imap_with_progress(pool, worker, objects, objLen, message, chunksize)) diff --git a/projects/hipblaslt/tensilelite/Tensile/Common/Utilities.py b/projects/hipblaslt/tensilelite/Tensile/Common/Utilities.py index 11ce15e4ff1f..bf03af25a7b1 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Common/Utilities.py +++ b/projects/hipblaslt/tensilelite/Tensile/Common/Utilities.py @@ -24,6 +24,7 @@ import functools import math +import operator import os import sys import time @@ -269,8 +270,20 @@ def state(obj): def state_key_ordering(cls): - def tup(obj): - return tuple([getattr(obj, k) for k in cls.StateKeys]) + # Use operator.attrgetter for efficiency if __slots__ is defined + if hasattr(cls, '__slots__'): + # attrgetter is faster for slotted classes + getter = operator.attrgetter(*cls.StateKeys) + if len(cls.StateKeys) == 1: + # attrgetter returns scalar for single key, we need tuple + def tup(obj): + return (getter(obj),) + else: + tup = getter + else: + # Fallback for regular classes + def tup(obj): + return tuple([getattr(obj, k) for k in cls.StateKeys]) def lt(a, b): return tup(a) < tup(b) diff --git a/projects/hipblaslt/tensilelite/Tensile/Contractions.py b/projects/hipblaslt/tensilelite/Tensile/Contractions.py index 0c8d8aa6cd0f..32f2841fc2b6 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Contractions.py +++ b/projects/hipblaslt/tensilelite/Tensile/Contractions.py @@ -37,9 +37,60 @@ from math import ceil MIN_K_FOR_GSU = 32 + +# Interning helpers to reduce memory usage by reusing identical objects +_free_index_cache = {} +def intern_free_index(isA, i=None, c=None, d=None, a=None, b=None): + key = (isA, i, c, d, a, b) + if key not in _free_index_cache: + obj = FreeIndex(isA, i, c, d) + obj.a = a + if b is not None: + obj.b = b + _free_index_cache[key] = obj + return _free_index_cache[key] + +_batch_index_cache = {} +def intern_batch_index(a=None, b=None, c=None, d=None): + key = (a, b, c, d) + if key not in _batch_index_cache: + obj = BatchIndex(c=c, d=d) + obj.a = a + obj.b = b + _batch_index_cache[key] = obj + return _batch_index_cache[key] + +_bound_index_cache = {} +def intern_bound_index(a=None, b=None, aMirror=False, bMirror=False): + key = (a, b, aMirror, bMirror) + if key not in _bound_index_cache: + obj = BoundIndex(aMirror=aMirror, bMirror=bMirror) + obj.a = a + obj.b = b + _bound_index_cache[key] = obj + return _bound_index_cache[key] + +_size_mapping_cache = {} +def intern_size_mapping(size_mapping): + """Intern a SizeMapping instance to reduce redundancy.""" + # Build hashable key from StateKeys, converting lists to tuples + key_parts = [] + for attr in size_mapping.StateKeys: + val = getattr(size_mapping, attr) + # Convert lists to tuples for hashing + if isinstance(val, list): + val = tuple(val) + key_parts.append(val) + key = tuple(key_parts) + + if key not in _size_mapping_cache: + _size_mapping_cache[key] = size_mapping + return _size_mapping_cache[key] + @state_key_ordering class FreeIndex: StateKeys = ['isA', 'i', 'c', 'd'] + __slots__ = ['isA', 'i', 'c', 'd', 'a', 'b'] def __init__(self, isA, i=None, c=None, d=None): self.isA = isA @@ -50,6 +101,7 @@ def __init__(self, isA, i=None, c=None, d=None): @state_key_ordering class BatchIndex: StateKeys = ['a', 'b', 'c', 'd'] + __slots__ = ['a', 'b', 'c', 'd'] def __init__(self, a=None, b=None, c=None, d=None): self.a = a self.b = b @@ -59,6 +111,7 @@ def __init__(self, a=None, b=None, c=None, d=None): @state_key_ordering class BoundIndex: StateKeys = ['a', 'b', 'aMirror', 'bMirror'] + __slots__ = ['a', 'b', 'aMirror', 'bMirror'] def __init__(self, a=None, b=None, aMirror=False, bMirror=False): self.a = a self.b = b @@ -107,6 +160,23 @@ def FromOriginalState(cls, d): for ib, ic in enumerate(d['IndexAssignmentsB']): indices[ic].b = ib + # Now intern all indices with their final state (including .a and .b) + for i, idx in enumerate(indices): + if isinstance(idx, FreeIndex): + indices[i] = intern_free_index(idx.isA, idx.i, idx.c, idx.d, + getattr(idx, 'a', None), getattr(idx, 'b', None)) + elif isinstance(idx, BatchIndex): + indices[i] = intern_batch_index(getattr(idx, 'a', None), getattr(idx, 'b', None), + idx.c, idx.d) + elif isinstance(idx, BoundIndex): + indices[i] = intern_bound_index(getattr(idx, 'a', None), getattr(idx, 'b', None), + idx.aMirror, idx.bMirror) + + # Update the lists with interned versions + freeIndices = [idx for idx in indices if isinstance(idx, FreeIndex)] + batchIndices = [idx for idx in indices if isinstance(idx, BatchIndex)] + boundIndices = [idx for idx in indices if isinstance(idx, BoundIndex)] + for idx in indices: assert idx is not None idxState = state(idx) @@ -596,6 +666,7 @@ class SizeMapping: 'nonTemporalA', 'nonTemporalB', ] + __slots__ = StateKeys @classmethod def FromOriginalState(cls, d): @@ -706,7 +777,7 @@ def FromSolutionStruct( assembler: Assembler, isaInfoMap: Dict[str, IsaInfo] ): - return cls.FromOriginalState( + rv = cls.FromOriginalState( solution._state, splitGSU, printSolutionRejectionReason, @@ -715,6 +786,9 @@ def FromSolutionStruct( isaInfoMap, solution.srcName ) + # Store reference to original solution instead of creating duplicate + rv.originalSolution = solution + return rv @classmethod def FromOriginalState( @@ -751,7 +825,7 @@ def FromOriginalState( info = cls.ReadOriginalInfo(d) rv.libraryLogicIndex = int(info.get("SolutionIndex", -1)) - rv.sizeMapping = SizeMapping.FromOriginalState(d) + rv.sizeMapping = intern_size_mapping(SizeMapping.FromOriginalState(d)) rv.internalArgsSupport = InternalArgsSupport.FromOriginalState(d) @@ -788,7 +862,7 @@ def FromOriginalState( @classmethod def ReadOriginalInfo(cls, d): - return dict([(key, str(value)) for (key, value) in list(d.items()) if key != 'ProblemType']) + return {key: str(value) for key, value in d.items() if key != 'ProblemType'} def __init__(self, **kwargs): self.name = None diff --git a/projects/hipblaslt/tensilelite/Tensile/CustomKernels.py b/projects/hipblaslt/tensilelite/Tensile/CustomKernels.py index ffceb636f5e4..127b3386a1d7 100644 --- a/projects/hipblaslt/tensilelite/Tensile/CustomKernels.py +++ b/projects/hipblaslt/tensilelite/Tensile/CustomKernels.py @@ -24,7 +24,9 @@ from . import CUSTOM_KERNEL_PATH from Tensile.Common.ValidParameters import checkParametersAreValid, validParameters, newMIValidParameters +from Tensile.CustomYamlLoader import DEFAULT_YAML_LOADER +from functools import lru_cache import yaml import os @@ -58,10 +60,13 @@ def getCustomKernelConfigAndAssembly(name, directory=CUSTOM_KERNEL_PATH): return (config, assembly) +# getCustomKernelConfig will get called repeatedly on the same file +# 20x logic loading speedup for aquavanjaram_Cijk_Ailk_Bljk_F8NH_HHS_BH_Bias_HAS_SAB_SAV_freesize_custom_GSUs +@lru_cache def readCustomKernelConfig(name, directory=CUSTOM_KERNEL_PATH): rawConfig, _ = getCustomKernelConfigAndAssembly(name, directory) try: - return yaml.safe_load(rawConfig)["custom.config"] + return yaml.load(rawConfig, Loader=DEFAULT_YAML_LOADER)["custom.config"] except yaml.scanner.ScannerError as e: raise RuntimeError("Failed to read configuration for custom kernel: {0}\nDetails:\n{1}".format(name, e)) diff --git a/projects/hipblaslt/tensilelite/Tensile/CustomYamlLoader.py b/projects/hipblaslt/tensilelite/Tensile/CustomYamlLoader.py index e03f456fbecf..9fdf38d8e5d2 100644 --- a/projects/hipblaslt/tensilelite/Tensile/CustomYamlLoader.py +++ b/projects/hipblaslt/tensilelite/Tensile/CustomYamlLoader.py @@ -1,6 +1,7 @@ # Copyright © Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT +import sys import yaml from pathlib import Path @@ -46,6 +47,11 @@ def parse_mapping(loader: yaml.Loader): return ret def is_float(value): + if not value: + return False + first_char = value[0] + if not (first_char in '+-.' or first_char.isdigit()): + return False try: float(value) return True @@ -56,21 +62,31 @@ def parse_scalar(loader: yaml.Loader): assert loader.check_event(yaml.ScalarEvent) evt = loader.get_event() value: str = evt.value - value_lower: str = value.lower() - if value_lower in ('true', 'yes',): + if not value: + if not evt.style: + return None + return value + + first_char = value[0] + if first_char in '+-' or first_char.isdigit(): + stripped = value.lstrip('+-') + if stripped.isdigit(): + return int(value) + elif is_float(value): + return float(value) + + value_folded = value.casefold() + + if value_folded in ('true', 'yes'): return True - elif value_lower in ('false', 'no',): + elif value_folded in ('false', 'no'): return False - elif value_lower in ('null', '', '~'): + elif value_folded in ('null', '~'): if not evt.style: return None - elif value_lower.lstrip('+-').isnumeric(): - return int(value_lower) - elif is_float(value_lower): - return float(value_lower) - return value + return sys.intern(value) def load_yaml_stream(yaml_path: Path, loader_type: yaml.Loader): with open(yaml_path, 'r') as f: diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Naming.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Naming.py index 99535e246650..d105a1f4eed4 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Naming.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Naming.py @@ -30,22 +30,24 @@ from .Problem import ProblemType -def getKeyNoInternalArgs(state, splitGSU: bool): - state_copy = deepcopy(state) - state_copy["ProblemType"]["GroupedGemm"] = False - if splitGSU: - state_copy["GlobalSplitU"] = "M" if (state_copy["GlobalSplitU"] > 1 or state_copy["GlobalSplitU"] == -1) else state_copy["GlobalSplitU"] - elif state["GlobalSplitU"] != 0: - state_copy["GlobalSplitU"] = "M" - state_copy["WorkGroupMapping"] = "M" - state_copy["WorkGroupMappingXCC"] = "M" - state_copy["WorkGroupMappingXCCGroup"] = "M" - state_copy["StaggerU"] = "M" - state_copy["StaggerUStride"] = "M" - state_copy["StaggerUMapping"] = "M" - state_copy["GlobalSplitUCoalesced"] = "M" - state_copy["GlobalSplitUWorkGroupMappingRoundRobin"] = "M" - return state_copy +_INTERNAL_ARG_KEYS = frozenset([ + "GlobalSplitUCoalesced", "GlobalSplitUWorkGroupMappingRoundRobin", + "StaggerU", "StaggerUMapping", "StaggerUStride", + "WorkGroupMapping", "WorkGroupMappingXCC", "WorkGroupMappingXCCGroup" +]) + + +def getKeyNoInternalArgs(state, splitGSU: bool) -> str: + """ + Returns a string that uniquely identifies solutions that differ only in + "internal args" (WorkGroupMapping, StaggerU, GlobalSplitUCoalesced, etc). + + The name includes these params but with normalized values ("M"), so solutions + with different internal args get the same key. + """ + # Previously this returned a deepcopy of the entire state dict with internal args set to "M" + # Now that _getName supports normalizing these to "M" we can avoid deepcopying the huge state dict + return _getName(state, getRequiredParametersFull(), splitGSU, ignoreInternalArgs=False, normalizeInternalArgs=True) @lru_cache(maxsize=None) @@ -54,7 +56,7 @@ def getParameterNameAbbreviation( name: str ): @ lru_cache(maxsize=None) -def getPrimitiveParameterValueAbbreviation(key, value): +def getPrimitiveParameterValueAbbreviation(value): if isinstance(value, str): return getParameterNameAbbreviation(value) elif isinstance(value, bool): @@ -81,7 +83,7 @@ def getParameterValueAbbreviation(key, value): return f"{value[0]}{value[1]}{value[2]:x}" compositieTypes = (dict, list, tuple,) if not isinstance(value, compositieTypes): - return getPrimitiveParameterValueAbbreviation(key, value) + return getPrimitiveParameterValueAbbreviation(value) elif isinstance(value, tuple): return ''.join(str(v) for v in value) elif isinstance(value, list): @@ -92,18 +94,29 @@ def getParameterValueAbbreviation(key, value): raise Exception(f"Parameter {key}={value} is new object type ({type(value)})") -def _getName(state, requiredParameters: frozenset, splitGSU: bool, ignoreInternalArgs): +def _getName(state, requiredParameters: frozenset, splitGSU: bool, ignoreInternalArgs, normalizeInternalArgs=False): + """ + Generate a solution/kernel name from state parameters. + Args: + state: Solution state dict + requiredParameters: Set of parameter names to include in the name + splitGSU: Whether to handle GlobalSplitU specially + ignoreInternalArgs: If True, exclude internal args from name (kernel mode) + normalizeInternalArgs: If True, include internal args but set to "M" (deduplication mode) + """ if "CustomKernelName" in state and state["CustomKernelName"]: return state["CustomKernelName"] gsuBackup = state["GlobalSplitU"] ggBackup = state["ProblemType"]["GroupedGemm"] - if ignoreInternalArgs: + if ignoreInternalArgs or normalizeInternalArgs: state["ProblemType"]["GroupedGemm"] = False if splitGSU: state["GlobalSplitU"] = "M" if (state["GlobalSplitU"] > 1 or state["GlobalSplitU"] == -1) else state["GlobalSplitU"] + elif normalizeInternalArgs and state["GlobalSplitU"] != 0: + state["GlobalSplitU"] = "M" requiredParametersTemp = set(requiredParameters.union(["GlobalSplitU"])) @@ -138,6 +151,10 @@ def _getName(state, requiredParameters: frozenset, splitGSU: bool, ignoreInterna components.append('SN') for key in sorted(state.keys()): if key[0] != '_' and key != "CustomKernelName" and key in requiredParametersTemp: + # When normalizing, use "M" for internal args instead of actual value + if normalizeInternalArgs and key in _INTERNAL_ARG_KEYS: + components.append(f'{getParameterNameAbbreviation(key)}M') + else: components.append(f'{getParameterNameAbbreviation(key)}{getParameterValueAbbreviation(key, state[key])}') state["GlobalSplitU"] = gsuBackup diff --git a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py index 7bebb58e89aa..f365f16a6c6d 100644 --- a/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py +++ b/projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py @@ -3347,7 +3347,7 @@ def getAttributes(self): return self._state def __hash__(self): - return hash(str(self) + self._state.get("codeObjectFile", "")) + return hash(str(self)) #return hash(self.getAttributes()) def __eq__(self, other): diff --git a/projects/hipblaslt/tensilelite/Tensile/TensileCreateLibrary/Run.py b/projects/hipblaslt/tensilelite/Tensile/TensileCreateLibrary/Run.py index 835ed9c01916..86204b75558f 100644 --- a/projects/hipblaslt/tensilelite/Tensile/TensileCreateLibrary/Run.py +++ b/projects/hipblaslt/tensilelite/Tensile/TensileCreateLibrary/Run.py @@ -24,10 +24,13 @@ import rocisa +import collections import functools import glob +import gc import itertools import os +import resource import shutil from pathlib import Path from timeit import default_timer as timer @@ -78,12 +81,30 @@ from .ParseArguments import parseArguments +def getMemoryUsage(): + """Get peak and current memory usage in MB.""" + rusage = resource.getrusage(resource.RUSAGE_SELF) + peak_memory_mb = rusage.ru_maxrss / 1024 # KB to MB on Linux + + # Get current memory from /proc/self/status + current_memory_mb = 0 + try: + with open('/proc/self/status') as f: + for line in f: + if line.startswith('VmRSS:'): + current_memory_mb = int(line.split()[1]) / 1024 # KB to MB + break + except: + current_memory_mb = peak_memory_mb # Fallback + + return (peak_memory_mb, current_memory_mb) + + class KernelCodeGenResult(NamedTuple): err: int src: str header: Optional[str] name: str - targetObjFilename: str isa: IsaVersion wavefrontSize: int cuoccupancy: int @@ -106,15 +127,37 @@ def processKernelSource(kernelWriterAssembly, data, splitGSU, kernel) -> KernelC asmFilename = getKernelFileBase(splitGSU, kernel) err, src = kernelWriter.getSourceFileString(kernel) header = kernelWriter.getHeaderFileString(kernel) - objFilename = kernel._state.get("codeObjectFile", None) pgr = int(kernel["PrefetchGlobalRead"]) return KernelCodeGenResult( - err, src, header, asmFilename, objFilename, tuple(kernel["ISA"]), \ + err, src, header, asmFilename, tuple(kernel["ISA"]), \ kernel["WavefrontSize"], kernel["CUOccupancy"], \ pgr, kernel["MathClocksUnrolledLoop"] ) +def processAndAssembleKernelTCL(kernelWriterAssembly, rocisa_data, splitGSU, kernel, assemblyTmpPath, assembler): + """ + Pipeline function for TCL mode that: + 1. Generates kernel source + 2. Writes .s file to disk + 3. Assembles to .o file + 4. Deletes .s file + """ + result = processKernelSource(kernelWriterAssembly, rocisa_data, splitGSU, kernel) + return writeAndAssembleKernel(result, assemblyTmpPath, assembler) + + +def writeMasterSolutionLibrary(name_lib_tuple, newLibraryDir, splitGSU, libraryFormat): + """ + Write a master solution library to disk. + Module-level function to support multiprocessing. + """ + name, lib = name_lib_tuple + filename = os.path.join(newLibraryDir, name) + lib.applyNaming(splitGSU) + LibraryIO.write(filename, state(lib), libraryFormat) + + def removeInvalidSolutionsAndKernels(results, kernels, solutions, errorTolerant, printLevel: bool, splitGSU: bool): removeKernels = [] removeKernelNames = [] @@ -166,12 +209,12 @@ def passPostKernelInfoToSolution(results, kernels, solutions, splitGSU: bool): resultDict = {} for kernIdx, r in enumerate(results): kName = getKernelNameMin(kernels[kernIdx], splitGSU) - resultDict["%s"%kName] = r + resultDict[kName] = r for solution in solutions: solutionKernels = solution.getKernels() for kernel in solutionKernels: kName = getKernelNameMin(kernel, splitGSU) - result = resultDict["%s"%kName] + result = resultDict[kName] solution._state["CUOccupancy"] = result.cuoccupancy solution._state["PrefetchGlobalRead"] = result.pgr solution._state["MathClocksUnrolledLoop"] = result.mathclk @@ -189,6 +232,24 @@ def writeAssembly(asmPath: Union[Path, str], result: KernelCodeGenResult): return path, isa, wfsize, minResult +def writeAndAssembleKernel(result: KernelCodeGenResult, asmPath: Union[Path, str], assembler): + """Write assembly file and immediately assemble it to .o file""" + if result.err: + printExit(f"Failed to build kernel {result.name} because it has error code {result.err}") + + path = Path(asmPath) / f"{result.name}.s" + with open(path, "w", encoding="utf-8") as f: + f.write(result.src) + + # Assemble .s -> .o + assembler(isaToGfx(result.isa), result.wavefrontSize, str(path), str(path.with_suffix(".o"))) + + # Delete assembly file immediately to save disk space + path.unlink() + + return KernelMinResult(result.err, result.cuoccupancy, result.pgr, result.mathclk) + + def writeHelpers( outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H ): @@ -257,9 +318,8 @@ def writeSolutionsAndKernels( duplicates = 0 for k in asmKernels: base = getKernelFileBase(splitGSU, k) + k["BaseName"] = base k.duplicate = True if base in visited else False - if not k.duplicate: - k["BaseName"] = base duplicates += k.duplicate print2(f"Duplicate: {base}") visited.add(base) @@ -268,13 +328,14 @@ def writeSolutionsAndKernels( numAsmKernels = len(asmKernels) numKernels = len(asmKernels) assert numKernels == numAsmKernels, "Only assembly kernels are supported in TensileLite" - asmIter = zip( - itertools.repeat(kernelWriterAssembly), - itertools.repeat(rocisa.rocIsa.getInstance().getData()), - itertools.repeat(splitGSU), - asmKernels + + processKernelFn = functools.partial( + processKernelSource, + kernelWriterAssembly=kernelWriterAssembly, + data=rocisa.rocIsa.getInstance().getData(), + splitGSU=splitGSU ) - asmResults = ParallelMap2(processKernelSource, asmIter, "Generating assembly kernels", return_as="list") + asmResults = ParallelMap2(processKernelFn, asmKernels, "Generating assembly kernels", return_as="list", multiArg=False) removeInvalidSolutionsAndKernels( asmResults, asmKernels, solutions, errorTolerant, getVerbosity(), splitGSU ) @@ -282,19 +343,21 @@ def writeSolutionsAndKernels( asmResults, asmKernels, solutions, splitGSU ) - def assemble(ret): - p, isa, wavefrontsize, result = ret - asmToolchain.assembler(isaToGfx(isa), wavefrontsize, str(p), str(p.with_suffix(".o"))) - - unaryWriteAssembly = functools.partial(writeAssembly, assemblyTmpPath) - compose = lambda *F: functools.reduce(lambda f, g: lambda x: f(g(x)), F) + # Use functools.partial to bind assemblyTmpPath and assembler + writeAndAssembleFn = functools.partial( + writeAndAssembleKernel, + asmPath=assemblyTmpPath, + assembler=asmToolchain.assembler + ) ret = ParallelMap2( - compose(assemble, unaryWriteAssembly), + writeAndAssembleFn, asmResults, "Writing assembly kernels", return_as="list", multiArg=False, ) + del asmResults + gc.collect() writeHelpers(outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H) srcKernelFile = Path(outputPath) / "Kernels.cpp" @@ -309,13 +372,21 @@ def assemble(ret): yappi.get_thread_stats().print_all(out=f) if not generateSourcesAndExit: + # Build cofile_objects from solutions for the old code path (no lazy loading) + # All solutions go into the default .co file (None) per architecture + cofile_objects = collections.defaultdict(lambda: collections.defaultdict(list)) + for solution in solutions: + isa = tuple(solution.getKernels()[0]['ISA']) + # Use a dummy index of 0 since we're not tracking solution indices in the old code path + cofile_objects[isa][None].append((0, solution)) + codeObjectFiles += buildAssemblyCodeObjectFiles( asmToolchain.linker, asmToolchain.bundler, - asmKernels, destLibPath, assemblyTmpPath, compress, + cofile_objects, ) buildSourceCodeObjectFiles( srcToolchain.compiler, @@ -339,6 +410,7 @@ def writeSolutionsAndKernelsTCL( kernelHelperObjs, kernelWriterAssembly, cmdlineArchs: List[str], + cofile_objects, compress=True, ): outputPath = Path(outputPath) @@ -367,41 +439,49 @@ def writeSolutionsAndKernelsTCL( visited.add(base) print1(f"Number of duplicate kernels: {duplicates}") - uniqueAsmKernels = [k for k in asmKernels if not k.duplicate] + # Also set BaseName on ALL solutions' kernels (including those with duplicate kernels) + # since solutions that weren't in the unique kernels list won't have BaseName set yet + # FIXME: this is jank + for solution in solutions: + for kernel in solution.getKernels(): + if "BaseName" not in kernel: + base = getKernelFileBase(splitGSU, kernel) + kernel["BaseName"] = base - def assemble(ret): - p, isa, wavefrontsize, result = ret - asmToolchain.assembler(isaToGfx(isa), wavefrontsize, str(p), str(p.with_suffix(".o"))) - return result + uniqueAsmKernels = [k for k in asmKernels if not k.duplicate] - unaryProcessKernelSource = functools.partial( - processKernelSource, + processKernelFn = functools.partial( + processAndAssembleKernelTCL, kernelWriterAssembly, rocisa.rocIsa.getInstance().getData(), splitGSU, + assemblyTmpPath=assemblyTmpPath, + assembler=asmToolchain.assembler ) - unaryWriteAssembly = functools.partial(writeAssembly, assemblyTmpPath) - compose = lambda *F: functools.reduce(lambda f, g: lambda x: f(g(x)), F) - ret = ParallelMap2( - compose(assemble, unaryWriteAssembly, unaryProcessKernelSource), + results = ParallelMap2( + processKernelFn, uniqueAsmKernels, "Generating assembly kernels", multiArg=False, return_as="list" ) + del processKernelFn + gc.collect() + passPostKernelInfoToSolution( - ret, uniqueAsmKernels, solutions, splitGSU + results, uniqueAsmKernels, solutions, splitGSU ) - # result.src is very large so let garbage collector know to clean up - del ret + del results + gc.collect() + buildAssemblyCodeObjectFiles( asmToolchain.linker, asmToolchain.bundler, - asmKernels, destLibPath, assemblyTmpPath, compress, + cofile_objects, ) writeHelpers(outputPath, kernelHelperObjs, KERNEL_HELPER_FILENAME_CPP, KERNEL_HELPER_FILENAME_H) @@ -493,41 +573,45 @@ def generateKernelHelperObjects(solutions: List[Solution], cxxCompiler: str, isa return sorted(khos, key=sortByEnum, reverse=True) # Ensure that we write Enum kernel helpers are first in list +def libraryIter(lib: MasterSolutionLibrary): + if len(lib.solutions): + for i, s in enumerate(lib.solutions.items()): + yield (i, *s) + else: + for _, lazyLib in lib.lazyLibraries.items(): + yield from libraryIter(lazyLib) + + @timing def generateLogicDataAndSolutions(logicFiles, args, assembler: Assembler, isaInfoMap): + # NB: Be careful with the two solution types + # Contractions.Solution's originalSolution field contains SolutionStruct.Solution + # masterLibrary.lazyLibraries[i].solutions and masterLibrary.solutions both contain + # Contraction Solutions, which then refer to a SolutionStructs.Solution in their originalSolution - if ";" in args["Architecture"]: - archs = args["Architecture"].split(";") # user arg list format - else: - archs = args["Architecture"].split("_") # workaround for cmake list in list issue - - solutions = [] - masterLibraries = {} + masterLibraries: dict[str, MasterSolutionLibrary] = {} nextSolIndex = 0 splitGSU = False printSolutionRejectionReason = True printIndexAssignmentInfo = False - fIter = zip( - logicFiles, - itertools.repeat(assembler), - itertools.repeat(splitGSU), - itertools.repeat(printSolutionRejectionReason), - itertools.repeat(printIndexAssignmentInfo), - itertools.repeat(isaInfoMap), - itertools.repeat(args["LazyLibraryLoading"]), + parseLogicFn = functools.partial( + LibraryIO.parseLibraryLogicFile, + assembler=assembler, + splitGSU=splitGSU, + printSolutionRejectionReason=printSolutionRejectionReason, + printIndexAssignmentInfo=printIndexAssignmentInfo, + isaInfoMap=isaInfoMap, + lazyLibraryLoading=args["LazyLibraryLoading"] ) - def libraryIter(lib: MasterSolutionLibrary): - if len(lib.solutions): - for i, s in enumerate(lib.solutions.items()): - yield (i, *s) - else: - for _, lazyLib in lib.lazyLibraries.items(): - yield from libraryIter(lazyLib) - for library in ParallelMap2( - LibraryIO.parseLibraryLogicFile, fIter, "Loading Logics...", return_as="generator_unordered" + parseLogicFn, logicFiles, "Loading Logics...", + return_as="generator_unordered", + minChunkSize=12, + maxWorkers=64, + maxtasksperchild=1, + multiArg=False, ): _, architectureName, _, _, _, newLibrary = library @@ -539,10 +623,13 @@ def libraryIter(lib: MasterSolutionLibrary): else: masterLibraries[architectureName] = newLibrary masterLibraries[architectureName].version = args["CodeObjectVersion"] + del library, newLibrary, _ + + gc.collect() # Sort masterLibraries to make global soln index values deterministic solnReIndex = 0 - masterLibraries = dict(sorted(masterLibraries.items())) + masterLibraries: dict[str, MasterSolutionLibrary] = dict(sorted(masterLibraries.items())) for _, masterLibrary in masterLibraries.items(): for _, sol in masterLibrary.solutions.items(): sol.index = solnReIndex @@ -557,6 +644,12 @@ def libraryIter(lib: MasterSolutionLibrary): } for _, sol in lib.solutions.items(): sol.index = solnReIndex + if "BaseName" in sol.originalSolution._state: + # FIXME: clearing BaseName here since it's often a) in logic yaml b) wrong + # gfx{1200,1201}_Cijk_Alik_Bljk_F8BS_BH_Bias_HA_S_SABV_SAV_UserArgs.yaml + # both have it set to the same value which is not what gets computed at runtime + # Maybe we can do something smarter below? + del sol.originalSolution._state["BaseName"] solnReIndex += 1 if args["GenSolTable"]: @@ -572,39 +665,94 @@ def libraryIter(lib: MasterSolutionLibrary): if key != "fallback": value.merge(masterLibraries["fallback"]) masterLibraries.pop("fallback") - solIndex = [] + + # Validate lazy loading invariant: solutions and lazyLibraries are mutually exclusive + lazyLoading = args["LazyLibraryLoading"] + for archName, masterLibrary in masterLibraries.items(): + hasSolutions = len(masterLibrary.solutions) > 0 + hasLazyLibraries = len(masterLibrary.lazyLibraries) > 0 + + if lazyLoading and hasSolutions: + printExit(f"Architecture {archName}: LazyLibraryLoading is enabled but masterLibrary.solutions is not empty ({len(masterLibrary.solutions)} solutions found)") + if not lazyLoading and hasLazyLibraries: + printExit(f"Architecture {archName}: LazyLibraryLoading is disabled but masterLibrary.lazyLibraries is not empty ({len(masterLibrary.lazyLibraries)} lazy libraries found)") + + codeObjectFilesIndex = {} + # YAML files with different CUCount values (104CU vs 110CU) may contain identical solutions + # that generate the same .S and .o files, but need to be linked into different .co files + # Example: aldebaran_Cijk_Ailk_Bjlk_SB_Bias_HA_SAV.yaml from 104CU (SolutionIndex 210, 1122) and + # 110CU (SolutionIndex 1362) all contain the identical solution + # "Cijk_Ailk_Bjlk_S_B_Bias_HA_S_SAV_UserArgs_MT128x128x16_MI32x32x1_SN_LDSB0_AA0..." + # which produces the same .S/.o file but must be linked into both: + # - TensileLibrary_SS_SS_HA_Bias_SAV_Type_SS_Contraction_l_Ailk_Bjlk_Cijk_Dijk_CU104_gfx90a.co + # - TensileLibrary_SS_SS_HA_Bias_SAV_Type_SS_Contraction_l_Ailk_Bjlk_Cijk_Dijk_gfx90a.co + # We build cofile_objects grouped by ISA and .co file name, with (sol.index, solution) tuples + # that will be sorted by index before linking to ensure correct kernel ordering per .co file. + cofile_objects = collections.defaultdict(lambda: collections.defaultdict(list)) + # {isa: {cofile_name: [(sol.index, solution), ...]}} + + solutions: list[Solution] = [] + # When tracking cofile_objects we reuse an existing same named Solution instance + # if present here + seenSolutions: dict[str, Solution] = {} + seenCodeObjectSolutions: tuple[bool, str, str] = set() + for _, masterLibrary in masterLibraries.items(): for _, sol in masterLibrary.solutions.items(): - solutions.append(sol.originalSolution) - solIndex.append(sol.index) + tensileSolution: Solution = sol.originalSolution + solutionName = str(tensileSolution) + isa = tensileSolution["ISA"] + if (isa, solutionName) in seenCodeObjectSolutions: + continue + seenCodeObjectSolutions.add((False, isa, solutionName)) + solutions.append(tensileSolution) + # Track that this solution goes in the default .co file (no codeObjectFile attribute) + cofile_objects[isa][None].append((sol.index, seenSolutions.setdefault(solutionName, tensileSolution))) for name, lib in masterLibrary.lazyLibraries.items(): for _, sol in lib.solutions.items(): - sol.originalSolution._state["codeObjectFile"] = name - solutions.append(sol.originalSolution) - solIndex.append(sol.index) - - # Get the solution index and it's codeObjectFile name - codeObjectFilesIndex = {} - for solution, index in zip(solutions, solIndex): - if "codeObjectFile" in solution._state and solution._state["codeObjectFile"] is not None: - if solution._state["codeObjectFile"] in codeObjectFilesIndex: - codeObjectFilesIndex[solution._state["codeObjectFile"]] = min(index, codeObjectFilesIndex[solution._state["codeObjectFile"]]) - else: - codeObjectFilesIndex[solution._state["codeObjectFile"]] = index - - # Reorder to int: name format + tensileSolution: Solution = sol.originalSolution + solutionName = str(tensileSolution) + if (name, solutionName) in seenCodeObjectSolutions: + continue + seenCodeObjectSolutions.add((True, name, solutionName)) + solutions.append(tensileSolution) + isa = tensileSolution["ISA"] + # Track which .co file(s) this solution needs to be linked into + cofile_objects[isa][name].append((sol.index, seenSolutions.setdefault(solutionName, tensileSolution))) + # Build lazy library mapping directly + if name not in codeObjectFilesIndex: + codeObjectFilesIndex[name] = sol.index + else: + codeObjectFilesIndex[name] = min(codeObjectFilesIndex[name], sol.index) + del seenCodeObjectSolutions + + # codeObjectFilesIndex uses sol.index values assigned during YAML parsing (lines 619-631) + # BEFORE deduplication. These indices are stable identifiers for lazy loading, not array positions. + # The runtime uses them via upper_bound() for range-based lookup: "indices 0-499 are in file A, + # indices 500-999 are in file B". After deduplication, some indices may not have corresponding + # solutions (if they were duplicates), but the mapping remains valid codeObjectFilesIndex = {v: k for k, v in codeObjectFilesIndex.items()} # Reorder to maintain ascending order by index codeObjectFilesIndex = dict(sorted(codeObjectFilesIndex.items())) # remove duplicates while preserving order numSoln = len(solutions) - solutions = dict.fromkeys(solutions).keys() + solutions = seenSolutions.values() - print1(f"Number of solutions parsed: {numSoln}") + print1(f"Number of solutions parsed: {solnReIndex}") + print1(f"Number of unique solutions accounting for .co name: {numSoln}") print1(f"Number of unique solutions: {len(solutions)}") - return solutions, masterLibraries, codeObjectFilesIndex + # Count solutions that appear in multiple .co files across all ISAs + solution_cofile_counts = collections.defaultdict(set) + for isa_cofiles in cofile_objects.values(): + for cofile_name, sol_list in isa_cofiles.items(): + for sol_idx, sol in sol_list: + solution_cofile_counts[str(sol)].add(cofile_name) + num_multi_co = sum(1 for cofiles in solution_cofile_counts.values() if len(cofiles) > 1) + print1(f"Number of solutions needing multiple .co files: {num_multi_co}") + + return solutions, masterLibraries, codeObjectFilesIndex, cofile_objects ################################################################################ @@ -707,7 +855,7 @@ def validLogicFile(p: Path): print2("# %s" % logicFile) start_glds = timer() - solutions, masterLibraries, libraryMapping = generateLogicDataAndSolutions( + solutions, masterLibraries, libraryMapping, cofile_objects = generateLogicDataAndSolutions( logicFiles, arguments, asmToolchain.assembler, isaInfoMap ) stop_glds = timer() @@ -730,10 +878,14 @@ def validLogicFile(p: Path): kernelHelperObjs, kernelWriterAssembly, archs, + cofile_objects, compress=arguments["UseCompression"], ) stop_wsk = timer() print(f"Time to generate kernels (s): {(stop_wsk-start_wsk):3.2f}") + numKernelHelperObjs = len(kernelHelperObjs) + del kernelWriterAssembly, kernelHelperObjs + gc.collect() archs = [ # is this really different than the other archs above? isaToGfx(arch) @@ -749,15 +901,12 @@ def validLogicFile(p: Path): for kernel in solutionKernels: kName = getKeyNoInternalArgs(kernel, False) if kName not in solDict: - solDict["%s"%kName] = kernel - - def writeMsl(name, lib): - filename = os.path.join(newLibraryDir, name) - lib.applyNaming(splitGSU) - LibraryIO.write(filename, state(lib), arguments["LibraryFormat"]) + solDict[kName] = kernel filename = os.path.join(newLibraryDir, "TensileLiteLibrary_lazy_Mapping") LibraryIO.write(filename, libraryMapping, "msgpack") + del libraryMapping + gc.collect() start_msl = timer() for archName, newMasterLibrary in masterLibraries.items(): @@ -772,14 +921,24 @@ def writeMsl(name, lib): for name, lib in newMasterLibrary.lazyLibraries.items(): for k, s in lib.solutions.items(): kName = getKeyNoInternalArgs(s.originalSolution, splitGSU) - s.sizeMapping.CUOccupancy = solDict["%s"%kName]["CUOccupancy"] + s.sizeMapping.CUOccupancy = solDict[kName]["CUOccupancy"] + + writeFn = functools.partial( + writeMasterSolutionLibrary, + newLibraryDir=newLibraryDir, + splitGSU=splitGSU, + libraryFormat=arguments["LibraryFormat"] + ) - ParallelMap2(writeMsl, + ParallelMap2(writeFn, newMasterLibrary.lazyLibraries.items(), "Writing master solution libraries", + multiArg=False, return_as="list") stop_msl = timer() print(f"Time to write master solution libraries (s): {(stop_msl-start_msl):3.2f}") + del masterLibraries, solutions, kernels, solDict + gc.collect() if not arguments["KeepBuildTmp"]: buildTmp = Path(arguments["OutputPath"]).parent / "library" / "build_tmp" @@ -796,8 +955,11 @@ def writeMsl(name, lib): print("") stop = timer() + peak_memory_mb, current_memory_mb = getMemoryUsage() print(f"Total time (s): {(stop-start):3.2f}") print(f"Total kernels processed: {numKernels}") print(f"Kernels processed per second: {(numKernels/(stop-start)):3.2f}") - print(f"KernelHelperObjs: {len(kernelHelperObjs)}") + print(f"KernelHelperObjs: {numKernelHelperObjs}") + print(f"Peak memory usage (MB): {peak_memory_mb:,.1f}") + print(f"Current memory usage (MB): {current_memory_mb:,.1f}") diff --git a/projects/hipblaslt/tensilelite/Tensile/TensileMergeLibrary.py b/projects/hipblaslt/tensilelite/Tensile/TensileMergeLibrary.py index e33c617b6f8a..ba163e991842 100644 --- a/projects/hipblaslt/tensilelite/Tensile/TensileMergeLibrary.py +++ b/projects/hipblaslt/tensilelite/Tensile/TensileMergeLibrary.py @@ -303,8 +303,7 @@ def avoidRegressions(originalDir, incrementalDir, outputPath, forceMerge, noEff= logicsFiles[origFile] = origFile logicsFiles[incFile] = incFile - iters = zip(logicsFiles.keys()) - logicsList = ParallelMap2(loadData, iters, "Loading Logics...", return_as="list") + logicsList = ParallelMap2(loadData, logicsFiles.keys(), "Loading Logics...", return_as="list", multiArg=False) logicsDict = {} for i, _ in enumerate(logicsList): logicsDict[logicsList[i][0]] = logicsList[i][1] diff --git a/projects/hipblaslt/tensilelite/Tensile/TensileUpdateLibrary.py b/projects/hipblaslt/tensilelite/Tensile/TensileUpdateLibrary.py index 5ff265d0edcc..c1803a634995 100644 --- a/projects/hipblaslt/tensilelite/Tensile/TensileUpdateLibrary.py +++ b/projects/hipblaslt/tensilelite/Tensile/TensileUpdateLibrary.py @@ -26,7 +26,7 @@ from .Tensile import addCommonArguments, argUpdatedGlobalParameters from .Common import assignGlobalParameters, print1, restoreDefaultGlobalParameters, HR, \ - globalParameters, architectureMap, ensurePath, ParallelMap, __version__ + globalParameters, architectureMap, ensurePath, ParallelMap2, __version__ import argparse import copy @@ -149,7 +149,7 @@ def TensileUpdateLibrary(userArgs): for logicFile in logicFiles: print("# %s" % logicFile) fIter = zip(logicFiles, itertools.repeat(args.logic_path), itertools.repeat(outputPath)) - libraries = ParallelMap(UpdateLogic, fIter, "Updating logic files", method=lambda x: x.starmap) + libraries = ParallelMap2(UpdateLogic, fIter, "Updating logic files", multiArg=True, return_as="list") def main(): diff --git a/projects/hipblaslt/tensilelite/Tensile/Toolchain/Assembly.py b/projects/hipblaslt/tensilelite/Tensile/Toolchain/Assembly.py index a8b91e8d6210..20bcaca72b46 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Toolchain/Assembly.py +++ b/projects/hipblaslt/tensilelite/Tensile/Toolchain/Assembly.py @@ -30,7 +30,7 @@ from pathlib import Path from typing import List, Union, NamedTuple -from Tensile.Common import print2 +from Tensile.Common import print1, print2 from Tensile.Common.Architectures import isaToGfx from ..SolutionStructs import Solution @@ -52,53 +52,112 @@ def makeAssemblyToolchain(assembler_path, bundler_path, co_version, build_id_kin def buildAssemblyCodeObjectFiles( linker: Linker, bundler: Bundler, - kernels: List[Solution], destDir: Union[Path, str], asmDir: Union[Path, str], compress: bool=True, + cofile_objects: dict=None, ): """Builds code object files from assembly files Args: - toolchain: The assembly toolchain object to use for building. - kernels: A list of the kernel objects to build. - writer: The KernelWriterAssembly object to use. + linker: The linker object for combining .o files. + bundler: The bundler object for compressing .co files. destDir: The destination directory for the code object files. asmDir: The directory containing the assembly files. compress: Whether to compress the code object files. + cofile_objects: Mapping from ISA to dict of cofile_name to list of (sol.index, solution) tuples. + Format: {isa: {cofile_name: [(sol.index, solution), ...]}} """ + if cofile_objects is None: + raise RuntimeError("cofile_objects must be provided to buildAssemblyCodeObjectFiles") + extObj = ".o" extCo = ".co" extCoRaw = ".co.raw" - - archKernelMap = collections.defaultdict(list) - for k in kernels: - archKernelMap[tuple(k['ISA'])].append(k) + asmDir = Path(asmDir) + destDir = Path(destDir) coFiles = [] - for arch, archKernels in archKernelMap.items(): - if len(archKernels) == 0: - continue - - gfx = isaToGfx(arch) - - objectFiles = [str(asmDir / (k["BaseName"] + extObj)) for k in archKernels if 'codeObjectFile' not in k] - coFileMap = collections.defaultdict(set) - if len(objectFiles): - coFileMap[asmDir / ("TensileLibrary_"+ gfx + extCoRaw)] = objectFiles - for kernel in archKernels: - coName = kernel.get("codeObjectFile", None) - if coName: - coFileMap[asmDir / (coName + extCoRaw)].add(str(asmDir / (kernel["BaseName"] + extObj))) - - for coFileRaw, objFiles in coFileMap.items(): - linker(objFiles, str(coFileRaw)) - coFile = destDir / coFileRaw.name.replace(extCoRaw, extCo) - if compress: - bundler.compress(str(coFileRaw), str(coFile), gfx) - else: - shutil.move(coFileRaw, coFile) - coFiles.append(coFile) + + lostSolutions = [] + # Build a global map of .o files to their reference counts and which .co files reference them + objFileRefCount = collections.Counter() + objFileToCoFiles = collections.defaultdict(list) # Track which .co files reference each .o file + for isa, cofile_map in cofile_objects.items(): + gfx = isaToGfx(isa) + for cofile_name, sol_list in cofile_map.items(): + # Determine .co filename + if cofile_name is None: + coFileName = f"TensileLibrary_{gfx}.co" + else: + coFileName = f"{cofile_name}.co" + + # Deduplicate solutions in this .co file by basename + seen_basenames = set() + for sol_idx, sol in sol_list: + basename = sol.getKernels()[0].get("BaseName", None) + if basename is None: + basename = "MISSING!.o" + lostSolutions += [(sol_idx, cofile_name, sol)] + if basename not in seen_basenames: + seen_basenames.add(basename) + objFilePath = asmDir / (basename + extObj) + objFileRefCount[str(objFilePath)] += 1 + objFileToCoFiles[str(objFilePath)].append(coFileName) + + sharedObjFiles = {objFile: count for objFile, count in objFileRefCount.items() if count > 1} + if sharedObjFiles: + print1(f"Found {len(sharedObjFiles)} .o files shared across multiple code objects:") + for objFile in list(sharedObjFiles.keys())[:10]: # Show first 10 examples + coFiles = objFileToCoFiles[objFile] + basename = Path(objFile).name + print1(f" {basename} -> {', '.join(coFiles)}") + + if lostSolutions: + for a, b, c in lostSolutions[:10]: + print(a, b, c) + raise Exception("Some solutions are missing a BaseName. First 10 printed.") + + # Now process each ISA and .co file + for isa, cofile_map in cofile_objects.items(): + gfx = isaToGfx(isa) + + for cofile_name, sol_list in cofile_map.items(): + # Sort by solution index to ensure correct kernel ordering + sol_list.sort(key=lambda x: x[0]) + + # Build list of .o files, deduplicating by basename within this .co file + objFiles = [] + seen_basenames = set() + for sol_idx, sol in sol_list: + basename = sol.getKernels()[0].get("BaseName", "MISSING!.o") + if basename not in seen_basenames: + seen_basenames.add(basename) + objFilePath = asmDir / (basename + extObj) + objFiles.append(str(objFilePath)) + + # Determine output filename + if cofile_name is None: + coFileRaw = asmDir / f"TensileLibrary_{gfx}{extCoRaw}" + else: + coFileRaw = asmDir / f"{cofile_name}{extCoRaw}" + + # Link the .o files into a .co file + linker(objFiles, str(coFileRaw)) + + # Delete .o files after linking once usage count reaches 0 + for objFile in objFiles: + objFileRefCount[objFile] -= 1 + if objFileRefCount[objFile] == 0: + Path(objFile).unlink() + + # Compress/move the .co file to destination + coFile = destDir / coFileRaw.name.replace(extCoRaw, extCo) + if compress: + bundler.compress(str(coFileRaw), str(coFile), gfx) + else: + shutil.move(coFileRaw, coFile) + coFiles.append(coFile) return coFiles diff --git a/projects/hipblaslt/tensilelite/Tensile/Toolchain/Component.py b/projects/hipblaslt/tensilelite/Tensile/Toolchain/Component.py index 67fa35e2d88d..dde83af4c322 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Toolchain/Component.py +++ b/projects/hipblaslt/tensilelite/Tensile/Toolchain/Component.py @@ -355,6 +355,7 @@ def _response_file_args(self, srcPaths: List[str], destPath: str) -> List[str]: when invoking the linker, LLVM allows the provision of arguments via a "response file" Reference: https://llvm.org/docs/CommandLine.html#response-files """ + # FIXME: this prevents threading as clang_args.txt is overwritten with open(Path.cwd() / "clang_args.txt", "wt") as file: file.write(" ".join(srcPaths).replace('\\', '\\\\') if os_name == "nt" else " ".join(srcPaths)) return [*(self.default_args), "-o", destPath, "@clang_args.txt"] diff --git a/projects/hipblaslt/tensilelite/requirements.txt b/projects/hipblaslt/tensilelite/requirements.txt index 60c4c1144537..5c8fd66a8842 100644 --- a/projects/hipblaslt/tensilelite/requirements.txt +++ b/projects/hipblaslt/tensilelite/requirements.txt @@ -2,8 +2,6 @@ dataclasses; python_version == '3.6' packaging pyyaml msgpack -joblib>=1.4.0; python_version >= '3.8' -joblib>=1.1.1; python_version < '3.8' simplejson ujson orjson