Skip to content
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
e6b1861
Rebase Muhammad's changes with develop
NaveenElumalaiAMD Nov 17, 2025
949bab2
Address the build failure
NaveenElumalaiAMD Nov 18, 2025
eff1634
Added the negative occupancy and deterministic tests to test_origami.cpp
NaveenElumalaiAMD Nov 19, 2025
9912b98
Address Muhammad's comments
NaveenElumalaiAMD Nov 19, 2025
445f649
Changes to origami tests
NaveenElumalaiAMD Nov 19, 2025
007f871
Fix python binding build
yenong-amd Nov 20, 2025
970563d
Reset CMake variables to its default value
NaveenElumalaiAMD Nov 20, 2025
9e3406a
minor changes in ContractionSolution.cpp
NaveenElumalaiAMD Nov 20, 2025
d1abdf1
name change: rank_config to rank_configs
NaveenElumalaiAMD Nov 20, 2025
67c04da
Fix python tests.
bethune-bryant Nov 20, 2025
518112d
Fix python tests
yenong-amd Nov 21, 2025
ba8ef6f
Address CI failures
NaveenElumalaiAMD Nov 24, 2025
b0a1b20
Ranking test
yenong-amd Nov 21, 2025
774e713
Ranking tests
yenong-amd Nov 22, 2025
a89829a
Fix refactoring results
yenong-amd Nov 24, 2025
c6b140b
change WGM to int
yenong-amd Nov 24, 2025
5d85502
select_ranked_configs to rank_configs and readding topk.
neoblizz Nov 25, 2025
18456c3
fix prediction lib serialization
yenong-amd Nov 25, 2025
96afcae
updated python bindings
yenong-amd Nov 25, 2025
e23fe4b
Address Bryant and Ali comments
NaveenElumalaiAMD Nov 25, 2025
23eb930
Match ranking tests BF16, FP32, TF32
yenong-amd Nov 25, 2025
476c083
Port test_variance_issue.cpp to test_origami.cpp
NaveenElumalaiAMD Nov 26, 2025
55c24a9
Remove gtest
NaveenElumalaiAMD Nov 26, 2025
8c72297
Fix dot2 instructions
yenong-amd Nov 26, 2025
5f2ff3b
Merge branch 'develop' into users/nelumala/origami/origami-refactor
yenong-amd Nov 26, 2025
71bc494
clang-formatted
NaveenElumalaiAMD Nov 26, 2025
ceaa2c2
Adding a logging utility.
neoblizz Nov 26, 2025
f1082fc
Separate impl and header + logging + cleanup.
neoblizz Nov 26, 2025
0a838ab
New cpp files.
neoblizz Nov 26, 2025
c7fdec4
Clean-up: Remove duplicated extract functions (moved to logger).
neoblizz Nov 26, 2025
2061c9e
Use config.logger.log
neoblizz Nov 26, 2025
436ac98
Missing include.
neoblizz Nov 26, 2025
9d1a9b2
Clean-up.
neoblizz Nov 26, 2025
f79e6a2
More logging.
neoblizz Nov 26, 2025
8968a4c
Remove old debug.
neoblizz Nov 26, 2025
2f130ae
Fix rocroller's build.
neoblizz Nov 26, 2025
44154e0
Trying w/out constexpr.
neoblizz Nov 27, 2025
7bdcc4a
Fixes for WGM and dot2
minsukim-amd Nov 28, 2025
6cb2863
remove duplicate lines
yenong-amd Dec 1, 2025
fd13db3
Change WGM and WGMXCC to int, matching sizemapping
yenong-amd Dec 1, 2025
0fd4fb1
Removed heuristics
yenong-amd Dec 1, 2025
2554d3a
Merge branch 'develop' into users/nelumala/origami/origami-refactor
neoblizz Dec 2, 2025
a248f55
Fix autoWGM bug
yenong-amd Dec 2, 2025
a02d8f3
Merge branch 'develop' into users/nelumala/origami/origami-refactor
yenong-amd Dec 2, 2025
9000801
Put heuristics back
minsukim-amd Dec 2, 2025
28646a5
Separate sorting for square
yenong-amd Dec 2, 2025
6b53310
minor fix
minsukim-amd Dec 2, 2025
5f9c3f5
Fix bug in cvt computation
yenong-amd Dec 3, 2025
7775816
Read wgm from sizemapping for non-temporal A/B
yenong-amd Dec 3, 2025
803fc6f
Fix defaultWGM
minsukim-amd Dec 3, 2025
8a47ea7
remove debug print, fix for defaultwgm
minsukim-amd Dec 3, 2025
ece677c
fix for origami config hash key
minsukim-amd Dec 4, 2025
8247963
update tie-breaker for WGM, NT, occupancy
minsukim-amd Dec 4, 2025
3c968b0
Remove batch==1 condition in NT heuristics. Remove WGM, NT, occupancy…
minsukim-amd Dec 4, 2025
5ac1a36
Fix conversion from rocisa DataType to origami data_type_t
AlexBrownAMD Dec 4, 2025
d26b372
Fix mutable in config
yenong-amd Dec 5, 2025
de49995
align function names in origami_module
minsukim-amd Dec 5, 2025
72c2fa0
add initializers to problem_t and config_t
minsukim-amd Dec 5, 2025
aa0f0c5
Add default value mappings in origami_module. Fix a bug in origami test
minsukim-amd Dec 5, 2025
99dee66
set default value max_cus=0
minsukim-amd Dec 5, 2025
aa0fce7
updated copyright
minsukim-amd Dec 5, 2025
ba05c37
updated function descriptions
minsukim-amd Dec 5, 2025
75e3944
Merge branch 'develop' into users/nelumala/origami/origami-refactor
minsukim-amd Dec 5, 2025
1e6e1c9
fix a bug in origami_module
minsukim-amd Dec 5, 2025
a56152b
remove dupcliated copyright
minsukim-amd Dec 5, 2025
7147fa7
relocate shebang
minsukim-amd Dec 5, 2025
1a25067
update README.md
minsukim-amd Dec 5, 2025
56a1edb
Merge branch 'develop' into users/nelumala/origami/origami-refactor
neoblizz Dec 6, 2025
2d68ed0
Merge branch 'develop' into users/nelumala/origami/origami-refactor
minsukim-amd Dec 8, 2025
eae075f
Merge branch 'develop' into users/nelumala/origami/origami-refactor
neoblizz Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

