Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions projects/hipblaslt/library/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ target_sources(hipblaslt
"${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/parameter_selection.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/rocroller_host.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/runtime_args_selection.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/solution_cache.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp"
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ target_sources(hipblaslt-rocroller
"${CMAKE_CURRENT_SOURCE_DIR}/parameter_selection.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/runtime_args_selection.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/rocroller_host.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/solution_cache.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/solution_selection.cpp"
)
target_compile_features(hipblaslt-rocroller PRIVATE cxx_std_20)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ can contain other parameters besides the tile size. A list of `SolutionIndexPara
in sorted order, based on what Origami predicts is the best performing kernel.

If a kernel has already been generated for the specific `SolutionIndexParameters` instance that was selected,
the kernel can be found in the cache and returned.
the kernel can be found in the `SolutionCache` and returned.

Otherwise, the rest of the `SolutionParameters` need to be selected.
`SolutionParameters` contain all of the parameters required to generate a kernel. These parameters
are selected based on the `KernelType` and the `SolutionIndexParameters`.

Once all of the `SolutionParameters` have been selected, the kernel is generated using rocRoller. The kernel
is then saved for reuse in the cache and returned.
is then saved for reuse in the `SolutionCache` and returned.

# Calling a rocRoller Kernel

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ std::string genKernelName(std::shared_ptr<SolutionParameters> gemm)
gemm->kernelType.typeAcc})
rv << toString(t) << "_";

if(gemm->kernelType.scaleAMode != Operations::ScaleMode::None)
rv << "SA_" << genScaleModeString(gemm->kernelType.scaleAMode) << "_";
if(gemm->kernelType.scaleBMode != Operations::ScaleMode::None)
rv << "SB_" << genScaleModeString(gemm->kernelType.scaleBMode) << "_";
if(gemm->kernelType.scaleTypeA.mode != Operations::ScaleMode::None)
rv << "SA_" << genScaleModeString(gemm->kernelType.scaleTypeA.mode) << "_";
if(gemm->kernelType.scaleTypeB.mode != Operations::ScaleMode::None)
rv << "SB_" << genScaleModeString(gemm->kernelType.scaleTypeB.mode) << "_";

rv << "WGT_";
rocRoller::streamJoin(
Expand Down Expand Up @@ -158,25 +158,25 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
auto mulInputA = tagLoadA;
auto mulInputB = tagLoadB;

AssertFatal(gemm->kernelType.scaleAMode == Operations::ScaleMode::None
|| gemm->kernelType.scaleAMode == Operations::ScaleMode::SingleScale
|| gemm->kernelType.scaleAMode == Operations::ScaleMode::Separate,
AssertFatal(gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::None
|| gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::SingleScale
|| gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate,
"Scale mode not supported!",
ShowValue(gemm->kernelType.scaleAMode));
AssertFatal(gemm->kernelType.scaleBMode == Operations::ScaleMode::None
|| gemm->kernelType.scaleBMode == Operations::ScaleMode::SingleScale
|| gemm->kernelType.scaleBMode == Operations::ScaleMode::Separate,
ShowValue(gemm->kernelType.scaleTypeA.mode));
AssertFatal(gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::None
|| gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::SingleScale
|| gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate,
"Scale mode not supported!",
ShowValue(gemm->kernelType.scaleBMode));
ShowValue(gemm->kernelType.scaleTypeB.mode));

std::optional<Operations::OperationTag> tagTensorScaleA, tagLoadScaleA, tagBlockScaleA,
tagTensorScaleB, tagLoadScaleB, tagBlockScaleB, tagScratch, tagSKGrid, tagWGM;

if(gemm->kernelType.scaleAMode == Operations::ScaleMode::Separate)
if(gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate)
{
tagTensorScaleA = command->addOperation(rocRoller::Operations::Tensor(
2,
gemm->kernelType.scaleTypeA,
gemm->kernelType.scaleTypeA.type,
gemm->kernelType.transA ? oneStridesT : oneStridesN));
tagLoadScaleA
= command->addOperation(rocRoller::Operations::T_Load_Tiled(*tagTensorScaleA));
Expand All @@ -185,14 +185,14 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
tagLoadA,
2,
tagLoadScaleA,
{gemm->kernelType.scaleABlockColSize, gemm->kernelType.scaleABlockRowSize}));
{gemm->kernelType.scaleTypeA.blockColSize, gemm->kernelType.scaleTypeA.blockRowSize}));
}

if(gemm->kernelType.scaleBMode == Operations::ScaleMode::Separate)
if(gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate)
{
tagTensorScaleB = command->addOperation(rocRoller::Operations::Tensor(
2,
gemm->kernelType.scaleTypeA,
gemm->kernelType.scaleTypeA.type,
gemm->kernelType.transB ? oneStridesT : oneStridesN));
tagLoadScaleB
= command->addOperation(rocRoller::Operations::T_Load_Tiled(*tagTensorScaleB));
Expand All @@ -201,7 +201,7 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
tagLoadB,
2,
tagLoadScaleB,
{gemm->kernelType.scaleBBlockColSize, gemm->kernelType.scaleBBlockRowSize}));
{gemm->kernelType.scaleTypeB.blockColSize, gemm->kernelType.scaleTypeB.blockRowSize}));
}

auto tagTensorC
Expand Down Expand Up @@ -330,11 +330,11 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
params->setDimensionInfo(tagLoadA, macTileA);
}

if(gemm->kernelType.scaleAMode == Operations::ScaleMode::Separate)
if(gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate)
{
// TODO: verify the division of scale block size is correct
auto const scaleBlockSize
= gemm->kernelType.scaleABlockRowSize * gemm->kernelType.scaleABlockColSize;
= gemm->kernelType.scaleTypeA.blockRowSize * gemm->kernelType.scaleTypeA.blockColSize;
auto macTileAScale = KernelGraph::CoordinateGraph::MacroTile(
{gemm->workgroupTile.m, gemm->workgroupTile.k / (int)scaleBlockSize},
LayoutType::MATRIX_A,
Expand All @@ -358,11 +358,11 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
params->setDimensionInfo(tagLoadB, macTileB);
}

if(gemm->kernelType.scaleBMode == Operations::ScaleMode::Separate)
if(gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate)
{
// TODO: verify the division of scale block size is correct
auto const scaleBlockSize
= gemm->kernelType.scaleBBlockRowSize * gemm->kernelType.scaleBBlockColSize;
= gemm->kernelType.scaleTypeB.blockRowSize * gemm->kernelType.scaleTypeB.blockColSize;
auto macTileBScale = KernelGraph::CoordinateGraph::MacroTile(
{gemm->workgroupTile.k / (int)scaleBlockSize, gemm->workgroupTile.n},
LayoutType::MATRIX_B,
Expand Down Expand Up @@ -547,15 +547,15 @@ CommandArguments createCommandArguments(std::shared_ptr<GemmKernel> gemm,
setCommandTensorArg(commandArgs, gemm->tagTensorA, descA, (float*)nullptr);
setCommandTensorArg(commandArgs, gemm->tagTensorB, descB, (float*)nullptr);

if(gemm->params->kernelType.scaleAMode == Operations::ScaleMode::Separate)
if(gemm->params->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate)
{
auto const scaleBlockSize = prob.scaleABlockRowSize * prob.scaleABlockColSize;
TensorDescriptor descAScale(gemm->params->kernelType.typeA,
{size_t(M), size_t(K / scaleBlockSize)},
gemm->params->kernelType.transA ? "T" : "N");
setCommandTensorArg(commandArgs, gemm->tagTensorScaleA, descAScale, (float*)nullptr);
}
if(gemm->params->kernelType.scaleBMode == Operations::ScaleMode::Separate)
if(gemm->params->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate)
{
auto const scaleBlockSize = prob.scaleBBlockRowSize * prob.scaleBBlockColSize;
TensorDescriptor descBScale(gemm->params->kernelType.typeB,
Expand All @@ -578,12 +578,12 @@ CommandArguments createCommandArguments(std::shared_ptr<GemmKernel> gemm,
commandArgs.setArgument(gemm->tagTensorC, ArgumentType::Value, (float*)prob.C);
commandArgs.setArgument(gemm->tagTensorD, ArgumentType::Value, (float*)prob.D);

if(gemm->params->kernelType.scaleAMode == Operations::ScaleMode::Separate)
if(gemm->params->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate)
{
commandArgs.setArgument(gemm->tagTensorScaleA, ArgumentType::Value, (uint8_t*)prob.scaleA);
}

if(gemm->params->kernelType.scaleBMode == Operations::ScaleMode::Separate)
if(gemm->params->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate)
{
commandArgs.setArgument(gemm->tagTensorScaleB, ArgumentType::Value, (uint8_t*)prob.scaleB);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,30 @@
#include <rocRoller/DataTypes/DataTypes.hpp>
#include <rocRoller/Operations/Command.hpp>

struct ScaleType
{
rocRoller::Operations::ScaleMode mode;
size_t blockRowSize = 32u;
size_t blockColSize = 1u;
rocRoller::DataType type = rocRoller::DataType::E8M0;

auto operator<=>(const ScaleType& other) const = default;
};

template<>
struct std::hash<ScaleType>
{
size_t operator()(const ScaleType& s) const noexcept
{
size_t modeHash = std::hash<rocRoller::Operations::ScaleMode>{}(s.mode);
size_t blockRowSizeHash = std::hash<size_t>{}(s.blockRowSize);
size_t blockColSizeHash = std::hash<size_t>{}(s.blockColSize);
size_t typeHash = std::hash<rocRoller::DataType>{}(s.type);

return modeHash ^ (blockRowSizeHash << 1) ^ (blockColSizeHash << 2) ^ (typeHash << 3);
}
};

/**
* @brief KernelType
*
Expand All @@ -48,16 +72,29 @@ struct KernelType
bool transA;
bool transB;

rocRoller::Operations::ScaleMode scaleAMode;
rocRoller::Operations::ScaleMode scaleBMode;
ScaleType scaleTypeA;
ScaleType scaleTypeB;

size_t scaleABlockRowSize = 32u;
size_t scaleABlockColSize = 1u;
size_t scaleBBlockRowSize = 1u;
size_t scaleBBlockColSize = 32u;
auto operator<=>(const KernelType& other) const = default;
};

rocRoller::DataType scaleTypeA = rocRoller::DataType::E8M0;
rocRoller::DataType scaleTypeB = rocRoller::DataType::E8M0;
template<>
struct std::hash<KernelType>
{
size_t operator()(const KernelType& k) const noexcept
{
size_t typeAHash = std::hash<rocRoller::DataType>{}(k.typeA);
size_t typeBHash = std::hash<rocRoller::DataType>{}(k.typeB);
size_t typeCHash = std::hash<rocRoller::DataType>{}(k.typeC);
size_t typeDHash = std::hash<rocRoller::DataType>{}(k.typeD);
size_t typeAccHash = std::hash<rocRoller::DataType>{}(k.typeAcc);
size_t scaleTypeAHash = std::hash<ScaleType>{}(k.scaleTypeA);
size_t scaleTypeBHash = std::hash<ScaleType>{}(k.scaleTypeB);
size_t transAHash = std::hash<bool>{}(k.transA);
size_t transBHash = std::hash<bool>{}(k.transB);

auto operator<=>(const KernelType& other) const = default;
return typeAHash ^ (typeBHash << 1) ^ (typeCHash << 2) ^
(typeDHash << 3) ^ (typeAccHash << 4) ^ (scaleTypeAHash << 5) ^
(scaleTypeBHash << 6) ^ (transAHash << 7) ^ (transBHash << 8);
}
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*! \file */
/* ************************************************************************
*
* MIT License
*
* Copyright (C) 2025 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
* ************************************************************************ */

#pragma once

#include "gemm.hpp"
#include "kernel_type.hpp"
#include "solution_selection.hpp"

class SolutionCache
{
public:
void addKernel(const KernelType& kernelType, const SolutionIndexParameters& params, std::shared_ptr<GemmKernel> kernel);
std::optional<std::shared_ptr<GemmKernel>> getKernel(const KernelType& kernelType, const SolutionIndexParameters& params);

private:
// Map of kernels that have already been generated.
// The first level of the map is indexed with a KernelType.
// The second level of the map is indexed with a hash value of a
// SolutionIndexParameters type.
// The value is a GemmKernel.
std::unordered_map<KernelType, std::unordered_map<int, std::shared_ptr<GemmKernel>>> m_generatedKernels;
};
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct WorkGroupTileSize
int m;
int n;
int k;

auto operator<=>(const WorkGroupTileSize& other) const = default;
};

/**
Expand Down Expand Up @@ -69,6 +71,8 @@ struct SolutionIndexParameters
{
WorkGroupTileSize workgroupTile;
bool workgroupMapping;

auto operator<=>(const SolutionIndexParameters& other) const = default;
};

int parametersToIndex(const SolutionIndexParameters& params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ std::shared_ptr<SolutionParameters>

// Swizzle Scale only support in certain situations
// Swizzle Scale also runs out of registers with FP8
if (kernelType.scaleAMode != rocRoller::Operations::ScaleMode::Separate ||
kernelType.scaleBMode != rocRoller::Operations::ScaleMode::Separate)
if (kernelType.scaleTypeA.mode != rocRoller::Operations::ScaleMode::Separate ||
kernelType.scaleTypeB.mode != rocRoller::Operations::ScaleMode::Separate)
{
gemm->swizzleScale = false;
gemm->prefetchScale = false;
Expand Down Expand Up @@ -161,18 +161,18 @@ std::shared_ptr<SolutionParameters>
// LDS can only be used for scaling data with certain workgroup tile sizes
auto workgroupSizeTotal = gemm->workgroupSizeX * gemm->workgroupSizeY;
auto numScaleElementsA = 0;
if(gemm->kernelType.scaleABlockRowSize * gemm->kernelType.scaleABlockColSize != 0)
if(gemm->kernelType.scaleTypeA.blockRowSize * gemm->kernelType.scaleTypeA.blockColSize != 0)
{
numScaleElementsA = gemm->workgroupTile.m
* (gemm->workgroupTile.k
/ (gemm->kernelType.scaleABlockRowSize * gemm->kernelType.scaleABlockColSize));
/ (gemm->kernelType.scaleTypeA.blockRowSize * gemm->kernelType.scaleTypeA.blockColSize));
}
auto numScaleElementsB = 0;
if(gemm->kernelType.scaleBBlockRowSize * gemm->kernelType.scaleBBlockColSize != 0)
if(gemm->kernelType.scaleTypeB.blockRowSize * gemm->kernelType.scaleTypeB.blockColSize != 0)
{
numScaleElementsB = gemm->workgroupTile.n
* (gemm->workgroupTile.k
/ (gemm->kernelType.scaleBBlockRowSize * gemm->kernelType.scaleBBlockColSize));
/ (gemm->kernelType.scaleTypeB.blockRowSize * gemm->kernelType.scaleTypeB.blockColSize));
}
if(numScaleElementsA % workgroupSizeTotal != 0)
{
Expand Down
Loading