diff --git a/projects/hipblaslt/tensilelite/Tensile/Components/CustomSchedule.py b/projects/hipblaslt/tensilelite/Tensile/Components/CustomSchedule.py index dd404ca48f8..65eb09acf85 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Components/CustomSchedule.py +++ b/projects/hipblaslt/tensilelite/Tensile/Components/CustomSchedule.py @@ -20,7 +20,8 @@ # CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ################################################################################ -from rocisa.code import Module, TextBlock, StructuredModule, KernelBody +from rocisa.code import KernelBody, Label, Macro, Module, RegSet, SrdUpperValue, \ + StructuredModule, TextBlock, ValueEndif, ValueIf, ValueElseIf, ValueSet, SignatureBase from rocisa.container import vgpr, sgpr, SMEMModifiers, replaceHolder, EXEC,\ VOP3PModifiers, ContinuousRegister from rocisa.instruction import BufferLoadB128, BufferLoadB32, BufferLoadB64, \ @@ -156,16 +157,15 @@ def addToStream(key, indexList, InstructionList): InstStreams = convOptToStream(opt1) - - module.add(TextBlock(".macro MAINLOOP ID useGR=1 usePLR=1 useGRInc=1 useLoop=1\n")) + macro = Macro("MAINLOOP", ["ID", "useGR=1", "usePLR=1", "useGRInc=1", "useLoop=1"]) #module.add(SBarrier(comment="debug")) lastIter = numLoopIter - 1 for miIndex in range(-1, len(mfmaCode)): if miIndex >= 0: - module.addComment0("mfmaIndex:%u"%(miIndex)) - module.add(mfmaCode[miIndex]) + macro.addComment0("mfmaIndex:%u"%(miIndex)) + macro.add(mfmaCode[miIndex]) def scheduleInst(indexList, instructionList): ret = [None]*len(indexList) @@ -197,55 +197,54 @@ def scheduleInst1(instList, macroGuard=""): if len(instList) == 1: if instList[0] != None: if macroGuard != "": - module.add(TextBlock(macroGuard)) - module.add(instList[0]) + macro.add(ValueIf(macroGuard)) + macro.add(instList[0]) if macroGuard != "": - module.add(TextBlock(".endif\n")) - + macro.add(ValueEndif()) for k,ts in ToSched.items(): if k in ['GRIncA', 'GRIncB']: # check for global read inc - scheduleInst1(ts, ".if \\useGRInc == 1\n") + scheduleInst1(ts, "\\useGRInc == 1") elif k in ['GRA', 'GRB', 'LWSA', 'LWSB']: # check for global reads - scheduleInst1(ts, ".if \\useGR == 1\n") + scheduleInst1(ts, "\\useGR == 1") elif k in ['LRA%u'%lastIter, 'LRB%u'%lastIter, 'LRSA', 'LRSB']: # check for next prefetch - scheduleInst1(ts, ".if \\usePLR == 1\n") + scheduleInst1(ts, "\\usePLR == 1") elif k in ['LCC']: # check for next prefetch - scheduleInst1(ts, ".if \\useLoop == 1\n") + scheduleInst1(ts, "\\useLoop == 1") else: scheduleInst1(ts) if needIfMacro: for codepath in range(numCodePath): if codepath == 0: - module.add(TextBlock(".if \\ID == %u\n"%codepath)) + macro.add(ValueIf("\\ID == %u"%codepath)) else: - module.add(TextBlock(".elseif \\ID == %u\n"%codepath)) + macro.add(ValueElseIf("\\ID == %u\n"%codepath)) def scheduleInst2(instList, macroGuard=""): if len(instList) == numCodePath: if instList[codepath] != None: if macroGuard != "": - module.add(TextBlock(macroGuard)) - module.add(instList[codepath]) + macro.add(ValueIf(macroGuard)) + macro.add(instList[codepath]) if macroGuard != "": - module.add(TextBlock(".endif\n")) + macro.add(ValueEndif()) for k,ts in ToSched.items(): if k in ['GRIncA', 'GRIncB']: # check for global read inc - scheduleInst2(ts, ".if \\useGRInc == 1\n") + scheduleInst2(ts, "\\useGRInc == 1\n") elif k in ['GRA', 'GRB', 'LWSA', 'LWSB']: # check for global reads - scheduleInst2(ts, ".if \\useGR == 1\n") + scheduleInst2(ts, "\\useGR == 1\n") elif k in ['LRA%u'%lastIter, 'LRB%u'%lastIter, 'LRSA', 'LRSB']: # check for next prefetch - scheduleInst2(ts, ".if \\usePLR == 1\n") + scheduleInst2(ts, "\\usePLR == 1\n") elif k in ['LCC']: # check for next prefetch - scheduleInst2(ts, ".if \\useLoop == 1\n") + scheduleInst2(ts, "\\useLoop == 1\n") else: scheduleInst2(ts) if codepath == numCodePath - 1: - module.add(TextBlock(".endif\n")) + macro.add(ValueEndif()) - module.add(TextBlock(".endm\n")) + module.add(macro) return module, numCodePath diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py index bc261dd77b0..f8f2dfeb352 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriter.py @@ -33,7 +33,7 @@ DSLoadB32, DSLoadB64, DSLoadB64TrB16, DSLoadInstruction, DSLoadU16, \ DSLoadU8, DSStore2B32, DSStore2B64, DSStoreB128, DSStoreB16, DSStoreB256, \ DSStoreB32, DSStoreB64, DSStoreB8, DSStoreInstruction, FlatLoadB128, FlatLoadB32, \ - FlatLoadB64, FlatStoreB128, FlatStoreB32, FlatStoreB64, Instruction, \ + FlatLoadB64, FlatStoreB128, FlatStoreB32, FlatStoreB64, Instruction, MacroInstruction, \ MFMAInstruction, SBarrier, SBranch, SCBranchSCC0, SCBranchSCC1, SCBranchVCCNZ, SCmpLeU32, \ SMFMAInstruction, SNop, SSetPrior, SSetRegIMM32B32, SSubU32, SWaitCnt, SWaitAlu, \ SLongBranchPositive, VFmaMixF32, VMadMixF32, VMovB32 @@ -2095,12 +2095,12 @@ def noLoadLoopBody( self, kernel, tensorParametersA, tensorParametersB, pack, is module.add(SWaitCnt(dscnt=0, vlcnt=0, vscnt=-1, comment="Wait for all PGR to complete")) module.add(SBarrier(comment="")) module.addComment0("Code-path 0, useGR=0, usePLR=1, useGRInc=1, useLoop = 0") - module.add(TextBlock("MAINLOOP 0 0 1 1 0\n")) + module.add(MacroInstruction(name="MAINLOOP", args=[0,0,1,1,0])) else: module.add(SWaitCnt(dscnt=0, vlcnt=0, vscnt=-1, comment="Wait for all PGR to complete")) module.add(SBarrier(comment="")) module.addComment0("Code-path 0, useGR=0, usePLR=0, useGRInc=0, useLoop = 0") - module.add(TextBlock("MAINLOOP 0 0 0 0 0\n")) + module.add(MacroInstruction(name="MAINLOOP", args=[0,0,0,0,0])) return module module = Module("noLoadLoopBody") expand = kernel["ExpandPointerSwap"] diff --git a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py index a0ed48d8560..fef11d022f1 100644 --- a/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py +++ b/projects/hipblaslt/tensilelite/Tensile/KernelWriterAssembly.py @@ -1444,7 +1444,7 @@ def checkResources(self, kernel, mkb: KernelBody): % (self.states.kernelName, self.states.overflowedResources, msg, \ self.vgprPool.size(), self.sgprPool.size())) mkb.body.add(SEndpgm(comment="overflowed resources"), 0) - mkb.body.add(ValueIf(value=0), 1) + mkb.body.add(ValueIf(value="0"), 1) ############################################################################## # code phrase for load batched address from array of buffer pointer @@ -13879,7 +13879,7 @@ def simdSpecDispatch(self, kernel, numCodePath): module.add(SWaitCnt(dscnt=0, vlcnt=-1, vscnt=-1, comment="Wait for all LR in pre-loop to complete")) if numCodePath == 1: - module.add(TextBlock("MAINLOOP 0\n")) + module.add(MacroInstruction(name="MAINLOOP", args=[0])) module.add(SCBranchSCC0(labelName="label_LoopBegin%s"%(loopChar), comment="" )) module.add(Label("LoopEnd%s"%(loopChar), "" )) return module @@ -13911,7 +13911,7 @@ def simdSpecDispatch(self, kernel, numCodePath): for l in range(numCodePath): module.addComment0("SIMD %u code-path"%l) module.add(loopLabelBegin[l]) - module.add(TextBlock("MAINLOOP %u\n"%l)) + module.add(MacroInstruction(name="MAINLOOP", args=[l])) module.add(SCBranchSCC0(labelName=loopLabelBegin[l].getLabelName(), comment="" )) tmpSgpr1 = self.sgprPool.checkOutAligned(2, 2) sgprPC = ContinuousRegister(tmpSgpr1, 3) diff --git a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/custom_mainloop_scheduling.yaml b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/custom_mainloop_scheduling.yaml index 78d3d9bd8e3..8d0f8886001 100644 --- a/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/custom_mainloop_scheduling.yaml +++ b/projects/hipblaslt/tensilelite/Tensile/Tests/common/gemm/gfx950/custom_mainloop_scheduling.yaml @@ -19,6 +19,7 @@ GlobalParameters: DeviceLDS: 163840 MaxLDS: 163840 PrintSolutionRejectionReason: True + DisableAsmComments : True BenchmarkProblems: ######################################## diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/code.hpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/code.hpp index a4285188858..22ae0d1025c 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/include/code.hpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/include/code.hpp @@ -608,6 +608,55 @@ namespace rocisa } }; + struct ValueEndif : public Item + { + std::string comment; + + ValueEndif(const std::string& comment = "") + : Item("ValueEndif") + , comment(comment) + { + } + + std::string toString() const override + { + return formatStr( + false, ".endif", comment, rocIsa::getInstance().getOutputOptions().outputNoComment); + } + }; + + struct ValueIf : public Item + { + std::string value; + + ValueIf(const std::string& value) + : Item("ValueIf") + , value(value) + { + } + + std::string toString() const override + { + return ".if " + value + "\n"; + } + }; + + struct ValueElseIf : public Item + { + std::string value; + + ValueElseIf(const std::string& value) + : Item("ValueElseIf") + , value(value) + { + } + + std::string toString() const override + { + return ".elseif " + value + "\n"; + } + }; + struct Macro : public Item { std::vector> itemList; @@ -635,7 +684,8 @@ namespace rocisa { // This is a workaround if(dynamic_cast(item.get()) || dynamic_cast(item.get()) - || dynamic_cast(item.get())) + || dynamic_cast(item.get()) || dynamic_cast(item.get()) + || dynamic_cast(item.get()) || dynamic_cast(item.get())) { item->parent = this; itemList.push_back(item); @@ -731,39 +781,6 @@ namespace rocisa } }; - struct ValueEndif : public Item - { - std::string comment; - - ValueEndif(const std::string& comment = "") - : Item("ValueEndif") - , comment(comment) - { - } - - std::string toString() const override - { - return formatStr( - false, ".endif", comment, rocIsa::getInstance().getOutputOptions().outputNoComment); - } - }; - - struct ValueIf : public Item - { - int value; - - ValueIf(int value) - : Item("ValueIf") - , value(value) - { - } - - std::string toString() const override - { - return ".if " + std::to_string(value); - } - }; - struct ValueSet : public Item { std::optional ref; diff --git a/projects/hipblaslt/tensilelite/rocisa/rocisa/src/code.cpp b/projects/hipblaslt/tensilelite/rocisa/rocisa/src/code.cpp index b25b42677c6..715bb4034da 100644 --- a/projects/hipblaslt/tensilelite/rocisa/rocisa/src/code.cpp +++ b/projects/hipblaslt/tensilelite/rocisa/rocisa/src/code.cpp @@ -249,13 +249,26 @@ void init_code(nb::module_ m) }); nb::class_(m_code, "ValueIf") - .def(nb::init(), nb::arg("value")) + .def(nb::init(), nb::arg("value")) .def("__str__", &rocisa::ValueIf::toString) .def("__deepcopy__", [](const rocisa::ValueIf& self, nb::dict&) { return new rocisa::ValueIf(self); }) .def("__getstate__", [](const rocisa::ValueIf& self) { return self.value; }) - .def("__setstate__", - [](rocisa::ValueIf& self, int value) { new(&self) rocisa::ValueIf(value); }); + .def("__setstate__", [](rocisa::ValueIf& self, const std::string& value) { + new(&self) rocisa::ValueIf(value); + }); + + nb::class_(m_code, "ValueElseIf") + .def(nb::init(), nb::arg("value")) + .def("__str__", &rocisa::ValueElseIf::toString) + .def("__deepcopy__", + [](const rocisa::ValueElseIf& self, nb::dict&) { + return new rocisa::ValueElseIf(self); + }) + .def("__getstate__", [](const rocisa::ValueElseIf& self) { return self.value; }) + .def("__setstate__", [](rocisa::ValueElseIf& self, const std::string& value) { + new(&self) rocisa::ValueElseIf(value); + }); nb::class_(m_code, "ValueSet") .def(nb::init(),