Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3ee9b9c
Change coordinate and control grph to reset flags
Dec 1, 2025
e5fc05d
Formatting
Dec 1, 2025
c096700
Remove wait zero after reset
Dec 2, 2025
5ce1959
Map ScratchPolicy to m_scratchAllocator
Dec 4, 2025
4a6c692
Enable multiple scratch allocation
Dec 4, 2025
d92cf90
Resolve merge conflicts
Dec 4, 2025
261b980
Add unit tests
Dec 5, 2025
002c05b
Formatting, add comments
Dec 5, 2025
c99bb7a
Minor changes
Dec 5, 2025
4575181
Add streamk_fp4 suite
Dec 6, 2025
33e427d
Fix for-loop range
Dec 6, 2025
3694e65
Fix for-lopp range
Dec 8, 2025
2d7d21a
Merge branch 'develop' into users/liu-yiqian/reset-flag
liu-yiqian Dec 8, 2025
fd85c81
Remove streamk suite, will add that separately
Dec 8, 2025
a2979c0
Update test, allocateScratch() returns size
Dec 10, 2025
f8ea067
Merge branch 'develop' into users/liu-yiqian/reset-flag
liu-yiqian Dec 10, 2025
c43dd23
Formatting
Dec 10, 2025
ee18ee9
Merge branch 'users/liu-yiqian/reset-flag' of github.com:ROCm/rocm-li…
Dec 10, 2025
6ea17d9
Only wave0_workitem0 writes to flag
Dec 11, 2025
c19d7f0
Print info when flag is not rest to 0, remove unnecessary WaitZero in…
Dec 11, 2025
3b9ec64
Enable rocroller streamk in hipblaslt
Dec 11, 2025
96262ce
Update SolutionIndexParameters
Dec 11, 2025
fa87f64
Use prob.Synchronizer for flags
Dec 11, 2025
f85411b
Enable streamk
Dec 11, 2025
7aa9059
Disable workgroupMapping when streamK is on
Dec 11, 2025
a769025
Not use streamk for f6 kernels
Dec 12, 2025
6cb5d65
Resolve conflicts
Dec 12, 2025
90ff6bd
Merge branch 'develop' into users/liu-yiqian/reset-flag
liu-yiqian Dec 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
ShowValue(gemm->kernelType.scaleBMode));

std::optional<Operations::OperationTag> tagTensorScaleA, tagLoadScaleA, tagBlockScaleA,
tagTensorScaleB, tagLoadScaleB, tagBlockScaleB, tagScratch, tagSKGrid, tagWGM;
tagTensorScaleB, tagLoadScaleB, tagBlockScaleB, tagSKGrid, tagWGM;
std::map<Operations::ScratchPolicy, Operations::OperationTag> tagScratch;

if(gemm->kernelType.scaleAMode == Operations::ScaleMode::Separate)
{
Expand Down Expand Up @@ -265,12 +266,19 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
DataDirection::ReadOnly,
rocRoller::NUMWGS);

tagScratch = command->allocateTag();
command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal),
*tagScratch,
ArgumentType::Value,
DataDirection::ReadWrite,
rocRoller::SCRATCH);
// Create Scratch operations for each ScratchPolicy
for(int i = 0; i < static_cast<int>(Operations::ScratchPolicy::Count); ++i)
{
auto policy = static_cast<Operations::ScratchPolicy>(i);
auto tag = command->allocateTag();
command->addOperation(Operations::Scratch(tag, policy));
command->allocateArgument(VariableType(DataType::UInt32, PointerType::PointerGlobal),
tag,
ArgumentType::Value,
DataDirection::ReadWrite,
rocRoller::getScratchName(policy));
tagScratch[policy] = tag;
}
}

if(gemm->workgroupMappingDim != -1)
Expand Down Expand Up @@ -489,8 +497,7 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
if(tagTensorScaleB)
gemmKernel->tagTensorScaleB = *tagTensorScaleB;

if(tagScratch)
gemmKernel->tagScratch = *tagScratch;
gemmKernel->tagScratch = tagScratch;

