Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions pkgs/development/rocm-modules/6/hipblaslt/Tensile-interning.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
diff --git a/tensilelite/Tensile/Common/Utilities.py b/tensilelite/Tensile/Common/Utilities.py
index 0a9d9db5b3..cb9779eaac 100644
--- a/tensilelite/Tensile/Common/Utilities.py
+++ b/tensilelite/Tensile/Common/Utilities.py
@@ -24,6 +24,7 @@

import functools
import math
+import operator
import os
import sys
import time
@@ -269,8 +270,20 @@ def state(obj):


def state_key_ordering(cls):
- def tup(obj):
- return tuple([getattr(obj, k) for k in cls.StateKeys])
+ # Use operator.attrgetter for efficiency if __slots__ is defined
+ if hasattr(cls, '__slots__'):
+ # attrgetter is faster for slotted classes
+ getter = operator.attrgetter(*cls.StateKeys)
+ if len(cls.StateKeys) == 1:
+ # attrgetter returns scalar for single key, we need tuple
+ def tup(obj):
+ return (getter(obj),)
+ else:
+ tup = getter
+ else:
+ # Fallback for regular classes
+ def tup(obj):
+ return tuple([getattr(obj, k) for k in cls.StateKeys])

def lt(a, b):
return tup(a) < tup(b)
diff --git a/tensilelite/Tensile/Contractions.py b/tensilelite/Tensile/Contractions.py
index c0d4e851b1..3f2c2e98c6 100644
--- a/tensilelite/Tensile/Contractions.py
+++ b/tensilelite/Tensile/Contractions.py
@@ -37,9 +37,60 @@ from Tensile.Toolchain.Component import Assembler
from math import ceil

MIN_K_FOR_GSU = 32
+
+# Interning helpers to reduce memory usage by reusing identical objects
+_free_index_cache = {}
+def intern_free_index(isA, i=None, c=None, d=None, a=None, b=None):
+ key = (isA, i, c, d, a, b)
+ if key not in _free_index_cache:
+ obj = FreeIndex(isA, i, c, d)
+ obj.a = a
+ if b is not None:
+ obj.b = b
+ _free_index_cache[key] = obj
+ return _free_index_cache[key]
+
+_batch_index_cache = {}
+def intern_batch_index(a=None, b=None, c=None, d=None):
+ key = (a, b, c, d)
+ if key not in _batch_index_cache:
+ obj = BatchIndex(c=c, d=d)
+ obj.a = a
+ obj.b = b
+ _batch_index_cache[key] = obj
+ return _batch_index_cache[key]
+
+_bound_index_cache = {}
+def intern_bound_index(a=None, b=None, aMirror=False, bMirror=False):
+ key = (a, b, aMirror, bMirror)
+ if key not in _bound_index_cache:
+ obj = BoundIndex(aMirror=aMirror, bMirror=bMirror)
+ obj.a = a
+ obj.b = b
+ _bound_index_cache[key] = obj
+ return _bound_index_cache[key]
+
+_size_mapping_cache = {}
+def intern_size_mapping(size_mapping):
+ """Intern a SizeMapping instance to reduce redundancy."""
+ # Build hashable key from StateKeys, converting lists to tuples
+ key_parts = []
+ for attr in size_mapping.StateKeys:
+ val = getattr(size_mapping, attr)
+ # Convert lists to tuples for hashing
+ if isinstance(val, list):
+ val = tuple(val)
+ key_parts.append(val)
+ key = tuple(key_parts)
+
+ if key not in _size_mapping_cache:
+ _size_mapping_cache[key] = size_mapping
+ return _size_mapping_cache[key]
+
@state_key_ordering
class FreeIndex:
StateKeys = ['isA', 'i', 'c', 'd']
+ __slots__ = ['isA', 'i', 'c', 'd', 'a', 'b']