#include <rocRoller/DataTypes/DataTypes.hpp>

#include <origami/utils.hpp>
#include "origami/types.hpp"

/**
* @brief Convert rocRoller::Datatype to analytical::DataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,55 +28,52 @@
#include "gemm.hpp"
#include "runtime_args_selection.hpp"

#include <origami/streamk.hpp>

const int DEFAULT_DYNAMIC_MODE = 6;
#include "origami/streamk.hpp"

int chooseStreamKGridSize(std::shared_ptr<GemmKernel> gemm,
const RocblasltContractionProblem& prob)
{
const origami::hardware_t analaytical_hardware = origami::hardware_t::get_hardware_for_device(0);
const origami::hardware_t analytical_hardware = origami::hardware_t::get_hardware_for_device(0);

const origami::grid_selection_t DEFAULT_DYNAMIC_MODE = origami::grid_selection_t::k_split_aware;

//setting max_cu's
size_t max_cus = analytical_hardware.N_CU;

size_t elementSizeA_bits = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeA).elementBits;
size_t elementSizeB_bits = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeB).elementBits;
size_t elementSizeD_bits = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeD).elementBits;
size_t elementSizeAcc = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeAcc).elementBytes;

origami::data_type_t dataType;
if (elementSizeA_bits < elementSizeB_bits)
dataType = rocroller_type_to_analytical_type(gemm->params->kernelType.typeB);
else
dataType = rocroller_type_to_analytical_type(gemm->params->kernelType.typeA);
origami::problem_t origami_problem = {
.size = {prob.m, prob.n, prob.k},
.batch = prob.batch_count,
.a_dtype = rocroller_type_to_analytical_type(gemm->params->kernelType.typeA),
.b_dtype = rocroller_type_to_analytical_type(gemm->params->kernelType.typeB),
.mi_dtype = rocroller_type_to_analytical_type(elementSizeA_bits < elementSizeB_bits ? gemm->params->kernelType.typeB : gemm->params->kernelType.typeA),
};
origami::config_t origami_config = {
.mt = {
static_cast<size_t>(gemm->params->workgroupTile.m),
static_cast<size_t>(gemm->params->workgroupTile.n),
static_cast<size_t>(gemm->params->workgroupTile.k)
},
.occupancy = gemm->occupancy,
.workspace_size = prob.workspaceSize,
.workspace_size_per_elem_c = elementSizeAcc,
};

auto reduction_type = origami::streamk::select_reduction(origami_problem,
analytical_hardware,
origami_config,
DEFAULT_DYNAMIC_MODE);

auto reduction_type = origami::streamk::select_reduction(prob.m, prob.n, prob.k, prob.batch_count,
gemm->params->workgroupTile.m, gemm->params->workgroupTile.n, gemm->params->workgroupTile.k, analaytical_hardware, DEFAULT_DYNAMIC_MODE);
// Override reduction type to tree reduction for now.
// When Parallel reduction is available, this line can be removed
reduction_type = origami::streamk::reduction_type::Tree;
origami_config.reduction_strategy = reduction_type;

auto result = origami::streamk::select_grid(prob.m,
prob.n,
prob.k,
prob.batch_count,
prob.trans_a == HIPBLAS_OP_T,
prob.trans_b == HIPBLAS_OP_T,
elementSizeA_bits,
elementSizeB_bits,
elementSizeD_bits,
dataType,
prob.workspaceSize,
gemm->params->workgroupTile.m,
gemm->params->workgroupTile.n,
gemm->params->workgroupTile.k,
gemm->params->machineInstruction.m,
gemm->params->machineInstruction.n,
gemm->params->machineInstruction.k,
DEFAULT_WGM,
elementSizeAcc,
gemm->occupancy,
analaytical_hardware,
DEFAULT_DYNAMIC_MODE,
reduction_type);
auto result = origami::streamk::select_grid_size(origami_problem,
analytical_hardware,
origami_config,
DEFAULT_DYNAMIC_MODE,
max_cus);

return result;
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#include "runtime_args_selection.hpp"
#include "solution_selection.hpp"

#include <origami/utils.hpp>
#include "origami/origami.hpp"

const int MAX_BITS_WORKGROUPTILE_M = 8;
const int MAX_BITS_WORKGROUPTILE_N = 8;
Expand All @@ -44,7 +44,9 @@ const int USE_WORKGROUP_MAPPING_K_SIZE = 4096;
* compile-time known.
*/