if(tagSKGrid)
gemmKernel->tagSKGrid = *tagSKGrid;
Expand Down Expand Up @@ -523,7 +530,14 @@ size_t workspaceRequired(std::shared_ptr<GemmKernel> gemm, const RocblasltContra

auto runtimeArgs = commandArgs.runtimeArguments();

return gemm->commandKernel->scratchSpaceRequired(runtimeArgs);
// Sum scratch requirements for all policies
size_t total = 0;
for(int i = 0; i < static_cast<int>(Operations::ScratchPolicy::Count); ++i)
{
auto policy = static_cast<Operations::ScratchPolicy>(i);
total += gemm->commandKernel->scratchSpaceRequired(policy, runtimeArgs);
}
return total;
}

CommandArguments createCommandArguments(std::shared_ptr<GemmKernel> gemm,
Expand Down Expand Up @@ -622,23 +636,53 @@ rocblaslt_status runGemmKernel(std::shared_ptr<GemmKernel> gemm,
}
return rocblaslt_status_invalid_value;
}

auto commandArgs = createCommandArguments(gemm, prob, DEFAULT_WGM);

// Add scratch space
if(workSpaceRequired > 0)
// Track allocated scratch memory for each policy
std::array<void*, static_cast<size_t>(Operations::ScratchPolicy::Count)> scratchPtrs = {};
Comment thread
bnemanich marked this conversation as resolved.
Outdated
std::array<size_t, static_cast<size_t>(Operations::ScratchPolicy::Count)> scratchSizes = {};

if(gemm->params->streamK)
{
commandArgs.setArgument(
gemm->tagScratch, ArgumentType::Value, static_cast<unsigned char*>(prob.workspace));
auto runtimeArgs = commandArgs.runtimeArguments();
for(int i = 0; i < static_cast<int>(Operations::ScratchPolicy::Count); ++i)
{
auto policy = static_cast<Operations::ScratchPolicy>(i);
scratchSizes[i] = gemm->commandKernel->scratchSpaceRequired(policy, runtimeArgs);
if(scratchSizes[i] > 0)
{
commandArgs.setArgument(
gemm->tagScratch.at(policy), ArgumentType::Value,
static_cast<unsigned char*>(scratchPtrs[i]));
}
}
}

auto runtimeArgs = commandArgs.runtimeArguments();

if(!gemm->commandKernel->matchesPredicates(runtimeArgs, LogLevel::Error))
{
// Free allocated scratch memory before returning
for(int i = 0; i < static_cast<int>(Operations::ScratchPolicy::Count); ++i)
{
if(scratchPtrs[i] != nullptr)
{
AssertFatal(hipFree(scratchPtrs[i]), "Failed to free scratch memory" + ShowValue(i));
}
}
return rocblaslt_status_invalid_value;
}

gemm->commandKernel->launchKernel(runtimeArgs, prob.stream);

// Free allocated scratch memory after kernel completes
for(int i = 0; i < static_cast<int>(Operations::ScratchPolicy::Count); ++i)
{
if(scratchPtrs[i] != nullptr)
{
AssertFatal(hipFree(scratchPtrs[i]), "Failed to free scratch memory" + ShowValue(i));
}
}

return rocblaslt_status_success;
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
#include <rocRoller/CommandSolution.hpp>
#include <rocRoller/Expression.hpp>
#include <rocRoller/KernelGraph/CoordinateGraph/Dimension.hpp>
#include <rocRoller/Operations/Scratch_fwd.hpp>
#include <rocRoller/TensorDescriptor.hpp>

#include <map>