def __init__(self, isA, i=None, c=None, d=None):
self.isA = isA
@@ -50,6 +101,7 @@ class FreeIndex:
@state_key_ordering
class BatchIndex:
StateKeys = ['a', 'b', 'c', 'd']
+ __slots__ = ['a', 'b', 'c', 'd']
def __init__(self, a=None, b=None, c=None, d=None):
self.a = a
self.b = b
@@ -59,6 +111,7 @@ class BatchIndex:
@state_key_ordering
class BoundIndex:
StateKeys = ['a', 'b', 'aMirror', 'bMirror']
+ __slots__ = ['a', 'b', 'aMirror', 'bMirror']
def __init__(self, a=None, b=None, aMirror=False, bMirror=False):
self.a = a
self.b = b
@@ -107,6 +160,23 @@ class ProblemType:
for ib, ic in enumerate(d['IndexAssignmentsB']):
indices[ic].b = ib

+ # Now intern all indices with their final state (including .a and .b)
+ for i, idx in enumerate(indices):
+ if isinstance(idx, FreeIndex):
+ indices[i] = intern_free_index(idx.isA, idx.i, idx.c, idx.d,
+ getattr(idx, 'a', None), getattr(idx, 'b', None))
+ elif isinstance(idx, BatchIndex):
+ indices[i] = intern_batch_index(getattr(idx, 'a', None), getattr(idx, 'b', None),
+ idx.c, idx.d)
+ elif isinstance(idx, BoundIndex):
+ indices[i] = intern_bound_index(getattr(idx, 'a', None), getattr(idx, 'b', None),
+ idx.aMirror, idx.bMirror)
+
+ # Update the lists with interned versions
+ freeIndices = [idx for idx in indices if isinstance(idx, FreeIndex)]
+ batchIndices = [idx for idx in indices if isinstance(idx, BatchIndex)]
+ boundIndices = [idx for idx in indices if isinstance(idx, BoundIndex)]
+
for idx in indices:
assert idx is not None
idxState = state(idx)
@@ -596,6 +666,7 @@ class SizeMapping:
'nonTemporalA',
'nonTemporalB',
]
+ __slots__ = StateKeys

@classmethod
def FromOriginalState(cls, d):
@@ -751,7 +822,7 @@ class Solution:
info = cls.ReadOriginalInfo(d)
rv.libraryLogicIndex = int(info.get("SolutionIndex", -1))

- rv.sizeMapping = SizeMapping.FromOriginalState(d)
+ rv.sizeMapping = intern_size_mapping(SizeMapping.FromOriginalState(d))

rv.internalArgsSupport = InternalArgsSupport.FromOriginalState(d)

diff --git a/tensilelite/Tensile/TensileCreateLibrary/Run.py b/tensilelite/Tensile/TensileCreateLibrary/Run.py
index 730b6b1fff..b0068563a0 100644
--- a/tensilelite/Tensile/TensileCreateLibrary/Run.py
+++ b/tensilelite/Tensile/TensileCreateLibrary/Run.py
@@ -104,7 +104,6 @@ class KernelCodeGenResult(NamedTuple):
src: str
header: Optional[str]
name: str
- targetObjFilename: str
isa: IsaVersion
wavefrontSize: int
cuoccupancy: int
@@ -127,10 +126,9 @@ def processKernelSource(kernelWriterAssembly, data, splitGSU, kernel) -> KernelC
asmFilename = getKernelFileBase(splitGSU, kernel)
err, src = kernelWriter.getSourceFileString(kernel)
header = kernelWriter.getHeaderFileString(kernel)
- objFilename = kernel._state.get("codeObjectFile", None)
pgr = int(kernel["PrefetchGlobalRead"])
return KernelCodeGenResult(
- err, src, header, asmFilename, objFilename, tuple(kernel["ISA"]), \
+ err, src, header, asmFilename, tuple(kernel["ISA"]), \
kernel["WavefrontSize"], kernel["CUOccupancy"], \
pgr, kernel["MathClocksUnrolledLoop"]
)
Loading
Loading