Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3ee9b9c
Change coordinate and control grph to reset flags
Dec 1, 2025
e5fc05d
Formatting
Dec 1, 2025
c096700
Remove wait zero after reset
Dec 2, 2025
5ce1959
Map ScratchPolicy to m_scratchAllocator
Dec 4, 2025
4a6c692
Enable multiple scratch allocation
Dec 4, 2025
d92cf90
Resolve merge conflicts
Dec 4, 2025
261b980
Add unit tests
Dec 5, 2025
002c05b
Formatting, add comments
Dec 5, 2025
c99bb7a
Minor changes
Dec 5, 2025
4575181
Add streamk_fp4 suite
Dec 6, 2025
33e427d
Fix for-loop range
Dec 6, 2025
3694e65
Fix for-lopp range
Dec 8, 2025
2d7d21a
Merge branch 'develop' into users/liu-yiqian/reset-flag
liu-yiqian Dec 8, 2025
fd85c81
Remove streamk suite, will add that separately
Dec 8, 2025
a2979c0
Update test, allocateScratch() returns size
Dec 10, 2025
f8ea067
Merge branch 'develop' into users/liu-yiqian/reset-flag
liu-yiqian Dec 10, 2025
c43dd23
Formatting
Dec 10, 2025
ee18ee9
Merge branch 'users/liu-yiqian/reset-flag' of github.com:ROCm/rocm-li…
Dec 10, 2025
6ea17d9
Only wave0_workitem0 writes to flag
Dec 11, 2025
c19d7f0
Print info when flag is not rest to 0, remove unnecessary WaitZero in…
Dec 11, 2025
3b9ec64
Enable rocroller streamk in hipblaslt
Dec 11, 2025
96262ce
Update SolutionIndexParameters
Dec 11, 2025
fa87f64
Use prob.Synchronizer for flags
Dec 11, 2025
f85411b
Enable streamk
Dec 11, 2025
7aa9059
Disable workgroupMapping when streamK is on
Dec 11, 2025
a769025
Not use streamk for f6 kernels
Dec 12, 2025
6cb5d65
Resolve conflicts
Dec 12, 2025
90ff6bd
Merge branch 'develop' into users/liu-yiqian/reset-flag
liu-yiqian Dec 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ std::string genKernelName(std::shared_ptr<SolutionParameters> gemm)
if(gemm->kernelType.scaleTypeB.mode != Operations::ScaleMode::None)
rv << "SB_" << genScaleModeString(gemm->kernelType.scaleTypeB.mode) << "_";

if(gemm->streamK)
{
rv << "SK_";
}
rv << "WGT_";
rocRoller::streamJoin(
rv, std::vector{gemm->workgroupTile.m, gemm->workgroupTile.n, gemm->workgroupTile.k}, "x");
Expand Down Expand Up @@ -170,7 +174,8 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
ShowValue(gemm->kernelType.scaleTypeB.mode));

std::optional<Operations::OperationTag> tagTensorScaleA, tagLoadScaleA, tagBlockScaleA,
tagTensorScaleB, tagLoadScaleB, tagBlockScaleB, tagScratch, tagSKGrid, tagWGM;
tagTensorScaleB, tagLoadScaleB, tagBlockScaleB, tagSKGrid, tagWGM;
std::map<Operations::ScratchPolicy, Operations::OperationTag> tagScratch;

if(gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate)
{
Expand Down Expand Up @@ -265,12 +270,19 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
DataDirection::ReadOnly,
rocRoller::NUMWGS);

tagScratch = command->allocateTag();
command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal),
*tagScratch,
ArgumentType::Value,
DataDirection::ReadWrite,
rocRoller::SCRATCH);
// Create Scratch operations for each ScratchPolicy
for(int i = 0; i < static_cast<int>(Operations::ScratchPolicy::Count); ++i)
{
auto policy = static_cast<Operations::ScratchPolicy>(i);
auto tag = command->allocateTag();
command->addOperation(Operations::Scratch(tag, policy));
command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal),
tag,
ArgumentType::Value,
DataDirection::ReadWrite,
rocRoller::getScratchName(policy));
tagScratch[policy] = tag;
}
}

if(gemm->workgroupMappingDim != -1)
Expand Down Expand Up @@ -489,8 +501,7 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
if(tagTensorScaleB)
gemmKernel->tagTensorScaleB = *tagTensorScaleB;

