Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3ee9b9c
Change coordinate and control grph to reset flags
Dec 1, 2025
e5fc05d
Formatting
Dec 1, 2025
c096700
Remove wait zero after reset
Dec 2, 2025
5ce1959
Map ScratchPolicy to m_scratchAllocator
Dec 4, 2025
4a6c692
Enable multiple scratch allocation
Dec 4, 2025
d92cf90
Resolve merge conflicts
Dec 4, 2025
261b980
Add unit tests
Dec 5, 2025
002c05b
Formatting, add comments
Dec 5, 2025
c99bb7a
Minor changes
Dec 5, 2025
4575181
Add streamk_fp4 suite
Dec 6, 2025
33e427d
Fix for-loop range
Dec 6, 2025
3694e65
Fix for-lopp range
Dec 8, 2025
2d7d21a
Merge branch 'develop' into users/liu-yiqian/reset-flag
liu-yiqian Dec 8, 2025
fd85c81
Remove streamk suite, will add that separately
Dec 8, 2025
a2979c0
Update test, allocateScratch() returns size
Dec 10, 2025
f8ea067
Merge branch 'develop' into users/liu-yiqian/reset-flag
liu-yiqian Dec 10, 2025
c43dd23
Formatting
Dec 10, 2025
ee18ee9
Merge branch 'users/liu-yiqian/reset-flag' of github.com:ROCm/rocm-li…
Dec 10, 2025
6ea17d9
Only wave0_workitem0 writes to flag
Dec 11, 2025
c19d7f0
Print info when flag is not rest to 0, remove unnecessary WaitZero in…
Dec 11, 2025
3b9ec64
Enable rocroller streamk in hipblaslt
Dec 11, 2025
96262ce
Update SolutionIndexParameters
Dec 11, 2025
fa87f64
Use prob.Synchronizer for flags
Dec 11, 2025
f85411b
Enable streamk
Dec 11, 2025
7aa9059
Disable workgroupMapping when streamK is on
Dec 11, 2025
a769025
Not use streamk for f6 kernels
Dec 12, 2025
6cb5d65
Resolve conflicts
Dec 12, 2025
90ff6bd
Merge branch 'develop' into users/liu-yiqian/reset-flag
liu-yiqian Dec 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion shared/rocroller/client/include/client/GEMMSolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "GEMMParameters.hpp"

#include <rocRoller/Operations/CommandArgument_fwd.hpp>
#include <rocRoller/Operations/Scratch_fwd.hpp>

using namespace rocRoller;

Expand Down Expand Up @@ -61,7 +62,8 @@ namespace rocRoller
{
}

virtual Operations::OperationTag getScratchTag() const
virtual Operations::OperationTag
getScratchTag(Operations::ScratchPolicy scratchPolicy) const
{
return {};
}
Expand Down
34 changes: 26 additions & 8 deletions shared/rocroller/client/include/client/StreamKGEMMSolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ namespace rocRoller
{
class StreamKGEMMSolution : public DataParallelGEMMSolution
{
Operations::OperationTag m_scratchTag, m_numWGsTag;
std::map<Operations::ScratchPolicy, Operations::OperationTag> m_scratchTags;
Operations::OperationTag m_numWGsTag;

public:
using DataParallelGEMMSolution::DataParallelGEMMSolution;

Operations::OperationTag getScratchTag() const override
Operations::OperationTag
getScratchTag(Operations::ScratchPolicy scratchPolicy) const override
{
return m_scratchTag;
return m_scratchTags.at(scratchPolicy);
}

protected:
Expand All @@ -63,15 +65,31 @@ namespace rocRoller
DataDirection::ReadOnly,
rocRoller::NUMWGS);

m_scratchTag = command->allocateTag();
command->addOperation(rocRoller::Operations::Scratch(
m_scratchTag, Operations::ScratchPolicy::None));
// Create a scratch operation for tile data
m_scratchTags[Operations::ScratchPolicy::None] = command->allocateTag();
command->addOperation(
Operations::Scratch(m_scratchTags[Operations::ScratchPolicy::None],
Operations::ScratchPolicy::None));
command->allocateArgument(
VariableType(DataType::UInt32, PointerType::PointerGlobal),
m_scratchTag,
m_scratchTags[Operations::ScratchPolicy::None],
ArgumentType::Value,
DataDirection::ReadWrite,
rocRoller::SCRATCH);
getScratchName(Operations::ScratchPolicy::None));

