Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions projects/hipblaslt/library/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@ target_sources(hipblaslt
"${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/parameter_selection.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/rocroller_host.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/runtime_args_selection.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/solution_cache.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp"
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ target_sources(hipblaslt-rocroller
"${CMAKE_CURRENT_SOURCE_DIR}/parameter_selection.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/runtime_args_selection.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/rocroller_host.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/solution_cache.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/solution_selection.cpp"
)
target_compile_features(hipblaslt-rocroller PRIVATE cxx_std_20)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ can contain other parameters besides the tile size. A list of `SolutionIndexPara
in sorted order, based on what Origami predicts is the best performing kernel.

If a kernel has already been generated for the specific `SolutionIndexParameters` instance that was selected,
the kernel can be found in the cache and returned.
the kernel can be found in the `SolutionCache` and returned.

Otherwise, the rest of the `SolutionParameters` need to be selected.
`SolutionParameters` contain all of the parameters required to generate a kernel. These parameters
are selected based on the `KernelType` and the `SolutionIndexParameters`.

Once all of the `SolutionParameters` have been selected, the kernel is generated using rocRoller. The kernel
is then saved for reuse in the cache and returned.
is then saved for reuse in the `SolutionCache` and returned.

# Calling a rocRoller Kernel

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ std::string genKernelName(std::shared_ptr<SolutionParameters> gemm)
gemm->kernelType.typeAcc})
rv << toString(t) << "_";

if(gemm->kernelType.scaleAMode != Operations::ScaleMode::None)
rv << "SA_" << genScaleModeString(gemm->kernelType.scaleAMode) << "_";
if(gemm->kernelType.scaleBMode != Operations::ScaleMode::None)
rv << "SB_" << genScaleModeString(gemm->kernelType.scaleBMode) << "_";
if(gemm->kernelType.scaleTypeA.mode != Operations::ScaleMode::None)
rv << "SA_" << genScaleModeString(gemm->kernelType.scaleTypeA.mode) << "_";
if(gemm->kernelType.scaleTypeB.mode != Operations::ScaleMode::None)
rv << "SB_" << genScaleModeString(gemm->kernelType.scaleTypeB.mode) << "_";

rv << "WGT_";
rocRoller::streamJoin(
Expand Down Expand Up @@ -158,25 +158,25 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
auto mulInputA = tagLoadA;
auto mulInputB = tagLoadB;

AssertFatal(gemm->kernelType.scaleAMode == Operations::ScaleMode::None
|| gemm->kernelType.scaleAMode == Operations::ScaleMode::SingleScale
|| gemm->kernelType.scaleAMode == Operations::ScaleMode::Separate,
AssertFatal(gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::None
|| gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::SingleScale
|| gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate,
"Scale mode not supported!",
ShowValue(gemm->kernelType.scaleAMode));
AssertFatal(gemm->kernelType.scaleBMode == Operations::ScaleMode::None
|| gemm->kernelType.scaleBMode == Operations::ScaleMode::SingleScale
|| gemm->kernelType.scaleBMode == Operations::ScaleMode::Separate,
ShowValue(gemm->kernelType.scaleTypeA.mode));
AssertFatal(gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::None
|| gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::SingleScale
|| gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate,
"Scale mode not supported!",
ShowValue(gemm->kernelType.scaleBMode));
ShowValue(gemm->kernelType.scaleTypeB.mode));

std::optional<Operations::OperationTag> tagTensorScaleA, tagLoadScaleA, tagBlockScaleA,
tagTensorScaleB, tagLoadScaleB, tagBlockScaleB, tagScratch, tagSKGrid, tagWGM;

if(gemm->kernelType.scaleAMode == Operations::ScaleMode::Separate)
if(gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate)
{
tagTensorScaleA = command->addOperation(rocRoller::Operations::Tensor(
2,
gemm->kernelType.scaleTypeA,
gemm->kernelType.scaleTypeA.type,
gemm->kernelType.transA ? oneStridesT : oneStridesN));
tagLoadScaleA
= command->addOperation(rocRoller::Operations::T_Load_Tiled(*tagTensorScaleA));
Expand All @@ -185,14 +185,14 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
tagLoadA,
2,
tagLoadScaleA,
{gemm->kernelType.scaleABlockColSize, gemm->kernelType.scaleABlockRowSize}));
{gemm->kernelType.scaleTypeA.blockColSize, gemm->kernelType.scaleTypeA.blockRowSize}));
}

if(gemm->kernelType.scaleBMode == Operations::ScaleMode::Separate)
if(gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate)
{
tagTensorScaleB = command->addOperation(rocRoller::Operations::Tensor(
2,
gemm->kernelType.scaleTypeA,
gemm->kernelType.scaleTypeA.type,
gemm->kernelType.transB ? oneStridesT : oneStridesN));
tagLoadScaleB
= command->addOperation(rocRoller::Operations::T_Load_Tiled(*tagTensorScaleB));
Expand All @@ -201,7 +201,7 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
tagLoadB,
2,
tagLoadScaleB,
{gemm->kernelType.scaleBBlockColSize, gemm->kernelType.scaleBBlockRowSize}));
{gemm->kernelType.scaleTypeB.blockColSize, gemm->kernelType.scaleTypeB.blockRowSize}));
}

auto tagTensorC
Expand Down Expand Up @@ -330,11 +330,11 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
params->setDimensionInfo(tagLoadA, macTileA);
}

if(gemm->kernelType.scaleAMode == Operations::ScaleMode::Separate)
if(gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate)
{
// TODO: verify the division of scale block size is correct
auto const scaleBlockSize
= gemm->kernelType.scaleABlockRowSize * gemm->kernelType.scaleABlockColSize;
= gemm->kernelType.scaleTypeA.blockRowSize * gemm->kernelType.scaleTypeA.blockColSize;
auto macTileAScale = KernelGraph::CoordinateGraph::MacroTile(
{gemm->workgroupTile.m, gemm->workgroupTile.k / (int)scaleBlockSize},
LayoutType::MATRIX_A,
Expand All @@ -358,11 +358,11 @@ std::shared_ptr<GemmKernel> genGemmKernel(std::shared_ptr<SolutionParameters> ge
params->setDimensionInfo(tagLoadB, macTileB);
}

if(gemm->kernelType.scaleBMode == Operations::ScaleMode::Separate)
if(gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate)
{
// TODO: verify the division of scale block size is correct
auto const scaleBlockSize
= gemm->kernelType.scaleBBlockRowSize * gemm->kernelType.scaleBBlockColSize;
= gemm->kernelType.scaleTypeB.blockRowSize * gemm->kernelType.scaleTypeB.blockColSize;
auto macTileBScale = KernelGraph::CoordinateGraph::MacroTile(
{gemm->workgroupTile.k / (int)scaleBlockSize, gemm->workgroupTile.n},
LayoutType::MATRIX_B,
Expand Down Expand Up @@ -547,15 +547,15 @@ CommandArguments createCommandArguments(std::shared_ptr<GemmKernel> gemm,
setCommandTensorArg(commandArgs, gemm->tagTensorA, descA, (float*)nullptr);
setCommandTensorArg(commandArgs, gemm->tagTensorB, descB, (float*)nullptr);

if(gemm->params->kernelType.scaleAMode == Operations::ScaleMode::Separate)
if(gemm->params->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate)
{
auto const scaleBlockSize = prob.scaleABlockRowSize * prob.scaleABlockColSize;
TensorDescriptor descAScale(gemm->params->kernelType.typeA,
{size_t(M), size_t(K / scaleBlockSize)},
gemm->params->kernelType.transA ? "T" : "N");
setCommandTensorArg(commandArgs, gemm->tagTensorScaleA, descAScale, (float*)nullptr);
}
if(gemm->params->kernelType.scaleBMode == Operations::ScaleMode::Separate)
if(gemm->params->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate)
{
auto const scaleBlockSize = prob.scaleBBlockRowSize * prob.scaleBBlockColSize;
TensorDescriptor descBScale(gemm->params->kernelType.typeB,
Expand All @@ -578,12 +578,12 @@ CommandArguments createCommandArguments(std::shared_ptr<GemmKernel> gemm,
commandArgs.setArgument(gemm->tagTensorC, ArgumentType::Value, (float*)prob.C);
commandArgs.setArgument(gemm->tagTensorD, ArgumentType::Value, (float*)prob.D);

if(gemm->params->kernelType.scaleAMode == Operations::ScaleMode::Separate)
if(gemm->params->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate)
{
commandArgs.setArgument(gemm->tagTensorScaleA, ArgumentType::Value, (uint8_t*)prob.scaleA);
}

if(gemm->params->kernelType.scaleBMode == Operations::ScaleMode::Separate)
if(gemm->params->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate)
{
commandArgs.setArgument(gemm->tagTensorScaleB, ArgumentType::Value, (uint8_t*)prob.scaleB);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,30 @@
#include <rocRoller/DataTypes/DataTypes.hpp>
#include <rocRoller/Operations/Command.hpp>

struct ScaleType
{
rocRoller::Operations::ScaleMode mode;
size_t blockRowSize = 32u;
size_t blockColSize = 1u;
rocRoller::DataType type = rocRoller::DataType::E8M0;

auto operator<=>(const ScaleType& other) const = default;
};

template<>
struct std::hash<ScaleType>
{
size_t operator()(const ScaleType& s) const noexcept
{
size_t modeHash = std::hash<rocRoller::Operations::ScaleMode>{}(s.mode);
size_t blockRowSizeHash = std::hash<size_t>{}(s.blockRowSize);
size_t blockColSizeHash = std::hash<size_t>{}(s.blockColSize);
size_t typeHash = std::hash<rocRoller::DataType>{}(s.type);

return modeHash ^ (blockRowSizeHash << 1) ^ (blockColSizeHash << 2) ^ (typeHash << 3);
}
};

/**
* @brief KernelType
*
Expand All @@ -48,16 +72,29 @@ struct KernelType
bool transA;
bool transB;

rocRoller::Operations::ScaleMode scaleAMode;
rocRoller::Operations::ScaleMode scaleBMode;
ScaleType scaleTypeA;
ScaleType scaleTypeB;

size_t scaleABlockRowSize = 32u;
size_t scaleABlockColSize = 1u;
size_t scaleBBlockRowSize = 1u;
size_t scaleBBlockColSize = 32u;
auto operator<=>(const KernelType& other) const = default;
};

rocRoller::DataType scaleTypeA = rocRoller::DataType::E8M0;
rocRoller::DataType scaleTypeB = rocRoller::DataType::E8M0;
template<>
struct std::hash<KernelType>
{
size_t operator()(const KernelType& k) const noexcept
{
size_t typeAHash = std::hash<rocRoller::DataType>{}(k.typeA);
size_t typeBHash = std::hash<rocRoller::DataType>{}(k.typeB);
size_t typeCHash = std::hash<rocRoller::DataType>{}(k.typeC);
size_t typeDHash = std::hash<rocRoller::DataType>{}(k.typeD);
size_t typeAccHash = std::hash<rocRoller::DataType>{}(k.typeAcc);
size_t scaleTypeAHash = std::hash<ScaleType>{}(k.scaleTypeA);
size_t scaleTypeBHash = std::hash<ScaleType>{}(k.scaleTypeB);
size_t transAHash = std::hash<bool>{}(k.transB);
Comment thread
bnemanich marked this conversation as resolved.
Outdated
size_t transBHash = std::hash<bool>{}(k.transB);

auto operator<=>(const KernelType& other) const = default;
return typeAHash ^ (typeBHash << 1) ^ (typeCHash << 2) ^
(typeDHash << 3) ^ (typeAccHash << 4) ^ (scaleTypeAHash << 5) ^
(scaleTypeBHash << 6) ^ (transAHash << 7) ^ (transBHash << 8);
}
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*! \file */
/* ************************************************************************
*
* MIT License
*
* Copyright (C) 2025 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
* ************************************************************************ */

#pragma once

#include "gemm.hpp"
#include "kernel_type.hpp"
#include "solution_selection.hpp"

class SolutionCache
{
public:
void addKernel(const KernelType& kernelType, const SolutionIndexParameters& params, std::shared_ptr<GemmKernel> kernel);
std::optional<std::shared_ptr<GemmKernel>> getKernel(const KernelType& kernelType, const SolutionIndexParameters& params);

private:
// Map of kernels that have already been generated.
// The first level of the map is indexed with a KernelType.
// The second level of the map is indexed with a hash value of a
// SolutionIndexParameters type.
// The value is a GemmKernel.
std::unordered_map<KernelType, std::unordered_map<int, std::shared_ptr<GemmKernel>>> m_generatedKernels;
};
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct WorkGroupTileSize
int m;
int n;
int k;

auto operator<=>(const WorkGroupTileSize& other) const = default;
};

/**
Expand Down Expand Up @@ -69,6 +71,8 @@ struct SolutionIndexParameters
{
WorkGroupTileSize workgroupTile;
bool workgroupMapping;

auto operator<=>(const SolutionIndexParameters& other) const = default;
};

int parametersToIndex(const SolutionIndexParameters& params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ std::shared_ptr<SolutionParameters>

// Swizzle Scale only support in certain situations
// Swizzle Scale also runs out of registers with FP8
if (kernelType.scaleAMode != rocRoller::Operations::ScaleMode::Separate ||
kernelType.scaleBMode != rocRoller::Operations::ScaleMode::Separate)
if (kernelType.scaleTypeA.mode != rocRoller::Operations::ScaleMode::Separate ||
kernelType.scaleTypeB.mode != rocRoller::Operations::ScaleMode::Separate)
{
gemm->swizzleScale = false;
gemm->prefetchScale = false;
Expand Down Expand Up @@ -161,18 +161,18 @@ std::shared_ptr<SolutionParameters>
// LDS can only be used for scaling data with certain workgroup tile sizes
auto workgroupSizeTotal = gemm->workgroupSizeX * gemm->workgroupSizeY;
auto numScaleElementsA = 0;
if(gemm->kernelType.scaleABlockRowSize * gemm->kernelType.scaleABlockColSize != 0)
if(gemm->kernelType.scaleTypeA.blockRowSize * gemm->kernelType.scaleTypeA.blockColSize != 0)
{
numScaleElementsA = gemm->workgroupTile.m
* (gemm->workgroupTile.k
/ (gemm->kernelType.scaleABlockRowSize * gemm->kernelType.scaleABlockColSize));
/ (gemm->kernelType.scaleTypeA.blockRowSize * gemm->kernelType.scaleTypeA.blockColSize));
}
auto numScaleElementsB = 0;
if(gemm->kernelType.scaleBBlockRowSize * gemm->kernelType.scaleBBlockColSize != 0)
if(gemm->kernelType.scaleTypeB.blockRowSize * gemm->kernelType.scaleTypeB.blockColSize != 0)
{
numScaleElementsB = gemm->workgroupTile.n
* (gemm->workgroupTile.k
/ (gemm->kernelType.scaleBBlockRowSize * gemm->kernelType.scaleBBlockColSize));
/ (gemm->kernelType.scaleTypeB.blockRowSize * gemm->kernelType.scaleTypeB.blockColSize));
}
if(numScaleElementsA % workgroupSizeTotal != 0)
{
Expand Down
Loading
Loading