if(tagScratch)
gemmKernel->tagScratch = *tagScratch;
gemmKernel->tagScratch = tagScratch;

if(tagSKGrid)
gemmKernel->tagSKGrid = *tagSKGrid;
Expand Down Expand Up @@ -523,7 +534,9 @@ size_t workspaceRequired(std::shared_ptr<GemmKernel> gemm, const RocblasltContra

auto runtimeArgs = commandArgs.runtimeArguments();

return gemm->commandKernel->scratchSpaceRequired(runtimeArgs);
// Only return scratch space for ScratchPolicy::None (uses prob.workspace)
return gemm->commandKernel->scratchSpaceRequired(
Operations::ScratchPolicy::None, runtimeArgs);
}

CommandArguments createCommandArguments(std::shared_ptr<GemmKernel> gemm,
Expand Down Expand Up @@ -622,14 +635,34 @@ rocblaslt_status runGemmKernel(std::shared_ptr<GemmKernel> gemm,
}
return rocblaslt_status_invalid_value;
}

auto commandArgs = createCommandArguments(gemm, prob, DEFAULT_WGM);

// Add scratch space
if(workSpaceRequired > 0)

if(gemm->params->streamK)
{
commandArgs.setArgument(
gemm->tagScratch, ArgumentType::Value, static_cast<unsigned char*>(prob.workspace));
auto runtimeArgs = commandArgs.runtimeArguments();

// Use prob.workspace for ScratchPolicy::None
auto noneScratchSize = gemm->commandKernel->scratchSpaceRequired(
Operations::ScratchPolicy::None, runtimeArgs);
if(noneScratchSize > 0 && prob.workspace != nullptr)
{
commandArgs.setArgument(
gemm->tagScratch.at(Operations::ScratchPolicy::None),
ArgumentType::Value,
static_cast<unsigned char*>(prob.workspace));
}

// Use prob.Synchronizer for ScratchPolicy::ZeroedBeforeAndAfter
auto zeroedScratchSize = gemm->commandKernel->scratchSpaceRequired(
Operations::ScratchPolicy::ZeroedBeforeAndAfter, runtimeArgs);
if(zeroedScratchSize > 0 && prob.Synchronizer != nullptr)
{
commandArgs.setArgument(
gemm->tagScratch.at(Operations::ScratchPolicy::ZeroedBeforeAndAfter),
ArgumentType::Value,
static_cast<unsigned char*>(prob.Synchronizer));
}
}

auto runtimeArgs = commandArgs.runtimeArguments();
Expand All @@ -640,5 +673,6 @@ rocblaslt_status runGemmKernel(std::shared_ptr<GemmKernel> gemm,
}

gemm->commandKernel->launchKernel(runtimeArgs, prob.stream);

return rocblaslt_status_success;
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
#include <rocRoller/CommandSolution.hpp>
#include <rocRoller/Expression.hpp>
#include <rocRoller/KernelGraph/CoordinateGraph/Dimension.hpp>
#include <rocRoller/Operations/Scratch_fwd.hpp>
#include <rocRoller/TensorDescriptor.hpp>

#include <map>