// Create a scratch operation for flags
m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter]
= command->allocateTag();
command->addOperation(Operations::Scratch(
m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter],
Operations::ScratchPolicy::ZeroedBeforeAndAfter));
command->allocateArgument(
VariableType(DataType::UInt32, PointerType::PointerGlobal),
m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter],
ArgumentType::Value,
DataDirection::ReadWrite,
getScratchName(Operations::ScratchPolicy::ZeroedBeforeAndAfter));

return command;
}

Expand Down
31 changes: 27 additions & 4 deletions shared/rocroller/client/src/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
*******************************************************************************/

#include "rocRoller/Serialization/YAML.hpp"
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <string>
Expand Down Expand Up @@ -344,14 +345,18 @@ namespace rocRoller::Client::GEMMClient
auto runtimeArgs = commandArgs.runtimeArguments();

// Note: the lifetime of deviceScratch needs to exceed kernel executions
std::shared_ptr<uint8_t> deviceScratch;
std::shared_ptr<uint8_t>
deviceScratch[static_cast<size_t>(Operations::ScratchPolicy::Count)];

for(int i = 0; i < static_cast<int>(Operations::ScratchPolicy::Count); ++i)
{
auto scratchSpaceRequired = commandKernel->scratchSpaceRequired(runtimeArgs);
auto policy = static_cast<Operations::ScratchPolicy>(i);
auto scratchSpaceRequired = commandKernel->scratchSpaceRequired(policy, runtimeArgs);
if(scratchSpaceRequired > 0)
{
deviceScratch = make_shared_device<uint8_t>(scratchSpaceRequired, 0);
deviceScratch[i] = make_shared_device<uint8_t>(scratchSpaceRequired, 0);
commandArgs.setArgument(
gemm->getScratchTag(), ArgumentType::Value, deviceScratch.get());
gemm->getScratchTag(policy), ArgumentType::Value, deviceScratch[i].get());
}
}

Expand Down Expand Up @@ -457,6 +462,24 @@ namespace rocRoller::Client::GEMMClient
auto [correct, rnorm] = validate<A, B, C, D>(
hostA, hostB, hostC, hostD, hostScaleA, hostScaleB, problemParams, arch);

// Verify ZeroedBeforeAndAfter scratch is all zeros after kernel
auto zeroedIdx = static_cast<size_t>(Operations::ScratchPolicy::ZeroedBeforeAndAfter);
if(deviceScratch[zeroedIdx])
{
auto zeroedSize = commandKernel->scratchSpaceRequired(
Operations::ScratchPolicy::ZeroedBeforeAndAfter, runtimeArgs);
std::vector<uint8_t> zeroedResult(zeroedSize);
AssertFatal(hipMemcpy(zeroedResult.data(),
deviceScratch[zeroedIdx].get(),
zeroedSize,
hipMemcpyDeviceToHost)
== (hipError_t)HIP_SUCCESS);
AssertFatal(
std::all_of(
zeroedResult.begin(), zeroedResult.end(), [](uint8_t v) { return v == 0; }),
"ZeroedBeforeAndAfter scratch should be all zeros after kernel execution");
}

result.checked = true;
result.correct = correct;
result.rnorm = rnorm;
Expand Down
8 changes: 6 additions & 2 deletions shared/rocroller/lib/include/rocRoller/CommandSolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,17 @@ namespace rocRoller

