Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions shared/rocroller/client/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ target_sources(rocroller-gemm
"${CMAKE_CURRENT_SOURCE_DIR}/include/client/GEMMSolution.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/include/client/GraphInspector.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/include/client/GraphInspector_impl.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/include/client/PreSwizzle.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/include/client/StreamKGEMMSolution.hpp"
"${CMAKE_CURRENT_SOURCE_DIR}/include/client/visualize.hpp"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,23 @@ namespace rocRoller
m_tagLoadScaleA = command->addOperation(
rocRoller::Operations::T_Load_Tiled(m_tagTensorScaleA.value()));

auto scaleInputA = m_tagLoadScaleA;

if(solutionParams.types.scaleSkipPermlane)
{
AssertFatal(solutionParams.types.scaleShuffleTileA.size() == 3,
ShowValue(solutionParams.types.scaleShuffleTileA));

scaleInputA
= command->addOperation(rocRoller::Operations::SubTileTranspose(
*m_tagLoadScaleA, solutionParams.types.scaleShuffleTileA));
}

m_tagBlockScaleA = mulInputA
= command->addOperation(rocRoller::Operations::BlockScale(
m_tagA,
2,
m_tagLoadScaleA,
scaleInputA,
{1,
static_cast<unsigned long>(solutionParams.types.scaleBlockSize)}));
}
Expand All @@ -160,11 +172,22 @@ namespace rocRoller
m_tagLoadScaleB = command->addOperation(
rocRoller::Operations::T_Load_Tiled(m_tagTensorScaleB.value()));

auto scaleInputB = m_tagLoadScaleB;

if(solutionParams.types.scaleSkipPermlane)
{
AssertFatal(solutionParams.types.scaleShuffleTileB.size() == 3);

scaleInputB
= command->addOperation(rocRoller::Operations::SubTileTranspose(
*m_tagLoadScaleB, solutionParams.types.scaleShuffleTileB));
}

m_tagBlockScaleB = mulInputB
= command->addOperation(rocRoller::Operations::BlockScale(
m_tagB,
2,
m_tagLoadScaleB,
scaleInputB,
{static_cast<unsigned long>(solutionParams.types.scaleBlockSize),
1}));
}
Expand Down
4 changes: 4 additions & 0 deletions shared/rocroller/client/include/client/GEMMParameters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ namespace rocRoller

bool scaleSkipPermlane = false;

// Order: M/N, K tile, K subtile
std::vector<size_t> scaleShuffleTileA;
std::vector<size_t> scaleShuffleTileB;
Comment thread
sdquiring marked this conversation as resolved.

std::string kernelNamePart() const;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ namespace rocRoller::Serialization

iot::mapRequired(io, "scaleBlockSize", params.scaleBlockSize);
iot::mapRequired(io, "scaleSkipPermlane", params.scaleSkipPermlane);

iot::mapRequired(io, "scaleShuffleTileA", params.scaleShuffleTileA);
iot::mapRequired(io, "scaleShuffleTileB", params.scaleShuffleTileB);
}

static void mapping(IO& io, Client::GEMMClient::TypeParameters& params, EmptyContext& ctx)
Expand Down
5 changes: 5 additions & 0 deletions shared/rocroller/client/include/client/GraphInspector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ namespace rocRoller

KernelGraph::CoordinateGraph::Transformer& tx();

KernelGraph::KernelGraphPtr graph()
{
return m_kgraph;
};

private:
void assignLiteralSizesAndStrides();

Expand Down
87 changes: 87 additions & 0 deletions shared/rocroller/client/include/client/PreSwizzle.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*******************************************************************************
*
* MIT License
*
* Copyright 2024-2025 AMD ROCm(TM) Software
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/

#pragma once

#include <vector>

#include <rocRoller/TensorDescriptor.hpp>

#include <client/GEMMParameters.hpp>

namespace rocRoller::Client
{

template <typename T>
inline std::vector<T> preSwizzle(std::vector<T> const& input,
TensorDescriptor const& desc,
std::vector<size_t> const& tile)
{
AssertFatal(tile.size() == 3, ShowValue(tile.size()), ShowValue(tile));
AssertFatal(desc.dimensions() == 2,
"Batch dimension not yet supported.",
ShowValue(desc.dimensions()),
ShowValue(desc));
AssertFatal(desc.totalAllocatedElements() == input.size(),
ShowValue(desc),
ShowValue(input.size()));

auto tileMN = tile[0];
auto tileK = tile[1];
auto subTileK = tile[2];

size_t instPerTileK = tileK / subTileK;
size_t instKPerTileMN = tileMN / subTileK;

std::vector<size_t> srcSizes = {subTileK,
instPerTileK,
desc.size(0) / (tileK),
instKPerTileMN,
subTileK,
desc.size(1) / (tileMN)};

TensorDescriptor src(desc.dataType(), srcSizes);

AssertFatal(src.totalAllocatedElements() == desc.totalAllocatedElements(),
ShowValue(src.totalAllocatedElements()),
ShowValue(desc.totalAllocatedElements()),
ShowValue(src.totalAllocatedElements() / desc.totalAllocatedElements()),
ShowValue(src),
ShowValue(desc));

auto dst
= TensorDescriptor::ShuffledNoPadding(desc.dataType(), srcSizes, {4, 1, 2, 3, 0, 5});

AssertFatal(src.totalAllocatedElements() == dst.totalAllocatedElements(),
ShowValue(src.totalAllocatedElements()),
ShowValue(dst.totalAllocatedElements()),
ShowValue(src),
ShowValue(dst));

return shuffleDims(input, dst, src);
}

}
4 changes: 3 additions & 1 deletion shared/rocroller/client/src/GEMMParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ namespace rocRoller
rv << "_" << t;

if(scaleSkipPermlane)
rv << "_PRE_SW";
{
rv << "_PreSW_AB";
}

return rv.str();
}
Expand Down
Loading