constexpr std::array<WorkGroupTileSize, 34> possibleTileSizes = {{
constexpr size_t possibleTileSizesCount = 34;

constexpr std::array<WorkGroupTileSize, possibleTileSizesCount> possibleTileSizes = {{
{256, 256, 128},
{256, 192, 128},
{256, 128, 128},
Expand Down Expand Up @@ -82,10 +84,10 @@ const int USE_WORKGROUP_MAPPING_K_SIZE = 4096;
}};

template <rocRoller::DataType typeA, rocRoller::DataType typeB>
constexpr auto generateTileList() {
std::array<origami::tile_tuple, possibleTileSizes.size()> tileList{};
auto generateTileList() {
std::array<origami::config_t, possibleTileSizesCount> tileList{};

for (size_t i = 0; i < possibleTileSizes.size(); ++i) {
for (size_t i = 0; i < possibleTileSizesCount; ++i) {
const auto& wgt = possibleTileSizes[i];
auto MI = pickMI(typeA, typeB, wgt);

Expand All @@ -96,27 +98,33 @@ constexpr auto generateTileList() {

int unroll = preferredUnrolling(typeA, typeB, wgt);

int non_temporal_a = 0;
int non_temporal_b = 0;

tileList[i] = std::make_tuple(
wgt.m, wgt.n, wgtk * unroll,
MI.m, MI.n, MI.k,
1, // occupancy
DEFAULT_WGM,
non_temporal_a,
non_temporal_b
);
origami::config_t origami_config = {
.mt = {
static_cast<size_t>(wgt.m),
static_cast<size_t>(wgt.n),
static_cast<size_t>(wgtk * unroll)
},
.mi = {
static_cast<size_t>(MI.m),
static_cast<size_t>(MI.n),
static_cast<size_t>(MI.k)
},
.occupancy = 1,
.cache_hints_a = 0,
.cache_hints_b = 0,
};

tileList[i] = origami_config;
}

return tileList;
}

using TileListGeneratorFn = std::vector<origami::tile_tuple>(*)();
using TileListGeneratorFn = std::vector<origami::config_t>(*)();

template <rocRoller::DataType A, rocRoller::DataType B>
std::vector<origami::tile_tuple> generateTileListWrapper() {
constexpr auto arr = generateTileList<A, B>();
std::vector<origami::config_t> generateTileListWrapper() {
auto arr = generateTileList<A, B>();
return {arr.begin(), arr.end()};
}

Expand Down Expand Up @@ -144,7 +152,7 @@ const std::map<std::pair<rocRoller::DataType, rocRoller::DataType>, TileListGene
INSTANTIATE_TILE_LIST_FOR(FP6)
};

std::vector<origami::tile_tuple> getTileListForKernelType(KernelType kernelType)
std::vector<origami::config_t> getTileListForKernelType(KernelType kernelType)
{
auto key = std::make_pair(kernelType.typeA, kernelType.typeB);
auto it = tileListGenerators.find(key);
Expand All @@ -170,43 +178,42 @@ std::vector<SolutionIndexParameters> chooseSolutionIndexParameters(
{
std::vector<SolutionIndexParameters> params;

std::vector<origami::tile_tuple> tile_list = getTileListForKernelType(kernelType);
std::vector<origami::config_t> origami_config_list = getTileListForKernelType(kernelType);

size_t elementSizeA_bits = rocRoller::DataTypeInfo::Get(kernelType.typeA).elementBits;
size_t elementSizeB_bits = rocRoller::DataTypeInfo::Get(kernelType.typeB).elementBits;
size_t elementSizeC_bits = rocRoller::DataTypeInfo::Get(kernelType.typeC).elementBits;

origami::data_type_t dataType;
if (elementSizeA_bits < elementSizeB_bits)
dataType = rocroller_type_to_analytical_type(kernelType.typeB);
else
dataType = rocroller_type_to_analytical_type(kernelType.typeA);

const origami::hardware_t analaytical_hardware = origami::hardware_t::get_hardware_for_device(0);

int WGM = std::sqrt(std::floor(analaytical_hardware.N_CU / analaytical_hardware.NUM_XCD));

auto selected_tiles = origami::select_best_macro_tile_size(
prob.m,
prob.n,
prob.k,
prob.batch_count,
prob.trans_a == hipblasOperation_t::HIPBLAS_OP_T,
prob.trans_b == hipblasOperation_t::HIPBLAS_OP_T,
analaytical_hardware,
tile_list,
elementSizeA_bits,
elementSizeB_bits,
elementSizeC_bits,
dataType,
kernelType.scaleABlockRowSize * kernelType.scaleABlockColSize, //Handle A vs B block size.
0.8,
false,
WGM);

for(auto const& selected_tile : selected_tiles)

const origami::hardware_t analytical_hardware = origami::hardware_t::get_hardware_for_device(0);

origami::problem_t origami_problem = {
.size = {prob.m, prob.n, prob.k},
.batch = prob.batch_count,
.a_transpose = (prob.trans_a == hipblasOperation_t::HIPBLAS_OP_T) ? origami::transpose_t::T : origami::transpose_t::N,
.b_transpose = (prob.trans_b == hipblasOperation_t::HIPBLAS_OP_T) ? origami::transpose_t::T : origami::transpose_t::N,
.a_dtype = rocroller_type_to_analytical_type(kernelType.typeA),
.b_dtype = rocroller_type_to_analytical_type(kernelType.typeB),
.mi_dtype = rocroller_type_to_analytical_type(elementSizeA_bits < elementSizeB_bits ? kernelType.typeB : kernelType.typeA),
.a_mx_block_size = kernelType.scaleABlockRowSize * kernelType.scaleABlockColSize,
.b_mx_block_size = kernelType.scaleBBlockRowSize * kernelType.scaleBBlockColSize,
};

int defaultWGM = std::ceil(std::sqrt(analytical_hardware.N_CU / analytical_hardware.NUM_XCD));
for (auto& config : origami_config_list) {
config.workgroup_mapping = defaultWGM;
}

auto prediction_result = origami::rank_configs(
origami_problem,
analytical_hardware,
origami_config_list
);

for(auto const& result : prediction_result)
{
WorkGroupTileSize wgt{(int)std::get<1>(selected_tile), (int)std::get<2>(selected_tile), (int)std::get<3>(selected_tile)};
auto mt_m = static_cast<int>(result.config.mt.m);
auto mt_n = static_cast<int>(result.config.mt.n);
auto mt_k = static_cast<int>(result.config.mt.k);
WorkGroupTileSize wgt{mt_m, mt_n, mt_k};
int unrollAmount = preferredUnrolling(kernelType.typeA, kernelType.typeB, wgt);
wgt.k /= unrollAmount;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
#include <Tensile/Task.hpp>
#include <Tensile/Utils.hpp>

#include <origami/streamk.hpp>
#include "origami/origami.hpp"
#include "origami/streamk.hpp"

#define TENSILE_COMMON_KERNEL_ARGS_SIZE 16

Expand Down Expand Up @@ -166,7 +167,7 @@ namespace TensileLite

struct StreamKSettings
{
origami::streamk::reduction_type reduction = origami::streamk::reduction_type::Tree;
origami::reduction_t reduction = origami::reduction_t::tree;
size_t grid = 0;
};

Expand All @@ -183,7 +184,7 @@ namespace TensileLite
using Problem = ContractionProblemGemm;
using Inputs = ContractionInputs;
using GroupedInputs = ContractionGroupedInputs;
using ParamsCache = CacheMap<std::pair<int32_t, uint32_t>, Problem>;
using ParamsCache = CacheMap<std::pair<int32_t, int32_t>, Problem>;

/**
* Indicate a solution is equally or estimatedly matched.
Expand Down Expand Up @@ -218,6 +219,11 @@ namespace TensileLite
}
virtual bool isFallbackForHW(Hardware const&) const;

bool isStreamK() const
{
return sizeMapping.streamK > 0;
}

//! Estimates based on problem size, solution tile, and machine hardware
//! charz:
struct StaticPerformanceModel
Expand Down Expand Up @@ -290,8 +296,8 @@ namespace TensileLite
void calculateGrid(dim3& workGroupSize,
dim3& numWorkGroups,
ContractionSolution::Problem const& problem) const;
origami::streamk::reduction_type getSKReduction(Problem const& problem, Hardware const& hardware) const;
size_t getSKGrid(Problem const& problem, Hardware const& hardware, size_t tiles, origami::streamk::reduction_type& reductionStrat) const;
origami::reduction_t getSKReduction(Problem const& problem, Hardware const& hardware) const;
size_t getSKGrid(Problem const& problem, Hardware const& hardware, size_t tiles, origami::reduction_t reductionStrat) const;
size_t partialTileSize(size_t skGrid) const;

static float computeGranularity(float x);
Expand Down Expand Up @@ -566,9 +572,9 @@ namespace TensileLite
uint32_t magicNumber(int magicDivAlg, uint32_t x, uint32_t* magicShift) const;
uint32_t smallMagicNumber(uint32_t x) const;

std::pair<int32_t, uint32_t> calculateAutoWGM(Problem const& problem,
Hardware const* hardware,
uint32_t skgrid) const;
std::pair<int32_t, int32_t> calculateAutoWGM(Problem const& problem,
Hardware const* hardware,
uint32_t skgrid) const;
uint32_t calculateAutoGSU(Problem const& problem, Hardware const* hardware) const;
};

Expand Down
Loading
Loading