/**
* @brief Returns the total number of bytes required for scratch space
* for the specified scratch policy.
*
* If this value is greather than 0, the user is required to allocate this
* If this value is greater than 0, the user is required to allocate this
* amount of device memory and pass it into the kernel.
*
* @param policy The scratch policy to query
* @param args The runtime arguments
* @return size_t
*/
size_t scratchSpaceRequired(RuntimeArguments const& args) const;
size_t scratchSpaceRequired(Operations::ScratchPolicy policy,
RuntimeArguments const& args) const;

/**
* @brief Returns the workgroup size
Expand Down
35 changes: 21 additions & 14 deletions shared/rocroller/lib/include/rocRoller/Context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include <rocRoller/KernelGraph/RegisterTagManager_fwd.hpp>
#include <rocRoller/KernelGraph/ScopeManager_fwd.hpp>
#include <rocRoller/KernelOptions.hpp>
#include <rocRoller/Operations/Scratch_fwd.hpp>
#include <rocRoller/ScheduledInstructions_fwd.hpp>
#include <rocRoller/Scheduling/Scheduling_fwd.hpp>
#include <rocRoller/Utilities/Random_fwd.hpp>
Expand Down Expand Up @@ -127,18 +128,23 @@ namespace rocRoller
void setKernel(AssemblyKernelPtr);

/**
* @brief Returns an expression representing how much scratch space is required (in bytes)
* @brief Allocate scratch space for the specified scratch policy.
*
* @return Expression::ExpressionPtr
* @param policy The scratch policy to allocate for
* @param size Number of bytes requested
* @return Expression::ExpressionPtr The offset before this allocation
*/
Expression::ExpressionPtr getScratchAmount() const;
Expression::ExpressionPtr allocateScratch(Operations::ScratchPolicy policy,
Expression::ExpressionPtr size);

/**
* @brief Allocate more scratch space
* @brief Returns an expression representing how much scratch space is required (in bytes)
* for the specified scratch policy.
*
* @param size Number of bytes requested
* @param policy The scratch policy to query
* @return Expression::ExpressionPtr
*/
void allocateScratch(Expression::ExpressionPtr size);
Expression::ExpressionPtr getScratchAmount(Operations::ScratchPolicy policy) const;

/**
* @brief Get register scope manager.
Expand Down Expand Up @@ -168,14 +174,15 @@ namespace rocRoller
std::array<std::shared_ptr<Register::Allocator>, static_cast<size_t>(Register::Type::Count)>
m_allocators;

std::shared_ptr<Scheduling::IObserver> m_observer;
AssemblyKernelPtr m_kernel;
std::shared_ptr<ArgumentLoader> m_argLoader;
std::shared_ptr<ScheduledInstructions> m_instructions;
std::shared_ptr<MemoryInstructions> m_mem;
LabelAllocatorPtr m_labelAllocator;
std::shared_ptr<LDSAllocator> m_ldsAllocator;
Expression::ExpressionPtr m_scratchAllocator;
std::shared_ptr<Scheduling::IObserver> m_observer;
AssemblyKernelPtr m_kernel;
std::shared_ptr<ArgumentLoader> m_argLoader;
std::shared_ptr<ScheduledInstructions> m_instructions;
std::shared_ptr<MemoryInstructions> m_mem;
LabelAllocatorPtr m_labelAllocator;
std::shared_ptr<LDSAllocator> m_ldsAllocator;
std::array<Expression::ExpressionPtr, static_cast<size_t>(Operations::ScratchPolicy::Count)>
m_scratchSizes;
std::shared_ptr<CopyGenerator> m_copier;
std::shared_ptr<BranchGenerator> m_brancher;
std::shared_ptr<CrashKernelGenerator> m_crasher;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,11 @@ namespace rocRoller
*
* @param size How many elements make up the User dimension.
* @param offset Location of data within the scratch space
* @param argName Name of the argument for this scratch space
*/
User(Expression::ExpressionPtr size, Expression::ExpressionPtr offset);
User(Expression::ExpressionPtr size,
Expression::ExpressionPtr offset,
std::string const& argName);