/**
* @brief GemmKernel
*
Expand All @@ -57,7 +60,7 @@ struct GemmKernel
rocRoller::Operations::OperationTag tagTensorScaleA;
rocRoller::Operations::OperationTag tagTensorScaleB;

rocRoller::Operations::OperationTag tagScratch;
std::map<rocRoller::Operations::ScratchPolicy, rocRoller::Operations::OperationTag> tagScratch;
rocRoller::Operations::OperationTag tagSKGrid;
rocRoller::Operations::OperationTag tagWGM;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct SolutionIndexParameters
{
WorkGroupTileSize workgroupTile;
bool workgroupMapping;
bool streamK;
};

int parametersToIndex(const SolutionIndexParameters& params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,16 @@ std::shared_ptr<SolutionParameters>
gemm->workgroupRemapXCC = true;
}

// TODO: StreamK is not currently working with prefetching or workgroup mapping
if(gemm->streamK)
{
gemm->prefetch = false;
gemm->workgroupMappingDim = -1;
gemm->workgroupRemapXCC = false;
}
// Pass StreamK flag from solution index parameters
gemm->streamK = solutionIndexParameters.streamK;

// // StreamK is not currently working with prefetching or workgroup mapping
// if(gemm->streamK)
// {
// gemm->prefetch = false;
// gemm->workgroupMappingDim = -1;
// gemm->workgroupRemapXCC = false;
// }

return gemm;
}
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,21 @@ std::vector<SolutionIndexParameters> chooseSolutionIndexParameters(
|| !std::has_single_bit(static_cast<uint>(wgt.n))))
continue;

params.push_back({wgt, true});
params.push_back({wgt, true, false});

if (prob.k < USE_WORKGROUP_MAPPING_K_SIZE)
{
params.back().workgroupMapping = false;
}

// Enable StreamK when number of output tiles < number of CUs
size_t numTilesM = prob.m / wgt.m;
size_t numTilesN = prob.n / wgt.n;
size_t numTiles = numTilesM * numTilesN * prob.batch_count;
if(numTiles < analytical_hardware.N_CU)
{
params.back().streamK = true;
}
}
}

Expand All @@ -261,6 +270,8 @@ int parametersToIndex(const SolutionIndexParameters& params)
result |= ((params.workgroupTile.m / REQUIRED_MULTIPLE_M_N) << pos);
pos += MAX_BITS_WORKGROUPTILE_M;
result |= ((params.workgroupMapping ? 1 : 0) << pos);
pos += 1;
result |= ((params.streamK ? 1 : 0) << pos);

// Set top bit indicating it is a rocRoller index
result |= (1 << 31);
Expand Down Expand Up @@ -290,6 +301,8 @@ SolutionIndexParameters indexToParameters(int index)
= ((index >> pos) & mask(MAX_BITS_WORKGROUPTILE_M)) * REQUIRED_MULTIPLE_M_N;
pos += MAX_BITS_WORKGROUPTILE_M;
result.workgroupMapping = (index >> pos) & 1;
pos += 1;
result.streamK = (index >> pos) & 1;

return result;
}
4 changes: 3 additions & 1 deletion shared/rocroller/client/include/client/GEMMSolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "GEMMParameters.hpp"

#include <rocRoller/Operations/CommandArgument_fwd.hpp>
#include <rocRoller/Operations/Scratch_fwd.hpp>

using namespace rocRoller;

Expand Down Expand Up @@ -61,7 +62,8 @@ namespace rocRoller
{
}

virtual Operations::OperationTag getScratchTag() const
virtual Operations::OperationTag
getScratchTag(Operations::ScratchPolicy scratchPolicy) const
{
return {};
}
Expand Down
34 changes: 26 additions & 8 deletions shared/rocroller/client/include/client/StreamKGEMMSolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ namespace rocRoller
{
class StreamKGEMMSolution : public DataParallelGEMMSolution
{
Operations::OperationTag m_scratchTag, m_numWGsTag;
std::map<Operations::ScratchPolicy, Operations::OperationTag> m_scratchTags;
Operations::OperationTag m_numWGsTag;

public:
using DataParallelGEMMSolution::DataParallelGEMMSolution;

Operations::OperationTag getScratchTag() const override
Operations::OperationTag
getScratchTag(Operations::ScratchPolicy scratchPolicy) const override
{
return m_scratchTag;
return m_scratchTags.at(scratchPolicy);
}

protected:
Expand All @@ -63,15 +65,31 @@ namespace rocRoller
DataDirection::ReadOnly,
rocRoller::NUMWGS);

m_scratchTag = command->allocateTag();
command->addOperation(rocRoller::Operations::Scratch(
m_scratchTag, Operations::ScratchPolicy::None));
// Create a scratch operation for tile data
m_scratchTags[Operations::ScratchPolicy::None] = command->allocateTag();
command->addOperation(
Operations::Scratch(m_scratchTags[Operations::ScratchPolicy::None],
Operations::ScratchPolicy::None));
command->allocateArgument(
VariableType(DataType::UInt32, PointerType::PointerGlobal),
m_scratchTag,
m_scratchTags[Operations::ScratchPolicy::None],
ArgumentType::Value,
DataDirection::ReadWrite,
rocRoller::SCRATCH);
getScratchName(Operations::ScratchPolicy::None));

// Create a scratch operation for flags
m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter]
= command->allocateTag();
command->addOperation(Operations::Scratch(
m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter],
Operations::ScratchPolicy::ZeroedBeforeAndAfter));
command->allocateArgument(
VariableType(DataType::UInt32, PointerType::PointerGlobal),
m_scratchTags[Operations::ScratchPolicy::ZeroedBeforeAndAfter],
ArgumentType::Value,
DataDirection::ReadWrite,
getScratchName(Operations::ScratchPolicy::ZeroedBeforeAndAfter));

return command;
}

Expand Down
31 changes: 27 additions & 4 deletions shared/rocroller/client/src/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
*******************************************************************************/

