From 3ee9b9c8c829b298f195ce51707f03fa0f829423 Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Mon, 1 Dec 2025 23:02:40 +0000 Subject: [PATCH 01/22] Change coordinate and control grph to reset flags --- .../Transformations/AddStreamK.cpp | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index 6851e9a03ca..cea84ea7275 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -578,6 +578,26 @@ namespace rocRoller auto doWhileTag = graph.control.addElement( DoWhileOp{(DF(flagRegister) == zero), "Global sync spin loop"}); + // Duplicate flag and next workgroup coordinates for reset operation + auto resetNextWorkgroupTag = graph.coordinates.addElement(Linear(nullptr, one)); + graph.coordinates.addElement(Join(), {workgroup, plusOneTag, forReceiveTileLoopCoord}, {resetNextWorkgroupTag}); + + auto resetFlagsScratchTag = graph.coordinates.addElement( + *graph.coordinates.get(flagsScratchTag)); + graph.coordinates.addElement(Duplicate(), {resetFlagsScratchTag}, {flagsScratchTag}); + graph.coordinates.addElement(PassThrough(), {resetNextWorkgroupTag}, {resetFlagsScratchTag}); + + // Reset flag operations + auto assignResetFlagTag = graph.control.addElement(Assign{Register::Type::Scalar, zero}); + graph.mapper.connect(assignResetFlagTag, flagRegister, NaryArgument::DEST); + + auto resetFlagTag = graph.control.addElement(StoreSGPR(DataType::UInt32, bufOpts)); + graph.mapper.connect(resetFlagTag, resetFlagsScratchTag); + graph.mapper.connect(resetFlagTag, flagRegister); + + auto waitZeroAfterResetTag = graph.control.addElement(WaitZero()); + auto barrierBeforeResetTag = graph.control.addElement(Barrier()); + auto accumulatorTile = graph.coordinates.get(accumulatorTileTag); uint numRegisters = accumulatorTile->elements() / (product(context->kernel()->workgroupSize()) * loopInfo.xLoopSize @@ -636,7 +656,8 @@ namespace rocRoller graph.control.addElement(Sequence(), {boundsCheckTag}, {doWhileTag}); graph.control.addElement(Body(), {doWhileTag}, {loadFlagTag}); - graph.control.chain(doWhileTag, loadAddForX, postWaitZeroTag); + graph.control.chain(doWhileTag, barrierBeforeResetTag, assignResetFlagTag, + resetFlagTag, waitZeroAfterResetTag, loadAddForX, postWaitZeroTag); return {preWaitZeroTag, receiveTileTag, setPlusOneTag}; } From e5fc05d41f454735a40ccea954f9e7917f42afd5 Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Mon, 1 Dec 2025 23:03:26 +0000 Subject: [PATCH 02/22] Formatting --- .../Transformations/AddStreamK.cpp | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index cea84ea7275..bf3c515f279 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -580,15 +580,18 @@ namespace rocRoller // Duplicate flag and next workgroup coordinates for reset operation auto resetNextWorkgroupTag = graph.coordinates.addElement(Linear(nullptr, one)); - graph.coordinates.addElement(Join(), {workgroup, plusOneTag, forReceiveTileLoopCoord}, {resetNextWorkgroupTag}); + graph.coordinates.addElement( + Join(), {workgroup, plusOneTag, forReceiveTileLoopCoord}, {resetNextWorkgroupTag}); - auto resetFlagsScratchTag = graph.coordinates.addElement( - *graph.coordinates.get(flagsScratchTag)); + auto resetFlagsScratchTag + = graph.coordinates.addElement(*graph.coordinates.get(flagsScratchTag)); graph.coordinates.addElement(Duplicate(), {resetFlagsScratchTag}, {flagsScratchTag}); - graph.coordinates.addElement(PassThrough(), {resetNextWorkgroupTag}, {resetFlagsScratchTag}); + graph.coordinates.addElement( + PassThrough(), {resetNextWorkgroupTag}, {resetFlagsScratchTag}); // Reset flag operations - auto assignResetFlagTag = graph.control.addElement(Assign{Register::Type::Scalar, zero}); + auto assignResetFlagTag + = graph.control.addElement(Assign{Register::Type::Scalar, zero}); graph.mapper.connect(assignResetFlagTag, flagRegister, NaryArgument::DEST); auto resetFlagTag = graph.control.addElement(StoreSGPR(DataType::UInt32, bufOpts)); @@ -656,8 +659,13 @@ namespace rocRoller graph.control.addElement(Sequence(), {boundsCheckTag}, {doWhileTag}); graph.control.addElement(Body(), {doWhileTag}, {loadFlagTag}); - graph.control.chain(doWhileTag, barrierBeforeResetTag, assignResetFlagTag, - resetFlagTag, waitZeroAfterResetTag, loadAddForX, postWaitZeroTag); + graph.control.chain(doWhileTag, + barrierBeforeResetTag, + assignResetFlagTag, + resetFlagTag, + waitZeroAfterResetTag, + loadAddForX, + postWaitZeroTag); return {preWaitZeroTag, receiveTileTag, setPlusOneTag}; } From c096700d06271071956a5fa7d546668ab3c079d5 Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Tue, 2 Dec 2025 22:41:40 +0000 Subject: [PATCH 03/22] Remove wait zero after reset --- .../lib/source/KernelGraph/Transformations/AddStreamK.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index bf3c515f279..e1543aa42f0 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -598,7 +598,6 @@ namespace rocRoller graph.mapper.connect(resetFlagTag, resetFlagsScratchTag); graph.mapper.connect(resetFlagTag, flagRegister); - auto waitZeroAfterResetTag = graph.control.addElement(WaitZero()); auto barrierBeforeResetTag = graph.control.addElement(Barrier()); auto accumulatorTile = graph.coordinates.get(accumulatorTileTag); @@ -663,7 +662,6 @@ namespace rocRoller barrierBeforeResetTag, assignResetFlagTag, resetFlagTag, - waitZeroAfterResetTag, loadAddForX, postWaitZeroTag); From 5ce195969c46556dac04ff87568082fdacb164d0 Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Thu, 4 Dec 2025 17:10:35 +0000 Subject: [PATCH 04/22] Map ScratchPolicy to m_scratchAllocator --- .../lib/include/rocRoller/CommandSolution.hpp | 7 +- .../lib/include/rocRoller/Context.hpp | 18 +++-- .../include/rocRoller/DataTypes/DataTypes.hpp | 13 +++ .../include/rocRoller/KernelGraph/Utils.hpp | 6 +- .../rocroller/lib/source/CommandSolution.cpp | 5 +- shared/rocroller/lib/source/Context.cpp | 16 ++-- shared/rocroller/lib/source/DataTypes.cpp | 19 +++++ .../Transformations/AddStreamK.cpp | 6 +- .../lib/source/KernelGraph/Utils.cpp | 8 +- shared/rocroller/test/unit/GEMMTest.cpp | 80 +++++++++++++++++-- 10 files changed, 149 insertions(+), 29 deletions(-) diff --git a/shared/rocroller/lib/include/rocRoller/CommandSolution.hpp b/shared/rocroller/lib/include/rocRoller/CommandSolution.hpp index c1d1eb1aada..5ddd3d5f810 100644 --- a/shared/rocroller/lib/include/rocRoller/CommandSolution.hpp +++ b/shared/rocroller/lib/include/rocRoller/CommandSolution.hpp @@ -307,13 +307,16 @@ namespace rocRoller /** * @brief Returns the total number of bytes required for scratch space + * for the specified scratch policy. * - * If this value is greather than 0, the user is required to allocate this + * If this value is greater than 0, the user is required to allocate this * amount of device memory and pass it into the kernel. * + * @param policy The scratch policy to query + * @param args The runtime arguments * @return size_t */ - size_t scratchSpaceRequired(RuntimeArguments const& args) const; + size_t scratchSpaceRequired(ScratchPolicy policy, RuntimeArguments const& args) const; /** * @brief Returns the workgroup size diff --git a/shared/rocroller/lib/include/rocRoller/Context.hpp b/shared/rocroller/lib/include/rocRoller/Context.hpp index 5fe9039e9d0..0a5be9c3305 100644 --- a/shared/rocroller/lib/include/rocRoller/Context.hpp +++ b/shared/rocroller/lib/include/rocRoller/Context.hpp @@ -28,6 +28,7 @@ #include #include +#include #include #include #include @@ -128,17 +129,20 @@ namespace rocRoller /** * @brief Returns an expression representing how much scratch space is required (in bytes) + * for the specified scratch policy. * + * @param policy The scratch policy to query * @return Expression::ExpressionPtr */ - Expression::ExpressionPtr getScratchAmount() const; + Expression::ExpressionPtr getScratchAmount(ScratchPolicy policy) const; /** - * @brief Allocate more scratch space + * @brief Allocate more scratch space for the specified scratch policy. * + * @param policy The scratch policy to allocate for * @param size Number of bytes requested */ - void allocateScratch(Expression::ExpressionPtr size); + void allocateScratch(ScratchPolicy policy, Expression::ExpressionPtr size); /** * @brief Get register scope manager. @@ -173,10 +177,10 @@ namespace rocRoller std::shared_ptr m_argLoader; std::shared_ptr m_instructions; std::shared_ptr m_mem; - LabelAllocatorPtr m_labelAllocator; - std::shared_ptr m_ldsAllocator; - Expression::ExpressionPtr m_scratchAllocator; - std::shared_ptr m_copier; + LabelAllocatorPtr m_labelAllocator; + std::shared_ptr m_ldsAllocator; + std::map m_scratchAllocators; + std::shared_ptr m_copier; std::shared_ptr m_brancher; std::shared_ptr m_crasher; std::shared_ptr m_random; diff --git a/shared/rocroller/lib/include/rocRoller/DataTypes/DataTypes.hpp b/shared/rocroller/lib/include/rocRoller/DataTypes/DataTypes.hpp index 3ebf32c3224..f5e26521d1b 100644 --- a/shared/rocroller/lib/include/rocRoller/DataTypes/DataTypes.hpp +++ b/shared/rocroller/lib/include/rocRoller/DataTypes/DataTypes.hpp @@ -200,6 +200,19 @@ namespace rocRoller std::string toString(NaryArgument n); std::ostream& operator<<(std::ostream& stream, NaryArgument n); + /** + * Scratch memory policy; distinguishes between different scratch memory types. + */ + enum class ScratchPolicy : int + { + SyncFlags = 0, //< Scratch space for synchronization flags + TileData, //< Scratch space for tile data + Count + }; + + std::string toString(ScratchPolicy s); + std::ostream& operator<<(std::ostream& stream, ScratchPolicy s); + inline constexpr DataType getIntegerType(bool isSigned, int sizeBytes); // Case insensitive and with special cases diff --git a/shared/rocroller/lib/include/rocRoller/KernelGraph/Utils.hpp b/shared/rocroller/lib/include/rocRoller/KernelGraph/Utils.hpp index 9326101d16a..0ce19ae2ff9 100644 --- a/shared/rocroller/lib/include/rocRoller/KernelGraph/Utils.hpp +++ b/shared/rocroller/lib/include/rocRoller/KernelGraph/Utils.hpp @@ -263,11 +263,15 @@ namespace rocRoller * * @param size * @param varType + * @param policy The scratch policy to use for allocation * @param context * @return User */ rocRoller::KernelGraph::CoordinateGraph::User newScratchCoordinate( - Expression::ExpressionPtr size, VariableType varType, ContextPtr context); + Expression::ExpressionPtr size, + VariableType varType, + ScratchPolicy policy, + ContextPtr context); /** * @brief Replace operation with a new operation. diff --git a/shared/rocroller/lib/source/CommandSolution.cpp b/shared/rocroller/lib/source/CommandSolution.cpp index 20b30370dde..62a198f6fc8 100644 --- a/shared/rocroller/lib/source/CommandSolution.cpp +++ b/shared/rocroller/lib/source/CommandSolution.cpp @@ -644,9 +644,10 @@ namespace rocRoller return m_context; } - size_t CommandKernel::scratchSpaceRequired(RuntimeArguments const& args) const + size_t CommandKernel::scratchSpaceRequired(ScratchPolicy policy, + RuntimeArguments const& args) const { - auto amount = m_context->getScratchAmount(); + auto amount = m_context->getScratchAmount(policy); auto times = evaluationTimes(amount); AssertFatal(times[Expression::EvaluationTime::Translate] diff --git a/shared/rocroller/lib/source/Context.cpp b/shared/rocroller/lib/source/Context.cpp index 32272a1b430..024570f67f4 100644 --- a/shared/rocroller/lib/source/Context.cpp +++ b/shared/rocroller/lib/source/Context.cpp @@ -47,8 +47,12 @@ namespace rocRoller { Context::Context() - : m_scratchAllocator(Expression::literal(0u)) { + // Initialize scratch allocators for each policy with zero + for(int i = 0; i < static_cast(ScratchPolicy::Count); ++i) + { + m_scratchAllocators[static_cast(i)] = Expression::literal(0u); + } } ContextPtr Context::ForDefaultHipDevice(std::string const& kernelName, @@ -287,14 +291,16 @@ namespace rocRoller m_kernel = assemblyKernel; } - Expression::ExpressionPtr Context::getScratchAmount() const + Expression::ExpressionPtr Context::getScratchAmount(ScratchPolicy policy) const { - return m_scratchAllocator; + auto it = m_scratchAllocators.find(policy); + AssertFatal(it != m_scratchAllocators.end(), "Scratch policy not found", ShowValue(policy)); + return it->second; } - void Context::allocateScratch(Expression::ExpressionPtr size) + void Context::allocateScratch(ScratchPolicy policy, Expression::ExpressionPtr size) { - m_scratchAllocator = simplify(m_scratchAllocator + size); + m_scratchAllocators[policy] = simplify(m_scratchAllocators[policy] + size); } void Context::scheduleCopy(Instruction const& inst) diff --git a/shared/rocroller/lib/source/DataTypes.cpp b/shared/rocroller/lib/source/DataTypes.cpp index 479f48e1dc4..29a35e6cdf2 100644 --- a/shared/rocroller/lib/source/DataTypes.cpp +++ b/shared/rocroller/lib/source/DataTypes.cpp @@ -306,6 +306,25 @@ namespace rocRoller return stream << toString(n); } + std::string toString(ScratchPolicy s) + { + switch(s) + { + case ScratchPolicy::SyncFlags: + return "SyncFlags"; + case ScratchPolicy::TileData: + return "TileData"; + + case ScratchPolicy::Count:; + } + return "Invalid"; + } + + std::ostream& operator<<(std::ostream& stream, ScratchPolicy s) + { + return stream << toString(s); + } + std::string toString(PointerType const& p) { switch(p) diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index e1543aa42f0..e9f413893a9 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -292,7 +292,8 @@ namespace rocRoller auto strideX = sizeY; auto strideY = literal(1u); - auto globalScratch = newScratchCoordinate(simplify(sizeX * sizeY), varType, context); + auto globalScratch + = newScratchCoordinate(simplify(sizeX * sizeY), varType, ScratchPolicy::TileData, context); auto globalScratchTag = graph.coordinates.addElement(globalScratch); std::vector jammedSizes = {loopInfo.xLoopSize, loopInfo.yLoopSize}; @@ -1143,7 +1144,8 @@ namespace rocRoller resultVariableType(numRemainPartialResults)); // Create scratch space for flags - auto flagsScratch = newScratchCoordinate(argInfo.numWGs, DataType::UInt32, context); + auto flagsScratch = newScratchCoordinate( + argInfo.numWGs, DataType::UInt32, ScratchPolicy::SyncFlags, context); auto flagsScratchTag = graph.coordinates.addElement(flagsScratch); // Create scratch space for partially accumulated tiles diff --git a/shared/rocroller/lib/source/KernelGraph/Utils.cpp b/shared/rocroller/lib/source/KernelGraph/Utils.cpp index e0a39f8c976..c34e8771235 100644 --- a/shared/rocroller/lib/source/KernelGraph/Utils.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Utils.cpp @@ -774,13 +774,17 @@ namespace rocRoller } rocRoller::KernelGraph::CoordinateGraph::User newScratchCoordinate( - Expression::ExpressionPtr size, VariableType varType, ContextPtr context) + Expression::ExpressionPtr size, + VariableType varType, + ScratchPolicy policy, + ContextPtr context) { - auto currentOffset = context->getScratchAmount(); + auto currentOffset = context->getScratchAmount(policy); auto newCoordinate = CT::User(size, currentOffset); // TODO Audit bytes/bits // Can we move size inside the CeilDivide? context->allocateScratch( + policy, size * Expression::literal(CeilDivide(DataTypeInfo::Get(varType).elementBits, 8u))); return newCoordinate; diff --git a/shared/rocroller/test/unit/GEMMTest.cpp b/shared/rocroller/test/unit/GEMMTest.cpp index aefd1db5586..06ffd74e8ec 100644 --- a/shared/rocroller/test/unit/GEMMTest.cpp +++ b/shared/rocroller/test/unit/GEMMTest.cpp @@ -29,6 +29,7 @@ #include #endif /* ROCROLLER_USE_HIP */ +#include #include #include @@ -455,13 +456,31 @@ namespace GEMMDriverTest command->addOperation(rocRoller::Operations::T_Store_Tiled(tagCvt, tagTensorD)); } - auto tagScratch = command->allocateTag(); + auto tileDataScratchTag = command->allocateTag(); command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), - tagScratch, + tileDataScratchTag, ArgumentType::Value, DataDirection::ReadWrite, rocRoller::SCRATCH); + auto syncFlagsScratchTag = command->allocateTag(); + command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), + syncFlagsScratchTag, + ArgumentType::Value, + DataDirection::ReadWrite, + rocRoller::SCRATCH); + + // Operations::OperationTag scratchTag[static_cast(ScratchPolicy::Count)]; + // for(size_t i = 0; i < static_cast(ScratchPolicy::Count); ++i) + // { + // scratchTag[i] = command->allocateTag(); + // command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), + // scratchTag[i], + // ArgumentType::Value, + // DataDirection::ReadWrite, rocRoller::SCRATCH); + // } + // std::cout << "create scratchTag" << std::endl; + Operations::OperationTag tagNumWGs; if(gemm.streamK) { @@ -673,10 +692,31 @@ namespace GEMMDriverTest commandArgs.setArgument(tagNumWGs, ArgumentType::Value, gemm.numWGs); } - auto scratchSpaceRequired - = commandKernel.scratchSpaceRequired(commandArgs.runtimeArguments()); - auto deviceScratch = make_shared_device(scratchSpaceRequired, 0); - commandArgs.setArgument(tagScratch, ArgumentType::Value, deviceScratch.get()); + auto tileDataScratchSpaceRequired = commandKernel.scratchSpaceRequired( + ScratchPolicy::TileData, commandArgs.runtimeArguments()); + auto deviceTileDataScratch = make_shared_device(tileDataScratchSpaceRequired, 0); + commandArgs.setArgument(tileDataScratchTag, ArgumentType::Value, deviceTileDataScratch.get()); + + auto syncFlagsScratchSpaceRequired = commandKernel.scratchSpaceRequired( + ScratchPolicy::SyncFlags, commandArgs.runtimeArguments()); + auto deviceSyncFlagsScratch = make_shared_device(syncFlagsScratchSpaceRequired, 0); + commandArgs.setArgument(syncFlagsScratchTag, ArgumentType::Value, deviceSyncFlagsScratch.get()); + + // std::shared_ptr deviceScratch[static_cast(ScratchPolicy::Count)]; + // size_t scratchSpaceRequired[static_cast(ScratchPolicy::Count)]; + // for(size_t i = 0; i < static_cast(ScratchPolicy::Count); ++i) + // { + // scratchSpaceRequired[i] = commandKernel.scratchSpaceRequired( + // static_cast(i), commandArgs.runtimeArguments()); + // // if(scratchSpaceRequired[i] > 0) + // { + // std::cout << "scratch space required: " << scratchSpaceRequired[i] << std::endl; + // std::cout << "index: " << i << std::endl; + // deviceScratch[i] = make_shared_device(scratchSpaceRequired[i], 0); + // commandArgs.setArgument(scratchTag[i], ArgumentType::Value, deviceScratch[i].get()); + // } + // } + std::cout << "set commmand arguments" << std::endl; if(gemm.workgroupMappingDim != -1) { @@ -789,16 +829,40 @@ namespace GEMMDriverTest for(int iteration = 0; iteration < numIters; ++iteration) { ASSERT_THAT(hipMemset(deviceD.get(), 0, M * N * sizeof(TD)), HasHipSuccess(0)); - ASSERT_THAT(hipMemset(deviceScratch.get(), 0, scratchSpaceRequired), - HasHipSuccess(0)); + // for(size_t i = 0; i < static_cast(ScratchPolicy::Count); ++i) + // { + // // if(scratchSpaceRequired[i] > 0) + // { + // ASSERT_THAT(hipMemset(deviceScratch[i].get(), 0, scratchSpaceRequired[i]), HasHipSuccess(0)); + // } + // } + ASSERT_THAT(hipMemset(deviceTileDataScratch.get(), 0, tileDataScratchSpaceRequired), HasHipSuccess(0)); + ASSERT_THAT(hipMemset(deviceSyncFlagsScratch.get(), 0, syncFlagsScratchSpaceRequired), HasHipSuccess(0)); + + std::cout << "launch kernel" << std::endl; commandKernel.launchKernel(commandArgs.runtimeArguments()); + std::cout << "copy device result" << std::endl; ASSERT_THAT( hipMemcpy( d_result.data(), deviceD.get(), M * N * sizeof(TD), hipMemcpyDeviceToHost), HasHipSuccess(0)); + std::cout << "verify sync flags scratch" << std::endl; + + // Verify SyncFlags scratch is all zeros after kernel + // auto syncFlagsIdx = static_cast(ScratchPolicy::SyncFlags); + auto syncFlagsResult = std::vector(syncFlagsScratchSpaceRequired); + ASSERT_THAT(hipMemcpy(syncFlagsResult.data(), + // deviceScratch[syncFlagsIdx].get(), + deviceSyncFlagsScratch.get(), + syncFlagsScratchSpaceRequired, + hipMemcpyDeviceToHost), + HasHipSuccess(0)); + EXPECT_TRUE(std::all_of(syncFlagsResult.begin(), syncFlagsResult.end(), [](uint8_t v) { return v == 0; })) + << "SyncFlags scratch should be all zeros after kernel execution"; + auto tol = gemmAcceptableError( M, N, K, m_context->targetArchitecture().target()); auto res = compare(d_result, h_result, tol); From 4a6c69257488aa98ca12d8c76fea808cd8a7d16a Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Thu, 4 Dec 2025 19:54:53 +0000 Subject: [PATCH 05/22] Enable multiple scratch allocation --- .../client/include/client/GEMMSolution.hpp | 3 +- .../include/client/StreamKGEMMSolution.hpp | 25 ++-- shared/rocroller/client/src/gemm.cpp | 30 ++++- .../lib/include/rocRoller/Context.hpp | 26 ++-- .../KernelGraph/CoordinateGraph/Dimension.hpp | 5 +- .../include/rocRoller/KernelGraph/Utils.hpp | 10 +- .../lib/include/rocRoller/KernelOptions.hpp | 7 + .../KernelGraph/CoordinateGraph/Dimension.cpp | 6 +- .../Transformations/AddStreamK.cpp | 7 +- .../lib/source/KernelGraph/Utils.cpp | 13 +- .../common/include/common/CommonGraphs.hpp | 3 + .../test/common/src/CommonGraphs.cpp | 17 ++- shared/rocroller/test/unit/GEMMFusion.cpp | 47 ++++++- shared/rocroller/test/unit/GEMMTest.cpp | 121 +++++++----------- 14 files changed, 189 insertions(+), 131 deletions(-) diff --git a/shared/rocroller/client/include/client/GEMMSolution.hpp b/shared/rocroller/client/include/client/GEMMSolution.hpp index fb2d0aa0c62..c311c192261 100644 --- a/shared/rocroller/client/include/client/GEMMSolution.hpp +++ b/shared/rocroller/client/include/client/GEMMSolution.hpp @@ -28,6 +28,7 @@ #include "GEMMParameters.hpp" +#include #include using namespace rocRoller; @@ -61,7 +62,7 @@ namespace rocRoller { } - virtual Operations::OperationTag getScratchTag() const + virtual Operations::OperationTag getScratchTag(ScratchPolicy scratchPolicy) const { return {}; } diff --git a/shared/rocroller/client/include/client/StreamKGEMMSolution.hpp b/shared/rocroller/client/include/client/StreamKGEMMSolution.hpp index 1e989b27965..0b42e22e8cd 100644 --- a/shared/rocroller/client/include/client/StreamKGEMMSolution.hpp +++ b/shared/rocroller/client/include/client/StreamKGEMMSolution.hpp @@ -41,14 +41,15 @@ namespace rocRoller { class StreamKGEMMSolution : public DataParallelGEMMSolution { - Operations::OperationTag m_scratchTag, m_numWGsTag; + std::map m_scratchTags; + Operations::OperationTag m_numWGsTag; public: using DataParallelGEMMSolution::DataParallelGEMMSolution; - Operations::OperationTag getScratchTag() const override + Operations::OperationTag getScratchTag(ScratchPolicy scratchPolicy) const override { - return m_scratchTag; + return m_scratchTags.at(scratchPolicy); } protected: @@ -63,13 +64,17 @@ namespace rocRoller DataDirection::ReadOnly, rocRoller::NUMWGS); - m_scratchTag = command->allocateTag(); - command->allocateArgument( - VariableType(DataType::UInt32, PointerType::PointerGlobal), - m_scratchTag, - ArgumentType::Value, - DataDirection::ReadWrite, - rocRoller::SCRATCH); + for(int i = 0; i < static_cast(ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + m_scratchTags[policy] = command->allocateTag(); + command->allocateArgument( + VariableType(DataType::UInt32, PointerType::PointerGlobal), + m_scratchTags[policy], + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(policy)); + } return command; } diff --git a/shared/rocroller/client/src/gemm.cpp b/shared/rocroller/client/src/gemm.cpp index b9c3dad3dfe..a94267aff4f 100644 --- a/shared/rocroller/client/src/gemm.cpp +++ b/shared/rocroller/client/src/gemm.cpp @@ -25,6 +25,7 @@ *******************************************************************************/ #include "rocRoller/Serialization/YAML.hpp" +#include #include #include #include @@ -344,14 +345,17 @@ namespace rocRoller::Client::GEMMClient auto runtimeArgs = commandArgs.runtimeArguments(); // Note: the lifetime of deviceScratch needs to exceed kernel executions - std::shared_ptr deviceScratch; + std::shared_ptr deviceScratch[static_cast(ScratchPolicy::Count)]; + + for(int i = 0; i < static_cast(ScratchPolicy::Count); ++i) { - auto scratchSpaceRequired = commandKernel->scratchSpaceRequired(runtimeArgs); + auto policy = static_cast(i); + auto scratchSpaceRequired = commandKernel->scratchSpaceRequired(policy, runtimeArgs); if(scratchSpaceRequired > 0) { - deviceScratch = make_shared_device(scratchSpaceRequired, 0); + deviceScratch[i] = make_shared_device(scratchSpaceRequired, 0); commandArgs.setArgument( - gemm->getScratchTag(), ArgumentType::Value, deviceScratch.get()); + gemm->getScratchTag(policy), ArgumentType::Value, deviceScratch[i].get()); } } @@ -457,6 +461,24 @@ namespace rocRoller::Client::GEMMClient auto [correct, rnorm] = validate( hostA, hostB, hostC, hostD, hostScaleA, hostScaleB, problemParams, arch); + // Verify SyncFlags scratch is all zeros after kernel + auto syncFlagsIdx = static_cast(ScratchPolicy::SyncFlags); + if(deviceScratch[syncFlagsIdx]) + { + auto syncFlagsSize + = commandKernel->scratchSpaceRequired(ScratchPolicy::SyncFlags, runtimeArgs); + std::vector syncFlagsResult(syncFlagsSize); + AssertFatal(hipMemcpy(syncFlagsResult.data(), + deviceScratch[syncFlagsIdx].get(), + syncFlagsSize, + hipMemcpyDeviceToHost) + == (hipError_t)HIP_SUCCESS); + AssertFatal(std::all_of(syncFlagsResult.begin(), + syncFlagsResult.end(), + [](uint8_t v) { return v == 0; }), + "SyncFlags scratch should be all zeros after kernel execution"); + } + result.checked = true; result.correct = correct; result.rnorm = rnorm; diff --git a/shared/rocroller/lib/include/rocRoller/Context.hpp b/shared/rocroller/lib/include/rocRoller/Context.hpp index 0a5be9c3305..00fcfbfae36 100644 --- a/shared/rocroller/lib/include/rocRoller/Context.hpp +++ b/shared/rocroller/lib/include/rocRoller/Context.hpp @@ -172,19 +172,19 @@ namespace rocRoller std::array, static_cast(Register::Type::Count)> m_allocators; - std::shared_ptr m_observer; - AssemblyKernelPtr m_kernel; - std::shared_ptr m_argLoader; - std::shared_ptr m_instructions; - std::shared_ptr m_mem; - LabelAllocatorPtr m_labelAllocator; - std::shared_ptr m_ldsAllocator; - std::map m_scratchAllocators; - std::shared_ptr m_copier; - std::shared_ptr m_brancher; - std::shared_ptr m_crasher; - std::shared_ptr m_random; - std::shared_ptr m_scope; + std::shared_ptr m_observer; + AssemblyKernelPtr m_kernel; + std::shared_ptr m_argLoader; + std::shared_ptr m_instructions; + std::shared_ptr m_mem; + LabelAllocatorPtr m_labelAllocator; + std::shared_ptr m_ldsAllocator; + std::map m_scratchAllocators; + std::shared_ptr m_copier; + std::shared_ptr m_brancher; + std::shared_ptr m_crasher; + std::shared_ptr m_random; + std::shared_ptr m_scope; std::string m_assemblyFileName; KernelOptions m_kernelOptions; diff --git a/shared/rocroller/lib/include/rocRoller/KernelGraph/CoordinateGraph/Dimension.hpp b/shared/rocroller/lib/include/rocRoller/KernelGraph/CoordinateGraph/Dimension.hpp index 17005bcb424..55e9362f5db 100644 --- a/shared/rocroller/lib/include/rocRoller/KernelGraph/CoordinateGraph/Dimension.hpp +++ b/shared/rocroller/lib/include/rocRoller/KernelGraph/CoordinateGraph/Dimension.hpp @@ -153,8 +153,11 @@ namespace rocRoller * * @param size How many elements make up the User dimension. * @param offset Location of data within the scratch space + * @param argName Name of the argument for this scratch space */ - User(Expression::ExpressionPtr size, Expression::ExpressionPtr offset); + User(Expression::ExpressionPtr size, + Expression::ExpressionPtr offset, + std::string const& argName); std::string name() const override; }; diff --git a/shared/rocroller/lib/include/rocRoller/KernelGraph/Utils.hpp b/shared/rocroller/lib/include/rocRoller/KernelGraph/Utils.hpp index 0ce19ae2ff9..7ae8e823c2b 100644 --- a/shared/rocroller/lib/include/rocRoller/KernelGraph/Utils.hpp +++ b/shared/rocroller/lib/include/rocRoller/KernelGraph/Utils.hpp @@ -267,11 +267,11 @@ namespace rocRoller * @param context * @return User */ - rocRoller::KernelGraph::CoordinateGraph::User newScratchCoordinate( - Expression::ExpressionPtr size, - VariableType varType, - ScratchPolicy policy, - ContextPtr context); + rocRoller::KernelGraph::CoordinateGraph::User + newScratchCoordinate(Expression::ExpressionPtr size, + VariableType varType, + ScratchPolicy policy, + ContextPtr context); /** * @brief Replace operation with a new operation. diff --git a/shared/rocroller/lib/include/rocRoller/KernelOptions.hpp b/shared/rocroller/lib/include/rocRoller/KernelOptions.hpp index 595c0998658..713760a155c 100644 --- a/shared/rocroller/lib/include/rocRoller/KernelOptions.hpp +++ b/shared/rocroller/lib/include/rocRoller/KernelOptions.hpp @@ -33,6 +33,7 @@ #include #include +#include #include #include @@ -47,6 +48,12 @@ namespace rocRoller const std::string NUMWGS = "numWGs"; const std::string WGM = "WGM"; + // Helper to get scratch argument name for a specific policy + inline std::string getScratchName(ScratchPolicy policy) + { + return rocRoller::SCRATCH + "_" + toString(policy); + } + class KernelOptions { public: diff --git a/shared/rocroller/lib/source/KernelGraph/CoordinateGraph/Dimension.cpp b/shared/rocroller/lib/source/KernelGraph/CoordinateGraph/Dimension.cpp index 61030ea2708..09cd3a9ec80 100644 --- a/shared/rocroller/lib/source/KernelGraph/CoordinateGraph/Dimension.cpp +++ b/shared/rocroller/lib/source/KernelGraph/CoordinateGraph/Dimension.cpp @@ -117,9 +117,11 @@ namespace rocRoller return name() + stag; } - User::User(Expression::ExpressionPtr size, Expression::ExpressionPtr offset) + User::User(Expression::ExpressionPtr size, + Expression::ExpressionPtr offset, + std::string const& argName) : BaseDimension(size, Expression::literal(1u), offset) - , argumentName(rocRoller::SCRATCH) + , argumentName(argName) { } diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index e9f413893a9..f7d109e792c 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -292,8 +292,8 @@ namespace rocRoller auto strideX = sizeY; auto strideY = literal(1u); - auto globalScratch - = newScratchCoordinate(simplify(sizeX * sizeY), varType, ScratchPolicy::TileData, context); + auto globalScratch = newScratchCoordinate( + simplify(sizeX * sizeY), varType, ScratchPolicy::TileData, context); auto globalScratchTag = graph.coordinates.addElement(globalScratch); std::vector jammedSizes = {loopInfo.xLoopSize, loopInfo.yLoopSize}; @@ -579,11 +579,12 @@ namespace rocRoller auto doWhileTag = graph.control.addElement( DoWhileOp{(DF(flagRegister) == zero), "Global sync spin loop"}); - // Duplicate flag and next workgroup coordinates for reset operation + // Create coordinate to indicate flag index to reset auto resetNextWorkgroupTag = graph.coordinates.addElement(Linear(nullptr, one)); graph.coordinates.addElement( Join(), {workgroup, plusOneTag, forReceiveTileLoopCoord}, {resetNextWorkgroupTag}); + // Duplicate flags scratch coordinate for reset operation auto resetFlagsScratchTag = graph.coordinates.addElement(*graph.coordinates.get(flagsScratchTag)); graph.coordinates.addElement(Duplicate(), {resetFlagsScratchTag}, {flagsScratchTag}); diff --git a/shared/rocroller/lib/source/KernelGraph/Utils.cpp b/shared/rocroller/lib/source/KernelGraph/Utils.cpp index c34e8771235..184417f7e9a 100644 --- a/shared/rocroller/lib/source/KernelGraph/Utils.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Utils.cpp @@ -29,6 +29,7 @@ #include #include +#include namespace rocRoller { @@ -773,14 +774,14 @@ namespace rocRoller return rv; } - rocRoller::KernelGraph::CoordinateGraph::User newScratchCoordinate( - Expression::ExpressionPtr size, - VariableType varType, - ScratchPolicy policy, - ContextPtr context) + rocRoller::KernelGraph::CoordinateGraph::User + newScratchCoordinate(Expression::ExpressionPtr size, + VariableType varType, + ScratchPolicy policy, + ContextPtr context) { auto currentOffset = context->getScratchAmount(policy); - auto newCoordinate = CT::User(size, currentOffset); + auto newCoordinate = CT::User(size, currentOffset, getScratchName(policy)); // TODO Audit bytes/bits // Can we move size inside the CeilDivide? context->allocateScratch( diff --git a/shared/rocroller/test/common/include/common/CommonGraphs.hpp b/shared/rocroller/test/common/include/common/CommonGraphs.hpp index b78c57efc88..cc6a549de61 100644 --- a/shared/rocroller/test/common/include/common/CommonGraphs.hpp +++ b/shared/rocroller/test/common/include/common/CommonGraphs.hpp @@ -30,6 +30,7 @@ #pragma once +#include #include #include @@ -229,6 +230,8 @@ namespace rocRollerTest rocRoller::Operations::OperationTag m_tagA, m_tagB, m_tagC, m_tagD; rocRoller::Operations::OperationTag m_tagNumWGs; + std::map m_scratchTags; + CommandPtr m_command; }; diff --git a/shared/rocroller/test/common/src/CommonGraphs.cpp b/shared/rocroller/test/common/src/CommonGraphs.cpp index dec5327d6df..858fd105023 100644 --- a/shared/rocroller/test/common/src/CommonGraphs.cpp +++ b/shared/rocroller/test/common/src/CommonGraphs.cpp @@ -28,6 +28,7 @@ #include #include +#include #include namespace rocRollerTest::Graphs @@ -285,12 +286,16 @@ namespace rocRollerTest::Graphs rocRoller::NUMWGS); } - auto tagScratch = m_command->allocateTag(); - m_command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), - tagScratch, - ArgumentType::Value, - DataDirection::ReadWrite, - rocRoller::SCRATCH); + for(int i = 0; i < static_cast(ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + m_scratchTags[policy] = m_command->allocateTag(); + m_command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), + m_scratchTags[policy], + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(policy)); + } } CommandPtr GEMM::getCommand() diff --git a/shared/rocroller/test/unit/GEMMFusion.cpp b/shared/rocroller/test/unit/GEMMFusion.cpp index 552b80e3865..806145ed7d1 100644 --- a/shared/rocroller/test/unit/GEMMFusion.cpp +++ b/shared/rocroller/test/unit/GEMMFusion.cpp @@ -29,6 +29,7 @@ #include #endif /* ROCROLLER_USE_HIP */ +#include #include #include @@ -38,6 +39,7 @@ #include #include #include +#include #include #include #include @@ -226,7 +228,7 @@ namespace GEMMDriverTest tagScratch, ArgumentType::Value, DataDirection::ReadWrite, - rocRoller::SCRATCH); + getScratchName(ScratchPolicy::TileData)); auto params = std::make_shared(); params->setManualKernelDimension(2); @@ -329,10 +331,19 @@ namespace GEMMDriverTest { commandArgs.setArgument(command->getNextTag(), ArgumentType::Value, gemm.numWGs); } - auto scratchSpaceRequired - = commandKernel.scratchSpaceRequired(commandArgs.runtimeArguments()); - auto deviceScratch = make_shared_device(scratchSpaceRequired, 0); - commandArgs.setArgument(tagScratch, ArgumentType::Value, deviceScratch.get()); + std::shared_ptr deviceScratch[static_cast(ScratchPolicy::Count)]; + size_t scratchSpaceRequired[static_cast(ScratchPolicy::Count)]; + for(size_t i = 0; i < static_cast(ScratchPolicy::Count); ++i) + { + scratchSpaceRequired[i] = commandKernel.scratchSpaceRequired( + static_cast(i), commandArgs.runtimeArguments()); + if(scratchSpaceRequired[i] > 0) + { + deviceScratch[i] = make_shared_device(scratchSpaceRequired[i], 0); + commandArgs.setArgument( + command->getNextTag(), ArgumentType::Value, deviceScratch[i].get()); + } + } // Host result std::vector h_result(M * N, 0.0); @@ -365,8 +376,14 @@ namespace GEMMDriverTest for(int iteration = 0; iteration < numIters; ++iteration) { ASSERT_THAT(hipMemset(deviceD.get(), 0, M * N * sizeof(T)), HasHipSuccess(0)); - ASSERT_THAT(hipMemset(deviceScratch.get(), 0, scratchSpaceRequired), - HasHipSuccess(0)); + for(size_t i = 0; i < static_cast(ScratchPolicy::Count); ++i) + { + if(scratchSpaceRequired[i] > 0) + { + ASSERT_THAT(hipMemset(deviceScratch[i].get(), 0, scratchSpaceRequired[i]), + HasHipSuccess(0)); + } + } commandKernel.launchKernel(commandArgs.runtimeArguments()); m_context = commandKernel.getContext(); @@ -376,6 +393,22 @@ namespace GEMMDriverTest d_result.data(), deviceD.get(), M * N * sizeof(T), hipMemcpyDeviceToHost), HasHipSuccess(0)); + // Verify SyncFlags scratch is all zeros after kernel + auto syncFlagsIdx = static_cast(ScratchPolicy::SyncFlags); + if(scratchSpaceRequired[syncFlagsIdx] > 0) + { + std::vector syncFlagsResult(scratchSpaceRequired[syncFlagsIdx]); + ASSERT_THAT(hipMemcpy(syncFlagsResult.data(), + deviceScratch[syncFlagsIdx].get(), + scratchSpaceRequired[syncFlagsIdx], + hipMemcpyDeviceToHost), + HasHipSuccess(0)); + EXPECT_TRUE(std::all_of(syncFlagsResult.begin(), + syncFlagsResult.end(), + [](uint8_t v) { return v == 0; })) + << "SyncFlags scratch should be all zeros after kernel execution"; + } + auto tol = gemmAcceptableError( M, N, K, m_context->targetArchitecture().target()); auto res = compare(d_result, h_result, tol); diff --git a/shared/rocroller/test/unit/GEMMTest.cpp b/shared/rocroller/test/unit/GEMMTest.cpp index 06ffd74e8ec..e5c02b06f47 100644 --- a/shared/rocroller/test/unit/GEMMTest.cpp +++ b/shared/rocroller/test/unit/GEMMTest.cpp @@ -30,6 +30,7 @@ #endif /* ROCROLLER_USE_HIP */ #include +#include #include #include @@ -40,6 +41,7 @@ #include #include #include +#include #include #include #include @@ -456,30 +458,18 @@ namespace GEMMDriverTest command->addOperation(rocRoller::Operations::T_Store_Tiled(tagCvt, tagTensorD)); } - auto tileDataScratchTag = command->allocateTag(); - command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), - tileDataScratchTag, - ArgumentType::Value, - DataDirection::ReadWrite, - rocRoller::SCRATCH); - - auto syncFlagsScratchTag = command->allocateTag(); - command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), - syncFlagsScratchTag, - ArgumentType::Value, - DataDirection::ReadWrite, - rocRoller::SCRATCH); - - // Operations::OperationTag scratchTag[static_cast(ScratchPolicy::Count)]; - // for(size_t i = 0; i < static_cast(ScratchPolicy::Count); ++i) - // { - // scratchTag[i] = command->allocateTag(); - // command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), - // scratchTag[i], - // ArgumentType::Value, - // DataDirection::ReadWrite, rocRoller::SCRATCH); - // } - // std::cout << "create scratchTag" << std::endl; + std::map scratchTag; + for(int i = 0; i < static_cast(ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + scratchTag[policy] = command->allocateTag(); + command->allocateArgument( + VariableType(DataType::UInt32, PointerType::PointerGlobal), + scratchTag[policy], + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(policy)); + } Operations::OperationTag tagNumWGs; if(gemm.streamK) @@ -692,31 +682,18 @@ namespace GEMMDriverTest commandArgs.setArgument(tagNumWGs, ArgumentType::Value, gemm.numWGs); } - auto tileDataScratchSpaceRequired = commandKernel.scratchSpaceRequired( - ScratchPolicy::TileData, commandArgs.runtimeArguments()); - auto deviceTileDataScratch = make_shared_device(tileDataScratchSpaceRequired, 0); - commandArgs.setArgument(tileDataScratchTag, ArgumentType::Value, deviceTileDataScratch.get()); - - auto syncFlagsScratchSpaceRequired = commandKernel.scratchSpaceRequired( - ScratchPolicy::SyncFlags, commandArgs.runtimeArguments()); - auto deviceSyncFlagsScratch = make_shared_device(syncFlagsScratchSpaceRequired, 0); - commandArgs.setArgument(syncFlagsScratchTag, ArgumentType::Value, deviceSyncFlagsScratch.get()); - - // std::shared_ptr deviceScratch[static_cast(ScratchPolicy::Count)]; - // size_t scratchSpaceRequired[static_cast(ScratchPolicy::Count)]; - // for(size_t i = 0; i < static_cast(ScratchPolicy::Count); ++i) - // { - // scratchSpaceRequired[i] = commandKernel.scratchSpaceRequired( - // static_cast(i), commandArgs.runtimeArguments()); - // // if(scratchSpaceRequired[i] > 0) - // { - // std::cout << "scratch space required: " << scratchSpaceRequired[i] << std::endl; - // std::cout << "index: " << i << std::endl; - // deviceScratch[i] = make_shared_device(scratchSpaceRequired[i], 0); - // commandArgs.setArgument(scratchTag[i], ArgumentType::Value, deviceScratch[i].get()); - // } - // } - std::cout << "set commmand arguments" << std::endl; + std::map> deviceScratch; + std::map scratchSpaceRequired; + for(int i = 0; i < static_cast(ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + scratchSpaceRequired[policy] + = commandKernel.scratchSpaceRequired(policy, commandArgs.runtimeArguments()); + deviceScratch[policy] = make_shared_device( + std::max(scratchSpaceRequired[policy], size_t(1)), 0); + commandArgs.setArgument( + scratchTag[policy], ArgumentType::Value, deviceScratch[policy].get()); + } if(gemm.workgroupMappingDim != -1) { @@ -829,39 +806,37 @@ namespace GEMMDriverTest for(int iteration = 0; iteration < numIters; ++iteration) { ASSERT_THAT(hipMemset(deviceD.get(), 0, M * N * sizeof(TD)), HasHipSuccess(0)); - // for(size_t i = 0; i < static_cast(ScratchPolicy::Count); ++i) - // { - // // if(scratchSpaceRequired[i] > 0) - // { - // ASSERT_THAT(hipMemset(deviceScratch[i].get(), 0, scratchSpaceRequired[i]), HasHipSuccess(0)); - // } - // } - ASSERT_THAT(hipMemset(deviceTileDataScratch.get(), 0, tileDataScratchSpaceRequired), HasHipSuccess(0)); - ASSERT_THAT(hipMemset(deviceSyncFlagsScratch.get(), 0, syncFlagsScratchSpaceRequired), HasHipSuccess(0)); - - std::cout << "launch kernel" << std::endl; + for(int i = 0; i < static_cast(ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + ASSERT_THAT(hipMemset(deviceScratch[policy].get(), + 0, + std::max(scratchSpaceRequired[policy], size_t(1))), + HasHipSuccess(0)); + } commandKernel.launchKernel(commandArgs.runtimeArguments()); - std::cout << "copy device result" << std::endl; ASSERT_THAT( hipMemcpy( d_result.data(), deviceD.get(), M * N * sizeof(TD), hipMemcpyDeviceToHost), HasHipSuccess(0)); - std::cout << "verify sync flags scratch" << std::endl; - // Verify SyncFlags scratch is all zeros after kernel - // auto syncFlagsIdx = static_cast(ScratchPolicy::SyncFlags); - auto syncFlagsResult = std::vector(syncFlagsScratchSpaceRequired); - ASSERT_THAT(hipMemcpy(syncFlagsResult.data(), - // deviceScratch[syncFlagsIdx].get(), - deviceSyncFlagsScratch.get(), - syncFlagsScratchSpaceRequired, - hipMemcpyDeviceToHost), - HasHipSuccess(0)); - EXPECT_TRUE(std::all_of(syncFlagsResult.begin(), syncFlagsResult.end(), [](uint8_t v) { return v == 0; })) - << "SyncFlags scratch should be all zeros after kernel execution"; + if(scratchSpaceRequired[ScratchPolicy::SyncFlags] > 0) + { + std::vector syncFlagsResult( + scratchSpaceRequired[ScratchPolicy::SyncFlags]); + ASSERT_THAT(hipMemcpy(syncFlagsResult.data(), + deviceScratch[ScratchPolicy::SyncFlags].get(), + scratchSpaceRequired[ScratchPolicy::SyncFlags], + hipMemcpyDeviceToHost), + HasHipSuccess(0)); + EXPECT_TRUE(std::all_of(syncFlagsResult.begin(), + syncFlagsResult.end(), + [](uint8_t v) { return v == 0; })) + << "SyncFlags scratch should be all zeros after kernel execution"; + } auto tol = gemmAcceptableError( M, N, K, m_context->targetArchitecture().target()); From 261b9802ec40f4e04b01362679a6e57326315c40 Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Fri, 5 Dec 2025 06:13:41 +0000 Subject: [PATCH 06/22] Add unit tests --- .../Transformations/AddStreamK.cpp | 24 +++ .../rocroller/test/catch/AddStreamKTest.cpp | 161 ++++++++++++++++++ .../test/catch/ScratchOperationTest.cpp | 91 ++++++++++ .../test/common/src/CommonGraphs.cpp | 1 + shared/rocroller/test/unit/GEMMFusion.cpp | 18 +- shared/rocroller/test/unit/GEMMTest.cpp | 2 - 6 files changed, 289 insertions(+), 8 deletions(-) diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index ef95d6e66e3..cc4c785231d 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -585,6 +585,30 @@ namespace rocRoller auto doWhileTag = graph.control.addElement( DoWhileOp{(DF(flagRegister) == zero), "Global sync spin loop"}); + // The coordinate graph for load, store, and reset flags: + // + // Workgroup + // | + // PassThrough + // | + // v + // flagsScratchTag <-Duplicate-- resetFlagsScratchTag + // | ^ + // PassThrough PassThrough + // | | + // v resetNextWorkgroupTag + // nextWorkgroupTag ^ + // | Join + // Split / | \ + // / | \ / | \ + // v v v / | \ + // Workgroup plusOne forReceiveTileLoop + // + // Note: nextWorkgroupTag and resetNextWorkgroupTag both connect to the same + // neighbors (Workgroup, plusOne, forReceiveTileLoop) but in opposite directions + // (Split vs Join). This allows loading flags from WG+1+i and resetting flags + // at the same index. + // Create coordinate to indicate flag index to reset auto resetNextWorkgroupTag = graph.coordinates.addElement(Linear(nullptr, one)); graph.coordinates.addElement( diff --git a/shared/rocroller/test/catch/AddStreamKTest.cpp b/shared/rocroller/test/catch/AddStreamKTest.cpp index 96cab26c283..4078fa32e0e 100644 --- a/shared/rocroller/test/catch/AddStreamKTest.cpp +++ b/shared/rocroller/test/catch/AddStreamKTest.cpp @@ -46,6 +46,7 @@ #include #include #include +#include #include #include #include @@ -621,3 +622,163 @@ TEST_CASE("AddStreamK with unroll K", "[streamk][kernel-graph]") } } } + +TEST_CASE("AddStreamK scratch policy usage", "[streamk][scratch]") +{ + using namespace rocRoller; + using namespace rocRoller::Operations; + using namespace KernelGraph; + using namespace ControlGraph; + using namespace CoordinateGraph; + + auto context = TestContext::ForDefaultTarget(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); + + // Use standard mode which creates both tile data and flags scratch + auto mode = StreamKMode::Standard; + + example.setTileSize(128, 256, 8); + example.setMFMA(32, 32, 2, 1); + example.setUseLDS(false, false, false); + example.setPrefetch(false, 0, 0, false); + example.setStreamK(mode); + + auto numWGs = example.getFlattenedWorkgroupSize(); + auto numWGsExpr = std::make_shared(numWGs); + + auto kgraph = example.getKernelGraph(); + auto params = example.getCommandParameters(); + + // Apply transforms including AddStreamK + std::vector transforms; + transforms.push_back(std::make_shared()); + transforms.push_back(std::make_shared(false)); + transforms.push_back(std::make_shared(params)); + transforms.push_back(std::make_shared(params, context.get())); + transforms.push_back(std::make_shared(context.get())); + transforms.push_back(std::make_shared(params, context.get())); + transforms.push_back(std::make_shared(params, context.get())); + transforms.push_back(std::make_shared()); + transforms.push_back(std::make_shared()); + transforms.push_back(std::make_shared( + context.get(), params, rocRoller::XLOOP, rocRoller::KLOOP, numWGsExpr)); + transforms.push_back(std::make_shared(context.get())); + + for(auto& t : transforms) + kgraph = kgraph.transform(t); + + SECTION("Tile data uses None policy") + { + // After AddStreamK, the None policy should have non-zero scratch allocation + // (used for tile data exchange between workgroups) + auto amountNone = context->getScratchAmount(ScratchPolicy::None); + auto valueNone = Expression::evaluate(amountNone); + // Tile data scratch should be allocated + CHECK(getUnsignedInt(valueNone) > 0); + } + + SECTION("Flags use ZeroedBeforeAndAfter policy") + { + // After AddStreamK, the ZeroedBeforeAndAfter policy should have non-zero scratch + // allocation (used for synchronization flags that must be zeroed before/after kernel) + auto amountZeroed = context->getScratchAmount(ScratchPolicy::ZeroedBeforeAndAfter); + auto valueZeroed = Expression::evaluate(amountZeroed); + // Flags scratch should be allocated for sync purposes + CHECK(getUnsignedInt(valueZeroed) > 0); + } + + SECTION("Different policies have different allocations") + { + auto amountNone = context->getScratchAmount(ScratchPolicy::None); + auto amountZeroed = context->getScratchAmount(ScratchPolicy::ZeroedBeforeAndAfter); + + auto valueNone = Expression::evaluate(amountNone); + auto valueZeroed = Expression::evaluate(amountZeroed); + + // Both should have allocations, but they should be independent + // (flags are typically smaller - one uint32 per WG, while tile data is larger) + CHECK(getUnsignedInt(valueNone) > 0); + CHECK(getUnsignedInt(valueZeroed) > 0); + // Tile data should typically be larger than flags + CHECK(getUnsignedInt(valueNone) > getUnsignedInt(valueZeroed)); + } + + SECTION("Only has one scratch coordinate for tile data") + { + + auto findScratchNone = [&](int tag) { + auto maybeUser = kgraph.coordinates.get(tag); + if(!maybeUser) + return false; + return maybeUser->argumentName == getScratchName(ScratchPolicy::None); + }; + + auto scratchNoneCoordinates = kgraph.coordinates.findElements(findScratchNone).to(); + CHECK(scratchNoneCoordinates.size() == 1); + + + } + + SECTION("Load, store, and reset flags scratch space correctly") + { + auto findScratchZeroed = [&](int tag) { + auto maybeUser = kgraph.coordinates.get(tag); + if(!maybeUser) + return false; + return maybeUser->argumentName == getScratchName(ScratchPolicy::ZeroedBeforeAndAfter); + }; + + auto scratchZeroedCoordinates + = kgraph.coordinates.findElements(findScratchZeroed).to(); + CHECK(scratchZeroedCoordinates.size() == 2); + + // Check the reset flags coordinate connects to the original flags coordinate via a duplicate edge + auto resetFlagsCoordinate = -1, originalFlagsCoordinate = -1; + for (const auto& tag : scratchZeroedCoordinates) + { + std::cout << "tag: " << tag << std::endl; + // Duplicate edge connects the reset flags coordinate to the original flags coordinate + auto isDuplicate = isEdge; + auto outDuplicates = kgraph.coordinates.getOutputNodeIndices(tag, isDuplicate) + .to(); + + CHECK((outDuplicates.size() == 1 || outDuplicates.empty())); + if (outDuplicates.size() == 1) + { + resetFlagsCoordinate = tag; + originalFlagsCoordinate = outDuplicates[0]; + } + + } + std::cout << "resetFlagsCoordinate: " << resetFlagsCoordinate << std::endl; + std::cout << "originalFlagsCoordinate: " << originalFlagsCoordinate << std::endl; + CHECK(resetFlagsCoordinate != -1); + CHECK(originalFlagsCoordinate != -1); + + // Duplicate coordinate should have a higher tag than the original flags coordinate + CHECK(resetFlagsCoordinate > originalFlagsCoordinate); + + + auto isPassThroughEdge = isEdge; + auto isJoinEdge = isEdge; + auto isSplitEdge = isEdge; + auto maybeNextWorkgroupTag = kgraph.coordinates.getOutputNodeIndices(originalFlagsCoordinate, isPassThroughEdge).to(); + CHECK(maybeNextWorkgroupTag.size() == 1); + auto nextWorkgroupTag = maybeNextWorkgroupTag[0]; + auto maybeSplit = kgraph.coordinates.getOutputNodeIndices(nextWorkgroupTag, isSplitEdge).to(); + CHECK(maybeSplit.size() == 3); + // CHECK(maybeSplit[0] == maybeSplit[1]); + // CHECK(maybeSplit[1] == maybeSplit[2]); + auto maybeResetNextWorkgroupTag0 = kgraph.coordinates.getOutputNodeIndices(maybeSplit[0], isJoinEdge).to(); + auto maybeResetNextWorkgroupTag1 = kgraph.coordinates.getOutputNodeIndices(maybeSplit[1], isJoinEdge).to(); + auto maybeResetNextWorkgroupTag2 = kgraph.coordinates.getOutputNodeIndices(maybeSplit[2], isJoinEdge).to(); + CHECK(maybeResetNextWorkgroupTag0.size() == 1); + CHECK(maybeResetNextWorkgroupTag1.size() == 1); + CHECK(maybeResetNextWorkgroupTag2.size() == 1); + CHECK(maybeResetNextWorkgroupTag0[0] == maybeResetNextWorkgroupTag1[0]); + CHECK(maybeResetNextWorkgroupTag1[0] == maybeResetNextWorkgroupTag2[0]); + auto maybeResetFlagsCoordinate = kgraph.coordinates.getOutputNodeIndices(maybeResetNextWorkgroupTag0[0], isPassThroughEdge).to(); + CHECK(maybeResetFlagsCoordinate.size() == 1); + CHECK(maybeResetFlagsCoordinate[0] == resetFlagsCoordinate); + } +} diff --git a/shared/rocroller/test/catch/ScratchOperationTest.cpp b/shared/rocroller/test/catch/ScratchOperationTest.cpp index 621b0877d92..b23ec3aec07 100644 --- a/shared/rocroller/test/catch/ScratchOperationTest.cpp +++ b/shared/rocroller/test/catch/ScratchOperationTest.cpp @@ -24,6 +24,9 @@ * *******************************************************************************/ +#include +#include +#include #include #include #include @@ -32,6 +35,9 @@ #include #include +#include "TestContext.hpp" + +#include #include using namespace rocRoller; @@ -226,4 +232,89 @@ namespace ScratchOperationTest CHECK(scratchOpTag2.uninitialized() == false); } } + + TEST_CASE("Scratch allocator per policy", "[scratch][context]") + { + auto context = TestContext::ForDefaultTarget(); + + SECTION("Context initializes all policies to zero") + { + for(int i = 0; i < static_cast(ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + auto amount = context->getScratchAmount(policy); + auto value = rocRoller::Expression::evaluate(amount); + CHECK(rocRoller::getUnsignedInt(value) == 0); + } + } + + SECTION("allocateScratch accumulates size per policy") + { + auto size1 = rocRoller::Expression::literal(100u); + auto size2 = rocRoller::Expression::literal(200u); + + context->allocateScratch(ScratchPolicy::None, size1); + context->allocateScratch(ScratchPolicy::None, size2); + + auto amount = context->getScratchAmount(ScratchPolicy::None); + auto value = rocRoller::Expression::evaluate(amount); + CHECK(rocRoller::getUnsignedInt(value) == 300); + } + + SECTION("getScratchAmount returns correct amount per policy") + { + auto size = rocRoller::Expression::literal(500u); + context->allocateScratch(ScratchPolicy::ZeroedBeforeAndAfter, size); + + auto amount = context->getScratchAmount(ScratchPolicy::ZeroedBeforeAndAfter); + auto value = rocRoller::Expression::evaluate(amount); + CHECK(rocRoller::getUnsignedInt(value) == 500); + } + + SECTION("Different policies have independent allocations") + { + auto sizeNone = rocRoller::Expression::literal(100u); + auto sizeZeroed = rocRoller::Expression::literal(200u); + + context->allocateScratch(ScratchPolicy::None, sizeNone); + context->allocateScratch(ScratchPolicy::ZeroedBeforeAndAfter, sizeZeroed); + + auto amountNone = context->getScratchAmount(ScratchPolicy::None); + auto amountZeroed = context->getScratchAmount(ScratchPolicy::ZeroedBeforeAndAfter); + + auto valueNone = rocRoller::Expression::evaluate(amountNone); + auto valueZeroed = rocRoller::Expression::evaluate(amountZeroed); + + CHECK(rocRoller::getUnsignedInt(valueNone) == 100); + CHECK(rocRoller::getUnsignedInt(valueZeroed) == 200); + } + } + + TEST_CASE("Utilities to get scratch policy name", "[scratch][utils]") + { + SECTION("Returns correct name for None policy") + { + auto name = rocRoller::getScratchName(ScratchPolicy::None); + CHECK(name == "SCRATCH_None"); + } + + SECTION("Returns correct name for ZeroedBeforeAndAfter policy") + { + auto name = rocRoller::getScratchName(ScratchPolicy::ZeroedBeforeAndAfter); + CHECK(name == "SCRATCH_ZeroedBeforeAndAfter"); + } + + SECTION("All policies have unique names") + { + std::set names; + for(int i = 0; i < static_cast(ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + auto name = rocRoller::getScratchName(policy); + CHECK(names.find(name) == names.end()); + names.insert(name); + } + CHECK(names.size() == static_cast(ScratchPolicy::Count)); + } + } } diff --git a/shared/rocroller/test/common/src/CommonGraphs.cpp b/shared/rocroller/test/common/src/CommonGraphs.cpp index 24834a593f7..f94235d9457 100644 --- a/shared/rocroller/test/common/src/CommonGraphs.cpp +++ b/shared/rocroller/test/common/src/CommonGraphs.cpp @@ -290,6 +290,7 @@ namespace rocRollerTest::Graphs { auto policy = static_cast(i); m_scratchTags[policy] = m_command->allocateTag(); + m_command->addOperation(rocRoller::Operations::Scratch(m_scratchTags[policy], policy)); m_command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), m_scratchTags[policy], ArgumentType::Value, diff --git a/shared/rocroller/test/unit/GEMMFusion.cpp b/shared/rocroller/test/unit/GEMMFusion.cpp index e77d0e61d6b..34887b384f5 100644 --- a/shared/rocroller/test/unit/GEMMFusion.cpp +++ b/shared/rocroller/test/unit/GEMMFusion.cpp @@ -223,12 +223,18 @@ namespace GEMMDriverTest rocRoller::Operations::Tensor(2, dataType, oneStridesN)); // E command->addOperation(rocRoller::Operations::T_Store_Tiled(tagRelu, tagTensorRelu)); - auto tagScratch = command->allocateTag(); - command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), - tagScratch, - ArgumentType::Value, - DataDirection::ReadWrite, - getScratchName(Operations::ScratchPolicy::None)); + Operations::OperationTag tagScratch[static_cast(Operations::ScratchPolicy::Count)]; + for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + tagScratch[i] = command->allocateTag(); + command->addOperation(rocRoller::Operations::Scratch(tagScratch[i], policy)); + command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), + tagScratch[i], + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(policy)); + } auto params = std::make_shared(); params->setManualKernelDimension(2); diff --git a/shared/rocroller/test/unit/GEMMTest.cpp b/shared/rocroller/test/unit/GEMMTest.cpp index 07667816e9f..3f370da9db8 100644 --- a/shared/rocroller/test/unit/GEMMTest.cpp +++ b/shared/rocroller/test/unit/GEMMTest.cpp @@ -29,8 +29,6 @@ #include #endif /* ROCROLLER_USE_HIP */ -#include -#include #include #include "GEMMF8F6F4.hpp" From 002c05be8b24cd44647adc47a342408a6877a39d Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Fri, 5 Dec 2025 06:15:39 +0000 Subject: [PATCH 07/22] Formatting, add comments --- .../Transformations/AddStreamK.cpp | 40 ++++++++-------- .../rocroller/test/catch/AddStreamKTest.cpp | 47 ++++++++++--------- shared/rocroller/test/unit/GEMMFusion.cpp | 15 +++--- shared/rocroller/test/unit/GEMMTestBase.hpp | 15 +++--- 4 files changed, 60 insertions(+), 57 deletions(-) diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index cc4c785231d..109cd73f8bf 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -585,29 +585,29 @@ namespace rocRoller auto doWhileTag = graph.control.addElement( DoWhileOp{(DF(flagRegister) == zero), "Global sync spin loop"}); - // The coordinate graph for load, store, and reset flags: - // - // Workgroup - // | - // PassThrough - // | - // v - // flagsScratchTag <-Duplicate-- resetFlagsScratchTag - // | ^ - // PassThrough PassThrough - // | | - // v resetNextWorkgroupTag - // nextWorkgroupTag ^ - // | Join - // Split / | \ + // The coordinate graph for load, store, and reset flags: + // + // Workgroup + // | + // PassThrough + // | + // v + // flagsScratchTag <-Duplicate-- resetFlagsScratchTag + // | ^ + // PassThrough PassThrough + // | | + // v resetNextWorkgroupTag + // nextWorkgroupTag ^ + // | Join + // Split / | \ // / | \ / | \ // v v v / | \ // Workgroup plusOne forReceiveTileLoop - // - // Note: nextWorkgroupTag and resetNextWorkgroupTag both connect to the same - // neighbors (Workgroup, plusOne, forReceiveTileLoop) but in opposite directions - // (Split vs Join). This allows loading flags from WG+1+i and resetting flags - // at the same index. + // + // Note: nextWorkgroupTag and resetNextWorkgroupTag both connect to the same + // neighbors (Workgroup, plusOne, forReceiveTileLoop) but in opposite directions + // (Split vs Join). This allows loading flags from WG+1+i and resetting flags + // at the same index. // Create coordinate to indicate flag index to reset auto resetNextWorkgroupTag = graph.coordinates.addElement(Linear(nullptr, one)); diff --git a/shared/rocroller/test/catch/AddStreamKTest.cpp b/shared/rocroller/test/catch/AddStreamKTest.cpp index 4078fa32e0e..ff3640ebb46 100644 --- a/shared/rocroller/test/catch/AddStreamKTest.cpp +++ b/shared/rocroller/test/catch/AddStreamKTest.cpp @@ -713,10 +713,9 @@ TEST_CASE("AddStreamK scratch policy usage", "[streamk][scratch]") return maybeUser->argumentName == getScratchName(ScratchPolicy::None); }; - auto scratchNoneCoordinates = kgraph.coordinates.findElements(findScratchNone).to(); + auto scratchNoneCoordinates + = kgraph.coordinates.findElements(findScratchNone).to(); CHECK(scratchNoneCoordinates.size() == 1); - - } SECTION("Load, store, and reset flags scratch space correctly") @@ -727,28 +726,27 @@ TEST_CASE("AddStreamK scratch policy usage", "[streamk][scratch]") return false; return maybeUser->argumentName == getScratchName(ScratchPolicy::ZeroedBeforeAndAfter); }; - + auto scratchZeroedCoordinates = kgraph.coordinates.findElements(findScratchZeroed).to(); CHECK(scratchZeroedCoordinates.size() == 2); // Check the reset flags coordinate connects to the original flags coordinate via a duplicate edge auto resetFlagsCoordinate = -1, originalFlagsCoordinate = -1; - for (const auto& tag : scratchZeroedCoordinates) + for(const auto& tag : scratchZeroedCoordinates) { std::cout << "tag: " << tag << std::endl; // Duplicate edge connects the reset flags coordinate to the original flags coordinate auto isDuplicate = isEdge; - auto outDuplicates = kgraph.coordinates.getOutputNodeIndices(tag, isDuplicate) - .to(); + auto outDuplicates + = kgraph.coordinates.getOutputNodeIndices(tag, isDuplicate).to(); CHECK((outDuplicates.size() == 1 || outDuplicates.empty())); - if (outDuplicates.size() == 1) + if(outDuplicates.size() == 1) { - resetFlagsCoordinate = tag; + resetFlagsCoordinate = tag; originalFlagsCoordinate = outDuplicates[0]; } - } std::cout << "resetFlagsCoordinate: " << resetFlagsCoordinate << std::endl; std::cout << "originalFlagsCoordinate: " << originalFlagsCoordinate << std::endl; @@ -758,26 +756,33 @@ TEST_CASE("AddStreamK scratch policy usage", "[streamk][scratch]") // Duplicate coordinate should have a higher tag than the original flags coordinate CHECK(resetFlagsCoordinate > originalFlagsCoordinate); - + // Check the graph of flags scratch coordinates matches AddStreamK implementation auto isPassThroughEdge = isEdge; - auto isJoinEdge = isEdge; - auto isSplitEdge = isEdge; - auto maybeNextWorkgroupTag = kgraph.coordinates.getOutputNodeIndices(originalFlagsCoordinate, isPassThroughEdge).to(); + auto isJoinEdge = isEdge; + auto isSplitEdge = isEdge; + auto maybeNextWorkgroupTag + = kgraph.coordinates.getOutputNodeIndices(originalFlagsCoordinate, isPassThroughEdge) + .to(); CHECK(maybeNextWorkgroupTag.size() == 1); auto nextWorkgroupTag = maybeNextWorkgroupTag[0]; - auto maybeSplit = kgraph.coordinates.getOutputNodeIndices(nextWorkgroupTag, isSplitEdge).to(); + auto maybeSplit = kgraph.coordinates.getOutputNodeIndices(nextWorkgroupTag, isSplitEdge) + .to(); CHECK(maybeSplit.size() == 3); - // CHECK(maybeSplit[0] == maybeSplit[1]); - // CHECK(maybeSplit[1] == maybeSplit[2]); - auto maybeResetNextWorkgroupTag0 = kgraph.coordinates.getOutputNodeIndices(maybeSplit[0], isJoinEdge).to(); - auto maybeResetNextWorkgroupTag1 = kgraph.coordinates.getOutputNodeIndices(maybeSplit[1], isJoinEdge).to(); - auto maybeResetNextWorkgroupTag2 = kgraph.coordinates.getOutputNodeIndices(maybeSplit[2], isJoinEdge).to(); + auto maybeResetNextWorkgroupTag0 + = kgraph.coordinates.getOutputNodeIndices(maybeSplit[0], isJoinEdge).to(); + auto maybeResetNextWorkgroupTag1 + = kgraph.coordinates.getOutputNodeIndices(maybeSplit[1], isJoinEdge).to(); + auto maybeResetNextWorkgroupTag2 + = kgraph.coordinates.getOutputNodeIndices(maybeSplit[2], isJoinEdge).to(); CHECK(maybeResetNextWorkgroupTag0.size() == 1); CHECK(maybeResetNextWorkgroupTag1.size() == 1); CHECK(maybeResetNextWorkgroupTag2.size() == 1); CHECK(maybeResetNextWorkgroupTag0[0] == maybeResetNextWorkgroupTag1[0]); CHECK(maybeResetNextWorkgroupTag1[0] == maybeResetNextWorkgroupTag2[0]); - auto maybeResetFlagsCoordinate = kgraph.coordinates.getOutputNodeIndices(maybeResetNextWorkgroupTag0[0], isPassThroughEdge).to(); + auto maybeResetFlagsCoordinate + = kgraph.coordinates + .getOutputNodeIndices(maybeResetNextWorkgroupTag0[0], isPassThroughEdge) + .to(); CHECK(maybeResetFlagsCoordinate.size() == 1); CHECK(maybeResetFlagsCoordinate[0] == resetFlagsCoordinate); } diff --git a/shared/rocroller/test/unit/GEMMFusion.cpp b/shared/rocroller/test/unit/GEMMFusion.cpp index 34887b384f5..e46dd16a40a 100644 --- a/shared/rocroller/test/unit/GEMMFusion.cpp +++ b/shared/rocroller/test/unit/GEMMFusion.cpp @@ -226,14 +226,15 @@ namespace GEMMDriverTest Operations::OperationTag tagScratch[static_cast(Operations::ScratchPolicy::Count)]; for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) { - auto policy = static_cast(i); + auto policy = static_cast(i); tagScratch[i] = command->allocateTag(); command->addOperation(rocRoller::Operations::Scratch(tagScratch[i], policy)); - command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), - tagScratch[i], - ArgumentType::Value, - DataDirection::ReadWrite, - getScratchName(policy)); + command->allocateArgument( + VariableType(DataType::UInt32, PointerType::PointerGlobal), + tagScratch[i], + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(policy)); } auto params = std::make_shared(); @@ -338,7 +339,7 @@ namespace GEMMDriverTest commandArgs.setArgument(command->getNextTag(), ArgumentType::Value, gemm.numWGs); } std::shared_ptr - deviceScratch[static_cast(Operations::ScratchPolicy::Count)]; + deviceScratch[static_cast(Operations::ScratchPolicy::Count)]; size_t scratchSpaceRequired[static_cast(Operations::ScratchPolicy::Count)]; for(size_t i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) { diff --git a/shared/rocroller/test/unit/GEMMTestBase.hpp b/shared/rocroller/test/unit/GEMMTestBase.hpp index 441ef962b2b..65036c39e9e 100644 --- a/shared/rocroller/test/unit/GEMMTestBase.hpp +++ b/shared/rocroller/test/unit/GEMMTestBase.hpp @@ -435,7 +435,7 @@ namespace GEMMTests } std::map scratchTags; - Operations::OperationTag tagNumWGs; + Operations::OperationTag tagNumWGs; if(gemm.streamK) { tagNumWGs = command->allocateTag(); @@ -676,8 +676,7 @@ namespace GEMMTests policy, commandArgs.runtimeArguments()); if(scratchSpaceRequired[i] > 0) { - deviceScratch[i] - = make_shared_device(scratchSpaceRequired[i], 0); + deviceScratch[i] = make_shared_device(scratchSpaceRequired[i], 0); commandArgs.setArgument( scratchTags.at(policy), ArgumentType::Value, deviceScratch[i].get()); } @@ -799,9 +798,8 @@ namespace GEMMTests { if(scratchSpaceRequired[i] > 0) { - ASSERT_THAT( - hipMemset(deviceScratch[i].get(), 0, scratchSpaceRequired[i]), - HasHipSuccess(0)); + ASSERT_THAT(hipMemset(deviceScratch[i].get(), 0, scratchSpaceRequired[i]), + HasHipSuccess(0)); } } @@ -831,9 +829,8 @@ namespace GEMMTests scratchSpaceRequired[zeroedIdx], hipMemcpyDeviceToHost), HasHipSuccess(0)); - EXPECT_TRUE(std::all_of(zeroedResult.begin(), - zeroedResult.end(), - [](uint8_t v) { return v == 0; })) + EXPECT_TRUE(std::all_of( + zeroedResult.begin(), zeroedResult.end(), [](uint8_t v) { return v == 0; })) << "ZeroedBeforeAndAfter scratch should be all zeros after kernel " "execution"; } From c99bb7abb52266b0a39544f85ad68617461e70fb Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Fri, 5 Dec 2025 17:59:09 +0000 Subject: [PATCH 08/22] Minor changes --- .../include/client/StreamKGEMMSolution.hpp | 36 +++++++++++------ .../Transformations/AddStreamK.cpp | 6 +-- .../rocroller/test/catch/AddStreamKTest.cpp | 40 ++++++++++++++----- shared/rocroller/test/unit/GEMMTestBase.hpp | 35 ++++++++++------ 4 files changed, 80 insertions(+), 37 deletions(-) diff --git a/shared/rocroller/client/include/client/StreamKGEMMSolution.hpp b/shared/rocroller/client/include/client/StreamKGEMMSolution.hpp index 85ebe3155a2..ce3375f6c08 100644 --- a/shared/rocroller/client/include/client/StreamKGEMMSolution.hpp +++ b/shared/rocroller/client/include/client/StreamKGEMMSolution.hpp @@ -65,18 +65,30 @@ namespace rocRoller DataDirection::ReadOnly, rocRoller::NUMWGS); - for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) - { - auto policy = static_cast(i); - m_scratchTags[policy] = command->allocateTag(); - command->addOperation(Operations::Scratch(m_scratchTags[policy], policy)); - command->allocateArgument( - VariableType(DataType::UInt32, PointerType::PointerGlobal), - m_scratchTags[policy], - ArgumentType::Value, - DataDirection::ReadWrite, - getScratchName(policy)); - } + // Create a scratch operation for tile data + m_scratchTags[Operations::ScratchPolicy::None] = command->allocateTag(); + command->addOperation( + Operations::Scratch(m_scratchTags[Operations::ScratchPolicy::None], + Operations::ScratchPolicy::None)); + command->allocateArgument( + VariableType(DataType::UInt32, PointerType::PointerGlobal), + m_scratchTags[Operations::ScratchPolicy::None], + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(Operations::ScratchPolicy::None)); + + // Create a scratch operation for flags + m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter] + = command->allocateTag(); + command->addOperation(Operations::Scratch( + m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter], + Operations::ScratchPolicy::ZeroedBeforeAndAfter)); + command->allocateArgument( + VariableType(DataType::UInt32, PointerType::PointerGlobal), + m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter], + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(Operations::ScratchPolicy::ZeroedBeforeAndAfter)); return command; } diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index 109cd73f8bf..bf8b3c88791 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -600,9 +600,9 @@ namespace rocRoller // nextWorkgroupTag ^ // | Join // Split / | \ - // / | \ / | \ - // v v v / | \ - // Workgroup plusOne forReceiveTileLoop + // / | \ / | \ + // v v v / | \ + // Workgroup plusOne forReceiveTileLoop // // Note: nextWorkgroupTag and resetNextWorkgroupTag both connect to the same // neighbors (Workgroup, plusOne, forReceiveTileLoop) but in opposite directions diff --git a/shared/rocroller/test/catch/AddStreamKTest.cpp b/shared/rocroller/test/catch/AddStreamKTest.cpp index ff3640ebb46..067383c7436 100644 --- a/shared/rocroller/test/catch/AddStreamKTest.cpp +++ b/shared/rocroller/test/catch/AddStreamKTest.cpp @@ -623,7 +623,7 @@ TEST_CASE("AddStreamK with unroll K", "[streamk][kernel-graph]") } } -TEST_CASE("AddStreamK scratch policy usage", "[streamk][scratch]") +TEST_CASE("AddStreamK scratch policy usage", "[streamk][kernel-graph][scratch]") { using namespace rocRoller; using namespace rocRoller::Operations; @@ -633,9 +633,7 @@ TEST_CASE("AddStreamK scratch policy usage", "[streamk][scratch]") auto context = TestContext::ForDefaultTarget(); auto example = rocRollerTest::Graphs::GEMM(DataType::Float); - - // Use standard mode which creates both tile data and flags scratch - auto mode = StreamKMode::Standard; + auto mode = StreamKMode::Standard; example.setTileSize(128, 256, 8); example.setMFMA(32, 32, 2, 1); @@ -731,11 +729,10 @@ TEST_CASE("AddStreamK scratch policy usage", "[streamk][scratch]") = kgraph.coordinates.findElements(findScratchZeroed).to(); CHECK(scratchZeroedCoordinates.size() == 2); - // Check the reset flags coordinate connects to the original flags coordinate via a duplicate edge + // Verify the reset flags coordinate connects to the original flags coordinate via a duplicate edge auto resetFlagsCoordinate = -1, originalFlagsCoordinate = -1; for(const auto& tag : scratchZeroedCoordinates) { - std::cout << "tag: " << tag << std::endl; // Duplicate edge connects the reset flags coordinate to the original flags coordinate auto isDuplicate = isEdge; auto outDuplicates @@ -748,15 +745,13 @@ TEST_CASE("AddStreamK scratch policy usage", "[streamk][scratch]") originalFlagsCoordinate = outDuplicates[0]; } } - std::cout << "resetFlagsCoordinate: " << resetFlagsCoordinate << std::endl; - std::cout << "originalFlagsCoordinate: " << originalFlagsCoordinate << std::endl; CHECK(resetFlagsCoordinate != -1); CHECK(originalFlagsCoordinate != -1); // Duplicate coordinate should have a higher tag than the original flags coordinate CHECK(resetFlagsCoordinate > originalFlagsCoordinate); - // Check the graph of flags scratch coordinates matches AddStreamK implementation + // Verify the graph of flags scratch coordinates matches AddStreamK implementation auto isPassThroughEdge = isEdge; auto isJoinEdge = isEdge; auto isSplitEdge = isEdge; @@ -785,5 +780,32 @@ TEST_CASE("AddStreamK scratch policy usage", "[streamk][scratch]") .to(); CHECK(maybeResetFlagsCoordinate.size() == 1); CHECK(maybeResetFlagsCoordinate[0] == resetFlagsCoordinate); + + // Verify there are two store flags operations + auto storeFlagsOps + = kgraph.control.findElements(kgraph.control.isElemType()).to(); + CHECK(storeFlagsOps.size() == 2); + + // Verify one StoreSGPR is stores the flags and the other resets the flags + auto storeFlagsTag = -1, resetFlagsTag = -1; + for(auto const tag : storeFlagsOps) + { + auto flagCoordinateTag = kgraph.mapper.get(tag); + if(flagCoordinateTag == resetFlagsCoordinate) + { + resetFlagsTag = tag; + } + else + { + storeFlagsTag = tag; + } + } + CHECK(storeFlagsTag != -1); + CHECK(resetFlagsTag != -1); + + // Verify the reset flags happens after the store flags + auto order + = kgraph.control.compareNodes(rocRoller::UpdateCache, storeFlagsTag, resetFlagsTag); + CHECK(order == NodeOrdering::LeftFirst); } } diff --git a/shared/rocroller/test/unit/GEMMTestBase.hpp b/shared/rocroller/test/unit/GEMMTestBase.hpp index 65036c39e9e..5f605cd146b 100644 --- a/shared/rocroller/test/unit/GEMMTestBase.hpp +++ b/shared/rocroller/test/unit/GEMMTestBase.hpp @@ -445,19 +445,28 @@ namespace GEMMTests DataDirection::ReadOnly, rocRoller::NUMWGS); - for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) - { - auto policy = static_cast(i); - scratchTags[policy] = command->allocateTag(); - command->addOperation( - rocRoller::Operations::Scratch(scratchTags.at(policy), policy)); - command->allocateArgument( - VariableType(DataType::UInt32, PointerType::PointerGlobal), - scratchTags.at(policy), - ArgumentType::Value, - DataDirection::ReadWrite, - getScratchName(policy)); - } + scratchTags[Operations::ScratchPolicy::None] = command->allocateTag(); + command->addOperation( + rocRoller::Operations::Scratch(scratchTags.at(Operations::ScratchPolicy::None), + Operations::ScratchPolicy::None)); + command->allocateArgument( + VariableType(DataType::UInt32, PointerType::PointerGlobal), + scratchTags.at(Operations::ScratchPolicy::None), + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(Operations::ScratchPolicy::None)); + + scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter] + = command->allocateTag(); + command->addOperation(rocRoller::Operations::Scratch( + scratchTags.at(Operations::ScratchPolicy::ZeroedBeforeAndAfter), + Operations::ScratchPolicy::ZeroedBeforeAndAfter)); + command->allocateArgument( + VariableType(DataType::UInt32, PointerType::PointerGlobal), + scratchTags.at(Operations::ScratchPolicy::ZeroedBeforeAndAfter), + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(Operations::ScratchPolicy::ZeroedBeforeAndAfter)); } Operations::OperationTag tagWGM; From 4575181bf41a47e60295e0cba42a71e77f9a73c9 Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Sat, 6 Dec 2025 02:40:31 +0000 Subject: [PATCH 09/22] Add streamk_fp4 suite --- .../rocroller/scripts/lib/rrperf/rrsuites.py | 65 +++++++++++++++++++ .../test/common/src/CommonGraphs.cpp | 29 +++++---- 2 files changed, 83 insertions(+), 11 deletions(-) diff --git a/shared/rocroller/scripts/lib/rrperf/rrsuites.py b/shared/rocroller/scripts/lib/rrperf/rrsuites.py index 3747ffe131e..f336749f802 100644 --- a/shared/rocroller/scripts/lib/rrperf/rrsuites.py +++ b/shared/rocroller/scripts/lib/rrperf/rrsuites.py @@ -139,6 +139,32 @@ M=256, N=256, K=16384, mac_m=64, mac_n=64, mac_k=64, types=fp32 ) +FP4GEMM_7680x8448x8448 = dict( + M=7680, + N=8448, + K=8448, + types=TypeParameters( + type_A="fp4", + type_B="fp4", + type_C="float", + type_D="float", + type_acc="float", + ), +) + +FP4GEMM_2048x2048x396288 = dict( + M=2048, + N=2048, + K=396288, + types=TypeParameters( + type_A="fp4", + type_B="fp4", + type_C="float", + type_D="float", + type_acc="float", + ), +) + def update_parameters(*args, **kwargs): rv = {} @@ -570,6 +596,44 @@ def streamk_sweep(): ) +def streamk_fp4_sweep(): + for twoTile, twoTileDPFirst in [(True, False), (False, True), (False, False)]: + for base in [FP4GEMM_2048x2048x396288]: + # 16x16x128 wave dimensions + for mac_m in [64, 128, 256]: + for mac_n in [64, 128, 256]: + for mac_k in [128]: + for wave_m in [16, 32]: + wave_n = 16 if wave_m == 16 else 32 + wave_k = 128 if wave_m == 16 else 64 + if ( + twoTile or twoTileDPFirst + ) and mac_m * mac_n * mac_k >= (wave_m * wave_m * mac_k): + # currently these run out of VGPRs. + pass + else: + yield mkGEMM( + base, + mac_m=mac_m, + mac_n=mac_n, + mac_k=mac_k, + wave_m=wave_m, + wave_n=wave_n, + wave_k=wave_k, + workgroup_size_x=128, + workgroup_size_y=2, + visualize=False, + prefetch=False, + streamKTwoTile=twoTile, + streamKTwoTileDPFirst=twoTileDPFirst, + types=TypeParameters( + base["types"], + trans_A="T", + trans_B="N", + ), + ) + + def streamk_smallMN_largeK_fp32(): for twoTile, twoTileDPFirst in [(True, False), (False, True), (False, False)]: yield mkGEMM( @@ -2061,6 +2125,7 @@ def all(): yield from fp8_kernels() yield from mxfp8_kernels() yield from mx_gemms_f8f6f4() + yield from streamk_fp4_sweep() yield from sgemm() yield from hgemm() diff --git a/shared/rocroller/test/common/src/CommonGraphs.cpp b/shared/rocroller/test/common/src/CommonGraphs.cpp index f94235d9457..0319ae1cad0 100644 --- a/shared/rocroller/test/common/src/CommonGraphs.cpp +++ b/shared/rocroller/test/common/src/CommonGraphs.cpp @@ -286,17 +286,24 @@ namespace rocRollerTest::Graphs rocRoller::NUMWGS); } - for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) - { - auto policy = static_cast(i); - m_scratchTags[policy] = m_command->allocateTag(); - m_command->addOperation(rocRoller::Operations::Scratch(m_scratchTags[policy], policy)); - m_command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), - m_scratchTags[policy], - ArgumentType::Value, - DataDirection::ReadWrite, - getScratchName(policy)); - } + m_scratchTags[Operations::ScratchPolicy::None] = m_command->allocateTag(); + m_command->addOperation(rocRoller::Operations::Scratch( + m_scratchTags[Operations::ScratchPolicy::None], Operations::ScratchPolicy::None)); + m_command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), + m_scratchTags[Operations::ScratchPolicy::None], + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(Operations::ScratchPolicy::None)); + m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter] = m_command->allocateTag(); + m_command->addOperation(rocRoller::Operations::Scratch( + m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter], + Operations::ScratchPolicy::ZeroedBeforeAndAfter)); + m_command->allocateArgument( + VariableType(DataType::UInt32, PointerType::PointerGlobal), + m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter], + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(Operations::ScratchPolicy::ZeroedBeforeAndAfter)); } CommandPtr GEMM::getCommand() From 33e427d8634679663b64602137264b2cbdf55431 Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Sat, 6 Dec 2025 05:56:56 +0000 Subject: [PATCH 10/22] Fix for-loop range --- .../lib/source/KernelGraph/Transformations/AddStreamK.cpp | 7 +++++-- shared/rocroller/scripts/lib/rrperf/rrsuites.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index bf8b3c88791..06719a056c9 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -1226,7 +1226,10 @@ namespace rocRoller int postAccumulationCond; if(accumInfo.accumulatorTile != -1) { - auto remainAccumTiles = numAccumTiles - DF(lastAccumTile) + one; + auto accumTileIdxEnd + = (argInfo.numSKTilesPerWG * wgExpr + DF(forTileIncr) + DF(forAccumIncr) - one) + % numAccumTiles; + auto remainAccumTiles = numAccumTiles - accumTileIdxEnd - one; auto numRemainPartialResults = (remainAccumTiles + argInfo.numSKTilesPerWG - one) / argInfo.numSKTilesPerWG; @@ -1261,7 +1264,7 @@ namespace rocRoller // Add send and receive auto hasFirstAccumTile = DF(firstAccumTile) == zero; auto doesntHaveFirstAccumTile = DF(firstAccumTile) != zero; - auto doesntHaveLastAccumTile = DF(lastAccumTile) < (numAccumTiles - one); + auto doesntHaveLastAccumTile = accumTileIdxEnd < (numAccumTiles - one); sendInfo = sendTile(graph, doesntHaveFirstAccumTile, diff --git a/shared/rocroller/scripts/lib/rrperf/rrsuites.py b/shared/rocroller/scripts/lib/rrperf/rrsuites.py index f336749f802..9c951883396 100644 --- a/shared/rocroller/scripts/lib/rrperf/rrsuites.py +++ b/shared/rocroller/scripts/lib/rrperf/rrsuites.py @@ -599,7 +599,6 @@ def streamk_sweep(): def streamk_fp4_sweep(): for twoTile, twoTileDPFirst in [(True, False), (False, True), (False, False)]: for base in [FP4GEMM_2048x2048x396288]: - # 16x16x128 wave dimensions for mac_m in [64, 128, 256]: for mac_n in [64, 128, 256]: for mac_k in [128]: @@ -624,6 +623,7 @@ def streamk_fp4_sweep(): workgroup_size_y=2, visualize=False, prefetch=False, + streamK=True, streamKTwoTile=twoTile, streamKTwoTileDPFirst=twoTileDPFirst, types=TypeParameters( @@ -644,6 +644,7 @@ def streamk_smallMN_largeK_fp32(): prefetch=False, # TODO: Fix k loop unrolling with stream k # prefetchInFlight=2, # prefetchLDSFactor=2, + streamK=True, streamKTwoTile=twoTile, streamKTwoTileDPFirst=twoTileDPFirst, types=TypeParameters( From 3694e656d1eb4b4ce13db3d7dab23bd296456b40 Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Mon, 8 Dec 2025 16:50:32 +0000 Subject: [PATCH 11/22] Fix for-lopp range --- .../lib/source/KernelGraph/Transformations/AddStreamK.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index 06719a056c9..a9704ac08af 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -1226,10 +1226,7 @@ namespace rocRoller int postAccumulationCond; if(accumInfo.accumulatorTile != -1) { - auto accumTileIdxEnd - = (argInfo.numSKTilesPerWG * wgExpr + DF(forTileIncr) + DF(forAccumIncr) - one) - % numAccumTiles; - auto remainAccumTiles = numAccumTiles - accumTileIdxEnd - one; + auto remainAccumTiles = numAccumTiles - DF(lastAccumTile) - one; auto numRemainPartialResults = (remainAccumTiles + argInfo.numSKTilesPerWG - one) / argInfo.numSKTilesPerWG; @@ -1264,7 +1261,7 @@ namespace rocRoller // Add send and receive auto hasFirstAccumTile = DF(firstAccumTile) == zero; auto doesntHaveFirstAccumTile = DF(firstAccumTile) != zero; - auto doesntHaveLastAccumTile = accumTileIdxEnd < (numAccumTiles - one); + auto doesntHaveLastAccumTile = DF(lastAccumTile) < (numAccumTiles - one); sendInfo = sendTile(graph, doesntHaveFirstAccumTile, From fd85c81f3ceb83b1273ab114d6b8915efd16cd6e Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Mon, 8 Dec 2025 21:46:04 +0000 Subject: [PATCH 12/22] Remove streamk suite, will add that separately --- .../rocroller/scripts/lib/rrperf/rrsuites.py | 87 ------------------- 1 file changed, 87 deletions(-) diff --git a/shared/rocroller/scripts/lib/rrperf/rrsuites.py b/shared/rocroller/scripts/lib/rrperf/rrsuites.py index 4a6f331bc80..33b9d5ba03e 100644 --- a/shared/rocroller/scripts/lib/rrperf/rrsuites.py +++ b/shared/rocroller/scripts/lib/rrperf/rrsuites.py @@ -139,32 +139,6 @@ M=256, N=256, K=16384, mac_m=64, mac_n=64, mac_k=64, types=fp32 ) -FP4GEMM_7680x8448x8448 = dict( - M=7680, - N=8448, - K=8448, - types=TypeParameters( - type_A="fp4", - type_B="fp4", - type_C="float", - type_D="float", - type_acc="float", - ), -) - -FP4GEMM_2048x2048x396288 = dict( - M=2048, - N=2048, - K=396288, - types=TypeParameters( - type_A="fp4", - type_B="fp4", - type_C="float", - type_D="float", - type_acc="float", - ), -) - def update_parameters(*args, **kwargs): rv = {} @@ -618,65 +592,6 @@ def streamk_sweep(): ) -def streamk_fp4_sweep(): - for twoTile, twoTileDPFirst in [(True, False), (False, True), (False, False)]: - for base in [FP4GEMM_2048x2048x396288]: - for mac_m in [64, 128, 256]: - for mac_n in [64, 128, 256]: - for mac_k in [128]: - for wave_m in [16, 32]: - wave_n = 16 if wave_m == 16 else 32 - wave_k = 128 if wave_m == 16 else 64 - if ( - twoTile or twoTileDPFirst - ) and mac_m * mac_n * mac_k >= (wave_m * wave_m * mac_k): - # currently these run out of VGPRs. - pass - else: - yield mkGEMM( - base, - mac_m=mac_m, - mac_n=mac_n, - mac_k=mac_k, - wave_m=wave_m, - wave_n=wave_n, - wave_k=wave_k, - workgroup_size_x=128, - workgroup_size_y=2, - visualize=False, - prefetch=False, - streamK=True, - streamKTwoTile=twoTile, - streamKTwoTileDPFirst=twoTileDPFirst, - types=TypeParameters( - base["types"], - trans_A="T", - trans_B="N", - ), - ) - - -def streamk_smallMN_largeK_fp32(): - for twoTile, twoTileDPFirst in [(True, False), (False, True), (False, False)]: - yield mkGEMM( - SGEMM_256x256x16384, - workgroup_size_x=128, - workgroup_size_y=2, - visualize=False, - prefetch=False, # TODO: Fix k loop unrolling with stream k - # prefetchInFlight=2, - # prefetchLDSFactor=2, - streamK=True, - streamKTwoTile=twoTile, - streamKTwoTileDPFirst=twoTileDPFirst, - types=TypeParameters( - SGEMM_256x256x16384["types"], - trans_A="T", - trans_B="N", - ), - ) - - def streamk(): common_overrides = dict( workgroup_size_x=128, @@ -2152,7 +2067,6 @@ def all(): yield from fp8_kernels() yield from mxfp8_kernels() yield from mx_gemms_f8f6f4() - yield from streamk_fp4_sweep() yield from sgemm() yield from hgemm() @@ -2161,7 +2075,6 @@ def all(): yield from streamk_sweep() yield from scalar_is_zero() yield from smallMN_largeK_fp32() - yield from streamk_smallMN_largeK_fp32() yield from codegen() From a2979c02af5e9ca0a4ae661da6b8aadb7157857f Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Wed, 10 Dec 2025 17:30:36 +0000 Subject: [PATCH 13/22] Update test, allocateScratch() returns size --- .../lib/include/rocRoller/Context.hpp | 22 ++++--- shared/rocroller/lib/source/Context.cpp | 21 +++--- .../lib/source/KernelGraph/Utils.cpp | 5 +- .../test/catch/ScratchOperationTest.cpp | 26 +++++--- .../rocroller/test/unit/GEMMStreamKTest.cpp | 64 ++++++++++++------- 5 files changed, 84 insertions(+), 54 deletions(-) diff --git a/shared/rocroller/lib/include/rocRoller/Context.hpp b/shared/rocroller/lib/include/rocRoller/Context.hpp index 6a3a5b5da4a..accacca33ac 100644 --- a/shared/rocroller/lib/include/rocRoller/Context.hpp +++ b/shared/rocroller/lib/include/rocRoller/Context.hpp @@ -28,7 +28,6 @@ #include #include -#include #include #include #include @@ -128,6 +127,16 @@ namespace rocRoller void setKernel(AssemblyKernelPtr); + /** + * @brief Allocate scratch space for the specified scratch policy. + * + * @param policy The scratch policy to allocate for + * @param size Number of bytes requested + * @return Expression::ExpressionPtr The offset before this allocation + */ + Expression::ExpressionPtr allocateScratch(Operations::ScratchPolicy policy, + Expression::ExpressionPtr size); + /** * @brief Returns an expression representing how much scratch space is required (in bytes) * for the specified scratch policy. @@ -137,14 +146,6 @@ namespace rocRoller */ Expression::ExpressionPtr getScratchAmount(Operations::ScratchPolicy policy) const; - /** - * @brief Allocate more scratch space for the specified scratch policy. - * - * @param policy The scratch policy to allocate for - * @param size Number of bytes requested - */ - void allocateScratch(Operations::ScratchPolicy policy, Expression::ExpressionPtr size); - /** * @brief Get register scope manager. */ @@ -180,7 +181,8 @@ namespace rocRoller std::shared_ptr m_mem; LabelAllocatorPtr m_labelAllocator; std::shared_ptr m_ldsAllocator; - std::map m_scratchAllocators; + std::array(Operations::ScratchPolicy::Count)> + m_scratchSizes; std::shared_ptr m_copier; std::shared_ptr m_brancher; std::shared_ptr m_crasher; diff --git a/shared/rocroller/lib/source/Context.cpp b/shared/rocroller/lib/source/Context.cpp index d4f30fff622..c29d3139b2f 100644 --- a/shared/rocroller/lib/source/Context.cpp +++ b/shared/rocroller/lib/source/Context.cpp @@ -48,11 +48,10 @@ namespace rocRoller { Context::Context() { - // Initialize scratch allocators for each policy with zero - for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + // Initialize scratch sizes for each policy with zero + for(size_t i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) { - m_scratchAllocators[static_cast(i)] - = Expression::literal(0u); + m_scratchSizes[i] = Expression::literal(0u); } } @@ -292,16 +291,18 @@ namespace rocRoller m_kernel = assemblyKernel; } - Expression::ExpressionPtr Context::getScratchAmount(Operations::ScratchPolicy policy) const + Expression::ExpressionPtr Context::allocateScratch(Operations::ScratchPolicy policy, + Expression::ExpressionPtr size) { - auto it = m_scratchAllocators.find(policy); - AssertFatal(it != m_scratchAllocators.end(), "Scratch policy not found", ShowValue(policy)); - return it->second; + auto idx = static_cast(policy); + auto currentOffset = m_scratchSizes[idx]; + m_scratchSizes[idx] = simplify(m_scratchSizes[idx] + size); + return currentOffset; } - void Context::allocateScratch(Operations::ScratchPolicy policy, Expression::ExpressionPtr size) + Expression::ExpressionPtr Context::getScratchAmount(Operations::ScratchPolicy policy) const { - m_scratchAllocators[policy] = simplify(m_scratchAllocators[policy] + size); + return m_scratchSizes[static_cast(policy)]; } void Context::scheduleCopy(Instruction const& inst) diff --git a/shared/rocroller/lib/source/KernelGraph/Utils.cpp b/shared/rocroller/lib/source/KernelGraph/Utils.cpp index aca74b388ba..73809aaa6c7 100644 --- a/shared/rocroller/lib/source/KernelGraph/Utils.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Utils.cpp @@ -824,13 +824,12 @@ namespace rocRoller Operations::ScratchPolicy policy, ContextPtr context) { - auto currentOffset = context->getScratchAmount(policy); - auto newCoordinate = CT::User(size, currentOffset, getScratchName(policy)); // TODO Audit bytes/bits // Can we move size inside the CeilDivide? - context->allocateScratch( + auto currentOffset = context->allocateScratch( policy, size * Expression::literal(CeilDivide(DataTypeInfo::Get(varType).elementBits, 8u))); + auto newCoordinate = CT::User(size, currentOffset, getScratchName(policy)); return newCoordinate; } diff --git a/shared/rocroller/test/catch/ScratchOperationTest.cpp b/shared/rocroller/test/catch/ScratchOperationTest.cpp index b23ec3aec07..a793989dd58 100644 --- a/shared/rocroller/test/catch/ScratchOperationTest.cpp +++ b/shared/rocroller/test/catch/ScratchOperationTest.cpp @@ -239,7 +239,7 @@ namespace ScratchOperationTest SECTION("Context initializes all policies to zero") { - for(int i = 0; i < static_cast(ScratchPolicy::Count); ++i) + for(size_t i = 0; i < static_cast(ScratchPolicy::Count); ++i) { auto policy = static_cast(i); auto amount = context->getScratchAmount(policy); @@ -248,24 +248,34 @@ namespace ScratchOperationTest } } - SECTION("allocateScratch accumulates size per policy") + SECTION("allocateScratch returns offset before allocation and accumulates size") { auto size1 = rocRoller::Expression::literal(100u); auto size2 = rocRoller::Expression::literal(200u); - context->allocateScratch(ScratchPolicy::None, size1); - context->allocateScratch(ScratchPolicy::None, size2); + auto offset1 = context->allocateScratch(ScratchPolicy::None, size1); + auto offset2 = context->allocateScratch(ScratchPolicy::None, size2); + // First allocation should return offset 0 + CHECK(rocRoller::getUnsignedInt(rocRoller::Expression::evaluate(offset1)) == 0); + // Second allocation should return offset 100 (after first allocation) + CHECK(rocRoller::getUnsignedInt(rocRoller::Expression::evaluate(offset2)) == 100); + + // Total should now be 300 auto amount = context->getScratchAmount(ScratchPolicy::None); auto value = rocRoller::Expression::evaluate(amount); CHECK(rocRoller::getUnsignedInt(value) == 300); } - SECTION("getScratchAmount returns correct amount per policy") + SECTION("allocateScratch returns correct offset per policy") { - auto size = rocRoller::Expression::literal(500u); - context->allocateScratch(ScratchPolicy::ZeroedBeforeAndAfter, size); + auto size = rocRoller::Expression::literal(500u); + auto offset = context->allocateScratch(ScratchPolicy::ZeroedBeforeAndAfter, size); + + // First allocation should return offset 0 + CHECK(rocRoller::getUnsignedInt(rocRoller::Expression::evaluate(offset)) == 0); + // Query total should now be 500 auto amount = context->getScratchAmount(ScratchPolicy::ZeroedBeforeAndAfter); auto value = rocRoller::Expression::evaluate(amount); CHECK(rocRoller::getUnsignedInt(value) == 500); @@ -307,7 +317,7 @@ namespace ScratchOperationTest SECTION("All policies have unique names") { std::set names; - for(int i = 0; i < static_cast(ScratchPolicy::Count); ++i) + for(size_t i = 0; i < static_cast(ScratchPolicy::Count); ++i) { auto policy = static_cast(i); auto name = rocRoller::getScratchName(policy); diff --git a/shared/rocroller/test/unit/GEMMStreamKTest.cpp b/shared/rocroller/test/unit/GEMMStreamKTest.cpp index 2eb5d92f8d6..67a6f4ab50e 100644 --- a/shared/rocroller/test/unit/GEMMStreamKTest.cpp +++ b/shared/rocroller/test/unit/GEMMStreamKTest.cpp @@ -36,8 +36,12 @@ namespace GEMMTests using namespace rocRoller; namespace SolutionParams = rocRoller::Parameters::Solution; + // ProblemConfig: (dataTypeAB, macM, macN, macK, m, n, k, numWGs) + using ProblemConfig = std::tuple; + class StreamKMultipleFixupsTestGPU - : public BaseGEMMContextFixture> @@ -67,46 +71,56 @@ namespace GEMMTests { }; - TEST_P(StreamKMultipleFixupsTestGPU, GPU_BasicGEMMFP16) + TEST_P(StreamKMultipleFixupsTestGPU, GPU_BasicGEMM) { if(m_context->targetArchitecture().target().isCDNA1GPU()) { - GTEST_SKIP() << "Skipping GPU_BasicGEMMFP16 test: CDNA1 not supported"; + GTEST_SKIP() << "Skipping GPU_BasicGEMM test: CDNA1 not supported"; } - GEMMProblem gemm; + auto [problemConfig, mode, loadPathA, loadPathB, storeLDSD] + = std::get<1>(GetParam()); + auto [dataTypeAB, macM, macN, macK, m, n, k, numWGs] = problemConfig; - hipDeviceProp_t deviceProperties; - ASSERT_THAT(hipGetDeviceProperties(&deviceProperties, 0), HasHipSuccess(0)); + GEMMProblem gemm; - gemm.macM = 128; - gemm.macN = 128; - gemm.macK = 16; + gemm.macM = macM; + gemm.macN = macN; + gemm.macK = macK; + gemm.m = m; + gemm.n = n; + gemm.k = k; + gemm.numWGs = numWGs; - gemm.waveK = 8; + if(dataTypeAB == DataType::Half) + { + gemm.waveK = 8; + } gemm.workgroupSizeX = 128; gemm.workgroupSizeY = 2; - gemm.numWGs = 128; - - auto numTilesM = 1; - auto numTilesN = 2; - auto numTilesK = 249; - - gemm.m = numTilesM * gemm.macM; - gemm.n = numTilesN * gemm.macN; - gemm.k = numTilesK * gemm.macK; - // assert that the number of output tiles is smaller than number of WGs // which means there is not enough data-parallel tiles, and has to split // K dimension into multiple tiles ASSERT_GE(gemm.numWGs, gemm.m * gemm.n / gemm.macM / gemm.macN); - std::tie(gemm.streamK, gemm.loadPathA, gemm.loadPathB, gemm.storeLDSD) - = std::get<1>(GetParam()); + gemm.streamK = mode; + gemm.loadPathA = loadPathA; + gemm.loadPathB = loadPathB; + gemm.storeLDSD = storeLDSD; - basicGEMM(gemm); + switch(dataTypeAB) + { + case DataType::Half: + basicGEMM(gemm); + break; + case DataType::Float: + basicGEMM(gemm); + break; + default: + Throw(fmt::format("Unexpected data type: {}. ", toString(dataTypeAB))); + } } TEST_P(StreamKWGMTestGPU, GPU_BasicGEMMStreamKWorkgroupMapping) @@ -220,6 +234,10 @@ namespace GEMMTests ::testing::Combine( currentGPUISA(), ::testing::Combine( + ::testing::Values( + // ProblemConfig: (dataTypeAB, macM, macN, macK, m, n, k, numWGs) + ProblemConfig{rocRoller::DataType::Half, 128, 128, 16, 128, 256, 15936, 128}, + ProblemConfig{rocRoller::DataType::Float, 64, 64, 64, 256, 256, 16384, 256}), /* problemConfig */ ::testing::Values( StreamKMode::Standard, StreamKMode::TwoTile, StreamKMode::TwoTileDPFirst), ::testing::Values(SolutionParams::LoadPath::BufferToLDSViaVGPR, From c43dd23867213ffa1cff07595974a3f2776f9a94 Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Wed, 10 Dec 2025 17:32:24 +0000 Subject: [PATCH 14/22] Formatting --- .../lib/include/rocRoller/Context.hpp | 26 +++++++++---------- shared/rocroller/lib/source/Context.cpp | 6 ++--- .../rocroller/test/unit/GEMMStreamKTest.cpp | 14 +++++++--- 3 files changed, 26 insertions(+), 20 deletions(-) diff --git a/shared/rocroller/lib/include/rocRoller/Context.hpp b/shared/rocroller/lib/include/rocRoller/Context.hpp index accacca33ac..39eb1781312 100644 --- a/shared/rocroller/lib/include/rocRoller/Context.hpp +++ b/shared/rocroller/lib/include/rocRoller/Context.hpp @@ -174,20 +174,20 @@ namespace rocRoller std::array, static_cast(Register::Type::Count)> m_allocators; - std::shared_ptr m_observer; - AssemblyKernelPtr m_kernel; - std::shared_ptr m_argLoader; - std::shared_ptr m_instructions; - std::shared_ptr m_mem; - LabelAllocatorPtr m_labelAllocator; - std::shared_ptr m_ldsAllocator; + std::shared_ptr m_observer; + AssemblyKernelPtr m_kernel; + std::shared_ptr m_argLoader; + std::shared_ptr m_instructions; + std::shared_ptr m_mem; + LabelAllocatorPtr m_labelAllocator; + std::shared_ptr m_ldsAllocator; std::array(Operations::ScratchPolicy::Count)> - m_scratchSizes; - std::shared_ptr m_copier; - std::shared_ptr m_brancher; - std::shared_ptr m_crasher; - std::shared_ptr m_random; - std::shared_ptr m_scope; + m_scratchSizes; + std::shared_ptr m_copier; + std::shared_ptr m_brancher; + std::shared_ptr m_crasher; + std::shared_ptr m_random; + std::shared_ptr m_scope; std::string m_assemblyFileName; KernelOptions m_kernelOptions; diff --git a/shared/rocroller/lib/source/Context.cpp b/shared/rocroller/lib/source/Context.cpp index c29d3139b2f..0356cfdbf4c 100644 --- a/shared/rocroller/lib/source/Context.cpp +++ b/shared/rocroller/lib/source/Context.cpp @@ -292,10 +292,10 @@ namespace rocRoller } Expression::ExpressionPtr Context::allocateScratch(Operations::ScratchPolicy policy, - Expression::ExpressionPtr size) + Expression::ExpressionPtr size) { - auto idx = static_cast(policy); - auto currentOffset = m_scratchSizes[idx]; + auto idx = static_cast(policy); + auto currentOffset = m_scratchSizes[idx]; m_scratchSizes[idx] = simplify(m_scratchSizes[idx] + size); return currentOffset; } diff --git a/shared/rocroller/test/unit/GEMMStreamKTest.cpp b/shared/rocroller/test/unit/GEMMStreamKTest.cpp index 67a6f4ab50e..add998e8683 100644 --- a/shared/rocroller/test/unit/GEMMStreamKTest.cpp +++ b/shared/rocroller/test/unit/GEMMStreamKTest.cpp @@ -78,9 +78,8 @@ namespace GEMMTests GTEST_SKIP() << "Skipping GPU_BasicGEMM test: CDNA1 not supported"; } - auto [problemConfig, mode, loadPathA, loadPathB, storeLDSD] - = std::get<1>(GetParam()); - auto [dataTypeAB, macM, macN, macK, m, n, k, numWGs] = problemConfig; + auto [problemConfig, mode, loadPathA, loadPathB, storeLDSD] = std::get<1>(GetParam()); + auto [dataTypeAB, macM, macN, macK, m, n, k, numWGs] = problemConfig; GEMMProblem gemm; @@ -237,7 +236,14 @@ namespace GEMMTests ::testing::Values( // ProblemConfig: (dataTypeAB, macM, macN, macK, m, n, k, numWGs) ProblemConfig{rocRoller::DataType::Half, 128, 128, 16, 128, 256, 15936, 128}, - ProblemConfig{rocRoller::DataType::Float, 64, 64, 64, 256, 256, 16384, 256}), /* problemConfig */ + ProblemConfig{rocRoller::DataType::Float, + 64, + 64, + 64, + 256, + 256, + 16384, + 256}), /* problemConfig */ ::testing::Values( StreamKMode::Standard, StreamKMode::TwoTile, StreamKMode::TwoTileDPFirst), ::testing::Values(SolutionParams::LoadPath::BufferToLDSViaVGPR, From 6ea17d9affbaa5ded7624979e5909c20e8b490f3 Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Thu, 11 Dec 2025 04:52:29 +0000 Subject: [PATCH 15/22] Only wave0_workitem0 writes to flag --- .../Transformations/AddStreamK.cpp | 54 ++++++++++++++++--- 1 file changed, 47 insertions(+), 7 deletions(-) diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index a9704ac08af..d003a58bea1 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -492,24 +492,42 @@ namespace rocRoller // TODO: Improve setting of arch-specific buffer options BufferInstructionOptions bufOpts{.glc = true}; + // if(!(context->targetArchitecture().target().isCDNA1GPU() + // || context->targetArchitecture().target().isCDNA2GPU())) + // { + // bufOpts.sc1 = true; + // } auto storeFlagTag = graph.control.addElement(StoreSGPR(DataType::UInt32, bufOpts)); graph.mapper.connect(storeFlagTag, flagsScratchTag); graph.mapper.connect(storeFlagTag, flagRegister); + // Create workitem coordinate and expression for wave 0 check + // Only workitem 0 (wave 0) should write to the flag + auto workitemTag = graph.coordinates.addElement(Workitem(0)); + auto workitemDF = std::make_shared( + Expression::DataFlowTag{workitemTag, Register::Type::Vector, DataType::UInt32}); + auto isWave0Expr = (workitemDF == Expression::literal(0u)); + auto wave0FlagStoreTag = graph.control.addElement( + ConditionalOp{isWave0Expr, "Wave0 Store Flag"}); + // Add to control auto preWaitZeroTag = graph.control.addElement(WaitZero()); auto postWaitZeroTag = graph.control.addElement(WaitZero()); graph.control.addElement(Sequence(), {preWaitZeroTag}, {sendTileTag}); graph.control.addElement(Body(), {sendTileTag}, {forX}); - graph.control.chain( - forX, waitZeroTag, barrierTag, assignFlagTag, storeFlagTag, postWaitZeroTag); + graph.control.chain(forX, waitZeroTag, barrierTag, wave0FlagStoreTag); + graph.control.addElement(Body(), {wave0FlagStoreTag}, {assignFlagTag}); + auto waitAfterStoreFlagTag = graph.control.addElement(WaitZero()); + graph.control.chain(assignFlagTag, storeFlagTag, waitAfterStoreFlagTag); + graph.control.chain(assignFlagTag, storeFlagTag); + graph.control.addElement(Sequence(), {wave0FlagStoreTag}, {postWaitZeroTag}); return {preWaitZeroTag, sendTileTag}; } /** - * Create send-tile block, which is roughly: + * Create receive-tile block, which is roughly: * * WaitZero() * if receiveTileExpr: @@ -519,6 +537,8 @@ namespace rocRoller * do: * LoadSGPR(flag[nextWG]) * while flag[nextWG] == 0 + * Barrier() + * Assign(flag[nextWG] = 0) * partiallyAccumulatedTile = LoadTiled() * fullyAccumulatedTile = Assign(localPartiallyAccumulatedTile) * fullyAccumulatedTile = Assign(fullyAccumulatedTile + partiallyAccumulatedTile) @@ -570,6 +590,12 @@ namespace rocRoller // TODO: Improve setting of arch-specific buffer options BufferInstructionOptions bufOpts{.glc = true}; + // if(!(context->targetArchitecture().target().isCDNA1GPU() + // || context->targetArchitecture().target().isCDNA2GPU())) + // { + // bufOpts.sc1 = true; + // } + auto flagRegister = graph.coordinates.addElement(VGPR()); auto loadFlagTag = graph.control.addElement(LoadSGPR(DataType::UInt32, bufOpts)); @@ -630,6 +656,15 @@ namespace rocRoller graph.mapper.connect(resetFlagTag, resetFlagsScratchTag); graph.mapper.connect(resetFlagTag, flagRegister); + // Create workitem coordinate and expression for wave 0 check + // Only workitem 0 (wave 0) should write to the flag + auto workitemTag = graph.coordinates.addElement(Workitem(0)); + auto workitemDF = std::make_shared( + Expression::DataFlowTag{workitemTag, Register::Type::Vector, DataType::UInt32}); + auto isWave0Expr = (workitemDF == Expression::literal(0u)); + auto wave0ResetFlagTag = graph.control.addElement( + ConditionalOp{isWave0Expr, "Wave0 Reset Flag"}); + auto barrierBeforeResetTag = graph.control.addElement(Barrier()); auto accumulatorTile = graph.coordinates.get(accumulatorTileTag); @@ -690,12 +725,17 @@ namespace rocRoller graph.control.addElement(Sequence(), {boundsCheckTag}, {doWhileTag}); graph.control.addElement(Body(), {doWhileTag}, {loadFlagTag}); + auto waitBeforeResetTag = graph.control.addElement(WaitZero()); + auto waitAfterResetTag = graph.control.addElement(WaitZero()); graph.control.chain(doWhileTag, + waitBeforeResetTag, barrierBeforeResetTag, - assignResetFlagTag, - resetFlagTag, - loadAddForX, - postWaitZeroTag); + wave0ResetFlagTag); + graph.control.addElement(Body(), {wave0ResetFlagTag}, {assignResetFlagTag}); + auto waitAfterRestFlagStoreTag = graph.control.addElement(WaitZero()); + graph.control.chain(assignResetFlagTag, resetFlagTag, waitAfterRestFlagStoreTag); + graph.control.chain( + wave0ResetFlagTag, waitAfterResetTag, loadAddForX, postWaitZeroTag); return {preWaitZeroTag, receiveTileTag, setPlusOneTag}; } From c19d7f0fc117dcec9cedd5dd436ba802ac83140e Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Thu, 11 Dec 2025 06:02:53 +0000 Subject: [PATCH 16/22] Print info when flag is not rest to 0, remove unnecessary WaitZero in AddStreamK --- .../Transformations/AddStreamK.cpp | 44 ++++++++----------- .../rocroller/test/unit/GEMMStreamKTest.cpp | 4 +- shared/rocroller/test/unit/GEMMTestBase.hpp | 19 ++++++-- 3 files changed, 37 insertions(+), 30 deletions(-) diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index d003a58bea1..197c999a691 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -428,9 +428,10 @@ namespace rocRoller * StoreTile() * WaitZero() * Barrier() - * flag = Assign(SGPR, 1u); - * StoreSGPR(flag) - * WaitZero() + * if wave0: + * flag = Assign(SGPR, 1u); + * StoreSGPR(flag) + * WaitZero() */ SendInfo sendTile(KernelGraph& graph, ExpressionPtr sendTileExpr, @@ -506,13 +507,12 @@ namespace rocRoller auto workitemTag = graph.coordinates.addElement(Workitem(0)); auto workitemDF = std::make_shared( Expression::DataFlowTag{workitemTag, Register::Type::Vector, DataType::UInt32}); - auto isWave0Expr = (workitemDF == Expression::literal(0u)); - auto wave0FlagStoreTag = graph.control.addElement( - ConditionalOp{isWave0Expr, "Wave0 Store Flag"}); + auto isWave0Expr = (workitemDF == Expression::literal(0u)); + auto wave0FlagStoreTag + = graph.control.addElement(ConditionalOp{isWave0Expr, "Wave0 Store Flag"}); // Add to control - auto preWaitZeroTag = graph.control.addElement(WaitZero()); - auto postWaitZeroTag = graph.control.addElement(WaitZero()); + auto preWaitZeroTag = graph.control.addElement(WaitZero()); graph.control.addElement(Sequence(), {preWaitZeroTag}, {sendTileTag}); graph.control.addElement(Body(), {sendTileTag}, {forX}); @@ -520,8 +520,6 @@ namespace rocRoller graph.control.addElement(Body(), {wave0FlagStoreTag}, {assignFlagTag}); auto waitAfterStoreFlagTag = graph.control.addElement(WaitZero()); graph.control.chain(assignFlagTag, storeFlagTag, waitAfterStoreFlagTag); - graph.control.chain(assignFlagTag, storeFlagTag); - graph.control.addElement(Sequence(), {wave0FlagStoreTag}, {postWaitZeroTag}); return {preWaitZeroTag, sendTileTag}; } @@ -538,7 +536,10 @@ namespace rocRoller * LoadSGPR(flag[nextWG]) * while flag[nextWG] == 0 * Barrier() - * Assign(flag[nextWG] = 0) + * if wave0: + * Assign(flag[nextWG] = 0) + * StoreSGPR(flag[nextWG]) + * WaitZero() * partiallyAccumulatedTile = LoadTiled() * fullyAccumulatedTile = Assign(localPartiallyAccumulatedTile) * fullyAccumulatedTile = Assign(fullyAccumulatedTile + partiallyAccumulatedTile) @@ -590,12 +591,6 @@ namespace rocRoller // TODO: Improve setting of arch-specific buffer options BufferInstructionOptions bufOpts{.glc = true}; - // if(!(context->targetArchitecture().target().isCDNA1GPU() - // || context->targetArchitecture().target().isCDNA2GPU())) - // { - // bufOpts.sc1 = true; - // } - auto flagRegister = graph.coordinates.addElement(VGPR()); auto loadFlagTag = graph.control.addElement(LoadSGPR(DataType::UInt32, bufOpts)); @@ -661,9 +656,9 @@ namespace rocRoller auto workitemTag = graph.coordinates.addElement(Workitem(0)); auto workitemDF = std::make_shared( Expression::DataFlowTag{workitemTag, Register::Type::Vector, DataType::UInt32}); - auto isWave0Expr = (workitemDF == Expression::literal(0u)); - auto wave0ResetFlagTag = graph.control.addElement( - ConditionalOp{isWave0Expr, "Wave0 Reset Flag"}); + auto isWave0Expr = (workitemDF == Expression::literal(0u)); + auto wave0ResetFlagTag + = graph.control.addElement(ConditionalOp{isWave0Expr, "Wave0 Reset Flag"}); auto barrierBeforeResetTag = graph.control.addElement(Barrier()); @@ -725,17 +720,16 @@ namespace rocRoller graph.control.addElement(Sequence(), {boundsCheckTag}, {doWhileTag}); graph.control.addElement(Body(), {doWhileTag}, {loadFlagTag}); - auto waitBeforeResetTag = graph.control.addElement(WaitZero()); - auto waitAfterResetTag = graph.control.addElement(WaitZero()); + // auto waitBeforeResetTag = graph.control.addElement(WaitZero()); graph.control.chain(doWhileTag, - waitBeforeResetTag, + // waitBeforeResetTag, barrierBeforeResetTag, wave0ResetFlagTag); graph.control.addElement(Body(), {wave0ResetFlagTag}, {assignResetFlagTag}); auto waitAfterRestFlagStoreTag = graph.control.addElement(WaitZero()); - graph.control.chain(assignResetFlagTag, resetFlagTag, waitAfterRestFlagStoreTag); graph.control.chain( - wave0ResetFlagTag, waitAfterResetTag, loadAddForX, postWaitZeroTag); + assignResetFlagTag, resetFlagTag, waitAfterRestFlagStoreTag); + graph.control.chain(wave0ResetFlagTag, loadAddForX, postWaitZeroTag); return {preWaitZeroTag, receiveTileTag, setPlusOneTag}; } diff --git a/shared/rocroller/test/unit/GEMMStreamKTest.cpp b/shared/rocroller/test/unit/GEMMStreamKTest.cpp index add998e8683..b6bf3c9b911 100644 --- a/shared/rocroller/test/unit/GEMMStreamKTest.cpp +++ b/shared/rocroller/test/unit/GEMMStreamKTest.cpp @@ -112,10 +112,10 @@ namespace GEMMTests switch(dataTypeAB) { case DataType::Half: - basicGEMM(gemm); + basicGEMM(gemm, false, false, 100); break; case DataType::Float: - basicGEMM(gemm); + basicGEMM(gemm, false, false, 100); break; default: Throw(fmt::format("Unexpected data type: {}. ", toString(dataTypeAB))); diff --git a/shared/rocroller/test/unit/GEMMTestBase.hpp b/shared/rocroller/test/unit/GEMMTestBase.hpp index 5f605cd146b..6c64d34c347 100644 --- a/shared/rocroller/test/unit/GEMMTestBase.hpp +++ b/shared/rocroller/test/unit/GEMMTestBase.hpp @@ -838,10 +838,23 @@ namespace GEMMTests scratchSpaceRequired[zeroedIdx], hipMemcpyDeviceToHost), HasHipSuccess(0)); - EXPECT_TRUE(std::all_of( - zeroedResult.begin(), zeroedResult.end(), [](uint8_t v) { return v == 0; })) + + bool allZeros = true; + for(size_t i = 0; i < zeroedResult.size(); ++i) + { + if(zeroedResult[i] != 0) + { + allZeros = false; + // Print as uint32 since flags are UInt32 + size_t flagIndex = i / sizeof(uint32_t); + std::cerr << "Non-zero at byte " << i << " (flag index " << flagIndex + << "): " << static_cast(zeroedResult[i]) << std::endl; + } + } + EXPECT_TRUE(allZeros) << "ZeroedBeforeAndAfter scratch should be all zeros after kernel " - "execution"; + "execution (size=" + << scratchSpaceRequired[zeroedIdx] << " bytes)"; } if(debuggable && !res.ok) From 3b9ec64cdac091968e940630cb6044f2f3eb1b7a Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Thu, 11 Dec 2025 11:10:05 -0600 Subject: [PATCH 17/22] Enable rocroller streamk in hipblaslt --- .../rocblaslt/src/rocroller/gemm.cpp | 74 +++++++++++++++---- .../rocblaslt/src/rocroller/include/gemm.hpp | 5 +- .../rocroller/include/solution_selection.hpp | 1 + .../src/rocroller/parameter_selection.cpp | 17 +++-- .../src/rocroller/solution_selection.cpp | 13 ++++ 5 files changed, 87 insertions(+), 23 deletions(-) diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp index 324b3c7e8e1..923aae474d9 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp @@ -170,7 +170,8 @@ std::shared_ptr genGemmKernel(std::shared_ptr ge ShowValue(gemm->kernelType.scaleBMode)); std::optional tagTensorScaleA, tagLoadScaleA, tagBlockScaleA, - tagTensorScaleB, tagLoadScaleB, tagBlockScaleB, tagScratch, tagSKGrid, tagWGM; + tagTensorScaleB, tagLoadScaleB, tagBlockScaleB, tagSKGrid, tagWGM; + std::map tagScratch; if(gemm->kernelType.scaleAMode == Operations::ScaleMode::Separate) { @@ -265,12 +266,19 @@ std::shared_ptr genGemmKernel(std::shared_ptr ge DataDirection::ReadOnly, rocRoller::NUMWGS); - tagScratch = command->allocateTag(); - command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), - *tagScratch, - ArgumentType::Value, - DataDirection::ReadWrite, - rocRoller::SCRATCH); + // Create Scratch operations for each ScratchPolicy + for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + auto tag = command->allocateTag(); + command->addOperation(Operations::Scratch(tag, policy)); + command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), + tag, + ArgumentType::Value, + DataDirection::ReadWrite, + rocRoller::getScratchName(policy)); + tagScratch[policy] = tag; + } } if(gemm->workgroupMappingDim != -1) @@ -489,8 +497,7 @@ std::shared_ptr genGemmKernel(std::shared_ptr ge if(tagTensorScaleB) gemmKernel->tagTensorScaleB = *tagTensorScaleB; - if(tagScratch) - gemmKernel->tagScratch = *tagScratch; + gemmKernel->tagScratch = tagScratch; if(tagSKGrid) gemmKernel->tagSKGrid = *tagSKGrid; @@ -523,7 +530,14 @@ size_t workspaceRequired(std::shared_ptr gemm, const RocblasltContra auto runtimeArgs = commandArgs.runtimeArguments(); - return gemm->commandKernel->scratchSpaceRequired(runtimeArgs); + // Sum scratch requirements for all policies + size_t total = 0; + for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + total += gemm->commandKernel->scratchSpaceRequired(policy, runtimeArgs); + } + return total; } CommandArguments createCommandArguments(std::shared_ptr gemm, @@ -622,23 +636,53 @@ rocblaslt_status runGemmKernel(std::shared_ptr gemm, } return rocblaslt_status_invalid_value; } - auto commandArgs = createCommandArguments(gemm, prob, DEFAULT_WGM); - // Add scratch space - if(workSpaceRequired > 0) + // Track allocated scratch memory for each policy + std::array(Operations::ScratchPolicy::Count)> scratchPtrs = {}; + std::array(Operations::ScratchPolicy::Count)> scratchSizes = {}; + + if(gemm->params->streamK) { - commandArgs.setArgument( - gemm->tagScratch, ArgumentType::Value, static_cast(prob.workspace)); + auto runtimeArgs = commandArgs.runtimeArguments(); + for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + scratchSizes[i] = gemm->commandKernel->scratchSpaceRequired(policy, runtimeArgs); + if(scratchSizes[i] > 0) + { + commandArgs.setArgument( + gemm->tagScratch.at(policy), ArgumentType::Value, + static_cast(scratchPtrs[i])); + } + } } auto runtimeArgs = commandArgs.runtimeArguments(); if(!gemm->commandKernel->matchesPredicates(runtimeArgs, LogLevel::Error)) { + // Free allocated scratch memory before returning + for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + { + if(scratchPtrs[i] != nullptr) + { + AssertFatal(hipFree(scratchPtrs[i]), "Failed to free scratch memory" + ShowValue(i)); + } + } return rocblaslt_status_invalid_value; } gemm->commandKernel->launchKernel(runtimeArgs, prob.stream); + + // Free allocated scratch memory after kernel completes + for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + { + if(scratchPtrs[i] != nullptr) + { + AssertFatal(hipFree(scratchPtrs[i]), "Failed to free scratch memory" + ShowValue(i)); + } + } + return rocblaslt_status_success; } diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/gemm.hpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/gemm.hpp index 6f0cd7dd440..72308103f66 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/gemm.hpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/gemm.hpp @@ -32,8 +32,11 @@ #include #include #include +#include #include +#include + /** * @brief GemmKernel * @@ -57,7 +60,7 @@ struct GemmKernel rocRoller::Operations::OperationTag tagTensorScaleA; rocRoller::Operations::OperationTag tagTensorScaleB; - rocRoller::Operations::OperationTag tagScratch; + std::map tagScratch; rocRoller::Operations::OperationTag tagSKGrid; rocRoller::Operations::OperationTag tagWGM; diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp index 3fd0babe898..11e58a2cdbf 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp @@ -69,6 +69,7 @@ struct SolutionIndexParameters { WorkGroupTileSize workgroupTile; bool workgroupMapping; + bool streamK; }; int parametersToIndex(const SolutionIndexParameters& params); diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp index 284060c2f6a..e784677a06c 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp @@ -196,13 +196,16 @@ std::shared_ptr gemm->workgroupRemapXCC = true; } - // TODO: StreamK is not currently working with prefetching or workgroup mapping - if(gemm->streamK) - { - gemm->prefetch = false; - gemm->workgroupMappingDim = -1; - gemm->workgroupRemapXCC = false; - } + // Pass StreamK flag from solution index parameters + gemm->streamK = solutionIndexParameters.streamK; + + // // StreamK is not currently working with prefetching or workgroup mapping + // if(gemm->streamK) + // { + // gemm->prefetch = false; + // gemm->workgroupMappingDim = -1; + // gemm->workgroupRemapXCC = false; + // } return gemm; } diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp index 2c9549e7bd0..f8c8790299d 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp @@ -243,6 +243,15 @@ std::vector chooseSolutionIndexParameters( { params.back().workgroupMapping = false; } + + // Enable StreamK when number of output tiles < number of CUs + size_t numTilesM = prob.m / wgt.m; + size_t numTilesN = prob.n / wgt.n; + size_t numTiles = numTilesM * numTilesN * prob.batch_count; + if(numTiles < analytical_hardware.N_CU) + { + params.back().streamK = true; + } } } @@ -261,6 +270,8 @@ int parametersToIndex(const SolutionIndexParameters& params) result |= ((params.workgroupTile.m / REQUIRED_MULTIPLE_M_N) << pos); pos += MAX_BITS_WORKGROUPTILE_M; result |= ((params.workgroupMapping ? 1 : 0) << pos); + pos += 1; + result |= ((params.streamK ? 1 : 0) << pos); // Set top bit indicating it is a rocRoller index result |= (1 << 31); @@ -290,6 +301,8 @@ SolutionIndexParameters indexToParameters(int index) = ((index >> pos) & mask(MAX_BITS_WORKGROUPTILE_M)) * REQUIRED_MULTIPLE_M_N; pos += MAX_BITS_WORKGROUPTILE_M; result.workgroupMapping = (index >> pos) & 1; + pos += 1; + result.streamK = (index >> pos) & 1; return result; } From 96262cec254c787a307c5bfa0dfcccb7fb6e7f54 Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Thu, 11 Dec 2025 11:23:51 -0600 Subject: [PATCH 18/22] Update SolutionIndexParameters --- .../amd_detail/rocblaslt/src/rocroller/solution_selection.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp index f8c8790299d..d0447361728 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp @@ -237,7 +237,7 @@ std::vector chooseSolutionIndexParameters( || !std::has_single_bit(static_cast(wgt.n)))) continue; - params.push_back({wgt, true}); + params.push_back({wgt, true, false}); if (prob.k < USE_WORKGROUP_MAPPING_K_SIZE) { From fa87f64db32ddd0c3b9fa69057e16ed2965999d5 Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Thu, 11 Dec 2025 15:21:39 -0600 Subject: [PATCH 19/22] Use prob.Synchronizer for flags --- .../rocblaslt/src/rocroller/gemm.cpp | 64 ++++++++----------- .../src/rocroller/parameter_selection.cpp | 5 +- .../Transformations/AddStreamK.cpp | 8 +-- 3 files changed, 31 insertions(+), 46 deletions(-) diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp index 923aae474d9..7ea346fa412 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp @@ -118,6 +118,10 @@ std::string genKernelName(std::shared_ptr gemm) if(gemm->kernelType.scaleBMode != Operations::ScaleMode::None) rv << "SB_" << genScaleModeString(gemm->kernelType.scaleBMode) << "_"; + if(gemm->streamK) + { + rv << "SK_"; + } rv << "WGT_"; rocRoller::streamJoin( rv, std::vector{gemm->workgroupTile.m, gemm->workgroupTile.n, gemm->workgroupTile.k}, "x"); @@ -530,14 +534,9 @@ size_t workspaceRequired(std::shared_ptr gemm, const RocblasltContra auto runtimeArgs = commandArgs.runtimeArguments(); - // Sum scratch requirements for all policies - size_t total = 0; - for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) - { - auto policy = static_cast(i); - total += gemm->commandKernel->scratchSpaceRequired(policy, runtimeArgs); - } - return total; + // Only return scratch space for ScratchPolicy::None (uses prob.workspace) + return gemm->commandKernel->scratchSpaceRequired( + Operations::ScratchPolicy::None, runtimeArgs); } CommandArguments createCommandArguments(std::shared_ptr gemm, @@ -638,23 +637,31 @@ rocblaslt_status runGemmKernel(std::shared_ptr gemm, } auto commandArgs = createCommandArguments(gemm, prob, DEFAULT_WGM); - // Track allocated scratch memory for each policy - std::array(Operations::ScratchPolicy::Count)> scratchPtrs = {}; - std::array(Operations::ScratchPolicy::Count)> scratchSizes = {}; if(gemm->params->streamK) { auto runtimeArgs = commandArgs.runtimeArguments(); - for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + + // Use prob.workspace for ScratchPolicy::None + auto noneScratchSize = gemm->commandKernel->scratchSpaceRequired( + Operations::ScratchPolicy::None, runtimeArgs); + if(noneScratchSize > 0 && prob.workspace != nullptr) { - auto policy = static_cast(i); - scratchSizes[i] = gemm->commandKernel->scratchSpaceRequired(policy, runtimeArgs); - if(scratchSizes[i] > 0) - { - commandArgs.setArgument( - gemm->tagScratch.at(policy), ArgumentType::Value, - static_cast(scratchPtrs[i])); - } + commandArgs.setArgument( + gemm->tagScratch.at(Operations::ScratchPolicy::None), + ArgumentType::Value, + static_cast(prob.workspace)); + } + + // Use prob.Synchronizer for ScratchPolicy::ZeroedBeforeAndAfter + auto zeroedScratchSize = gemm->commandKernel->scratchSpaceRequired( + Operations::ScratchPolicy::ZeroedBeforeAndAfter, runtimeArgs); + if(zeroedScratchSize > 0 && prob.Synchronizer != nullptr) + { + commandArgs.setArgument( + gemm->tagScratch.at(Operations::ScratchPolicy::ZeroedBeforeAndAfter), + ArgumentType::Value, + static_cast(prob.Synchronizer)); } } @@ -662,27 +669,10 @@ rocblaslt_status runGemmKernel(std::shared_ptr gemm, if(!gemm->commandKernel->matchesPredicates(runtimeArgs, LogLevel::Error)) { - // Free allocated scratch memory before returning - for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) - { - if(scratchPtrs[i] != nullptr) - { - AssertFatal(hipFree(scratchPtrs[i]), "Failed to free scratch memory" + ShowValue(i)); - } - } return rocblaslt_status_invalid_value; } gemm->commandKernel->launchKernel(runtimeArgs, prob.stream); - // Free allocated scratch memory after kernel completes - for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) - { - if(scratchPtrs[i] != nullptr) - { - AssertFatal(hipFree(scratchPtrs[i]), "Failed to free scratch memory" + ShowValue(i)); - } - } - return rocblaslt_status_success; } diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp index e784677a06c..e4fbf4f580e 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp @@ -37,6 +37,7 @@ std::string SolutionParameters::toString() const result << "MachineInstruction:" << machineInstruction.m << "x" << machineInstruction.n << "x" << machineInstruction.k << std::endl; result << "WorkgroupSize:" << workgroupSizeX << "x" << workgroupSizeY << std::endl; + result << "StreamK: " << streamK << std::endl; result << "LoadA: " << loadPathA << std::endl; result << "LoadB: " << loadPathB << std::endl; result << "LDS Usage"; @@ -197,9 +198,9 @@ std::shared_ptr } // Pass StreamK flag from solution index parameters - gemm->streamK = solutionIndexParameters.streamK; + gemm->streamK = false; - // // StreamK is not currently working with prefetching or workgroup mapping + // StreamK is not currently working with prefetching or workgroup mapping // if(gemm->streamK) // { // gemm->prefetch = false; diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index 197c999a691..d8b7f74716b 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -493,11 +493,7 @@ namespace rocRoller // TODO: Improve setting of arch-specific buffer options BufferInstructionOptions bufOpts{.glc = true}; - // if(!(context->targetArchitecture().target().isCDNA1GPU() - // || context->targetArchitecture().target().isCDNA2GPU())) - // { - // bufOpts.sc1 = true; - // } + auto storeFlagTag = graph.control.addElement(StoreSGPR(DataType::UInt32, bufOpts)); graph.mapper.connect(storeFlagTag, flagsScratchTag); graph.mapper.connect(storeFlagTag, flagRegister); @@ -720,9 +716,7 @@ namespace rocRoller graph.control.addElement(Sequence(), {boundsCheckTag}, {doWhileTag}); graph.control.addElement(Body(), {doWhileTag}, {loadFlagTag}); - // auto waitBeforeResetTag = graph.control.addElement(WaitZero()); graph.control.chain(doWhileTag, - // waitBeforeResetTag, barrierBeforeResetTag, wave0ResetFlagTag); graph.control.addElement(Body(), {wave0ResetFlagTag}, {assignResetFlagTag}); From f85411b2e4f6b6281f30ee3c7df878376e1a662b Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Thu, 11 Dec 2025 21:34:37 +0000 Subject: [PATCH 20/22] Enable streamk --- .../amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp index e4fbf4f580e..d10168af79a 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp @@ -198,7 +198,7 @@ std::shared_ptr } // Pass StreamK flag from solution index parameters - gemm->streamK = false; + gemm->streamK = true; // StreamK is not currently working with prefetching or workgroup mapping // if(gemm->streamK) From 7aa9059ecbce9cc545e0dc744b47ef063fc439eb Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Thu, 11 Dec 2025 21:46:37 +0000 Subject: [PATCH 21/22] Disable workgroupMapping when streamK is on --- .../src/rocroller/parameter_selection.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp index d10168af79a..66884ff351e 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp @@ -198,15 +198,14 @@ std::shared_ptr } // Pass StreamK flag from solution index parameters - gemm->streamK = true; - - // StreamK is not currently working with prefetching or workgroup mapping - // if(gemm->streamK) - // { - // gemm->prefetch = false; - // gemm->workgroupMappingDim = -1; - // gemm->workgroupRemapXCC = false; - // } + gemm->streamK = solutionIndexParameters.streamK; + + // StreamK is not currently working with workgroup mapping due to register pressure + if(gemm->streamK) + { + gemm->workgroupMappingDim = -1; + gemm->workgroupRemapXCC = false; + } return gemm; } From a7690257f4975beeae807425b5a5998b40fb290a Mon Sep 17 00:00:00 2001 From: yiqialiu Date: Fri, 12 Dec 2025 19:07:19 +0000 Subject: [PATCH 22/22] Not use streamk for f6 kernels --- .../rocblaslt/src/rocroller/solution_selection.cpp | 5 +++-- .../KernelGraph/Transformations/AddStreamK.cpp | 4 +--- shared/rocroller/test/unit/GEMMTestBase.hpp | 12 ++++++++---- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp index d0447361728..597e0dc525d 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp @@ -244,11 +244,12 @@ std::vector chooseSolutionIndexParameters( params.back().workgroupMapping = false; } - // Enable StreamK when number of output tiles < number of CUs + // Enable StreamK when number of output tiles < number of CUs and not f6 data type size_t numTilesM = prob.m / wgt.m; size_t numTilesN = prob.n / wgt.n; size_t numTiles = numTilesM * numTilesN * prob.batch_count; - if(numTiles < analytical_hardware.N_CU) + auto isF6 = (kernelType.typeA == rocRoller::DataType::FP6 || kernelType.typeA == rocRoller::DataType::BF6 || kernelType.typeB == rocRoller::DataType::FP6 || kernelType.typeB == rocRoller::DataType::BF6); + if(numTiles < analytical_hardware.N_CU && !isF6) { params.back().streamK = true; } diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index d8b7f74716b..17a7fadc8ee 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -716,9 +716,7 @@ namespace rocRoller graph.control.addElement(Sequence(), {boundsCheckTag}, {doWhileTag}); graph.control.addElement(Body(), {doWhileTag}, {loadFlagTag}); - graph.control.chain(doWhileTag, - barrierBeforeResetTag, - wave0ResetFlagTag); + graph.control.chain(doWhileTag, barrierBeforeResetTag, wave0ResetFlagTag); graph.control.addElement(Body(), {wave0ResetFlagTag}, {assignResetFlagTag}); auto waitAfterRestFlagStoreTag = graph.control.addElement(WaitZero()); graph.control.chain( diff --git a/shared/rocroller/test/unit/GEMMTestBase.hpp b/shared/rocroller/test/unit/GEMMTestBase.hpp index 6c64d34c347..d30c9bcca47 100644 --- a/shared/rocroller/test/unit/GEMMTestBase.hpp +++ b/shared/rocroller/test/unit/GEMMTestBase.hpp @@ -803,12 +803,16 @@ namespace GEMMTests for(int iteration = 0; iteration < numIters; ++iteration) { ASSERT_THAT(hipMemset(deviceD.get(), 0, M * N * sizeof(TD)), HasHipSuccess(0)); - for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + if(iteration == 0) { - if(scratchSpaceRequired[i] > 0) + for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) { - ASSERT_THAT(hipMemset(deviceScratch[i].get(), 0, scratchSpaceRequired[i]), - HasHipSuccess(0)); + if(scratchSpaceRequired[i] > 0) + { + ASSERT_THAT( + hipMemset(deviceScratch[i].get(), 0, scratchSpaceRequired[i]), + HasHipSuccess(0)); + } } }