std::string name() const override;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,11 +319,15 @@ namespace rocRoller
*
* @param size
* @param varType
* @param policy The scratch policy to use for allocation
* @param context
* @return User
*/
rocRoller::KernelGraph::CoordinateGraph::User newScratchCoordinate(
Expression::ExpressionPtr size, VariableType varType, ContextPtr context);
rocRoller::KernelGraph::CoordinateGraph::User
newScratchCoordinate(Expression::ExpressionPtr size,
VariableType varType,
Operations::ScratchPolicy policy,
ContextPtr context);

/**
* @brief Replace operation with a new operation.
Expand Down
7 changes: 7 additions & 0 deletions shared/rocroller/lib/include/rocRoller/KernelOptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <string>

#include <rocRoller/AssertOpKinds_fwd.hpp>
#include <rocRoller/Operations/Scratch_fwd.hpp>
#include <rocRoller/Utilities/EnumBitset.hpp>
#include <rocRoller/Utilities/Settings_fwd.hpp>

Expand All @@ -49,6 +50,12 @@ namespace rocRoller
const std::string NUMWGS = "numWGs";
const std::string WGM = "WGM";

// Helper to get scratch argument name for a specific policy
inline std::string getScratchName(Operations::ScratchPolicy policy)
{
return rocRoller::SCRATCH + "_" + Operations::toString(policy);
}

class KernelOptions
{
public:
Expand Down
5 changes: 3 additions & 2 deletions shared/rocroller/lib/source/CommandSolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,10 @@ namespace rocRoller
return m_context;
}

size_t CommandKernel::scratchSpaceRequired(RuntimeArguments const& args) const
size_t CommandKernel::scratchSpaceRequired(Operations::ScratchPolicy policy,
RuntimeArguments const& args) const
{
auto amount = m_context->getScratchAmount();
auto amount = m_context->getScratchAmount(policy);

auto times = evaluationTimes(amount);
AssertFatal(times[Expression::EvaluationTime::Translate]
Expand Down
18 changes: 13 additions & 5 deletions shared/rocroller/lib/source/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@
namespace rocRoller
{
Context::Context()
: m_scratchAllocator(Expression::literal(0u))
{
// Initialize scratch sizes for each policy with zero
for(size_t i = 0; i < static_cast<size_t>(Operations::ScratchPolicy::Count); ++i)
{
m_scratchSizes[i] = Expression::literal(0u);
}
}

ContextPtr Context::ForDefaultHipDevice(std::string const& kernelName,
Expand Down Expand Up @@ -287,14 +291,18 @@ namespace rocRoller
m_kernel = assemblyKernel;
}

Expression::ExpressionPtr Context::getScratchAmount() const
Expression::ExpressionPtr Context::allocateScratch(Operations::ScratchPolicy policy,
Expression::ExpressionPtr size)
{
return m_scratchAllocator;
auto idx = static_cast<size_t>(policy);
auto currentOffset = m_scratchSizes[idx];
m_scratchSizes[idx] = simplify(m_scratchSizes[idx] + size);
return currentOffset;
}

void Context::allocateScratch(Expression::ExpressionPtr size)
Expression::ExpressionPtr Context::getScratchAmount(Operations::ScratchPolicy policy) const
{
m_scratchAllocator = simplify(m_scratchAllocator + size);
return m_scratchSizes[static_cast<size_t>(policy)];
}

void Context::scheduleCopy(Instruction const& inst)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,11 @@ namespace rocRoller
return name() + stag;
}

User::User(Expression::ExpressionPtr size, Expression::ExpressionPtr offset)
User::User(Expression::ExpressionPtr size,
Expression::ExpressionPtr offset,
std::string const& argName)
: BaseDimension(size, Expression::literal(1u), offset)
, argumentName(rocRoller::SCRATCH)
, argumentName(argName)
{
}

Expand Down
Loading
Loading