#include "rocRoller/Serialization/YAML.hpp"
#include <algorithm>
#include <filesystem>
#include <fstream>
#include <string>
Expand Down Expand Up @@ -344,14 +345,18 @@ namespace rocRoller::Client::GEMMClient
auto runtimeArgs = commandArgs.runtimeArguments();

// Note: the lifetime of deviceScratch needs to exceed kernel executions
std::shared_ptr<uint8_t> deviceScratch;
std::shared_ptr<uint8_t>
deviceScratch[static_cast<size_t>(Operations::ScratchPolicy::Count)];

for(int i = 0; i < static_cast<int>(Operations::ScratchPolicy::Count); ++i)
{
auto scratchSpaceRequired = commandKernel->scratchSpaceRequired(runtimeArgs);
auto policy = static_cast<Operations::ScratchPolicy>(i);
auto scratchSpaceRequired = commandKernel->scratchSpaceRequired(policy, runtimeArgs);
if(scratchSpaceRequired > 0)
{
deviceScratch = make_shared_device<uint8_t>(scratchSpaceRequired, 0);
deviceScratch[i] = make_shared_device<uint8_t>(scratchSpaceRequired, 0);
commandArgs.setArgument(
gemm->getScratchTag(), ArgumentType::Value, deviceScratch.get());
gemm->getScratchTag(policy), ArgumentType::Value, deviceScratch[i].get());
}
}

Expand Down Expand Up @@ -457,6 +462,24 @@ namespace rocRoller::Client::GEMMClient
auto [correct, rnorm] = validate<A, B, C, D>(
hostA, hostB, hostC, hostD, hostScaleA, hostScaleB, problemParams, arch);

// Verify ZeroedBeforeAndAfter scratch is all zeros after kernel
auto zeroedIdx = static_cast<size_t>(Operations::ScratchPolicy::ZeroedBeforeAndAfter);
if(deviceScratch[zeroedIdx])
{
auto zeroedSize = commandKernel->scratchSpaceRequired(
Operations::ScratchPolicy::ZeroedBeforeAndAfter, runtimeArgs);
std::vector<uint8_t> zeroedResult(zeroedSize);
AssertFatal(hipMemcpy(zeroedResult.data(),
deviceScratch[zeroedIdx].get(),
zeroedSize,
hipMemcpyDeviceToHost)
== (hipError_t)HIP_SUCCESS);
AssertFatal(
std::all_of(
zeroedResult.begin(), zeroedResult.end(), [](uint8_t v) { return v == 0; }),
"ZeroedBeforeAndAfter scratch should be all zeros after kernel execution");
}

result.checked = true;
result.correct = correct;
result.rnorm = rnorm;
Expand Down
8 changes: 6 additions & 2 deletions shared/rocroller/lib/include/rocRoller/CommandSolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,17 @@ namespace rocRoller

/**
* @brief Returns the total number of bytes required for scratch space
* for the specified scratch policy.
*
* If this value is greather than 0, the user is required to allocate this
* If this value is greater than 0, the user is required to allocate this
* amount of device memory and pass it into the kernel.
*
* @param policy The scratch policy to query
* @param args The runtime arguments
* @return size_t
*/
size_t scratchSpaceRequired(RuntimeArguments const& args) const;
size_t scratchSpaceRequired(Operations::ScratchPolicy policy,
RuntimeArguments const& args) const;

/**
* @brief Returns the workgroup size
Expand Down
Loading
Loading