/**
* @brief GemmKernel
*
Expand All @@ -57,7 +60,7 @@ struct GemmKernel
rocRoller::Operations::OperationTag tagTensorScaleA;
rocRoller::Operations::OperationTag tagTensorScaleB;

rocRoller::Operations::OperationTag tagScratch;
std::map<rocRoller::Operations::ScratchPolicy, rocRoller::Operations::OperationTag> tagScratch;
rocRoller::Operations::OperationTag tagSKGrid;
rocRoller::Operations::OperationTag tagWGM;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ struct SolutionIndexParameters
{
WorkGroupTileSize workgroupTile;
bool workgroupMapping;
bool streamK;

auto operator<=>(const SolutionIndexParameters& other) const = default;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ std::string SolutionParameters::toString() const
result << "MachineInstruction:" << machineInstruction.m << "x" << machineInstruction.n << "x"
<< machineInstruction.k << std::endl;
result << "WorkgroupSize:" << workgroupSizeX << "x" << workgroupSizeY << std::endl;
result << "StreamK: " << streamK << std::endl;
result << "LoadA: " << loadPathA << std::endl;
result << "LoadB: " << loadPathB << std::endl;
result << "LDS Usage";
Expand Down Expand Up @@ -196,10 +197,12 @@ std::shared_ptr<SolutionParameters>
gemm->workgroupRemapXCC = true;
}

// TODO: StreamK is not currently working with prefetching or workgroup mapping
// Pass StreamK flag from solution index parameters
gemm->streamK = solutionIndexParameters.streamK;

// StreamK is not currently working with workgroup mapping due to register pressure
if(gemm->streamK)
{
gemm->prefetch = false;
gemm->workgroupMappingDim = -1;
gemm->workgroupRemapXCC = false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,22 @@ std::vector<SolutionIndexParameters> chooseSolutionIndexParameters(
|| !std::has_single_bit(static_cast<uint>(wgt.n))))
continue;

params.push_back({wgt, true});
params.push_back({wgt, true, false});

if (prob.k < USE_WORKGROUP_MAPPING_K_SIZE)
{
params.back().workgroupMapping = false;
}

// Enable StreamK when number of output tiles < number of CUs and not f6 data type
size_t numTilesM = prob.m / wgt.m;
size_t numTilesN = prob.n / wgt.n;
size_t numTiles = numTilesM * numTilesN * prob.batch_count;
auto isF6 = (kernelType.typeA == rocRoller::DataType::FP6 || kernelType.typeA == rocRoller::DataType::BF6 || kernelType.typeB == rocRoller::DataType::FP6 || kernelType.typeB == rocRoller::DataType::BF6);
if(numTiles < analytical_hardware.N_CU && !isF6)
{
params.back().streamK = true;
}
}
}

Expand All @@ -261,6 +271,8 @@ int parametersToIndex(const SolutionIndexParameters& params)
result |= ((params.workgroupTile.m / REQUIRED_MULTIPLE_M_N) << pos);
pos += MAX_BITS_WORKGROUPTILE_M;
result |= ((params.workgroupMapping ? 1 : 0) << pos);
pos += 1;
result |= ((params.streamK ? 1 : 0) << pos);

// Set top bit indicating it is a rocRoller index
result |= (1 << 31);
Expand Down Expand Up @@ -290,6 +302,8 @@ SolutionIndexParameters indexToParameters(int index)
= ((index >> pos) & mask(MAX_BITS_WORKGROUPTILE_M)) * REQUIRED_MULTIPLE_M_N;
pos += MAX_BITS_WORKGROUPTILE_M;
result.workgroupMapping = (index >> pos) & 1;
pos += 1;
result.streamK = (index >> pos) & 1;

return result;
}
4 changes: 3 additions & 1 deletion shared/rocroller/client/include/client/GEMMSolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "GEMMParameters.hpp"

#include <rocRoller/Operations/CommandArgument_fwd.hpp>
#include <rocRoller/Operations/Scratch_fwd.hpp>

using namespace rocRoller;

Expand Down Expand Up @@ -61,7 +62,8 @@ namespace rocRoller
{
}

virtual Operations::OperationTag getScratchTag() const
virtual Operations::OperationTag
getScratchTag(Operations::ScratchPolicy scratchPolicy) const
{
return {};
}
Expand Down
34 changes: 26 additions & 8 deletions shared/rocroller/client/include/client/StreamKGEMMSolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ namespace rocRoller
{
class StreamKGEMMSolution : public DataParallelGEMMSolution
{
Operations::OperationTag m_scratchTag, m_numWGsTag;
std::map<Operations::ScratchPolicy, Operations::OperationTag> m_scratchTags;
Operations::OperationTag m_numWGsTag;

public:
using DataParallelGEMMSolution::DataParallelGEMMSolution;

Operations::OperationTag getScratchTag() const override
Operations::OperationTag
getScratchTag(Operations::ScratchPolicy scratchPolicy) const override
{
return m_scratchTag;
return m_scratchTags.at(scratchPolicy);
}

protected:
Expand All @@ -63,15 +65,31 @@ namespace rocRoller
DataDirection::ReadOnly,
rocRoller::NUMWGS);

m_scratchTag = command->allocateTag();
command->addOperation(rocRoller::Operations::Scratch(
m_scratchTag, Operations::ScratchPolicy::None));
// Create a scratch operation for tile data
m_scratchTags[Operations::ScratchPolicy::None] = command->allocateTag();
command->addOperation(
Operations::Scratch(m_scratchTags[Operations::ScratchPolicy::None],
Operations::ScratchPolicy::None));
command->allocateArgument(
VariableType(DataType::UInt32, PointerType::PointerGlobal),
m_scratchTag,
m_scratchTags[Operations::ScratchPolicy::None],
ArgumentType::Value,
DataDirection::ReadWrite,
rocRoller::SCRATCH);
getScratchName(Operations::ScratchPolicy::None));

// Create a scratch operation for flags
m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter]
= command->allocateTag();
command->addOperation(Operations::Scratch(
m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter],
Operations::ScratchPolicy::ZeroedBeforeAndAfter));
command->allocateArgument(
VariableType(DataType::UInt32, PointerType::PointerGlobal),
m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter],
ArgumentType::Value,
DataDirection::ReadWrite,
getScratchName(Operations::ScratchPolicy::ZeroedBeforeAndAfter));

return command;
}

Expand Down
31 changes: 27 additions & 4 deletions shared/rocroller/client/src/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
*******************************************************************************/

