diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp index 6d68edacb4e..8c53cc3f061 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp @@ -118,6 +118,10 @@ std::string genKernelName(std::shared_ptr gemm) if(gemm->kernelType.scaleTypeB.mode != Operations::ScaleMode::None) rv << "SB_" << genScaleModeString(gemm->kernelType.scaleTypeB.mode) << "_"; + if(gemm->streamK) + { + rv << "SK_"; + } rv << "WGT_"; rocRoller::streamJoin( rv, std::vector{gemm->workgroupTile.m, gemm->workgroupTile.n, gemm->workgroupTile.k}, "x"); @@ -170,7 +174,8 @@ std::shared_ptr genGemmKernel(std::shared_ptr ge ShowValue(gemm->kernelType.scaleTypeB.mode)); std::optional tagTensorScaleA, tagLoadScaleA, tagBlockScaleA, - tagTensorScaleB, tagLoadScaleB, tagBlockScaleB, tagScratch, tagSKGrid, tagWGM; + tagTensorScaleB, tagLoadScaleB, tagBlockScaleB, tagSKGrid, tagWGM; + std::map tagScratch; if(gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate) { @@ -265,12 +270,19 @@ std::shared_ptr genGemmKernel(std::shared_ptr ge DataDirection::ReadOnly, rocRoller::NUMWGS); - tagScratch = command->allocateTag(); - command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), - *tagScratch, - ArgumentType::Value, - DataDirection::ReadWrite, - rocRoller::SCRATCH); + // Create Scratch operations for each ScratchPolicy + for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + auto tag = command->allocateTag(); + command->addOperation(Operations::Scratch(tag, policy)); + command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), + tag, + ArgumentType::Value, + DataDirection::ReadWrite, + rocRoller::getScratchName(policy)); + tagScratch[policy] = tag; + } } if(gemm->workgroupMappingDim != -1) @@ -489,8 +501,7 @@ std::shared_ptr genGemmKernel(std::shared_ptr ge if(tagTensorScaleB) gemmKernel->tagTensorScaleB = *tagTensorScaleB; - if(tagScratch) - gemmKernel->tagScratch = *tagScratch; + gemmKernel->tagScratch = tagScratch; if(tagSKGrid) gemmKernel->tagSKGrid = *tagSKGrid; @@ -523,7 +534,9 @@ size_t workspaceRequired(std::shared_ptr gemm, const RocblasltContra auto runtimeArgs = commandArgs.runtimeArguments(); - return gemm->commandKernel->scratchSpaceRequired(runtimeArgs); + // Only return scratch space for ScratchPolicy::None (uses prob.workspace) + return gemm->commandKernel->scratchSpaceRequired( + Operations::ScratchPolicy::None, runtimeArgs); } CommandArguments createCommandArguments(std::shared_ptr gemm, @@ -622,14 +635,34 @@ rocblaslt_status runGemmKernel(std::shared_ptr gemm, } return rocblaslt_status_invalid_value; } - auto commandArgs = createCommandArguments(gemm, prob, DEFAULT_WGM); - // Add scratch space - if(workSpaceRequired > 0) + + if(gemm->params->streamK) { - commandArgs.setArgument( - gemm->tagScratch, ArgumentType::Value, static_cast(prob.workspace)); + auto runtimeArgs = commandArgs.runtimeArguments(); + + // Use prob.workspace for ScratchPolicy::None + auto noneScratchSize = gemm->commandKernel->scratchSpaceRequired( + Operations::ScratchPolicy::None, runtimeArgs); + if(noneScratchSize > 0 && prob.workspace != nullptr) + { + commandArgs.setArgument( + gemm->tagScratch.at(Operations::ScratchPolicy::None), + ArgumentType::Value, + static_cast(prob.workspace)); + } + + // Use prob.Synchronizer for ScratchPolicy::ZeroedBeforeAndAfter + auto zeroedScratchSize = gemm->commandKernel->scratchSpaceRequired( + Operations::ScratchPolicy::ZeroedBeforeAndAfter, runtimeArgs); + if(zeroedScratchSize > 0 && prob.Synchronizer != nullptr) + { + commandArgs.setArgument( + gemm->tagScratch.at(Operations::ScratchPolicy::ZeroedBeforeAndAfter), + ArgumentType::Value, + static_cast(prob.Synchronizer)); + } } auto runtimeArgs = commandArgs.runtimeArguments(); @@ -640,5 +673,6 @@ rocblaslt_status runGemmKernel(std::shared_ptr gemm, } gemm->commandKernel->launchKernel(runtimeArgs, prob.stream); + return rocblaslt_status_success; } diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/gemm.hpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/gemm.hpp index 6f0cd7dd440..72308103f66 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/gemm.hpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/gemm.hpp @@ -32,8 +32,11 @@ #include #include #include +#include #include +#include + /** * @brief GemmKernel * @@ -57,7 +60,7 @@ struct GemmKernel rocRoller::Operations::OperationTag tagTensorScaleA; rocRoller::Operations::OperationTag tagTensorScaleB; - rocRoller::Operations::OperationTag tagScratch; + std::map tagScratch; rocRoller::Operations::OperationTag tagSKGrid; rocRoller::Operations::OperationTag tagWGM; diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp index e4055d2f7ca..d4f92263f7e 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp @@ -71,6 +71,7 @@ struct SolutionIndexParameters { WorkGroupTileSize workgroupTile; bool workgroupMapping; + bool streamK; auto operator<=>(const SolutionIndexParameters& other) const = default; }; diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp index f5cbedb25a6..75d0ab75960 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp @@ -37,6 +37,7 @@ std::string SolutionParameters::toString() const result << "MachineInstruction:" << machineInstruction.m << "x" << machineInstruction.n << "x" << machineInstruction.k << std::endl; result << "WorkgroupSize:" << workgroupSizeX << "x" << workgroupSizeY << std::endl; + result << "StreamK: " << streamK << std::endl; result << "LoadA: " << loadPathA << std::endl; result << "LoadB: " << loadPathB << std::endl; result << "LDS Usage"; @@ -196,10 +197,12 @@ std::shared_ptr gemm->workgroupRemapXCC = true; } - // TODO: StreamK is not currently working with prefetching or workgroup mapping + // Pass StreamK flag from solution index parameters + gemm->streamK = solutionIndexParameters.streamK; + + // StreamK is not currently working with workgroup mapping due to register pressure if(gemm->streamK) { - gemm->prefetch = false; gemm->workgroupMappingDim = -1; gemm->workgroupRemapXCC = false; } diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp index ae1dca5e046..63144fe76ef 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp @@ -237,12 +237,22 @@ std::vector chooseSolutionIndexParameters( || !std::has_single_bit(static_cast(wgt.n)))) continue; - params.push_back({wgt, true}); + params.push_back({wgt, true, false}); if (prob.k < USE_WORKGROUP_MAPPING_K_SIZE) { params.back().workgroupMapping = false; } + + // Enable StreamK when number of output tiles < number of CUs and not f6 data type + size_t numTilesM = prob.m / wgt.m; + size_t numTilesN = prob.n / wgt.n; + size_t numTiles = numTilesM * numTilesN * prob.batch_count; + auto isF6 = (kernelType.typeA == rocRoller::DataType::FP6 || kernelType.typeA == rocRoller::DataType::BF6 || kernelType.typeB == rocRoller::DataType::FP6 || kernelType.typeB == rocRoller::DataType::BF6); + if(numTiles < analytical_hardware.N_CU && !isF6) + { + params.back().streamK = true; + } } } @@ -261,6 +271,8 @@ int parametersToIndex(const SolutionIndexParameters& params) result |= ((params.workgroupTile.m / REQUIRED_MULTIPLE_M_N) << pos); pos += MAX_BITS_WORKGROUPTILE_M; result |= ((params.workgroupMapping ? 1 : 0) << pos); + pos += 1; + result |= ((params.streamK ? 1 : 0) << pos); // Set top bit indicating it is a rocRoller index result |= (1 << 31); @@ -290,6 +302,8 @@ SolutionIndexParameters indexToParameters(int index) = ((index >> pos) & mask(MAX_BITS_WORKGROUPTILE_M)) * REQUIRED_MULTIPLE_M_N; pos += MAX_BITS_WORKGROUPTILE_M; result.workgroupMapping = (index >> pos) & 1; + pos += 1; + result.streamK = (index >> pos) & 1; return result; } diff --git a/shared/rocroller/client/include/client/GEMMSolution.hpp b/shared/rocroller/client/include/client/GEMMSolution.hpp index 09aa729cb14..3a2c825f5e0 100644 --- a/shared/rocroller/client/include/client/GEMMSolution.hpp +++ b/shared/rocroller/client/include/client/GEMMSolution.hpp @@ -29,6 +29,7 @@ #include "GEMMParameters.hpp" #include +#include using namespace rocRoller; @@ -61,7 +62,8 @@ namespace rocRoller { } - virtual Operations::OperationTag getScratchTag() const + virtual Operations::OperationTag + getScratchTag(Operations::ScratchPolicy scratchPolicy) const { return {}; } diff --git a/shared/rocroller/client/include/client/StreamKGEMMSolution.hpp b/shared/rocroller/client/include/client/StreamKGEMMSolution.hpp index a64789be4ea..ce3375f6c08 100644 --- a/shared/rocroller/client/include/client/StreamKGEMMSolution.hpp +++ b/shared/rocroller/client/include/client/StreamKGEMMSolution.hpp @@ -41,14 +41,16 @@ namespace rocRoller { class StreamKGEMMSolution : public DataParallelGEMMSolution { - Operations::OperationTag m_scratchTag, m_numWGsTag; + std::map m_scratchTags; + Operations::OperationTag m_numWGsTag; public: using DataParallelGEMMSolution::DataParallelGEMMSolution; - Operations::OperationTag getScratchTag() const override + Operations::OperationTag + getScratchTag(Operations::ScratchPolicy scratchPolicy) const override { - return m_scratchTag; + return m_scratchTags.at(scratchPolicy); } protected: @@ -63,15 +65,31 @@ namespace rocRoller DataDirection::ReadOnly, rocRoller::NUMWGS); - m_scratchTag = command->allocateTag(); - command->addOperation(rocRoller::Operations::Scratch( - m_scratchTag, Operations::ScratchPolicy::None)); + // Create a scratch operation for tile data + m_scratchTags[Operations::ScratchPolicy::None] = command->allocateTag(); + command->addOperation( + Operations::Scratch(m_scratchTags[Operations::ScratchPolicy::None], + Operations::ScratchPolicy::None)); command->allocateArgument( VariableType(DataType::UInt32, PointerType::PointerGlobal), - m_scratchTag, + m_scratchTags[Operations::ScratchPolicy::None], ArgumentType::Value, DataDirection::ReadWrite, - rocRoller::SCRATCH); + getScratchName(Operations::ScratchPolicy::None)); + + // Create a scratch operation for flags + m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter] + = command->allocateTag(); + command->addOperation(Operations::Scratch( + m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter], + Operations::ScratchPolicy::ZeroedBeforeAndAfter)); + command->allocateArgument( + VariableType(DataType::UInt32, PointerType::PointerGlobal), + m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter], + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(Operations::ScratchPolicy::ZeroedBeforeAndAfter)); + return command; } diff --git a/shared/rocroller/client/src/gemm.cpp b/shared/rocroller/client/src/gemm.cpp index 0b49c056e85..8935ff93345 100644 --- a/shared/rocroller/client/src/gemm.cpp +++ b/shared/rocroller/client/src/gemm.cpp @@ -25,6 +25,7 @@ *******************************************************************************/ #include "rocRoller/Serialization/YAML.hpp" +#include #include #include #include @@ -344,14 +345,18 @@ namespace rocRoller::Client::GEMMClient auto runtimeArgs = commandArgs.runtimeArguments(); // Note: the lifetime of deviceScratch needs to exceed kernel executions - std::shared_ptr deviceScratch; + std::shared_ptr + deviceScratch[static_cast(Operations::ScratchPolicy::Count)]; + + for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) { - auto scratchSpaceRequired = commandKernel->scratchSpaceRequired(runtimeArgs); + auto policy = static_cast(i); + auto scratchSpaceRequired = commandKernel->scratchSpaceRequired(policy, runtimeArgs); if(scratchSpaceRequired > 0) { - deviceScratch = make_shared_device(scratchSpaceRequired, 0); + deviceScratch[i] = make_shared_device(scratchSpaceRequired, 0); commandArgs.setArgument( - gemm->getScratchTag(), ArgumentType::Value, deviceScratch.get()); + gemm->getScratchTag(policy), ArgumentType::Value, deviceScratch[i].get()); } } @@ -457,6 +462,24 @@ namespace rocRoller::Client::GEMMClient auto [correct, rnorm] = validate( hostA, hostB, hostC, hostD, hostScaleA, hostScaleB, problemParams, arch); + // Verify ZeroedBeforeAndAfter scratch is all zeros after kernel + auto zeroedIdx = static_cast(Operations::ScratchPolicy::ZeroedBeforeAndAfter); + if(deviceScratch[zeroedIdx]) + { + auto zeroedSize = commandKernel->scratchSpaceRequired( + Operations::ScratchPolicy::ZeroedBeforeAndAfter, runtimeArgs); + std::vector zeroedResult(zeroedSize); + AssertFatal(hipMemcpy(zeroedResult.data(), + deviceScratch[zeroedIdx].get(), + zeroedSize, + hipMemcpyDeviceToHost) + == (hipError_t)HIP_SUCCESS); + AssertFatal( + std::all_of( + zeroedResult.begin(), zeroedResult.end(), [](uint8_t v) { return v == 0; }), + "ZeroedBeforeAndAfter scratch should be all zeros after kernel execution"); + } + result.checked = true; result.correct = correct; result.rnorm = rnorm; diff --git a/shared/rocroller/lib/include/rocRoller/CommandSolution.hpp b/shared/rocroller/lib/include/rocRoller/CommandSolution.hpp index c1d1eb1aada..31585d08b67 100644 --- a/shared/rocroller/lib/include/rocRoller/CommandSolution.hpp +++ b/shared/rocroller/lib/include/rocRoller/CommandSolution.hpp @@ -307,13 +307,17 @@ namespace rocRoller /** * @brief Returns the total number of bytes required for scratch space + * for the specified scratch policy. * - * If this value is greather than 0, the user is required to allocate this + * If this value is greater than 0, the user is required to allocate this * amount of device memory and pass it into the kernel. * + * @param policy The scratch policy to query + * @param args The runtime arguments * @return size_t */ - size_t scratchSpaceRequired(RuntimeArguments const& args) const; + size_t scratchSpaceRequired(Operations::ScratchPolicy policy, + RuntimeArguments const& args) const; /** * @brief Returns the workgroup size diff --git a/shared/rocroller/lib/include/rocRoller/Context.hpp b/shared/rocroller/lib/include/rocRoller/Context.hpp index 5fe9039e9d0..39eb1781312 100644 --- a/shared/rocroller/lib/include/rocRoller/Context.hpp +++ b/shared/rocroller/lib/include/rocRoller/Context.hpp @@ -53,6 +53,7 @@ #include #include #include +#include #include #include #include @@ -127,18 +128,23 @@ namespace rocRoller void setKernel(AssemblyKernelPtr); /** - * @brief Returns an expression representing how much scratch space is required (in bytes) + * @brief Allocate scratch space for the specified scratch policy. * - * @return Expression::ExpressionPtr + * @param policy The scratch policy to allocate for + * @param size Number of bytes requested + * @return Expression::ExpressionPtr The offset before this allocation */ - Expression::ExpressionPtr getScratchAmount() const; + Expression::ExpressionPtr allocateScratch(Operations::ScratchPolicy policy, + Expression::ExpressionPtr size); /** - * @brief Allocate more scratch space + * @brief Returns an expression representing how much scratch space is required (in bytes) + * for the specified scratch policy. * - * @param size Number of bytes requested + * @param policy The scratch policy to query + * @return Expression::ExpressionPtr */ - void allocateScratch(Expression::ExpressionPtr size); + Expression::ExpressionPtr getScratchAmount(Operations::ScratchPolicy policy) const; /** * @brief Get register scope manager. @@ -168,14 +174,15 @@ namespace rocRoller std::array, static_cast(Register::Type::Count)> m_allocators; - std::shared_ptr m_observer; - AssemblyKernelPtr m_kernel; - std::shared_ptr m_argLoader; - std::shared_ptr m_instructions; - std::shared_ptr m_mem; - LabelAllocatorPtr m_labelAllocator; - std::shared_ptr m_ldsAllocator; - Expression::ExpressionPtr m_scratchAllocator; + std::shared_ptr m_observer; + AssemblyKernelPtr m_kernel; + std::shared_ptr m_argLoader; + std::shared_ptr m_instructions; + std::shared_ptr m_mem; + LabelAllocatorPtr m_labelAllocator; + std::shared_ptr m_ldsAllocator; + std::array(Operations::ScratchPolicy::Count)> + m_scratchSizes; std::shared_ptr m_copier; std::shared_ptr m_brancher; std::shared_ptr m_crasher; diff --git a/shared/rocroller/lib/include/rocRoller/KernelGraph/CoordinateGraph/Dimension.hpp b/shared/rocroller/lib/include/rocRoller/KernelGraph/CoordinateGraph/Dimension.hpp index 17005bcb424..55e9362f5db 100644 --- a/shared/rocroller/lib/include/rocRoller/KernelGraph/CoordinateGraph/Dimension.hpp +++ b/shared/rocroller/lib/include/rocRoller/KernelGraph/CoordinateGraph/Dimension.hpp @@ -153,8 +153,11 @@ namespace rocRoller * * @param size How many elements make up the User dimension. * @param offset Location of data within the scratch space + * @param argName Name of the argument for this scratch space */ - User(Expression::ExpressionPtr size, Expression::ExpressionPtr offset); + User(Expression::ExpressionPtr size, + Expression::ExpressionPtr offset, + std::string const& argName); std::string name() const override; }; diff --git a/shared/rocroller/lib/include/rocRoller/KernelGraph/Utils.hpp b/shared/rocroller/lib/include/rocRoller/KernelGraph/Utils.hpp index ffad0b98afc..8d3363b960b 100644 --- a/shared/rocroller/lib/include/rocRoller/KernelGraph/Utils.hpp +++ b/shared/rocroller/lib/include/rocRoller/KernelGraph/Utils.hpp @@ -319,11 +319,15 @@ namespace rocRoller * * @param size * @param varType + * @param policy The scratch policy to use for allocation * @param context * @return User */ - rocRoller::KernelGraph::CoordinateGraph::User newScratchCoordinate( - Expression::ExpressionPtr size, VariableType varType, ContextPtr context); + rocRoller::KernelGraph::CoordinateGraph::User + newScratchCoordinate(Expression::ExpressionPtr size, + VariableType varType, + Operations::ScratchPolicy policy, + ContextPtr context); /** * @brief Replace operation with a new operation. diff --git a/shared/rocroller/lib/include/rocRoller/KernelOptions.hpp b/shared/rocroller/lib/include/rocRoller/KernelOptions.hpp index 0c6fe227057..b57e3725896 100644 --- a/shared/rocroller/lib/include/rocRoller/KernelOptions.hpp +++ b/shared/rocroller/lib/include/rocRoller/KernelOptions.hpp @@ -33,6 +33,7 @@ #include #include +#include #include #include @@ -49,6 +50,12 @@ namespace rocRoller const std::string NUMWGS = "numWGs"; const std::string WGM = "WGM"; + // Helper to get scratch argument name for a specific policy + inline std::string getScratchName(Operations::ScratchPolicy policy) + { + return rocRoller::SCRATCH + "_" + Operations::toString(policy); + } + class KernelOptions { public: diff --git a/shared/rocroller/lib/source/CommandSolution.cpp b/shared/rocroller/lib/source/CommandSolution.cpp index 0005f1a8c94..5a1326bc149 100644 --- a/shared/rocroller/lib/source/CommandSolution.cpp +++ b/shared/rocroller/lib/source/CommandSolution.cpp @@ -646,9 +646,10 @@ namespace rocRoller return m_context; } - size_t CommandKernel::scratchSpaceRequired(RuntimeArguments const& args) const + size_t CommandKernel::scratchSpaceRequired(Operations::ScratchPolicy policy, + RuntimeArguments const& args) const { - auto amount = m_context->getScratchAmount(); + auto amount = m_context->getScratchAmount(policy); auto times = evaluationTimes(amount); AssertFatal(times[Expression::EvaluationTime::Translate] diff --git a/shared/rocroller/lib/source/Context.cpp b/shared/rocroller/lib/source/Context.cpp index 32272a1b430..0356cfdbf4c 100644 --- a/shared/rocroller/lib/source/Context.cpp +++ b/shared/rocroller/lib/source/Context.cpp @@ -47,8 +47,12 @@ namespace rocRoller { Context::Context() - : m_scratchAllocator(Expression::literal(0u)) { + // Initialize scratch sizes for each policy with zero + for(size_t i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + { + m_scratchSizes[i] = Expression::literal(0u); + } } ContextPtr Context::ForDefaultHipDevice(std::string const& kernelName, @@ -287,14 +291,18 @@ namespace rocRoller m_kernel = assemblyKernel; } - Expression::ExpressionPtr Context::getScratchAmount() const + Expression::ExpressionPtr Context::allocateScratch(Operations::ScratchPolicy policy, + Expression::ExpressionPtr size) { - return m_scratchAllocator; + auto idx = static_cast(policy); + auto currentOffset = m_scratchSizes[idx]; + m_scratchSizes[idx] = simplify(m_scratchSizes[idx] + size); + return currentOffset; } - void Context::allocateScratch(Expression::ExpressionPtr size) + Expression::ExpressionPtr Context::getScratchAmount(Operations::ScratchPolicy policy) const { - m_scratchAllocator = simplify(m_scratchAllocator + size); + return m_scratchSizes[static_cast(policy)]; } void Context::scheduleCopy(Instruction const& inst) diff --git a/shared/rocroller/lib/source/KernelGraph/CoordinateGraph/Dimension.cpp b/shared/rocroller/lib/source/KernelGraph/CoordinateGraph/Dimension.cpp index 61030ea2708..09cd3a9ec80 100644 --- a/shared/rocroller/lib/source/KernelGraph/CoordinateGraph/Dimension.cpp +++ b/shared/rocroller/lib/source/KernelGraph/CoordinateGraph/Dimension.cpp @@ -117,9 +117,11 @@ namespace rocRoller return name() + stag; } - User::User(Expression::ExpressionPtr size, Expression::ExpressionPtr offset) + User::User(Expression::ExpressionPtr size, + Expression::ExpressionPtr offset, + std::string const& argName) : BaseDimension(size, Expression::literal(1u), offset) - , argumentName(rocRoller::SCRATCH) + , argumentName(argName) { } diff --git a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp index 9d6aae3846b..17a7fadc8ee 100644 --- a/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Transformations/AddStreamK.cpp @@ -306,7 +306,8 @@ namespace rocRoller auto strideX = sizeY; auto strideY = literal(1u); - auto globalScratch = newScratchCoordinate(simplify(sizeX * sizeY), varType, context); + auto globalScratch = newScratchCoordinate( + simplify(sizeX * sizeY), varType, Operations::ScratchPolicy::None, context); auto globalScratchTag = graph.coordinates.addElement(globalScratch); std::vector jammedSizes = {loopInfo.xLoopSize, loopInfo.yLoopSize}; @@ -427,9 +428,10 @@ namespace rocRoller * StoreTile() * WaitZero() * Barrier() - * flag = Assign(SGPR, 1u); - * StoreSGPR(flag) - * WaitZero() + * if wave0: + * flag = Assign(SGPR, 1u); + * StoreSGPR(flag) + * WaitZero() */ SendInfo sendTile(KernelGraph& graph, ExpressionPtr sendTileExpr, @@ -491,24 +493,35 @@ namespace rocRoller // TODO: Improve setting of arch-specific buffer options BufferInstructionOptions bufOpts{.glc = true}; + auto storeFlagTag = graph.control.addElement(StoreSGPR(DataType::UInt32, bufOpts)); graph.mapper.connect(storeFlagTag, flagsScratchTag); graph.mapper.connect(storeFlagTag, flagRegister); + // Create workitem coordinate and expression for wave 0 check + // Only workitem 0 (wave 0) should write to the flag + auto workitemTag = graph.coordinates.addElement(Workitem(0)); + auto workitemDF = std::make_shared( + Expression::DataFlowTag{workitemTag, Register::Type::Vector, DataType::UInt32}); + auto isWave0Expr = (workitemDF == Expression::literal(0u)); + auto wave0FlagStoreTag + = graph.control.addElement(ConditionalOp{isWave0Expr, "Wave0 Store Flag"}); + // Add to control - auto preWaitZeroTag = graph.control.addElement(WaitZero()); - auto postWaitZeroTag = graph.control.addElement(WaitZero()); + auto preWaitZeroTag = graph.control.addElement(WaitZero()); graph.control.addElement(Sequence(), {preWaitZeroTag}, {sendTileTag}); graph.control.addElement(Body(), {sendTileTag}, {forX}); - graph.control.chain( - forX, waitZeroTag, barrierTag, assignFlagTag, storeFlagTag, postWaitZeroTag); + graph.control.chain(forX, waitZeroTag, barrierTag, wave0FlagStoreTag); + graph.control.addElement(Body(), {wave0FlagStoreTag}, {assignFlagTag}); + auto waitAfterStoreFlagTag = graph.control.addElement(WaitZero()); + graph.control.chain(assignFlagTag, storeFlagTag, waitAfterStoreFlagTag); return {preWaitZeroTag, sendTileTag}; } /** - * Create send-tile block, which is roughly: + * Create receive-tile block, which is roughly: * * WaitZero() * if receiveTileExpr: @@ -518,6 +531,11 @@ namespace rocRoller * do: * LoadSGPR(flag[nextWG]) * while flag[nextWG] == 0 + * Barrier() + * if wave0: + * Assign(flag[nextWG] = 0) + * StoreSGPR(flag[nextWG]) + * WaitZero() * partiallyAccumulatedTile = LoadTiled() * fullyAccumulatedTile = Assign(localPartiallyAccumulatedTile) * fullyAccumulatedTile = Assign(fullyAccumulatedTile + partiallyAccumulatedTile) @@ -584,6 +602,62 @@ namespace rocRoller auto doWhileTag = graph.control.addElement( DoWhileOp{(DF(flagRegister) == zero), "Global sync spin loop"}); + // The coordinate graph for load, store, and reset flags: + // + // Workgroup + // | + // PassThrough + // | + // v + // flagsScratchTag <-Duplicate-- resetFlagsScratchTag + // | ^ + // PassThrough PassThrough + // | | + // v resetNextWorkgroupTag + // nextWorkgroupTag ^ + // | Join + // Split / | \ + // / | \ / | \ + // v v v / | \ + // Workgroup plusOne forReceiveTileLoop + // + // Note: nextWorkgroupTag and resetNextWorkgroupTag both connect to the same + // neighbors (Workgroup, plusOne, forReceiveTileLoop) but in opposite directions + // (Split vs Join). This allows loading flags from WG+1+i and resetting flags + // at the same index. + + // Create coordinate to indicate flag index to reset + auto resetNextWorkgroupTag = graph.coordinates.addElement(Linear(nullptr, one)); + graph.coordinates.addElement( + Join(), {workgroup, plusOneTag, forReceiveTileLoopCoord}, {resetNextWorkgroupTag}); + + // Duplicate flags scratch coordinate for reset operation + auto resetFlagsScratchTag + = graph.coordinates.addElement(*graph.coordinates.get(flagsScratchTag)); + graph.coordinates.addElement(Duplicate(), {resetFlagsScratchTag}, {flagsScratchTag}); + graph.coordinates.addElement( + PassThrough(), {resetNextWorkgroupTag}, {resetFlagsScratchTag}); + + // Reset flag operations + auto assignResetFlagTag + = graph.control.addElement(Assign{Register::Type::Scalar, zero}); + graph.mapper.connect(assignResetFlagTag, flagRegister, NaryArgument::DEST); + + auto resetFlagTag = graph.control.addElement(StoreSGPR(DataType::UInt32, bufOpts)); + graph.mapper.connect(resetFlagTag, resetFlagsScratchTag); + graph.mapper.connect(resetFlagTag, flagRegister); + + // Create workitem coordinate and expression for wave 0 check + // Only workitem 0 (wave 0) should write to the flag + auto workitemTag = graph.coordinates.addElement(Workitem(0)); + auto workitemDF = std::make_shared( + Expression::DataFlowTag{workitemTag, Register::Type::Vector, DataType::UInt32}); + auto isWave0Expr = (workitemDF == Expression::literal(0u)); + auto wave0ResetFlagTag + = graph.control.addElement(ConditionalOp{isWave0Expr, "Wave0 Reset Flag"}); + + auto barrierBeforeResetTag = graph.control.addElement(Barrier()); + auto accumulatorTile = graph.coordinates.get(accumulatorTileTag); uint numRegisters = accumulatorTile->elements() / (product(context->kernel()->workgroupSize()) * loopInfo.xLoopSize @@ -642,7 +716,12 @@ namespace rocRoller graph.control.addElement(Sequence(), {boundsCheckTag}, {doWhileTag}); graph.control.addElement(Body(), {doWhileTag}, {loadFlagTag}); - graph.control.chain(doWhileTag, loadAddForX, postWaitZeroTag); + graph.control.chain(doWhileTag, barrierBeforeResetTag, wave0ResetFlagTag); + graph.control.addElement(Body(), {wave0ResetFlagTag}, {assignResetFlagTag}); + auto waitAfterRestFlagStoreTag = graph.control.addElement(WaitZero()); + graph.control.chain( + assignResetFlagTag, resetFlagTag, waitAfterRestFlagStoreTag); + graph.control.chain(wave0ResetFlagTag, loadAddForX, postWaitZeroTag); return {preWaitZeroTag, receiveTileTag, setPlusOneTag}; } @@ -1173,7 +1252,7 @@ namespace rocRoller int postAccumulationCond; if(accumInfo.accumulatorTile != -1) { - auto remainAccumTiles = numAccumTiles - DF(lastAccumTile) + one; + auto remainAccumTiles = numAccumTiles - DF(lastAccumTile) - one; auto numRemainPartialResults = (remainAccumTiles + argInfo.numSKTilesPerWG - one) / argInfo.numSKTilesPerWG; @@ -1185,7 +1264,11 @@ namespace rocRoller resultVariableType(numRemainPartialResults)); // Create scratch space for flags - auto flagsScratch = newScratchCoordinate(argInfo.numWGs, DataType::UInt32, context); + auto flagsScratch + = newScratchCoordinate(argInfo.numWGs, + DataType::UInt32, + Operations::ScratchPolicy::ZeroedBeforeAndAfter, + context); auto flagsScratchTag = graph.coordinates.addElement(flagsScratch); // Create scratch space for partially accumulated tiles diff --git a/shared/rocroller/lib/source/KernelGraph/Utils.cpp b/shared/rocroller/lib/source/KernelGraph/Utils.cpp index 0f960f6b38d..73809aaa6c7 100644 --- a/shared/rocroller/lib/source/KernelGraph/Utils.cpp +++ b/shared/rocroller/lib/source/KernelGraph/Utils.cpp @@ -29,6 +29,7 @@ #include #include +#include namespace rocRoller { @@ -817,15 +818,18 @@ namespace rocRoller return rv; } - rocRoller::KernelGraph::CoordinateGraph::User newScratchCoordinate( - Expression::ExpressionPtr size, VariableType varType, ContextPtr context) + rocRoller::KernelGraph::CoordinateGraph::User + newScratchCoordinate(Expression::ExpressionPtr size, + VariableType varType, + Operations::ScratchPolicy policy, + ContextPtr context) { - auto currentOffset = context->getScratchAmount(); - auto newCoordinate = CT::User(size, currentOffset); // TODO Audit bytes/bits // Can we move size inside the CeilDivide? - context->allocateScratch( + auto currentOffset = context->allocateScratch( + policy, size * Expression::literal(CeilDivide(DataTypeInfo::Get(varType).elementBits, 8u))); + auto newCoordinate = CT::User(size, currentOffset, getScratchName(policy)); return newCoordinate; } diff --git a/shared/rocroller/test/catch/AddStreamKTest.cpp b/shared/rocroller/test/catch/AddStreamKTest.cpp index 96cab26c283..067383c7436 100644 --- a/shared/rocroller/test/catch/AddStreamKTest.cpp +++ b/shared/rocroller/test/catch/AddStreamKTest.cpp @@ -46,6 +46,7 @@ #include #include #include +#include #include #include #include @@ -621,3 +622,190 @@ TEST_CASE("AddStreamK with unroll K", "[streamk][kernel-graph]") } } } + +TEST_CASE("AddStreamK scratch policy usage", "[streamk][kernel-graph][scratch]") +{ + using namespace rocRoller; + using namespace rocRoller::Operations; + using namespace KernelGraph; + using namespace ControlGraph; + using namespace CoordinateGraph; + + auto context = TestContext::ForDefaultTarget(); + auto example = rocRollerTest::Graphs::GEMM(DataType::Float); + auto mode = StreamKMode::Standard; + + example.setTileSize(128, 256, 8); + example.setMFMA(32, 32, 2, 1); + example.setUseLDS(false, false, false); + example.setPrefetch(false, 0, 0, false); + example.setStreamK(mode); + + auto numWGs = example.getFlattenedWorkgroupSize(); + auto numWGsExpr = std::make_shared(numWGs); + + auto kgraph = example.getKernelGraph(); + auto params = example.getCommandParameters(); + + // Apply transforms including AddStreamK + std::vector transforms; + transforms.push_back(std::make_shared()); + transforms.push_back(std::make_shared(false)); + transforms.push_back(std::make_shared(params)); + transforms.push_back(std::make_shared(params, context.get())); + transforms.push_back(std::make_shared(context.get())); + transforms.push_back(std::make_shared(params, context.get())); + transforms.push_back(std::make_shared(params, context.get())); + transforms.push_back(std::make_shared()); + transforms.push_back(std::make_shared()); + transforms.push_back(std::make_shared( + context.get(), params, rocRoller::XLOOP, rocRoller::KLOOP, numWGsExpr)); + transforms.push_back(std::make_shared(context.get())); + + for(auto& t : transforms) + kgraph = kgraph.transform(t); + + SECTION("Tile data uses None policy") + { + // After AddStreamK, the None policy should have non-zero scratch allocation + // (used for tile data exchange between workgroups) + auto amountNone = context->getScratchAmount(ScratchPolicy::None); + auto valueNone = Expression::evaluate(amountNone); + // Tile data scratch should be allocated + CHECK(getUnsignedInt(valueNone) > 0); + } + + SECTION("Flags use ZeroedBeforeAndAfter policy") + { + // After AddStreamK, the ZeroedBeforeAndAfter policy should have non-zero scratch + // allocation (used for synchronization flags that must be zeroed before/after kernel) + auto amountZeroed = context->getScratchAmount(ScratchPolicy::ZeroedBeforeAndAfter); + auto valueZeroed = Expression::evaluate(amountZeroed); + // Flags scratch should be allocated for sync purposes + CHECK(getUnsignedInt(valueZeroed) > 0); + } + + SECTION("Different policies have different allocations") + { + auto amountNone = context->getScratchAmount(ScratchPolicy::None); + auto amountZeroed = context->getScratchAmount(ScratchPolicy::ZeroedBeforeAndAfter); + + auto valueNone = Expression::evaluate(amountNone); + auto valueZeroed = Expression::evaluate(amountZeroed); + + // Both should have allocations, but they should be independent + // (flags are typically smaller - one uint32 per WG, while tile data is larger) + CHECK(getUnsignedInt(valueNone) > 0); + CHECK(getUnsignedInt(valueZeroed) > 0); + // Tile data should typically be larger than flags + CHECK(getUnsignedInt(valueNone) > getUnsignedInt(valueZeroed)); + } + + SECTION("Only has one scratch coordinate for tile data") + { + + auto findScratchNone = [&](int tag) { + auto maybeUser = kgraph.coordinates.get(tag); + if(!maybeUser) + return false; + return maybeUser->argumentName == getScratchName(ScratchPolicy::None); + }; + + auto scratchNoneCoordinates + = kgraph.coordinates.findElements(findScratchNone).to(); + CHECK(scratchNoneCoordinates.size() == 1); + } + + SECTION("Load, store, and reset flags scratch space correctly") + { + auto findScratchZeroed = [&](int tag) { + auto maybeUser = kgraph.coordinates.get(tag); + if(!maybeUser) + return false; + return maybeUser->argumentName == getScratchName(ScratchPolicy::ZeroedBeforeAndAfter); + }; + + auto scratchZeroedCoordinates + = kgraph.coordinates.findElements(findScratchZeroed).to(); + CHECK(scratchZeroedCoordinates.size() == 2); + + // Verify the reset flags coordinate connects to the original flags coordinate via a duplicate edge + auto resetFlagsCoordinate = -1, originalFlagsCoordinate = -1; + for(const auto& tag : scratchZeroedCoordinates) + { + // Duplicate edge connects the reset flags coordinate to the original flags coordinate + auto isDuplicate = isEdge; + auto outDuplicates + = kgraph.coordinates.getOutputNodeIndices(tag, isDuplicate).to(); + + CHECK((outDuplicates.size() == 1 || outDuplicates.empty())); + if(outDuplicates.size() == 1) + { + resetFlagsCoordinate = tag; + originalFlagsCoordinate = outDuplicates[0]; + } + } + CHECK(resetFlagsCoordinate != -1); + CHECK(originalFlagsCoordinate != -1); + + // Duplicate coordinate should have a higher tag than the original flags coordinate + CHECK(resetFlagsCoordinate > originalFlagsCoordinate); + + // Verify the graph of flags scratch coordinates matches AddStreamK implementation + auto isPassThroughEdge = isEdge; + auto isJoinEdge = isEdge; + auto isSplitEdge = isEdge; + auto maybeNextWorkgroupTag + = kgraph.coordinates.getOutputNodeIndices(originalFlagsCoordinate, isPassThroughEdge) + .to(); + CHECK(maybeNextWorkgroupTag.size() == 1); + auto nextWorkgroupTag = maybeNextWorkgroupTag[0]; + auto maybeSplit = kgraph.coordinates.getOutputNodeIndices(nextWorkgroupTag, isSplitEdge) + .to(); + CHECK(maybeSplit.size() == 3); + auto maybeResetNextWorkgroupTag0 + = kgraph.coordinates.getOutputNodeIndices(maybeSplit[0], isJoinEdge).to(); + auto maybeResetNextWorkgroupTag1 + = kgraph.coordinates.getOutputNodeIndices(maybeSplit[1], isJoinEdge).to(); + auto maybeResetNextWorkgroupTag2 + = kgraph.coordinates.getOutputNodeIndices(maybeSplit[2], isJoinEdge).to(); + CHECK(maybeResetNextWorkgroupTag0.size() == 1); + CHECK(maybeResetNextWorkgroupTag1.size() == 1); + CHECK(maybeResetNextWorkgroupTag2.size() == 1); + CHECK(maybeResetNextWorkgroupTag0[0] == maybeResetNextWorkgroupTag1[0]); + CHECK(maybeResetNextWorkgroupTag1[0] == maybeResetNextWorkgroupTag2[0]); + auto maybeResetFlagsCoordinate + = kgraph.coordinates + .getOutputNodeIndices(maybeResetNextWorkgroupTag0[0], isPassThroughEdge) + .to(); + CHECK(maybeResetFlagsCoordinate.size() == 1); + CHECK(maybeResetFlagsCoordinate[0] == resetFlagsCoordinate); + + // Verify there are two store flags operations + auto storeFlagsOps + = kgraph.control.findElements(kgraph.control.isElemType()).to(); + CHECK(storeFlagsOps.size() == 2); + + // Verify one StoreSGPR is stores the flags and the other resets the flags + auto storeFlagsTag = -1, resetFlagsTag = -1; + for(auto const tag : storeFlagsOps) + { + auto flagCoordinateTag = kgraph.mapper.get(tag); + if(flagCoordinateTag == resetFlagsCoordinate) + { + resetFlagsTag = tag; + } + else + { + storeFlagsTag = tag; + } + } + CHECK(storeFlagsTag != -1); + CHECK(resetFlagsTag != -1); + + // Verify the reset flags happens after the store flags + auto order + = kgraph.control.compareNodes(rocRoller::UpdateCache, storeFlagsTag, resetFlagsTag); + CHECK(order == NodeOrdering::LeftFirst); + } +} diff --git a/shared/rocroller/test/catch/ScratchOperationTest.cpp b/shared/rocroller/test/catch/ScratchOperationTest.cpp index 621b0877d92..a793989dd58 100644 --- a/shared/rocroller/test/catch/ScratchOperationTest.cpp +++ b/shared/rocroller/test/catch/ScratchOperationTest.cpp @@ -24,6 +24,9 @@ * *******************************************************************************/ +#include +#include +#include #include #include #include @@ -32,6 +35,9 @@ #include #include +#include "TestContext.hpp" + +#include #include using namespace rocRoller; @@ -226,4 +232,99 @@ namespace ScratchOperationTest CHECK(scratchOpTag2.uninitialized() == false); } } + + TEST_CASE("Scratch allocator per policy", "[scratch][context]") + { + auto context = TestContext::ForDefaultTarget(); + + SECTION("Context initializes all policies to zero") + { + for(size_t i = 0; i < static_cast(ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + auto amount = context->getScratchAmount(policy); + auto value = rocRoller::Expression::evaluate(amount); + CHECK(rocRoller::getUnsignedInt(value) == 0); + } + } + + SECTION("allocateScratch returns offset before allocation and accumulates size") + { + auto size1 = rocRoller::Expression::literal(100u); + auto size2 = rocRoller::Expression::literal(200u); + + auto offset1 = context->allocateScratch(ScratchPolicy::None, size1); + auto offset2 = context->allocateScratch(ScratchPolicy::None, size2); + + // First allocation should return offset 0 + CHECK(rocRoller::getUnsignedInt(rocRoller::Expression::evaluate(offset1)) == 0); + // Second allocation should return offset 100 (after first allocation) + CHECK(rocRoller::getUnsignedInt(rocRoller::Expression::evaluate(offset2)) == 100); + + // Total should now be 300 + auto amount = context->getScratchAmount(ScratchPolicy::None); + auto value = rocRoller::Expression::evaluate(amount); + CHECK(rocRoller::getUnsignedInt(value) == 300); + } + + SECTION("allocateScratch returns correct offset per policy") + { + auto size = rocRoller::Expression::literal(500u); + auto offset = context->allocateScratch(ScratchPolicy::ZeroedBeforeAndAfter, size); + + // First allocation should return offset 0 + CHECK(rocRoller::getUnsignedInt(rocRoller::Expression::evaluate(offset)) == 0); + + // Query total should now be 500 + auto amount = context->getScratchAmount(ScratchPolicy::ZeroedBeforeAndAfter); + auto value = rocRoller::Expression::evaluate(amount); + CHECK(rocRoller::getUnsignedInt(value) == 500); + } + + SECTION("Different policies have independent allocations") + { + auto sizeNone = rocRoller::Expression::literal(100u); + auto sizeZeroed = rocRoller::Expression::literal(200u); + + context->allocateScratch(ScratchPolicy::None, sizeNone); + context->allocateScratch(ScratchPolicy::ZeroedBeforeAndAfter, sizeZeroed); + + auto amountNone = context->getScratchAmount(ScratchPolicy::None); + auto amountZeroed = context->getScratchAmount(ScratchPolicy::ZeroedBeforeAndAfter); + + auto valueNone = rocRoller::Expression::evaluate(amountNone); + auto valueZeroed = rocRoller::Expression::evaluate(amountZeroed); + + CHECK(rocRoller::getUnsignedInt(valueNone) == 100); + CHECK(rocRoller::getUnsignedInt(valueZeroed) == 200); + } + } + + TEST_CASE("Utilities to get scratch policy name", "[scratch][utils]") + { + SECTION("Returns correct name for None policy") + { + auto name = rocRoller::getScratchName(ScratchPolicy::None); + CHECK(name == "SCRATCH_None"); + } + + SECTION("Returns correct name for ZeroedBeforeAndAfter policy") + { + auto name = rocRoller::getScratchName(ScratchPolicy::ZeroedBeforeAndAfter); + CHECK(name == "SCRATCH_ZeroedBeforeAndAfter"); + } + + SECTION("All policies have unique names") + { + std::set names; + for(size_t i = 0; i < static_cast(ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + auto name = rocRoller::getScratchName(policy); + CHECK(names.find(name) == names.end()); + names.insert(name); + } + CHECK(names.size() == static_cast(ScratchPolicy::Count)); + } + } } diff --git a/shared/rocroller/test/common/include/common/CommonGraphs.hpp b/shared/rocroller/test/common/include/common/CommonGraphs.hpp index ddf65ea4e28..7d3b1560f76 100644 --- a/shared/rocroller/test/common/include/common/CommonGraphs.hpp +++ b/shared/rocroller/test/common/include/common/CommonGraphs.hpp @@ -30,6 +30,7 @@ #pragma once +#include #include #include @@ -39,6 +40,7 @@ #include #include #include +#include #include @@ -229,6 +231,9 @@ namespace rocRollerTest rocRoller::Operations::OperationTag m_tagA, m_tagB, m_tagC, m_tagD; rocRoller::Operations::OperationTag m_tagNumWGs; + std::map + m_scratchTags; + CommandPtr m_command; }; diff --git a/shared/rocroller/test/common/src/CommonGraphs.cpp b/shared/rocroller/test/common/src/CommonGraphs.cpp index b98ff13a1b6..0319ae1cad0 100644 --- a/shared/rocroller/test/common/src/CommonGraphs.cpp +++ b/shared/rocroller/test/common/src/CommonGraphs.cpp @@ -28,6 +28,7 @@ #include #include +#include #include namespace rocRollerTest::Graphs @@ -285,12 +286,24 @@ namespace rocRollerTest::Graphs rocRoller::NUMWGS); } - auto tagScratch = m_command->allocateTag(); + m_scratchTags[Operations::ScratchPolicy::None] = m_command->allocateTag(); + m_command->addOperation(rocRoller::Operations::Scratch( + m_scratchTags[Operations::ScratchPolicy::None], Operations::ScratchPolicy::None)); m_command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), - tagScratch, + m_scratchTags[Operations::ScratchPolicy::None], ArgumentType::Value, DataDirection::ReadWrite, - rocRoller::SCRATCH); + getScratchName(Operations::ScratchPolicy::None)); + m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter] = m_command->allocateTag(); + m_command->addOperation(rocRoller::Operations::Scratch( + m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter], + Operations::ScratchPolicy::ZeroedBeforeAndAfter)); + m_command->allocateArgument( + VariableType(DataType::UInt32, PointerType::PointerGlobal), + m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter], + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(Operations::ScratchPolicy::ZeroedBeforeAndAfter)); } CommandPtr GEMM::getCommand() diff --git a/shared/rocroller/test/unit/GEMMFusion.cpp b/shared/rocroller/test/unit/GEMMFusion.cpp index 552b80e3865..e46dd16a40a 100644 --- a/shared/rocroller/test/unit/GEMMFusion.cpp +++ b/shared/rocroller/test/unit/GEMMFusion.cpp @@ -29,6 +29,7 @@ #include #endif /* ROCROLLER_USE_HIP */ +#include #include #include @@ -38,6 +39,7 @@ #include #include #include +#include #include #include #include @@ -221,12 +223,19 @@ namespace GEMMDriverTest rocRoller::Operations::Tensor(2, dataType, oneStridesN)); // E command->addOperation(rocRoller::Operations::T_Store_Tiled(tagRelu, tagTensorRelu)); - auto tagScratch = command->allocateTag(); - command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), - tagScratch, - ArgumentType::Value, - DataDirection::ReadWrite, - rocRoller::SCRATCH); + Operations::OperationTag tagScratch[static_cast(Operations::ScratchPolicy::Count)]; + for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + tagScratch[i] = command->allocateTag(); + command->addOperation(rocRoller::Operations::Scratch(tagScratch[i], policy)); + command->allocateArgument( + VariableType(DataType::UInt32, PointerType::PointerGlobal), + tagScratch[i], + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(policy)); + } auto params = std::make_shared(); params->setManualKernelDimension(2); @@ -329,10 +338,20 @@ namespace GEMMDriverTest { commandArgs.setArgument(command->getNextTag(), ArgumentType::Value, gemm.numWGs); } - auto scratchSpaceRequired - = commandKernel.scratchSpaceRequired(commandArgs.runtimeArguments()); - auto deviceScratch = make_shared_device(scratchSpaceRequired, 0); - commandArgs.setArgument(tagScratch, ArgumentType::Value, deviceScratch.get()); + std::shared_ptr + deviceScratch[static_cast(Operations::ScratchPolicy::Count)]; + size_t scratchSpaceRequired[static_cast(Operations::ScratchPolicy::Count)]; + for(size_t i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + { + scratchSpaceRequired[i] = commandKernel.scratchSpaceRequired( + static_cast(i), commandArgs.runtimeArguments()); + if(scratchSpaceRequired[i] > 0) + { + deviceScratch[i] = make_shared_device(scratchSpaceRequired[i], 0); + commandArgs.setArgument( + command->getNextTag(), ArgumentType::Value, deviceScratch[i].get()); + } + } // Host result std::vector h_result(M * N, 0.0); @@ -365,8 +384,14 @@ namespace GEMMDriverTest for(int iteration = 0; iteration < numIters; ++iteration) { ASSERT_THAT(hipMemset(deviceD.get(), 0, M * N * sizeof(T)), HasHipSuccess(0)); - ASSERT_THAT(hipMemset(deviceScratch.get(), 0, scratchSpaceRequired), - HasHipSuccess(0)); + for(size_t i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + { + if(scratchSpaceRequired[i] > 0) + { + ASSERT_THAT(hipMemset(deviceScratch[i].get(), 0, scratchSpaceRequired[i]), + HasHipSuccess(0)); + } + } commandKernel.launchKernel(commandArgs.runtimeArguments()); m_context = commandKernel.getContext(); @@ -376,6 +401,23 @@ namespace GEMMDriverTest d_result.data(), deviceD.get(), M * N * sizeof(T), hipMemcpyDeviceToHost), HasHipSuccess(0)); + // Verify ZeroedBeforeAndAfter scratch is all zeros after kernel + auto zeroedIdx + = static_cast(Operations::ScratchPolicy::ZeroedBeforeAndAfter); + if(scratchSpaceRequired[zeroedIdx] > 0) + { + std::vector zeroedResult(scratchSpaceRequired[zeroedIdx]); + ASSERT_THAT(hipMemcpy(zeroedResult.data(), + deviceScratch[zeroedIdx].get(), + scratchSpaceRequired[zeroedIdx], + hipMemcpyDeviceToHost), + HasHipSuccess(0)); + EXPECT_TRUE(std::all_of( + zeroedResult.begin(), zeroedResult.end(), [](uint8_t v) { return v == 0; })) + << "ZeroedBeforeAndAfter scratch should be all zeros after kernel " + "execution"; + } + auto tol = gemmAcceptableError( M, N, K, m_context->targetArchitecture().target()); auto res = compare(d_result, h_result, tol); diff --git a/shared/rocroller/test/unit/GEMMStreamKTest.cpp b/shared/rocroller/test/unit/GEMMStreamKTest.cpp index 2eb5d92f8d6..b6bf3c9b911 100644 --- a/shared/rocroller/test/unit/GEMMStreamKTest.cpp +++ b/shared/rocroller/test/unit/GEMMStreamKTest.cpp @@ -36,8 +36,12 @@ namespace GEMMTests using namespace rocRoller; namespace SolutionParams = rocRoller::Parameters::Solution; + // ProblemConfig: (dataTypeAB, macM, macN, macK, m, n, k, numWGs) + using ProblemConfig = std::tuple; + class StreamKMultipleFixupsTestGPU - : public BaseGEMMContextFixture> @@ -67,46 +71,55 @@ namespace GEMMTests { }; - TEST_P(StreamKMultipleFixupsTestGPU, GPU_BasicGEMMFP16) + TEST_P(StreamKMultipleFixupsTestGPU, GPU_BasicGEMM) { if(m_context->targetArchitecture().target().isCDNA1GPU()) { - GTEST_SKIP() << "Skipping GPU_BasicGEMMFP16 test: CDNA1 not supported"; + GTEST_SKIP() << "Skipping GPU_BasicGEMM test: CDNA1 not supported"; } - GEMMProblem gemm; + auto [problemConfig, mode, loadPathA, loadPathB, storeLDSD] = std::get<1>(GetParam()); + auto [dataTypeAB, macM, macN, macK, m, n, k, numWGs] = problemConfig; - hipDeviceProp_t deviceProperties; - ASSERT_THAT(hipGetDeviceProperties(&deviceProperties, 0), HasHipSuccess(0)); + GEMMProblem gemm; - gemm.macM = 128; - gemm.macN = 128; - gemm.macK = 16; + gemm.macM = macM; + gemm.macN = macN; + gemm.macK = macK; + gemm.m = m; + gemm.n = n; + gemm.k = k; + gemm.numWGs = numWGs; - gemm.waveK = 8; + if(dataTypeAB == DataType::Half) + { + gemm.waveK = 8; + } gemm.workgroupSizeX = 128; gemm.workgroupSizeY = 2; - gemm.numWGs = 128; - - auto numTilesM = 1; - auto numTilesN = 2; - auto numTilesK = 249; - - gemm.m = numTilesM * gemm.macM; - gemm.n = numTilesN * gemm.macN; - gemm.k = numTilesK * gemm.macK; - // assert that the number of output tiles is smaller than number of WGs // which means there is not enough data-parallel tiles, and has to split // K dimension into multiple tiles ASSERT_GE(gemm.numWGs, gemm.m * gemm.n / gemm.macM / gemm.macN); - std::tie(gemm.streamK, gemm.loadPathA, gemm.loadPathB, gemm.storeLDSD) - = std::get<1>(GetParam()); + gemm.streamK = mode; + gemm.loadPathA = loadPathA; + gemm.loadPathB = loadPathB; + gemm.storeLDSD = storeLDSD; - basicGEMM(gemm); + switch(dataTypeAB) + { + case DataType::Half: + basicGEMM(gemm, false, false, 100); + break; + case DataType::Float: + basicGEMM(gemm, false, false, 100); + break; + default: + Throw(fmt::format("Unexpected data type: {}. ", toString(dataTypeAB))); + } } TEST_P(StreamKWGMTestGPU, GPU_BasicGEMMStreamKWorkgroupMapping) @@ -220,6 +233,17 @@ namespace GEMMTests ::testing::Combine( currentGPUISA(), ::testing::Combine( + ::testing::Values( + // ProblemConfig: (dataTypeAB, macM, macN, macK, m, n, k, numWGs) + ProblemConfig{rocRoller::DataType::Half, 128, 128, 16, 128, 256, 15936, 128}, + ProblemConfig{rocRoller::DataType::Float, + 64, + 64, + 64, + 256, + 256, + 16384, + 256}), /* problemConfig */ ::testing::Values( StreamKMode::Standard, StreamKMode::TwoTile, StreamKMode::TwoTileDPFirst), ::testing::Values(SolutionParams::LoadPath::BufferToLDSViaVGPR, diff --git a/shared/rocroller/test/unit/GEMMTestBase.hpp b/shared/rocroller/test/unit/GEMMTestBase.hpp index 4f30b732f8b..d30c9bcca47 100644 --- a/shared/rocroller/test/unit/GEMMTestBase.hpp +++ b/shared/rocroller/test/unit/GEMMTestBase.hpp @@ -434,14 +434,8 @@ namespace GEMMTests command->addOperation(rocRoller::Operations::T_Store_Tiled(tagCvt, tagTensorD)); } - auto tagScratch = command->allocateTag(); - command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal), - tagScratch, - ArgumentType::Value, - DataDirection::ReadWrite, - rocRoller::SCRATCH); - - Operations::OperationTag tagNumWGs; + std::map scratchTags; + Operations::OperationTag tagNumWGs; if(gemm.streamK) { tagNumWGs = command->allocateTag(); @@ -450,6 +444,29 @@ namespace GEMMTests ArgumentType::Value, DataDirection::ReadOnly, rocRoller::NUMWGS); + + scratchTags[Operations::ScratchPolicy::None] = command->allocateTag(); + command->addOperation( + rocRoller::Operations::Scratch(scratchTags.at(Operations::ScratchPolicy::None), + Operations::ScratchPolicy::None)); + command->allocateArgument( + VariableType(DataType::UInt32, PointerType::PointerGlobal), + scratchTags.at(Operations::ScratchPolicy::None), + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(Operations::ScratchPolicy::None)); + + scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter] + = command->allocateTag(); + command->addOperation(rocRoller::Operations::Scratch( + scratchTags.at(Operations::ScratchPolicy::ZeroedBeforeAndAfter), + Operations::ScratchPolicy::ZeroedBeforeAndAfter)); + command->allocateArgument( + VariableType(DataType::UInt32, PointerType::PointerGlobal), + scratchTags.at(Operations::ScratchPolicy::ZeroedBeforeAndAfter), + ArgumentType::Value, + DataDirection::ReadWrite, + getScratchName(Operations::ScratchPolicy::ZeroedBeforeAndAfter)); } Operations::OperationTag tagWGM; @@ -653,16 +670,28 @@ namespace GEMMTests commandArgs.setArgument(tagScalarSeed, ArgumentType::Value, srCvtSeed.value()); // Create scratch space + size_t scratchSpaceRequired[static_cast(Operations::ScratchPolicy::Count)]; + std::shared_ptr + deviceScratch[static_cast(Operations::ScratchPolicy::Count)]; + std::fill(std::begin(scratchSpaceRequired), std::end(scratchSpaceRequired), 0); + std::fill(std::begin(deviceScratch), std::end(deviceScratch), nullptr); if(gemm.streamK) { commandArgs.setArgument(tagNumWGs, ArgumentType::Value, gemm.numWGs); + for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + { + auto policy = static_cast(i); + scratchSpaceRequired[i] = commandKernel.scratchSpaceRequired( + policy, commandArgs.runtimeArguments()); + if(scratchSpaceRequired[i] > 0) + { + deviceScratch[i] = make_shared_device(scratchSpaceRequired[i], 0); + commandArgs.setArgument( + scratchTags.at(policy), ArgumentType::Value, deviceScratch[i].get()); + } + } } - auto scratchSpaceRequired - = commandKernel.scratchSpaceRequired(commandArgs.runtimeArguments()); - auto deviceScratch = make_shared_device(scratchSpaceRequired, 0); - commandArgs.setArgument(tagScratch, ArgumentType::Value, deviceScratch.get()); - if(gemm.workgroupMappingDim != -1) { commandArgs.setArgument(tagWGM, ArgumentType::Value, gemm.workgroupMappingValue); @@ -774,8 +803,18 @@ namespace GEMMTests for(int iteration = 0; iteration < numIters; ++iteration) { ASSERT_THAT(hipMemset(deviceD.get(), 0, M * N * sizeof(TD)), HasHipSuccess(0)); - ASSERT_THAT(hipMemset(deviceScratch.get(), 0, scratchSpaceRequired), - HasHipSuccess(0)); + if(iteration == 0) + { + for(int i = 0; i < static_cast(Operations::ScratchPolicy::Count); ++i) + { + if(scratchSpaceRequired[i] > 0) + { + ASSERT_THAT( + hipMemset(deviceScratch[i].get(), 0, scratchSpaceRequired[i]), + HasHipSuccess(0)); + } + } + } commandKernel.launchKernel(commandArgs.runtimeArguments()); @@ -792,6 +831,36 @@ namespace GEMMTests res.acceptableError.relativeL2Tolerance, iteration); + // Verify ZeroedBeforeAndAfter scratch is all zeros after kernel execution + auto zeroedIdx + = static_cast(Operations::ScratchPolicy::ZeroedBeforeAndAfter); + if(scratchSpaceRequired[zeroedIdx] > 0) + { + std::vector zeroedResult(scratchSpaceRequired[zeroedIdx]); + ASSERT_THAT(hipMemcpy(zeroedResult.data(), + deviceScratch[zeroedIdx].get(), + scratchSpaceRequired[zeroedIdx], + hipMemcpyDeviceToHost), + HasHipSuccess(0)); + + bool allZeros = true; + for(size_t i = 0; i < zeroedResult.size(); ++i) + { + if(zeroedResult[i] != 0) + { + allZeros = false; + // Print as uint32 since flags are UInt32 + size_t flagIndex = i / sizeof(uint32_t); + std::cerr << "Non-zero at byte " << i << " (flag index " << flagIndex + << "): " << static_cast(zeroedResult[i]) << std::endl; + } + } + EXPECT_TRUE(allZeros) + << "ZeroedBeforeAndAfter scratch should be all zeros after kernel " + "execution (size=" + << scratchSpaceRequired[zeroedIdx] << " bytes)"; + } + if(debuggable && !res.ok) { for(size_t i = 0; i < M; i++)