Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions projects/hipblaslt/tensilelite/Tensile/Components/StreamK.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,23 @@ def skExtraIters(self, writer, kernel, sSkExtraIters, sTmp):
return module

@abc.abstractmethod
def computeLoadSrd(self, writer, kernel, tc, sTmp):
def computeLoadSrd(self, writer, kernel, tP, sTmp):
pass

def computeLoadSrdCommon(self, writer, kernel, tc, sTmp):
def computeLoadSrdCommon(self, writer, kernel, tP, sTmp):
module = Module("StreamK Common computeLoadSrd")

tc = tP["tensorChar"]
_DepthU = kernel["_DepthU%s" % tc]
# swizzle
if (tP["isSwizzled"] and tc == 'A'):
_DepthU = (_DepthU * 16) # MI_M = 16
elif (tP["isSwizzled"] and tc == 'B'):
_DepthU = (_DepthU * 16) # MI_N = 16

tileStart = sTmp + 2
# StreamK partial tile - offset to tile start index
module.add(SMulI32(dst=sgpr(sTmp), src0=sgpr("StreamKLocalStart"), src1=kernel["DepthU"], comment="StreamK tile start offset"))
module.add(SMulI32(dst=sgpr(sTmp), src0=sgpr("StreamKLocalStart"), src1=_DepthU, comment="StreamK tile start offset"))
strideL = writer.strideRef(tc, kernel["ProblemType"]["IndicesSummation"][0])
module.add(writer.s_mul_u64_u32(sgpr(sTmp), sgpr(sTmp+1), sgpr(sTmp), strideL, comment="StreamK tile start offset"))
# Overflow check removed
Expand Down Expand Up @@ -255,9 +263,16 @@ def graAddressesCommon(self, writer, kernel, tP, vTmp):
module = Module("StreamK Common graAddresses")

tc = tP["tensorChar"]
_DepthU = kernel["_DepthU%s" % tc]
# swizzle
if (tP["isSwizzled"] and tc == 'A'):
_DepthU = (_DepthU * 16) # MI_M = 16
Comment thread
nakajee marked this conversation as resolved.
Outdated
elif (tP["isSwizzled"] and tc == 'B'):
_DepthU = (_DepthU * 16) # MI_N = 16

# StreamK partial tile - offset to tile start index
tmpOffset = writer.sgprPool.checkOut(2, "skStartOffset")
module.add(SMulI32(dst=sgpr(tmpOffset), src0=sgpr("StreamKLocalStart"), src1=int(kernel["DepthU"] * tP["bpe"]), comment="StreamK tile start offset"))
module.add(SMulI32(dst=sgpr(tmpOffset), src0=sgpr("StreamKLocalStart"), src1=int(_DepthU * tP["bpe"]), comment="StreamK tile start offset"))
strideL = writer.strideRef(tc, kernel["ProblemType"]["IndicesSummation"][0])
module.add(writer.s_mul_u64_u32(sgpr(tmpOffset), sgpr(tmpOffset+1), sgpr(tmpOffset), strideL, "StreamK tile start offset"))
# Overflow check removed
Expand Down Expand Up @@ -1745,7 +1760,7 @@ def graWorkGroup(self, writer, kernel, tPA, tPB):
module = Module("StreamK Off graWorkGroup")
return module

def computeLoadSrd(self, writer, kernel, tc, sTmp):
def computeLoadSrd(self, writer, kernel, tP, sTmp):
module = Module("StreamK Off computeLoadSrd")
return module

Expand Down Expand Up @@ -1861,9 +1876,9 @@ def graWorkGroup(self, writer, kernel, tPA, tPB):

return module

def computeLoadSrd(self, writer, kernel, tc, sTmp):
def computeLoadSrd(self, writer, kernel, tP, sTmp):
module = Module("StreamK Basic computeLoadSrd")
module.add(self.computeLoadSrdCommon(writer, kernel, tc, sTmp))
module.add(self.computeLoadSrdCommon(writer, kernel, tP, sTmp))
return module

def computeStoreSrdStart(self, writer, kernel):
Expand Down Expand Up @@ -1986,9 +2001,9 @@ def graWorkGroup(self, writer, kernel, tPA, tPB):

return module

def computeLoadSrd(self, writer, kernel, tc, sTmp):
def computeLoadSrd(self, writer, kernel, tP, sTmp):
module = Module("StreamK TwoTileOriginal computeLoadSrd")
module.add(self.computeLoadSrdCommon(writer, kernel, tc, sTmp))
module.add(self.computeLoadSrdCommon(writer, kernel, tP, sTmp))
return module

def computeStoreSrdStart(self, writer, kernel):
Expand Down Expand Up @@ -2256,9 +2271,9 @@ def graWorkGroup(self, writer, kernel, tPA, tPB):

return module

def computeLoadSrd(self, writer, kernel, tc, sTmp):
def computeLoadSrd(self, writer, kernel, tP, sTmp):
module = Module("StreamK TwoTileDPFirst computeLoadSrd")
module.add(self.computeLoadSrdCommon(writer, kernel, tc, sTmp))
module.add(self.computeLoadSrdCommon(writer, kernel, tP, sTmp))
return module

def computeStoreSrdStart(self, writer, kernel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3617,7 +3617,7 @@ def computeLoadSrd(self, kernel, tP, tc, indices, bpe):
strideF, comment="tlu=0, scaled tile-offset by stride"))

skComponent = Component.StreamK.find(self)
module.add(skComponent.computeLoadSrd(self, kernel, tc, stmp))
module.add(skComponent.computeLoadSrd(self, kernel, tP, stmp))

gsuComponent = Component.GSU.find(self)
module.add(gsuComponent.computeLoadSrd(self, kernel, tP, stmp, tileStart))
Expand Down
Loading