#include "rocRoller/Serialization/YAML.hpp"
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <string>
Expand Down Expand Up @@ -344,14 +345,18 @@ namespace rocRoller::Client::GEMMClient
auto runtimeArgs = commandArgs.runtimeArguments();

// Note: the lifetime of deviceScratch needs to exceed kernel executions
std::shared_ptr<uint8_t> deviceScratch;
std::shared_ptr<uint8_t>
deviceScratch[static_cast<size_t>(Operations::ScratchPolicy::Count)];

for(int i = 0; i < static_cast<int>(Operations::ScratchPolicy::Count); ++i)
{
auto scratchSpaceRequired = commandKernel->scratchSpaceRequired(runtimeArgs);
auto policy = static_cast<Operations::ScratchPolicy>(i);
auto scratchSpaceRequired = commandKernel->scratchSpaceRequired(policy, runtimeArgs);
if(scratchSpaceRequired > 0)
{
deviceScratch = make_shared_device<uint8_t>(scratchSpaceRequired, 0);
deviceScratch[i] = make_shared_device<uint8_t>(scratchSpaceRequired, 0);
commandArgs.setArgument(
gemm->getScratchTag(), ArgumentType::Value, deviceScratch.get());
gemm->getScratchTag(policy), ArgumentType::Value, deviceScratch[i].get());
}
}

Expand Down Expand Up @@ -457,6 +462,24 @@ namespace rocRoller::Client::GEMMClient
auto [correct, rnorm] = validate<A, B, C, D>(
hostA, hostB, hostC, hostD, hostScaleA, hostScaleB, problemParams, arch);

// Verify ZeroedBeforeAndAfter scratch is all zeros after kernel
auto zeroedIdx = static_cast<size_t>(Operations::ScratchPolicy::ZeroedBeforeAndAfter);
if(deviceScratch[zeroedIdx])
{
auto zeroedSize = commandKernel->scratchSpaceRequired(
Operations::ScratchPolicy::ZeroedBeforeAndAfter, runtimeArgs);
std::vector<uint8_t> zeroedResult(zeroedSize);
AssertFatal(hipMemcpy(zeroedResult.data(),
deviceScratch[zeroedIdx].get(),
zeroedSize,
hipMemcpyDeviceToHost)
== (hipError_t)HIP_SUCCESS);
AssertFatal(
std::all_of(
zeroedResult.begin(), zeroedResult.end(), [](uint8_t v) { return v == 0; }),
"ZeroedBeforeAndAfter scratch should be all zeros after kernel execution");
}

result.checked = true;
result.correct = correct;
result.rnorm = rnorm;
Expand Down
8 changes: 6 additions & 2 deletions shared/rocroller/lib/include/rocRoller/CommandSolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,17 @@ namespace rocRoller

/**
* @brief Returns the total number of bytes required for scratch space
* for the specified scratch policy.
*
* If this value is greather than 0, the user is required to allocate this
* If this value is greater than 0, the user is required to allocate this
* amount of device memory and pass it into the kernel.
*
* @param policy The scratch policy to query
* @param args The runtime arguments
* @return size_t
*/
size_t scratchSpaceRequired(RuntimeArguments const& args) const;
size_t scratchSpaceRequired(Operations::ScratchPolicy policy,
RuntimeArguments const& args) const;

/**
* @brief Returns the workgroup size
Expand Down
Loading