diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/analytical_utils.hpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/analytical_utils.hpp index 81e08a2c904..17f22105033 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/analytical_utils.hpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/analytical_utils.hpp @@ -33,7 +33,7 @@ #include -#include "origami/types.hpp" +#include /** * @brief Convert rocRoller::Datatype to analytical::DataType diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/runtime_args_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/runtime_args_selection.cpp index 372c7f6b970..ca4da79952d 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/runtime_args_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/runtime_args_selection.cpp @@ -28,52 +28,55 @@ #include "gemm.hpp" #include "runtime_args_selection.hpp" -#include "origami/streamk.hpp" +#include + +const int DEFAULT_DYNAMIC_MODE = 6; int chooseStreamKGridSize(std::shared_ptr gemm, const RocblasltContractionProblem& prob) { - const origami::hardware_t analytical_hardware = origami::hardware_t::get_hardware_for_device(0); - - const origami::grid_selection_t DEFAULT_DYNAMIC_MODE = origami::grid_selection_t::k_split_aware; - - //setting max_cu's - size_t max_cus = analytical_hardware.N_CU; + const origami::hardware_t analaytical_hardware = origami::hardware_t::get_hardware_for_device(0); size_t elementSizeA_bits = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeA).elementBits; size_t elementSizeB_bits = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeB).elementBits; + size_t elementSizeD_bits = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeD).elementBits; size_t elementSizeAcc = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeAcc).elementBytes; - origami::problem_t origami_problem = { - .size = {prob.m, prob.n, prob.k}, - .batch = prob.batch_count, - .a_dtype = rocroller_type_to_analytical_type(gemm->params->kernelType.typeA), - .b_dtype = rocroller_type_to_analytical_type(gemm->params->kernelType.typeB), - .mi_dtype = rocroller_type_to_analytical_type(elementSizeA_bits < elementSizeB_bits ? gemm->params->kernelType.typeB : gemm->params->kernelType.typeA), - }; - origami::config_t origami_config = { - .mt = { - static_cast(gemm->params->workgroupTile.m), - static_cast(gemm->params->workgroupTile.n), - static_cast(gemm->params->workgroupTile.k) - }, - .occupancy = gemm->occupancy, - .workspace_size = prob.workspaceSize, - .workspace_size_per_elem_c = elementSizeAcc, - }; - - auto reduction_type = origami::streamk::select_reduction(origami_problem, - analytical_hardware, - origami_config, - DEFAULT_DYNAMIC_MODE); + origami::data_type_t dataType; + if (elementSizeA_bits < elementSizeB_bits) + dataType = rocroller_type_to_analytical_type(gemm->params->kernelType.typeB); + else + dataType = rocroller_type_to_analytical_type(gemm->params->kernelType.typeA); - origami_config.reduction_strategy = reduction_type; + auto reduction_type = origami::streamk::select_reduction(prob.m, prob.n, prob.k, prob.batch_count, + gemm->params->workgroupTile.m, gemm->params->workgroupTile.n, gemm->params->workgroupTile.k, analaytical_hardware, DEFAULT_DYNAMIC_MODE); + // Override reduction type to tree reduction for now. + // When Parallel reduction is available, this line can be removed + reduction_type = origami::streamk::reduction_type::Tree; - auto result = origami::streamk::select_grid_size(origami_problem, - analytical_hardware, - origami_config, - DEFAULT_DYNAMIC_MODE, - max_cus); + auto result = origami::streamk::select_grid(prob.m, + prob.n, + prob.k, + prob.batch_count, + prob.trans_a == HIPBLAS_OP_T, + prob.trans_b == HIPBLAS_OP_T, + elementSizeA_bits, + elementSizeB_bits, + elementSizeD_bits, + dataType, + prob.workspaceSize, + gemm->params->workgroupTile.m, + gemm->params->workgroupTile.n, + gemm->params->workgroupTile.k, + gemm->params->machineInstruction.m, + gemm->params->machineInstruction.n, + gemm->params->machineInstruction.k, + DEFAULT_WGM, + elementSizeAcc, + gemm->occupancy, + analaytical_hardware, + DEFAULT_DYNAMIC_MODE, + reduction_type); return result; } diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp index 63144fe76ef..74dcc55c9d0 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp @@ -29,7 +29,7 @@ #include "runtime_args_selection.hpp" #include "solution_selection.hpp" -#include "origami/origami.hpp" +#include const int MAX_BITS_WORKGROUPTILE_M = 8; const int MAX_BITS_WORKGROUPTILE_N = 8; @@ -44,9 +44,7 @@ const int USE_WORKGROUP_MAPPING_K_SIZE = 4096; * compile-time known. */ - constexpr size_t possibleTileSizesCount = 34; - - constexpr std::array possibleTileSizes = {{ + constexpr std::array possibleTileSizes = {{ {256, 256, 128}, {256, 192, 128}, {256, 128, 128}, @@ -84,10 +82,10 @@ const int USE_WORKGROUP_MAPPING_K_SIZE = 4096; }}; template -auto generateTileList() { - std::array tileList{}; +constexpr auto generateTileList() { + std::array tileList{}; - for (size_t i = 0; i < possibleTileSizesCount; ++i) { + for (size_t i = 0; i < possibleTileSizes.size(); ++i) { const auto& wgt = possibleTileSizes[i]; auto MI = pickMI(typeA, typeB, wgt); @@ -98,33 +96,27 @@ auto generateTileList() { int unroll = preferredUnrolling(typeA, typeB, wgt); - origami::config_t origami_config = { - .mt = { - static_cast(wgt.m), - static_cast(wgt.n), - static_cast(wgtk * unroll) - }, - .mi = { - static_cast(MI.m), - static_cast(MI.n), - static_cast(MI.k) - }, - .occupancy = 1, - .cache_hints_a = 0, - .cache_hints_b = 0, - }; - - tileList[i] = origami_config; + int non_temporal_a = 0; + int non_temporal_b = 0; + + tileList[i] = std::make_tuple( + wgt.m, wgt.n, wgtk * unroll, + MI.m, MI.n, MI.k, + 1, // occupancy + DEFAULT_WGM, + non_temporal_a, + non_temporal_b + ); } return tileList; } -using TileListGeneratorFn = std::vector(*)(); +using TileListGeneratorFn = std::vector(*)(); template -std::vector generateTileListWrapper() { - auto arr = generateTileList(); +std::vector generateTileListWrapper() { + constexpr auto arr = generateTileList(); return {arr.begin(), arr.end()}; } @@ -152,7 +144,7 @@ const std::map, TileListGene INSTANTIATE_TILE_LIST_FOR(FP6) }; -std::vector getTileListForKernelType(KernelType kernelType) +std::vector getTileListForKernelType(KernelType kernelType) { auto key = std::make_pair(kernelType.typeA, kernelType.typeB); auto it = tileListGenerators.find(key); @@ -178,42 +170,43 @@ std::vector chooseSolutionIndexParameters( { std::vector params; - std::vector origami_config_list = getTileListForKernelType(kernelType); + std::vector tile_list = getTileListForKernelType(kernelType); size_t elementSizeA_bits = rocRoller::DataTypeInfo::Get(kernelType.typeA).elementBits; size_t elementSizeB_bits = rocRoller::DataTypeInfo::Get(kernelType.typeB).elementBits; - - const origami::hardware_t analytical_hardware = origami::hardware_t::get_hardware_for_device(0); - - origami::problem_t origami_problem = { - .size = {prob.m, prob.n, prob.k}, - .batch = prob.batch_count, - .a_transpose = (prob.trans_a == hipblasOperation_t::HIPBLAS_OP_T) ? origami::transpose_t::T : origami::transpose_t::N, - .b_transpose = (prob.trans_b == hipblasOperation_t::HIPBLAS_OP_T) ? origami::transpose_t::T : origami::transpose_t::N, - .a_dtype = rocroller_type_to_analytical_type(kernelType.typeA), - .b_dtype = rocroller_type_to_analytical_type(kernelType.typeB), - .mi_dtype = rocroller_type_to_analytical_type(elementSizeA_bits < elementSizeB_bits ? kernelType.typeB : kernelType.typeA), - .a_mx_block_size = kernelType.scaleTypeA.blockRowSize * kernelType.scaleTypeA.blockColSize, - .b_mx_block_size = kernelType.scaleTypeB.blockRowSize * kernelType.scaleTypeB.blockColSize, - }; - - int defaultWGM = std::ceil(std::sqrt(analytical_hardware.N_CU / analytical_hardware.NUM_XCD)); - for (auto& config : origami_config_list) { - config.workgroup_mapping = defaultWGM; - } - - auto prediction_result = origami::rank_configs( - origami_problem, - analytical_hardware, - origami_config_list - ); - - for(auto const& result : prediction_result) + size_t elementSizeC_bits = rocRoller::DataTypeInfo::Get(kernelType.typeC).elementBits; + + origami::data_type_t dataType; + if (elementSizeA_bits < elementSizeB_bits) + dataType = rocroller_type_to_analytical_type(kernelType.typeB); + else + dataType = rocroller_type_to_analytical_type(kernelType.typeA); + + const origami::hardware_t analaytical_hardware = origami::hardware_t::get_hardware_for_device(0); + + int WGM = std::sqrt(std::floor(analaytical_hardware.N_CU / analaytical_hardware.NUM_XCD)); + + auto selected_tiles = origami::select_best_macro_tile_size( + prob.m, + prob.n, + prob.k, + prob.batch_count, + prob.trans_a == hipblasOperation_t::HIPBLAS_OP_T, + prob.trans_b == hipblasOperation_t::HIPBLAS_OP_T, + analaytical_hardware, + tile_list, + elementSizeA_bits, + elementSizeB_bits, + elementSizeC_bits, + dataType, + kernelType.scaleTypeA.blockRowSize * kernelType.scaleTypeA.blockColSize, //Handle A vs B block size. + 0.8, + false, + WGM); + + for(auto const& selected_tile : selected_tiles) { - auto mt_m = static_cast(result.config.mt.m); - auto mt_n = static_cast(result.config.mt.n); - auto mt_k = static_cast(result.config.mt.k); - WorkGroupTileSize wgt{mt_m, mt_n, mt_k}; + WorkGroupTileSize wgt{(int)std::get<1>(selected_tile), (int)std::get<2>(selected_tile), (int)std::get<3>(selected_tile)}; int unrollAmount = preferredUnrolling(kernelType.typeA, kernelType.typeB, wgt); wgt.k /= unrollAmount; @@ -249,7 +242,7 @@ std::vector chooseSolutionIndexParameters( size_t numTilesN = prob.n / wgt.n; size_t numTiles = numTilesM * numTilesN * prob.batch_count; auto isF6 = (kernelType.typeA == rocRoller::DataType::FP6 || kernelType.typeA == rocRoller::DataType::BF6 || kernelType.typeB == rocRoller::DataType::FP6 || kernelType.typeB == rocRoller::DataType::BF6); - if(numTiles < analytical_hardware.N_CU && !isF6) + if(numTiles < analaytical_hardware.N_CU && !isF6) { params.back().streamK = true; } diff --git a/projects/hipblaslt/tensilelite/include/Tensile/ContractionSolution.hpp b/projects/hipblaslt/tensilelite/include/Tensile/ContractionSolution.hpp index 6d3a494301b..57b47fa6f40 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/ContractionSolution.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/ContractionSolution.hpp @@ -42,8 +42,7 @@ #include #include -#include "origami/origami.hpp" -#include "origami/streamk.hpp" +#include #define TENSILE_COMMON_KERNEL_ARGS_SIZE 16 @@ -167,7 +166,7 @@ namespace TensileLite struct StreamKSettings { - origami::reduction_t reduction = origami::reduction_t::tree; + origami::streamk::reduction_type reduction = origami::streamk::reduction_type::Tree; size_t grid = 0; }; @@ -184,7 +183,7 @@ namespace TensileLite using Problem = ContractionProblemGemm; using Inputs = ContractionInputs; using GroupedInputs = ContractionGroupedInputs; - using ParamsCache = CacheMap, Problem>; + using ParamsCache = CacheMap, Problem>; /** * Indicate a solution is equally or estimatedly matched. @@ -219,11 +218,6 @@ namespace TensileLite } virtual bool isFallbackForHW(Hardware const&) const; - bool isStreamK() const - { - return sizeMapping.streamK > 0; - } - //! Estimates based on problem size, solution tile, and machine hardware //! charz: struct StaticPerformanceModel @@ -296,8 +290,8 @@ namespace TensileLite void calculateGrid(dim3& workGroupSize, dim3& numWorkGroups, ContractionSolution::Problem const& problem) const; - origami::reduction_t getSKReduction(Problem const& problem, Hardware const& hardware) const; - size_t getSKGrid(Problem const& problem, Hardware const& hardware, size_t tiles, origami::reduction_t reductionStrat) const; + origami::streamk::reduction_type getSKReduction(Problem const& problem, Hardware const& hardware) const; + size_t getSKGrid(Problem const& problem, Hardware const& hardware, size_t tiles, origami::streamk::reduction_type& reductionStrat) const; size_t partialTileSize(size_t skGrid) const; static float computeGranularity(float x); @@ -572,9 +566,9 @@ namespace TensileLite uint32_t magicNumber(int magicDivAlg, uint32_t x, uint32_t* magicShift) const; uint32_t smallMagicNumber(uint32_t x) const; - std::pair calculateAutoWGM(Problem const& problem, - Hardware const* hardware, - uint32_t skgrid) const; + std::pair calculateAutoWGM(Problem const& problem, + Hardware const* hardware, + uint32_t skgrid) const; uint32_t calculateAutoGSU(Problem const& problem, Hardware const* hardware) const; }; diff --git a/projects/hipblaslt/tensilelite/include/Tensile/PredictionLibrary.hpp b/projects/hipblaslt/tensilelite/include/Tensile/PredictionLibrary.hpp index df1920a8f23..7b2d6b18692 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/PredictionLibrary.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/PredictionLibrary.hpp @@ -45,9 +45,9 @@ namespace TensileLite template struct ProblemPredictionLibrary : public SolutionLibrary { - std::unordered_map> solutionmap; - std::vector origami_config_list; - std::unordered_map origami_config_map; + std::unordered_map> solutionmap; + std::vector tile_list; + std::unordered_map tile_map; static std::string Type() { @@ -156,35 +156,56 @@ namespace TensileLite batch *= problem.batchSize(i); } + bool debug = Debug::Instance().printPropertyEvaluation(); hip::HipAMDGPU const* pAMDGPU = dynamic_cast(&hardware); - + size_t elementSizeA_bits + = problem.a().elementBytes() * 8; + size_t elementSizeB_bits + = problem.b().elementBytes() * 8; + size_t elementSizeC_bits + = problem.c().elementBytes() * 8; const origami::hardware_t& analytical_hardware = *(pAMDGPU->analyticalHardware); - auto miDataType = datatypeToAnalyticalDatatype(problem.computeInputType()); - + if(origami::hardware_t::is_debug_enabled()) + { + analytical_hardware.print(); + } + int defaultWGM = std::ceil(std::sqrt(analytical_hardware.N_CU / analytical_hardware.NUM_XCD)); + origami::data_type_t miDataType = datatypeToAnalyticalDatatype(problem.computeInputType()); if(problem.f32XdlMathOp() == rocisa::DataType::XFloat32) // Check F32 compute type miDataType = origami::data_type_t::XFloat32; - origami::problem_t origami_problem = { - .size = {m, n, k}, - .batch = batch, - .a_transpose = problem.transA() ? origami::transpose_t::T : origami::transpose_t::N, - .b_transpose = problem.transB() ? origami::transpose_t::T : origami::transpose_t::N, - .a_dtype = datatypeToAnalyticalDatatype(problem.a().dataType()), - .b_dtype = datatypeToAnalyticalDatatype(problem.b().dataType()), - .c_dtype = datatypeToAnalyticalDatatype(problem.c().dataType()), - .d_dtype = datatypeToAnalyticalDatatype(problem.d().dataType()), - .mi_dtype = miDataType, - .a_mx_block_size = 0, // MX Data types come from rocroller - .b_mx_block_size = 0, // MX Data types come from rocroller - }; - - auto prediction_result = origami::rank_configs( - origami_problem, *(pAMDGPU->analyticalHardware), origami_config_list); - - for(const auto& r : prediction_result) + auto selected_tiles = origami::select_best_macro_tile_size( + m, + n, + k, + batch, + problem.transA(), + problem.transB(), + *(pAMDGPU->analyticalHardware), + tile_list, + elementSizeA_bits, + elementSizeB_bits, + elementSizeC_bits, + miDataType, + 0, // mx_block_size -> MX Data types come from rocroller. + 0.8, // L2 hit-rate (not used anymore -- should be removed) + false, + defaultWGM, + pAMDGPU->skMaxCUs); + for(const auto& tile : selected_tiles) { - auto mapiter = origami_config_map.find(r.config); + auto mapiter = tile_map.find(std::make_tuple(std::get<1>(tile), + std::get<2>(tile), + std::get<3>(tile), + std::get<4>(tile), + std::get<5>(tile), + std::get<6>(tile), + std::get<7>(tile), + std::get<8>(tile), + std::get<9>(tile), + std::get<10>(tile) + )); auto smapiter = solutionmap.find(mapiter->second); - if(mapiter != origami_config_map.end() && smapiter != solutionmap.end()) + if(mapiter != tile_map.end() && smapiter != solutionmap.end()) { auto solution = smapiter->second; if((*solution->hardwarePredicate)(hardware) diff --git a/projects/hipblaslt/tensilelite/include/Tensile/Serialization/PredictionLibrary.hpp b/projects/hipblaslt/tensilelite/include/Tensile/Serialization/PredictionLibrary.hpp index db2a5d05daf..77da8d810c6 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/Serialization/PredictionLibrary.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/Serialization/PredictionLibrary.hpp @@ -83,39 +83,21 @@ namespace TensileLite auto solution = slnIter->second; lib.solutionmap.insert(std::make_pair(index, solution)); - origami::dim3_t origami_mi; - if(solution->sizeMapping.matrixInstruction[0] == 0 - && solution->sizeMapping.matrixInstruction[1] == 0 - && solution->sizeMapping.matrixInstruction[2] == 0) - { - // Override dot2 instruction with vector lane widths - origami_mi = {1, 1, 64}; - } - else - { - origami_mi = { - static_cast(solution->sizeMapping.matrixInstruction[0]), - static_cast(solution->sizeMapping.matrixInstruction[1]), - static_cast( - solution->sizeMapping.matrixInstruction[2])}; - } + auto solution_tuple = std::make_tuple( + solution->sizeMapping.macroTile.x, // MT_M + solution->sizeMapping.macroTile.y, // MT_N + solution->sizeMapping.depthU, // MT_K + solution->sizeMapping.matrixInstruction[0], // MI_M + solution->sizeMapping.matrixInstruction[1], // MI_N + solution->sizeMapping.matrixInstruction[2], // MI_K + solution->sizeMapping.CUOccupancy, // Occupancy + solution->sizeMapping.workGroupMapping, // WGM + solution->sizeMapping.nonTemporalA, // Cache flag: A + solution->sizeMapping.nonTemporalB // Cache flag: B + ); - origami::config_t origami_config = { - .mt = {solution->sizeMapping.macroTile.x, - solution->sizeMapping.macroTile.y, - solution->sizeMapping.depthU}, - .mi = origami_mi, - .occupancy - = std::max(solution->sizeMapping.CUOccupancy, static_cast(1)), - .workgroup_mapping = solution->sizeMapping.workGroupMapping, - .cache_hints_a = solution->sizeMapping.nonTemporalA, - .cache_hints_b = solution->sizeMapping.nonTemporalB, - .workspace_size = std::numeric_limits::max(), - .workspace_size_per_elem_c = std::numeric_limits::max(), - }; - - lib.origami_config_list.emplace_back(origami_config); - lib.origami_config_map.insert(std::make_pair(origami_config, index)); + lib.tile_list.emplace_back(solution_tuple); + lib.tile_map.insert(std::make_pair(solution_tuple, index)); } } } diff --git a/projects/hipblaslt/tensilelite/include/Tensile/UtilsOrigami.hpp b/projects/hipblaslt/tensilelite/include/Tensile/UtilsOrigami.hpp index 21dd009eada..4d2519947be 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/UtilsOrigami.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/UtilsOrigami.hpp @@ -26,57 +26,58 @@ #pragma once -#include +#include #include namespace TensileLite { - + inline origami::data_type_t datatypeToAnalyticalDatatype(rocisa::DataType type) { switch(type) { - case rocisa::DataType::Float: - return origami::data_type_t::Float; - case rocisa::DataType::Double: - return origami::data_type_t::Double; - case rocisa::DataType::ComplexFloat: - return origami::data_type_t::ComplexFloat; - case rocisa::DataType::ComplexDouble: - return origami::data_type_t::ComplexDouble; - case rocisa::DataType::Half: - return origami::data_type_t::Half; - case rocisa::DataType::Int8x4: - return origami::data_type_t::Int8x4; - case rocisa::DataType::Int32: - return origami::data_type_t::Int32; - case rocisa::DataType::BFloat16: - return origami::data_type_t::BFloat16; - case rocisa::DataType::Int8: - return origami::data_type_t::Int8; - case rocisa::DataType::Int64: - return origami::data_type_t::Int64; - case rocisa::DataType::XFloat32: - return origami::data_type_t::XFloat32; - case rocisa::DataType::Float8_fnuz: - return origami::data_type_t::Float8_fnuz; - case rocisa::DataType::BFloat8_fnuz: - return origami::data_type_t::BFloat8_fnuz; - case rocisa::DataType::Float8BFloat8_fnuz: - return origami::data_type_t::Float8BFloat8_fnuz; - case rocisa::DataType::BFloat8Float8_fnuz: - return origami::data_type_t::BFloat8Float8_fnuz; - case rocisa::DataType::Float8: - return origami::data_type_t::Float8; - case rocisa::DataType::BFloat8: - return origami::data_type_t::BFloat8; - case rocisa::DataType::Float8BFloat8: - return origami::data_type_t::Float8BFloat8; - case rocisa::DataType::BFloat8Float8: - return origami::data_type_t::BFloat8Float8; + case rocisa::DataType::Float: + return origami::data_type_t::Float; + case rocisa::DataType::Double: + return origami::data_type_t::Double; + case rocisa::DataType::ComplexFloat: + return origami::data_type_t::ComplexFloat; + case rocisa::DataType::ComplexDouble: + return origami::data_type_t::ComplexDouble; + case rocisa::DataType::Half: + return origami::data_type_t::Half; + case rocisa::DataType::Int8x4: + return origami::data_type_t::Int8x4; + case rocisa::DataType::Int32: + return origami::data_type_t::Int32; + case rocisa::DataType::BFloat16: + return origami::data_type_t::BFloat16; + case rocisa::DataType::Int8: + return origami::data_type_t::Int8; + case rocisa::DataType::Int64: + return origami::data_type_t::Int64; + case rocisa::DataType::XFloat32: + return origami::data_type_t::XFloat32; + case rocisa::DataType::Float8_fnuz: + return origami::data_type_t::Float8_fnuz; + case rocisa::DataType::BFloat8_fnuz: + return origami::data_type_t::BFloat8_fnuz; + case rocisa::DataType::Float8BFloat8_fnuz: + return origami::data_type_t::Float8BFloat8_fnuz; + case rocisa::DataType::BFloat8Float8_fnuz: + return origami::data_type_t::BFloat8Float8_fnuz; + case rocisa::DataType::Float8: + return origami::data_type_t::Float8; + case rocisa::DataType::BFloat8: + return origami::data_type_t::BFloat8; + case rocisa::DataType::Float8BFloat8: + return origami::data_type_t::Float8BFloat8; + case rocisa::DataType::BFloat8Float8: + return origami::data_type_t::BFloat8Float8; - default: - return origami::data_type_t::None; + default: + return origami::data_type_t::None; } } } // namespace TensileLite + \ No newline at end of file diff --git a/projects/hipblaslt/tensilelite/include/Tensile/hip/HipHardware.hpp b/projects/hipblaslt/tensilelite/include/Tensile/hip/HipHardware.hpp index fdcf9852051..a195ccff105 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/hip/HipHardware.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/hip/HipHardware.hpp @@ -28,7 +28,7 @@ #include #include -#include +#include #include diff --git a/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp b/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp index 96a51e83e89..b7e86bb966f 100644 --- a/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp +++ b/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp @@ -32,11 +32,10 @@ #include #include #include -#include #include -#include #include +#include #include #include @@ -50,6 +49,8 @@ namespace TensileLite { + using ReductionType = origami::streamk::reduction_type; + enum class KERNELARGTYPE { NORMAL = 0, @@ -274,32 +275,58 @@ namespace TensileLite PrintBufferValueClass betaPrint( (void*)args[i].beta, sizeof(args[i].beta), problems[i].betaType()); std::cout << "Gemm " << i << ":" << std::endl; - std::cout << " " << "m: " << args[i].m << std::endl; - std::cout << " " << "n: " << args[i].n << std::endl; - std::cout << " " << "batch: " << args[i].batch << std::endl; - std::cout << " " << "k: " << args[i].k << std::endl; - std::cout << " " << "D: " << args[i].d << std::endl; - std::cout << " " << "C: " << args[i].c << std::endl; - std::cout << " " << "A: " << args[i].a << std::endl; - std::cout << " " << "B: " << args[i].b << std::endl; - std::cout << " " << "strideD1: " << args[i].strideD1 << std::endl; - std::cout << " " << "strideD2: " << args[i].strideD2 << std::endl; - std::cout << " " << "strideC1: " << args[i].strideC1 << std::endl; - std::cout << " " << "strideC2: " << args[i].strideC2 << std::endl; - std::cout << " " << "strideA1: " << args[i].strideA1 << std::endl; - std::cout << " " << "strideA2: " << args[i].strideA2 << std::endl; - std::cout << " " << "strideB1: " << args[i].strideB1 << std::endl; - std::cout << " " << "strideB2: " << args[i].strideB2 << std::endl; - std::cout << " " << "Alpha: " << alphaPrint << std::endl; - std::cout << " " << "Beta: " << betaPrint << std::endl; - std::cout << " " << "scaleAlphaVec: " << args[i].scaleAlphaVec << std::endl; - std::cout << " " << "bias: " << args[i].bias << std::endl; - std::cout << " " << "e: " << args[i].e << std::endl; - std::cout << " " << "strideE1: " << args[i].strideE1 << std::endl; - std::cout << " " << "strideE2: " << args[i].strideE2 << std::endl; - std::cout << " " << "act0: " << args[i].act0 << std::endl; - std::cout << " " << "act1: " << args[i].act1 << std::endl; - std::cout << " " << "activationType: " << args[i].activationType << std::endl; + std::cout << " " + << "m: " << args[i].m << std::endl; + std::cout << " " + << "n: " << args[i].n << std::endl; + std::cout << " " + << "batch: " << args[i].batch << std::endl; + std::cout << " " + << "k: " << args[i].k << std::endl; + std::cout << " " + << "D: " << args[i].d << std::endl; + std::cout << " " + << "C: " << args[i].c << std::endl; + std::cout << " " + << "A: " << args[i].a << std::endl; + std::cout << " " + << "B: " << args[i].b << std::endl; + std::cout << " " + << "strideD1: " << args[i].strideD1 << std::endl; + std::cout << " " + << "strideD2: " << args[i].strideD2 << std::endl; + std::cout << " " + << "strideC1: " << args[i].strideC1 << std::endl; + std::cout << " " + << "strideC2: " << args[i].strideC2 << std::endl; + std::cout << " " + << "strideA1: " << args[i].strideA1 << std::endl; + std::cout << " " + << "strideA2: " << args[i].strideA2 << std::endl; + std::cout << " " + << "strideB1: " << args[i].strideB1 << std::endl; + std::cout << " " + << "strideB2: " << args[i].strideB2 << std::endl; + std::cout << " " + << "Alpha: " << alphaPrint << std::endl; + std::cout << " " + << "Beta: " << betaPrint << std::endl; + std::cout << " " + << "scaleAlphaVec: " << args[i].scaleAlphaVec << std::endl; + std::cout << " " + << "bias: " << args[i].bias << std::endl; + std::cout << " " + << "e: " << args[i].e << std::endl; + std::cout << " " + << "strideE1: " << args[i].strideE1 << std::endl; + std::cout << " " + << "strideE2: " << args[i].strideE2 << std::endl; + std::cout << " " + << "act0: " << args[i].act0 << std::endl; + std::cout << " " + << "act1: " << args[i].act1 << std::endl; + std::cout << " " + << "activationType: " << args[i].activationType << std::endl; } } } @@ -517,11 +544,11 @@ namespace TensileLite template void ContractionSolution::singleCallArgs(ContractionSolution::Problem const& problem, ContractionInputs const& inputs, - uint32_t const& workspaceOffsetInByte, - Hardware const* hardware, - dim3 const& problemNumGroupTiles, - dim3 const& numWorkGroups, - KA& args, + uint32_t const& workspaceOffsetInByte, + Hardware const* hardware, + dim3 const& problemNumGroupTiles, + dim3 const& numWorkGroups, + KA& args, StreamKSettings const& sk) const { if(debugKernel) @@ -539,7 +566,7 @@ namespace TensileLite TensorDescriptor const& metadata = problem.metadata(); auto [autoWGM, autoWGMXCC] = calculateAutoWGM(problem, hardware, sk.grid); - uint32_t autoGsuVal = calculateAutoGSU(problem, hardware); + uint32_t autoGsuVal = calculateAutoGSU(problem, hardware); uint32_t gsu = problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal; { @@ -572,12 +599,10 @@ namespace TensileLite } else if(problemType.stridedBatched) { - if(sizeMapping.streamK > 0 && sk.reduction == origami::reduction_t::parallel) + if(sizeMapping.streamK > 0 && sk.reduction == ReductionType::Parallel) { - args.template append("ws_d", - (uint8_t*)inputs.ws + workspaceOffsetInByte); - args.template append("ws_c", - (uint8_t*)inputs.ws + workspaceOffsetInByte); + args.template append("ws_d", (uint8_t*)inputs.ws + workspaceOffsetInByte); + args.template append("ws_c", (uint8_t*)inputs.ws + workspaceOffsetInByte); } else { @@ -615,7 +640,7 @@ namespace TensileLite // StreamK workspace + flags args.template append("ws", inputs.ws); - if(sk.reduction == origami::reduction_t::parallel) + if(sk.reduction == ReductionType::Parallel) args.template append("Flags", nullptr); else args.template append("Flags", inputs.Synchronizer); @@ -625,9 +650,8 @@ namespace TensileLite size_t startStrideAB = problemType.useInitialStridesAB ? 0 : 1; // Pass wsStride if it's not in MBSK mode - bool gsuWSStride - = gsu > 1 && sizeMapping.globalAccumulation != 3 && sizeMapping.streamK == 0; - bool skWSStride = sizeMapping.streamK > 0 && sk.reduction == origami::reduction_t::parallel; + bool gsuWSStride = gsu > 1 && sizeMapping.globalAccumulation != 3 && sizeMapping.streamK == 0; + bool skWSStride = sizeMapping.streamK > 0 && sk.reduction == ReductionType::Parallel; if(gsuWSStride || skWSStride) { size_t wsStride = startStrideCD ? d.sizes()[0] : 1; @@ -711,24 +735,23 @@ namespace TensileLite args.template append("totalIters", totalIters); if(sizeMapping.streamK == 1) // Basic SK - { + { uint32_t itersPerWave = CeilDivide(totalIters, numWorkGroups.x); args.template append("SKItersPerWG", itersPerWave); - } + } else if(sizeMapping.streamK >= 2) // Two-tile SK - { - if(sk.reduction == origami::reduction_t::parallel) - { - uint32_t skSplit - = sk.grid / tiles; // skTiles is skSplit in parallel reduction path + { + if(sk.reduction == ReductionType::Parallel) + { + uint32_t skSplit = sk.grid / tiles; // skTiles is skSplit in parallel reduction path uint32_t skItersPerWG = itersPerTile / skSplit; args.template append("SKItersPerWG", skItersPerWG); - args.template append("skGrid", sk.grid); - args.template append("skTiles", skSplit); - } + args.template append("skGrid", sk.grid); + args.template append("skTiles", skSplit); + } else - { + { AMDGPU const* pAMDGPU = dynamic_cast(hardware); assert(pAMDGPU != nullptr && pAMDGPU->computeUnitCount != 0); int fullTiles = pAMDGPU->skFullTiles; @@ -742,21 +765,21 @@ namespace TensileLite uint32_t skTiles = sk.grid; // If not evenly divisible, determine number of Stream-K tiles if(tiles % sk.grid != 0) - { + { // Number of data-parallel tiles on each workgroup would be: // dpTilesPerWG = bigEnough ? (tiles - skTiles) / skGrid : 0; skTiles = bigEnough ? sk.grid * fullTiles + tiles % sk.grid : tiles; // Cap Stream-K tiles at total number of tiles in case of large multiplier skTiles = min(skTiles, tiles); - } + } uint32_t skItersPerWG = skTiles * itersPerTile / sk.grid; args.template append("SKItersPerWG", skItersPerWG); - args.template append("skGrid", sk.grid); - args.template append("skTiles", skTiles); - } - } + args.template append("skGrid", sk.grid); + args.template append("skTiles", skTiles); + } + } } if constexpr(insertKernelArgs) @@ -935,26 +958,29 @@ namespace TensileLite / std::ceil(std::ceil(m / mt0) * std::ceil(n / mt1) * gsu / cuCount); } - std::pair ContractionSolution::calculateAutoWGM(Problem const& problem, - Hardware const* hardware, - uint32_t const skgrid) const + std::pair ContractionSolution::calculateAutoWGM( + Problem const& problem, + Hardware const* hardware, + uint32_t const skgrid) const { // Hardware - AMDGPU const* pAMDGPU = dynamic_cast(hardware); + AMDGPU const* pAMDGPU = dynamic_cast(hardware); hip::HipAMDGPU const* hipAMDGPU = dynamic_cast(hardware); // Default WGM int32_t defaultWGM; - int32_t defaultWGMXCC; + uint32_t defaultWGMXCC; // Dynamically pick the values - if(sizeMapping.streamK != 0 && skgrid != 0 && sizeMapping.workGroupMapping == 0 - && sizeMapping.workGroupMappingXCC == -1 - && sizeMapping.nonTemporalA < 4 /* Exclude NTs for now till we fix libs */ - && sizeMapping.nonTemporalB < 4 /* Exclude NTs for now till we fix libs */) - { - int32_t c_wgm = 0; - int32_t c_wgmxcc = 0; + if(sizeMapping.streamK != 0 + && skgrid != 0 + && sizeMapping.workGroupMapping == 0 + && sizeMapping.workGroupMappingXCC == -1 + && sizeMapping.nonTemporalA < 4 /* Exclude NTs for now till we fix libs */ + && sizeMapping.nonTemporalB < 4 /* Exclude NTs for now till we fix libs */) + { + int32_t c_wgm = 0; + uint32_t c_wgmxcc = 0; // Try to find cached WGM and WGMXCC std::tie(c_wgm, c_wgmxcc) = paramsCache.find(problem); @@ -963,30 +989,30 @@ namespace TensileLite auto sizes = problem.problemSizes(); if(sizes.size() >= 4) { - origami::problem_t origami_problem = { - .size = {sizes[0], sizes[1], sizes[3]}, - .batch = sizes[2], - }; - origami::config_t origami_config = { - .mt = {static_cast(sizeMapping.macroTile.x), - static_cast(sizeMapping.macroTile.y), - static_cast(sizeMapping.depthU)}, - .cache_hints_a = sizeMapping.nonTemporalA, - .cache_hints_b = sizeMapping.nonTemporalB, - }; - std::tie(defaultWGMXCC, defaultWGM) = origami::select_workgroup_mapping( - origami_problem, *(hipAMDGPU->analyticalHardware), origami_config, skgrid); + auto wgm_pred = origami::select_best_wgm(*(hipAMDGPU->analyticalHardware), + sizes[0], + sizes[1], + sizes[3], + sizes[2], + sizeMapping.macroTile.x, + sizeMapping.macroTile.y, + sizeMapping.depthU, + sizeMapping.nonTemporalA, + sizeMapping.nonTemporalB, + skgrid, + false); + defaultWGMXCC = std::get<0>(wgm_pred); + defaultWGM = std::get<1>(wgm_pred); // Add to cache only if dynamically calculated. paramsCache.add(std::make_pair(defaultWGM, defaultWGMXCC), problem); if(Debug::Instance().printPropertyEvaluation()) - std::cout << "Dynamic WGM " << defaultWGM << ", WGMXCC " << defaultWGMXCC - << std::endl; + std::cout << "Dynamic WGM "<< defaultWGM << ", WGMXCC " << defaultWGMXCC << std::endl; } } else { - defaultWGM = c_wgm; + defaultWGM = c_wgm; defaultWGMXCC = c_wgmxcc; } } @@ -1012,6 +1038,7 @@ namespace TensileLite defaultWGMXCC = sizeMapping.workGroupMappingXCC; } + // If WGM and WGMXCC are explicitly specified at runtime, they override default and predictions if(pAMDGPU->fixedWGM != std::numeric_limits::max()) { @@ -1041,30 +1068,30 @@ namespace TensileLite AMDGPU const* pAMDGPU = dynamic_cast(hardware); assert(pAMDGPU); - uint32_t numCUs = pAMDGPU->computeUnitCount; - uint32_t numWGs = getNumWorkGroups(problem, sizeMapping); + uint32_t numCUs = pAMDGPU->computeUnitCount; + uint32_t numWGs = getNumWorkGroups(problem, sizeMapping); // avoid zero division - if(numWGs == 0) + if (numWGs == 0) { return 1; } - uint32_t MT0 = sizeMapping.macroTile.x; - uint32_t MT1 = sizeMapping.macroTile.y; - uint32_t MT2 = sizeMapping.depthU; - uint32_t M = problem.freeSizeA(0); - uint32_t N = problem.freeSizeB(0); - uint32_t B = problem.batchSize(0); - uint32_t K = problem.boundSize(0); - uint32_t GSULimit1 = max(1, (uint32_t)std::floor(numCUs / numWGs)); - uint32_t GSULimit2 = max(1, (uint32_t)std::floor((float)K / (float)MT2 / 3.0)); - uint32_t gsuVal = min(GSULimit2, max(1, GSULimit1)); + uint32_t MT0 = sizeMapping.macroTile.x; + uint32_t MT1 = sizeMapping.macroTile.y; + uint32_t MT2 = sizeMapping.depthU; + uint32_t M = problem.freeSizeA(0); + uint32_t N = problem.freeSizeB(0); + uint32_t B = problem.batchSize(0); + uint32_t K = problem.boundSize(0); + uint32_t GSULimit1 = max(1, (uint32_t)std::floor(numCUs / numWGs)); + uint32_t GSULimit2 = max(1, (uint32_t)std::floor((float)K / (float)MT2 / 3.0)); + uint32_t gsuVal = min(GSULimit2, max(1, GSULimit1)); // WorkgroupNumberCheck #define MAX_WORKGROUP_NUMBER 16777216 if(gsuVal > 1) gsuVal = min(gsuVal, - MAX_WORKGROUP_NUMBER / std::ceil(static_cast(M) / MT0) - / std::ceil(static_cast(N) / MT1) / B); + MAX_WORKGROUP_NUMBER / std::ceil(static_cast(M) / MT0) + / std::ceil(static_cast(N) / MT1) / B); // GlobalSplitUCheckMinK if(gsuVal > 1) @@ -1073,19 +1100,17 @@ namespace TensileLite // SynchronizerSizeCheck if(gsuVal > 1 && sizeMapping.globalAccumulation == 3) // MBSK { - uint32_t synchronizerUsage - = sizeMapping.synchronizerSizePerWG * problem.getNumTiles(sizeMapping, 1) * B; + uint32_t synchronizerUsage = sizeMapping.synchronizerSizePerWG * problem.getNumTiles(sizeMapping, 1) * B; gsuVal = synchronizerUsage > 409600 ? 1 : gsuVal; } // Avoid selecting a gsu value that would make launch grid over the limit - uint32_t tiles0 = CeilDivide(M, MT0); - uint32_t tiles1 = CeilDivide(N, MT1); - uint32_t tiles = tiles0 * tiles1 * B; - uint32_t workGroupSize = sizeMapping.workGroupSize.x * sizeMapping.workGroupSize.y - * sizeMapping.workGroupSize.z; + uint32_t tiles0 = CeilDivide(M, MT0); + uint32_t tiles1 = CeilDivide(N, MT1); + uint32_t tiles = tiles0 * tiles1 * B; + uint32_t workGroupSize = sizeMapping.workGroupSize.x * sizeMapping.workGroupSize.y * sizeMapping.workGroupSize.z; uint32_t maxGsuValue = (std::numeric_limits::max() / workGroupSize) / tiles; - gsuVal = min(gsuVal, maxGsuValue); + gsuVal = min(gsuVal, maxGsuValue); // avoid gsu < 1 gsuVal = max(gsuVal, 1); @@ -1149,8 +1174,7 @@ namespace TensileLite // NB: get value from param= set in runtime / vs value from sizeMapping: from logic yaml. // param: default values: [xcc = 0, xccg = 0]. So when we never set xcc/xccg in runtime: we always get from sizeMapping. // From sizeMapping = from logic yaml. If not set in Config-Yaml, use default value [1, -1] - wgmxccg - = param.wgmxccg() != 0 ? param.wgmxccg() : sizeMapping.workGroupMappingXCCGroup; + wgmxccg = param.wgmxccg() != 0 ? param.wgmxccg() : sizeMapping.workGroupMappingXCCGroup; if(wgmxcc >= 1 && wgmxccg == -1) { AMDGPU const* pAMDGPU = dynamic_cast(hardware); @@ -1161,7 +1185,7 @@ namespace TensileLite } else if(internalArgsSupport.version == 2 && internalArgsSupport.useSFC) { - internalArg1 = wgm; + internalArg1 = wgm; } } @@ -1185,9 +1209,9 @@ namespace TensileLite uint32_t staggerU = mask8 & sizeMapping.staggerU; if(Debug::Instance().disableStaggerU()) staggerU = 0; - staggerU = staggerU | staggerUShift; - staggerU = staggerU | staggerUMapping; - internalArg0 = internalArg0 | (staggerU << 16); + staggerU = staggerU | staggerUShift; + staggerU = staggerU | staggerUMapping; + internalArg0 = internalArg0 | (staggerU << 16); } else if(T_Debug && Debug::Instance().disableStaggerU()) std::cout << "solution doesn't support configurable staggerU" << std::endl; @@ -1201,12 +1225,12 @@ namespace TensileLite } } - void ContractionSolution::calculateGrid(dim3& workGroupSize, - dim3& numWorkGroups, + void ContractionSolution::calculateGrid(dim3& workGroupSize, + dim3& numWorkGroups, ContractionSolution::Problem const& problem) const { workGroupSize.x = sizeMapping.workGroupSize.x * sizeMapping.workGroupSize.y - * sizeMapping.workGroupSize.z; + * sizeMapping.workGroupSize.z; workGroupSize.y = 1; workGroupSize.z = 1; @@ -1313,15 +1337,8 @@ namespace TensileLite std::cout << "AutoWGM: " << autoWGM << std::endl; std::cout << "AutoWGMXCC: " << autoWGMXCC << std::endl; } - kernelArgs(1, - 0, - rv.args, - getNumWorkGroups(rv), - &hardware, - problem.getParams(), - autoWGM, - autoWGMXCC, - autoGsuVal); + kernelArgs( + 1, 0, rv.args, getNumWorkGroups(rv), &hardware, problem.getParams(), autoWGM, autoWGMXCC, autoGsuVal); } singleCallArgs( problem, inputs, 0, &hardware, problemNumGroupTiles, rv.numWorkGroups, rv.args, sk); @@ -1395,8 +1412,7 @@ namespace TensileLite numWorkGroups.x = CeilDivide(numWorkGroups.x, sizeMapping.macroTile.x); numWorkGroups.y = CeilDivide(numWorkGroups.y, sizeMapping.macroTile.y); - uint32_t gsu - = problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal; + uint32_t gsu = problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal; if(gsu > 0) numWorkGroups.y *= gsu; @@ -1456,7 +1472,7 @@ namespace TensileLite { for(int idx = 0; idx < problems.size(); idx++) { - auto problem = problems[idx]; + auto problem = problems[idx]; StreamKSettings sk; // Grouped gemm currently not supported in SK // But this code path is run to calculate to determine if solution is supported @@ -1734,10 +1750,10 @@ namespace TensileLite template void ContractionSolution::outputConversionCallArgs(ContractionSolution::Problem const& problem, ContractionInputs const& inputs, - uint32_t const& workspaceOffsetInByte, - KA& args, + uint32_t const& workspaceOffsetInByte, + KA& args, StreamKSettings const& sk, - uint32_t autoGsuVal) const + uint32_t autoGsuVal) const { TensorDescriptor const& c = problem.c(); TensorDescriptor const& d = problem.d(); @@ -1890,15 +1906,14 @@ namespace TensileLite args.template append(concatenate_if("size_", i), size); i++; } - uint32_t gsu - = sizeMapping.globalAccumulation == 1 - ? 1 - : (problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal); + uint32_t gsu = sizeMapping.globalAccumulation == 1 + ? 1 + : (problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal); if(sizeMapping.streamK > 0) { auto tiles = problem.getNumTiles(sizeMapping, 1); - gsu = sk.grid / tiles; + gsu = sk.grid / tiles; } args.template append(concatenate_if("gsu"), gsu); @@ -1947,17 +1962,16 @@ namespace TensileLite vw = 2; } - uint32_t gsu - = sizeMapping.globalAccumulation == 1 - ? 1 - : (problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal); + uint32_t gsu = sizeMapping.globalAccumulation == 1 + ? 1 + : (problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal); if(sizeMapping.streamK > 0) { // If using post kernel with stream-k then it is doing parallel reduciton // Calculate the splitting factor auto tiles = problem.getNumTiles(sizeMapping, 1); - gsu = sk.grid / tiles; + gsu = sk.grid / tiles; } rv.kernelName = outputConversionKernelName(problem, inputs, vw, gsu); @@ -2112,10 +2126,10 @@ namespace TensileLite problems, vw, rv.workGroupSize, rv.numWorkGroups, rv.numWorkItems, h_args); uint32_t autoGsuVal = calculateAutoGSU(problems[0], &hardware); - uint32_t gsu = sizeMapping.globalAccumulation == 1 - ? 1 - : (problems[0].getParams().gsu() > 0 ? problems[0].getParams().gsu() - : autoGsuVal); + uint32_t gsu + = sizeMapping.globalAccumulation == 1 + ? 1 + : (problems[0].getParams().gsu() > 0 ? problems[0].getParams().gsu() : autoGsuVal); if constexpr(std::is_same::value) { @@ -2126,7 +2140,7 @@ namespace TensileLite = this->requiredHostWorkspaceSizePerProblem * problems.size(); for(int idx = 0; idx < problems.size(); idx++) { - auto problem = problems[idx]; + auto problem = problems[idx]; StreamKSettings sk; outputConversionCallArgs( problem, inputs.grouped[idx], workspaceOffsetInByte, h_args, sk, autoGsuVal); @@ -2566,7 +2580,7 @@ namespace TensileLite std::vector rv; auto autoGsuVal = calculateAutoGSU(problem, &hardware); - auto gsu = problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal; + auto gsu = problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal; if(gsu > 1 && sizeMapping.globalAccumulation != 2 && sizeMapping.globalAccumulation != 3) { if(debug) @@ -2578,13 +2592,11 @@ namespace TensileLite StreamKSettings sk; if(sizeMapping.streamK > 0) { - sk.reduction = getSKReduction(problem, hardware); - auto tiles = problem.getNumTiles(sizeMapping, 1); - sk.grid = getSKGrid(problem, hardware, tiles, sk.reduction); + sk.reduction = getSKReduction(problem, hardware); + auto tiles = problem.getNumTiles(sizeMapping, 1); + sk.grid = getSKGrid(problem, hardware, tiles, sk.reduction); const bool streamKDP = Debug::Instance().useStreamKDataParrallel(); - if(sk.grid > 0 - && (sk.reduction == origami::reduction_t::parallel - || (tiles % sk.grid != 0 && !streamKDP))) + if(sk.grid > 0 && (sk.reduction == ReductionType::Parallel || (tiles % sk.grid != 0 && !streamKDP))) { // Check ideal amount of workspace for optimal performance size_t idealWorkspace = partialTileSize(sk.grid); @@ -2592,8 +2604,8 @@ namespace TensileLite // Performance will likely be lower, but the kernel can run if workspace is unavailable if(idealWorkspace > problem.workspaceSize()) { - sk.reduction = origami::reduction_t::tree; - sk.grid = tiles; + sk.reduction = ReductionType::Tree; + sk.grid = tiles; } } } @@ -2603,8 +2615,7 @@ namespace TensileLite else rv.push_back(generateSingleCall(problem, inputs, hardware, sk)); - if(((sizeMapping.globalAccumulation != 3) && gsu > 1 && sizeMapping.globalAccumulation) - || sk.reduction == origami::reduction_t::parallel) + if(((sizeMapping.globalAccumulation != 3) && gsu > 1 && sizeMapping.globalAccumulation) || sk.reduction == ReductionType::Parallel) { if(debug) rv.push_back(generateOutputConversionCall(problem, inputs, sk, autoGsuVal)); @@ -2793,8 +2804,7 @@ namespace TensileLite rv.push_back( generateSingleCallGroupedGemm(problems, inputs, hardware, h_args, dUA)); - auto gsu = problems[0].getParams().gsu() > 0 ? problems[0].getParams().gsu() - : calculateAutoGSU(problems[0], &hardware); + auto gsu = problems[0].getParams().gsu() > 0 ? problems[0].getParams().gsu() : calculateAutoGSU(problems[0], &hardware); if((sizeMapping.globalAccumulation && gsu > 1) && (sizeMapping.globalAccumulation != 3)) { @@ -2982,12 +2992,12 @@ namespace TensileLite auto tiles = problem.getNumTiles(sizeMapping, 1); if(tiles > 0) // Grouped GEMM reports 0 tiles { - auto reductionStrat = getSKReduction(problem, hardware); - size_t skGrid = getSKGrid(problem, hardware, tiles, reductionStrat); + ReductionType reductionStrat = getSKReduction(problem, hardware); + size_t skGrid = getSKGrid(problem, hardware, tiles, reductionStrat); // Get space required for partial tiles= - if(reductionStrat == origami::reduction_t::parallel) + if(reductionStrat == ReductionType::Parallel) { - size_t splitk = skGrid / tiles; + size_t splitk = skGrid / tiles; size_t idealWorkspace = requiredWorkspaceSizeGsu(problem, hardware, splitk); if(idealWorkspace <= problem.workspaceSize()) size += idealWorkspace; @@ -3006,24 +3016,22 @@ namespace TensileLite else { // TODO: Pass GSU from problem and change value[2] to gsu if gsu != default value - size_t gsu = problem.getParams().gsu() > 0 ? problem.getParams().gsu() - : calculateAutoGSU(problem, &hardware); + size_t gsu = problem.getParams().gsu() > 0 ? problem.getParams().gsu() : calculateAutoGSU(problem, &hardware); size += requiredWorkspaceSizeGsu(problem, hardware, gsu); } + return size; } - size_t ContractionSolution::requiredWorkspaceSizeGsu(Problem const& problem, - Hardware const& hardware, - size_t gsu) const + size_t ContractionSolution::requiredWorkspaceSizeGsu(Problem const& problem, Hardware const& hardware, size_t gsu) const { size_t size = 0; size_t gsuMultiplier = gsu > 1 ? gsu : 0; size_t batch = problem.d().sizes()[2]; size_t tiles = problem.getNumTiles(sizeMapping, gsu) * batch; - size_t tileSize - = sizeMapping.macroTile.x * sizeMapping.macroTile.y * sizeMapping.workspaceSizePerElemC; + size_t tileSize = sizeMapping.macroTile.x * sizeMapping.macroTile.y + * sizeMapping.workspaceSizePerElemC; size_t bufSize = gsu > 1 ? tiles * tileSize : 0; size += bufSize; @@ -3032,15 +3040,19 @@ namespace TensileLite { if(problem.biasSrc() == ContractionProblemGemm::TENSOR::A) { - size += problem.freeSizeA(0) * sizeMapping.workspaceSizePerElemBias * gsuMultiplier; + size += problem.freeSizeA(0) * sizeMapping.workspaceSizePerElemBias + * gsuMultiplier; } else if(problem.biasSrc() == ContractionProblemGemm::TENSOR::B) { - size += problem.freeSizeB(0) * sizeMapping.workspaceSizePerElemBias * gsuMultiplier; + size += problem.freeSizeB(0) * sizeMapping.workspaceSizePerElemBias + * gsuMultiplier; } - else if(problem.biasSrc() == ContractionProblemGemm::TENSOR::D && (gsuMultiplier == 0)) + else if(problem.biasSrc() == ContractionProblemGemm::TENSOR::D + && (gsuMultiplier == 0)) { - size += problem.d().totalLogicalElements() * problem.computeTypeElementSize() * gsu; + size += problem.d().totalLogicalElements() * problem.computeTypeElementSize() + * gsu; } } @@ -3100,8 +3112,7 @@ namespace TensileLite return h_args.size(); } - size_t ContractionSolution::requiredSynchronizerSize(Problem const& problem, - Hardware const& hardware) const + size_t ContractionSolution::requiredSynchronizerSize(Problem const& problem, Hardware const& hardware) const { if(sizeMapping.globalAccumulation == 3) { @@ -3112,10 +3123,9 @@ namespace TensileLite return 0; } - origami::reduction_t ContractionSolution::getSKReduction(Problem const& problem, - Hardware const& hardware) const + ReductionType ContractionSolution::getSKReduction(Problem const& problem, Hardware const& hardware) const { - auto reductionStrat = origami::reduction_t::tree; + ReductionType reductionStrat = ReductionType::Tree; AMDGPU const* pAMDGPU = dynamic_cast(&hardware); assert(pAMDGPU != nullptr && pAMDGPU->computeUnitCount != 0); @@ -3123,13 +3133,13 @@ namespace TensileLite if(!sizeMapping.customKernelName.empty()) { // Custom kernel currently only supports single-kernel reduction - reductionStrat = origami::reduction_t::tree; + reductionStrat = ReductionType::Tree; } else if(pAMDGPU->skDynamicGrid > 0) { size_t x = 1; size_t y = 1; - size_t z = 1; + size_t z = 1; size_t batch = 1; for(size_t i = 0; i < problem.freeIndicesA().size(); i++) { @@ -3147,34 +3157,30 @@ namespace TensileLite { batch *= problem.batchSize(i); } + origami::data_type_t miDataType = datatypeToAnalyticalDatatype(problem.computeInputType()); hip::HipAMDGPU const* hipAMDGPU = dynamic_cast(&hardware); - origami::problem_t origami_problem = { - .size = {x, y, z}, - .batch = batch, - }; - origami::config_t origami_config = { - .mt = {static_cast(sizeMapping.macroTile.x), - static_cast(sizeMapping.macroTile.y), - static_cast(sizeMapping.depthU)}, - }; - reductionStrat = origami::streamk::select_reduction( - origami_problem, + x, + y, + z, + batch, + sizeMapping.macroTile.x, + sizeMapping.macroTile.y, + sizeMapping.depthU, *(hipAMDGPU->analyticalHardware), - origami_config, - static_cast(pAMDGPU->skDynamicGrid)); + pAMDGPU->skDynamicGrid); } return reductionStrat; } - size_t ContractionSolution::getSKGrid(Problem const& problem, - Hardware const& hardware, - size_t tiles, - origami::reduction_t reductionStrat) const + size_t ContractionSolution::getSKGrid(Problem const& problem, + Hardware const& hardware, + size_t tiles, + ReductionType& reductionStrat) const { - size_t skGrid = tiles; // Fallback + size_t skGrid = tiles; // Fallback const bool streamKDP = Debug::Instance().useStreamKDataParrallel(); if(streamKDP) skGrid = tiles; @@ -3198,7 +3204,7 @@ namespace TensileLite { skGrid = pAMDGPU->skFixedGrid; } - else if(pAMDGPU->skDynamicGrid > 0) + else if (pAMDGPU->skDynamicGrid > 0) { size_t x = 1; size_t y = 1; @@ -3215,34 +3221,33 @@ namespace TensileLite { batch *= problem.batchSize(i); } + origami::data_type_t miDataType = datatypeToAnalyticalDatatype(problem.computeInputType()); hip::HipAMDGPU const* hipAMDGPU = dynamic_cast(&hardware); - origami::problem_t origami_problem = { - .size = {x, y, z}, - .batch = batch, - .a_dtype = datatypeToAnalyticalDatatype(problem.alphaType()), - .b_dtype = datatypeToAnalyticalDatatype(problem.betaType()), - .mi_dtype = datatypeToAnalyticalDatatype(problem.computeInputType()), - }; - origami::config_t origami_config = { - .mt = {static_cast(sizeMapping.macroTile.x), - static_cast(sizeMapping.macroTile.y), - static_cast(sizeMapping.depthU)}, - .mi = {static_cast(sizeMapping.matrixInstruction[0]), - static_cast(sizeMapping.matrixInstruction[1]), - static_cast(sizeMapping.matrixInstruction[2])}, - .occupancy = std::max(sizeMapping.CUOccupancy, static_cast(1)), - .workgroup_mapping = sizeMapping.workGroupMapping, - .workspace_size = problem.workspaceSize(), - .workspace_size_per_elem_c = sizeMapping.workspaceSizePerElemC, - .reduction_strategy = reductionStrat, - }; - skGrid = origami::streamk::select_grid_size( - origami_problem, - *(hipAMDGPU->analyticalHardware), - origami_config, - static_cast(pAMDGPU->skDynamicGrid), - pAMDGPU->skMaxCUs); + skGrid = origami::streamk::select_grid(x, + y, + z, + batch, + problem.transA(), + problem.transB(), + problem.a().elementBytes() * 8, + problem.b().elementBytes() * 8, + problem.c().elementBytes() * 8, + miDataType, + problem.workspaceSize(), + sizeMapping.macroTile.x, + sizeMapping.macroTile.y, + sizeMapping.depthU, + sizeMapping.matrixInstruction[0], + sizeMapping.matrixInstruction[1], + sizeMapping.matrixInstruction[2], + sizeMapping.workGroupMapping, + sizeMapping.workspaceSizePerElemC, + sizeMapping.CUOccupancy, + *(hipAMDGPU->analyticalHardware), + pAMDGPU->skDynamicGrid, + reductionStrat, + pAMDGPU->skMaxCUs); } // Limit the CUs Stream-K is launched on either max or the specified, // whichever is minimum. @@ -3267,11 +3272,11 @@ namespace TensileLite // For tree-reduction there are some limits for divisions to avoid overflow // If we hit one of the limits, fallback to DP size_t itersPerTile = problem.getItersPerTile(sizeMapping); - size_t itersPerWG = tiles * itersPerTile / skGrid; - if(itersPerTile >= 65536 || itersPerWG >= 65536 || (tiles * itersPerTile) >= 16777216) + size_t itersPerWG = tiles * itersPerTile / skGrid; + if(itersPerTile >=65536 || itersPerWG >= 65536 || (tiles * itersPerTile) >= 16777216) { - reductionStrat = origami::reduction_t::tree; - skGrid = tiles; + reductionStrat = ReductionType::Tree; + skGrid = tiles; } return skGrid; @@ -3281,8 +3286,7 @@ namespace TensileLite { size_t size = 0; - size_t tileSize - = sizeMapping.macroTile.x * sizeMapping.macroTile.y * sizeMapping.workspaceSizePerElemC; + size_t tileSize = sizeMapping.macroTile.x * sizeMapping.macroTile.y * sizeMapping.workspaceSizePerElemC; size += tileSize * skGrid; // Partials tile per WG // TODO batches // TODO round up for alignment? @@ -3295,13 +3299,8 @@ namespace TensileLite return x / ceil(x); } - ContractionSolution::Granularities - ContractionSolution::computeGranularities(Hardware const& hardware, - double M, - double N, - double K, - double NumBatches, - uint32_t autoGsuVal) const + ContractionSolution::Granularities ContractionSolution::computeGranularities( + Hardware const& hardware, double M, double N, double K, double NumBatches, uint32_t autoGsuVal) const { ContractionSolution::Granularities granularities; @@ -3405,8 +3404,7 @@ namespace TensileLite } double K = problem.boundSize(0); // TODO - fix for multiple summations - pp.granularities = ContractionSolution::computeGranularities( - hardware, M, N, K, NumBatches, calculateAutoGSU(problem, &hardware)); + pp.granularities = ContractionSolution::computeGranularities(hardware, M, N, K, NumBatches, calculateAutoGSU(problem, &hardware)); auto it = ideals.begin(); diff --git a/shared/origami/.clang-format b/shared/origami/.clang-format deleted file mode 100644 index 1158e834f77..00000000000 --- a/shared/origami/.clang-format +++ /dev/null @@ -1,34 +0,0 @@ -BasedOnStyle: Google -BinPackArguments: false -BinPackParameters: false -ColumnLimit: 100 -IndentWidth: 2 -BreakConstructorInitializers: BeforeComma -IncludeBlocks: Preserve - - -AlignAfterOpenBracket: Align -AlignConsecutiveAssignments: true -AlignOperands: true -AlignTrailingComments: true - -AllowAllArgumentsOnNextLine: true -AllowAllParametersOfDeclarationOnNextLine : false -AllowShortBlocksOnASingleLine : true -AllowShortCaseLabelsOnASingleLine: true -AllowShortFunctionsOnASingleLine : All -AllowShortIfStatementsOnASingleLine: true -AllowShortLambdasOnASingleLine : All -AllowShortLoopsOnASingleLine: true -AlwaysBreakAfterDefinitionReturnType: None -AlwaysBreakBeforeMultilineStrings : false -AlwaysBreakTemplateDeclarations: Yes - -BreakBeforeBinaryOperators: None -Cpp11BracedListStyle: true -IndentWrappedFunctionNames : false -KeepEmptyLinesAtTheStartOfBlocks : false -PointerAlignment: Left -ReflowComments : true -ExperimentalAutoDetectBinPacking: false -BreakBeforeBraces: Attach diff --git a/shared/origami/.gitignore b/shared/origami/.gitignore deleted file mode 100644 index 050a6c80a57..00000000000 --- a/shared/origami/.gitignore +++ /dev/null @@ -1,20 +0,0 @@ -# Build directory -build/ - -# CMake cache -CMakeCache.txt - -# IDE files -.vscode/ -.vs/ - -# Python bindings build -__pycache__/ -*.pyc -*.pyo -*.so - -# Temporary files -*.tmp -*.swp -*~ diff --git a/shared/origami/CMakeLists.txt b/shared/origami/CMakeLists.txt index 9a420a4ea86..6779f95a621 100644 --- a/shared/origami/CMakeLists.txt +++ b/shared/origami/CMakeLists.txt @@ -1,27 +1,5 @@ -################################################################################ -# -# MIT License -# -# Copyright 2025 AMD ROCm(TM) Software -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- -# ies of the Software, and to permit persons to whom the Software is furnished -# to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- -# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- -# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# -################################################################################ +# Copyright Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT cmake_minimum_required(VERSION 3.24.4) @@ -46,12 +24,11 @@ option(ORIGAMI_BUILD_SHARED_LIBS "Build shared libraries." ${ORIGAMI_STANDALONE} option(ORIGAMI_ENABLE_PYTHON "Enable Python bindings." OFF) option(ORIGAMI_BUILD_TESTING "Enable Python binding tests." OFF) option(ORIGAMI_ENABLE_INSTALL "Configure origami installation" ON) -option(ORIGAMI_ENABLE_FETCH "Auto-fetch dependencies with FetchContent" ON) find_package(hip REQUIRED) if(ORIGAMI_BUILD_SHARED_LIBS OR (BUILD_SHARED_LIBS AND ORIGAMI_STANDALONE)) - add_library(origami SHARED) + add_library(origami SHARED) else() add_library(origami STATIC) endif() @@ -60,39 +37,34 @@ rocm_set_soversion(origami "${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}") add_library(roc::origami ALIAS origami) -set_target_properties(origami PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(origami PROPERTIES + POSITION_INDEPENDENT_CODE ON +) target_compile_features(origami PUBLIC cxx_std_17) add_library(origami-headers INTERFACE) add_library(roc::origami-headers ALIAS origami-headers) -target_compile_features(origami-headers INTERFACE cxx_std_17) - -target_include_directories( - origami-headers INTERFACE $ - $ +target_include_directories(origami-headers + INTERFACE + $ + $ ) -target_sources( - origami-headers - INTERFACE $ - $ - $ - $ - $ - $ - $ +target_sources(origami-headers + INTERFACE + $ + $ + $ + $ ) target_link_libraries(origami PUBLIC roc::origami-headers) -target_sources( - origami - PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/gemm.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/hardware.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/log.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/origami.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/streamk.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/types.cpp" +target_sources(origami + PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/gemm.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/utils.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/streamk.cpp" ) target_link_libraries(origami PUBLIC hip::host) @@ -103,30 +75,32 @@ endif() if(ORIGAMI_BUILD_TESTING OR BUILD_TESTING) enable_testing() - + add_subdirectory(tests) if(ORIGAMI_ENABLE_PYTHON) find_package(Python3 REQUIRED COMPONENTS Interpreter Development) - add_test(NAME origami_python_test COMMAND "${Python3_EXECUTABLE}" "origami_test.py" - WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/python" + add_test( + NAME origami_python_test + COMMAND "${Python3_EXECUTABLE}" "origami_test.py" + WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/python" ) - - add_test(NAME origami_python_grid_test COMMAND "${Python3_EXECUTABLE}" - "origami_grid_test.py" - WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/python" + + add_test( + NAME origami_python_grid_test + COMMAND "${Python3_EXECUTABLE}" "origami_grid_test.py" + WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/python" ) - set_tests_properties( - origami_python_test origami_python_grid_test + set_tests_properties(origami_python_test origami_python_grid_test PROPERTIES - ENVIRONMENT - "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/python:${CMAKE_CURRENT_SOURCE_DIR}/python:$ENV{PYTHONPATH}" + ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/python:${CMAKE_CURRENT_SOURCE_DIR}/python:$ENV{PYTHONPATH}" ) - set_tests_properties( - origami_python_test origami_python_grid_test PROPERTIES DEPENDS origami_python + set_tests_properties(origami_python_test origami_python_grid_test + PROPERTIES + DEPENDS origami_python ) endif() endif() @@ -135,49 +109,43 @@ if(ORIGAMI_ENABLE_INSTALL OR ORIGAMI_STANDALONE) rocm_install(TARGETS origami origami-headers) rocm_export_targets( - TARGETS - roc::origami - roc::origami-headers - DEPENDS - PACKAGE - hip - NAMESPACE - roc:: + TARGETS roc::origami roc::origami-headers + DEPENDS PACKAGE hip + NAMESPACE roc:: ) if(ORIGAMI_BUILD_TESTING OR BUILD_TESTING) - rocm_install(TARGETS origami-tests COMPONENT tests) - + rocm_install(TARGETS origami-tests + COMPONENT tests + ) + + rocm_install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/tests/origami_gtest.yaml" + DESTINATION "${CMAKE_INSTALL_BINDIR}" + COMPONENT tests + ) endif() rocm_install( - DIRECTORY - include/ - DESTINATION - "${CMAKE_INSTALL_INCLUDEDIR}" - COMPONENT - devel - FILES_MATCHING - PATTERN - "*.hpp" - PATTERN - "*.h" + DIRECTORY include/ + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}" + COMPONENT devel + FILES_MATCHING PATTERN "*.hpp" PATTERN "*.h" ) configure_file( "${CMAKE_CURRENT_SOURCE_DIR}/cmake/origami-config.cmake.in" - "${CMAKE_CURRENT_BINARY_DIR}/origami-config.cmake" @ONLY + "${CMAKE_CURRENT_BINARY_DIR}/origami-config.cmake" + @ONLY ) rocm_install( - FILES "${CMAKE_CURRENT_BINARY_DIR}/origami-config.cmake" DESTINATION - "${CMAKE_INSTALL_LIBDIR}/cmake/origami" COMPONENT devel + FILES "${CMAKE_CURRENT_BINARY_DIR}/origami-config.cmake" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/origami" + COMPONENT devel ) set(BUILD_SHARED_LIBS ${ORIGAMI_BUILD_SHARED_LIBS}) - set(ORIGAMI_CONFIG_DIR "\${CPACK_PACKAGING_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}" - CACHE PATH "Path placed into ldconfig file" - ) + set(ORIGAMI_CONFIG_DIR "\${CPACK_PACKAGING_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}" CACHE PATH "Path placed into ldconfig file") rocm_create_package( NAME origami diff --git a/shared/origami/include/origami/gemm.hpp b/shared/origami/include/origami/gemm.hpp index 607b34fb3f4..18d427ea69d 100644 --- a/shared/origami/include/origami/gemm.hpp +++ b/shared/origami/include/origami/gemm.hpp @@ -1,197 +1,245 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright 2025 AMD ROCm(TM) Software - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once -#include #include "origami/hardware.hpp" -#include "origami/types.hpp" - -namespace origami { - -/** - * @brief Compute the number of matrix instructions required to compute a single MT_MXMT_NXMT_K - * tile. - * - * @param mt Macro tile dimensions - * @param mi Micro tile dimensions - * @return size_t Number of matrix instructions - */ -size_t compute_number_matrix_instructions(dim3_t mt, dim3_t mi); - -/** - * @brief Compute TF32 conversion overhead. - * - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param config Kernel configuration. - * @return double Latency in cycles. - */ -static inline double compute_cvt_overhead(const problem_t& problem, - const hardware_t& hardware, - const config_t& config); -/** - * @brief Compute the latency to process a single macro-tile for the given problem and hardware. - * - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param config Kernel configuration. - * @return size_t Latency in cycles. - */ -size_t compute_mt_compute_latency(const problem_t& problem, - const hardware_t& hardware, - const config_t& config); - -/** - * @brief Check if MT fits in LDS - * - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param mt Macro tile dimensions - * @param a_dtype Data type of operand A - * @param b_dtype Data type of operand B - * @return bool True if MT fits in LDS, false otherwise - */ -bool check_lds_capacity(const hardware_t& hardware, - dim3_t mt, - data_type_t a_dtype, - data_type_t b_dtype); - -/** - * @brief Compute the amount of data loaded from A to produce a MT_MxMT_NxMT_K tile. - * - * @param MT_M Macro tile dimension M - * @param MT_K Macro tile dimension K - * @return size_t Amount of data loaded from A - */ -size_t compute_A_loads(size_t MT_M, size_t MT_K); - -/** - * @brief Compute the amount of data loaded from B to produce a MT_MxMT_NxMT_K tile. - * - * @param MT_N Macro tile dimension N - * @param MT_K Macro tile dimension K - * @return size_t Amount of data loaded from B - */ -size_t compute_B_loads(size_t MT_N, size_t MT_K); - -/** - * @brief A linear-estimation method for estimating L2-hitrate. - * - * @todo Parameterize this based on the space-filling curve algos. - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param config Kernel configuration. - * @param splitting_factor - * @return double Predicted L2-hitrate. - */ -double estimate_l2_hit(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - std::size_t splitting_factor); - -/** - * @brief Estimate the MALL-hitrate (last-level cache.) - * - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param config Kernel configuration. - * @param num_active_cus - * @param splitting_factor - * @return double Predicted MALL-hitrate. - */ -double estimate_mall_hit(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - std::size_t num_active_cus, - std::size_t splitting_factor); - -/** - * @brief Determine the memory latency per MT_M x MT_N x MT_K Macro Tile (L_MT). - * - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param config Kernel configuration. - * @param num_active_cus - * @param splitting_factor - * @return double Latency in cycles. - */ -double compute_memory_latency(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - std::size_t num_active_cus, - std::size_t splitting_factor); - -/** - * @brief Computes the latency to compute a K-COMPLETE tile. - * - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param config Kernel configuration. - * @param num_active_cus - * @param splitting_factor - * @return double Latency in cycles. - */ -double compute_tile_latency(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - std::size_t num_active_cus, - std::size_t splitting_factor); - -/** - * @brief Computes the latency per K-complete MT wave. - * A wave is defined as the time it takes for one CU to complete one - * K-complete output tile - * - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param config Kernel configuration. - * @param num_active_cus - * @param splitting_factor - * @return double Latency in cycles. - */ -double compute_timestep_latency(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - std::size_t num_active_cus, - std::size_t splitting_factor); - -/** - * @brief Compute the total latency of a gemm based on the latency of one wave multiplied by the - * number of waves A wave is defined as the time it takes for one CU to complete one K-complete - * output tile. - * - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param config Kernel configuration. - * @param max_cus - * @return double Latency in cycles. - */ -double compute_total_latency(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - size_t max_cus); - -} // namespace origami +#include + +namespace origami +{ + // Placeholder for compute_reuse_in_block_gemm function. + // TODO move over L2 hit rate simulation for tie-breaking. + double compute_reuse_in_block_gemm(size_t grid_m, + size_t grid_n, + size_t grid_k, + size_t A_size, + size_t B_size, + size_t C_size, + size_t nproc, + size_t capacity, + const std::vector& radix, + bool print_radix, + bool print_output, + size_t max_timesteps, + size_t max_iters); + + // Compute + std::tuple compute_CU_occupancy(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, + size_t element_size_B, + size_t element_size_out, + data_type_t mi_datatype, + int WGM, + size_t workspace_size, + size_t workspace_size_per_elem_c, + int occupancy, + int dynamic_grid_version, + size_t split, + size_t max_cus = 0); + + /* ---------------------------------------------------------------------------------------- */ + /* Compute-related functions */ + /* ---------------------------------------------------------------------------------------- */ + // Compute the number of matrix instructions required to compute a single MT_MXMT_NXMT_K tile. + size_t compute_number_matrix_instructions(const hardware_t& hardware, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K); + + // Determine the compute latency per MT_MxMT_NxMT_K Macro Tile (L_MT). + size_t compute_mt_compute_latency(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, //In bits + size_t element_size_B, //In bits, + data_type_t mi_datatype); + + /* ---------------------------------------------------------------------------------------- */ + /* Memory-related functions */ + /* ---------------------------------------------------------------------------------------- */ + // Check if MT fits in LDS + bool check_lds_capacity( + const hardware_t& hardware, size_t MT_M, size_t MT_N, size_t MT_K, size_t element_size_out); + + // Compute the amount of data loaded from A to produce a MT_MxMT_NxMT_K tile. + size_t compute_A_loads(size_t MT_M, size_t MT_K); + + // Compute the amount of data loaded from B to produce a MT_MxMT_NxMT_K tile. + size_t compute_B_loads(size_t MT_N, size_t MT_K); + + // Computes total data loads per CU per MT from A and B + // Reads happen every MT, Writes happen every K-complete tile. + size_t compute_cu_loads(size_t MT_M, size_t MT_N, size_t MT_K); + + // Estimates the l2 hit-rate + double estimate_l2_hit(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t element_size, + int WGM, + size_t splittingFactor); + + // Estimates the mall hit-rate + double estimate_mall_hit(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t element_size, + int WGM, + size_t numActiveCUs, + size_t splittingFactor); + + // Determine the memory latency per MT_MxMT_NxMT_K Macro Tile (L_MT). + double compute_memory_latency(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + bool transA, + bool transB, + size_t batch, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t element_size_A, //In bits + size_t element_size_B, //In bits, + size_t mx_block_size, + int WGM, + size_t numActiveCUs, + size_t splittingFactor); + + /* ---------------------------------------------------------------------------------------- */ + /* Tile-related functions */ + /* ---------------------------------------------------------------------------------------- */ + // Computes the latency to compute a K-COMPLETE tile. + double compute_tile_latency(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, //In bits + size_t element_size_B, //In bits, + size_t element_size_out, //In bits + data_type_t mi_datatype, + size_t mx_block_size, + int WGM, + int occupancy, + size_t numActiveCUs, + size_t splittingFactor); + + // Computes the latency per K-complete MT wave. + // A wave is defined as : The time it takes for one CU to complete one K-complete output tile + double compute_wave_latency(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, //In bits + size_t element_size_B, //In bits, + size_t element_size_out, //In bits + data_type_t mi_datatype, + size_t mx_block_size, + int WGM, + int occupancy, + size_t numActiveCUs, + size_t splittingFactor); + + // Compute the total latency of a gemm based on the latency of one wave multiplied by the number of waves + // A wave is defined as : The time it takes for one CU to complete one K-complete output tile + double compute_total_latency(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, //In bits + size_t element_size_B, //In bits, + size_t element_size_out, //In bits + data_type_t mi_datatype, + size_t mx_block_size, + int WGM, + int non_temporal_a = 0, + int non_temporal_b = 0, + int occupancy = 1, + size_t split = 0, + size_t max_cus = 0); + + // Compute the performance from the latency. + // IMPORTANT : This program is NOT meant to be an analytical model for performance, but rather a way to rank different macro tile sizes. + // These performance values could be wildly inaccurate in absolute terms, but will often result in the correct ranking of MTin relative terms. + double compute_perf_gflops(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, + size_t element_size_B, + size_t element_size_out, + data_type_t mi_datatype, + int WGM, + size_t max_cus = 0); +} // namespace origami diff --git a/shared/origami/include/origami/hardware.hpp b/shared/origami/include/origami/hardware.hpp index 0fe73d06cfe..f5dd7edb4ca 100644 --- a/shared/origami/include/origami/hardware.hpp +++ b/shared/origami/include/origami/hardware.hpp @@ -1,252 +1,763 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright 2025 AMD ROCm(TM) Software - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once -#include +#include #include -#include #include #include #include -#include "origami/types.hpp" - -namespace origami { -/** - * @brief Represents hardware characteristics and capabilities of GPU architectures. - * - */ -class hardware_t { - public: - /** - * @brief Enumeration of supported GPU architectures. - * - */ - enum class architecture_t { gfx90a, gfx942, gfx950, gfx1201, gfx1100, gfx1151, Count }; - - /** - * @brief Convert architecture name string to architecture_t enum. - * - * @param str Architecture name as string (e.g., "gfx90a", "gfx942") - * @return architecture_t Corresponding enum value, or Count if not recognized - */ - static constexpr architecture_t arch_name_to_enum(std::string_view str) noexcept { - if (str == "gfx90a") return architecture_t::gfx90a; - if (str == "gfx942") return architecture_t::gfx942; - if (str == "gfx950") return architecture_t::gfx950; - if (str == "gfx1201") return architecture_t::gfx1201; - if (str == "gfx1100") return architecture_t::gfx1100; - if (str == "gfx1151") return architecture_t::gfx1151; - return architecture_t::Count; - } - - /** - * @brief Architecture-specific constants for memory and compute characteristics. - * - */ - struct architecture_constants { - size_t num_xcds; ///< Number of XCDs (XCD = XGMI Complex Die) - double mem1_perf_ratio; - double mem2_perf_ratio; - double mem3_perf_ratio; - size_t parallel_mi_cu; ///< Number of parallel matrix instructions per compute unit - std::tuple - mem_bw_per_wg_coefficients; ///< Memory bandwidth coefficients per workgroup - double mem_clock_ratio; ///< Memory clock ratio relative to compute clock - - constexpr architecture_constants(size_t num_xcds, - double mem1_perf_ratio, - double mem2_perf_ratio, - double mem3_perf_ratio, - size_t parallel_mi_cu, - std::tuple mem_bw_per_wg_coefficients, - double mem_clock_ratio) // Obtained through microbenchmarking - : num_xcds(num_xcds) - , mem1_perf_ratio(mem1_perf_ratio) - , mem2_perf_ratio(mem2_perf_ratio) - , mem3_perf_ratio(mem3_perf_ratio) - , parallel_mi_cu(parallel_mi_cu) - , mem_bw_per_wg_coefficients(mem_bw_per_wg_coefficients) - , mem_clock_ratio(mem_clock_ratio) {} - }; - - /** - * @brief Get architecture-specific constants for a given architecture. - * - * Returns the pre-configured constants (memory performance ratios, bandwidth - * coefficients, etc.) for the specified architecture. These values are - * determined through microbenchmarking. - * - * @param arch Architecture enum value - * @return architecture_constants Constants for the specified architecture - */ - static constexpr architecture_constants get_arch_constants(architecture_t arch) { - switch (arch) { - case architecture_t::gfx90a: - return {1, 5.5, 1.21875121875121875122 * 1.2, 1.2, 4, std::make_tuple(0, 0.03, 0), 1.5}; - case architecture_t::gfx942: - return {8, 17, 1.21875121875121875122 * 6, 4, 4, std::make_tuple(0, 0.015, 0), 1.5}; - case architecture_t::gfx950: - return {8, 17, 1.21875121875121875122 * 7, 6, 4, std::make_tuple(0, 0.008, 0), 1.5}; - case architecture_t::gfx1201: - return {1, 5.74, 1.21875121875121875122 * 2.41, 0.464, 2, std::make_tuple(0, 0.17, 0), 1.5}; - case architecture_t::gfx1100: - return {1, 7.12, 1.21875121875121875122 * 3.48, 0.732, 2, std::make_tuple(0, 0.11, 0), 1.5}; - case architecture_t::gfx1151: - return {1, 2.47, 1.21875121875121875122 * 0.93, 0.215, 2, std::make_tuple(0, 0.22, 0), 1.5}; - default: return {0, 0, 0, 0, 0, std::make_tuple(0, 0, 0), 0}; +namespace origami +{ + enum class data_type_t : int + { + Float, + Double, + ComplexFloat, + ComplexDouble, + Half, + Int8x4, + Int32, + BFloat16, + Int8, + Int4, + Int64, + XFloat32, + Float8_fnuz, + BFloat8_fnuz, + Float8BFloat8_fnuz, + BFloat8Float8_fnuz, + Float8, + BFloat8, + Float8BFloat8, + BFloat8Float8, + Float6, + BFloat6, + Float4, + Count, + None = Count + }; + + inline data_type_t int_to_data_type(int dt) + { + return (data_type_t)dt; + } + + inline int data_type_to_bits(data_type_t type) + { + switch(type) + { + case data_type_t::Float: + return 32; + case data_type_t::Double: + return 64; + case data_type_t::ComplexFloat: + return 64; + case data_type_t::ComplexDouble: + return 128; + case data_type_t::Half: + return 16; + case data_type_t::Int8x4: + return 32; + case data_type_t::Int32: + return 32; + case data_type_t::BFloat16: + return 16; + case data_type_t::Int8: + return 8; + case data_type_t::Int4: + return 4; + case data_type_t::Int64: + return 64; + case data_type_t::XFloat32: + return 32; + case data_type_t::Float8_fnuz: + return 8; + case data_type_t::BFloat8_fnuz: + return 8; + case data_type_t::Float8BFloat8_fnuz: + return 8; + case data_type_t::BFloat8Float8_fnuz: + return 8; + case data_type_t::Float8: + return 8; + case data_type_t::BFloat8: + return 8; + case data_type_t::Float8BFloat8: + return 8; + case data_type_t::BFloat8Float8: + return 8; + case data_type_t::Float6: + return 6; + case data_type_t::BFloat6: + return 6; + case data_type_t::Float4: + return 4; + default: + return -1; // Invalid type + } } - } - - /** - * @brief Map of matrix instruction latencies by architecture. - * - */ - static const std::unordered_map> - INSTRUCTION_MAP; - - architecture_t arch; ///< GPU architecture type - size_t N_CU; ///< Number of Compute Units - size_t lds_capacity; ///< Capacity of Local Data Share (LDS) in bytes - double mem1_perf_ratio; - double mem2_perf_ratio; - double mem3_perf_ratio; - size_t L2_capacity; ///< Capacity of L2 cache in bytes - size_t CU_per_L2; ///< Number of compute units per L2 cache domain - double compute_clock_ghz; ///< Compute clock frequency in GHz - size_t parallel_mi_cu; ///< Number of parallel matrix instructions per compute unit - std::tuple - mem_bw_per_wg_coefficients; ///< Memory bandwidth coefficients per workgroup - size_t NUM_XCD; ///< Number of XCDs (XGMI Complex Die) - - /** - * @brief Construct hardware_t with explicit parameters. - * - * @param arch GPU architecture type - * @param N_CU Number of compute units - * @param lds_capacity LDS capacity in bytes - * @param NUM_XCD Number of XCDs - * @param mem1_perf_ratio Memory level 1 performance ratio - * @param mem2_perf_ratio Memory level 2 performance ratio - * @param mem3_perf_ratio Memory level 3 performance ratio - * @param L2_capacity L2 cache capacity in bytes - * @param compute_clock_ghz Compute clock frequency in GHz - * @param parallel_mi_cu Number of parallel matrix instructions per CU - * @param mem_bw_per_wg_coefficients Memory bandwidth coefficients per workgroup - */ - hardware_t(architecture_t arch, - size_t N_CU, - size_t lds_capacity, - size_t NUM_XCD, - double mem1_perf_ratio, - double mem2_perf_ratio, - double mem3_perf_ratio, - size_t L2_capacity, - double compute_clock_ghz, - size_t parallel_mi_cu, - std::tuple mem_bw_per_wg_coefficients); - - /** - * @brief Construct hardware_t from HIP device properties. - * - * Automatically determines architecture and extracts hardware parameters - * from the provided HIP device properties structure. - * - * @param properties HIP device properties structure - */ - hardware_t(hipDeviceProp_t properties); - - /** - * @brief Copy constructor. - * - * @param other Another hardware_t instance to copy from - */ - hardware_t(const hardware_t& other); - - /** - * @brief Create hardware_t instance from HIP device properties. - * - * - * @param properties HIP device properties structure - * @return hardware_t Configured hardware instance - */ - static hardware_t get_hardware_for_properties(hipDeviceProp_t properties); - - /** - * @brief Create hardware_t instance for a specific HIP device. - * - * Queries the specified HIP device and creates a hardware_t instance - * with the appropriate architecture and parameters. - * - * @param deviceId HIP device ID - * @return hardware_t Configured hardware instance for the device - */ - static hardware_t get_hardware_for_device(int deviceId); - - /** - * @brief Check if the hardware described by properties is supported. - * - * Determines whether the GPU architecture represented by the device - * properties is supported by the analytical model. - * - * @param properties HIP device properties structure - * @return true if the architecture is supported, false otherwise - */ - static bool is_hardware_supported(hipDeviceProp_t properties); - - /** - * @brief Print hardware details to stdout. - * - */ - void print() const; - - /** - * @brief Get matrix instruction latency for given instruction parameters. - * - * - * @param MI_M Matrix instruction M dimension - * @param MI_N Matrix instruction N dimension - * @param MI_K Matrix instruction K dimension - * @param mi_input_type Input data type for the matrix instruction - * @return size_t Instruction latency in cycles, or 0 if not found - */ - size_t get_mi_latency(size_t MI_M, size_t MI_N, size_t MI_K, data_type_t mi_input_type) const; - - private: - /** - * @brief Extract substring before the first colon character. - * - * Helper function used for parsing architecture names from device - * property strings (e.g., extracting "gfx90a" from "gfx90a:..."). - * - * @param input Input string to parse - * @return std::string Substring before the first colon, or entire string if no colon found - */ - static std::string get_before_first_colon(const std::string& input); -}; -} // namespace origami + + inline std::string to_string(data_type_t type) + { + switch(type) + { + case data_type_t::Float: + return "Float"; + case data_type_t::Double: + return "Double"; + case data_type_t::ComplexFloat: + return "ComplexFloat"; + case data_type_t::ComplexDouble: + return "ComplexDouble"; + case data_type_t::Half: + return "Half"; + case data_type_t::Int8x4: + return "Int8x4"; + case data_type_t::Int32: + return "Int32"; + case data_type_t::BFloat16: + return "BFloat16"; + case data_type_t::Int8: + return "Int8"; + case data_type_t::Int4: + return "Int4"; + case data_type_t::Int64: + return "Int64"; + case data_type_t::XFloat32: + return "XFloat32"; + case data_type_t::Float8_fnuz: + return "Float8_fnuz"; + case data_type_t::BFloat8_fnuz: + return "BFloat8_fnuz"; + case data_type_t::Float8BFloat8_fnuz: + return "Float8BFloat8_fnuz"; + case data_type_t::BFloat8Float8_fnuz: + return "BFloat8Float8_fnuz"; + case data_type_t::Float8: + return "Float8"; + case data_type_t::BFloat8: + return "BFloat8"; + case data_type_t::Float8BFloat8: + return "Float8BFloat8"; + case data_type_t::BFloat8Float8: + return "BFloat8Float8"; + case data_type_t::Float6: + return "Float6"; + case data_type_t::BFloat6: + return "BFloat6"; + case data_type_t::Float4: + return "Float4"; + default: + return "Invalid"; + } + return "Invalid"; + } + + inline data_type_t string_to_data_type(std::string s) + { + if (s == "f32") + return data_type_t::Float; + if (s == "c32") + return data_type_t::ComplexFloat; + if (s == "c64") + return data_type_t::ComplexDouble; + if (s == "f64") + return data_type_t::Double; + if (s == "f16") + return data_type_t::Half; + if (s == "i32") + return data_type_t::Int32; + if (s == "bf16") + return data_type_t::BFloat16; + if (s == "i8") + return data_type_t::Int8; + if (s == "i4") + return data_type_t::Int4; + if (s == "xf32") + return data_type_t::XFloat32; + if (s == "f8") + return data_type_t::Float8; + if (s == "bf8") + return data_type_t::BFloat8; + if (s == "f6") + return data_type_t::Float6; + if (s == "bf6") + return data_type_t::BFloat6; + if (s == "f4") + return data_type_t::Float4; + return data_type_t::None; + } + + struct matrix_instruction + { + size_t MI_M; + size_t MI_N; + size_t MI_K; + data_type_t mi_input_type; + + matrix_instruction() + : MI_M(0) + , MI_N(0) + , MI_K(0) + , mi_input_type(data_type_t::Float) + { + } + + matrix_instruction(size_t m, size_t n, size_t k, data_type_t mi_input_type) + : MI_M(m) + , MI_N(n) + , MI_K(k) + , mi_input_type(mi_input_type) + { + } + + matrix_instruction(const matrix_instruction& other) + : MI_M(other.MI_M) + , MI_N(other.MI_N) + , MI_K(other.MI_K) + , mi_input_type(other.mi_input_type) + { + } + + bool operator<(const matrix_instruction& other) const + { + return std::tie(MI_M, MI_N, MI_K, mi_input_type) + < std::tie(other.MI_M, other.MI_N, other.MI_K, other.mi_input_type); + } + + bool operator==(const matrix_instruction& other) const + { + return MI_M == other.MI_M && MI_N == other.MI_N && MI_K == other.MI_K + && mi_input_type == other.mi_input_type; + } + + std::size_t hash() const + { + return std::hash()(MI_M) ^ std::hash()(MI_N) + ^ std::hash()(MI_K) ^ std::hash()(mi_input_type); + } + }; +} + +// Specialize std::hash for the matrix_instruction struct to use it as an unordered_map key. +namespace std +{ + template <> + struct hash + { + std::size_t operator()(const origami::matrix_instruction& k) const + { + return k.hash(); + } + }; +} + +namespace origami +{ + class hardware_t + { + public: + enum class architecture_t + { + gfx90a, + gfx942, + gfx950, + gfx1201, + gfx1100, + gfx1151, + Count + }; + + static architecture_t arch_name_to_enum(const std::string& str) + { + static const std::unordered_map str_to_enum_map + = {{"gfx90a", architecture_t::gfx90a}, + {"gfx942", architecture_t::gfx942}, + {"gfx950", architecture_t::gfx950}, + {"gfx1201", architecture_t::gfx1201}, + {"gfx1100", architecture_t::gfx1100}, + {"gfx1151", architecture_t::gfx1151}}; + + auto it = str_to_enum_map.find(str); + if(it != str_to_enum_map.end()) + { + return it->second; + } + else + { + return architecture_t::Count; + } + } + + struct architecture_constants + { + size_t num_xcds; + double mem1_perf_ratio; + double mem2_perf_ratio; + double mem3_perf_ratio; + size_t parallel_mi_cu; + std::tuple mem_bw_per_wg_coefficients; + double mem_clock_ratio; + + architecture_constants(size_t num_xcds, + double mem1_perf_ratio, + double mem2_perf_ratio, + double mem3_perf_ratio, + size_t parallel_mi_cu, + std::tuple mem_bw_per_wg_coefficients, + double mem_clock_ratio) //Obtained through microbenchmarking + : num_xcds(num_xcds) + , mem1_perf_ratio(mem1_perf_ratio) + , mem2_perf_ratio(mem2_perf_ratio) + , mem3_perf_ratio(mem3_perf_ratio) + , parallel_mi_cu(parallel_mi_cu) + , mem_bw_per_wg_coefficients(mem_bw_per_wg_coefficients) + , mem_clock_ratio(mem_clock_ratio) + { + } + }; + + inline static const std::unordered_map ARCH_CONSTANT_MAP + = {{hardware_t::architecture_t::gfx90a, + hardware_t::architecture_constants( + 1, 5.5, 1.21875121875121875122 * 1.2, 1.2, 4, std::make_tuple(0, 0.03, 0), 1.5)}, + {hardware_t::architecture_t::gfx942, + hardware_t::architecture_constants( + 8, 17, 1.21875121875121875122 * 6, 4, 4, std::make_tuple(0, 0.015, 0), 1.5)}, + {hardware_t::architecture_t::gfx950, + // hardware_t::architecture_constants( + // 8, 17, 1.21875121875121875122 * 7, 6, 4, std::make_tuple(-0.000013, 0.007070, 0.027355), 1.5)}}; + hardware_t::architecture_constants( + 8, 17, 1.21875121875121875122 * 7, 6, 4, std::make_tuple(0, 0.008, 0), 1.5)}, + {hardware_t::architecture_t::gfx1201, + hardware_t::architecture_constants( + 1, 5.74, 1.21875121875121875122 * 2.41, 0.464, 2, std::make_tuple(0, 0.17, 0), 1.5)}, + {hardware_t::architecture_t::gfx1100, + hardware_t::architecture_constants( + 1, 7.12, 1.21875121875121875122 * 3.48, 0.732, 2, std::make_tuple(0, 0.11, 0), 1.5)}, + {hardware_t::architecture_t::gfx1151, + hardware_t::architecture_constants( + 1, 2.47, 1.21875121875121875122 * 0.93, 0.215, 2, std::make_tuple(0, 0.22, 0), 1.5)}}; + + inline static const std::unordered_map> INSTRUCTION_MAP + = {{hardware_t::architecture_t::gfx90a, + { + // F32 + {matrix_instruction(32, 32, 2, data_type_t::Float), 64}, // v_mfma_f32_32x32x2_f32 + {matrix_instruction(32, 32, 1, data_type_t::Float), 64}, // v_mfma_f32_32x32x1_2b_f32 + {matrix_instruction(16, 16, 4, data_type_t::Float), 32}, // v_mfma_f32_16x16x4_f32 + {matrix_instruction(16, 16, 1, data_type_t::Float), 32}, // v_mfma_f32_16x16x1_4b_f32 + {matrix_instruction(4, 4, 1, data_type_t::Float), 8}, // v_mfma_f32_4x4x1_16b_f32 + // F64 + {matrix_instruction(16, 16, 4, data_type_t::Double), 32}, // v_mfma_f64_16x16x4_f64 + {matrix_instruction(4, 4, 4, data_type_t::Double), 16}, // v_mfma_f64_4x4x4_4b_f64 + // TODO ComplexFloat + // TODO ComplexDouble + // F16 + {matrix_instruction(32, 32, 4, data_type_t::Half), 64}, // v_mfma_f32_32x32x4_2b_f16 + {matrix_instruction(32, 32, 8, data_type_t::Half), 64}, // v_mfma_f32_32x32x8_f16 + {matrix_instruction(16, 16, 4, data_type_t::Half), 32}, // v_mfma_f32_16x16x4_4b_f16 + {matrix_instruction(16, 16, 16, data_type_t::Half), 32}, // v_mfma_f32_16x16x16_f16 + {matrix_instruction(4, 4, 4, data_type_t::Half), 8}, // v_mfma_f32_4x4x4_16b_f16 + // BF16 + {matrix_instruction(32, 32, 4, data_type_t::BFloat16), 64}, // v_mfma_f32_32x32x4_2b_bf16 + {matrix_instruction(32, 32, 8, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x8_bf16 + {matrix_instruction(16, 16, 4, data_type_t::BFloat16), 32}, // v_mfma_f32_16x16x4_4b_bf16 + {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 + {matrix_instruction(4, 4, 4, data_type_t::BFloat16), 8}, // v_mfma_f32_4x4x4_16b_bf16 + // I8 + {matrix_instruction(32, 32, 8, data_type_t::Int8), 64}, // v_mfma_f32_32x32x16_f8 + {matrix_instruction(32, 32, 4, data_type_t::Int8), 64}, // v_mfma_i32_32x32x4_2b_i8 + {matrix_instruction(16, 16, 16, data_type_t::Int8), 32}, // v_mfma_f32_16x16x32_i8 + {matrix_instruction(16, 16, 4, data_type_t::Int8), 32}, // v_mfma_i32_16x16x4_4b_i8 + {matrix_instruction(4, 4, 4, data_type_t::Int8), 8}, // v_mfma_i32_4x4x4_16b_i8 + // XF32 + {matrix_instruction(32, 32, 8, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x8_bf16 * 3 + {matrix_instruction(32, 32, 16, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x16_bf16 * 3 + {matrix_instruction(16, 16, 16, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 + {matrix_instruction(16, 16, 32, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 + }}, + {hardware_t::architecture_t::gfx942, + { + // F32 + {matrix_instruction(32, 32, 2, data_type_t::Float), 64}, // v_mfma_f32_32x32x2_f32 + {matrix_instruction(32, 32, 1, data_type_t::Float), 64}, // v_mfma_f32_32x32x1_2b_f32 + {matrix_instruction(16, 16, 4, data_type_t::Float), 32}, // v_mfma_f32_16x16x4_f32 + {matrix_instruction(16, 16, 1, data_type_t::Float), 32}, // v_mfma_f32_16x16x1_4b_f32 + {matrix_instruction(4, 4, 1, data_type_t::Float), 8}, // v_mfma_f32_4x4x1_16b_f32 + // F64 + {matrix_instruction(16, 16, 4, data_type_t::Double), 32}, // v_mfma_f64_16x16x4_f64 + {matrix_instruction(4, 4, 4, data_type_t::Double), 16}, // v_mfma_f64_4x4x4_4b_f64 + // TODO ComplexFloat + // TODO ComplexDouble + // F16 + {matrix_instruction(32, 32, 4, data_type_t::Half), 64}, // v_mfma_f32_32x32x4_2b_f16 + {matrix_instruction(32, 32, 8, data_type_t::Half), 32}, // v_mfma_f32_32x32x8_f16 + {matrix_instruction(16, 16, 4, data_type_t::Half), 32}, // v_mfma_f32_16x16x4_4b_f16 + {matrix_instruction(16, 16, 16, data_type_t::Half), 16}, // v_mfma_f32_16x16x16_f16 + {matrix_instruction(4, 4, 4, data_type_t::Half), 8}, // v_mfma_f32_4x4x4_16b_f16 + // BF16 + {matrix_instruction(32, 32, 4, data_type_t::BFloat16), 64}, // v_mfma_f32_32x32x4_2b_bf16 + {matrix_instruction(32, 32, 8, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x8_bf16 + {matrix_instruction(16, 16, 4, data_type_t::BFloat16), 32}, // v_mfma_f32_16x16x4_4b_bf16 + {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 + {matrix_instruction(4, 4, 4, data_type_t::BFloat16), 8}, // v_mfma_f32_4x4x4_16b_bf16 + // F8 + {matrix_instruction(32, 32, 16, data_type_t::Float8_fnuz), 32}, // v_mfma_f32_32x32x16_f8 + {matrix_instruction(16, 16, 32, data_type_t::Float8_fnuz), 16}, // v_mfma_f32_16x16x32_f8 + // BF8 + {matrix_instruction(32, 32, 16, data_type_t::BFloat8_fnuz), 32}, // v_mfma_f32_32x32x16_bf8 + {matrix_instruction(16, 16, 32, data_type_t::BFloat8_fnuz), 16}, // v_mfma_f32_16x16x32_bf8 + // F8B8 + {matrix_instruction(32, 32, 16, data_type_t::Float8BFloat8_fnuz), 32}, // v_mfma_f32_32x32x16_f8_bf8 + {matrix_instruction(16, 16, 32, data_type_t::Float8BFloat8_fnuz), 16}, // v_mfma_f32_16x16x32_f8_bf8 + // B8F8 + {matrix_instruction(32, 32, 16, data_type_t::BFloat8Float8_fnuz), 32}, // v_mfma_f32_32x32x16_bf8_f8 + {matrix_instruction(16, 16, 32, data_type_t::BFloat8Float8_fnuz), 16}, // v_mfma_f32_16x16x32_bf8_f8 + // I8 + {matrix_instruction(32, 32, 16, data_type_t::Int8), 32}, // v_mfma_f32_32x32x16_f8 + {matrix_instruction(32, 32, 4, data_type_t::Int8), 64}, // v_mfma_i32_32x32x4_2b_i8 + {matrix_instruction(16, 16, 32, data_type_t::Int8), 16}, // v_mfma_f32_16x16x32_i8 + {matrix_instruction(16, 16, 4, data_type_t::Int8), 32}, // v_mfma_i32_16x16x4_4b_i8 + {matrix_instruction(4, 4, 4, data_type_t::Int8), 8}, // v_mfma_i32_4x4x4_16b_i8 + // XF32 + {matrix_instruction(32, 32, 4, data_type_t::XFloat32), 32}, // v_mfma_f32_32x32x4_xf32 + {matrix_instruction(16, 16, 32, data_type_t::XFloat32), 16}, // v_mfma_f32_16x16x8_xf32 + }}, + {hardware_t::architecture_t::gfx950, + { + // F32 + {matrix_instruction(32, 32, 2, data_type_t::Float), 64}, // v_mfma_f32_32x32x2_f32 + {matrix_instruction(32, 32, 1, data_type_t::Float), 64}, // v_mfma_f32_32x32x1_2b_f32 + {matrix_instruction(16, 16, 4, data_type_t::Float), 32}, // v_mfma_f32_16x16x4_f32 + {matrix_instruction(16, 16, 1, data_type_t::Float), 32}, // v_mfma_f32_16x16x1_4b_f32 + {matrix_instruction(4, 4, 1, data_type_t::Float), 8}, // v_mfma_f32_4x4x1_16b_f32 + // F64 + {matrix_instruction(16, 16, 4, data_type_t::Double), 64}, // v_mfma_f64_16x16x4_f64 + {matrix_instruction(4, 4, 4, data_type_t::Double), 16}, // v_mfma_f64_4x4x4_4b_f64 + // TODO ComplexFloat + // TODO ComplexDouble + // F16 + {matrix_instruction(32, 32, 8, data_type_t::Half), 32}, // v_mfma_f32_32x32x8_f16 + {matrix_instruction(32, 32, 16, data_type_t::Half), 32}, // v_mfma_f32_32x32x16_f16 + {matrix_instruction(16, 16, 16, data_type_t::Half), 16}, // v_mfma_f32_16x16x16_f16 + {matrix_instruction(16, 16, 32, data_type_t::Half), 16}, // v_mfma_f32_16x16x32_f16 + // BF16 + {matrix_instruction(32, 32, 8, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x8_bf16 + {matrix_instruction(32, 32, 16, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x16_bf16 + {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 + {matrix_instruction(16, 16, 32, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 + // F8 + {matrix_instruction(32, 32, 64, data_type_t::Float8), 64}, // v_mfma_f32_32x32x64_f8 + {matrix_instruction(32, 32, 16, data_type_t::Float8), 32}, // v_mfma_f32_32x32x16_f8 + {matrix_instruction(16, 16, 128, data_type_t::Float8), 32}, // v_mfma_f32_16x16x128_f8 + {matrix_instruction(16, 16, 32, data_type_t::Float8), 16}, // v_mfma_f32_16x16x32_f8 + // BF8 + {matrix_instruction(32, 32, 64, data_type_t::BFloat8), 64}, // v_mfma_f32_32x32x64_bf8 + {matrix_instruction(32, 32, 16, data_type_t::BFloat8), 32}, // v_mfma_f32_32x32x16_bf8 + {matrix_instruction(16, 16, 128, data_type_t::BFloat8), 32}, // v_mfma_f32_16x16x128_bf8 + {matrix_instruction(16, 16, 32, data_type_t::BFloat8), 16}, // v_mfma_f32_16x16x32_bf8 + // F8B8 + {matrix_instruction(32, 32, 64, data_type_t::Float8BFloat8), 64}, // v_mfma_f32_32x32x64_f8_bf8 + {matrix_instruction(32, 32, 16, data_type_t::Float8BFloat8), 32}, // v_mfma_f32_32x32x16_f8_bf8 + {matrix_instruction(16, 16, 128, data_type_t::Float8BFloat8), 32}, // v_mfma_f32_16x16x128_f8_bf8 + {matrix_instruction(16, 16, 32, data_type_t::Float8BFloat8), 16}, // v_mfma_f32_16x16x32_f8_bf8 + // B8F8 + {matrix_instruction(32, 32, 64, data_type_t::BFloat8Float8), 64}, // v_mfma_f32_32x32x64_bf8_f8 + {matrix_instruction(32, 32, 16, data_type_t::BFloat8Float8), 32}, // v_mfma_f32_32x32x16_bf8_f8 + {matrix_instruction(16, 16, 128, data_type_t::BFloat8Float8), 32}, // v_mfma_f32_16x16x128_bf8_f8 + {matrix_instruction(16, 16, 32, data_type_t::BFloat8Float8), 16}, // v_mfma_f32_16x16x32_bf8_f8 + // I8 + {matrix_instruction(32, 32, 16, data_type_t::Int8), 32}, // v_mfma_f32_32x32x16_f8 + {matrix_instruction(32, 32, 4, data_type_t::Int8), 64}, // v_mfma_i32_32x32x4_2b_i8 + {matrix_instruction(16, 16, 32, data_type_t::Int8), 16}, // v_mfma_f32_16x16x32_i8 + {matrix_instruction(16, 16, 4, data_type_t::Int8), 32}, // v_mfma_i32_16x16x4_4b_i8 + {matrix_instruction(4, 4, 4, data_type_t::Int8), 8}, // v_mfma_i32_4x4x4_16b_i8 + // XF32 + {matrix_instruction(32, 32, 8, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x8_bf16 * 3 + {matrix_instruction(32, 32, 16, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x16_bf16 * 3 + {matrix_instruction(16, 16, 16, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 + {matrix_instruction(16, 16, 32, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 + // F6 + {matrix_instruction(32, 32, 64, data_type_t::Float6), 32}, // v_mfma_f32_32x32x64_f6 + {matrix_instruction(16, 16, 128, data_type_t::Float6), 16}, // v_mfma_f32_16x16x128_f6 + // BF6 + {matrix_instruction(32, 32, 64, data_type_t::BFloat6), 32}, // v_mfma_f32_32x32x64_bf6 + {matrix_instruction(16, 16, 128, data_type_t::BFloat6), 16}, // v_mfma_f32_16x16x128_bf6 + // F4 + {matrix_instruction(32, 32, 64, data_type_t::Float4), 32}, // v_mfma_f32_32x32x64_f4 + {matrix_instruction(16, 16, 128, data_type_t::Float4), 16}, // v_mfma_f32_16x16x128_f4 + // DOT2 + {matrix_instruction( 1, 1, 64, data_type_t::Half), 16}, // V_DOT2_F32_F16 + {matrix_instruction( 1, 1, 64, data_type_t::BFloat16), 16}, // V_DOT2_F32_BF16 + }}, + {hardware_t::architecture_t::gfx1201, + { + // F16 + {matrix_instruction(16, 16, 16, data_type_t::Half), 16}, // v_wmma_f16_16x16x16_f16/v_wmma_f32_16x16x16_f16 + // BF16 + {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_wmma_bf16_16x16x16_bf16/v_wmma_f32_16x16x16_bf16 + // F8 + {matrix_instruction(16, 16, 16, data_type_t::Float8), 8}, // v_wmma_f32_16x16x16_fp8_fp8 + // F8B8 + {matrix_instruction(16, 16, 16, data_type_t::Float8BFloat8), 8}, // v_wmma_f32_16x16x16_fp8_bf8 + // B8F8 + {matrix_instruction(16, 16, 16, data_type_t::BFloat8Float8), 8}, // v_wmma_f32_16x16x16_bf8_fp8 + // B8 + {matrix_instruction(16, 16, 16, data_type_t::BFloat8), 8}, // v_wmma_f32_16x16x16_bf8_bf8 + // I8 + {matrix_instruction(16, 16, 16, data_type_t::Int8), 8}, // v_wmma_i32_16x16x16_iu8 + // I4 + {matrix_instruction(16, 16, 16, data_type_t::Int4), 8}, // v_wmma_i32_16x16x16_iu4 + {matrix_instruction(16, 16, 32, data_type_t::Int4), 8}, // v_wmma_i32_16x16x32_iu4 + }}, + {hardware_t::architecture_t::gfx1100, + { + // F16 + {matrix_instruction(16, 16, 16, data_type_t::Half), 32}, // v_wmma_f32_16x16x16_f16/v_wmma_f16_16x16x16_f16 + // BF16 + {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 32}, // v_wmma_f32_16x16x16_bf16/v_wmma_bf16_16x16x16_bf16 + // I8 + {matrix_instruction(16, 16, 16, data_type_t::Int8), 32}, // v_wmma_i32_16x16x16_iu8 + // I4 + {matrix_instruction(16, 16, 16, data_type_t::Int4), 16}, // v_wmma_i32_16x16x16_iu4 + }}, + {hardware_t::architecture_t::gfx1151, + { + // F16 + {matrix_instruction(16, 16, 16, data_type_t::Half), 32}, // v_wmma_f32_16x16x16_f16/v_wmma_f16_16x16x16_f16 + // BF16 + {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 32}, // v_wmma_f32_16x16x16_bf16/v_wmma_bf16_16x16x16_bf16 + // I8 + {matrix_instruction(16, 16, 16, data_type_t::Int8), 32}, // v_wmma_i32_16x16x16_iu8 + // I4 + {matrix_instruction(16, 16, 16, data_type_t::Int4), 16}, // v_wmma_i32_16x16x16_iu4 + }}}; + + architecture_t arch; + size_t N_CU; // Number of Compute Units + size_t lds_capacity; // Capacity of LDS + double mem1_perf_ratio; + double mem2_perf_ratio; + double mem3_perf_ratio; + size_t L2_capacity; // Capacity of L2 in bytes + size_t CU_per_L2; // Number of compute units per L2 domain + double compute_clock_ghz; + size_t parallel_mi_cu; // The number of parallel MI in a CU + std::tuple mem_bw_per_wg_coefficients; + size_t NUM_XCD; + + hardware_t(architecture_t arch, + size_t N_CU, + size_t lds_capacity, + size_t NUM_XCD, + double mem1_perf_ratio, + double mem2_perf_ratio, + double mem3_perf_ratio, + size_t L2_capacity, + double compute_clock_ghz, + size_t parallel_mi_cu, + std::tuple mem_bw_per_wg_coefficients) + : arch(arch) + , N_CU(N_CU) + , lds_capacity(lds_capacity) + , mem1_perf_ratio(mem1_perf_ratio) + , mem2_perf_ratio(mem2_perf_ratio) + , mem3_perf_ratio(mem3_perf_ratio) + , L2_capacity(L2_capacity) + , CU_per_L2(N_CU / NUM_XCD) + , compute_clock_ghz(compute_clock_ghz) + , parallel_mi_cu(parallel_mi_cu) + , mem_bw_per_wg_coefficients(mem_bw_per_wg_coefficients) + , NUM_XCD(NUM_XCD) + { + } + + hardware_t(hipDeviceProp_t properties) + : hardware_t(get_hardware_for_properties(properties)) + { + } + + hardware_t(const hardware_t& other) + : arch(other.arch) + , N_CU(other.N_CU) + , lds_capacity(other.lds_capacity) + , mem1_perf_ratio(other.mem1_perf_ratio) + , mem2_perf_ratio(other.mem2_perf_ratio) + , mem3_perf_ratio(other.mem3_perf_ratio) + , L2_capacity(other.L2_capacity) + , CU_per_L2(other.CU_per_L2) + , compute_clock_ghz(other.compute_clock_ghz) + , parallel_mi_cu(other.parallel_mi_cu) + , mem_bw_per_wg_coefficients(other.mem_bw_per_wg_coefficients) + , NUM_XCD(other.NUM_XCD) + { + } + + static hardware_t get_hardware_for_properties(hipDeviceProp_t properties) + { + auto arch_name = get_before_first_colon(properties.gcnArchName); + auto arch_enum = arch_name_to_enum(arch_name); + auto it = ARCH_CONSTANT_MAP.find(arch_enum); + if(it == ARCH_CONSTANT_MAP.end()) + { + throw std::runtime_error( + "Attempting to retrieve hardware constants for unsupported architecture: " + + arch_name); // Could also return default values here. + } + auto constants = it->second; + return hardware_t(arch_enum, + properties.multiProcessorCount, + properties.sharedMemPerBlock, + constants.num_xcds, + 1e9 * constants.mem1_perf_ratio / properties.clockRate, + 1e9 * constants.mem2_perf_ratio + / (properties.memoryClockRate * constants.mem_clock_ratio), + 1e9 * constants.mem3_perf_ratio / properties.memoryClockRate, + properties.l2CacheSize, + properties.clockRate / 1e6, + constants.parallel_mi_cu, + constants.mem_bw_per_wg_coefficients); + } + + static hardware_t get_hardware_for_device(int deviceId) + { + hipDeviceProp_t prop; + hipError_t e = hipGetDeviceProperties(&prop, deviceId); + if(e) + { + throw std::runtime_error(hipGetErrorString(e)); + } + return get_hardware_for_properties(prop); + } + + static bool is_hardware_supported(hipDeviceProp_t properties) + { + auto arch_name = get_before_first_colon(properties.gcnArchName); + auto arch_enum = arch_name_to_enum(arch_name); + auto it = ARCH_CONSTANT_MAP.find(arch_enum); + return it != ARCH_CONSTANT_MAP.end(); + } + + // Function to print hardware details + void print() const + { + std::cout << "================== Hardware Configuration ==================\n"; + std::cout << "Number of CUs (N_CU) : " << N_CU << "\n"; + std::cout << "LDS capacity : " << lds_capacity << " bytes\n"; + std::cout << "mem1_perf_ratio : " << mem1_perf_ratio << "\n"; + std::cout << "mem2_perf_ratio : " << mem2_perf_ratio << "\n"; + std::cout << "mem3_perf_ratio : " << mem3_perf_ratio << "\n"; + std::cout << "L2 Cache capacity : " << L2_capacity << " bytes\n"; + std::cout << "CUs per L2 domain : " << CU_per_L2 << "\n"; + std::cout << "Compute clock (GHz) : " << compute_clock_ghz << "\n"; + std::cout << "Parallel MI/CU : " << parallel_mi_cu << "\n"; + std::cout << "Number of XCDs (NUM_XCD) : " << NUM_XCD << "\n"; + std::cout << "mem_bw_per_wg_coefficients: " << std::get<0>(mem_bw_per_wg_coefficients) << ", " + << std::get<1>(mem_bw_per_wg_coefficients) << ", " + << std::get<2>(mem_bw_per_wg_coefficients) << "\n\n"; + + std::cout << "------------------ Instruction Map -------------------------\n"; + // Loop over the instruction_map and print each entry + for(const auto& kv : INSTRUCTION_MAP.at(arch)) + { + const auto& key = kv.first; + const auto& L_MI = kv.second; + + std::cout << "Instruction: MI_M=" << key.MI_M << ", MI_N=" << key.MI_N + << ", MI_K=" << key.MI_K << ", mi_input_type=" << to_string(key.mi_input_type) + << " bytes\n" + << " -> Latency (L_MI): " << L_MI << "\n"; + } + std::cout << "===========================================================\n"; + } + // Debug tracking info + mutable std::unordered_map debug_info; + + static bool is_debug_enabled() + { + static bool debugEnvVar = read_debug_env_var(); //Used to cache the read. + return debugEnvVar; + } + + static bool is_heuristics_enabled() + { + static bool heuristicsEnvVar = read_heuristics_env_var(); //Used to cache the read. + return heuristicsEnvVar; + } + + void log_debug(const std::string& key, const std::string& value) const + { + debug_info[key] = value; + } + + void log_debug(const std::string& key, double value) const + { + debug_info[key] = std::to_string(value); + } + + void clear_debug() const + { + debug_info.clear(); + } + + void print_debug_info() const + { + std::cout << "=== Hardware Debug Info ===\n"; + for(const auto& [key, val] : debug_info) + { + std::cout << key << ": " << val << "\n"; + } + std::cout << "===========================\n"; + } + + size_t get_mi_latency(size_t MI_M, size_t MI_N, size_t MI_K, data_type_t mi_input_type) const + { + const auto& instruction_map = INSTRUCTION_MAP.at(arch); + auto key = matrix_instruction(MI_M, MI_N, MI_K, mi_input_type); + + auto it = instruction_map.find(key); + if(it != instruction_map.end()) + { + return it->second / parallel_mi_cu; + } + else + { + std::cerr << "Warning: Latency not found for MI_M=" << MI_M << ", MI_N=" << MI_N + << ", MI_K=" << MI_K << ", mi_input_type=" << to_string(mi_input_type) + << ". Returning latency value of 32 (really slow).\n"; + return 32 / parallel_mi_cu; // Default latency if instruction is not found + } + } + + private: + static std::string get_before_first_colon(const std::string& input) + { + size_t pos = input.find(':'); + if(pos != std::string::npos) + { + return input.substr(0, pos); + } + return input; // Return the whole string if ':' is not found + } + + // Helper function to read the debug environment variable + static bool read_debug_env_var() + { + const char* env = std::getenv("ANALYTICAL_GEMM_DEBUG"); + return env && std::string(env) == "1"; + } + + // Helper function to read the heuristics environment variable + static bool read_heuristics_env_var() + { + const char* env = std::getenv("ANALYTICAL_GEMM_HEURISTICS"); + return !(env && std::string(env) == "0"); + } + }; +} // namespace origami diff --git a/shared/origami/include/origami/log.hpp b/shared/origami/include/origami/log.hpp deleted file mode 100644 index 62d54f2392a..00000000000 --- a/shared/origami/include/origami/log.hpp +++ /dev/null @@ -1,170 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright 2025 AMD ROCm(TM) Software - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ - -#pragma once - -#include -#include -#include -#include - -namespace origami { - -/** - * @brief Logger for collecting and exporting analytical metrics in JSON format. - * - * Provides a single templated logging function that stores key-value pairs - * and can export them as JSON. The caller is responsible for checking if - * debug is enabled before calling log(). - */ -class logger_t { - public: - /** - * @brief Default constructor (constexpr-compatible). - * The map is lazily allocated on first use. - */ - constexpr logger_t() : metrics_(nullptr) {} - - /** - * @brief Destructor - */ - ~logger_t() = default; - - /** - * @brief Copy constructor. - */ - logger_t(const logger_t& other) { - if (other.metrics_) { - metrics_ = std::make_unique>(*other.metrics_); - } else { - metrics_ = nullptr; - } - } - - /** - * @brief Move constructor. - */ - logger_t(logger_t&& other) noexcept = default; - - /** - * @brief Copy assignment operator. - */ - logger_t& operator=(const logger_t& other) { - if (this != &other) { - if (other.metrics_) { - metrics_ = std::make_unique>(*other.metrics_); - } else { - metrics_ = nullptr; - } - } - return *this; - } - - /** - * @brief Move assignment operator. - */ - logger_t& operator=(logger_t&& other) noexcept = default; - - /** - * @brief Ensure metrics map is allocated. - */ - void ensure_metrics() const { - if (!metrics_) { metrics_ = std::make_unique>(); } - } - - /** - * @brief Log a key-value pair. - * - * @tparam T Type of the value (must be convertible to JSON-compatible string) - * @param key The metric key - * @param value The metric value - */ - template - void log(const std::string& key, const T& value) { - ensure_metrics(); - (*metrics_)[key] = to_json_string(value); - } - - /** - * @brief Clear all logged metrics. - */ - void clear() { - if (metrics_) { metrics_->clear(); } - } - - /** - * @brief Print all metrics as JSON to stdout. - */ - void print() const; - - /** - * @brief Export metrics to a JSON file. - * - * @param filename Output filename - */ - void export_json(const std::string& filename) const; - - /** - * @brief Get all metrics as a map. - * - * @return Map of metric key-value pairs - */ - std::unordered_map get_metrics() const { - if (!metrics_) { return std::unordered_map(); } - return *metrics_; - } - - /** - * @brief Check if logger has any metrics. - * - * @return true if metrics map is not empty - */ - bool empty() const { return !metrics_ || metrics_->empty(); } - - private: - mutable std::unique_ptr> metrics_; - - // Convert value to JSON-compatible string - template - std::string to_json_string(const T& value) { - if constexpr (std::is_same_v>, std::string>) { - return "\"" + value + "\""; - } else if constexpr (std::is_convertible_v || - std::is_array_v>) { - return "\"" + std::string(value) + "\""; - } else if constexpr (std::is_same_v>, bool>) { - return value ? "true" : "false"; - } else if constexpr (std::is_arithmetic_v && !std::is_same_v) { - return std::to_string(value); - } else { - // For other types, try to_string (this may fail for some types) - static_assert(std::is_arithmetic_v, "Type must be convertible to string or arithmetic"); - return std::to_string(value); - } - } -}; - -} // namespace origami diff --git a/shared/origami/include/origami/math.hpp b/shared/origami/include/origami/math.hpp deleted file mode 100644 index 8498f1a5bac..00000000000 --- a/shared/origami/include/origami/math.hpp +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright 2025 AMD ROCm(TM) Software - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ - -#pragma once - -namespace origami { -namespace math { - -/** - * @brief Performs `(n + d - 1) / d`, but is robust against the case where - * `(n + d - 1)` would overflow. - * - */ -template -inline constexpr N safe_ceil_div(N n, D d) { - // Static cast to undo integral promotion. - return static_cast(d == 0 ? 0 : (n / d + (n % d != 0 ? 1 : 0))); -} - -} // namespace math -} // namespace origami diff --git a/shared/origami/include/origami/origami.hpp b/shared/origami/include/origami/origami.hpp deleted file mode 100644 index 8b78bdba5b1..00000000000 --- a/shared/origami/include/origami/origami.hpp +++ /dev/null @@ -1,120 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright 2025 AMD ROCm(TM) Software - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ - -#pragma once - -#include -#include -#include -#include - -#include "origami/hardware.hpp" -#include "origami/types.hpp" - -namespace origami { - -/** - * @brief Based on the provided problem and configs; selects the best config. - * - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param configs Vector of all possible valid configurations. - * @return prediction_result_t Configurations with best latency. - */ -prediction_result_t select_config(const problem_t& problem, - const hardware_t& hardware, - const std::vector& configs); - -/** - * @brief Select best workgroup-mapping for the given tile size. - * - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param config Kernel configuration. - * @param skGrid StreamK grid size. - * @return std::tuple - */ -std::tuple select_workgroup_mapping(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - size_t skGrid); - -/** - * @brief Rank configurations based on predicted performance. - * - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param configs List of candidate configurations to rank - * @return std::vector Configurations with latencies ranked by performance - * (best first) - */ -std::vector rank_configs(const problem_t& problem, - const hardware_t& hardware, - const std::vector& configs); - -/** - * @brief Select best configuration based only on M, N, K dimensions with default settings. - * - * @param M Problem dimension M - * @param N Problem dimension N - * @param K Problem dimension K - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param configs List of candidate configurations - * @return prediction_result_t Configurations with best latency. - */ -prediction_result_t select_config_mnk(std::size_t M, - std::size_t N, - std::size_t K, - const hardware_t& hardware, - const std::vector& configs); - -/** - * @brief Select top K configurations based on performance ranking. - * - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param configs List of candidate configurations to rank - * @param topk Number of top configurations to return - * @return std::vector Top K configurations ranked by performance (best first) - */ -std::vector select_topk_configs(const problem_t& problem, - const hardware_t& hardware, - const std::vector& configs, - std::size_t topk); - -/** - * @brief Given a latency, compute the achieved throughput in gflops. - * - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param problem Problem description (M, N, K, etc.) - * @param latency Kernel latency. - * @return double Throughput in gflops/s. - */ -double compute_perf_gflops(const hardware_t& hardware, - const problem_t& problem, - const double latency); - -} // namespace origami diff --git a/shared/origami/include/origami/streamk.hpp b/shared/origami/include/origami/streamk.hpp index 10cc6077dbb..467bb94de9f 100644 --- a/shared/origami/include/origami/streamk.hpp +++ b/shared/origami/include/origami/streamk.hpp @@ -1,79 +1,78 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright 2025 AMD ROCm(TM) Software - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT #pragma once #include "origami/hardware.hpp" -#include "origami/types.hpp" - #include -namespace origami { -namespace streamk { -/** - * @brief Number of output tiles. - * - * @param mt_m Tile size in M-dimension. - * @param mt_n Tile size in N-dimension. - * @param m Matrix's m-dimension. - * @param n Matrix's n-dimension. - * @param batch Number of batches. - * @return size_t Total number of output tiles. - */ -size_t compute_number_of_output_tiles(size_t mt_m, size_t mt_n, size_t m, size_t n, size_t batch); +namespace origami +{ + namespace streamk + { + enum class reduction_type + { + // BasicReduction, + Tree, + Parallel, + // AtomicReduction, + Count, + None = Count + }; + + inline reduction_type int_to_reduction_type(int rt) + { + return (reduction_type)rt; + } + + size_t get_workspace( + size_t x, + size_t y, + size_t mt_m, + size_t mt_n, + size_t bpe_c, + size_t grid, + size_t tiles, + reduction_type reduction); + + reduction_type select_reduction( + size_t x, + size_t y, + size_t z, + size_t batch, + size_t mt_m, + size_t mt_n, + size_t mt_k, + const hardware_t& analytical_hardware, + int dynamic_grid_version); -/** - * @brief Select the best reduction strategy for StreamK. - * - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param config Kernel configuration. - * @param algorithm Grid selection algorithm - * @return reduction_t Selected reduction strategy - */ -reduction_t select_reduction(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - grid_selection_t algorithm); + const char* rtype_to_string(streamk::reduction_type r); -/** - * @brief Based on the provided kernel config, select the best grid dimension. - * - * @param problem Problem description (M, N, K, etc.) - * @param hardware Hardware characteristics (@see origami::hardware_t) - * @param config Kernel configuration. - * @param grid_selection_t grid selection algorithm (@see origami::grid_selection_t) - * @param max_cus Maximum number of CUs to use. - * @return size_t Dimensions of the grid launched. - */ -size_t select_grid_size(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - grid_selection_t algorithm, - size_t max_cus = 0); + size_t select_grid(size_t x, + size_t y, + size_t z, + size_t batch, + bool trans_a, + bool trans_b, + size_t element_size_A, + size_t element_size_B, + size_t element_size_out, + data_type_t mi_datatype, + size_t workspace_size, + size_t mt_m, + size_t mt_n, + size_t mt_k, + size_t mi_m, + size_t mi_n, + size_t mi_k, + int workgroup_mapping, + size_t workspace_size_per_elem_c, + int occupancy, + const hardware_t& analytical_hardware, + int dynamic_grid_version, + reduction_type reduction_strategy, + size_t max_cus = 0); + // max workspace -} // namespace streamk -} // namespace origami + } // namespace streamk +} diff --git a/shared/origami/include/origami/types.hpp b/shared/origami/include/origami/types.hpp deleted file mode 100644 index 69675590943..00000000000 --- a/shared/origami/include/origami/types.hpp +++ /dev/null @@ -1,414 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright 2025 AMD ROCm(TM) Software - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "origami/log.hpp" -#include "origami/math.hpp" - -namespace origami { - -/** - * @brief Enumeration of supported data types. - * - */ -enum class data_type_t : int { - Float, - Double, - ComplexFloat, - ComplexDouble, - Half, - Int8x4, - Int32, - BFloat16, - Int8, - Int4, - Int64, - XFloat32, - Float8_fnuz, - BFloat8_fnuz, - Float8BFloat8_fnuz, - BFloat8Float8_fnuz, - Float8, - BFloat8, - Float8BFloat8, - BFloat8Float8, - Float6, - BFloat6, - Float4, - Count, - None = Count -}; - -/** - * @brief Convert integer to data_type_t enum. - * - * @param dt Integer value to convert - * @return data_type_t Corresponding data type - */ -inline data_type_t int_to_data_type(int dt) { return static_cast(dt); } - -/** - * @brief Convert data_type_t to number of bits. - * - * @param type Data type - * @return int Number of bits - */ -int datatype_to_bits(data_type_t type); - -/** - * @brief Convert data_type_t to number of bytes. - * - * @param type Data type - * @return int Number of bytes - */ -inline int data_type_to_bytes(data_type_t type) { - return math::safe_ceil_div(datatype_to_bits(type), 8); -} - -/** - * @brief Convert data_type_t to string. - * - * @param type Data type - * @return std::string String representation of data type - */ -std::string datatype_to_string(data_type_t type); - -/** - * @brief Convert string to data_type_t enum. - * - * @param s String value to convert - * @return data_type_t Corresponding data type - */ -data_type_t string_to_datatype(std::string s); - -/** - * @brief Struct to define a matrix instruction. - * - * Contains the dimensions and data type of a matrix instruction. - */ -struct matrix_instruction { - size_t MI_M; - size_t MI_N; - size_t MI_K; - data_type_t mi_input_type; - - matrix_instruction() : MI_M(0), MI_N(0), MI_K(0), mi_input_type(data_type_t::Float) {} - - matrix_instruction(size_t m, size_t n, size_t k, data_type_t mi_input_type) - : MI_M(m), MI_N(n), MI_K(k), mi_input_type(mi_input_type) {} - - matrix_instruction(const matrix_instruction& other) - : MI_M(other.MI_M), MI_N(other.MI_N), MI_K(other.MI_K), mi_input_type(other.mi_input_type) {} - - bool operator<(const matrix_instruction& other) const { - return std::tie(MI_M, MI_N, MI_K, mi_input_type) < - std::tie(other.MI_M, other.MI_N, other.MI_K, other.mi_input_type); - } - - bool operator==(const matrix_instruction& other) const { - return MI_M == other.MI_M && MI_N == other.MI_N && MI_K == other.MI_K && - mi_input_type == other.mi_input_type; - } - - std::size_t hash() const { - return std::hash()(MI_M) ^ std::hash()(MI_N) ^ std::hash()(MI_K) ^ - std::hash()(mi_input_type); - } -}; - -/** - * @brief Grid selection algorithms for StreamK. - * - * Different algorithms to select the grid size for kernel execution. - */ -enum class grid_selection_t : std::uint32_t { - number_of_cus = 0, ///< Use number of compute units - min_resources = 1, ///< Use minimum required resources - energy_aware = 2, ///< Energy-aware selection - reduction_cost_aware = 3, ///< Reduction cost-aware selection - data_parallel = 4, ///< Data parallel approach - analytical = 5, ///< Analytical model-based selection - k_split_aware = 6, ///< K-split aware selection - count, ///< Count of Grid selection algos - none = 0xFFFFFFFFu ///< Explicitly invalid -}; - -/** - * @brief Reduction strategy types for StreamK. - * - * Different algorithms for reduction operations in StreamK. - */ -enum class reduction_t : std::uint32_t { - spinlock = 0, ///< Spinlock-based reduction - tree = 1, ///< Tree-based reduction - parallel = 2, ///< Parallel reduction - atomic = 3, ///< Atomic Add-based reduction - count, ///< Count of reduction types - none = 0xFFFFFFFFu ///< Explicitly invalid / no reduction -}; - -/** - * @brief Convert integer to reduction_t enum. - * - * @param rt Integer value to convert - * @return reduction_t Corresponding reduction type - */ -inline constexpr reduction_t int_to_reduction_t(int rt) { return static_cast(rt); } - -/** - * @brief Indicates whether a matrix is supplied in transposed or not. - */ -enum class transpose_t { - T, - N, - - Count -}; - -/** - * @brief A compact 3-D dimension triple (M, N, K). - * - * Provides convenient accessors for common GEMM tiling parameters - * and helpers like mnk() for volume. - */ -struct dim3_t { - /// M dimension (rows). - std::size_t m; - - /// N dimension (columns). - std::size_t n; - - /// K dimension (reduction). - std::size_t k; - - constexpr bool operator==(const dim3_t& o) const noexcept { - return m == o.m && n == o.n && k == o.k; - } - - constexpr bool operator!=(const dim3_t& o) const noexcept { return !(*this == o); } - - /// @return Product m*n. - constexpr std::size_t mn() const noexcept { return m * n; } - - /// @return Product m*k. - constexpr std::size_t mk() const noexcept { return m * k; } - - /// @return Product n*k. - constexpr std::size_t nk() const noexcept { return n * k; } - - /// @return Product m*n*k. - constexpr std::size_t mnk() const noexcept { return m * n * k; } -}; - -/** - * @brief Runtime options for controlling debug, heuristics, and other behaviors. - * - * Provides programmatic access to runtime configuration options that can be - * set either programmatically or via environment variables. - */ -struct runtime_options { - /// Enable debug logging (reads from ANALYTICAL_GEMM_DEBUG env var) - bool debug_enabled; - - /// Enable heuristics (reads from ANALYTICAL_GEMM_HEURISTICS env var) - bool heuristics_enabled; - - /// Heuristics variance threshold (reads from ANALYTICAL_GEMM_HEURISTICS_VARIANCE env var) - double heuristics_variance; - - /** - * @brief Default constructor that reads from environment variables. - */ - runtime_options(); - - /** - * @brief Constructor with explicit values (does not read from environment). - */ - runtime_options(bool debug, bool heuristics, double variance); - - /** - * @brief Get the global runtime options instance. - */ - static runtime_options& get(); - - /** - * @brief Read debug setting from environment variable. - * @return true if ANALYTICAL_GEMM_DEBUG is set to "1", false otherwise - */ - static bool read_debug_from_env(); - - /** - * @brief Read heuristics setting from environment variable. - * @return true if ANALYTICAL_GEMM_HEURISTICS is set to "1", false otherwise - */ - static bool read_heuristics_from_env(); - - /** - * @brief Read heuristics variance from environment variable. - * @return double Variance value from ANALYTICAL_GEMM_HEURISTICS_VARIANCE, or 0.0 if not set - */ - static double read_heuristics_variance_from_env(); - - /** - * @brief Update runtime options from environment variables. - */ - void update_from_env(); -}; - -/** - * @brief Full kernel configuration (tile shape + execution parameters). - * - * Holds the geometric tile sizes along with occupancy, - * work-group mapping (WGM), and cache-control hints. - */ -struct config_t { - /// Macro tile and matrix-instruction shape. - dim3_t mt{0, 0, 0}; - dim3_t mi{0, 0, 0}; - - /// Occupancy (number of waves resident per CU). - int occupancy = -1; - - /// Reorder workgroup id for L2 reuse. - int workgroup_mapping = 0; - - /// Whether operand A is accessed with cache-flags. - int cache_hints_a = 0; - - /// Whether operand B is accessed with cache-flags. - int cache_hints_b = 0; - - /// Workspace size parameters. - std::size_t workspace_size = 0; - std::size_t workspace_size_per_elem_c = 0; - - /// Reduction strategy. - reduction_t reduction_strategy = reduction_t::none; - - /// Runtime options (if null, uses global singleton) - const runtime_options* runtime_opts{nullptr}; - - /// Logger for analytical metrics - mutable logger_t logger; - - constexpr bool operator==(const config_t& o) const noexcept { - return mt == o.mt && mi == o.mi && cache_hints_a == o.cache_hints_a && - cache_hints_b == o.cache_hints_b && workgroup_mapping == o.workgroup_mapping; - } - - std::size_t hash() const { - return std::hash()(mt.m) ^ std::hash()(mt.n) ^ std::hash()(mt.k) ^ - std::hash()(mi.m) ^ std::hash()(mi.n) ^ std::hash()(mi.k) ^ - std::hash()(cache_hints_a) ^ std::hash()(cache_hints_b) ^ - std::hash()(workgroup_mapping); - } - - void validate() const { - if (!is_valid()) { throw std::runtime_error("Invalid config_t"); } - } - - bool is_valid() const { - return mt.m > 0 && mt.n > 0 && mt.k > 0 && mi.m > 0 && mi.n > 0 && mi.k > 0 && occupancy > 0; - } -}; - -/** - * @brief Latency prediction result given kernel configuration. - * - * Combines a configuration with its estimated latency. - */ -struct prediction_result_t { - double latency; - config_t config; -}; - -/** - * @brief Struct to define the GEMM problem characteristics. - * - * Contains all the parameters needed to describe a GEMM operation, - * including matrix dimensions, data types, and operation flags. - */ -struct problem_t { - /// Size of the problem: M, N, K. - dim3_t size{0, 0, 0}; - - /// Batch size. - std::size_t batch = 1; - - /// Transpose types (TT, TN, NT, TT.) - transpose_t a_transpose = transpose_t::N; - transpose_t b_transpose = transpose_t::N; - - /// Data types: A, B, C, D. - data_type_t a_dtype = data_type_t::None; - data_type_t b_dtype = data_type_t::None; - data_type_t c_dtype = data_type_t::None; - data_type_t d_dtype = data_type_t::None; - - /// Compute type. - data_type_t mi_dtype = data_type_t::None; - - /// MX block size. - std::size_t a_mx_block_size = 0; - std::size_t b_mx_block_size = 0; -}; - -/** - * @brief Get runtime options from config, or global singleton if config doesn't specify. - * - * @param config Configuration struct (may contain runtime_opts pointer) - * @return const runtime_options& Reference to runtime options - */ -inline const runtime_options& get_runtime_options(const config_t& config) { - return config.runtime_opts ? *config.runtime_opts : runtime_options::get(); -} - -} // namespace origami - -// Specialization of std::hash in the std namespace for use of std::unordered_map with -// matrix_instruction and config_t as keys. -namespace std { -template <> -struct hash { - std::size_t operator()(const origami::matrix_instruction& k) const { return k.hash(); } -}; - -template <> -struct hash { - std::size_t operator()(const origami::config_t& config) const noexcept { return config.hash(); } -}; -} // namespace std diff --git a/shared/origami/include/origami/utils.hpp b/shared/origami/include/origami/utils.hpp new file mode 100644 index 00000000000..92e07c6fe8f --- /dev/null +++ b/shared/origami/include/origami/utils.hpp @@ -0,0 +1,116 @@ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "origami/gemm.hpp" +#include "origami/hardware.hpp" +#include +#include +#include +#include // For std::function + +namespace origami +{ + using result_tuple = std::tuple; // non_temporal_b + + using tile_tuple = std::tuple; // non_temporal_b + + size_t select_best_grid_size(size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + const hardware_t& hardware, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, + size_t element_size_B, + size_t element_size_out, + data_type_t mi_datatype, + size_t mx_block_size, + double H_L2, + int WGM, + size_t biggest_allowable_split = 8, + size_t max_cus = 0); + + std::vector select_best_macro_tile_size(size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + const hardware_t& hardware, + const std::vector& MT_list, + size_t element_size_A, + size_t element_size_B, + size_t element_size_out, + data_type_t mi_datatype, + size_t mx_block_size, + double H_L2, + bool print, + int WGM, + size_t max_cus = 0); + + std::vector sweep_macro_tile_sizes(size_t M, + size_t N, + size_t K, + bool transA, + bool transB, + hardware_t& hardware, + size_t element_size = 2, + size_t max_MT_M = 256, + size_t max_MT_N = 256, + size_t max_MT_K = 128, + size_t step_MT_M = 32, + size_t step_MT_N = 32, + size_t step_MT_K = 32, + double H_L2 = 0.8, + const std::vector& tiles_to_add + = {}, + bool print = false); + + std::tuple select_best_wgm(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + size_t MT_M, + size_t MT_N, + size_t MT_K, + int nta, + int ntb, + size_t skGrid, + bool print); + + double compute_tflops_from_latency(double latency_cycles, + size_t M, + size_t N, + size_t K, + double clock_GHz); + +} // namespace origami diff --git a/shared/origami/python/CMakeLists.txt b/shared/origami/python/CMakeLists.txt index f969e328585..d0ae6eac548 100644 --- a/shared/origami/python/CMakeLists.txt +++ b/shared/origami/python/CMakeLists.txt @@ -1,27 +1,5 @@ -################################################################################ -# -# MIT License -# -# Copyright 2025 AMD ROCm(TM) Software -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- -# ies of the Software, and to permit persons to whom the Software is furnished -# to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- -# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- -# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# -################################################################################ +# Copyright Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) diff --git a/shared/origami/python/README.md b/shared/origami/python/README.md index 9cd51b02596..ff7f2adfe83 100644 --- a/shared/origami/python/README.md +++ b/shared/origami/python/README.md @@ -35,7 +35,9 @@ import origami hardware = origami.getHardwareForDevice(args.device) -result = origami.rank_configs(problem, hardware, configs) +result = origami.select_best_macro_tile_size( + args.m, args.n, args.k, args.transA, args.transB, hardware, tile_list, args.element_size, args.miDataType, 0.8, args.debug, args.print, args.wgm + ) ``` ## Modifying `origami_module.cpp` diff --git a/shared/origami/python/origami_grid_test.py b/shared/origami/python/origami_grid_test.py index 20b3dac0afe..8a6c6822c33 100755 --- a/shared/origami/python/origami_grid_test.py +++ b/shared/origami/python/origami_grid_test.py @@ -1,35 +1,15 @@ +# Copyright Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + #!/usr/bin/env python3 -################################################################################ -# -# MIT License -# -# Copyright 2025 AMD ROCm(TM) Software -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- -# ies of the Software, and to permit persons to whom the Software is furnished -# to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- -# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- -# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# -################################################################################ +# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT import argparse import origami import math - def parseArguments(): parser = argparse.ArgumentParser(description="Test StreamK Grid Selection.") parser.add_argument("-m", type=int, default=8192, help="Problem M dimension") @@ -50,26 +30,13 @@ def parseArguments(): "--type_b", type=str, default="f16", help="Size of each element in B in bits" ) parser.add_argument( - "--type_acc", - type=str, - default="f32", - help="Size of each element in partial tile in bits", - ) - parser.add_argument( - "--type_d", - type=str, - default="f16", - help="Size of each element in the output in bits", + "--type_acc", type=str, default="f32", help="Size of each element in partial tile in bits" ) parser.add_argument( - "--type_compute", type=str, default=None, help="Instruction input type" - ) - parser.add_argument( - "--workspace_size", - type=int, - default=0, - help="Amount of workspace available in bytes", + "--type_d", type=str, default="f16", help="Size of each element in the output in bits" ) + parser.add_argument("--type_compute", type=str, default=None, help="Instruction input type") + parser.add_argument("--workspace_size", type=int, default=0, help="Amount of workspace available in bytes") parser.add_argument("--debug", action="store_true", help="Enable debug mode") parser.add_argument("--print", action="store_true", help="Print hardware info") parser.add_argument( @@ -79,30 +46,17 @@ def parseArguments(): parser.add_argument("--mt_m", type=int, default=32, help="Macro-tile dimension M") parser.add_argument("--mt_n", type=int, default=32, help="Macro-tile dimension N") parser.add_argument("--mt_k", type=int, default=256, help="Macro-tile dimension K") - parser.add_argument( - "--mi_m", type=int, default=16, help="Machine Instruction dimension M" - ) - parser.add_argument( - "--mi_n", type=int, default=16, help="Machine Instruction dimension N" - ) - parser.add_argument( - "--mi_k", type=int, default=16, help="Machine Instruction dimension K" - ) + parser.add_argument("--mi_m", type=int, default=16, help="Machine Instruction dimension M") + parser.add_argument("--mi_n", type=int, default=16, help="Machine Instruction dimension N") + parser.add_argument("--mi_k", type=int, default=16, help="Machine Instruction dimension K") parser.add_argument("--occupancy", type=int, default=1, help="Occupancy of kernel") - parser.add_argument( - "--dynamic_grid_version", - type=int, - default=5, - help="Version of Dynamic Grid Selection to use", - ) + parser.add_argument("--dynamic_grid_version", type=int, default=5, help="Version of Dynamic Grid Selection to use") args = parser.parse_args() if args.type_compute is None: - if origami.datatype_to_bits( - origami.string_to_datatype(args.type_a) - ) > origami.datatype_to_bits(origami.string_to_datatype(args.type_b)): + if origami.datatype_to_bits(origami.string_to_datatype(args.type_a)) > origami.datatype_to_bits(origami.string_to_datatype(args.type_b)): args.type_compute = args.type_a else: args.type_compute = args.type_b @@ -118,52 +72,42 @@ def main(): if args.print: hardware.print() - # Create problem description - problem = origami.problem_t() - problem.size = origami.dim3_t(args.m, args.n, args.k) - problem.batch = args.batch - problem.a_transpose = ( - origami.transpose_t.T if args.trans_a else origami.transpose_t.N - ) - problem.b_transpose = ( - origami.transpose_t.T if args.trans_b else origami.transpose_t.N - ) - problem.a_dtype = origami.string_to_datatype(args.type_a) - problem.b_dtype = origami.string_to_datatype(args.type_b) - problem.d_dtype = origami.string_to_datatype(args.type_d) - problem.c_dtype = problem.d_dtype - problem.mi_dtype = origami.string_to_datatype(args.type_compute) - problem.a_mx_block_size = 0 - problem.b_mx_block_size = 0 - - # Create config - config = origami.config_t() - config.mt = origami.dim3_t(args.mt_m, args.mt_n, args.mt_k) - config.mi = origami.dim3_t(args.mi_m, args.mi_n, args.mi_k) - config.occupancy = args.occupancy - config.workgroup_mapping = args.wgm - - # Select reduction strategy - grid_algorithm = origami.grid_selection_t.analytical # default to analytical - if args.dynamic_grid_version == 0: - grid_algorithm = origami.grid_selection_t.number_of_cus - elif args.dynamic_grid_version == 1: - grid_algorithm = origami.grid_selection_t.min_resources - elif args.dynamic_grid_version == 2: - grid_algorithm = origami.grid_selection_t.energy_aware - elif args.dynamic_grid_version == 3: - grid_algorithm = origami.grid_selection_t.reduction_cost_aware - elif args.dynamic_grid_version == 4: - grid_algorithm = origami.grid_selection_t.data_parallel - elif args.dynamic_grid_version == 5: - grid_algorithm = origami.grid_selection_t.analytical - elif args.dynamic_grid_version == 6: - grid_algorithm = origami.grid_selection_t.k_split_aware - - reduction = origami.select_reduction(problem, hardware, config, grid_algorithm) - - winner_grid = origami.select_grid_size( - problem, hardware, config, grid_algorithm, hardware.N_CU + reduction = origami.select_reduction( + args.m, + args.n, + args.k, + args.batch, + args.mt_m, + args.mt_n, + args.mt_k, + hardware, + args.dynamic_grid_version + ) + + winner_grid = origami.select_grid( + args.m, + args.n, + args.k, + args.batch, + args.trans_a, + args.trans_b, + origami.datatype_to_bits(origami.string_to_datatype(args.type_a)), + origami.datatype_to_bits(origami.string_to_datatype(args.type_b)), + origami.datatype_to_bits(origami.string_to_datatype(args.type_d)), + origami.string_to_datatype(args.type_compute), + args.workspace_size, + args.mt_m, + args.mt_n, + args.mt_k, + args.mi_m, # MI_M + args.mi_n, # MI_N + args.mi_k, # MI_K + args.wgm, + origami.datatype_to_bits(origami.string_to_datatype(args.type_acc)) // 8, + args.occupancy, + hardware, + args.dynamic_grid_version, + reduction ) print(f"Best reduction algo : {reduction}") diff --git a/shared/origami/python/origami_module.cpp b/shared/origami/python/origami_module.cpp index 0bb0f8a9a8b..26c8409bd28 100644 --- a/shared/origami/python/origami_module.cpp +++ b/shared/origami/python/origami_module.cpp @@ -1,290 +1,172 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright 2025 AMD ROCm(TM) Software - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "origami/hardware.hpp" +#include "origami/streamk.hpp" +#include "origami/utils.hpp" #include +#include +#include #include +#include #include -#include #include -#include -#include -#include "origami/gemm.hpp" -#include "origami/hardware.hpp" -// logger_t is defined in types.hpp -#include "origami/origami.hpp" -#include "origami/streamk.hpp" -#include "origami/types.hpp" using hardware_t = origami::hardware_t; -namespace nb = nanobind; +namespace nb = nanobind; using namespace nb::literals; -NB_MODULE(origami, m) { - nanobind::enum_(m, "architecture_t") - .value("gfx942", hardware_t::architecture_t::gfx942) - .value("gfx950", hardware_t::architecture_t::gfx950) - .export_values(); - - nanobind::enum_(m, "data_type_t") - .value("Float", origami::data_type_t::Float) - .value("ComplexFloat", origami::data_type_t::ComplexFloat) - .value("ComplexDouble", origami::data_type_t::ComplexDouble) - .value("Double", origami::data_type_t::Double) - .value("Half", origami::data_type_t::Half) - .value("Int8x4", origami::data_type_t::Int8x4) - .value("Int32", origami::data_type_t::Int32) - .value("BFloat16", origami::data_type_t::BFloat16) - .value("Int8", origami::data_type_t::Int8) - .value("Int64", origami::data_type_t::Int64) - .value("XFloat32", origami::data_type_t::XFloat32) - .value("Float8_fnuz", origami::data_type_t::Float8_fnuz) - .value("BFloat8_fnuz", origami::data_type_t::BFloat8_fnuz) - .value("Float8BFloat8_fnuz", origami::data_type_t::Float8BFloat8_fnuz) - .value("BFloat8Float8_fnuz", origami::data_type_t::BFloat8Float8_fnuz) - .value("Float8", origami::data_type_t::Float8) - .value("BFloat8", origami::data_type_t::BFloat8) - .value("Float8BFloat8", origami::data_type_t::Float8BFloat8) - .value("BFloat8Float8", origami::data_type_t::BFloat8Float8) - .value("Float6", origami::data_type_t::Float6) - .value("BFloat6", origami::data_type_t::BFloat6) - .value("Float4", origami::data_type_t::Float4) - .export_values(); - - // After your other nanobind::enum_ blocks - nanobind::enum_(m, "transpose_t") - .value("T", origami::transpose_t::T) - .value("N", origami::transpose_t::N) - // Optional: usually you don't expose Count, but you can if you want - // .value("Count", origami::transpose_t::Count) - .export_values(); - - m.def("int_to_data_type", &origami::int_to_data_type, "Convert int to data_type_t."); - - nanobind::enum_(m, "grid_selection_t") - .value("number_of_cus", origami::grid_selection_t::number_of_cus) - .value("min_resources", origami::grid_selection_t::min_resources) - .value("energy_aware", origami::grid_selection_t::energy_aware) - .value("reduction_cost_aware", origami::grid_selection_t::reduction_cost_aware) - .value("data_parallel", origami::grid_selection_t::data_parallel) - .value("analytical", origami::grid_selection_t::analytical) - .value("k_split_aware", origami::grid_selection_t::k_split_aware) - .export_values(); - - nanobind::enum_(m, "reduction_t") - .value("tree", origami::reduction_t::tree) - .value("parallel", origami::reduction_t::parallel) - .export_values(); - - m.def("int_to_reduction_t", &origami::int_to_reduction_t, "Convert int to reduction_t."); - - // Add new struct bindings - nanobind::class_(m, "dim3_t") - .def(nanobind::init()) - .def_rw("m", &origami::dim3_t::m) - .def_rw("n", &origami::dim3_t::n) - .def_rw("k", &origami::dim3_t::k) - .def("mn", &origami::dim3_t::mn) - .def("mk", &origami::dim3_t::mk) - .def("nk", &origami::dim3_t::nk) - .def("mnk", &origami::dim3_t::mnk); - - nanobind::class_(m, "logger_t") - .def(nanobind::init<>()) - .def("clear", &origami::logger_t::clear, "Clear all logged metrics") - .def("print", &origami::logger_t::print, "Print all metrics as JSON to stdout") - .def("export_json", &origami::logger_t::export_json, "Export metrics to a JSON file") - .def("get_metrics", &origami::logger_t::get_metrics, "Get all metrics as a map") - .def("empty", &origami::logger_t::empty, "Check if logger has any metrics") - // Overloads for templated log() method - .def( - "log", - [](origami::logger_t& self, const std::string& key, int value) { self.log(key, value); }, - "Log an integer value") - .def( - "log", - [](origami::logger_t& self, const std::string& key, double value) { - self.log(key, value); - }, - "Log a double value") - .def( - "log", - [](origami::logger_t& self, const std::string& key, const std::string& value) { - self.log(key, value); - }, - "Log a string value") - .def( - "log", - [](origami::logger_t& self, const std::string& key, bool value) { self.log(key, value); }, - "Log a boolean value") - .def( - "log", - [](origami::logger_t& self, const std::string& key, size_t value) { - self.log(key, value); - }, - "Log a size_t value"); - - nanobind::class_(m, "config_t") - .def(nanobind::init<>()) - .def_rw("mt", &origami::config_t::mt) - .def_rw("mi", &origami::config_t::mi) - .def_rw("occupancy", &origami::config_t::occupancy) - .def_rw("workgroup_mapping", &origami::config_t::workgroup_mapping) - .def_rw("cache_hints_a", &origami::config_t::cache_hints_a) - .def_rw("cache_hints_b", &origami::config_t::cache_hints_b) - .def_rw("workspace_size", &origami::config_t::workspace_size) - .def_rw("workspace_size_per_elem_c", &origami::config_t::workspace_size_per_elem_c) - .def_rw("logger", &origami::config_t::logger); - - nanobind::class_(m, "prediction_result_t") - .def(nanobind::init<>()) - .def_rw("latency", &origami::prediction_result_t::latency) - .def_rw("config", &origami::prediction_result_t::config); - - nanobind::class_(m, "problem_t") - .def(nanobind::init<>()) - .def_rw("size", &origami::problem_t::size) - .def_rw("batch", &origami::problem_t::batch) - .def_rw("a_transpose", &origami::problem_t::a_transpose) - .def_rw("b_transpose", &origami::problem_t::b_transpose) - .def_rw("a_dtype", &origami::problem_t::a_dtype) - .def_rw("b_dtype", &origami::problem_t::b_dtype) - .def_rw("c_dtype", &origami::problem_t::c_dtype) - .def_rw("d_dtype", &origami::problem_t::d_dtype) - .def_rw("mi_dtype", &origami::problem_t::mi_dtype) - .def_rw("a_mx_block_size", &origami::problem_t::a_mx_block_size) - .def_rw("b_mx_block_size", &origami::problem_t::b_mx_block_size); - - nanobind::class_(m, "hardware_t") - .def(nanobind::init>()) - .def("print", &hardware_t::print) - .def_rw("N_CU", &hardware_t::N_CU) - .def_rw("lds_capacity", &hardware_t::lds_capacity) - .def_rw("mem1_perf_ratio", &hardware_t::mem1_perf_ratio) - .def_rw("mem2_perf_ratio", &hardware_t::mem2_perf_ratio) - .def_rw("mem3_perf_ratio", &hardware_t::mem3_perf_ratio) - .def_rw("L2_capacity", &hardware_t::L2_capacity) - .def_rw("CU_per_L2", &hardware_t::CU_per_L2) - .def_rw("compute_clock_ghz", &hardware_t::compute_clock_ghz) - .def_rw("parallel_mi_cu", &hardware_t::parallel_mi_cu) - .def_rw("mem_bw_per_wg_coefficients", &hardware_t::mem_bw_per_wg_coefficients) - .def_rw("NUM_XCD", &hardware_t::NUM_XCD); - - m.def("get_hardware_for_device", - &hardware_t::get_hardware_for_device, - "This gets a hardware object for a device."); - - m.def("datatype_to_bits", &origami::datatype_to_bits, "Return the number of bits in a datatype"); - m.def("string_to_datatype", - &origami::string_to_datatype, - "Convert a string representation of a datatype into data_type_t enum"); - m.def("datatype_to_string", - &origami::datatype_to_string, - "Convert data_type_t enum to string representation"); - - m.def("select_config", - &origami::select_config, - "problem"_a, - "hardware"_a, - "configs"_a, - "Select best configuration based on problem and hardware"); - m.def("select_grid_size", - &origami::streamk::select_grid_size, - "problem"_a, - "hardware"_a, - "config"_a, - "algorithm"_a, - "max_cus"_a = 0, - "Select best grid size for the given configuration"); - m.def("select_workgroup_mapping", - &origami::select_workgroup_mapping, - "problem"_a, - "hardware"_a, - "config"_a, - "skGrid"_a, - - "Select best workgroup mapping"); - m.def("rank_configs", - &origami::rank_configs, - "problem"_a, - "hardware"_a, - "configs"_a, - "Rank configurations by performance"); - m.def("select_config_mnk", - &origami::select_config_mnk, - "M"_a, - "N"_a, - "K"_a, - "hardware"_a, - "configs"_a, - - "Select best configuration for M,N,K dimensions"); - m.def("select_topk_configs", - &origami::select_topk_configs, - "problem"_a, - "hardware"_a, - "configs"_a, - "topk"_a, - - "Select topk configurations"); - m.def("compute_perf_gflops", - &origami::compute_perf_gflops, - "hardware"_a, - "problem"_a, - "latency"_a, - - "Compute performance in GFLOPS"); - - // StreamK functions - m.def("select_reduction", - &origami::streamk::select_reduction, - "problem"_a, - "hardware"_a, - "config"_a, - "algorithm"_a, - "Select best StreamK reduction strategy"); - - // GEMM functions - m.def("compute_total_latency", - static_cast(&origami::compute_total_latency), - "problem"_a, - "hardware"_a, - "config"_a, - "max_cus"_a, - "Compute total latency"); +NB_MODULE(origami, m) +{ + nanobind::enum_(m, "architecture_t") + .value("gfx942", hardware_t::architecture_t::gfx942) + .value("gfx950", hardware_t::architecture_t::gfx950) + .export_values(); + + nanobind::enum_(m, "data_type_t") + .value("Float", origami::data_type_t::Float) + .value("ComplexFloat", origami::data_type_t::ComplexFloat) + .value("ComplexDouble", origami::data_type_t::ComplexDouble) + .value("Double", origami::data_type_t::Double) + .value("Half", origami::data_type_t::Half) + .value("Int8x4", origami::data_type_t::Int8x4) + .value("Int32", origami::data_type_t::Int32) + .value("BFloat16", origami::data_type_t::BFloat16) + .value("Int8", origami::data_type_t::Int8) + .value("Int64", origami::data_type_t::Int64) + .value("XFloat32", origami::data_type_t::XFloat32) + .value("Float8_fnuz", origami::data_type_t::Float8_fnuz) + .value("BFloat8_fnuz", origami::data_type_t::BFloat8_fnuz) + .value("Float8BFloat8_fnuz", origami::data_type_t::Float8BFloat8_fnuz) + .value("BFloat8Float8_fnuz", origami::data_type_t::BFloat8Float8_fnuz) + .value("Float8", origami::data_type_t::Float8) + .value("BFloat8", origami::data_type_t::BFloat8) + .value("Float8BFloat8", origami::data_type_t::Float8BFloat8) + .value("BFloat8Float8", origami::data_type_t::BFloat8Float8) + .value("Float6", origami::data_type_t::Float6) + .value("BFloat6", origami::data_type_t::BFloat6) + .value("Float4", origami::data_type_t::Float4) + .export_values(); + + m.def("int_to_data_type", + &origami::int_to_data_type, + "Convert int to data_type_t."); + + nanobind::enum_(m, "reduction_type") + .value("Tree", origami::streamk::reduction_type::Tree) + .value("Parallel", origami::streamk::reduction_type::Parallel) + .export_values(); + + m.def("int_to_reduction_type", + &origami::streamk::int_to_reduction_type, + "Convert int to reduction_type."); + + nanobind::class_(m, "hardware_t") + .def(nanobind::init>()) + .def("print", &hardware_t::print) + .def("print_debug_info", &hardware_t::print_debug_info) + .def_rw("N_CU", &hardware_t::N_CU) + .def_rw("lds_capacity", &hardware_t::lds_capacity) + .def_rw("mem1_perf_ratio", &hardware_t::mem1_perf_ratio) + .def_rw("mem2_perf_ratio", &hardware_t::mem2_perf_ratio) + .def_rw("mem3_perf_ratio", &hardware_t::mem3_perf_ratio) + .def_rw("L2_capacity", &hardware_t::L2_capacity) + .def_rw("CU_per_L2", &hardware_t::CU_per_L2) + .def_rw("compute_clock_ghz", &hardware_t::compute_clock_ghz) + .def_rw("parallel_mi_cu", &hardware_t::parallel_mi_cu) + .def_rw("mem_bw_per_wg_coefficients", &hardware_t::mem_bw_per_wg_coefficients) + .def_rw("NUM_XCD", &hardware_t::NUM_XCD) + .def_rw("debug_info", &hardware_t::debug_info); + + m.def("get_hardware_for_device", + &hardware_t::get_hardware_for_device, + "This gets a hardware object for a device."); + + m.def("datatype_to_bits", &origami::data_type_to_bits, "Return the number of bits in a datatype"); + m.def("string_to_datatype", &origami::string_to_data_type, "Convert a string representation of a datatype into data_type_t enum"); + m.def("select_best_macro_tile_size", + &origami::select_best_macro_tile_size, + "M"_a, + "N"_a, + "K"_a, + "batch"_a, + "transA"_a, + "transB"_a, + "hardware"_a, + "MT_list"_a, + "element_size_A"_a, + "element_size_B"_a, + "element_size_out"_a, + "mi_datatype"_a, + "mx_block_size"_a, + "H_L2"_a, + "print"_a, + "WGM"_a, + "max_cus"_a = 0, + "Get best macro tile sizes."); + m.def("select_reduction", &origami::streamk::select_reduction, "Select best StreamK reduction strategy"); + m.def("select_grid", &origami::streamk::select_grid, + "x"_a, + "y"_a, + "z"_a, + "batch"_a, + "trans_a"_a, + "trans_b"_a, + "element_size_A"_a, + "element_size_B"_a, + "element_size_out"_a, + "mi_datatype"_a, + "workspace_size"_a, + "mt_m"_a, + "mt_n"_a, + "mt_k"_a, + "mi_m"_a, + "mi_n"_a, + "mi_k"_a, + "workgroup_mapping"_a, + "workspace_size_per_elem_c"_a, + "occupancy"_a, + "analytical_hardware"_a, + "dynamic_grid_version"_a, + "reduction_strategy"_a, + "max_cus"_a = 0, + "Select Best StreamK Grid Size"); + m.def("compute_total_latency", &origami::compute_total_latency, + "hardware"_a, + "M"_a, + "N"_a, + "K"_a, + "batch"_a, + "transA"_a, + "transB"_a, + "MT_M"_a, + "MT_N"_a, + "MT_K"_a, + "MI_M"_a, + "MI_N"_a, + "MI_K"_a, + "element_size_A"_a, + "element_size_B"_a, + "element_size_out"_a, + "mi_datatype"_a, + "mx_block_size"_a, + "WGM"_a, + "non_temporal_a"_a = 0, + "non_temporal_b"_a = 0, + "occupancy"_a = 1, + "split"_a = 0, + "max_cus"_a = 0, + "Compute the total latency of a gemm"); + m.def("select_best_wgm", &origami::select_best_wgm, "Get best workgroup mapping."); } diff --git a/shared/origami/python/origami_test.py b/shared/origami/python/origami_test.py index 4946a378af3..fdf4d393150 100755 --- a/shared/origami/python/origami_test.py +++ b/shared/origami/python/origami_test.py @@ -1,29 +1,7 @@ #!/usr/bin/env python3 -################################################################################ -# -# MIT License -# -# Copyright 2025 AMD ROCm(TM) Software -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- -# ies of the Software, and to permit persons to whom the Software is furnished -# to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- -# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- -# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# -################################################################################ +# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT import argparse import origami @@ -50,92 +28,89 @@ # BFloat6 == "bf6" # Float4 == "f4" -MatInst = {"gfx950": {}, "gfx942": {}} -MatInst["gfx950"]["f32"] = [(16, 16, 4, 1), (32, 32, 2, 1)] +MatInst = {'gfx950': {}, 'gfx942': {}} +MatInst['gfx950']["f32"] = [ + (16,16,4,1), + (32,32,2,1) + ] # MatInst['gfx950']["c32"] = [] # MatInst['gfx950']["c64"] = [] -MatInst["gfx950"]["f64"] = [(16, 16, 4, 1)] -MatInst["gfx950"]["f16"] = [ +MatInst['gfx950']["f64"] = [ + (16,16,4,1) + ] +MatInst['gfx950']["f16"] = [ # (4,4,4,16), #gfx942 - # [16,16,4,4] # never use 16x16x4x4 - # [16,16,16,1] #gfx942 - # [32,32,4,2] # never use 32x32x4x2 - # [32,32,8,1] #gfx942 - (16, 16, 32, 1), # gfx950 - (32, 32, 16, 1), # gfx950 -] + #[16,16,4,4] # never use 16x16x4x4 + #[16,16,16,1] #gfx942 + #[32,32,4,2] # never use 32x32x4x2 + #[32,32,8,1] #gfx942 + (16,16,32,1), #gfx950 + (32,32,16,1) #gfx950 + ] # MatInst['gfx950']["i32"] =[] -MatInst["gfx950"]["bf16"] = MatInst["gfx950"]["f16"] +MatInst['gfx950']["bf16"] = MatInst['gfx950']["f16"] # MatInst['gfx950']["i8"] =[ # (32,32,16,1), # (16,16,32,1), # (4,4,4,16) # ] -MatInst["gfx950"]["xf32"] = [ +MatInst['gfx950']["xf32"] = [ # (4,4,4,16), #gfx942 - # [16,16,4,4] # never use 16x16x4x4 - # [16,16,16,1] #gfx942 - # [32,32,4,2] # never use 32x32x4x2 - # [32,32,8,1] #gfx942 - (16, 16, 32, 1), # gfx950 - (32, 32, 16, 1), # gfx950 -] -MatInst["gfx950"]["f8"] = [ - (4, 4, 4, 16), # gfx950, gfx942 - (16, 16, 128, 1), # gfx950 - (32, 32, 64, 1), # gfx950 -] -MatInst["gfx950"]["bf8"] = MatInst["gfx950"]["f8"] - + #[16,16,4,4] # never use 16x16x4x4 + #[16,16,16,1] #gfx942 + #[32,32,4,2] # never use 32x32x4x2 + #[32,32,8,1] #gfx942 + (16,16,32,1), #gfx950 + (32,32,16,1) #gfx950 + ] +MatInst['gfx950']["f8"] = [ + (4,4,4,16), #gfx950, gfx942 + (16,16,128,1), #gfx950 + (32,32,64,1) #gfx950 + ] +MatInst['gfx950']["bf8"] = MatInst['gfx950']["f8"] def parseArguments(): - parser = argparse.ArgumentParser( - description="""Get hypothetical Origami MTxDU selection for a size""" - ) + parser = argparse.ArgumentParser(description="""Get hypothetical Origami MTxDU selection for a size""") parser.add_argument("-m", type=int, default=8192) parser.add_argument("-n", type=int, default=8192) parser.add_argument("-b", type=int, default=1) parser.add_argument("-k", type=int, default=8192) parser.add_argument("--trans_a", type=bool, default=True) parser.add_argument("--trans_b", type=bool, default=False) - parser.add_argument("--device", type=int, default=0) # to get hardware specs + parser.add_argument("--device", type=int, default=0) # to get hardware specs parser.add_argument("--type_a", type=str, default="f16") parser.add_argument("--type_b", type=str, default="f16") parser.add_argument("--type_d", type=str, default="f16") parser.add_argument("--scale_block_size", type=int, default=0) parser.add_argument("--wgm", type=int, default=6) - parser.add_argument( - "--sizes", type=bool, default=False - ) # to load the sizes from a csv file. -m/-n/-b/-k will be ignored if True. - parser.add_argument( - "--path", type=str, default="./sizes.csv" - ) # path to the csv file. Fails if sizes is True, and path or file does not exist. + parser.add_argument("--sizes", type=bool, default=False) # to load the sizes from a csv file. -m/-n/-b/-k will be ignored if True. + parser.add_argument("--path", type=str, default="./sizes.csv") # path to the csv file. Fails if sizes is True, and path or file does not exist. parser.add_argument("--arch", type=str, default="gfx950") parser.add_argument("--print", action="store_true") return parser.parse_args() - -def createConfigList(arch, gemmType): +def createTileList(arch, gemmType): LIST_OF_WAVEs_TO_INCLUDE = [[4, 1], [2, 2], [1, 4], [1, 2], [2, 1], [1, 1]] MIN_MT0 = MIN_MT1 = 16 MAX_MT0 = MAX_MT1 = 512 - # generate all configs for each datatype: + # generate all MTs for each datatype: bm_max = 0 - configs = [] + tile_list = set() for MI in MatInst[arch][gemmType]: for bm in range(bm_max + 1): - MIBlockM = 2**bm + MIBlockM = 2 ** bm for wave in LIST_OF_WAVEs_TO_INCLUDE: waveTileM = 0 waveTileN = 0 while True: - waveTileM += 1 - waveTileN = 0 + waveTileM+=1 + waveTileN=0 MatrixInstM = MI[0] * MIBlockM MT0 = MatrixInstM * waveTileM * wave[0] if MT0 < MIN_MT0: @@ -144,7 +119,7 @@ def createConfigList(arch, gemmType): break while True: - waveTileN += 1 + waveTileN+=1 MatrixInstN = MI[1] / MIBlockM * MI[3] MT1 = int(MatrixInstN * waveTileN * wave[1]) @@ -154,118 +129,89 @@ def createConfigList(arch, gemmType): break # LDS size check for LSU - LSU = max(1, 4 // wave[0] // wave[1]) - if LSU > 1 and MT0 * MT1 * 4 * LSU > 256 * 256: + LSU = max(1, 4//wave[0]//wave[1]) + if LSU > 1 and MT0*MT1*4*LSU > 256*256: continue - if MT0 * MT1 > 256 * 256: + if MT0*MT1 > 256*256: continue for DU in [16, 32, 64, 128, 256, 512, 1024]: - # Create config_t object - config = origami.config_t() - config.mt = origami.dim3_t(MT0, MT1, DU) - config.mi = origami.dim3_t(MI[0], MI[1], MI[2]) - config.occupancy = 1 - config.workgroup_mapping = 6 - configs.append(config) - - return configs + tile_list.add((MT0, MT1, DU, MI[0], MI[1], MI[2], 1, 6, 0, 0)) + return [tile for tile in tile_list] def main(): args = parseArguments() hardware = origami.get_hardware_for_device(args.device) - configs = createConfigList(args.arch, args.type_a) - - print(" Number of unique configs: ", len(configs)) - - if args.sizes: # sizes from a file - try: - with open(args.path, "r") as csvfile: - csv_reader = csv.reader(csvfile) - print(f"M,N,Batch,K,MT0,MT1,DU,MI0,MI1,MI2,latency") - for row in csv_reader: - M = int(row[0]) - N = int(row[1]) - B = int(row[2]) - K = int(row[3]) - # Create problem description - problem = origami.problem_t() - problem.size = origami.dim3_t(M, N, K) - problem.batch = B - problem.a_transpose = ( - origami.transpose_t.T if args.trans_a else origami.transpose_t.N - ) - problem.b_transpose = ( - origami.transpose_t.T if args.trans_b else origami.transpose_t.N - ) - problem.a_dtype = origami.string_to_datatype(args.type_a) - problem.b_dtype = origami.string_to_datatype(args.type_b) - problem.d_dtype = origami.string_to_datatype(args.type_d) - problem.c_dtype = problem.d_dtype - problem.mi_dtype = problem.a_dtype - problem.a_mx_block_size = args.scale_block_size - problem.b_mx_block_size = args.scale_block_size - - # Select best config - best_config = origami.select_config(problem, hardware, configs) - latency = best_config.latency - - # MxNxBxK, MT0xMT1xDU, MI0xMI1xMI2, latency/cycles - print( - f"{M},{N},{B},{K},{best_config.mt.m},{best_config.mt.n},{best_config.mt.k},{best_config.mi.m},{best_config.mi.n},{best_config.mi.k},{latency:0.3f}" - ) - except FileNotFoundError: - raise FileNotFoundError( - f"Error: The size file: '{args.path}' does not exist." - ) - else: # one size from the command line. - # Create problem description - problem = origami.problem_t() - problem.size = origami.dim3_t(args.m, args.n, args.k) - problem.batch = args.b - problem.a_transpose = ( - origami.transpose_t.T if args.trans_a else origami.transpose_t.N - ) - problem.b_transpose = ( - origami.transpose_t.T if args.trans_b else origami.transpose_t.N - ) - problem.a_dtype = origami.string_to_datatype(args.type_a) - problem.b_dtype = origami.string_to_datatype(args.type_b) - problem.d_dtype = origami.string_to_datatype(args.type_d) - problem.c_dtype = problem.d_dtype - problem.mi_dtype = problem.a_dtype - problem.a_mx_block_size = args.scale_block_size - problem.b_mx_block_size = args.scale_block_size - - # Select best config - best_config = origami.select_config(problem, hardware, configs) - latency = best_config.latency - - print( - f"The best config for [{args.m}, {args.n}, {args.b}, {args.k}] is: MT=({best_config.config.mt.m},{best_config.config.mt.n},{best_config.config.mt.k}), MI=({best_config.config.mi.m},{best_config.config.mi.n},{best_config.config.mi.k}), latency={latency:0.3f}" + tile_list = createTileList(args.arch, args.type_a) + + print(" Number of unique MTxDU: ", len(tile_list)) + + if args.sizes: # sizes from a file + try: + with open(args.path, 'r') as csvfile: + csv_reader = csv.reader(csvfile) + print(f"M,N,Batch,K,MT0,MT1,DU,MI0,MI1,MI2,MI3,latency") + for row in csv_reader: + M = int(row[0]) + N = int(row[1]) + B = int(row[2]) + K = int(row[3]) + ret = origami.select_best_macro_tile_size( + M, + N, + K, + B, + args.trans_a, + args.trans_b, + hardware, + tile_list, + origami.datatype_to_bits(origami.string_to_datatype(args.type_a)), + origami.datatype_to_bits(origami.string_to_datatype(args.type_b)), + origami.datatype_to_bits(origami.string_to_datatype(args.type_d)), + origami.string_to_datatype(args.type_a), + args.scale_block_size, + 0.8, + args.print, + args.wgm, + ) + #MxNxBxK, MT0xMT1xDU, MI0xMI1xMI2xMI3, latency/cycles + print(f"{M},{N},{B},{K},{ret[0][1]},{ret[0][2]},{ret[0][3]},{ret[0][4]},{ret[0][5]},{ret[0][6]},{ret[0][7]},{ret[0][0]:0.3f}") + except FileNotFoundError: + raise FileNotFoundError(f"Error: The size file: '{args.path}' does not exist.") + else: # one size from the command line. + ret = origami.select_best_macro_tile_size( + args.m, + args.n, + args.k, + args.b, + args.trans_a, + args.trans_b, + hardware, + tile_list, + origami.datatype_to_bits(origami.string_to_datatype(args.type_a)), + origami.datatype_to_bits(origami.string_to_datatype(args.type_b)), + origami.datatype_to_bits(origami.string_to_datatype(args.type_d)), + origami.string_to_datatype(args.type_a), + args.scale_block_size, + 0.8, + args.print, + args.wgm, ) - - # Get top configs - ranked_configs = origami.rank_configs(problem, hardware, configs) - print(" Top 5 configs: ") - for i, config in enumerate(ranked_configs[:5]): - print( - f" {i+1}. MT=({config.config.mt.m},{config.config.mt.n},{config.config.mt.k}), MI=({config.config.mi.m},{config.config.mi.n},{config.config.mi.k}), latency={config.latency:0.3f}" - ) + print(f"The best MTxDU for [{args.m}, {args.n}, {args.b}, {args.k}] is: {ret[0]}") # Match this with the top condition + # add an option to list top 10~15 + print(" full list of MTs: \n", ret) if args.print: hardware.print() - with open("configs.log", "w") as file: - for config in configs: - file.write( - f"MT=({config.mt.m},{config.mt.n},{config.mt.k}), MI=({config.mi.m},{config.mi.n},{config.mi.k})\n" - ) + hardware.print_debug_info() + with open("MTxDU.log",'w') as file: + for tile in tile_list: + file.write(f'{tile}\n') return 0 - if __name__ == "__main__": exit(main()) diff --git a/shared/origami/python/setup.py b/shared/origami/python/setup.py index e3493bb61fe..ce2859d583d 100644 --- a/shared/origami/python/setup.py +++ b/shared/origami/python/setup.py @@ -1,27 +1,5 @@ -################################################################################ -# -# MIT License -# -# Copyright 2025 AMD ROCm(TM) Software -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- -# ies of the Software, and to permit persons to whom the Software is furnished -# to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- -# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- -# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# -################################################################################ +# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT # setup.py from setuptools import setup, Extension diff --git a/shared/origami/src/origami/gemm.cpp b/shared/origami/src/origami/gemm.cpp index 40facff606c..e7ee1268957 100644 --- a/shared/origami/src/origami/gemm.cpp +++ b/shared/origami/src/origami/gemm.cpp @@ -1,11 +1,12 @@ // Copyright Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +#include "origami/gemm.hpp" + +#include "origami/streamk.hpp" #include -#include -#include +#include // For timing #include -#include #include #include #include @@ -13,1064 +14,1484 @@ #include #include -#include "origami/hardware.hpp" -#include "origami/math.hpp" -#include "origami/types.hpp" +namespace origami +{ + /* ---------------------------------------------------------------------------------------- */ + /* Misc. functions */ + /* ---------------------------------------------------------------------------------------- */ + // Performs `(n + d - 1) / d`, but is robust against the case where `(n + d - 1)` would + // overflow. + template + constexpr N safe_ceil_div(N n, D d) + { + // Static cast to undo integral promotion. + return static_cast(d == 0 ? 0 : (n / d + (n % d != 0 ? 1 : 0))); + } + + auto lds128_penalty(size_t dim_elems, size_t element_bits) + { + const size_t bytes = dim_elems * safe_ceil_div(element_bits, 8); + const size_t mod = bytes % 128; + if(mod == 0) + return 1.0; + const double frac = double(mod) / 128.0; // 0..1 + const double base = (element_bits <= 16) ? 1.1 : 1.35; // BF16/FP16 < FP32 + return 1.0 + base * frac; // up to ~2.35x worst case + }; -#include "origami/gemm.hpp" -#include "origami/streamk.hpp" + double calculate_work_utilization( + size_t M, size_t N, size_t K, size_t MT_M, size_t MT_N, size_t MT_K) + { + if(M == 0 || N == 0 || K == 0 || MT_M == 0 || MT_N == 0 || MT_K == 0) + return 1.0; + + // Calculate the full dimensions covered by the launched grid of tiles (spatial). + const double launched_M = static_cast(safe_ceil_div(M, MT_M)) * MT_M; + const double launched_N = static_cast(safe_ceil_div(N, MT_N)) * MT_N; + + // Calculate the full depth covered by the k-loop iterations (temporal). + const double launched_K = static_cast(safe_ceil_div(K, MT_K)) * MT_K; + + // The utilization is the ratio of the useful problem volume to the total scheduled volume. + const double useful_volume = static_cast(M * N * K); + const double launched_volume = launched_M * launched_N * launched_K; + + if(launched_volume < 1.0) + return 1.0; // Avoid division by zero for tiny/empty problems + + const double utilization = useful_volume / launched_volume; + + return utilization; + } + + double calculate_output_utilization( + size_t M, size_t N, size_t MT_M, size_t MT_N, size_t vector_elems = 1) + { + if(M == 0 || N == 0 || MT_M == 0 || MT_N == 0) + return 1.0; -namespace origami { -double calculate_work_utilization(const problem_t& problem, const config_t& config) { - const size_t M = problem.size.m; - const size_t N = problem.size.n; - const size_t K = problem.size.k; - - const size_t MT_M = config.mt.m; - const size_t MT_N = config.mt.n; - const size_t MT_K = config.mt.k; - - if (MT_M <= 0 || MT_N <= 0) return 1.0; - - // Calculate the full dimensions covered by the launched grid of tiles (spatial). - const double launched_M = - static_cast(math::safe_ceil_div(M, MT_M)) * static_cast(MT_M); - const double launched_N = - static_cast(math::safe_ceil_div(N, MT_N)) * static_cast(MT_N); - - // Calculate the full depth covered by the k-loop iterations (temporal). - const double launched_K = - static_cast(math::safe_ceil_div(K, MT_K)) * static_cast(MT_K); - - // The utilization is the ratio of the useful problem volume to the total scheduled volume. - const double useful_volume = static_cast(M * N * K); - const double launched_volume = launched_M * launched_N * launched_K; - - if (launched_volume < 1.0) return 1.0; // Avoid division by zero for tiny/empty problems - - const double utilization = useful_volume / launched_volume; - - return utilization; -} - -double calculate_output_utilization(const problem_t& problem, - const config_t& config, - size_t vector_elems = 1) { - const size_t M = problem.size.m; - const size_t N = problem.size.n; - - const size_t MT_M = config.mt.m; - const size_t MT_N = config.mt.n; - - if (MT_M <= 0 || MT_N <= 0) return 1.0; - - // Tiled coverage in M/N - const double launched_M = - static_cast(math::safe_ceil_div(M, MT_M)) * static_cast(MT_M); - const double launched_N = - static_cast(math::safe_ceil_div(N, MT_N)) * static_cast(MT_N); - - // Optional: model vectorization/alignment remainders (e.g., ld/st width) - // This assumes vectors must be fully inside bounds; tail elements are scalarized. - const size_t M_vec = (vector_elems > 1) ? math::safe_ceil_div(M, vector_elems) * vector_elems : M; - const size_t N_vec = (vector_elems > 1) ? math::safe_ceil_div(N, vector_elems) * vector_elems : N; - - const double useful = static_cast(M_vec) * static_cast(N_vec); - const double launched = launched_M * launched_N; - - if (launched < 1.0) return 1.0; - return useful / launched; -} - -// Computes the number of active compute units if there is only one wave and it is partial -// Otherwise, returns hardware.N_CU -std::tuple compute_cu_occupancy(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - grid_selection_t grid_selection, - size_t max_cus, - size_t split = 0) { - // Number of output MTs - size_t num_mts = streamk::compute_number_of_output_tiles( - config.mt.m, config.mt.n, problem.size.m, problem.size.n, problem.batch); - - size_t num_wgs, num_active_cus, numWaves, splitFactor; - - if (split) // if it is given - { - split = split > 1 ? split : 1; - num_wgs = num_mts * split; - num_active_cus = num_wgs < hardware.N_CU ? num_wgs : hardware.N_CU; - numWaves = math::safe_ceil_div(num_wgs, hardware.N_CU); - splitFactor = split; - - if (get_runtime_options(config).debug_enabled) { - config.logger.log("reduction type", "Origami"); + // Tiled coverage in M/N + const double launched_M = static_cast(safe_ceil_div(M, MT_M)) * MT_M; + const double launched_N = static_cast(safe_ceil_div(N, MT_N)) * MT_N; + + // Optional: model vectorization/alignment remainders (e.g., ld/st width) + // This assumes vectors must be fully inside bounds; tail elements are scalarized. + const size_t M_vec = (vector_elems > 1) ? (M / vector_elems) * vector_elems : M; + const size_t N_vec = (vector_elems > 1) ? (N / vector_elems) * vector_elems : N; + + const double useful = static_cast(M_vec) * static_cast(N_vec); + const double launched = launched_M * launched_N; + + if(launched < 1.0) + return 1.0; + return useful / launched; } - } else // as what StreamK predicts - { - auto config_with_reduction = config; - config_with_reduction.reduction_strategy = - streamk::select_reduction(problem, hardware, config, grid_selection); - - num_wgs = streamk::select_grid_size( - problem, hardware, config_with_reduction, grid_selection, max_cus); - - // output variables - num_active_cus = num_wgs < hardware.N_CU ? num_wgs : hardware.N_CU; - // There are cases in which StreamK combines multiple output MTs and assigns to 1 WG. - // That means, we artifically observe one full wave, but that is not what actually happens - // under the hood. From a theoretical point of view, these distributions change all of the - // computations in Origami. With current implementation, it is hard to capture that - // behaviour analytically. So for now, if the num_wgs is less than the num_mts, we calculate - // numWaves based on the num_mts. Otherwise, we use num_wgs to compute numWaves. - numWaves = num_wgs > num_mts ? math::safe_ceil_div(num_wgs, hardware.N_CU) - : math::safe_ceil_div(num_mts, hardware.N_CU); - splitFactor = math::safe_ceil_div(num_wgs, num_mts); - } - - if (get_runtime_options(config).debug_enabled) { - config.logger.log("num_mts", num_mts); - config.logger.log("num_wgs", num_wgs); - config.logger.log("num_active_cus", num_active_cus); - config.logger.log("numWaves", numWaves); - config.logger.log("splitFactor", splitFactor); - config.logger.log("max_cus", max_cus); - } - - return std::make_tuple(num_wgs, num_active_cus, numWaves, splitFactor); -} - -/* ---------------------------------------------------------------------------------------- */ -/* Compute-related functions */ -/* ---------------------------------------------------------------------------------------- */ -// Compute the number of matrix instructions required to compute a single MT_MXMT_NXMT_K tile. -size_t compute_number_matrix_instructions(dim3_t mt, dim3_t mi) { - // Compute the number of Matrix Instructions required in each dim. - size_t num_m_instrs = math::safe_ceil_div(mt.m, mi.m); - size_t num_n_instrs = math::safe_ceil_div(mt.n, mi.n); - size_t num_k_instrs = math::safe_ceil_div(mt.k, mi.k); - - // Total number of matrix instructions. - size_t num_matrix_instrs = num_m_instrs * num_n_instrs * num_k_instrs; - - return num_matrix_instrs; -} - -// Compute arithmic intensity -double arithmetic_intensity(double m, double n, double k, double bytes_per_element) { - // Numerator: 2.0 * m * n * k - // Denominator: (m*n + n*k + m*k) * bytes_per_element - double numerator = 2.0 * m * n * k; - double denominator = (m * n + n * k + m * k) * bytes_per_element; - - return numerator / denominator; -} - -// Computes Emulated arithmetic intensity for TF32 (assumes 3xBF16). -double emulated_tf32_arithmetic_intensity(double m, double n, double k, double bytes_per_element) { - // Numerator: 3.0 * 2.0 * m * n * k - // Denominator: (m*n + n*k + m*k) * bytes_per_element - double numerator = 3.0 * 2.0 * m * n * k; - double denominator = (m * n + n * k + m * k) * bytes_per_element; - - return numerator / denominator; -} - -// Compute cvt overhead in x1 tf32 emulation -// TODO: We can generalize the same routine to cover more GEMMs that perform conversion -static inline double compute_cvt_overhead_x1(const problem_t& problem, - const hardware_t& hardware, - const config_t& config) { - // In X1 TF32 GEMMs, we do: - // v_cvt_pk_bf16_f32 (convert/pack fp32 to bf16) - // v_cvt_pk_bf16_f32 (convert/pack fp32 to bf16) - // ds_write_b64 - // That is, the extra instructions that we need to account for are the two cvt_pk ops - // per wave tile - - // However, these extra ops should not be added up to the overal tile latency becuase - // they can be run in parallel to Matix and Memory operations (given they are not dependent). - // So, We should ideally take L_tile = max{Mem, Comp, Vec (cvt latencies)}. - // Since, Vec latency is not modeled yet, we somehow model that into the current logic - // by scaling according to MFMA latencies and putting some heuristics to model the fact - // that these vector operations can be hidden (read interleaved) with the other memory - // or MFMA instructions. - - // --- Shorthands ----------------------------------------------------------- - const double MT_M = static_cast(config.mt.m); - const double MT_N = static_cast(config.mt.n); - const double MT_K = static_cast(config.mt.k); - - const double MI_M = static_cast(config.mi.m); - const double MI_N = static_cast(config.mi.n); - const double MI_K = static_cast(config.mi.k); - - const auto a_bytes = data_type_to_bytes(problem.a_dtype); - const auto b_bytes = data_type_to_bytes(problem.b_dtype); - - // TODO: Use kernel's actual wavetiles. - const double wave_tile_m = MT_M / 2.0; - const double wave_tile_n = MT_N / 2.0; - const double wave_tile_k = MT_K / MI_K; - - // MFMA count - const double N_MI = (wave_tile_m / MI_M) * (wave_tile_n / MI_N) * wave_tile_k; - const double num_mfma = 1.0 * N_MI; - // Cycle scale per MI - const double L_MI = hardware.get_mi_latency(MI_M, MI_N, MI_K, problem.mi_dtype); - const double mfma_cycles = num_mfma * L_MI; - - // 2) Bytes (per K-slice), using ceil-div to whole bytes - const double bytesA = wave_tile_m * MT_K * static_cast(a_bytes); - const double bytesB = wave_tile_n * MT_K * static_cast(b_bytes); - - // 3) Modeled transfer quanta (128B lines) - // dsA = bytesA / (128 * MI_M) - // dsB = bytesB / (128 * MI_N) - // GR = dsA (global->LDS modeled equal to A-side DS) - const double dsA = (bytesA / 128.0) / MI_M; // LDS->VGPR for A - const double dsB = (bytesB / 128.0) / MI_N; // LDS->VGPR for B - const double GR = dsA; // Global->LDS reads - const double LR = dsA + dsB; // total DS->VGPR - - // 5) Exposed vs hidden CVT - // spare MFMA - const double spare_mfma = std::max(0.0, num_mfma - LR - GR); - // 2 cvt per each ds_write (this for SS_BSS -- should be revised for other datatypes) - // Each cvt has a latency of four. It is scaled by the MI Latency - // Note: change 16.0 based on mi_data_type if we want to generalize this for all - // casting GEMMs. - const double cvt = (2.0 * 4.0 / 16.0 * L_MI) * LR; - // cvt ops are interleaved in main loop and don't stall matrix or memory units. - // Heuristically, we set - const double H = (8.0 / 16.0 * L_MI) * spare_mfma + (4.0 / 16.0) * L_MI * (LR + GR); - const double overhead = std::max(cvt - H, 0.0); - - return overhead; -} - -// Compute cvt overhead in tf32 emulation -static inline double compute_cvt_overhead(const problem_t& problem, - const hardware_t& hardware, - const config_t& config) { - // Wave tile sizes - // TODO: Use kernel's actual wavetiles. - const double wave_tile_m = config.mt.m / 2.0; - const double wave_tile_n = config.mt.n / 2.0; - const double wave_tile_k = config.mt.k / config.mi.k; - - // MFMA count and cycles - const double N_MI = (wave_tile_m / config.mi.m) * (wave_tile_n / config.mi.n) * wave_tile_k; - - // TF32 emu: 3× BF16 MI issue slots - const double num_mfma = 3.0 * static_cast(N_MI); - - // Cycle scale per MI (use BF16 MI latency as the basic timing quantum) - const double L_MI_bf16 = - hardware.get_mi_latency(config.mi.m, config.mi.n, config.mi.k, data_type_t::BFloat16); - // const double mfma_cycles = num_mfma * L_MI_bf16; - - // 2) Bytes (per K-slice), using ceil-div to whole bytes - int a_bytes = data_type_to_bytes(problem.a_dtype); - int b_bytes = data_type_to_bytes(problem.b_dtype); - - const double bytesA = static_cast(wave_tile_m) * config.mt.k * a_bytes; - const double bytesB = static_cast(wave_tile_n) * config.mt.k * b_bytes; - - // const double mt_bytesA - // = static_cast(MT_M) * MT_K * safe_ceil_div(element_size_A, 8); - - // 3) Modeled transfer quanta (128B lines) - // dsA = bytesA / (128 * MI_M) - // dsB = bytesB / (128 * MI_N) - // GR = dsA (global->LDS modeled equal to A-side DS) - const double dsA = (bytesA / 128.0) / static_cast(config.mi.m); // LDS->VGPR for A - const double dsB = (bytesB / 128.0) / static_cast(config.mi.n); // LDS->VGPR for B - const double GR = dsA; // Global->LDS reads - const double LR = dsA + dsB; // total DS->VGPR - - // 4) Heuristic cycle weights (scaled to MI latency). - // Preserves your A=104, B=8, C=4 when L_MI_bf16 == 16. - // 24 vector instructions per 2 ds_reads (16x16x32) - // 24 vector instructions per 2 ds_reads for A and for B. - // 3 instructions per fp32 value read; number ds_read * size - const double A = (104.0 / 16.0) * L_MI_bf16; // CVT per LR-sized chunk (DS->VGPR) - const double B = (8.0 / 16.0) * L_MI_bf16; // hidden per spare MFMA slot - // MI16: 16 - 4 (12 cycles), for those 4 cycles, VGPRs are locked. 8 cycles to do anything. - const double C = (4.0 / 16.0) * L_MI_bf16; // hidden per (LR+GR) slot // MI16 - // 32 cycles (mfma), 4 cycles, 28, 4 vgpr lock, 24 cycles left. - // 24: 6 conv instructions, 3 ds_reads, ~6 grs - - // 5) Exposed vs hidden CVT - const double spare_mfma = std::max(0.0, num_mfma - LR - GR); - const double cvt = A * dsA; // only DS->VGPR contributes CVT - const double H = B * spare_mfma + C * (LR + GR); // hidden cycles - const double overhead = std::max(cvt - H, 0.0); - - // 6) Efficiency - // const double denom = mfma_cycles + overhead; - // const double eff = (denom > 0.0) ? (mfma_cycles / denom) : 1; - - return overhead; -} - -// Determine the compute latency per MT_MxMT_NxMT_K Macro Tile (L_MT). -size_t compute_mt_compute_latency(const problem_t& problem, - const hardware_t& hardware, - const config_t& config) { - // Compute the number of matrix instructions - size_t N_MI = compute_number_matrix_instructions(config.mt, config.mi); - // Latency of a single MT_MxMT_NxMT_k tile is the latency of one MI multiplied by - // number of MI per MT_MxMT_NxMT_k. - size_t L_MI = hardware.get_mi_latency(config.mi.m, config.mi.n, config.mi.k, problem.mi_dtype); - - // size_t mt_arith = arithmetic_intensity(MT_M, MT_N, MT_K, 2); - // printf("MT_M:%d MT_N:%d MT_K:%d arith:%d\n", MT_M, MT_N, MT_K, mt_arith); - // size_t arith = ((M * N * K * 2) / (M * K + N * K + M * N)); - size_t L_MT = L_MI * N_MI; - - return L_MT; -} - -/* ---------------------------------------------------------------------------------------- */ -/* Memory-related functions */ -/* ---------------------------------------------------------------------------------------- */ -// Check if MT fits in LDS -bool check_lds_capacity(const hardware_t& hardware, - dim3_t mt, - data_type_t a_dtype, - data_type_t b_dtype) { - // A and B size - size_t a_loads_in_bytes = mt.mk() * data_type_to_bytes(a_dtype); - size_t b_loads_in_bytes = mt.nk() * data_type_to_bytes(b_dtype); - // Size of those in bytes - size_t LDS_usage = a_loads_in_bytes + b_loads_in_bytes; - - if (LDS_usage > hardware.lds_capacity) { - return false; // Exceeds LDS capacity - } else { - return true; // Within LDS capacity - } -} - -// Compute limited achievable memory bandwidth based on active CUs -double compute_mem_bw_from_occupancy(const hardware_t& hardware, size_t num_active_cus) { - const double CUs = static_cast(num_active_cus); - - if (num_active_cus > hardware.N_CU) return 1.0; - - const double bw_limited = std::get<0>(hardware.mem_bw_per_wg_coefficients) * CUs * CUs + - std::get<1>(hardware.mem_bw_per_wg_coefficients) * CUs + - std::get<2>(hardware.mem_bw_per_wg_coefficients); - - return std::min(bw_limited, 1.0); -} - -double estimate_l2_hit(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - size_t splitting_factor) { - // Use size_t for dimensions and counts to ensure type safety. - const size_t workgroups_m = math::safe_ceil_div(problem.size.m, config.mt.m); - const size_t workgroups_n = math::safe_ceil_div(problem.size.n, config.mt.n); - const size_t total_workgroups = workgroups_m * workgroups_n; - - // Concurrently executing workgroups are limited by the number of CUs.a - const size_t concurrent_workgroups = std::min(total_workgroups, hardware.N_CU); - if (concurrent_workgroups == 0) - throw std::runtime_error("#Workgroups is zero in estimate l2 hit"); - - // Number of CUs that might share the same K-tiles, adjusted for K-splitting. - // This affects contention on the L2 cache partitions (XCDs). - const size_t effective_cus = math::safe_ceil_div(concurrent_workgroups, splitting_factor); - const size_t cu_per_xcd = - std::max(math::safe_ceil_div(effective_cus, hardware.NUM_XCD), static_cast(1)); - - // Initial guess for the L2 tile dimensions (a tile of workgroups). - size_t l2_tile_n = std::min(static_cast(config.workgroup_mapping), workgroups_n); - size_t l2_tile_m = math::safe_ceil_div(cu_per_xcd, l2_tile_n); - - // Handle wrap-around case: if the tile is taller than the grid, wrap it to be wider. - if (l2_tile_m > workgroups_m) { - size_t num_wraps = (l2_tile_m / workgroups_m); - l2_tile_n += (num_wraps * config.workgroup_mapping); - l2_tile_m = workgroups_m; - } - - // Clamp initial tile dimensions to the actual grid size. - l2_tile_m = std::max(std::min(workgroups_m, l2_tile_m), static_cast(1)); - l2_tile_n = std::max(std::min(workgroups_n, l2_tile_n), static_cast(1)); - - // Calculate memory footprint in bytes. - const size_t a_bytes = static_cast(data_type_to_bytes(problem.a_dtype)); - const size_t b_bytes = static_cast(data_type_to_bytes(problem.b_dtype)); - auto calculate_footprint = [&](size_t tile_m, size_t tile_n) { - size_t a_footprint = tile_m * config.mt.mk() * a_bytes; - size_t b_footprint = tile_n * config.mt.nk() * b_bytes; - return a_footprint + b_footprint; - }; - - // Symmetrically shrink the L2 tile until it fits in the L2 cache capacity. - // This is more robust than shrinking only one dimension. - while (calculate_footprint(l2_tile_m, l2_tile_n) > hardware.L2_capacity) { - if (l2_tile_m > 1 && l2_tile_m >= l2_tile_n) { - l2_tile_m--; - } else if (l2_tile_n > 1) { - l2_tile_n--; - } else { - // Cannot shrink further. - break; + + // Computes the number of active compute units if there is only one wave and it is partial + // Otherwise, returns hardware.N_CU + std::tuple compute_CU_occupancy(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, + size_t element_size_B, + size_t element_size_out, + data_type_t mi_datatype, + int WGM, + size_t workspace_size, + size_t workspace_size_per_elem_c, + int occupancy, + int dynamic_grid_version, + size_t split, + size_t max_cus) + { + // Number of output MTs + size_t numMT_M = safe_ceil_div(M, MT_M); + size_t numMT_N = safe_ceil_div(N, MT_N); + size_t numMTs = numMT_M * numMT_N * batch; + + size_t numWGs, numActiveCUs, numWaves, splitFactor; + + if(split) // if it is given + { + split = split > 1 ? split : 1; + numWGs = numMTs * split; + numActiveCUs = numWGs < hardware.N_CU ? numWGs : hardware.N_CU; + numWaves = safe_ceil_div(numWGs, hardware.N_CU); + splitFactor = split; + + if(hardware_t::is_debug_enabled()) + { + hardware.log_debug("reduction type", "Origami"); + } + } + else // as what StreamK predicts + { + streamk::reduction_type rt = streamk::select_reduction(M, + N, + K, + batch, + MT_M, + MT_N, + MT_K, + hardware, + dynamic_grid_version); + numWGs = streamk::select_grid(M, + N, + K, + batch, + transA, + transB, + element_size_A, + element_size_B, + element_size_out, + mi_datatype, + workspace_size, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + WGM, + workspace_size_per_elem_c, + occupancy, + hardware, + dynamic_grid_version, + rt, + max_cus); + + // output variables + numActiveCUs = numWGs < hardware.N_CU ? numWGs : hardware.N_CU; + // There are cases in which StreamK combines multiple output MTs and assigns to 1 WG. + // That means, we artifically observe one full wave, but that is not what actually happens + // under the hood. From a theoretical point of view, these distributions change all of the + // computations in Origami. With current implementation, it is hard to capture that + // behaviour analytically. So for now, if the numWGs is less than the numMTs, we calculate + // numWaves based on the numMTs. Otherwise, we use numWGs to compute numWaves. + numWaves = numWGs > numMTs ? safe_ceil_div(numWGs, hardware.N_CU) + : safe_ceil_div(numMTs, hardware.N_CU); + splitFactor = safe_ceil_div(numWGs, numMTs); + + if(hardware_t::is_debug_enabled()) + { + hardware.log_debug("reduction type", streamk::rtype_to_string(rt)); + } + } + + if(hardware_t::is_debug_enabled()) + { + hardware.log_debug("numMTs", numMTs); + hardware.log_debug("numWGs", numWGs); + hardware.log_debug("numActiveCUs", numActiveCUs); + hardware.log_debug("numWaves", numWaves); + hardware.log_debug("splitFactor", splitFactor); + hardware.log_debug("max_cus", max_cus); + } + + return std::make_tuple(numWGs, numActiveCUs, numWaves, splitFactor); + } + + /* ---------------------------------------------------------------------------------------- */ + /* Compute-related functions */ + /* ---------------------------------------------------------------------------------------- */ + // Compute the number of matrix instructions required to compute a single MT_MXMT_NXMT_K tile. + size_t compute_number_matrix_instructions(const hardware_t& hardware, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K) + { + // Compute the number of Matrix Instructions required in each dim + size_t N_MI_M = safe_ceil_div(MT_M, MI_M); + size_t N_MI_N = safe_ceil_div(MT_N, MI_N); + size_t N_MI_K = safe_ceil_div(MT_K, MI_K); + // Total number of matrix instructions for MT_MxMT_NxMT_K tile + size_t N_MI = N_MI_M * N_MI_N * N_MI_K; + + return N_MI; } - } - - // Uncached reads are the first read of each unique element within the L2 tile. - const long long uncached_A_reads = static_cast(l2_tile_m) * config.mt.mk(); - const long long uncached_B_reads = static_cast(l2_tile_n) * config.mt.nk(); - const long long total_uncached_reads = uncached_A_reads + uncached_B_reads; - - // Total reads are the sum of all reads performed by all workgroups in the L2 tile. - // Matrix A is reused l2_tile_n times, Matrix B is reused l2_tile_m times. - const long long total_A_reads = uncached_A_reads * l2_tile_n; - const long long total_B_reads = uncached_B_reads * l2_tile_m; - const long long total_reads = std::max(total_A_reads + total_B_reads, 1LL); - - const long long cached_reads = total_reads - total_uncached_reads; - - double l2_hit_rate = static_cast(cached_reads) / static_cast(total_reads); - - // Final clamping and logging. - if (get_runtime_options(config).debug_enabled) { - config.logger.log("L2Tile_M", l2_tile_m); - config.logger.log("L2Tile_N", l2_tile_n); - config.logger.log("TotalWorkgroups", total_workgroups); - config.logger.log("ConcurrentWorkgroups", concurrent_workgroups); - } - - // Clamp the hit rate to be within a realistic [0, 1] range. - return std::max(0.0, std::min(l2_hit_rate, 1.0)); -} - -// Estimate MALL hit-rate -double estimate_mall_hit(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - size_t num_active_cus, - size_t splitting_factor) { - const size_t workgroups_m = math::safe_ceil_div(problem.size.m, config.mt.m); - const size_t workgroups_n = math::safe_ceil_div(problem.size.n, config.mt.n); - - if (num_active_cus == 0) throw std::runtime_error("Number of Active CUs was 0"); - - // --- Initial Tile Sizing based on Concurrency --- - // Use ceiling division for a more accurate initial guess. - size_t mall_tile_m = - math::safe_ceil_div(num_active_cus, static_cast(config.workgroup_mapping)); - size_t mall_tile_n = std::min(static_cast(config.workgroup_mapping), workgroups_n); - - // Handle wrap-around case if the tile is taller than the grid. - if (mall_tile_m > workgroups_m) { - size_t num_wraps = mall_tile_m / workgroups_m; - mall_tile_n += (num_wraps * config.workgroup_mapping); - mall_tile_m = workgroups_m; - } - - // Clamp initial tile dimensions to the actual grid size. - mall_tile_m = std::max(std::min(workgroups_m, mall_tile_m), static_cast(1)); - mall_tile_n = std::max(std::min(workgroups_n, mall_tile_n), static_cast(1)); - - // --- CRITICAL: Shrink tile to fit into MALL Capacity --- - const size_t a_bytes = static_cast(data_type_to_bytes(problem.a_dtype)); - const size_t b_bytes = static_cast(data_type_to_bytes(problem.b_dtype)); - - auto calculate_footprint = [&](size_t tile_m, size_t tile_n) { - size_t a_footprint = tile_m * config.mt.mk() * a_bytes; - size_t b_footprint = tile_n * config.mt.nk() * b_bytes; - return a_footprint + b_footprint; - }; - - // --- Calculate Hit Rate based on the final, capacity-aware tile size --- - const long long uncached_A_reads = static_cast(mall_tile_m) * config.mt.mk(); - const long long uncached_B_reads = static_cast(mall_tile_n) * config.mt.nk(); - const long long total_uncached_reads = uncached_A_reads + uncached_B_reads; - - const long long total_A_reads = uncached_A_reads * mall_tile_n; - const long long total_B_reads = uncached_B_reads * mall_tile_m; - const long long total_reads = std::max(total_A_reads + total_B_reads, 1LL); - - const long long cached_reads = total_reads - total_uncached_reads; - - double mall_hit_rate = static_cast(cached_reads) / static_cast(total_reads); - - if (get_runtime_options(config).debug_enabled) { - config.logger.log("MallTile_M", mall_tile_m); - config.logger.log("MallTile_N", mall_tile_n); - config.logger.log("MallFootprint_Bytes", calculate_footprint(mall_tile_m, mall_tile_n)); - } - - // Clamp the final result to the valid [0, 1] range. - return std::max(0.0, std::min(mall_hit_rate, 1.0)); -} - -/** - * @brief L2 hit rate from a global (problem-wide) perspective using the refactored API. - * Computes in BYTES to correctly handle differing A/B dtypes. + + // Compute arithmic intensity + double arithmetic_intensity(double m, double n, double k, double bytes_per_element) + { + // Numerator: 2.0 * m * n * k + // Denominator: (m*n + n*k + m*k) * bytes_per_element + double numerator = 2.0 * m * n * k; + double denominator = (m * n + n * k + m * k) * bytes_per_element; + + return numerator / denominator; + } + + // Computes Emulated arithmetic intensity for TF32 (assumes 3xBF16). + double emulated_tf32_arithmetic_intensity(double m, double n, double k, double bytes_per_element) + { + // Numerator: 3.0 * 2.0 * m * n * k + // Denominator: (m*n + n*k + m*k) * bytes_per_element + double numerator = 3.0 * 2.0 * m * n * k; + double denominator = (m * n + n * k + m * k) * bytes_per_element; + + return numerator / denominator; + } + + // Compute cvt overhead in x1 tf32 emulation + // TODO: We can generalize the same routine to cover more GEMMs that perform conversion + static inline double compute_cvt_overhead_x1(const hardware_t& hardware, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, + size_t element_size_B, + data_type_t mi_datatype) + { + // In X1 TF32 GEMMs, we do: + // v_cvt_pk_bf16_f32 (convert/pack fp32 to bf16) + // v_cvt_pk_bf16_f32 (convert/pack fp32 to bf16) + // ds_write_b64 + // That is, the extra instructions that we need to account for are the two cvt_pk ops + // per wave tile + + // However, these extra ops should not be added up to the overal tile latency becuase + // they can be run in parallel to Matix and Memory operations (given they are not dependent). + // So, We should ideally take L_tile = max{Mem, Comp, Vec (cvt latencies)}. + // Since, Vec latency is not modeled yet, we somehow model that into the current logic + // by scaling according to MFMA latencies and putting some heuristics to model the fact + // that these vector operations can be hidden (read interleaved) with the other memory + // or MFMA instructions. + + // TODO: Use kernel's actual wavetiles. + const double wave_tile_m = MT_M / 2.0; + const double wave_tile_n = MT_N / 2.0; + const double wave_tile_k = MT_K / MI_K; + + // MFMA count + const double N_MI = (wave_tile_m / MI_M) * (wave_tile_n / MI_N) * wave_tile_k; + const double num_mfma = 1.0 * static_cast(N_MI); + // Cycle scale per MI + const double L_MI = hardware.get_mi_latency(MI_M, MI_N, MI_K, mi_datatype); + const double mfma_cycles = num_mfma * L_MI; + + // 2) Bytes (per K-slice), using ceil-div to whole bytes + const double bytesA + = static_cast(wave_tile_m) * MT_K * safe_ceil_div(element_size_A, 8); + const double bytesB + = static_cast(wave_tile_n) * MT_K * safe_ceil_div(element_size_B, 8); + + // 3) Modeled transfer quanta (128B lines) + // dsA = bytesA / (128 * MI_M) + // dsB = bytesB / (128 * MI_N) + // GR = dsA (global->LDS modeled equal to A-side DS) + const double dsA = (bytesA / 128.0) / static_cast(MI_M); // LDS->VGPR for A + const double dsB = (bytesB / 128.0) / static_cast(MI_N); // LDS->VGPR for B + const double GR = dsA; // Global->LDS reads + const double LR = dsA + dsB; // total DS->VGPR + + // 5) Exposed vs hidden CVT + // spare MFMA + const double spare_mfma = std::max(0.0, num_mfma - LR - GR); + // 2 cvt per each ds_write (this for SS_BSS -- should be revised for other datatypes) + // Each cvt has a latency of four. It is scaled by the MI Latency + // Note: change 16.0 based on mi_data_type if we want to generalize this for all + // casting GEMMs. + const double cvt = (2.0 * 4.0 / 16.0 * L_MI) * LR; + // cvt ops are interleaved in main loop and don't stall matrix or memory units. + // Heuristically, we set + const double H = (8.0 / 16.0 * L_MI) * spare_mfma + (4.0 / 16.0) * L_MI * (LR + GR); + const double overhead = std::max(cvt - H, 0.0); + + return overhead; + } + + // Compute cvt overhead in tf32 emulation + static inline double compute_cvt_overhead(const hardware_t& hardware, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, + size_t element_size_B) + { + // Wave tile sizes + // TODO: Use kernel's actual wavetiles. + const double wave_tile_m = MT_M / 2.0; + const double wave_tile_n = MT_N / 2.0; + const double wave_tile_k = MT_K / MI_K; + + // MFMA count and cycles + const double N_MI = (wave_tile_m / MI_M) * (wave_tile_n / MI_N) * wave_tile_k; + + // TF32 emu: 3× BF16 MI issue slots + const double num_mfma = 3.0 * static_cast(N_MI); + + // Cycle scale per MI (use BF16 MI latency as the basic timing quantum) + const double L_MI_bf16 = hardware.get_mi_latency(MI_M, MI_N, MI_K, data_type_t::BFloat16); + //const double mfma_cycles = num_mfma * L_MI_bf16; + + // 2) Bytes (per K-slice), using ceil-div to whole bytes + const double bytesA + = static_cast(wave_tile_m) * MT_K * safe_ceil_div(element_size_A, 8); + const double bytesB + = static_cast(wave_tile_n) * MT_K * safe_ceil_div(element_size_B, 8); + + // const double mt_bytesA + // = static_cast(MT_M) * MT_K * safe_ceil_div(element_size_A, 8); + + // 3) Modeled transfer quanta (128B lines) + // dsA = bytesA / (128 * MI_M) + // dsB = bytesB / (128 * MI_N) + // GR = dsA (global->LDS modeled equal to A-side DS) + const double dsA = (bytesA / 128.0) / static_cast(MI_M); // LDS->VGPR for A + const double dsB = (bytesB / 128.0) / static_cast(MI_N); // LDS->VGPR for B + const double GR = dsA; // Global->LDS reads + const double LR = dsA + dsB; // total DS->VGPR + + // 4) Heuristic cycle weights (scaled to MI latency). + // Preserves your A=104, B=8, C=4 when L_MI_bf16 == 16. + // 24 vector instructions per 2 ds_reads (16x16x32) + // 24 vector instructions per 2 ds_reads for A and for B. + // 3 instructions per fp32 value read; number ds_read * size + const double A = (104.0 / 16.0) * L_MI_bf16; // CVT per LR-sized chunk (DS->VGPR) + const double B = (8.0 / 16.0) * L_MI_bf16; // hidden per spare MFMA slot + // MI16: 16 - 4 (12 cycles), for those 4 cycles, VGPRs are locked. 8 cycles to do anything. + const double C = (4.0 / 16.0) * L_MI_bf16; // hidden per (LR+GR) slot // MI16 + // 32 cycles (mfma), 4 cycles, 28, 4 vgpr lock, 24 cycles left. + // 24: 6 conv instructions, 3 ds_reads, ~6 grs + + // 5) Exposed vs hidden CVT + const double spare_mfma = std::max(0.0, num_mfma - LR - GR); + const double cvt = A * dsA; // only DS->VGPR contributes CVT + const double H = B * spare_mfma + C * (LR + GR); // hidden cycles + const double overhead = std::max(cvt - H, 0.0); + + // 6) Efficiency + //const double denom = mfma_cycles + overhead; + //const double eff = (denom > 0.0) ? (mfma_cycles / denom) : 1; + + return overhead; + } + + // Determine the compute latency per MT_MxMT_NxMT_K Macro Tile (L_MT). + size_t compute_mt_compute_latency(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, + size_t element_size_B, + data_type_t mi_datatype) + { + // Compute the number of matrix instructions + size_t N_MI + = compute_number_matrix_instructions(hardware, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K); + // Latency of a single MT_MxMT_NxMT_k tile is the latency of one MI multiplied by + // number of MI per MT_MxMT_NxMT_k. + size_t L_MI = hardware.get_mi_latency(MI_M, MI_N, MI_K, mi_datatype); + + // size_t mt_arith = arithmetic_intensity(MT_M, MT_N, MT_K, 2); + // printf("MT_M:%d MT_N:%d MT_K:%d arith:%d\n", MT_M, MT_N, MT_K, mt_arith); + // size_t arith = ((M * N * K * 2) / (M * K + N * K + M * N)); + size_t L_MT = L_MI * N_MI; + + return L_MT; + } + + /* ---------------------------------------------------------------------------------------- */ + /* Memory-related functions */ + /* ---------------------------------------------------------------------------------------- */ + // Check if MT fits in LDS + bool check_lds_capacity( + const hardware_t& hardware, size_t MT_M, size_t MT_N, size_t MT_K, size_t element_size) + { + // A and B size + size_t Ld_A_value = compute_A_loads(MT_M, MT_K); + size_t Ld_B_value = compute_B_loads(MT_N, MT_K); + // Size of those in bytes + size_t LDS_usage = (Ld_A_value + Ld_B_value) * (element_size / 8); + + if(LDS_usage > hardware.lds_capacity) + { + return false; // Exceeds LDS capacity + } + else + { + return true; // Within LDS capacity + } + } + + // Compute the amount of data loaded from A to produce a MT_MxMT_NxMT_K tile. + size_t compute_A_loads(size_t MT_M, size_t MT_K) + { + // Compute the size of loads from A for a single MT_MxMT_NxMT_K tiles + size_t Ld_A_value = MT_M * MT_K; + + return Ld_A_value; + } + + // Compute the amount of data loaded from B to produce a MT_MxMT_NxMT_K tile. + size_t compute_B_loads(size_t MT_N, size_t MT_K) + { + // Compute the size of loads from B for a single MT_MxMT_NxMT_K tiles + size_t Ld_B_value = MT_N * MT_K; + + return Ld_B_value; + } + + // Compute limited achievable memory bandwidth based on active CUs + double compute_mem_bw_from_occupancy(const hardware_t& hardware, size_t numActiveCUs) + { + const double CUs = static_cast(numActiveCUs); + + if(numActiveCUs > hardware.N_CU) + return 1.0; + + const double bw_limited = std::get<0>(hardware.mem_bw_per_wg_coefficients) * CUs * CUs + + std::get<1>(hardware.mem_bw_per_wg_coefficients) * CUs + + std::get<2>(hardware.mem_bw_per_wg_coefficients); + + return std::min(bw_limited, 1.0); + } + + /* + * This heuristic models data reuse by defining a "tile of workgroups" that can fit + * its working set (portions of matrices A and B) into the L2 cache. The hit rate + * is the ratio of reused data reads to total data reads within this tile. + * + * @param M, N, K, batch Problem dimensions. + * @param MT_M, MT_N, MT_K Macro-tile dimensions. + * @param element_size Size of a single data element in bits. + * @param WGM Workgroup mapping size (typically 64). + * @param splittingFactor K-splitting factor, reduces L2 contention. + * @return Estimated L2 hit rate (0.0 to 1.0). */ -double compute_l2_hit_rate_global(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - size_t l2_capacity_bytes) { - // --- Hardware Parameters (as requested, defined locally) --- - // You would normally get l2_capacity_bytes from your hardware_t struct. - if (l2_capacity_bytes == 0) throw std::runtime_error("L2 Capacity is zero"); - - // 1. Calculate the grid dimensions in terms of macro-tiles - const size_t grid_m = math::safe_ceil_div(problem.size.m, config.mt.m); - const size_t grid_n = math::safe_ceil_div(problem.size.n, config.mt.n); - - if (grid_m == 0 || grid_n == 0) - throw std::runtime_error("estimate_l2_hit grid dimensions can not be zero"); - - // 2. Calculate the working set size for one full pass of global reuse - // This is the data needed by one full column of CUs (for A) and one full row (for B). - const double a_bytes = static_cast(data_type_to_bytes(problem.a_dtype)); - const double b_bytes = static_cast(data_type_to_bytes(problem.b_dtype)); - - const double a_working_set = static_cast(grid_m * config.mt.mk()) * a_bytes; - const double b_working_set = static_cast(grid_n * config.mt.nk()) * b_bytes; - const double total_working_set_bytes = a_working_set + b_working_set; - - // 3. CRUCIAL: Check if the working set fits in the L2 cache. - // If it doesn't, the global reuse pattern is broken by capacity misses, - // and the hit rate will be very low. - if (total_working_set_bytes > l2_capacity_bytes) { - // Return a floor value for the hit rate. The exact value can be tuned, - // but it should be low to indicate that the ideal reuse is not possible. - return 0.1; // 10% hit rate - } - - // 4. If it fits, calculate the idealized global hit rate - // Total reads if nothing was cached - const double total_A_reads = static_cast(grid_m * grid_n * config.mt.mk()); - const double total_B_reads = static_cast(grid_m * grid_n * config.mt.nk()); - - // Uncached reads are the first-time fetches for each row/column - const double uncached_A_reads = - static_cast(grid_m * config.mt.mk()); // One full column fetches A - const double uncached_B_reads = - static_cast(grid_n * config.mt.nk()); // One full row fetches B - - const double total_reads = total_A_reads + total_B_reads; - if (total_reads == 0) return 1.0; // No reads, perfect hit rate. - - const double cached_reads = - (total_A_reads - uncached_A_reads) + (total_B_reads - uncached_B_reads); - - return cached_reads / total_reads; -} - -inline size_t round_up_mul(size_t x, size_t m) { return (x + m - 1) / m * m; } - -size_t round_elements_to_128B(size_t elements, size_t element_size_bits) { - const size_t transaction_bits = 128u * 8u; // 1024 - const size_t g = std::gcd(element_size_bits, transaction_bits); - const size_t E_block = transaction_bits / g; // elements per 128B-aligned chunk - return round_up_mul(elements, E_block); -} - -// Determine the memory latency -double compute_memory_latency(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - size_t num_active_cus, - size_t splitting_factor) { - // Extract parameters from structured types - const auto a_bytes = data_type_to_bytes(problem.a_dtype); - const auto b_bytes = data_type_to_bytes(problem.b_dtype); - const auto a_bits = datatype_to_bits(problem.a_dtype); - const auto b_bits = datatype_to_bits(problem.b_dtype); - size_t batch = problem.batch; - - const bool a_trans = (problem.a_transpose == transpose_t::T); - const bool b_trans = (problem.b_transpose == transpose_t::T); - - const size_t MT_M = config.mt.m; - const size_t MT_N = config.mt.n; - const size_t MT_K = config.mt.k; - - // 1) Estimate L2 hit-rate - double H_mem1 = estimate_l2_hit(problem, hardware, config, splitting_factor); - - // Global cap on L2 hit-rate (prevents impossible cache residency claims) - // (Assumes capacity is given in KiB, convert to bytes) - double H_mem1_global = - compute_l2_hit_rate_global(problem, hardware, config, hardware.L2_capacity * 1024); - - H_mem1 = std::min(H_mem1, H_mem1_global); - - if (H_mem1 == 0) { H_mem1 = 0.5; } - - // 2) Estimate mall hit-rate - double H_mem2 = estimate_mall_hit(problem, hardware, config, num_active_cus, splitting_factor); - - // 3) Total loads are loads from A and loads from B - size_t MT_M_rounded_128bytes = round_elements_to_128B(MT_M, a_bits); - size_t MT_N_rounded_128bytes = round_elements_to_128B(MT_N, a_bits); - size_t MT_K_rounded_128bytes = round_elements_to_128B(MT_K, a_bits); - - if (!a_trans && !b_trans) { - MT_N_rounded_128bytes = MT_N; - MT_K_rounded_128bytes = MT_K; - } else if (a_trans && !b_trans) { - MT_M_rounded_128bytes = MT_M; - MT_N_rounded_128bytes = MT_N; - } else if (!a_trans && b_trans) { - MT_K_rounded_128bytes = MT_K; - } - - size_t Ld_A_value = MT_M_rounded_128bytes * MT_K_rounded_128bytes; - size_t Ld_B_value = MT_N_rounded_128bytes * MT_K_rounded_128bytes; - size_t Ld_CU_bytes = (Ld_A_value * static_cast(a_bytes)) // A Bytes - + (Ld_B_value * static_cast(b_bytes)); // B Bytes - - // Logic for block scaled datatypes (Assuming BS=32 and 8-bit scales) - // TODO This is technically wrong, need separate flag to enable MX so we can differentiate FP8 - // and MX8 - if (a_bits < 8 && problem.a_mx_block_size != 0) { - // Number of scales per tile - size_t num_scales_A = math::safe_ceil_div(config.mt.mk(), problem.a_mx_block_size); - Ld_CU_bytes += num_scales_A; // One Byte per scale - } - if (b_bits < 8 && problem.b_mx_block_size != 0) { - // Number of scales per tile - size_t num_scales_B = math::safe_ceil_div(config.mt.nk(), problem.b_mx_block_size); - Ld_CU_bytes += num_scales_B; // One Byte per scale - } - - // 4) total loads by all CUs - double total_Ld = Ld_CU_bytes * static_cast(num_active_cus); - - // 5) mem1‐limited factor (simple linear model) - double mem1_bw_limited = static_cast(num_active_cus) / static_cast(hardware.N_CU); - double limited_mem1_bw = (hardware.mem1_perf_ratio * mem1_bw_limited); - - // 6) mem1 latency - double L_mem_mem1 = (limited_mem1_bw > 0) ? (total_Ld / (limited_mem1_bw)) : 0.0; - - // 7) mem2‐limited from occupancy (Can't Issue enough load/stores) - double bw_limited = compute_mem_bw_from_occupancy(hardware, num_active_cus); - - // 8) loads that reach each level - double Ld_mem2 = (1.0 - H_mem1) * total_Ld; - double Ld_MEM = (1.0 - H_mem2) * Ld_mem2; - - // 9) enforce whole‐problem minimum loads when we can fit M/N in the CUs. - // Calculate the tile of workgroups that can run concurrently (logic from estimate_mall_hit). - size_t grid_m = math::safe_ceil_div(problem.size.m, MT_M); - size_t grid_n = math::safe_ceil_div(problem.size.n, MT_N); - size_t mall_m = - math::safe_ceil_div(num_active_cus, static_cast(config.workgroup_mapping)); - size_t mall_n = std::min(static_cast(config.workgroup_mapping), grid_n); - // Handle wrap-around case - if (mall_m > grid_m) { - size_t num_wraps = (mall_m / grid_m); - mall_n += (num_wraps * config.workgroup_mapping); - mall_m = grid_m; - } - // Clamp tile dimensions - mall_m = std::max(std::min(grid_m, mall_m), static_cast(1)); - mall_n = std::max(std::min(grid_n, mall_n), static_cast(1)); - // This is the minimum unique bytes needed from HBM to feed the concurrent workgroups. - double min_load = static_cast((mall_m * config.mt.mk() * static_cast(a_bytes)) + - (mall_n * config.mt.nk() * static_cast(b_bytes))) * - batch; // Apply batching to the minimum load itself. - // The actual loads cannot be less than this physical minimum. - Ld_MEM = std::max(Ld_MEM, min_load); - Ld_mem2 = std::max(Ld_mem2, min_load); - - // 10) mem2 latency - double limited_mem2_bw = (hardware.mem2_perf_ratio * bw_limited); - double L_mem_mem2 = (limited_mem2_bw > 0) ? (Ld_mem2 / limited_mem2_bw) : 0.0; - - // 11) MEM latency - double limited_mem_bw = (hardware.mem3_perf_ratio * bw_limited); - double L_mem_MEM = (limited_mem_bw > 0) ? (Ld_MEM / limited_mem_bw) : 0.0; - L_mem_MEM += 200; // Load Latency - - // 12) pick the worst‐case bound - double L_mem = std::max({L_mem_mem1, L_mem_mem2, L_mem_MEM}); - - if (get_runtime_options(config).debug_enabled) { - config.logger.log("mem1_perf_ratio", hardware.mem1_perf_ratio); - config.logger.log("mem2_perf_ratio", hardware.mem2_perf_ratio); - config.logger.log("mem3_perf_ratio", hardware.mem3_perf_ratio); - config.logger.log("mem_bw_per_wg_coefficients(0)", - std::get<0>(hardware.mem_bw_per_wg_coefficients)); - config.logger.log("mem_bw_per_wg_coefficients(1)", - std::get<1>(hardware.mem_bw_per_wg_coefficients)); - config.logger.log("mem_bw_per_wg_coefficients(2)", - std::get<2>(hardware.mem_bw_per_wg_coefficients)); - config.logger.log("H_mem1 (mem1 hit ratio)", H_mem1); - config.logger.log("H_mem2 (mem2 hit ratio)", H_mem2); - config.logger.log("Total Load (bytes)", total_Ld); - config.logger.log("Ld_mem2 (bytes)", Ld_mem2); - config.logger.log("Ld_MEM (bytes)", Ld_MEM); - config.logger.log("L_mem_mem1 (cycles)", L_mem_mem1); - config.logger.log("L_mem_mem2 (cycles)", L_mem_mem2); - config.logger.log("L_mem_MEM (cycles)", L_mem_MEM); - config.logger.log("MT_K % 128 bytes", MT_K * static_cast(b_bytes) % 128); - config.logger.log("MT_M % 128 bytes", MT_M * static_cast(a_bytes) % 128); - config.logger.log("MT_N % 128 bytes", MT_N * static_cast(b_bytes) % 128); - config.logger.log( - "MT_N % 128 + MT_M % 128 bytes", - (MT_M * static_cast(a_bytes) % 128) + MT_N * static_cast(b_bytes) % 128); - config.logger.log( - "MT_N % 64 + MT_M % 64 bytes", - (MT_M * static_cast(a_bytes) % 64) + MT_N * static_cast(b_bytes) % 64); - config.logger.log("MT_K % 64 bytes", MT_K * static_cast(b_bytes) % 64); - config.logger.log("MT_M % 64 bytes", MT_M * static_cast(a_bytes) % 64); - config.logger.log("MT_N % 64 bytes", MT_N * static_cast(b_bytes) % 64); - config.logger.log("Tile Arithmetic Intensity", - MT_M * MT_N * MT_K / (MT_M * MT_K + MT_N * MT_K)); - } - - return L_mem; -} - -/* ---------------------------------------------------------------------------------------- */ -/* Tile-related functions */ -/* ---------------------------------------------------------------------------------------- */ -double compute_tile_latency(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - size_t num_active_cus, - size_t splitting_factor) { - // Extract parameters from structured types - const size_t K = problem.size.k; - size_t batch = problem.batch; - - const size_t MT_M = config.mt.m; - const size_t MT_N = config.mt.n; - const size_t MT_K = config.mt.k; - - const auto a_bits = datatype_to_bits(problem.a_dtype); - const auto b_bits = datatype_to_bits(problem.b_dtype); - const size_t a_bytes = static_cast(data_type_to_bytes(problem.a_dtype)); - const size_t d_bytes = static_cast(data_type_to_bytes(problem.d_dtype)); - - // 1) Compute per-tile latencies - double L_compute = compute_mt_compute_latency(problem, hardware, config); - - double L_mem = - compute_memory_latency(problem, hardware, config, num_active_cus, splitting_factor); - - // TODO Does work utilization need to be 128-byte rounded for a cache line? - double utilization = calculate_work_utilization(problem, config); - double output_utilization = calculate_output_utilization(problem, config, 1UL); - // The effective latency per useful operation increases as utilization drops. - // This penalty affects BOTH compute and memory bounds for the tile's core work. - double effective_tile_penalty = (utilization > 1e-9) ? (1.0 / (utilization)) : 1.0; - double output_utilization_penalty = - (output_utilization > 1e-9) ? (1.0 / (output_utilization)) : 1.0; - // 2) Work-group setup & iteration latencies - double L_WG_setup = 1; // WG_setup_Latency - - // 3) Prologue: 2.2× memory latency - double L_prologue = 1.5 * L_mem; // 1.5 chosen emprically - - // L_compute *= std::max(L_compute, L_LDS); - - // 4) Epilogue: writes from all active CUs with limited bandwidth - double mem_bw_occ = compute_mem_bw_from_occupancy(hardware, num_active_cus); - double mem_bw_occ_limited = hardware.mem3_perf_ratio * mem_bw_occ; - size_t MT_M_rounded_128bytes = round_elements_to_128B(MT_M, datatype_to_bits(problem.a_dtype)); - - double L_epilogue = (static_cast(num_active_cus / splitting_factor) * - MT_M_rounded_128bytes * MT_N * static_cast(d_bytes)) / - mem_bw_occ_limited; - // One compute iteration happens in the prologue - L_epilogue += L_compute * effective_tile_penalty; - // Epilogue and Prologue overhead are reduced with higher occupancy kernels. - int grid_m = static_cast(math::safe_ceil_div(problem.size.m, MT_M)); - int grid_n = static_cast(math::safe_ceil_div(problem.size.n, MT_N)); - - size_t real_occupancy = - std::min(std::max(config.occupancy, static_cast(1)), - static_cast(math::safe_ceil_div(grid_m * grid_n * batch * splitting_factor, - hardware.N_CU))); // Number of WGs per CU. - - L_prologue = L_prologue * pow(0.95, real_occupancy); // Factor chosen empirically - L_epilogue = L_epilogue * pow(0.95, real_occupancy); // Factor chosen empirically - // 4') K-split reductions are globally coherent, we need to write and read split-1 MT_M*MT_N - // tiles to coherent memory - if (splitting_factor > 1) { - size_t n_partials = splitting_factor - 1; - - // Only the reduction CU reads from all splits. - double partial_read_bytes = - grid_m * grid_n * n_partials * MT_M_rounded_128bytes * MT_N * static_cast(d_bytes); - - // All CUs write (once for each partial, and once by the reduction CU for the output.) - double partial_write_bytes = - grid_m * grid_n * MT_M_rounded_128bytes * MT_N * static_cast(d_bytes); - - double partial_readwrite_bytes = partial_read_bytes + partial_write_bytes; - - // 64 Threads active in a SIMD. Exposed to at least latency of reducing splitting_factor - // tiles. - double partial_adds = - (static_cast(config.mt.mn()) * static_cast(splitting_factor)) / (64); - - double L_reduce = partial_readwrite_bytes / (mem_bw_occ_limited); - L_epilogue += L_reduce + partial_adds + 10000; - } - // 4'') tf32 emu has some more overhead - double L_cvt = 0; - if ((problem.mi_dtype == data_type_t::XFloat32) && - (hardware.arch == hardware_t::architecture_t::gfx950)) { - L_cvt = compute_cvt_overhead(problem, hardware, config); - } else if ((a_bits == 32) && (b_bits == 32) && (problem.mi_dtype == data_type_t::BFloat16) && - (hardware.arch == hardware_t::architecture_t::gfx950)) // SS_BSS on GFX950 - { - L_cvt = compute_cvt_overhead_x1(problem, hardware, config); - } - - // 5) Single-tile latency (always additive) - // Calculate the fraction of the work that is useful (not padding). - - // 5) Single-tile latency (apply penalty after finding the bottleneck) - double L_tile_single = (std::max(L_compute, L_mem) * effective_tile_penalty) + L_cvt; - L_prologue *= effective_tile_penalty; - // 6) Number of K-iterations (excluding epilogue), at least 1 - // long num_iter = static_cast(((K + MT_K - 1) / MT_K)) - 1; - // num_iter = std::ceil(num_iter / splitting_factor); - // num_iter = std::max(num_iter, 1L); - const long k_per_split = static_cast(math::safe_ceil_div(K, splitting_factor)); - long num_iter = - std::max(static_cast(math::safe_ceil_div(static_cast(k_per_split), MT_K) - 1), - static_cast(1)); - // Zero Padding in the K dimension on last iteration - if (K % MT_K != 0) { - const double problem_k_quant = static_cast(K % MT_K) / static_cast(K); - L_epilogue += problem_k_quant * 50000; // Scale by remainder proportion of problem. 50k cycle - // penalty if have to zero pad all except 1. - //(Scale Determined Empirically) - } - // L_epilogue *= output_utilization_penalty; - - // 7) Total tile latency - double L_tile_total = - (L_tile_single * static_cast(num_iter)) + L_prologue + L_epilogue * 2 + L_WG_setup + - (500 * static_cast( - num_iter)); // 7 instructions (each with 4 cycles) at the end of the loop - - if (get_runtime_options(config).debug_enabled) { - double problem_k_quant = ((K % MT_K) / (double)K); - config.logger.log("Iteration Compute Latency", L_compute); - config.logger.log("L_mem", L_mem); - config.logger.log("L_cvt", L_cvt); - config.logger.log("L_tile_single", L_tile_single); - config.logger.log("num_iter", num_iter); - config.logger.log("L_prologue", L_prologue); - config.logger.log("L_epilogue", L_epilogue); - config.logger.log("L_tile_total", L_tile_total); - config.logger.log("Effective Tile Penalty", effective_tile_penalty); - config.logger.log("Problem K quant", problem_k_quant); - config.logger.log("K quant overhead", (problem_k_quant * 50000)); - config.logger.log("Problem Tile Quant", utilization); - config.logger.log("Real Occupancy", utilization); - config.logger.log("Output Utilization Penalty", output_utilization_penalty); - config.logger.log("Output Utilization", output_utilization); - std::string bound_source; - if (L_compute >= L_mem) { - L_tile_single = L_compute + L_cvt; - bound_source = "Compute"; - } else { - L_tile_single = L_mem + L_cvt; - bound_source = "Memory"; + double estimate_l2_hit(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t element_size, + int WGM, + size_t splittingFactor) + { + // Use size_t for dimensions and counts to ensure type safety. + const size_t workgroups_m = safe_ceil_div(M, MT_M); + const size_t workgroups_n = safe_ceil_div(N, MT_N); + const size_t total_workgroups = workgroups_m * workgroups_n; + + // Concurrently executing workgroups are limited by the number of CUs.a + const size_t concurrent_workgroups = std::min(total_workgroups, hardware.N_CU); + if(concurrent_workgroups == 0) + throw std::runtime_error("#Workgroups is zero in estimate l2 hit"); + + // Number of CUs that might share the same K-tiles, adjusted for K-splitting. + // This affects contention on the L2 cache partitions (XCDs). + const size_t effective_cus = safe_ceil_div(concurrent_workgroups, splittingFactor); + const size_t cu_per_xcd = std::max(safe_ceil_div(effective_cus, hardware.NUM_XCD), static_cast(1)); + + // Initial guess for the L2 tile dimensions (a tile of workgroups). + size_t l2_tile_n = std::min(static_cast(WGM), workgroups_n); + size_t l2_tile_m = safe_ceil_div(cu_per_xcd, l2_tile_n); + + // Handle wrap-around case: if the tile is taller than the grid, wrap it to be wider. + if(l2_tile_m > workgroups_m) + { + size_t num_wraps = (l2_tile_m / workgroups_m); + l2_tile_n += (num_wraps * WGM); + l2_tile_m = workgroups_m; + } + + // Clamp initial tile dimensions to the actual grid size. + l2_tile_m = std::max(std::min(workgroups_m, l2_tile_m), static_cast(1)); + l2_tile_n = std::max(std::min(workgroups_n, l2_tile_n), static_cast(1)); + + // Calculate memory footprint in bytes. + const size_t element_bytes = safe_ceil_div(element_size, 8); + auto calculate_footprint = [&](size_t tile_m, size_t tile_n) { + size_t a_footprint = tile_m * MT_M * MT_K * element_bytes; + size_t b_footprint = tile_n * MT_N * MT_K * element_bytes; + return a_footprint + b_footprint; + }; + + // Symmetrically shrink the L2 tile until it fits in the L2 cache capacity. + // This is more robust than shrinking only one dimension. + while(calculate_footprint(l2_tile_m, l2_tile_n) > hardware.L2_capacity) + { + if(l2_tile_m > 1 && l2_tile_m >= l2_tile_n) + { + l2_tile_m--; + } + else if(l2_tile_n > 1) + { + l2_tile_n--; + } + else + { + // Cannot shrink further. + break; + } + } + + // Uncached reads are the first read of each unique element within the L2 tile. + const long long uncached_A_reads = static_cast(l2_tile_m) * MT_M * MT_K; + const long long uncached_B_reads = static_cast(l2_tile_n) * MT_N * MT_K; + const long long total_uncached_reads = uncached_A_reads + uncached_B_reads; + + // Total reads are the sum of all reads performed by all workgroups in the L2 tile. + // Matrix A is reused l2_tile_n times, Matrix B is reused l2_tile_m times. + const long long total_A_reads = uncached_A_reads * l2_tile_n; + const long long total_B_reads = uncached_B_reads * l2_tile_m; + const long long total_reads = std::max(total_A_reads + total_B_reads, 1LL); + + const long long cached_reads = total_reads - total_uncached_reads; + + double l2_hit_rate = static_cast(cached_reads) / static_cast(total_reads); + + // Final clamping and logging. + if(hardware_t::is_debug_enabled()) + { + hardware.log_debug("L2Tile_M", l2_tile_m); + hardware.log_debug("L2Tile_N", l2_tile_n); + hardware.log_debug("TotalWorkgroups", total_workgroups); + hardware.log_debug("ConcurrentWorkgroups", concurrent_workgroups); + } + + // Clamp the hit rate to be within a realistic [0, 1] range. + return std::max(0.0, std::min(l2_hit_rate, 1.0)); } - config.logger.log("Iteration Bound", bound_source + " (" + std::to_string(L_tile_single) + ")"); - config.logger.log("K % MT_K", K % MT_K); - } - - return L_tile_total; -} - -// Computes the latency per K-complete MT wave -// A wave is defined as : The time it takes for one CU to complete one K-complete output tile -double compute_timestep_latency(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - size_t num_active_cus, - size_t splitting_factor) { - // Assume latency of a wave is latency of a single k-complete output tile. - double L_wave = compute_tile_latency(problem, hardware, config, num_active_cus, splitting_factor); - - return L_wave; -} - -// Compute the total latency of a gemm based on the latency of one wave multiplied by the number of -// waves A wave is defined as : The time it takes for one CU to complete one K-complete output tile -double compute_total_latency(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - size_t max_cus) { - assert(config.is_valid()); - - // Extract parameters from structured types - size_t M = problem.size.m; - size_t N = problem.size.n; - size_t K = problem.size.k; - size_t batch = problem.batch; - - bool a_trans = problem.a_transpose == transpose_t::T; - bool b_trans = problem.b_transpose == transpose_t::T; - - size_t MT_M = config.mt.m; - size_t MT_N = config.mt.n; - size_t MT_K = config.mt.k; - size_t MI_M = config.mi.m; - size_t MI_N = config.mi.n; - size_t MI_K = config.mi.k; - - const int a_bits = datatype_to_bits(problem.a_dtype); - const int b_bits = datatype_to_bits(problem.b_dtype); - const int a_bytes = data_type_to_bytes(problem.a_dtype); - const int d_bytes = data_type_to_bytes(problem.d_dtype); - - if (get_runtime_options(config).debug_enabled) { - config.logger.log( - "Problem_Size", - std::to_string(int(M)) + "x" + std::to_string(int(N)) + "x" + std::to_string(int(K))); - config.logger.log("Batch", std::to_string(int(batch))); - config.logger.log("Macro_Tile", - std::to_string(int(MT_M)) + "x" + std::to_string(int(MT_N)) + "x" + - std::to_string(int(MT_K))); - config.logger.log("Element Size A (bits)", a_bits); - config.logger.log("Element Size B (bits)", b_bits); - } - - // 0) Short-circuit - // We don't need to compute latency for all MTs. With this, we can shortcut. - bool shortCircuit = true; - if (shortCircuit) { - // When problem dimensions are small enough that we can fit them in one tile, we should do - // so. This short circuit condition also decreases selection latency when problems are very - // small :) - // TODO 256 and 256 here should be largest M and N tile dimensions in library - if (M <= 256 && N <= 256 && K < 1024 && batch != 1 && (MT_M < M || MT_N < N)) - return std::numeric_limits::max(); - - // Use Dot2 only for M < 3 - if (MI_M == 1 && MI_N == 1 && MI_K == 64 && M > 2) return std::numeric_limits::max(); - - size_t K_mod_128bytes = K * a_bytes % 128; - size_t MT_K_mod_128bytes = MT_K * a_bytes % 128; - if (K_mod_128bytes == 0 && MT_K_mod_128bytes == 0) { - // avoid division by 0 if K == 0 - if (M <= MT_M * 2 && !b_trans && ((N * b_bits) / (M * a_bits) > 5)) { - // Use nontemporal B - if (!(config.cache_hints_b == 4)) { return std::numeric_limits::max(); } - } else if (N <= MT_N * 2 && a_trans && ((M * a_bits) / (N * b_bits) > 5)) { - // Use Non Temporal A - if (!(config.cache_hints_a == 4)) { return std::numeric_limits::max(); } - } else { - // Never use Non Temporal - if (config.cache_hints_a || config.cache_hints_b) { - return std::numeric_limits::max(); + + // Estimate MALL hit-rate + double estimate_mall_hit(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t element_size, + int WGM, + size_t numActiveCUs, + size_t splittingFactor) + { + const size_t workgroups_m = safe_ceil_div(M, MT_M); + const size_t workgroups_n = safe_ceil_div(N, MT_N); + + if(numActiveCUs == 0) + throw std::runtime_error("Number of Active CUs was 0"); + + // --- Initial Tile Sizing based on Concurrency --- + // Use ceiling division for a more accurate initial guess. + size_t mall_tile_m = safe_ceil_div(numActiveCUs, static_cast(WGM)); + size_t mall_tile_n = std::min(static_cast(WGM), workgroups_n); + + // Handle wrap-around case if the tile is taller than the grid. + if(mall_tile_m > workgroups_m) + { + size_t num_wraps = mall_tile_m / workgroups_m; + mall_tile_n += (num_wraps * WGM); + mall_tile_m = workgroups_m; } - } - } else if (config.cache_hints_a || config.cache_hints_b) { - return std::numeric_limits::max(); + + // Clamp initial tile dimensions to the actual grid size. + mall_tile_m = std::max(std::min(workgroups_m, mall_tile_m), static_cast(1)); + mall_tile_n = std::max(std::min(workgroups_n, mall_tile_n), static_cast(1)); + + // --- CRITICAL: Shrink tile to fit into MALL Capacity --- + const size_t element_bytes = safe_ceil_div(element_size, 8); + auto calculate_footprint = [&](size_t tile_m, size_t tile_n) { + size_t a_footprint = tile_m * MT_M * MT_K * element_bytes; + size_t b_footprint = tile_n * MT_N * MT_K * element_bytes; + return a_footprint + b_footprint; + }; + + // --- Calculate Hit Rate based on the final, capacity-aware tile size --- + const long long uncached_A_reads = static_cast(mall_tile_m) * MT_M * MT_K; + const long long uncached_B_reads = static_cast(mall_tile_n) * MT_N * MT_K; + const long long total_uncached_reads = uncached_A_reads + uncached_B_reads; + + const long long total_A_reads = uncached_A_reads * mall_tile_n; + const long long total_B_reads = uncached_B_reads * mall_tile_m; + const long long total_reads = std::max(total_A_reads + total_B_reads, 1LL); + + const long long cached_reads = total_reads - total_uncached_reads; + + double mall_hit_rate = static_cast(cached_reads) / static_cast(total_reads); + + if(hardware_t::is_debug_enabled()) + { + hardware.log_debug("MallTile_M", mall_tile_m); + hardware.log_debug("MallTile_N", mall_tile_n); + hardware.log_debug("MallFootprint_Bytes", + calculate_footprint(mall_tile_m, mall_tile_n)); + } + + // Clamp the final result to the valid [0, 1] range. + return std::max(0.0, std::min(mall_hit_rate, 1.0)); } - } - // 1-1) To compute the latency, use default WGM. And WGM can't be greater than one - int defaultWGM = static_cast(ceil(std::sqrt(hardware.N_CU / hardware.NUM_XCD))); - auto config_with_default_wgm = config; - config_with_default_wgm.workgroup_mapping = std::max(defaultWGM, 1); + /** + @brief Computes the L2 hit rate from a global, + problem - wide perspective. + **/ + double compute_l2_hit_rate_global(size_t M, + size_t N, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t element_size, + size_t l2_capacity_bytes) + { + // --- Hardware Parameters (as requested, defined locally) --- + // You would normally get l2_capacity_bytes from your hardware_t struct. + if(l2_capacity_bytes == 0) + throw std::runtime_error("L2 Capacity is zero"); + ; + + // 1. Calculate the grid dimensions in terms of macro-tiles + const size_t grid_m = safe_ceil_div(M, MT_M); + const size_t grid_n = safe_ceil_div(N, MT_N); + + if(grid_m == 0 || grid_n == 0) + throw std::runtime_error("estimate_l2_hit grid dimensions can not be zero"); + ; + + // 2. Calculate the working set size for one full pass of global reuse + // This is the data needed by one full column of CUs (for A) and one full row (for B). + const double bytes_per_element = static_cast(element_size) / 8.0; + const double a_working_set = static_cast(grid_m * MT_M * MT_K) * bytes_per_element; + const double b_working_set = static_cast(grid_n * MT_N * MT_K) * bytes_per_element; + const double total_working_set_bytes = a_working_set + b_working_set; + + // 3. CRUCIAL: Check if the working set fits in the L2 cache. + // If it doesn't, the global reuse pattern is broken by capacity misses, + // and the hit rate will be very low. + if(total_working_set_bytes > l2_capacity_bytes) + { + // Return a floor value for the hit rate. The exact value can be tuned, + // but it should be low to indicate that the ideal reuse is not possible. + return 0.1; // 10% hit rate + } - // 1-2) Find CU occupancy - auto [num_wgs, num_active_cus, numWaves, splitting_factor] = compute_cu_occupancy( - problem, hardware, config_with_default_wgm, grid_selection_t::k_split_aware, max_cus); + // 4. If it fits, calculate the idealized global hit rate + // Total reads if nothing was cached + const double total_A_reads = static_cast(grid_m * grid_n * MT_M * MT_K); + const double total_B_reads = static_cast(grid_m * grid_n * MT_N * MT_K); - // 2) Compute latency of a wave - // Compute latency of a wave - double L_wave = compute_timestep_latency( - problem, hardware, config_with_default_wgm, num_active_cus, splitting_factor); + // Uncached reads are the first-time fetches for each row/column + const double uncached_A_reads + = static_cast(grid_m * MT_M * MT_K); // One full column fetches A + const double uncached_B_reads + = static_cast(grid_n * MT_N * MT_K); // One full row fetches B - // Compute latency for all waves and return it as the latency for the MT/problem - double total_latency = L_wave * numWaves; + const double total_reads = total_A_reads + total_B_reads; + if(total_reads == 0) + return 1.0; // No reads, perfect hit rate. - // 3) Customized heuristics - // TODO These are quantifying effects that don't work in the current math. - // TODO THESE SHOULD BE TEMPORARY FIXES AND BE MORE SOLIDLY INTEGRATED LATER - bool heuristics = get_runtime_options(config).heuristics_enabled; + const double cached_reads + = (total_A_reads - uncached_A_reads) + (total_B_reads - uncached_B_reads); - if (heuristics) { - if (MT_M == 64 && MT_N == 32 && MT_K == 32 && !b_trans && a_bits == 16) { - total_latency = total_latency * 10; + return cached_reads / total_reads; } - bool tf32_emu = ((problem.mi_dtype == data_type_t::XFloat32) && - (hardware.arch == hardware_t::architecture_t::gfx950)); + inline size_t round_up_mul(size_t x, size_t m) + { + return (x + m - 1) / m * m; + } - // Heuristics for TF32 - if (tf32_emu) { - double bytes_per_element = static_cast(a_bytes); - double arith = emulated_tf32_arithmetic_intensity(M, N, K, bytes_per_element); - double compute_threshold = 1000; // threshold empirically determined. + size_t round_elements_to_128B(size_t elements, size_t element_size_bits) + { + const size_t transaction_bits = 128u * 8u; // 1024 + const size_t g = std::gcd(element_size_bits, transaction_bits); + const size_t E_block = transaction_bits / g; // elements per 128B-aligned chunk + return round_up_mul(elements, E_block); + } - // The kernel for this is more optimized (Custom kernel NT) - if ((!a_trans && b_trans) && MT_M == 256 && MT_N == 256 && MT_K == 32) { - if (arith < compute_threshold) - total_latency = total_latency * 0.6; - else - total_latency = total_latency * 0.4; - } + // Determine the memory latency + double compute_memory_latency(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + bool transA, + bool transB, + size_t batch, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t element_size_A, + size_t element_size_B, + size_t mx_block_size, + int WGM, + size_t numActiveCUs, + size_t splittingFactor) + { + // 1) Estimate L2 hit-rate + double H_mem1 = estimate_l2_hit( + hardware, M, N, K, batch, MT_M, MT_N, MT_K, element_size_A, WGM, splittingFactor); + + double H_mem1_global = compute_l2_hit_rate_global( + M, N, MT_M, MT_N, MT_K, element_size_A, hardware.L2_capacity * 1024); + + H_mem1 = std::min(H_mem1, H_mem1_global); + + if(H_mem1 == 0) + { + H_mem1 = 0.5; + } - // The kernel for this is more optimized (Custom kernel NN) - if ((!a_trans && !b_trans) && MT_M == 256 && MT_N == 256 && MT_K == 32) { - if (arith < compute_threshold) - total_latency = total_latency * 0.8; - else - total_latency = total_latency * 0.4; - } + // 2) Estimate mall hit-rate + double H_mem2 = estimate_mall_hit(hardware, + M, + N, + K, + batch, + MT_M, + MT_N, + MT_K, + element_size_A, + WGM, + numActiveCUs, + splittingFactor); + + // 3) Total loads are loads from A and loads from B + size_t MT_M_rounded_128bytes = round_elements_to_128B(MT_M, element_size_A); + size_t MT_N_rounded_128bytes = round_elements_to_128B(MT_N, element_size_A); + size_t MT_K_rounded_128bytes = round_elements_to_128B(MT_K, element_size_A); + if(!transA && !transB) + { + MT_N_rounded_128bytes = MT_N; + MT_K_rounded_128bytes = MT_K; + } + else if(transA && !transB) + { + MT_M_rounded_128bytes = MT_M; + MT_N_rounded_128bytes = MT_N; + } + else if(!transA && transB) + { + MT_K_rounded_128bytes = MT_K; + } - // The kernel for this is more optimized (Custom kernel TN) - if ((a_trans && !b_trans) && MT_M == 256 && MT_N == 256 && MT_K == 32) { - if (arith < compute_threshold) - total_latency = total_latency * 0.8; - else - total_latency = total_latency * 0.4; - } + size_t Ld_A_value = compute_A_loads(MT_M_rounded_128bytes, MT_K_rounded_128bytes); + size_t Ld_B_value = compute_B_loads(MT_N_rounded_128bytes, MT_K_rounded_128bytes); + size_t Ld_CU_bytes = (Ld_A_value * safe_ceil_div(element_size_A, 8)) // A Bytes + + (Ld_B_value * safe_ceil_div(element_size_B, 8)); // B Bytes + + // Logic for block scaled datatypes (Assuming BS=32 and 8-bit scales) + // TODO This is technically wrong, need separate flag to enable MX so we can differentiate FP8 + // and MX8 + if(element_size_A < 8 && mx_block_size != 0) + { + // Number of scales per tile + size_t num_scales_A = safe_ceil_div(MT_M * MT_K, mx_block_size); + Ld_CU_bytes += num_scales_A; // One Byte per scale + } + if(element_size_B < 8 && mx_block_size != 0) + { + // Number of scales per tile + size_t num_scales_B = safe_ceil_div(MT_N * MT_K, mx_block_size); + Ld_CU_bytes += num_scales_B; // One Byte per scale + } + + // 4) total loads by all CUs + double total_Ld = Ld_CU_bytes * static_cast(numActiveCUs); + + // 5) mem1‐limited factor (simple linear model) + double mem1_bw_limited + = static_cast(numActiveCUs) / static_cast(hardware.N_CU); + double limited_mem1_bw = (hardware.mem1_perf_ratio * mem1_bw_limited); + + // 6) mem1 latency + double L_mem_mem1 = (limited_mem1_bw > 0) ? (total_Ld / (limited_mem1_bw)) : 0.0; + + // 7) mem2‐limited from occupancy (Can't Issue enough load/stores) + double bw_limited = compute_mem_bw_from_occupancy(hardware, numActiveCUs); + + // 8) loads that reach each level + double Ld_mem2 = (1.0 - H_mem1) * total_Ld; + double Ld_MEM = (1.0 - H_mem2) * Ld_mem2; + + // 9) enforce whole‐problem minimum loads when we can fit M/N in the CUs. + // Calculate the tile of workgroups that can run concurrently (logic from estimate_mall_hit). + size_t grid_m = safe_ceil_div(M, MT_M); + size_t grid_n = safe_ceil_div(N, MT_N); + size_t mall_m = safe_ceil_div(numActiveCUs, static_cast(WGM)); + size_t mall_n = std::min(static_cast(WGM), grid_n); + // Handle wrap-around case + if(mall_m > grid_m) + { + size_t num_wraps = (mall_m / grid_m); + mall_n += (num_wraps * WGM); + mall_m = grid_m; + } + // Clamp tile dimensions + mall_m = std::max(std::min(grid_m, mall_m), static_cast(1)); + mall_n = std::max(std::min(grid_n, mall_n), static_cast(1)); + // This is the minimum unique bytes needed from HBM to feed the concurrent workgroups. + double min_load + = static_cast((mall_m * MT_M * MT_K * safe_ceil_div(element_size_A, 8)) + + (mall_n * MT_N * MT_K * safe_ceil_div(element_size_B, 8))) + * batch; // Apply batching to the minimum load itself. + // The actual loads cannot be less than this physical minimum. + Ld_MEM = std::max(Ld_MEM, min_load); + Ld_mem2 = std::max(Ld_mem2, min_load); + + // 10) mem2 latency + double limited_mem2_bw = (hardware.mem2_perf_ratio * bw_limited); + double L_mem_mem2 = (limited_mem2_bw > 0) ? (Ld_mem2 / limited_mem2_bw) : 0.0; + + // 11) MEM latency + double limited_mem_bw = (hardware.mem3_perf_ratio * bw_limited); + double L_mem_MEM = (limited_mem_bw > 0) ? (Ld_MEM / limited_mem_bw) : 0.0; + L_mem_MEM += 200; // Load Latency + + // 12) pick the worst‐case bound + double L_mem = std::max({L_mem_mem1, L_mem_mem2, L_mem_MEM}); + + if(hardware_t::is_debug_enabled()) + { + hardware.log_debug("mem1_perf_ratio", hardware.mem1_perf_ratio); + hardware.log_debug("mem2_perf_ratio", hardware.mem2_perf_ratio); + hardware.log_debug("mem3_perf_ratio", hardware.mem3_perf_ratio); + hardware.log_debug("mem_bw_per_wg_coefficients(0)", + std::get<0>(hardware.mem_bw_per_wg_coefficients)); + hardware.log_debug("mem_bw_per_wg_coefficients(1)", + std::get<1>(hardware.mem_bw_per_wg_coefficients)); + hardware.log_debug("mem_bw_per_wg_coefficients(2)", + std::get<2>(hardware.mem_bw_per_wg_coefficients)); + hardware.log_debug("H_mem1 (mem1 hit ratio)", H_mem1); + hardware.log_debug("H_mem2 (mem2 hit ratio)", H_mem2); + hardware.log_debug("Total Load (bytes)", total_Ld); + hardware.log_debug("Ld_mem2 (bytes)", Ld_mem2); + hardware.log_debug("Ld_MEM (bytes)", Ld_MEM); + hardware.log_debug("L_mem_mem1 (cycles)", L_mem_mem1); + hardware.log_debug("L_mem_mem2 (cycles)", L_mem_mem2); + hardware.log_debug("L_mem_MEM (cycles)", L_mem_MEM); + hardware.log_debug("MT_K % 128 bytes", MT_K * safe_ceil_div(element_size_B, 8) % 128); + hardware.log_debug("MT_M % 128 bytes", MT_M * safe_ceil_div(element_size_A, 8) % 128); + hardware.log_debug("MT_N % 128 bytes", MT_N * safe_ceil_div(element_size_B, 8) % 128); + hardware.log_debug("MT_N % 128 + MT_M % 128 bytes", + (MT_M * safe_ceil_div(element_size_A, 8) % 128) + + MT_N * safe_ceil_div(element_size_B, 8) % 128); + hardware.log_debug("MT_N % 64 + MT_M % 64 bytes", + (MT_M * safe_ceil_div(element_size_A, 8) % 64) + + MT_N * safe_ceil_div(element_size_B, 8) % 64); + hardware.log_debug("MT_K % 64 bytes", MT_K * safe_ceil_div(element_size_B, 8) % 64); + hardware.log_debug("MT_M % 64 bytes", MT_M * safe_ceil_div(element_size_A, 8) % 64); + hardware.log_debug("MT_N % 64 bytes", MT_N * safe_ceil_div(element_size_B, 8) % 64); + hardware.log_debug("Tile Arithmetic Intensity", + MT_M * MT_N * MT_K / (MT_M * MT_K + MT_N * MT_K)); + } + + return L_mem; + } + + /* ---------------------------------------------------------------------------------------- */ + /* Tile-related functions */ + /* ---------------------------------------------------------------------------------------- */ + double compute_tile_latency(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, + size_t element_size_B, + size_t element_size_out, + data_type_t mi_datatype, + size_t mx_block_size, + int WGM, + int occupancy, + size_t numActiveCUs, + size_t splittingFactor) + { + // 1) Compute per-tile latencies + double L_compute = compute_mt_compute_latency(hardware, + M, + N, + K, + transA, + transB, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + mi_datatype); + + double L_mem = compute_memory_latency(hardware, + M, + N, + K, + transA, + transB, + batch, + MT_M, + MT_N, + MT_K, + element_size_A, + element_size_B, + mx_block_size, + WGM, + numActiveCUs, + splittingFactor); + + // TODO Does work utilization need to be 128-byte rounded for a cache line? + double utilization = calculate_work_utilization(M, N, K, MT_M, MT_N, MT_K); + double output_utilization = calculate_output_utilization(M, N, MT_M, MT_N, 1); + // The effective latency per useful operation increases as utilization drops. + // This penalty affects BOTH compute and memory bounds for the tile's core work. + double effective_tile_penalty = (utilization > 1e-9) ? (1.0 / (utilization)) : 1.0; + double output_utilization_penalty + = (output_utilization > 1e-9) ? (1.0 / (output_utilization)) : 1.0; + // 2) Work-group setup & iteration latencies + double L_WG_setup = 1; // WG_setup_Latency + + // 3) Prologue: 2.2× memory latency + double L_prologue = 1.5 * L_mem; // 1.5 chosen emprically + + // L_compute *= std::max(L_compute, L_LDS); + + // 4) Epilogue: writes from all active CUs with limited bandwidth + double mem_bw_occ = compute_mem_bw_from_occupancy(hardware, numActiveCUs); + double mem_bw_occ_limited = hardware.mem3_perf_ratio * mem_bw_occ; + size_t MT_M_rounded_128bytes = round_elements_to_128B(MT_M, element_size_A); + + double L_epilogue = (static_cast(numActiveCUs / splittingFactor) + * MT_M_rounded_128bytes * MT_N * safe_ceil_div(element_size_out, 8)) + / mem_bw_occ_limited; + // One compute iteration happens in the prologue + L_epilogue += L_compute * effective_tile_penalty; + // Epilogue and Prologue overhead are reduced with higher occupancy kernels. + int grid_m = static_cast(safe_ceil_div(M, MT_M)); + int grid_n = static_cast(safe_ceil_div(N, MT_N)); + int real_occupancy = std::min(occupancy, + static_cast(safe_ceil_div(grid_m * grid_n * batch * splittingFactor, + hardware.N_CU))); // Number of WGs per CU. + L_prologue = L_prologue * pow(0.95, real_occupancy); // Factor chosen empirically + L_epilogue = L_epilogue * pow(0.95, real_occupancy); // Factor chosen empirically + // 4') K-split reductions are globally coherent, we need to write and read split-1 MT_M*MT_N + // tiles to coherent memory + if(splittingFactor > 1) + { + size_t n_partials = splittingFactor - 1; + + // Only the reduction CU reads from all splits. + double partial_read_bytes = grid_m * grid_n * n_partials * MT_M_rounded_128bytes * MT_N + * safe_ceil_div(element_size_out, 8); + + // All CUs write (once for each partial, and once by the reduction CU for the output.) + double partial_write_bytes = grid_m * grid_n * MT_M_rounded_128bytes * MT_N + * safe_ceil_div(element_size_out, 8); + + double partial_readwrite_bytes = partial_read_bytes + partial_write_bytes; + + // 64 Threads active in a SIMD. Exposed to at least latency of reducing splittingfactor + // tiles. + double partial_adds = ((MT_M * MT_N) * splittingFactor) / (64); + // Things have to be written to memory + double mem_bw_occ = compute_mem_bw_from_occupancy(hardware, numActiveCUs); + double mem_bw_occ_limited = hardware.mem3_perf_ratio * mem_bw_occ; + + double L_reduce = partial_readwrite_bytes / (mem_bw_occ_limited); + L_epilogue += L_reduce + partial_adds + 10000; + } + // 4'') tf32 emu has some more overhead + double L_cvt = 0; + if((mi_datatype == data_type_t::XFloat32) + && (hardware.arch == hardware_t::architecture_t::gfx950)) + { + L_cvt = compute_cvt_overhead( + hardware, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K, element_size_A, element_size_B); + } + else if((element_size_A == 32) && (element_size_B == 32) + && (mi_datatype == data_type_t::BFloat16) + && (hardware.arch == hardware_t::architecture_t::gfx950)) // SS_BSS on GFX950 + { + L_cvt = compute_cvt_overhead_x1(hardware, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + mi_datatype); + } + + // 5) Single-tile latency (always additive) + // Calculate the fraction of the work that is useful (not padding). + + // 5) Single-tile latency (apply penalty after finding the bottleneck) + double L_tile_single = (std::max(L_compute, L_mem) * effective_tile_penalty) + L_cvt; + L_prologue *= effective_tile_penalty; + // 6) Number of K-iterations (excluding epilogue), at least 1 + // long num_iter = static_cast(((K + MT_K - 1) / MT_K)) - 1; + // num_iter = std::ceil(num_iter / splittingFactor); + // num_iter = std::max(num_iter, 1L); + long splittedK = static_cast(safe_ceil_div(K, splittingFactor)); + long num_iter + = std::max(static_cast(safe_ceil_div(splittedK, MT_K) - 1), static_cast(1)); + // Zero Padding in the K dimension on last iteration + if(K % MT_K != 0) + { + double problem_k_quant = ((K % MT_K) / (double)K); + L_epilogue + += problem_k_quant + * 50000; // Scale by remainder proportion of problem. 50k cycle penalty if have to zero pad all except 1. + //(Scale Determined Empirically) + } + //L_epilogue *= output_utilization_penalty; + + // 7) Total tile latency + double L_tile_total + = (L_tile_single * num_iter) + L_prologue + L_epilogue * 2 + L_WG_setup + + (500 * num_iter); // 7 instructions (each with 4 cycles) at the end of the loop + + if(MT_K == 1024) + { + L_prologue = L_prologue * 100; + } + + if(hardware_t::is_debug_enabled()) + { + double problem_k_quant = ((K % MT_K) / (double)K); + hardware.log_debug("Iteration Compute Latency", L_compute); + hardware.log_debug("L_mem", L_mem); + hardware.log_debug("L_cvt", L_cvt); + hardware.log_debug("L_tile_single", L_tile_single); + hardware.log_debug("num_iter", num_iter); + hardware.log_debug("L_prologue", L_prologue); + hardware.log_debug("L_epilogue", L_epilogue); + hardware.log_debug("L_tile_total", L_tile_total); + hardware.log_debug("Effective Tile Peanlty", effective_tile_penalty); + hardware.log_debug("Problem K quant", problem_k_quant); + hardware.log_debug("K quant overhead", (problem_k_quant * 50000)); + hardware.log_debug("Problem Tiile Quant", utilization); + hardware.log_debug("Real Occupancy", utilization); + hardware.log_debug("Output Utilization Penalty", output_utilization_penalty); + hardware.log_debug("Output Utilization", output_utilization); + std::string bound_source; + if(L_compute >= L_mem) + { + L_tile_single = L_compute + L_cvt; + bound_source = "Compute"; + } + else + { + L_tile_single = L_mem + L_cvt; + bound_source = "Memory"; + } + hardware.log_debug("Iteration Bound", + bound_source + " (" + std::to_string(L_tile_single) + ")"); + + hardware.log_debug("K % MT_K", K % MT_K); + } + + return L_tile_total; + } + + // Computes the latency per K-complete MT wave + // A wave is defined as : The time it takes for one CU to complete one K-complete output tile + double compute_wave_latency(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, + size_t element_size_B, + size_t element_size_out, + data_type_t mi_datatype, + size_t mx_block_size, + int WGM, + int occupancy, + size_t numActiveCUs, + size_t splittingFactor) + { + // Assume latency of a wave is latency of a single k-complete output tile. + double L_wave = compute_tile_latency(hardware, + M, + N, + K, + batch, + transA, + transB, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + element_size_out, + mi_datatype, + mx_block_size, + WGM, + occupancy, + numActiveCUs, + splittingFactor); + + return L_wave; + } + + // Compute the total latency of a gemm based on the latency of one wave multiplied by the number of + // waves A wave is defined as : The time it takes for one CU to complete one K-complete output tile + double compute_total_latency(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, + size_t element_size_B, + size_t element_size_out, + data_type_t mi_datatype, + size_t mx_block_size, + int WGM, + int non_temporal_a, + int non_temporal_b, + int occupancy, + size_t split, + size_t max_cus) + { + if(hardware_t::is_debug_enabled()) + { + hardware.log_debug("Problem_Size", + std::to_string(int(M)) + "x" + std::to_string(int(N)) + "x" + + std::to_string(int(K))); + hardware.log_debug("Macro_Tile", + std::to_string(int(MT_M)) + "x" + std::to_string(int(MT_N)) + "x" + + std::to_string(int(MT_K))); + hardware.log_debug("Element Size A (bits)", element_size_A); + hardware.log_debug("Element Size B (bits)", element_size_B); + } + + // 0) Short-circuit + // We don't need to compute latency for all MTs. With this, we can shortcut. + bool shortCircuit = true; + if(shortCircuit) + { + // When problem dimensions are small enough that we can fit them in one tile, we should do + // so. This short circuit condition also decreases selection latency when problems are very + // small :) + // TODO 256 and 256 here should be largest M and N tile dimensions in library + if(M <= 256 && N <= 256 && K < 1024 && batch != 1 && (MT_M < M || MT_N < N)) + return std::numeric_limits::max(); + + // Override dot2 instruction with vector lane widths + if(MI_N == 0 && MI_M == 0 && MI_K == 0) + { + // We only use Dot2 for NN layout where M < 3 + if(M > 2 || transA || transB) + return std::numeric_limits::max(); + MI_M = 1; + MI_N = 1; + MI_K = 64; + } + + + + + if(batch == 1) + { + + + + size_t K_mod_128bytes = K * safe_ceil_div(element_size_A, 8) % 128; + size_t MT_K_mod_128bytes = MT_K * safe_ceil_div(element_size_A, 8) % 128; + if(K_mod_128bytes == 0 && MT_K_mod_128bytes == 0 && batch == 1) + { + // avoid division by 0 if K == 0 + if(M <= MT_M *2 && !transB && ((N * element_size_B)/(M * element_size_A) > 5)) + { + //Use nontemporal B + if(!(non_temporal_b == 4)) + { + return std::numeric_limits::max(); + } + } + else if(N <= MT_N *2 && transA && ((M * element_size_A)/(N * element_size_B) > 5)) + { + //Use Non Temporal A + if(!(non_temporal_a == 4)) + { + return std::numeric_limits::max(); + } + } + else + { + //Never use Non Temporal + if(non_temporal_a || non_temporal_b) + { + return std::numeric_limits::max(); + } + } + } + else if(non_temporal_a || non_temporal_b) + { + return std::numeric_limits::max(); + } + + } + + + + } + + // 1-1) WGM + WGM = std::max(WGM, 1); // WGM can't be less than one. + occupancy = std::max(occupancy, 1); // occupancy can't be less than one. + + // 1-2) Find CU occupancy + auto [numWGs, numActiveCUs, numWaves, splittingFactor] + = compute_CU_occupancy(hardware, + M, + N, + K, + batch, + transA, + transB, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + element_size_out, + mi_datatype, + WGM, + std::numeric_limits::max(), // workspace + std::numeric_limits::max(), // workspace per c + 0, // occupancy + 6, // dynamic_grid + 0, + max_cus); + + // 2) Compute latency of a wave + // Compute latency of a wave + double L_wave = compute_wave_latency(hardware, + M, + N, + K, + batch, + transA, + transB, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + element_size_out, + mi_datatype, + mx_block_size, + WGM, + occupancy, + numActiveCUs, + splittingFactor); + + // Compute latency for all waves and return it as the latency for the MT/problem + double total_latency = L_wave * numWaves; + + if(MT_M == 64 && MT_N == 32 && MT_K == 32 && !transB && element_size_A == 16) + { + total_latency = total_latency * 10; + } + + // 3) Customized heuristics + // TODO These are quantifying effects that don't work in the current math. + // TODO THESE SHOULD BE TEMPORARY FIXES AND BE MORE SOLIDLY INTEGRATED LATER + bool heuristics = hardware_t::is_heuristics_enabled(); + + const char* env = std::getenv("ANALYTICAL_GEMM_HEURISTICS"); + heuristics = !(env && std::string(env) == "0"); + + + + // heuristics = 0; + // Heuristics for TF32 + bool tf32_emu = ((mi_datatype == data_type_t::XFloat32) + && (hardware.arch == hardware_t::architecture_t::gfx950)); + if(tf32_emu && heuristics) + { + double bytes_per_element = static_cast(element_size_A) / 8.0; + double arith = emulated_tf32_arithmetic_intensity(M, N, K, bytes_per_element); + double compute_threshold = 1000; // threshold empirically determined. + + // The kernel for this is more optimized (Custom kernel NT) + if((!transA && transB) && MT_M == 256 && MT_N == 256 && MT_K == 32) + { + if(arith < compute_threshold) + total_latency = total_latency * 0.6; + else + total_latency = total_latency * 0.4; + } + + // The kernel for this is more optimized (Custom kernel NN) + if((!transA && !transB) && MT_M == 256 && MT_N == 256 && MT_K == 32) + { + if(arith < compute_threshold) + total_latency = total_latency * 0.8; + else + total_latency = total_latency * 0.4; + } + + // The kernel for this is more optimized (Custom kernel TN) + if((transA && !transB) && MT_M == 256 && MT_N == 256 && MT_K == 32) + { + if(arith < compute_threshold) + total_latency = total_latency * 0.8; + else + total_latency = total_latency * 0.4; + } + + // Bias large DU where K-dimension is large and M and N are small. + if((K >= (M * 16) && K >= (N * 16)) && (MT_K >= 128)) + { + total_latency = total_latency * 0.5; + } + } + + if(hardware_t::is_debug_enabled()) + { + hardware.log_debug("Total_latency (with heuristics)", total_latency); + hardware.log_debug("non_temporal_a", non_temporal_a); + hardware.log_debug("non_temporal_b", non_temporal_b); + hardware.log_debug("kernel_occupancy", occupancy); + hardware.log_debug("splitting_factor", splittingFactor); + hardware.log_debug("Input Tile Size A", MT_M * MT_K); + hardware.log_debug("Input Tile Size B", MT_N * MT_K); + hardware.log_debug("Output Tile Size", MT_M * MT_N); + hardware.log_debug("Tile M/N", MT_M / MT_N); + hardware.log_debug("Tile N/M", MT_N / MT_M); + hardware.log_debug("Problem M/N", MT_M / MT_N); + hardware.log_debug("Problem N/M", MT_N / MT_M); + size_t occupancy_percent = numActiveCUs / hardware.N_CU; + hardware.log_debug("Peak theoretical GFLOPs based on occupancy", + 1300 * occupancy_percent); + if(hardware_t::is_debug_enabled()) + { + hardware.print_debug_info(); + } + } + + return total_latency; + } - // Bias large DU where K-dimension is large and M and N are small. - if ((K >= (M * 16) && K >= (N * 16)) && (MT_K >= 128)) { - total_latency = total_latency * 0.5; - } + // Compute the performance from the latency. + // IMPORTANT : This program is NOT meant to be an analytical model for performance, but rather a way + // to rank different macro tile sizes. These performance values could be wildly inaccurate in + // absolute terms, but will often result in the correct ranking of MTin relative terms. + double compute_perf_gflops(const hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, + size_t element_size_B, + size_t element_size_out, + data_type_t mi_datatype, + int WGM, + size_t max_cus) + { + // Compute total FLOPs + double total_FLOPs = 2.0 * M * N * K; // For GEMM, each multiply-add is 2 FLOPs + // Compute total time in seconds + double cycles_per_second + = hardware.compute_clock_ghz * 1e9; // 1 GHz = 1e9 cycles per second + size_t mx_block_size = 0; + double latency_cycles = compute_total_latency(hardware, + M, + N, + K, + batch, + transA, + transB, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + element_size_out, + mi_datatype, + mx_block_size, + WGM, + 0, + 0, + 1, + 0, + max_cus); + double total_time_seconds = latency_cycles / cycles_per_second; + // Compute performance in FLOPS + double FLOPS = total_FLOPs / total_time_seconds; + // Convert to TFLOPS + double GFLOPS = FLOPS / 1e9; // 1 TFLOP = 1e9 FLOPs + return GFLOPS; } - } - - if (get_runtime_options(config).debug_enabled) { - config.logger.log("Total_latency (with heuristics)", total_latency); - config.logger.log("non_temporal_a", config.cache_hints_a); - config.logger.log("non_temporal_b", config.cache_hints_b); - config.logger.log("kernel_occupancy", config.occupancy); - config.logger.log("splitting_factor", splitting_factor); - config.logger.log("Input Tile Size A", MT_M * MT_K); - config.logger.log("Input Tile Size B", MT_N * MT_K); - config.logger.log("Output Tile Size", MT_M * MT_N); - config.logger.log("Tile M/N", MT_M / MT_N); - config.logger.log("Tile N/M", MT_N / MT_M); - config.logger.log("Problem M/N", M / N); - config.logger.log("Problem N/M", N / M); - size_t occupancy_percent = num_active_cus / hardware.N_CU; - config.logger.log("Peak theoretical GFLOPs based on occupancy", 1300 * occupancy_percent); - if (get_runtime_options(config).debug_enabled) { config.logger.print(); } - } - - return total_latency; -} - -} // namespace origami +} // namespace origami diff --git a/shared/origami/src/origami/hardware.cpp b/shared/origami/src/origami/hardware.cpp deleted file mode 100644 index 51485fb2bf1..00000000000 --- a/shared/origami/src/origami/hardware.cpp +++ /dev/null @@ -1,388 +0,0 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "origami/hardware.hpp" -#include "origami/types.hpp" - -#include -#include -#include -#include - -namespace origami { - -// Static member definition -// clang-format off -const std::unordered_map> - hardware_t::INSTRUCTION_MAP = { - {hardware_t::architecture_t::gfx90a, - { - // F32 - {matrix_instruction(32, 32, 2, data_type_t::Float), 64}, // v_mfma_f32_32x32x2_f32 - {matrix_instruction(32, 32, 1, data_type_t::Float), 64}, // v_mfma_f32_32x32x1_2b_f32 - {matrix_instruction(16, 16, 4, data_type_t::Float), 32}, // v_mfma_f32_16x16x4_f32 - {matrix_instruction(16, 16, 1, data_type_t::Float), 32}, // v_mfma_f32_16x16x1_4b_f32 - {matrix_instruction(4, 4, 1, data_type_t::Float), 8}, // v_mfma_f32_4x4x1_16b_f32 - - // F64 - {matrix_instruction(16, 16, 4, data_type_t::Double), 32}, // v_mfma_f64_16x16x4_f64 - {matrix_instruction(4, 4, 4, data_type_t::Double), 16}, // v_mfma_f64_4x4x4_4b_f64 - - // TODO ComplexFloat - // TODO ComplexDouble - - // F16 - {matrix_instruction(32, 32, 4, data_type_t::Half), 64}, // v_mfma_f32_32x32x4_2b_f16 - {matrix_instruction(32, 32, 8, data_type_t::Half), 64}, // v_mfma_f32_32x32x8_f16 - {matrix_instruction(16, 16, 4, data_type_t::Half), 32}, // v_mfma_f32_16x16x4_4b_f16 - {matrix_instruction(16, 16, 16, data_type_t::Half), 32}, // v_mfma_f32_16x16x16_f16 - {matrix_instruction(4, 4, 4, data_type_t::Half), 8}, // v_mfma_f32_4x4x4_16b_f16 - - // BF16 - {matrix_instruction(32, 32, 4, data_type_t::BFloat16), 64}, // v_mfma_f32_32x32x4_2b_bf16 - {matrix_instruction(32, 32, 8, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x8_bf16 - {matrix_instruction(16, 16, 4, data_type_t::BFloat16), 32}, // v_mfma_f32_16x16x4_4b_bf16 - {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 - {matrix_instruction(4, 4, 4, data_type_t::BFloat16), 8}, // v_mfma_f32_4x4x4_16b_bf16 - - // I8 - {matrix_instruction(32, 32, 8, data_type_t::Int8), 64}, // v_mfma_f32_32x32x16_f8 - {matrix_instruction(32, 32, 4, data_type_t::Int8), 64}, // v_mfma_i32_32x32x4_2b_i8 - {matrix_instruction(16, 16, 16, data_type_t::Int8), 32}, // v_mfma_f32_16x16x32_i8 - {matrix_instruction(16, 16, 4, data_type_t::Int8), 32}, // v_mfma_i32_16x16x4_4b_i8 - {matrix_instruction(4, 4, 4, data_type_t::Int8), 8}, // v_mfma_i32_4x4x4_16b_i8 - - // XF32 - {matrix_instruction(32, 32, 8, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x8_bf16 * 3 - {matrix_instruction(32, 32, 16, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x16_bf16 * 3 - {matrix_instruction(16, 16, 16, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 - {matrix_instruction(16, 16, 32, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 - }}, - {hardware_t::architecture_t::gfx942, - { - // F32 - {matrix_instruction(32, 32, 2, data_type_t::Float), 64}, // v_mfma_f32_32x32x2_f32 - {matrix_instruction(32, 32, 1, data_type_t::Float), 64}, // v_mfma_f32_32x32x1_2b_f32 - {matrix_instruction(16, 16, 4, data_type_t::Float), 32}, // v_mfma_f32_16x16x4_f32 - {matrix_instruction(16, 16, 1, data_type_t::Float), 32}, // v_mfma_f32_16x16x1_4b_f32 - {matrix_instruction(4, 4, 1, data_type_t::Float), 8}, // v_mfma_f32_4x4x1_16b_f32 - - // F64 - {matrix_instruction(16, 16, 4, data_type_t::Double), 32}, // v_mfma_f64_16x16x4_f64 - {matrix_instruction(4, 4, 4, data_type_t::Double), 16}, // v_mfma_f64_4x4x4_4b_f64 - - // TODO ComplexFloat - // TODO ComplexDouble - - // F16 - {matrix_instruction(32, 32, 4, data_type_t::Half), 64}, // v_mfma_f32_32x32x4_2b_f16 - {matrix_instruction(32, 32, 8, data_type_t::Half), 32}, // v_mfma_f32_32x32x8_f16 - {matrix_instruction(16, 16, 4, data_type_t::Half), 32}, // v_mfma_f32_16x16x4_4b_f16 - {matrix_instruction(16, 16, 16, data_type_t::Half), 16}, // v_mfma_f32_16x16x16_f16 - {matrix_instruction(4, 4, 4, data_type_t::Half), 8}, // v_mfma_f32_4x4x4_16b_f16 - - // BF16 - {matrix_instruction(32, 32, 4, data_type_t::BFloat16), 64}, // v_mfma_f32_32x32x4_2b_bf16 - {matrix_instruction(32, 32, 8, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x8_bf16 - {matrix_instruction(16, 16, 4, data_type_t::BFloat16), 32}, // v_mfma_f32_16x16x4_4b_bf16 - {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 - {matrix_instruction(4, 4, 4, data_type_t::BFloat16), 8}, // v_mfma_f32_4x4x4_16b_bf16 - - // F8 - {matrix_instruction(32, 32, 16, data_type_t::Float8_fnuz), 32}, // v_mfma_f32_32x32x16_f8 - {matrix_instruction(16, 16, 32, data_type_t::Float8_fnuz), 16}, // v_mfma_f32_16x16x32_f8 - - // BF8 - {matrix_instruction(32, 32, 16, data_type_t::BFloat8_fnuz), 32}, // v_mfma_f32_32x32x16_bf8 - {matrix_instruction(16, 16, 32, data_type_t::BFloat8_fnuz), 16}, // v_mfma_f32_16x16x32_bf8 - - // F8B8 - {matrix_instruction(32, 32, 16, data_type_t::Float8BFloat8_fnuz), 32}, // v_mfma_f32_32x32x16_f8_bf8 - {matrix_instruction(16, 16, 32, data_type_t::Float8BFloat8_fnuz), 16}, // v_mfma_f32_16x16x32_f8_bf8 - - // B8F8 - {matrix_instruction(32, 32, 16, data_type_t::BFloat8Float8_fnuz), 32}, // v_mfma_f32_32x32x16_bf8_f8 - {matrix_instruction(16, 16, 32, data_type_t::BFloat8Float8_fnuz), 16}, // v_mfma_f32_16x16x32_bf8_f8 - - // I8 - {matrix_instruction(32, 32, 16, data_type_t::Int8), 32}, // v_mfma_f32_32x32x16_f8 - {matrix_instruction(32, 32, 4, data_type_t::Int8), 64}, // v_mfma_i32_32x32x4_2b_i8 - {matrix_instruction(16, 16, 32, data_type_t::Int8), 16}, // v_mfma_f32_16x16x32_i8 - {matrix_instruction(16, 16, 4, data_type_t::Int8), 32}, // v_mfma_i32_16x16x4_4b_i8 - {matrix_instruction(4, 4, 4, data_type_t::Int8), 8}, // v_mfma_i32_4x4x4_16b_i8 - - // XF32 - {matrix_instruction(32, 32, 4, data_type_t::XFloat32), 32}, // v_mfma_f32_32x32x4_xf32 - {matrix_instruction(16, 16, 32, data_type_t::XFloat32), 16}, // v_mfma_f32_16x16x8_xf32 - }}, - {hardware_t::architecture_t::gfx950, - { - // F32 - {matrix_instruction(32, 32, 2, data_type_t::Float), 64}, // v_mfma_f32_32x32x2_f32 - {matrix_instruction(32, 32, 1, data_type_t::Float), 64}, // v_mfma_f32_32x32x1_2b_f32 - {matrix_instruction(16, 16, 4, data_type_t::Float), 32}, // v_mfma_f32_16x16x4_f32 - {matrix_instruction(16, 16, 1, data_type_t::Float), 32}, // v_mfma_f32_16x16x1_4b_f32 - {matrix_instruction(4, 4, 1, data_type_t::Float), 8}, // v_mfma_f32_4x4x1_16b_f32 - - // F64 - {matrix_instruction(16, 16, 4, data_type_t::Double), 64}, // v_mfma_f64_16x16x4_f64 - {matrix_instruction(4, 4, 4, data_type_t::Double), 16}, // v_mfma_f64_4x4x4_4b_f64 - - // TODO ComplexFloat - // TODO ComplexDouble - - // F16 - {matrix_instruction(32, 32, 8, data_type_t::Half), 32}, // v_mfma_f32_32x32x8_f16 - {matrix_instruction(32, 32, 16, data_type_t::Half), 32}, // v_mfma_f32_32x32x16_f16 - {matrix_instruction(16, 16, 16, data_type_t::Half), 16}, // v_mfma_f32_16x16x16_f16 - {matrix_instruction(16, 16, 32, data_type_t::Half), 16}, // v_mfma_f32_16x16x32_f16 - - // BF16 - {matrix_instruction(32, 32, 8, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x8_bf16 - {matrix_instruction(32, 32, 16, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x16_bf16 - {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 - {matrix_instruction(16, 16, 32, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 - - // F8 - {matrix_instruction(32, 32, 64, data_type_t::Float8), 64}, // v_mfma_f32_32x32x64_f8 - {matrix_instruction(32, 32, 16, data_type_t::Float8), 32}, // v_mfma_f32_32x32x16_f8 - {matrix_instruction(16, 16, 128, data_type_t::Float8), 32}, // v_mfma_f32_16x16x128_f8 - {matrix_instruction(16, 16, 32, data_type_t::Float8), 16}, // v_mfma_f32_16x16x32_f8 - - // BF8 - {matrix_instruction(32, 32, 64, data_type_t::BFloat8), 64}, // v_mfma_f32_32x32x64_bf8 - {matrix_instruction(32, 32, 16, data_type_t::BFloat8), 32}, // v_mfma_f32_32x32x16_bf8 - {matrix_instruction(16, 16, 128, data_type_t::BFloat8), 32}, // v_mfma_f32_16x16x128_bf8 - {matrix_instruction(16, 16, 32, data_type_t::BFloat8), 16}, // v_mfma_f32_16x16x32_bf8 - - // F8B8 - {matrix_instruction(32, 32, 64, data_type_t::Float8BFloat8), 64}, // v_mfma_f32_32x32x64_f8_bf8 - {matrix_instruction(32, 32, 16, data_type_t::Float8BFloat8), 32}, // v_mfma_f32_32x32x16_f8_bf8 - {matrix_instruction(16, 16, 128, data_type_t::Float8BFloat8), 32}, // v_mfma_f32_16x16x128_f8_bf8 - {matrix_instruction(16, 16, 32, data_type_t::Float8BFloat8), 16}, // v_mfma_f32_16x16x32_f8_bf8 - - // B8F8 - {matrix_instruction(32, 32, 64, data_type_t::BFloat8Float8), 64}, // v_mfma_f32_32x32x64_bf8_f8 - {matrix_instruction(32, 32, 16, data_type_t::BFloat8Float8), 32}, // v_mfma_f32_32x32x16_bf8_f8 - {matrix_instruction(16, 16, 128, data_type_t::BFloat8Float8), 32}, // v_mfma_f32_16x16x128_bf8_f8 - {matrix_instruction(16, 16, 32, data_type_t::BFloat8Float8), 16}, // v_mfma_f32_16x16x32_bf8_f8 - - // I8 - {matrix_instruction(32, 32, 16, data_type_t::Int8), 32}, // v_mfma_f32_32x32x16_f8 - {matrix_instruction(32, 32, 4, data_type_t::Int8), 64}, // v_mfma_i32_32x32x4_2b_i8 - {matrix_instruction(16, 16, 32, data_type_t::Int8), 16}, // v_mfma_f32_16x16x32_i8 - {matrix_instruction(16, 16, 4, data_type_t::Int8), 32}, // v_mfma_i32_16x16x4_4b_i8 - {matrix_instruction(4, 4, 4, data_type_t::Int8), 8}, // v_mfma_i32_4x4x4_16b_i8 - - // XF32 - {matrix_instruction(32, 32, 8, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x8_bf16 * 3 - {matrix_instruction(32, 32, 16, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x16_bf16 * 3 - {matrix_instruction(16, 16, 16, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 - {matrix_instruction(16, 16, 32, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 - - // F6 - {matrix_instruction(32, 32, 64, data_type_t::Float6), 32}, // v_mfma_f32_32x32x64_f6 - {matrix_instruction(16, 16, 128, data_type_t::Float6), 16}, // v_mfma_f32_16x16x128_f6 - - // BF6 - {matrix_instruction(32, 32, 64, data_type_t::BFloat6), 32}, // v_mfma_f32_32x32x64_bf6 - {matrix_instruction(16, 16, 128, data_type_t::BFloat6), 16}, // v_mfma_f32_16x16x128_bf6 - - // F4 - {matrix_instruction(32, 32, 64, data_type_t::Float4), 32}, // v_mfma_f32_32x32x64_f4 - {matrix_instruction(16, 16, 128, data_type_t::Float4), 16}, // v_mfma_f32_16x16x128_f4 - - // DOT2 - {matrix_instruction(1, 1, 64, data_type_t::Half), 16}, // V_DOT2_F32_F16 - {matrix_instruction(1, 1, 64, data_type_t::BFloat16), 16}, // V_DOT2_F32_BF16 - }}, - {hardware_t::architecture_t::gfx1201, - { - // F16 - {matrix_instruction(16, 16, 16, data_type_t::Half), 16}, // v_wmma_f16_16x16x16_f16/v_wmma_f32_16x16x16_f16 - - // BF16 - {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_wmma_bf16_16x16x16_bf16/v_wmma_f32_16x16x16_bf16 - - // F8 - {matrix_instruction(16, 16, 16, data_type_t::Float8), 8}, // v_wmma_f32_16x16x16_fp8_fp8 - - // F8B8 - {matrix_instruction(16, 16, 16, data_type_t::Float8BFloat8), 8}, // v_wmma_f32_16x16x16_fp8_bf8 - - // B8F8 - {matrix_instruction(16, 16, 16, data_type_t::BFloat8Float8), 8}, // v_wmma_f32_16x16x16_bf8_fp8 - - // B8 - {matrix_instruction(16, 16, 16, data_type_t::BFloat8), 8}, // v_wmma_f32_16x16x16_bf8_bf8 - - // I8 - {matrix_instruction(16, 16, 16, data_type_t::Int8), 8}, // v_wmma_i32_16x16x16_iu8 - - // I4 - {matrix_instruction(16, 16, 16, data_type_t::Int4), 8}, // v_wmma_i32_16x16x16_iu4 - {matrix_instruction(16, 16, 32, data_type_t::Int4), 8}, // v_wmma_i32_16x16x32_iu4 - }}, - {hardware_t::architecture_t::gfx1100, - { - // F16 - {matrix_instruction(16, 16, 16, data_type_t::Half), 32}, // v_wmma_f32_16x16x16_f16/v_wmma_f16_16x16x16_f16 - - // BF16 - {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 32}, // v_wmma_f32_16x16x16_bf16/v_wmma_bf16_16x16x16_bf16 - - // I8 - {matrix_instruction(16, 16, 16, data_type_t::Int8), 32}, // v_wmma_i32_16x16x16_iu8 - - // I4 - {matrix_instruction(16, 16, 16, data_type_t::Int4), 16}, // v_wmma_i32_16x16x16_iu4 - }}, - {hardware_t::architecture_t::gfx1151, - { - // F16 - {matrix_instruction(16, 16, 16, data_type_t::Half), 32}, // v_wmma_f32_16x16x16_f16/v_wmma_f16_16x16x16_f16 - - // BF16 - {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 32}, // v_wmma_f32_16x16x16_bf16/v_wmma_bf16_16x16x16_bf16 - - // I8 - {matrix_instruction(16, 16, 16, data_type_t::Int8), 32}, // v_wmma_i32_16x16x16_iu8 - - // I4 - {matrix_instruction(16, 16, 16, data_type_t::Int4), 16}, // v_wmma_i32_16x16x16_iu4 - }}}; -// clang-format on - -hardware_t::hardware_t(architecture_t arch, - size_t N_CU, - size_t lds_capacity, - size_t NUM_XCD, - double mem1_perf_ratio, - double mem2_perf_ratio, - double mem3_perf_ratio, - size_t L2_capacity, - double compute_clock_ghz, - size_t parallel_mi_cu, - std::tuple mem_bw_per_wg_coefficients) - : arch(arch) - , N_CU(N_CU) - , lds_capacity(lds_capacity) - , mem1_perf_ratio(mem1_perf_ratio) - , mem2_perf_ratio(mem2_perf_ratio) - , mem3_perf_ratio(mem3_perf_ratio) - , L2_capacity(L2_capacity) - , CU_per_L2(N_CU / NUM_XCD) - , compute_clock_ghz(compute_clock_ghz) - , parallel_mi_cu(parallel_mi_cu) - , mem_bw_per_wg_coefficients(mem_bw_per_wg_coefficients) - , NUM_XCD(NUM_XCD) {} - -hardware_t::hardware_t(hipDeviceProp_t properties) - : hardware_t(get_hardware_for_properties(properties)) {} - -hardware_t::hardware_t(const hardware_t& other) - : arch(other.arch) - , N_CU(other.N_CU) - , lds_capacity(other.lds_capacity) - , mem1_perf_ratio(other.mem1_perf_ratio) - , mem2_perf_ratio(other.mem2_perf_ratio) - , mem3_perf_ratio(other.mem3_perf_ratio) - , L2_capacity(other.L2_capacity) - , CU_per_L2(other.CU_per_L2) - , compute_clock_ghz(other.compute_clock_ghz) - , parallel_mi_cu(other.parallel_mi_cu) - , mem_bw_per_wg_coefficients(other.mem_bw_per_wg_coefficients) - , NUM_XCD(other.NUM_XCD) {} - -hardware_t hardware_t::get_hardware_for_properties(hipDeviceProp_t properties) { - auto arch_name = get_before_first_colon(properties.gcnArchName); - auto arch_enum = arch_name_to_enum(arch_name); - if (arch_enum == architecture_t::Count) { - throw std::runtime_error( - std::string("Attempting to retrieve hardware constants for unsupported architecture: ") + - std::string(arch_name)); - } - auto constants = get_arch_constants(arch_enum); - return hardware_t( - arch_enum, - properties.multiProcessorCount, - properties.sharedMemPerBlock, - constants.num_xcds, - 1e9 * constants.mem1_perf_ratio / properties.clockRate, - 1e9 * constants.mem2_perf_ratio / (properties.memoryClockRate * constants.mem_clock_ratio), - 1e9 * constants.mem3_perf_ratio / properties.memoryClockRate, - properties.l2CacheSize, - properties.clockRate / 1e6, - constants.parallel_mi_cu, - constants.mem_bw_per_wg_coefficients); -} - -hardware_t hardware_t::get_hardware_for_device(int deviceId) { - hipDeviceProp_t prop; - hipError_t e = hipGetDeviceProperties(&prop, deviceId); - if (e) { throw std::runtime_error(hipGetErrorString(e)); } - return get_hardware_for_properties(prop); -} - -bool hardware_t::is_hardware_supported(hipDeviceProp_t properties) { - auto arch_name = get_before_first_colon(properties.gcnArchName); - auto arch_enum = arch_name_to_enum(arch_name); - return arch_enum != architecture_t::Count; -} - -void hardware_t::print() const { - std::cout << "================== Hardware Configuration ==================\n"; - std::cout << "Number of CUs (N_CU) : " << N_CU << "\n"; - std::cout << "LDS capacity : " << lds_capacity << " bytes\n"; - std::cout << "mem1_perf_ratio : " << mem1_perf_ratio << "\n"; - std::cout << "mem2_perf_ratio : " << mem2_perf_ratio << "\n"; - std::cout << "mem3_perf_ratio : " << mem3_perf_ratio << "\n"; - std::cout << "L2 Cache capacity : " << L2_capacity << " bytes\n"; - std::cout << "CUs per L2 domain : " << CU_per_L2 << "\n"; - std::cout << "Compute clock (GHz) : " << compute_clock_ghz << "\n"; - std::cout << "Parallel MI/CU : " << parallel_mi_cu << "\n"; - std::cout << "Number of XCDs (NUM_XCD) : " << NUM_XCD << "\n"; - std::cout << "mem_bw_per_wg_coefficients: " << std::get<0>(mem_bw_per_wg_coefficients) << ", " - << std::get<1>(mem_bw_per_wg_coefficients) << ", " - << std::get<2>(mem_bw_per_wg_coefficients) << "\n\n"; - - std::cout << "------------------ Instruction Map -------------------------\n"; - // Loop over the instruction_map and print each entry - for (const auto& kv : INSTRUCTION_MAP.at(arch)) { - const auto& key = kv.first; - const auto& L_MI = kv.second; - - std::cout << "Instruction: MI_M=" << key.MI_M << ", MI_N=" << key.MI_N << ", MI_K=" << key.MI_K - << ", mi_input_type=" << datatype_to_string(key.mi_input_type) << " bytes\n" - << " -> Latency (L_MI): " << L_MI << "\n"; - } - std::cout << "===========================================================\n"; -} - -size_t hardware_t::get_mi_latency(size_t MI_M, - size_t MI_N, - size_t MI_K, - data_type_t mi_input_type) const { - const auto& instruction_map = INSTRUCTION_MAP.at(arch); - auto key = matrix_instruction(MI_M, MI_N, MI_K, mi_input_type); - - auto it = instruction_map.find(key); - if (it != instruction_map.end()) { - return it->second / parallel_mi_cu; - } else { - if (origami::runtime_options().get().debug_enabled) - std::cerr << "Warning: Latency not found for MI_M=" << MI_M << ", MI_N=" << MI_N - << ", MI_K=" << MI_K << ", mi_input_type=" << datatype_to_string(mi_input_type) - << ". Returning latency value of 32 (really slow).\n"; - return 32 / parallel_mi_cu; // Default latency if instruction is not found - } -} - -std::string hardware_t::get_before_first_colon(const std::string& input) { - size_t pos = input.find(':'); - if (pos != std::string::npos) { return input.substr(0, pos); } - return input; // Return the whole string if ':' is not found -} - -} // namespace origami diff --git a/shared/origami/src/origami/log.cpp b/shared/origami/src/origami/log.cpp deleted file mode 100644 index b6e595f766e..00000000000 --- a/shared/origami/src/origami/log.cpp +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "origami/log.hpp" - -#include -#include -#include -#include - -namespace origami { - -void logger_t::print() const { - if (!metrics_ || metrics_->empty()) { - std::cout << "{}\n"; - return; - } - std::cout << "{\n"; - bool first = true; - for (const auto& [key, val] : *metrics_) { - if (!first) std::cout << ",\n"; - std::cout << " \"" << key << "\": " << val; - first = false; - } - std::cout << "\n}\n"; -} - -void logger_t::export_json(const std::string& filename) const { - if (!metrics_ || metrics_->empty()) { - std::ofstream file(filename); - if (!file.is_open()) { - std::cerr << "Error: Could not open file " << filename << " for writing\n"; - return; - } - file << "{}\n"; - file.close(); - std::cout << "Analytical metrics exported to JSON: " << filename << "\n"; - return; - } - - std::ofstream file(filename); - if (!file.is_open()) { - std::cerr << "Error: Could not open file " << filename << " for writing\n"; - return; - } - - file << std::fixed << std::setprecision(6); - file << "{\n"; - bool first = true; - for (const auto& [key, val] : *metrics_) { - if (!first) file << ",\n"; - file << " \"" << key << "\": " << val; - first = false; - } - file << "\n}\n"; - - file.close(); - std::cout << "Analytical metrics exported to JSON: " << filename << "\n"; -} - -} // namespace origami - diff --git a/shared/origami/src/origami/origami.cpp b/shared/origami/src/origami/origami.cpp deleted file mode 100644 index ada2b9bb3d3..00000000000 --- a/shared/origami/src/origami/origami.cpp +++ /dev/null @@ -1,446 +0,0 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include -#include -#include - -#include "origami/gemm.hpp" -#include "origami/math.hpp" -#include "origami/origami.hpp" -#include "origami/streamk.hpp" -#include "origami/types.hpp" - -namespace origami { - -std::vector select_topk_configs(const problem_t& problem, - const hardware_t& hardware, - const std::vector& configs, - std::size_t topk) { - // Use rank_configs to get configurations with latencies ranked by performance - auto ranked_configs = rank_configs(problem, hardware, configs); - - // Return only the top K configurations - std::vector topk_configs; - size_t count = std::min(topk, ranked_configs.size()); - topk_configs.reserve(count); - for (size_t i = 0; i < count; ++i) { topk_configs.push_back(ranked_configs[i]); } - return topk_configs; -} - -/** - * @brief Selects the best WGM (maximizing L2 hit rate) given fixed macro tile sizes. - * - * @param[in] problem Problem description (M, N, K, etc.) - * @param[in] hardware Hardware characteristics - * @param config Kernel configuration. - * - * @return A tuple: best predicted (wgmxcc, wgm). - */ - -std::tuple select_workgroup_mapping(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - size_t skGrid) { - // Extract parameters from structured types - size_t M = problem.size.m; - size_t N = problem.size.n; - size_t K = problem.size.k; - size_t batch = problem.batch; - - size_t MT_M = config.mt.m; - size_t MT_N = config.mt.n; - size_t MT_K = config.mt.k; - - int nta = config.cache_hints_a; - int ntb = config.cache_hints_b; - - // Default is the closest we can get to a square - size_t max_CU_XCD = hardware.N_CU / hardware.NUM_XCD; - int defaultWGM = static_cast(ceil(std::sqrt(max_CU_XCD))); - - // Number of output MTs per split and batch - size_t numMT_M = math::safe_ceil_div(M, MT_M); - size_t numMT_N = math::safe_ceil_div(N, MT_N); - size_t numMTs = numMT_M * numMT_N; - - // What SK does -- we already have skGrid so just compute numWaves and splitFactor - auto numWaves = skGrid > numMTs ? math::safe_ceil_div(skGrid, hardware.N_CU) - : math::safe_ceil_div(numMTs, hardware.N_CU); - auto splitFactor = math::safe_ceil_div(skGrid, numMTs); - - // ------------------- - // NonTemporal Cases - // ------------------- - if (nta > 3 && ntb < 4) - return std::make_tuple(hardware.NUM_XCD, 1); - else if (nta < 4 && ntb > 3) - return std::make_tuple(hardware.NUM_XCD, - std::min(max_CU_XCD, math::safe_ceil_div(numMTs, hardware.NUM_XCD))); - else if (nta > 3 && ntb > 3) - return std::make_tuple(hardware.NUM_XCD, 1); - - // ------------------- - // WGMXCC Prediction - // ------------------- - // Default WGMXCC -- always number of XCD - int defaultWGMXCC = static_cast(hardware.NUM_XCD); - bool isWGMXCCset = false; - int out_wgmxcc = static_cast(defaultWGMXCC); - - // Batched GEMMs - if (batch > 1 && !isWGMXCCset) { - // Total tiles including batch count - size_t numTotalTiles = numMTs * batch; - - // if only one MT per each GEMM -> no mapping - // if less than hardware.NUM_XCD total tiles -> no mapping - if (numMTs == 1 || numTotalTiles <= hardware.NUM_XCD) { - out_wgmxcc = 1; - isWGMXCCset = true; - } - // else use the default (num_xcd) - } - - // If we are lucky that the splitFactor is a multiple of NUM_XCD -> no mapping - if ((splitFactor % hardware.NUM_XCD == 0) && !isWGMXCCset) { - out_wgmxcc = 1; - isWGMXCCset = true; - } - - // Small GEMMs - if ((numMTs <= hardware.NUM_XCD) && !isWGMXCCset) { - out_wgmxcc = 1; - isWGMXCCset = true; - } - - // For sizes that we have more than 2 waves of computations, we skip xcc mapping as MALL is - // more important -- matrix should not be skinny - // To avoid regressions, it's set to default, but it should actually be 1! - bool MallIsImportant = - (splitFactor == 1 && batch == 1 && numMTs > 2 * hardware.N_CU && numMT_M > 8 && numMT_N > 8); - if (MallIsImportant && !isWGMXCCset) { - out_wgmxcc = defaultWGMXCC; - isWGMXCCset = true; - } - - // ------------------- - // WGM Prediction - // ------------------- - // Default WGM - bool isWGMset = false; - int out_wgm = defaultWGM; - - // shortcut: - // 1. if we have decided to not remap xcc, there is no reason to use wgm - // 2. GEMMs that only have one tile in one dimension don't need wgm - // 3. Batched GEMMs don't need wgm (emprically -> batch count is often large!) - if (((out_wgmxcc == 1 && !MallIsImportant) || numMT_M == 1 || numMT_N == 1 || batch > 1) && - !isWGMset) { - out_wgm = 1; - isWGMset = true; - } - - // For tall cases (M >> N), if we have enough tiles to schedule, we use the number of tiles - // in the smaller dimension as WGM value - if (numMTs > hardware.N_CU && M > 10 * N && numMT_N <= 8) { - out_wgm = numMT_N; - isWGMset = true; - } - - // Cases where we have multiple rounds of computation per each CU - // To avoid regressions, it's set to defaultWGM. However, I think WGM=1 should be the winner - if (MallIsImportant && !isWGMset) { - out_wgm = defaultWGM; - isWGMset = true; - } - - if (!isWGMset) { - size_t numWGs = numWaves * splitFactor * numMTs; - size_t q = numWGs / hardware.NUM_XCD; - size_t r = numWGs % hardware.NUM_XCD; - - std::vector wgmList = {1, 2, 3, 4, 5, 6, 8, 16}; - int bestWGM = 1; - int bestL2 = std::numeric_limits::max(); - for (auto wgm : wgmList) { - auto slabTiles = numMT_M * std::min(wgm, static_cast(numMT_N)); - auto slabCount = math::safe_ceil_div(numMT_N, wgm); - auto edgeSlabWidth = numMT_N - (slabCount - 1) * wgm; - auto wgmL2Estimate = 0; - auto numXCD = std::min(hardware.NUM_XCD, numWGs); - - // Compute unique loads per L2 tile - for (uint32_t x = 0; x < numWaves * numXCD; ++x) { - // Range of "output tiles" that this xcd takes. - auto xccStart = q * x + (x < r ? x : r); - auto xccEnd = xccStart + q - 1 + (x < r ? 1 : 0); - // xccStart and xccEnd are supposed to be tile IDs - // In case of splitting, they are WG IDs. Modify to get tile IDs - xccStart /= splitFactor; - xccEnd /= splitFactor; - - auto slabStart = xccStart / slabTiles; - auto slabEnd = xccEnd / slabTiles; - - auto firstSlabWidth = (slabStart == slabCount - 1 ? edgeSlabWidth : wgm); - auto firstSlabStartIndex = xccStart % slabTiles; - auto firstSlabStartRow = firstSlabStartIndex / firstSlabWidth; - auto firstSlabEndRow = - std::min((firstSlabStartIndex + (xccEnd - xccStart)) / firstSlabWidth, numMT_M - 1); - auto rowsInFirstSlab = firstSlabEndRow - firstSlabStartRow + 1; - - auto lastSlabWidth = (slabEnd == slabCount - 1 ? edgeSlabWidth : wgm); - auto lastSlabEndIndex = xccEnd % slabTiles; - auto lastSlabEndRow = lastSlabEndIndex / lastSlabWidth; - auto colsInLastRow = (lastSlabEndIndex % lastSlabWidth) + 1; - auto colsInLastSlab = (lastSlabEndRow > 0 ? lastSlabWidth : colsInLastRow); - - size_t uniqueRows = 0; - size_t uniqueCols = 0; - if (slabEnd == slabStart) { - uniqueRows = lastSlabEndRow - firstSlabStartRow + 1; - uniqueCols = firstSlabWidth; - if (rowsInFirstSlab <= 2) uniqueCols = std::min(xccEnd - xccStart + 1, firstSlabWidth); - } else { - auto colsInFirstRow = firstSlabWidth - (xccStart % firstSlabWidth); - auto colsInFirstSlab = (rowsInFirstSlab > 1 ? firstSlabWidth : colsInFirstRow); - auto fullSlabs = slabEnd - slabStart - 1; - uniqueRows = - (fullSlabs > 0 ? numMT_M : std::min(rowsInFirstSlab + lastSlabEndRow + 1, numMT_M)); - uniqueCols = colsInFirstSlab + colsInLastSlab + fullSlabs * wgm; - } - - // Sum up the L2 loads over all XCD - // We should technically multiply by K (or splitted K), but it - // has no effect on sorting - auto xccL2Estimate = uniqueRows * MT_M + uniqueCols * MT_N; - wgmL2Estimate += xccL2Estimate; - } - - // If we have found a better WGM - if (wgmL2Estimate < bestL2) { - bestL2 = wgmL2Estimate; - bestWGM = wgm; - } - - if (get_runtime_options(config).debug_enabled) { - config.logger.log("WGM", wgm); - config.logger.log("L2Estimate", wgmL2Estimate); - } - } - - out_wgm = bestWGM; - } - - return std::make_tuple(out_wgmxcc, out_wgm); -} - -std::vector rank_configs(const problem_t& problem, - const hardware_t& hardware, - const std::vector& configs) { - if (configs.empty()) { throw std::runtime_error("No configurations provided."); } - - std::vector results(configs.size()); - - std::transform(std::execution::seq, - configs.begin(), - configs.end(), - results.begin(), - [&](const config_t& config) -> prediction_result_t { - if (!check_lds_capacity(hardware, config.mt, problem.a_dtype, problem.b_dtype)) { - return {std::numeric_limits::max(), config}; - } - double latency = compute_total_latency(problem, hardware, config, hardware.N_CU); - return {latency, config}; - }); - - results.erase(std::remove_if(results.begin(), - results.end(), - [](const prediction_result_t& p) { - return p.latency == std::numeric_limits::max(); - }), - results.end()); - - std::stable_sort(results.begin(), - results.end(), - [](const prediction_result_t& a, const prediction_result_t& b) { - return a.latency < b.latency; - }); - - if (results.empty()) { throw std::runtime_error("No valid configs found."); } - - // Compute arithmetic intensity for tie-breaking - // Flops = 2 * MT_M * MT_N * MT_K, Memory traffic = MT_M*MT_K + MT_K*MT_N + MT_M*MT_N - auto compute_arithmetic_intensity = [](const config_t& config) -> double { - const auto MT_M = config.mt.m; - const auto MT_N = config.mt.n; - const auto MT_K = config.mt.k; - - const double flops = static_cast(2ull * MT_M * MT_N * MT_K); - const double memory_traffic = static_cast(MT_M * MT_K + MT_N * MT_K + MT_M * MT_N); - - if (memory_traffic == 0.0) return 0.0; - return flops / memory_traffic; - }; - - // Apply tie-breaking logic for configs with similar latency - double best_latency = results.front().latency; - size_t num_the_same = 0; - - // Count the number of similar latencies - constexpr double epsilon = 1e-9; - // variance is set through environment variable ANALYTICAL_GEMM_HEURISTICS_VARIANCE - // Use runtime_options from first config if available, otherwise global singleton - const double top_N_heuristic = get_runtime_options(configs.front()).heuristics_variance; - for (const auto& res : results) { - bool within_top; - const double diff = std::abs(res.latency - best_latency); - - if (top_N_heuristic <= epsilon) { - // Absolute tolerance path - within_top = diff < epsilon; - } else { - // Relative tolerance path (guard denom) - const double denom = std::max(std::abs(best_latency), epsilon); - // If it's within top_N_heuristic%, include it. - within_top = (diff / denom) < top_N_heuristic; - } - - if (within_top) - ++num_the_same; - else - break; - } - - // Sort top candidates by arithmetic intensity (descending - highest first) - if (num_the_same > 1) { - std::stable_sort(results.begin(), - results.begin() + num_the_same, - [&compute_arithmetic_intensity](const prediction_result_t& a, - const prediction_result_t& b) { - return compute_arithmetic_intensity(a.config) > - compute_arithmetic_intensity(b.config); - }); - - // After arithmetic intensity tie-breaking, check if we still have ties - // among the top results (those with same latency and arithmetic intensity) - // Check if the top tiles still have the same arithmetic intensity - double first_ai = compute_arithmetic_intensity(results.front().config); - size_t num_same_ai = 1; - for (size_t i = 1; i < num_the_same; ++i) { - double current_ai = compute_arithmetic_intensity(results[i].config); - if (std::abs(current_ai - first_ai) < 1e-6) { - num_same_ai++; - } else { - break; - } - } - - // If we still have ties after arithmetic intensity, apply problem dimension tie-breaker - if (num_same_ai > 1) { - // Problem dimension-based tie breaker: - // If M > N, prefer tiles with larger MT_M - // If N > M, prefer tiles with larger MT_N - // If M == N, this tie-breaker doesn't apply (will use final tie-breaker) - - if (problem.size.m != problem.size.n) { - std::stable_sort(results.begin(), - results.begin() + num_same_ai, - [problem](const prediction_result_t& a, const prediction_result_t& b) { - if (problem.size.m > problem.size.n) { - // M-dominant: prefer larger MT_M - if (a.config.mt.m != b.config.mt.m) - return a.config.mt.m > b.config.mt.m; - // If MT_M is same, prefer larger MT_N as secondary - return a.config.mt.n > b.config.mt.n; - } else // N > M - { - // N-dominant: prefer larger MT_N - if (a.config.mt.n != b.config.mt.n) - return a.config.mt.n > b.config.mt.n; - // If MT_N is same, prefer larger MT_M as secondary - return a.config.mt.m > b.config.mt.m; - } - }); - } - - // Final tie-breaker: when all else is equal (including square problems), - // consistently prefer tiles with larger MT_M - // This ensures deterministic selection regardless of input order - std::stable_sort(results.begin(), - results.begin() + num_same_ai, - [](const prediction_result_t& a, const prediction_result_t& b) { - // Prefer larger MT_M first - if (a.config.mt.m != b.config.mt.m) return a.config.mt.m > b.config.mt.m; - // If MT_M is same, prefer larger MT_N - if (a.config.mt.n != b.config.mt.n) return a.config.mt.n > b.config.mt.n; - // If both MT_M and MT_N are same, prefer larger MT_K - return a.config.mt.k > b.config.mt.k; - }); - } - } - - return results; -} - -prediction_result_t select_config_mnk(size_t M, - size_t N, - size_t K, - const hardware_t& hardware, - const std::vector& configs) { - // Create a default problem_t with the provided M, N, K and reasonable defaults - problem_t problem; - problem.size.m = M; - problem.size.n = N; - problem.size.k = K; - problem.batch = 1; - problem.a_transpose = transpose_t::T; // Default to T - problem.b_transpose = transpose_t::N; // Default to N - problem.a_dtype = data_type_t::Half; // Default to fp16 - problem.b_dtype = data_type_t::Half; // Default to fp16 - problem.c_dtype = data_type_t::Half; // Default to fp16 - problem.d_dtype = data_type_t::Half; // Default to fp16 - problem.mi_dtype = data_type_t::Half; // Default to fp16 - problem.a_mx_block_size = 0; // Default MX block size - problem.b_mx_block_size = 0; // Default MX block size - - // Use the existing select_config function with the constructed problem - return select_config(problem, hardware, configs); -} - -prediction_result_t select_config(const problem_t& problem, - const hardware_t& hardware, - const std::vector& configs) { - auto ranked_configs = rank_configs(problem, hardware, configs); - - // Return the top configuration - return ranked_configs[0]; -} - -double compute_perf_gflops(const hardware_t& hardware, - const problem_t& problem, - const double latency) { - // Extract parameters from structured types - size_t M = problem.size.m; - size_t N = problem.size.n; - size_t K = problem.size.k; - size_t batch = problem.batch; - - // Compute total FLOPs - double total_FLOPs = 2.0 * M * N * K; // For GEMM, each multiply-add is 2 FLOPs - // Compute total time in seconds - double cycles_per_second = hardware.compute_clock_ghz * 1e9; // 1 GHz = 1e9 cycles per second - - double total_time_seconds = latency / cycles_per_second; - - // Compute performance in FLOPS - double FLOPS = total_FLOPs / total_time_seconds; - // Convert to GFLOPS - double GFLOPS = FLOPS / 1e9; // 1 TFLOP = 1e9 FLOPs - return GFLOPS; -} -} // namespace origami diff --git a/shared/origami/src/origami/streamk.cpp b/shared/origami/src/origami/streamk.cpp index 1973e3c308a..8c6af0bfb2c 100644 --- a/shared/origami/src/origami/streamk.cpp +++ b/shared/origami/src/origami/streamk.cpp @@ -1,447 +1,504 @@ // Copyright Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "origami/streamk.hpp" -#include "origami/gemm.hpp" #include "origami/hardware.hpp" -#include "origami/math.hpp" -#include "origami/types.hpp" - -namespace origami { -namespace streamk { -size_t compute_number_of_output_tiles(size_t mt_m, size_t mt_n, size_t m, size_t n, size_t batch) { - size_t m_tiles = math::safe_ceil_div(m, mt_m); - size_t n_tiles = math::safe_ceil_div(n, mt_n); - return m_tiles * n_tiles * batch; -} - -/** - * @brief Returns number of k-iterations. - * - * @param output_tiles Number of output tiles. - * @param iters_per_tile Number of iterations per tile. - * @return constexpr size_t Number of total iterations. - */ -constexpr size_t num_iters_total(size_t output_tiles, size_t iters_per_tile) { - return output_tiles * iters_per_tile; -} - -/** - * @brief Returns number of k-iterations per tile. - * - * @param mt_k K-dimension tile size. - * @param k Reduction dimension. - * @return constexpr size_t Number of k-iteration per tile. - */ -constexpr size_t num_iters_per_tile(size_t mt_k, size_t k) { return math::safe_ceil_div(k, mt_k); } - -/** - * @brief Number of iterations per cta. - * - * @param iters_total Total number of k-iterations. - * @param g Number of workgroups (grid-size). - * @return constexpr size_t Number of iterations per cta. - */ -constexpr size_t num_iters_per_cta(size_t iters_total, int g) { - return math::safe_ceil_div(iters_total, g); -} - -constexpr size_t num_fixup_peers_v2(size_t g, - size_t iters_total, - size_t iters_per_tile, - size_t iters_per_cta) { - // If tiles don't evenly divide there are always at least 2 fixup peers, and more if - // iters_per_tile > iters_per_cta - size_t hasFixup = - (iters_total % g == 0 && // Check if some WGs have more iters than others - iters_per_cta % iters_per_tile == 0) // Check if WGs have an even number of full tiles - ? 0 - : 1; - return math::safe_ceil_div(iters_per_tile, iters_per_cta) + hasFixup; -} - -/** - * @brief Number of workgroups involved in the Stream-K's fixup step. - * - * @param g Number of total workgroups (grid-size.) - * @param iters_total Total number of k-iterations. - * @param iters_per_tile K-iterations per tile. - * @param iters_per_cta Number of iterations per workgroup. - * @return constexpr size_t Number of workgroups involved in fixup. - */ - -constexpr size_t num_fixup_peers(size_t iters_per_tile, size_t iters_per_cta) { - return math::safe_ceil_div(iters_per_tile, iters_per_cta); -} - -/** - * @brief Returns the predicted latency for a given grid-size. - * - * @param mt BLK_M, BLK_N, BLK_K macro-tile. - * @param size M, N, K size. - * @param batch Number of batches. - * @param g Grid size to test. - * @param a alpha (a), fixed-size cost incurred by each workgroup. - * @param b Beta (b) incorporates conditional costs of outputting temporary partial. - * @param c Represents instruction and stall workload of each MAC-iteration. - * @param d Delta (d) is the cost of reading and accumulating the partial sums. - * @return double Predicted latency. - */ -std::tuple predicted_runtime(dim3_t mt, - dim3_t size, - size_t batch, - size_t g, - double a, - double b, - double c, - double d) { - size_t output_tiles = compute_number_of_output_tiles(mt.m, mt.n, size.m, size.n, batch); - size_t iters_per_tile = num_iters_per_tile(mt.k, size.k); - size_t iters_total = num_iters_total(output_tiles, iters_per_tile); - size_t iters_per_cta = num_iters_per_cta(iters_total, g); - size_t fixup_peers = num_fixup_peers(iters_per_tile, iters_per_cta); - - double runtime = a + (b * (fixup_peers > 1)) + (c * iters_per_cta) + (d * (fixup_peers - 1)); - - return std::make_tuple(runtime, iters_per_cta, fixup_peers); -} - -std::tuple predicted_runtime_v2(dim3_t mt, - dim3_t size, - size_t batch, - size_t g, - double a, - double b, - double c, - double d) { - size_t output_tiles = compute_number_of_output_tiles(mt.m, mt.n, size.m, size.n, batch); - size_t iters_per_tile = num_iters_per_tile(mt.k, size.k); - size_t iters_total = num_iters_total(output_tiles, iters_per_tile); - size_t iters_per_cta = num_iters_per_cta(iters_total, g); - size_t fixup_peers = num_fixup_peers_v2(g, iters_total, iters_per_tile, iters_per_cta); - - size_t remainder_tiles = output_tiles % g; - double k_split_ratio = remainder_tiles / static_cast(g); - - double cache_penalty = 0.0; - if (fixup_peers >= 1) { - // Calculate the ideal equal split ratio - double ideal_split_ratio = 1.0 / fixup_peers; - - // Measure deviation from the ideal equal split - double imbalance = 1 / std::abs(k_split_ratio - ideal_split_ratio); - - // Scale the penalty by the imbalance and the per-collaborator cost (d) - cache_penalty = d * imbalance * fixup_peers; - } - - // Include the cache penalty in the runtime prediction - double runtime = - a + (b * (fixup_peers > 1)) + (c * iters_per_cta) + (d * (fixup_peers - 1)) + cache_penalty; - - return std::make_tuple(runtime, iters_per_cta, fixup_peers, cache_penalty); -} - -reduction_t select_reduction(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - grid_selection_t algorithm) { - reduction_t reduction_strategy = reduction_t::tree; - - if (algorithm == grid_selection_t::k_split_aware) { - size_t tiles = compute_number_of_output_tiles( - config.mt.m, config.mt.n, problem.size.m, problem.size.n, problem.batch); - size_t cu_count = hardware.N_CU; - size_t iters_per_tile = std::max(size_t(1), num_iters_per_tile(config.mt.k, problem.size.k)); - - if (tiles < cu_count) { - // For problems with large k and low number of tiles, use parallel reduction - // TODO Benchmark to check if limits are correct - constexpr int MinItersForParallel = 64; - constexpr int MaxTilesForParallel = 64; - if (iters_per_tile >= MinItersForParallel && tiles <= MaxTilesForParallel) - reduction_strategy = reduction_t::parallel; - } - } - - return reduction_strategy; -} - -/** - * @brief Dynamically pick the minimum between the cu_count or number of tiles. - * @param problem Problem description (M, N, K, etc.) - * @param config Kernel configuration. - * @param cu_count cu count - * @return size_t minimum between the cu_count or number of tiles - */ -size_t grid_min_resources(const problem_t& problem, const config_t& config, size_t cu_count) { - size_t tiles = compute_number_of_output_tiles( - config.mt.m, config.mt.n, problem.size.m, problem.size.n, problem.batch); - return std::min(cu_count, tiles); -} - -/** - * @brief Dynamically pick the minimum between the cu_count or number of tiles, - * and scale down really large sizes to use fewer CUs for power/energy savings - * @param problem Problem description (M, N, K, etc.) - * @param config Kernel configuration. - * @param cu_count cu count - * @return size_t minimum between the cu_count or number of tiles - */ -size_t grid_energy_aware(const problem_t& problem, const config_t& config, size_t cu_count) { - size_t tiles = compute_number_of_output_tiles( - config.mt.m, config.mt.n, problem.size.m, problem.size.n, problem.batch); - size_t sk_grid = cu_count; - - if (tiles > sk_grid) { - for (size_t i = 1; i <= 32; i *= 2) { - size_t tiles_per_cu = math::safe_ceil_div(i * tiles, cu_count); - size_t reduced_grid = math::safe_ceil_div(i * tiles, tiles_per_cu); - float utilization = static_cast(reduced_grid) / static_cast(cu_count); - if (utilization > 0.75f) { - if (utilization < 1.0f) sk_grid = reduced_grid; - break; - } - } - } - return std::min(sk_grid, tiles); -} - -/** - * @brief Dynamically predict the best grid-size by weighing the cost of the fix-up - * step and the cost of processing MAC-loop instructions. When the cost of fix-up - * is the bottleneck, use smaller grid size. - * Architecture dependent. - * @param problem Problem description (M, N, K, etc.) - * @param config Kernel configuration. - * @param grid_start grid_start is 1 - * @param grid_end grid_end is cu_count - * @return size_t minimum between the cu_count or number of tiles - */ -size_t grid_reduction_cost_aware(const problem_t& problem, - const config_t& config, - size_t grid_start, - size_t grid_end) { - // Fixed overhead alpha (a), fixed-size cost incurred by - // each work-group, e.g. the grid launch latency, the initial - // compulsary cache misses, the cost of writing the final output tile - // to C. - // double a = 5544 + 9130; - double a = 2.772 + 4.565; // 5.04 + 8.30; - - // Beta (b) incorporates conditional costs of outputting temporary partial - // sums for scenarios where the number of output tiles does not quantize - // perfectly across the number of processors. - double b = 3.01; // 5.47; 6017; - - // c represents instruction and stall workload of each MAC-iteration. - double c = 2.2935; // 4.17; 4587; - - // Delta (d) is the cost of reading and accumulating the partial sums from - // other work-groups covering the same tile. - double d = 10.22; // 18.59; 20449; - - std::pair min_grid_runtime; - std::pair min_grid_runtime_v2; - min_grid_runtime.second = std::numeric_limits::max(); - min_grid_runtime_v2.second = std::numeric_limits::max(); - - size_t g = grid_start; - - // Predict the number of CTAs to use between 1 and 304 - for (; g <= static_cast(grid_end); ++g) { - auto [runtime, iters_per_cta, fixup_peers] = - predicted_runtime(config.mt, problem.size, problem.batch, g, a, b, c, d); - - auto [runtime_v2, iters_per_cta_v2, fixup_peers_v2, cache_penalty] = - predicted_runtime_v2(config.mt, problem.size, problem.batch, g, a, b, c, d); - - if (min_grid_runtime.second > runtime) { - min_grid_runtime.first = g; - min_grid_runtime.second = runtime; - } - - if (min_grid_runtime_v2.second > runtime_v2) { - min_grid_runtime_v2.first = g; - min_grid_runtime_v2.second = runtime_v2; - } - } - - if (get_runtime_options(config).debug_enabled) { - config.logger.log("grid_reduction_cost_aware_best_grid_size_original", min_grid_runtime.first); - config.logger.log("grid_reduction_cost_aware_best_runtime_original", min_grid_runtime.second); - config.logger.log("grid_reduction_cost_aware_best_grid_size_cache_offset", min_grid_runtime_v2.first); - config.logger.log("grid_reduction_cost_aware_best_runtime_cache_offset", min_grid_runtime_v2.second); - } - - return min_grid_runtime_v2.first; -} - -/** - * @brief Fix Stream-K algorithm to function like a Data-parallel schedule - * where grid size is equal to the number of output tiles. - * @param problem Problem description (M, N, K, etc.) - * @param config Kernel configuration. - * @return size_t number of tiles - */ -size_t grid_data_parallel(const problem_t& problem, const config_t& config) { - return compute_number_of_output_tiles( - config.mt.m, config.mt.n, problem.size.m, problem.size.n, problem.batch); -} - -size_t grid_analytical(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - size_t biggest_allowable_split, - size_t max_cus) { - // Extract parameters from structured types - size_t M = problem.size.m; - size_t N = problem.size.n; - size_t K = problem.size.k; - size_t batch = problem.batch; - - size_t MT_M = config.mt.m; - size_t MT_N = config.mt.n; - - // compute how many 32×32 tiles are needed in each dim, - // then multiply to get total grid size: - size_t grid = ((M + MT_M - 1) / MT_M) * ((N + MT_N - 1) / MT_N) * batch; - - size_t max_hw_split = std::floor(hardware.N_CU / grid); - size_t MAX_SPLIT = std::min(biggest_allowable_split, max_hw_split); - - size_t best_split = 1; - double best_latency = std::numeric_limits::infinity(); - - for (size_t split = 1; split <= MAX_SPLIT; ++split) { - double latency = compute_total_latency(problem, hardware, config, max_cus); - - if (latency < best_latency) { - best_latency = latency; - best_split = split; - } - best_latency = latency; - best_split = split; - } - - size_t best_grid = best_split * grid; - - // you now have both `grid` and `best_split`— - // return whichever is appropriate (here we stick with split): - return best_grid; -} - -size_t grid_k_split_aware(const problem_t& problem, - const config_t& config, - size_t cu_count, - size_t max_cus) { - size_t tiles = compute_number_of_output_tiles( - config.mt.m, config.mt.n, problem.size.m, problem.size.n, problem.batch); - - size_t sk_grid = tiles; // Fallback if no good fractional tile is found - if (max_cus > 0) sk_grid = std::min(sk_grid, max_cus); - - const size_t iters_per_tile = num_iters_per_tile(config.mt.k, problem.size.k); +#include "origami/streamk.hpp" +#include "origami/utils.hpp" +// #include + +namespace origami +{ + namespace streamk + { + namespace math + { + /** + * Performs `(n + d - 1) / d`, but is robust against the case where + * `(n + d - 1)` would overflow. + */ + template + __device__ __host__ inline constexpr N safe_ceil_div(N n, D d) + { + // Static cast to undo integral promotion. + return static_cast(d == 0 ? 0 : (n / d + (n % d != 0 ? 1 : 0))); + } + } // namespace math + + constexpr size_t num_iters_total(size_t output_tiles, size_t iters_per_tile) + { + return output_tiles * iters_per_tile; + } - const size_t tile_size = config.mt.m * config.mt.n * config.workspace_size_per_elem_c; + constexpr size_t num_iters_per_tile(size_t BLK_K, size_t k) + { + return math::safe_ceil_div(k, BLK_K); + } - // More tiles than CUs - // Distribute tiles evenly across maximum number of CUs - // Split remaining tiles as evenly as possible for better caching - if (tiles > cu_count) { - size_t virt_cu_count = cu_count; - if (config.occupancy > 1 && max_cus == 0) virt_cu_count *= config.occupancy; + constexpr size_t num_iters_per_cta(size_t iters_total, int g) + { + return math::safe_ceil_div(iters_total, g); + } - const std::vector tile_fractions = { - 0.0, 1.0 / 2.0, 1.0 / 8.0, 1.0 / 5.0, 1.0 / 4.0, 1.0 / 3.0}; - const size_t min_even_tiles = tiles / virt_cu_count; + constexpr size_t + number_of_output_tiles(size_t BLK_M, size_t BLK_N, size_t m, size_t n, size_t batch) + { + size_t m_tiles = math::safe_ceil_div(m, BLK_M); + size_t n_tiles = math::safe_ceil_div(n, BLK_N); + return m_tiles * n_tiles * batch; + } - for (double frac : tile_fractions) { - const size_t frac_grid = static_cast((tiles / (min_even_tiles + frac)) + 0.5); + constexpr size_t num_fixup_peers_v2(size_t g, + size_t iters_total, + size_t iters_per_tile, + size_t iters_per_cta) + { + // If tiles don't evenly divide there are always at least 2 fixup peers, and more if iters_per_tile > iters_per_cta + size_t hasFixup + = (iters_total % g == 0 && // Check if some WGs have more iters than others + iters_per_cta % iters_per_tile + == 0) // Check if WGs have an even number of full tiles + ? 0 + : 1; + return math::safe_ceil_div(iters_per_tile, iters_per_cta) + hasFixup; + } - // Check if higher occupancy would cause excessive workspace requirements (set current limit - // to 128MB) - if ((tiles % frac_grid != 0) && (tile_size * frac_grid > 128ull * 1024ull * 1024ull)) - continue; + constexpr size_t num_fixup_peers(size_t iters_per_tile, size_t iters_per_cta) + { + return math::safe_ceil_div(iters_per_tile, iters_per_cta); + } - if (frac_grid <= virt_cu_count) { - sk_grid = frac_grid; - break; - } - } - } - // Fewer tiles than CUs - // Split tiles evenly in k-dimension - // Attempt to maximize CU utilization, up to a peak number of splits - // Max splitting is currently constant, but should be dependant on K dimension - else if (tiles < cu_count) { - // For problems with large k and low number of tiles, use parallel reduction - // TODO Benchmark to check if limits are correct - // constexpr int MinItersForParallel = 64; - // constexpr int MaxTilesForParallel = 16; - constexpr int MinItersPerCU = 8; - - if (config.reduction_strategy == reduction_t::parallel) { - size_t virt_cu_count = cu_count; - // TODO check if using occupancy info makes workspace too large - // if (occupancy > 1) - // virt_cu_count *= occupancy; - - // Find max splitting factor to use as much of GPU as possible - const size_t maxSplitsForTiles = virt_cu_count / tiles; - - // Find max splitting factor to ensure each CU has a minimum number of iterations to do - const size_t maxSplitsForIters = iters_per_tile / MinItersPerCU; - - const size_t maxSplits = std::min(maxSplitsForTiles, maxSplitsForIters); - sk_grid = tiles * maxSplits; - } else { - const std::vector tile_fractions = {16, 12, 8, 6, 4, 3, 2, 1}; - for (size_t frac : tile_fractions) { - const size_t splitGrid = tiles * frac; - const size_t itersPerCU = iters_per_tile / frac; - if (splitGrid <= cu_count && itersPerCU >= MinItersPerCU) { - sk_grid = splitGrid; - break; + const char* rtype_to_string(streamk::reduction_type r) + { + switch(r) + { + case streamk::reduction_type::Tree: + return "Tree"; + case streamk::reduction_type::Parallel: + return "Parallel"; + case streamk::reduction_type::None: + return "None"; + default: + return "Unknown"; } - } } - } - - if (tiles % sk_grid != 0 && tile_size * sk_grid > config.workspace_size) sk_grid = tiles; - - return sk_grid; -} - -size_t select_grid_size(const problem_t& problem, - const hardware_t& hardware, - const config_t& config, - grid_selection_t algorithm, - size_t max_cus) { - size_t cu_count = hardware.N_CU; - if (max_cus > 0) cu_count = std::min(cu_count, max_cus); - switch (algorithm) { - case grid_selection_t::min_resources: - return streamk::grid_min_resources(problem, config, cu_count); - case grid_selection_t::energy_aware: - return streamk::grid_energy_aware(problem, config, cu_count); + std::tuple predicted_runtime(size_t BLK_M, + size_t BLK_N, + size_t BLK_K, + size_t m, + size_t n, + size_t k, + size_t batch, + int g, + double a, + double b, + double c, + double d) + { + size_t output_tiles = number_of_output_tiles(BLK_M, BLK_N, m, n, batch); + size_t iters_per_tile = num_iters_per_tile(BLK_K, k); + size_t iters_total = num_iters_total(output_tiles, iters_per_tile); + size_t iters_per_cta = num_iters_per_cta(iters_total, g); + size_t fixup_peers = num_fixup_peers(iters_per_tile, iters_per_cta); + + double runtime + = a + (b * (fixup_peers > 1)) + (c * iters_per_cta) + (d * (fixup_peers - 1)); + + return std::make_tuple(runtime, iters_per_cta, fixup_peers); + } - case grid_selection_t::reduction_cost_aware: - return streamk::grid_reduction_cost_aware(problem, config, 1, cu_count); + std::tuple predicted_runtime_v2(size_t BLK_M, + size_t BLK_N, + size_t BLK_K, + size_t m, + size_t n, + size_t k, + size_t batch, + int g, + double a, + double b, + double c, + double d) + { + size_t output_tiles = number_of_output_tiles(BLK_M, BLK_N, m, n, batch); + size_t iters_per_tile = num_iters_per_tile(BLK_K, k); + size_t iters_total = num_iters_total(output_tiles, iters_per_tile); + size_t iters_per_cta = num_iters_per_cta(iters_total, g); + size_t fixup_peers + = num_fixup_peers_v2(g, iters_total, iters_per_tile, iters_per_cta); + + size_t remainder_tiles = output_tiles % g; + double k_split_ratio = remainder_tiles / static_cast(g); + + double cache_penalty = 0.0; + if(fixup_peers >= 1) + { + // Calculate the ideal equal split ratio + double ideal_split_ratio = 1.0 / fixup_peers; + + // Measure deviation from the ideal equal split + double imbalance = 1 / std::abs(k_split_ratio - ideal_split_ratio); + + // Scale the penalty by the imbalance and the per-collaborator cost (d) + cache_penalty = d * imbalance * fixup_peers; + } + + // Include the cache penalty in the runtime prediction + double runtime = a + (b * (fixup_peers > 1)) + (c * iters_per_cta) + + (d * (fixup_peers - 1)) + cache_penalty; + + return std::make_tuple(runtime, iters_per_cta, fixup_peers, cache_penalty); + } - case grid_selection_t::data_parallel: return streamk::grid_data_parallel(problem, config); + int best_predicted_grid_size(size_t BLK_M, + size_t BLK_N, + size_t BLK_K, + size_t m, + size_t n, + size_t k, + size_t batch, + int grid_start, + int grid_end, + bool verbose = false) + { + + // Fixed overhead alpha (a), fixed-size cost incurred by + // each work-group, e.g. the grid launch latency, the initial + // compulsary cache misses, the cost of writing the final output tile + // to C. + // double a = 5544 + 9130; + double a = 2.772 + 4.565; // 5.04 + 8.30; + + // Beta (b) incorporates conditional costs of outputting temporary partial + // sums for scenarios where the number of output tiles does not quantize + // perfectly across the number of processors. + double b = 3.01; // 5.47; 6017; + + // c represents instruction and stall workload of each MAC-iteration. + double c = 2.2935; // 4.17; 4587; + + // Delta (d) is the cost of reading and accumulating the partial sums from + // other work-groups covering the same tile. + double d = 10.22; // 18.59; 20449; + + std::pair min_grid_runtime; + std::pair min_grid_runtime_v2; + min_grid_runtime.second = std::numeric_limits::max(); + min_grid_runtime_v2.second = std::numeric_limits::max(); + + size_t g = grid_start; + + // Predict the number of CTAs to use between 1 and 304 + for(; g <= static_cast(grid_end); ++g) + { + auto [runtime, iters_per_cta, fixup_peers] + = predicted_runtime(BLK_M, BLK_N, BLK_K, m, n, k, batch, g, a, b, c, d); + + auto [runtime_v2, iters_per_cta_v2, fixup_peers_v2, cache_penalty] + = predicted_runtime_v2(BLK_M, BLK_N, BLK_K, m, n, k, batch, g, a, b, c, d); + + if(verbose) + { + std::cout << "[original] " + << "grid size: " << g << ", runtime: " << runtime + << ", iters_per_cta: " << iters_per_cta << ", fixup_peers: " + << fixup_peers + // << ", cache_penalty: " << cache_penalty + << ", m: " << m << ", n: " << n << ", k: " << k << ", a: " << a + << ", b: " << b << ", c: " << c << ", d: " << d << std::endl; + + std::cout << "[cache-offset] " + << "grid size: " << g << ", runtime: " << runtime_v2 + << ", iters_per_cta: " << iters_per_cta_v2 + << ", fixup_peers: " << fixup_peers_v2 + << ", cache_penalty: " << cache_penalty << ", m: " << m + << ", n: " << n << ", k: " << k << ", a: " << a << ", b: " << b + << ", c: " << c << ", d: " << d << std::endl; + } + + if(min_grid_runtime.second > runtime) + { + min_grid_runtime.first = g; + min_grid_runtime.second = runtime; + } + + if(min_grid_runtime_v2.second > runtime_v2) + { + min_grid_runtime_v2.first = g; + min_grid_runtime_v2.second = runtime_v2; + } + } + + if(verbose) + { + std::cout << "[original] Number of Output Tiles: " + << number_of_output_tiles(BLK_M, BLK_N, m, n, batch) << std::endl; + std::cout << "[original] Minimum runtime: " << min_grid_runtime.second + << " @ grid size: " << min_grid_runtime.first << std::endl; + + std::cout << "[cache-offset] Number of Output Tiles: " + << number_of_output_tiles(BLK_M, BLK_N, m, n, batch) << std::endl; + std::cout << "[cache-offset] Minimum runtime: " << min_grid_runtime_v2.second + << " @ grid size: " << min_grid_runtime_v2.first << std::endl; + } + + return min_grid_runtime_v2.first; + } - case grid_selection_t::analytical: - return streamk::grid_analytical(problem, hardware, config, 10, max_cus); + size_t get_workspace( + size_t x, + size_t y, + size_t mt_m, + size_t mt_n, + size_t bpe_c, + size_t grid, + size_t tiles, + reduction_type reduction) + { + size_t size = 0; + if(reduction == reduction_type::Tree) + { + if(tiles % grid == 0) + { + size_t tileSize = mt_m * mt_n * bpe_c; + size += tileSize * grid; + } + } + else if(reduction == reduction_type::Parallel) + { + size_t splitSize = x * y * bpe_c; + size_t splitCount = grid / tiles; + size += splitSize * splitCount; + } + return size; + } - case grid_selection_t::k_split_aware: - return streamk::grid_k_split_aware(problem, config, cu_count, max_cus); + reduction_type select_reduction( + size_t x, + size_t y, + size_t z, + size_t batch, + size_t mt_m, + size_t mt_n, + size_t mt_k, + const hardware_t& analytical_hardware, + int dynamic_grid_version) + { + reduction_type reductionStrat = reduction_type::Tree; + + if(dynamic_grid_version == 6) + { + size_t tiles = number_of_output_tiles(mt_m, mt_n, x, y, batch); + size_t cu_count = analytical_hardware.N_CU; + size_t iters_per_tile = std::max(size_t(1), math::safe_ceil_div(z, mt_k)); + + if (tiles < cu_count) + { + // For problems with large k and low number of tiles, use parallel reduction + // TODO Benchmark to check if limits are correct + constexpr int MinItersForParallel = 64; + constexpr int MaxTilesForParallel = 64; + if (iters_per_tile >= MinItersForParallel && tiles <= MaxTilesForParallel) + reductionStrat = reduction_type::Parallel; + } + } + + return reductionStrat; + } - case grid_selection_t::number_of_cus: - default: return hardware.N_CU; - } + size_t select_grid( + size_t x, + size_t y, + size_t z, + size_t batch, + bool trans_a, + bool trans_b, + size_t element_size_A, + size_t element_size_B, + size_t element_size_out, + data_type_t mi_datatype, + size_t workspace_size, + size_t mt_m, + size_t mt_n, + size_t mt_k, + size_t mi_m, + size_t mi_n, + size_t mi_k, + int workgroup_mapping, + size_t workspace_size_per_elem_c, + int occupancy, + const hardware_t& analytical_hardware, + int dynamic_grid_version, + reduction_type reduction_strategy, + size_t max_cus) + { + size_t cu_count = analytical_hardware.N_CU; + if(max_cus > 0) + cu_count = std::min(cu_count, max_cus); + size_t tiles = number_of_output_tiles(mt_m, mt_n, x, y, batch); + + // Dynamically pick the minimum between the cu_count or number of tiles. + if(dynamic_grid_version == 1) + { + return std::min(cu_count, tiles); + } + + // Dynamically pick the minimum between the cu_count or number of tiles, + // and scale down really large sizes to use fewer CUs for power/energy savings. + else if(dynamic_grid_version == 2) + { + size_t sk_grid = cu_count; + if(tiles > sk_grid) + { + for(size_t i = 1; i <= 32; i *= 2) + { + size_t tilesPerCU = math::safe_ceil_div(i * tiles, cu_count); + size_t reducedGrid = math::safe_ceil_div(i * tiles, tilesPerCU); + float utilization = ((float)reducedGrid) / ((float)cu_count); + if(utilization > 0.75f) + { + if(utilization < 1.0f) + sk_grid = reducedGrid; + break; + } + } + } + + return std::min(sk_grid, tiles); + } + // Dynamically predict the best grid-size by weighing the cost of the fix-up + // step and the cost of processing MAC-loop instructions. When the cost of fix-up + // is the bottleneck, use smaller grid size. + // Architecture dependent. + else if(dynamic_grid_version == 3) + { + return origami::streamk::best_predicted_grid_size(mt_m, + mt_n, + mt_k, + x, + y, + z, + batch, + 1, + cu_count); + } + // Fix Stream-K algorithm to function like a Data-parallel schedule + // where grid size is equal to the number of output tiles. + else if(dynamic_grid_version == 4) + { + return origami::streamk::number_of_output_tiles( + mt_m, mt_n, x, y, batch); + } + else if(dynamic_grid_version == 5) + { + return origami::select_best_grid_size(x, + y, + z, + batch, + trans_a, + trans_b, + analytical_hardware, + mt_m, + mt_n, + mt_k, + mi_m, + mi_n, + mi_k, + element_size_A, + element_size_B, + element_size_out, + mi_datatype, + 0, + 0.0, + workgroup_mapping, + 10, + max_cus); + } + else if(dynamic_grid_version == 6) + { + size_t iters_per_tile = std::max(size_t(1), math::safe_ceil_div(z, mt_k)); + size_t sk_grid = tiles; // Fallback if no good fractional tile is found + if(max_cus > 0) + sk_grid = std::min(sk_grid, max_cus); + size_t tile_size = mt_m * mt_n * workspace_size_per_elem_c; + // More tiles than CUs + // Distribute tiles evenly across maximum number of CUs + // Split remaining tiles as evenly as possible for better caching + if(tiles > cu_count) + { + size_t virt_cu_count = cu_count; + if (occupancy > 1 && max_cus == 0) + virt_cu_count *= occupancy; + + const std::vector tile_fractions = {0.0, 1.0/2.0, 1.0/8.0, 1.0/5.0, 1.0/4.0, 1.0/3.0}; + size_t min_even_tiles = tiles / virt_cu_count; + for(double frac: tile_fractions) + { + size_t frac_grid = (size_t)((tiles / (min_even_tiles + frac)) + 0.5); + // Check if higher occupancy would cause excessive workspace requirements (set current limit to 128MB) + if((tiles % frac_grid != 0) && (tile_size * frac_grid > 128*1024*1024)) + continue; + if(frac_grid <= virt_cu_count) + { + sk_grid = frac_grid; + break; + } + } + } + // Fewer tiles than CUs + // Split tiles evenly in k-dimension + // Attempt to maximize CU utilization, up to a peak number of splits + // Max splitting is currently constant, but should be dependant on K dimension + else if (tiles < cu_count) + { + // For problems with large k and low number of tiles, use parallel reduction + // TODO Benchmark to check if limits are correct + // constexpr int MinItersForParallel = 64; + // constexpr int MaxTilesForParallel = 16; + constexpr int MinItersPerCU = 8; + // if (iters_per_tile >= MinItersForParallel && tiles <= MaxTilesForParallel) + if(reduction_strategy == reduction_type::Parallel) + { + size_t virt_cu_count = cu_count; + // TODO check if using occupancy info makes workspace too large + // if (occupancy > 1 && max_cus == 0) + // virt_cu_count *= occupancy; + + // Find max splitting factor to use as much of GPU as possible + size_t maxSplitsForTiles = virt_cu_count / tiles; + + // Find max splitting factor to ensure each CU has a minimum number of iterations to do + size_t maxSplitsForIters = iters_per_tile / MinItersPerCU; + + size_t maxSplits = std::min(maxSplitsForTiles, maxSplitsForIters); + sk_grid = tiles * maxSplits; + } + else + { + const std::vector tile_fractions = {16, 12, 8, 6, 4, 3, 2, 1}; + for(size_t frac: tile_fractions) + { + size_t splitGrid = tiles * frac; + size_t itersPerCU = iters_per_tile / frac; + if(splitGrid <= cu_count && itersPerCU >= 8) + { + sk_grid = splitGrid; + break; + } + } + } + } + + if (tiles % sk_grid != 0 && tile_size * sk_grid > workspace_size) + sk_grid = tiles; + return sk_grid; + } + // If no option is specified, launch exactly cu_count worth of workgroups. + else + { + return cu_count; + } + } + } // namespace streamk } -} // namespace streamk -} // namespace origami diff --git a/shared/origami/src/origami/types.cpp b/shared/origami/src/origami/types.cpp deleted file mode 100644 index add4dcbde81..00000000000 --- a/shared/origami/src/origami/types.cpp +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "origami/types.hpp" - -#include -#include -#include - -namespace origami { - -runtime_options::runtime_options() { update_from_env(); } - -runtime_options::runtime_options(bool debug, bool heuristics, double variance) - : debug_enabled(debug), heuristics_enabled(heuristics), heuristics_variance(variance) {} - -runtime_options& runtime_options::get() { - static runtime_options instance; - return instance; -} - -bool runtime_options::read_debug_from_env() { - const char* env = std::getenv("ANALYTICAL_GEMM_DEBUG"); - return env && std::string(env) == "1"; -} - -bool runtime_options::read_heuristics_from_env() { - const char* env = std::getenv("ANALYTICAL_GEMM_HEURISTICS"); - return !(env && std::string(env) == "0"); -} - -double runtime_options::read_heuristics_variance_from_env() { - constexpr double default_variance = 0.01; // 1% - - if (const char* env = std::getenv("ANALYTICAL_GEMM_HEURISTICS_VARIANCE")) { - try { - double val = std::stod(env); - if (std::isfinite(val) && val > 0.0) { return val; } - } catch (...) { - // fall through to default - } - } - return default_variance; -} - -void runtime_options::update_from_env() { - debug_enabled = read_debug_from_env(); - heuristics_enabled = read_heuristics_from_env(); - heuristics_variance = read_heuristics_variance_from_env(); -} - -int datatype_to_bits(data_type_t type) { - switch (type) { - case data_type_t::Float: return 32; - case data_type_t::Double: return 64; - case data_type_t::ComplexFloat: return 64; - case data_type_t::ComplexDouble: return 128; - case data_type_t::Half: return 16; - case data_type_t::Int8x4: return 32; - case data_type_t::Int32: return 32; - case data_type_t::BFloat16: return 16; - case data_type_t::Int8: return 8; - case data_type_t::Int4: return 4; - case data_type_t::Int64: return 64; - case data_type_t::XFloat32: return 32; - case data_type_t::Float8_fnuz: return 8; - case data_type_t::BFloat8_fnuz: return 8; - case data_type_t::Float8BFloat8_fnuz: return 8; - case data_type_t::BFloat8Float8_fnuz: return 8; - case data_type_t::Float8: return 8; - case data_type_t::BFloat8: return 8; - case data_type_t::Float8BFloat8: return 8; - case data_type_t::BFloat8Float8: return 8; - case data_type_t::Float6: return 6; - case data_type_t::BFloat6: return 6; - case data_type_t::Float4: return 4; - default: return -1; // Invalid type - } -} - -std::string datatype_to_string(data_type_t type) { - switch (type) { - case data_type_t::Float: return "Float"; - case data_type_t::Double: return "Double"; - case data_type_t::ComplexFloat: return "ComplexFloat"; - case data_type_t::ComplexDouble: return "ComplexDouble"; - case data_type_t::Half: return "Half"; - case data_type_t::Int8x4: return "Int8x4"; - case data_type_t::Int32: return "Int32"; - case data_type_t::BFloat16: return "BFloat16"; - case data_type_t::Int8: return "Int8"; - case data_type_t::Int4: return "Int4"; - case data_type_t::Int64: return "Int64"; - case data_type_t::XFloat32: return "XFloat32"; - case data_type_t::Float8_fnuz: return "Float8_fnuz"; - case data_type_t::BFloat8_fnuz: return "BFloat8_fnuz"; - case data_type_t::Float8BFloat8_fnuz: return "Float8BFloat8_fnuz"; - case data_type_t::BFloat8Float8_fnuz: return "BFloat8Float8_fnuz"; - case data_type_t::Float8: return "Float8"; - case data_type_t::BFloat8: return "BFloat8"; - case data_type_t::Float8BFloat8: return "Float8BFloat8"; - case data_type_t::BFloat8Float8: return "BFloat8Float8"; - case data_type_t::Float6: return "Float6"; - case data_type_t::BFloat6: return "BFloat6"; - case data_type_t::Float4: return "Float4"; - default: return "Invalid"; - } -} - -data_type_t string_to_datatype(std::string s) { - if (s == "f32") return data_type_t::Float; - if (s == "c32") return data_type_t::ComplexFloat; - if (s == "c64") return data_type_t::ComplexDouble; - if (s == "f64") return data_type_t::Double; - if (s == "f16") return data_type_t::Half; - if (s == "i32") return data_type_t::Int32; - if (s == "bf16") return data_type_t::BFloat16; - if (s == "i8") return data_type_t::Int8; - if (s == "i4") return data_type_t::Int4; - if (s == "xf32") return data_type_t::XFloat32; - if (s == "f8") return data_type_t::Float8; - if (s == "bf8") return data_type_t::BFloat8; - if (s == "f6") return data_type_t::Float6; - if (s == "bf6") return data_type_t::BFloat6; - if (s == "f4") return data_type_t::Float4; - return data_type_t::None; -} - -} // namespace origami diff --git a/shared/origami/src/origami/utils.cpp b/shared/origami/src/origami/utils.cpp new file mode 100644 index 00000000000..39f9909e156 --- /dev/null +++ b/shared/origami/src/origami/utils.cpp @@ -0,0 +1,733 @@ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "origami/utils.hpp" + +#include +#include // For timing +#include +#include +#include // For output formatting +#include + +namespace origami { + +static double read_heuristics_variance_env_var() { + constexpr double default_variance = 0.01; // 1% + + if (const char* env = std::getenv("ANALYTICAL_GEMM_HEURISTICS_VARIANCE")) { + try { + double val = std::stod(env); + if (std::isfinite(val) && val > 0.0) { + return val; + } + } catch (...) { + // fall through to default + } + } + return default_variance; +} + +// +// Tiebreaker function. +// +void pick_best_tile_by_arithmetic_intensity(std::vector& top_results, + size_t num_to_sort) { + if (top_results.empty()) { + throw std::runtime_error("pick_best_tile_by_arithmetic_intensity received empty list."); + } + + // 1) Define a helper function to compute the arithmetic intensity of a tile. + // Here we assume: + // - Flops for tile (MT_M, MT_N, MT_K) is: 2 * MT_M * MT_N * MT_K + // - Memory traffic approximated as: MT_M*MT_K + MT_K*MT_N + MT_M*MT_N + // - Arithmetic intensity = flops / memory_traffic + auto compute_arithmetic_intensity = [](const result_tuple& t) -> double { + // The tuple is: (latency, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K) + auto MT_M = std::get<1>(t); + auto MT_N = std::get<2>(t); + auto MT_K = std::get<3>(t); + + double flops = static_cast(2ull * MT_M * MT_N * MT_K); + double memory_traffic = static_cast(MT_M * MT_K + MT_N * MT_K + MT_M * MT_N); + + // Avoid division by zero. + if (memory_traffic == 0.0) return 0.0; + + return flops / memory_traffic; + }; + // 2) Sort the results in descending order of arithmetic intensity + // (highest arithmetic intensity first). + std::stable_sort(top_results.begin(), top_results.begin() + num_to_sort, + [&](const result_tuple& a, const result_tuple& b) { + double ai_a = compute_arithmetic_intensity(a); + double ai_b = compute_arithmetic_intensity(b); + return ai_a > ai_b; // descending + }); + // 3) Return the tile with the highest arithmetic intensity +} + +result_tuple pick_best_tile_with_dimension_priority(const std::vector& top_results, + size_t M, size_t N, size_t K) { + if (top_results.empty()) { + throw std::runtime_error("pick_best_tile_with_dimension_priority received empty list."); + } + + // 1) Determine whether M or N is more important + // (based on which is larger), and always place K last. + // This yields a priority order of either { 'M', 'N', 'K' } + // or { 'N', 'M', 'K' }. + std::vector dimPriority; + if (M >= N) + dimPriority = {'M', 'N', 'K'}; + else + dimPriority = {'N', 'M', 'K'}; + + // 2) Helper function to extract the tile dimension: + // (latency, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K) + auto getTileSize = [](const result_tuple& t, char dimChar) -> size_t { + switch (dimChar) { + case 'M': + return std::get<1>(t); // MT_M + case 'N': + return std::get<2>(t); // MT_N + case 'K': + return std::get<3>(t); // MT_K + default: + return 0; + } + }; + + // 3) Sort in descending order according to the dimension priority. + // - Compare dimensionPriority[0] first + // - If there's a tie, compare dimensionPriority[1] + // - If still a tie, compare dimensionPriority[2] + // - If they're all equal, consider them tied + std::vector sorted = top_results; // copy + std::stable_sort(sorted.begin(), sorted.end(), + [&](const result_tuple& a, const result_tuple& b) { + for (char d : dimPriority) { + size_t ta = getTileSize(a, d); + size_t tb = getTileSize(b, d); + if (ta > tb) return true; + if (ta < tb) return false; + } + // If all relevant dimensions are the same, treat as a tie + return false; + }); + + // 4) Return the best tile (the first after sorting). + return sorted.front(); +} + +size_t select_best_grid_size(size_t M, size_t N, size_t K, size_t batch, bool transA, bool transB, + const hardware_t& hardware, size_t MT_M, size_t MT_N, size_t MT_K, + size_t MI_M, size_t MI_N, size_t MI_K, size_t element_size_A, + size_t element_size_B, size_t element_size_out, + data_type_t mi_datatype, size_t mx_block_size, double H_L2, int WGM, + size_t biggest_allowable_split, size_t max_cus) { + // compute how many 32×32 tiles are needed in each dim, + // then multiply to get total grid size: + size_t grid = ((M + MT_M - 1) / MT_M) * ((N + MT_N - 1) / MT_N) * batch; + + size_t max_hw_split = std::floor(hardware.N_CU / grid); + size_t MAX_SPLIT = std::min(biggest_allowable_split, max_hw_split); + + size_t best_split = 1; + double best_latency = std::numeric_limits::infinity(); + + for (size_t split = 1; split <= MAX_SPLIT; ++split) { + double latency = + compute_total_latency(hardware, M, N, + K, // problem dims + batch, transA, transB, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K, + element_size_A, // ElementSizeA + element_size_B, // ElementSizeB + element_size_out, // ElementSizeout + mi_datatype, mx_block_size, WGM, 0, 0, 0, split, max_cus); + + if (latency < best_latency) { + best_latency = latency; + best_split = split; + } + best_latency = latency; + best_split = split; + } + + size_t best_grid = best_split * grid; + + // you now have both `grid` and `best_split`— + // return whichever is appropriate (here we stick with split): + return best_grid; +} + +std::vector select_best_macro_tile_size(size_t M, size_t N, size_t K, size_t batch, + bool transA, bool transB, + const hardware_t& hardware, + const std::vector& MT_list, + size_t element_size_A, // In bits + size_t element_size_B, // In bits + size_t element_size_out, // In bits + data_type_t mi_datatype, size_t mx_block_size, + double H_L2, bool print, int defaultWGM, + size_t max_cus) { + std::vector valid_results; + valid_results.reserve(MT_list.size()); + + // bool tf32_emu = ((mi_datatype == data_type_t::XFloat32) + // && (hardware.arch == hardware_t::architecture_t::gfx950)); + + for (const auto& mt : MT_list) { + size_t MT_M = std::get<0>(mt); + size_t MT_N = std::get<1>(mt); + size_t MT_K = std::get<2>(mt); + size_t MI_M = std::get<3>(mt); + size_t MI_N = std::get<4>(mt); + size_t MI_K = std::get<5>(mt); + int occupancy = std::get<6>(mt); + int WGM = std::get<7>(mt); + int non_temporal_a = std::get<8>(mt); + int non_temporal_b = std::get<9>(mt); + + if (hardware_t::is_debug_enabled()) { + std::cout << "Evaluating MT_M=" << MT_M << ", MT_N=" << MT_N << ", MT_K=" << MT_K + << ", MI_M=" << MI_M << ", MI_N=" << MI_N << ", MI_K=" << MI_K << "\n"; + } + + if (check_lds_capacity(hardware, MT_M, MT_N, MT_K, element_size_A)) { + double Total_latency = compute_total_latency( + hardware, M, N, K, batch, transA, transB, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K, + element_size_A, element_size_B, element_size_out, mi_datatype, mx_block_size, + defaultWGM, non_temporal_a, non_temporal_b, occupancy, 0, max_cus); + + valid_results.emplace_back(Total_latency, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K, occupancy, + WGM, non_temporal_a, non_temporal_b); + } else if (hardware_t::is_debug_enabled()) { + std::cout << "Skipping MT_M=" << MT_M << ", MT_N=" << MT_N << ", MT_K=" << MT_K + << " due to LDS capacity\n"; + } + } + + if (valid_results.empty()) { + throw std::runtime_error("No valid macro-tile sizes found."); + } + + // 1) Sort results by ascending latency. + std::stable_sort(valid_results.begin(), valid_results.end(), + [](auto const& a, auto const& b) { return std::get<0>(a) < std::get<0>(b); }); + + // 2) Collect results that tie for the absolute best latency. + double best_latency = std::get<0>(valid_results.front()); + size_t num_the_same = 0; + + // Count the number of similar latencies + constexpr double epsilon = 1e-9; + // variance is set through environment variable ANALYTICAL_GEMM_HEURISTICS_VARIANCE + static const double top_N_heuristic = read_heuristics_variance_env_var(); + for (const auto& res : valid_results) { + bool within_top; + const double diff = std::abs(std::get<0>(res) - best_latency); + + if (top_N_heuristic <= epsilon) { + // Absolute tolerance path + within_top = diff < epsilon; + } else { + // Relative tolerance path (guard denom) + const double denom = std::max(std::abs(best_latency), epsilon); + // If it's within top_N_heuristic%, include it. + within_top = (diff / denom) < top_N_heuristic; + } + + if (within_top) + ++num_the_same; + else + break; + } + + // 3) If that tie group has at least 10 entries, we only use those. + // 4) Otherwise, keep adding the next best latencies until we have 10 total or run out. + // std::vector top_candidates = tie_results; + // if(tie_results.size() < 10) + // { + // size_t i = tie_results.size(); + // while(top_candidates.size() < 10 && i < valid_results.size()) + // { + // top_candidates.push_back(valid_results[i]); + // i++; + // } + // } + // Now ‘top_candidates’ is either all the tied best-latency results (if >=10), + // or the top 10 latencies overall (including however many best-latency entries there were). + + // Finally, use your existing tie-breaker on top_candidates + pick_best_tile_by_arithmetic_intensity(valid_results, num_the_same); + + // After arithmetic intensity tie-breaking, check if we still have ties + // among the top results (those with same latency and arithmetic intensity) + if (num_the_same > 1) { + // Helper to compute arithmetic intensity + auto compute_arithmetic_intensity = [](const result_tuple& t) -> double { + auto MT_M = std::get<1>(t); + auto MT_N = std::get<2>(t); + auto MT_K = std::get<3>(t); + double flops = static_cast(2ull * MT_M * MT_N * MT_K); + double memory_traffic = static_cast(MT_M * MT_K + MT_N * MT_K + MT_M * MT_N); + return (memory_traffic == 0.0) ? 0.0 : (flops / memory_traffic); + }; + + // Check if the top tiles still have the same arithmetic intensity + double first_ai = compute_arithmetic_intensity(valid_results[0]); + size_t num_same_ai = 1; + for (size_t i = 1; i < num_the_same; ++i) { + double current_ai = compute_arithmetic_intensity(valid_results[i]); + if (std::abs(current_ai - first_ai) < 1e-6) { + num_same_ai++; + } else { + break; + } + } + + // If we still have ties after arithmetic intensity, apply problem dimension tie-breaker + if (num_same_ai > 1) { + // Problem dimension-based tie breaker: + // If M > N, prefer tiles with larger MT_M + // If N > M, prefer tiles with larger MT_N + // If M == N, this tie-breaker doesn't apply (will use final tie-breaker) + + if (M != N) { + std::stable_sort(valid_results.begin(), valid_results.begin() + num_same_ai, + [M, N](const result_tuple& a, const result_tuple& b) { + size_t MT_M_a = std::get<1>(a); + size_t MT_N_a = std::get<2>(a); + size_t MT_M_b = std::get<1>(b); + size_t MT_N_b = std::get<2>(b); + + if (M > N) { + // M-dominant: prefer larger MT_M + if (MT_M_a != MT_M_b) return MT_M_a > MT_M_b; + // If MT_M is same, prefer larger MT_N as secondary + return MT_N_a > MT_N_b; + } else // N > M + { + // N-dominant: prefer larger MT_N + if (MT_N_a != MT_N_b) return MT_N_a > MT_N_b; + // If MT_N is same, prefer larger MT_M as secondary + return MT_M_a > MT_M_b; + } + }); + } + + // Final tie-breaker: when all else is equal (including square problems), + // consistently prefer tiles with larger MT_M + // This ensures deterministic selection regardless of input order + std::stable_sort(valid_results.begin(), valid_results.begin() + num_same_ai, + [](const result_tuple& a, const result_tuple& b) { + size_t MT_M_a = std::get<1>(a); + size_t MT_N_a = std::get<2>(a); + size_t MT_K_a = std::get<3>(a); + size_t MT_M_b = std::get<1>(b); + size_t MT_N_b = std::get<2>(b); + size_t MT_K_b = std::get<3>(b); + + // Prefer larger MT_M first + if (MT_M_a != MT_M_b) return MT_M_a > MT_M_b; + // If MT_M is same, prefer larger MT_N + if (MT_N_a != MT_N_b) return MT_N_a > MT_N_b; + // If both MT_M and MT_N are same, prefer larger MT_K + return MT_K_a > MT_K_b; + }); + } + } + + if (print) { + for (const auto& tile : valid_results) { + std::cout << M << "x" << N << "x" << K + << "Selected Macro-Tile: Latency=" << std::get<0>(tile) + << ", MT_M=" << std::get<0>(tile) << ", MT_N=" << std::get<1>(tile) + << ", MT_K=" << std::get<2>(tile) << ", MI_M=" << std::get<3>(tile) + << ", MI_N=" << std::get<4>(tile) << ", MI_K=" << std::get<5>(tile) + << ", Occupancy=" << std::get<6>(tile) << ", WGM=" << std::get<7>(tile) + << ", NonTemporalA=" << std::get<8>(tile) + << ", NonTemporalB=" << std::get<9>(tile) << "\n"; + } + } + + return valid_results; +} + +template +constexpr N safe_ceil_div(N n, D d) { + // Static cast to undo integral promotion. + return static_cast(d == 0 ? 0 : (n / d + (n % d != 0 ? 1 : 0))); +} + +/*! + * \brief Selects the best WGM (maximizing L2 hit rate) given fixed macro tile sizes. + * + * \param[in] hardware - Hardware + * \param[in] M, N, K, batch - Problem + * \param[in] MT_M, MT_N, MT_K - Solution + * \param[in] print - whether to print the final best result + * + * \return best WGMXCC, WGM. + */ +std::tuple select_best_wgm(const hardware_t& hardware, size_t M, size_t N, + size_t K, size_t batch, size_t MT_M, size_t MT_N, + size_t MT_K, int nta, int ntb, size_t skGrid, + bool print) { + // Default is the closest we can get to a square + size_t max_CU_XCD = hardware.N_CU / hardware.NUM_XCD; + size_t defaultWGM = ceil(std::sqrt(max_CU_XCD)); + + // Number of output MTs per split and batch + size_t numMT_M = safe_ceil_div(M, MT_M); + size_t numMT_N = safe_ceil_div(N, MT_N); + size_t numMTs = numMT_M * numMT_N; + + // What SK does -- we already have skGrid so just compute numWaves and splitFactor + auto numWaves = skGrid > numMTs ? safe_ceil_div(skGrid, hardware.N_CU) + : safe_ceil_div(numMTs, hardware.N_CU); + auto splitFactor = safe_ceil_div(skGrid, numMTs); + + // ------------------- + // NonTemporal Cases + // ------------------- + if (nta > 3 && ntb < 4) + return std::make_tuple(hardware.NUM_XCD, 1); + else if (nta < 4 && ntb > 3) + return std::make_tuple(hardware.NUM_XCD, + std::min(max_CU_XCD, safe_ceil_div(numMTs, hardware.NUM_XCD))); + else if (nta > 3 && ntb > 3) + return std::make_tuple(hardware.NUM_XCD, 1); + + // ------------------- + // WGMXCC Prediction + // ------------------- + // Default WGMXCC -- always number of XCD + auto defaultWGMXCC = hardware.NUM_XCD; + bool isWGMXCCset = false; + size_t out_wgmxcc = defaultWGMXCC; + + // Batched GEMMs + if (batch > 1 && !isWGMXCCset) { + // Total tiles including batch count + size_t numTotalTiles = numMTs * batch; + + // if only one MT per each GEMM -> no mapping + // if less than hardware.NUM_XCD total tiles -> no mapping + if (numMTs == 1 || numTotalTiles <= hardware.NUM_XCD) { + out_wgmxcc = 1; + isWGMXCCset = true; + } + // else use the default (num_xcd) + } + + // If we are lucky that the splitFactor is a multiple of NUM_XCD -> no mapping + if ((splitFactor % hardware.NUM_XCD == 0) && !isWGMXCCset) { + out_wgmxcc = 1; + isWGMXCCset = true; + } + + // Small GEMMs + if ((numMTs <= hardware.NUM_XCD) && !isWGMXCCset) { + out_wgmxcc = 1; + isWGMXCCset = true; + } + + // For sizes that we have more than 2 waves of computations, we skip xcc mapping as MALL is + // more important -- matrix should not be skinny + // To avoid regressions, it's set to default, but it should actually be 1! + bool MallIsImportant = (splitFactor == 1 && batch == 1 && numMTs > 2 * hardware.N_CU && + numMT_M > 8 && numMT_N > 8); + if (MallIsImportant && !isWGMXCCset) { + out_wgmxcc = defaultWGMXCC; + isWGMXCCset = true; + } + + // ------------------- + // WGM Prediction + // ------------------- + // Default WGM + bool isWGMset = false; + size_t out_wgm = defaultWGM; + + // shortcut: + // 1. if we have decided to not remap xcc, there is no reason to use wgm + // 2. GEMMs that only have one tile in one dimension don't need wgm + // 3. Batched GEMMs don't need wgm (emprically -> batch count is often large!) + if (((out_wgmxcc == 1 && !MallIsImportant) || numMT_M == 1 || numMT_N == 1 || batch > 1) && + !isWGMset) { + out_wgm = 1; + isWGMset = true; + } + + // For tall cases (M >> N), if we have enough tiles to schedule, we use the number of tiles + // in the smaller dimension as WGM value + if (numMTs > hardware.N_CU && M > 10 * N && numMT_N <= 8) { + out_wgm = numMT_N; + isWGMset = true; + } + + // Cases where we have multiple rounds of computation per each CU + // To avoid regressions, it's set to defaultWGM. However, I think WGM=1 should be the winner + if (MallIsImportant && !isWGMset) { + out_wgm = defaultWGM; + isWGMset = true; + } + + if (!isWGMset) { + size_t numWGs = numWaves * splitFactor * numMTs; + size_t q = numWGs / hardware.NUM_XCD; + size_t r = numWGs % hardware.NUM_XCD; + + std::vector wgmList = {1, 2, 3, 4, 5, 6, 8, 16}; + int bestWGM = 1; + int bestL2 = std::numeric_limits::max(); + for (auto wgm : wgmList) { + auto slabTiles = numMT_M * std::min(wgm, static_cast(numMT_N)); + auto slabCount = safe_ceil_div(numMT_N, wgm); + auto edgeSlabWidth = numMT_N - (slabCount - 1) * wgm; + auto wgmL2Estimate = 0; + auto numXCD = std::min(hardware.NUM_XCD, numWGs); + + // Compute unique loads per L2 tile + for (uint32_t x = 0; x < numWaves * numXCD; ++x) { + // Range of "output tiles" that this xcd takes. + auto xccStart = q * x + (x < r ? x : r); + auto xccEnd = xccStart + q - 1 + (x < r ? 1 : 0); + // xccStart and xccEnd are supposed to be tile IDs + // In case of splitting, they are WG IDs. Modify to get tile IDs + xccStart /= splitFactor; + xccEnd /= splitFactor; + + auto slabStart = xccStart / slabTiles; + auto slabEnd = xccEnd / slabTiles; + + auto firstSlabWidth = (slabStart == slabCount - 1 ? edgeSlabWidth : wgm); + auto firstSlabStartIndex = xccStart % slabTiles; + auto firstSlabStartRow = firstSlabStartIndex / firstSlabWidth; + auto firstSlabEndRow = std::min( + (firstSlabStartIndex + (xccEnd - xccStart)) / firstSlabWidth, numMT_M - 1); + auto rowsInFirstSlab = firstSlabEndRow - firstSlabStartRow + 1; + + auto lastSlabWidth = (slabEnd == slabCount - 1 ? edgeSlabWidth : wgm); + auto lastSlabEndIndex = xccEnd % slabTiles; + auto lastSlabEndRow = lastSlabEndIndex / lastSlabWidth; + auto colsInLastRow = (lastSlabEndIndex % lastSlabWidth) + 1; + auto colsInLastSlab = (lastSlabEndRow > 0 ? lastSlabWidth : colsInLastRow); + + size_t uniqueRows = 0; + size_t uniqueCols = 0; + if (slabEnd == slabStart) { + uniqueRows = lastSlabEndRow - firstSlabStartRow + 1; + uniqueCols = firstSlabWidth; + if (rowsInFirstSlab <= 2) + uniqueCols = std::min(xccEnd - xccStart + 1, firstSlabWidth); + } else { + auto colsInFirstRow = firstSlabWidth - (xccStart % firstSlabWidth); + auto colsInFirstSlab = (rowsInFirstSlab > 1 ? firstSlabWidth : colsInFirstRow); + auto fullSlabs = slabEnd - slabStart - 1; + uniqueRows = + (fullSlabs > 0 ? numMT_M + : std::min(rowsInFirstSlab + lastSlabEndRow + 1, numMT_M)); + uniqueCols = colsInFirstSlab + colsInLastSlab + fullSlabs * wgm; + } + + // Sum up the L2 loads over all XCD + // We should technically multiply by K (or splitted K), but it + // has no effect on sorting + auto xccL2Estimate = uniqueRows * MT_M + uniqueCols * MT_N; + wgmL2Estimate += xccL2Estimate; + } + + // If we have found a better WGM + if (wgmL2Estimate < bestL2) { + bestL2 = wgmL2Estimate; + bestWGM = wgm; + } + + if (print || hardware_t::is_debug_enabled()) + std::cout << "WGM (" << wgm << "), L2Estimate (" << wgmL2Estimate << ")" + << std::endl; + } + + out_wgm = bestWGM; + } + + return std::make_tuple(out_wgmxcc, out_wgm); +} + +// Logic to decide between two MT that are "tied" +std::vector> tie_breaker_macro_tile_sizes( + const std::vector>& top_results, size_t M, size_t N, + size_t K, hardware_t& hardware, + std::function + tie_breaker_fn) { + std::vector> tie_breaker_results; + + for (const auto& res : top_results) { + size_t MT_M = std::get<1>(res); + size_t MT_N = std::get<2>(res); + size_t MT_K = std::get<3>(res); + + // Call user-provided tie-breaking function + double precise_latency = tie_breaker_fn(M, N, K, MT_M, MT_N, MT_K, hardware); + + tie_breaker_results.emplace_back(precise_latency, MT_M, MT_N, MT_K); + } + + // Sort results by precise_latency (ascending order) + std::stable_sort(tie_breaker_results.begin(), tie_breaker_results.end()); + + return tie_breaker_results; +} + +std::vector> +rank_macro_tile_sizes( + size_t M, size_t N, size_t K, bool transA, bool transB, hardware_t& hardware, + const std::vector& MT_list, size_t element_size, data_type_t mi_datatype, + double H_L2, bool print, size_t WGM, + std::function + tie_breaker_fn, + size_t max_cus) { + std::vector> results; + + typedef std::tuple result_tuple; + + for (size_t i = 0; i < MT_list.size(); ++i) { + size_t MT_M = std::get<0>(MT_list[i]); + size_t MT_N = std::get<1>(MT_list[i]); + size_t MT_K = std::get<2>(MT_list[i]); + size_t MI_M = std::get<3>(MT_list[i]); + size_t MI_N = std::get<4>(MT_list[i]); + size_t MI_K = std::get<5>(MT_list[i]); + + if (hardware_t::is_debug_enabled()) { + std::cout << "Evaluating MT_M=" << MT_M << ", MT_N=" << MT_N << ", MT_K=" << MT_K + << ", MI_M=" << MI_M << ", MI_N=" << MI_N << ", MI_K=" << MI_K << "\n"; + } + + if (check_lds_capacity(hardware, MT_M, MT_N, MT_K, element_size)) { + size_t split = 1; + size_t mx_block_size = 0; + double Total_latency = + compute_total_latency(hardware, M, N, K, + 1, // Batch + transA, transB, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K, + element_size * 8, // Element Size A + element_size * 8, // Element Size B + element_size * 8, // Element Size out + mi_datatype, mx_block_size, WGM, 0, 0, 0, split, max_cus); + + results.push_back(std::make_tuple(Total_latency, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K)); + } else if (hardware_t::is_debug_enabled()) { + std::cout << "Skipping MT_M=" << MT_M << ", MT_N=" << MT_N << ", MT_K=" << MT_K + << " due to LDS capacity\n"; + } + } + + // Sort results by Total_latency, from worst (largest latency) to best (smallest latency) + std::stable_sort(results.begin(), results.end(), + [](const result_tuple& a, const result_tuple& b) { + return std::get<0>(a) > std::get<0>(b); + }); + + if (!results.empty()) { + double best_latency = std::get<0>(results.back()); + + std::vector top_results; + for (size_t i = 0; i < results.size(); ++i) { + if (std::abs(std::get<0>(results[i]) - best_latency) < 1e-6) { + top_results.push_back(results[i]); + } + } + + if (top_results.size() > 1) { + if (hardware_t::is_debug_enabled()) { + std::cout << "Tie detected among top-ranked tile sizes. Applying " + "tie-breaker...\n"; + } + + // Compute tie-breaker scores and store them along with the result indices + std::vector> + tie_breaker_scores; // (score, index in top_results) + + for (size_t i = 0; i < top_results.size(); ++i) { + const result_tuple& res = top_results[i]; + size_t MT_M = std::get<1>(res); + size_t MT_N = std::get<2>(res); + size_t MT_K = std::get<3>(res); + size_t MI_M = std::get<4>(res); + size_t MI_N = std::get<5>(res); + size_t MI_K = std::get<6>(res); + double score = tie_breaker_fn(MT_M, MT_N, MT_K, MI_M, MI_N, MI_K, hardware); + + tie_breaker_scores.push_back(std::make_pair(score, i)); + } + + // Now sort the tie_breaker_scores based on score + std::stable_sort(tie_breaker_scores.begin(), tie_breaker_scores.end(), + [](const std::pair& a, + const std::pair& b) { return a.first > b.first; }); + + // Now re-order 'top_results' based on sorted indices + std::vector sorted_top_results; + for (size_t i = 0; i < tie_breaker_scores.size(); ++i) { + size_t idx = tie_breaker_scores[i].second; + sorted_top_results.push_back(top_results[idx]); + } + + // Remove the tied results from 'results' and insert the sorted 'sorted_top_results' + results.erase(std::remove_if(results.begin(), results.end(), + [best_latency](const result_tuple& res) { + return std::abs(std::get<0>(res) - best_latency) < + 1e-6; + }), + results.end()); + + results.insert(results.end(), sorted_top_results.begin(), sorted_top_results.end()); + // No need to re-sort results as total_latency remains same for tied results + } + } + + if (print) { + std::cout << "Total Latency\tMT_M\tMT_N\tMT_K\tMI_M\tMI_N\tMI_K\n"; + for (size_t i = 0; i < results.size(); ++i) { + double latency = std::get<0>(results[i]); + size_t MT_M = std::get<1>(results[i]); + size_t MT_N = std::get<2>(results[i]); + size_t MT_K = std::get<3>(results[i]); + size_t MI_M = std::get<4>(results[i]); + size_t MI_N = std::get<5>(results[i]); + size_t MI_K = std::get<6>(results[i]); + std::cout << std::fixed << std::setprecision(2) << latency << "\t" << MT_M << "\t" + << MT_N << "\t" << MT_K << "\t" << MI_M << "\t" << MI_N << "\t" << MI_K + << "\n"; + } + } + + return results; +} + +double compute_tflops_from_latency(double latency_cycles, size_t M, size_t N, size_t K, + double clock_GHz) { + // Compute total FLOPs + double total_FLOPs = 2.0 * M * N * K; // For GEMM, each multiply-add is 2 FLOPs + // Compute total time in seconds + double cycles_per_second = clock_GHz * 1e9; // 1 GHz = 1e9 cycles per second + double total_time_seconds = latency_cycles / cycles_per_second; + // Compute performance in FLOPS + double FLOPS = total_FLOPs / total_time_seconds; + // Convert to TFLOPS + double TFLOPS = FLOPS / 1e12; // 1 TFLOP = 1e12 FLOPs + + if (hardware_t::is_debug_enabled()) { + std::cout << "Total FLOPs: " << total_FLOPs << "\n"; + std::cout << "Total Time: " << total_time_seconds << " seconds\n"; + std::cout << "Performance: " << FLOPS << " FLOPS\n"; + std::cout << "Achieved Performance: " << TFLOPS << " TFLOPS\n"; + } + + return TFLOPS; +} +} // namespace origami diff --git a/shared/origami/tests/CMakeLists.txt b/shared/origami/tests/CMakeLists.txt index 863c4d12547..822f4853d20 100644 --- a/shared/origami/tests/CMakeLists.txt +++ b/shared/origami/tests/CMakeLists.txt @@ -1,52 +1,39 @@ -################################################################################ -# -# MIT License -# -# Copyright 2025 AMD ROCm(TM) Software -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- -# ies of the Software, and to permit persons to whom the Software is furnished -# to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- -# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- -# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -# -################################################################################ +# Copyright Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT -find_package(Catch2 3 QUIET CONFIG) -if(NOT Catch2_FOUND) - if(ORIGAMI_ENABLE_FETCH) - include(FetchContent) - fetchcontent_declare( - Catch2 GIT_REPOSITORY https://github.com/catchorg/Catch2.git GIT_TAG devel - ) - fetchcontent_makeavailable(Catch2) - else() - message(FATAL_ERROR "Failed to find Catch2") - endif() -endif() +find_package(GTest REQUIRED) +find_package(Boost REQUIRED COMPONENTS filesystem) add_executable(origami-tests) -target_sources( - origami-tests PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/test_gemm.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/test_origami.cpp" +target_sources(origami-tests + PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/origami_gtest.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/test_tile_ordering_issue.cpp" + # "${CMAKE_CURRENT_SOURCE_DIR}/test_variance_issue.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/test_negative_occupancy.cpp" ) -target_include_directories(origami-tests PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/include") +target_include_directories(origami-tests + PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/include" +) -target_link_libraries(origami-tests PRIVATE roc::origami Catch2::Catch2WithMain) +configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/origami_gtest.yaml" + "${CMAKE_CURRENT_BINARY_DIR}/origami_gtest.yaml" + COPYONLY +) -include(CTest) -include(Catch) -catch_discover_tests(origami-tests) +target_link_libraries(origami-tests + PRIVATE + roc::origami + Boost::filesystem + GTest::gtest + GTest::gtest_main +) + +gtest_discover_tests(origami-tests + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + TIMEOUT 60 +) diff --git a/shared/origami/tests/include/common.hpp b/shared/origami/tests/include/common.hpp deleted file mode 100644 index 3c18f1a3b48..00000000000 --- a/shared/origami/tests/include/common.hpp +++ /dev/null @@ -1,114 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright 2025 AMD ROCm(TM) Software - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ - -#pragma once -#include -#include "origami/gemm.hpp" -#include "origami/hardware.hpp" -#include "origami/origami.hpp" -#include "origami/streamk.hpp" - -// List of GPU architectures to test -inline const std::vector test_architectures = {942, 950}; - -// Helper function to construct problem_t -inline origami::problem_t make_problem(size_t m, - size_t n, - size_t k, - origami::transpose_t a_trans = origami::transpose_t::T, - origami::transpose_t b_trans = origami::transpose_t::N, - size_t batch = 1, - int mx_block_size = 0) { - origami::problem_t problem; - problem.size.m = m; - problem.size.n = n; - problem.size.k = k; - problem.batch = batch; - problem.a_transpose = a_trans; - problem.b_transpose = b_trans; - problem.a_dtype = origami::data_type_t::BFloat16; - problem.b_dtype = origami::data_type_t::BFloat16; - problem.c_dtype = origami::data_type_t::BFloat16; - problem.d_dtype = origami::data_type_t::BFloat16; - problem.mi_dtype = origami::data_type_t::BFloat16; - problem.a_mx_block_size = mx_block_size; - problem.b_mx_block_size = mx_block_size; - return problem; -} - -// Helper function to construct config_t -inline origami::config_t make_config(size_t mt_m, - size_t mt_n, - size_t mt_k, - size_t mi_m = 16, - size_t mi_n = 16, - size_t mi_k = 16, - int wgm = 1, - int occupancy = 1, - int non_temporal_a = 0, - int non_temporal_b = 0) { - origami::config_t config; - config.mt.m = mt_m; - config.mt.n = mt_n; - config.mt.k = mt_k; - config.mi.m = mi_m; - config.mi.n = mi_n; - config.mi.k = mi_k; - config.occupancy = occupancy; - config.workgroup_mapping = wgm; - config.cache_hints_a = non_temporal_a; - config.cache_hints_b = non_temporal_b; - return config; -} - -// Helper function to construct hardware_t with all parameters -inline origami::hardware_t make_hardware( - int gpu_arch, - size_t n_cu = 304, - size_t lds_capacity = 65536, - size_t num_xcd = 8, - double mem1_perf_ratio = 1.0, - double mem2_perf_ratio = 1.0, - double mem3_perf_ratio = 1.0, - size_t l2_capacity = 4000000, - double compute_clock_ghz = 1.0, - size_t parallel_mi_cu = 1, - std::tuple mem_bw_per_wg_coefficients = std::make_tuple(0, 0.015, 0)) { - const std::string gpu_arch_str = "gfx" + std::to_string(gpu_arch); - auto gpu_arch_enum = origami::hardware_t::arch_name_to_enum(gpu_arch_str); - - return origami::hardware_t(gpu_arch_enum, - n_cu, - lds_capacity, - num_xcd, - mem1_perf_ratio, - mem2_perf_ratio, - mem3_perf_ratio, - l2_capacity, - compute_clock_ghz, - parallel_mi_cu, - mem_bw_per_wg_coefficients); -} diff --git a/shared/origami/tests/include/testing_origami.hpp b/shared/origami/tests/include/testing_origami.hpp new file mode 100644 index 00000000000..f89603b1c53 --- /dev/null +++ b/shared/origami/tests/include/testing_origami.hpp @@ -0,0 +1,657 @@ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once +#include "origami/gemm.hpp" +#include "origami/hardware.hpp" +#include "origami/utils.hpp" +#include +#include +#include +#include + +struct InputWithExpected +{ + std::map values; + std::optional expected; + std::optional expected_gt; + std::optional expected_lt; +}; + +struct MyTestData +{ + std::string name; + std::vector inputs; +}; + +// Parameterized test class declaration +class AnalyticalGtest : public ::testing::TestWithParam +{ +}; + +void ComputeLoads(int MT_M, int MT_N, int MT_K, const std::optional expected) +{ + auto a_loads = origami::compute_A_loads(MT_M, MT_K); + auto b_loads = origami::compute_B_loads(MT_N, MT_K); + EXPECT_EQ(a_loads, expected); + EXPECT_EQ(b_loads, expected); +} + +void EstimateL2Hit(const origami::hardware_t& hardware, + int M, + int N, + int K, + int batch, + int MT_M, + int MT_N, + int MT_K, + size_t element_size, + int splittingFactor, + const std::optional expected_gt, + const std::optional expected_lt) +{ + double l2_hit; + for(int i = 1; i < 1025; i++) + { + l2_hit = origami::estimate_l2_hit( + hardware, M, N, K, batch, MT_M, MT_N, MT_K, element_size, i, splittingFactor); + EXPECT_GT(l2_hit, expected_gt); + EXPECT_LT(l2_hit, expected_lt); + } +} + +void ComputeNumMatrixInstructions(const origami::hardware_t& hardware, + int MT_M, + int MT_N, + int MT_K, + int MI_M, + int MI_N, + int MI_K, + const std::optional expected) +{ + auto NumberMatrixInstructions + = origami::compute_number_matrix_instructions(hardware, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K); + EXPECT_EQ(NumberMatrixInstructions, expected); +} + +void ComputeMTComputeLatency(const origami::hardware_t& hardware, + size_t M, + size_t N, + size_t K, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, + size_t element_size_B, + const std::optional expected, + const std::optional expected_gt) +{ + auto latency = origami::compute_mt_compute_latency(hardware, + M, + N, + K, + transA, + transB, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + origami::data_type_t::BFloat16); + + if(expected.has_value()) + EXPECT_EQ(latency, expected); + else if(expected_gt.has_value()) + EXPECT_GT(latency, expected_gt); +} + +void ComputeMemoryLatency(const origami::hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t element_size_A, + size_t element_size_B, + size_t mx_block_size, + int wgm, + int numActiveCUs, + int splittingFactor) +{ + auto mem_latency_small = origami::compute_memory_latency(hardware, + M, + N, + K, + transA, + transB, + batch, + MT_M, + MT_N, + MT_K, + element_size_A, + element_size_B, + mx_block_size, + wgm, + numActiveCUs, + splittingFactor); + + auto mem_latency_large = origami::compute_memory_latency(hardware, + M, + N, + K, + transA, + transB, + batch, + MT_M * 2, + MT_N * 2, + MT_K * 2, + element_size_A, + element_size_B, + mx_block_size, + wgm, + numActiveCUs, + splittingFactor); + + EXPECT_LT(mem_latency_small, mem_latency_large); +} + +void ComputeTileLatency(const origami::hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, //In bits + size_t element_size_B, //In bits, + size_t element_size_out, //In bits + size_t mx_block_size, + int WGM, + size_t numActiveCUs, + size_t splittingFactor) +{ + auto tile_latency_small = origami::compute_tile_latency(hardware, + M, + N, + K, + batch, + transA, + transB, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + element_size_out, + origami::data_type_t::BFloat16, + mx_block_size, + WGM, + 1, + numActiveCUs, + splittingFactor); + + auto tile_latency_large = origami::compute_tile_latency(hardware, + M, + N, + K, + batch, + transA, + transB, + MT_M * 2, + MT_N * 2, + MT_K * 2, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + element_size_out, + origami::data_type_t::BFloat16, + mx_block_size, + WGM, + 1, + numActiveCUs, + splittingFactor); + + EXPECT_GT(tile_latency_large, tile_latency_small); +} + +void ComputeWaveLatency(const origami::hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, //In bits + size_t element_size_B, //In bits, + size_t element_size_out, //In bits + size_t mx_block_size, + int WGM, + size_t numActiveCUs, + size_t splittingFactor) +{ + auto tile_latency = origami::compute_tile_latency(hardware, + M, + N, + K, + batch, + transA, + transB, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + element_size_out, + origami::data_type_t::BFloat16, + mx_block_size, + WGM, + 1, + numActiveCUs, + splittingFactor); + auto wave_latency = origami::compute_wave_latency(hardware, + M, + N, + K, + batch, + transA, + transB, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + element_size_out, + origami::data_type_t::BFloat16, + mx_block_size, + WGM, + 1, + numActiveCUs, + splittingFactor); + EXPECT_DOUBLE_EQ(wave_latency, tile_latency); +} + +void ComputeTotalLatency(const origami::hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, //In bits + size_t element_size_B, //In bits, + size_t element_size_out, //In bits + size_t mx_block_size, + int WGM, + size_t splittingFactor, + size_t max_cus) +{ + double latency_cycles_small = origami::compute_total_latency(hardware, + M, + N, + K, + batch, + transA, + transB, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + element_size_out, + origami::data_type_t::BFloat16, + mx_block_size, + WGM, + 0, + 0, + splittingFactor, + max_cus); + + double latency_cycles_large = origami::compute_total_latency(hardware, + M * 2, + N * 2, + K * 2, + batch, + transA, + transB, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + element_size_out, + origami::data_type_t::BFloat16, + mx_block_size, + WGM, + 0, + 0, + splittingFactor, + max_cus); + EXPECT_LT(latency_cycles_small, latency_cycles_large); +} + +void ComputePerfGflops(size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, //In bits + size_t element_size_B, //In bits, + size_t element_size_out, //In bits + int WGM, + size_t max_cus) +{ + auto gfx942arch = origami::hardware_t::arch_name_to_enum("gfx942"); + auto gfx942_slow = origami::hardware_t( + gfx942arch, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.4, 1, std::make_tuple(0, 0.015, 0)); + auto gfx942_fast = origami::hardware_t( + gfx942arch, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.8, 1, std::make_tuple(0, 0.015, 0)); + double flops_slow = origami::compute_perf_gflops(gfx942_slow, + M, + N, + K, + batch, + transA, + transB, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + element_size_out, + origami::data_type_t::BFloat16, + WGM, + max_cus); + double flops_fast = origami::compute_perf_gflops(gfx942_fast, + M, + N, + K, + batch, + transA, + transB, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + element_size_out, + origami::data_type_t::BFloat16, + WGM, + max_cus); + EXPECT_GT(flops_fast, flops_slow); // faster clock = higher flops +} + +void EstimateMallHit(const origami::hardware_t& hardware, + int M, + int N, + int K, + int batch, + int MT_M, + int MT_N, + int MT_K, + size_t element_size, + size_t numActiveCUs, + size_t splittingFactor, + const std::optional expected_gt) +{ + double mall_hit; + for(int i = 1; i < 1025; i++) + { + mall_hit = origami::estimate_mall_hit(hardware, + M, + N, + K, + batch, + MT_M, + MT_N, + MT_K, + element_size, + i, + numActiveCUs, + splittingFactor); + EXPECT_GT(mall_hit, expected_gt); + } +} + +void CheckLDSCapacity( + const origami::hardware_t& hardware, int MT_M, int MT_N, int MT_K, size_t element_size) +{ + auto fit_lds_memory = origami::check_lds_capacity(hardware, MT_M, MT_N, MT_K, element_size); + EXPECT_TRUE(fit_lds_memory); +} + +// hardware_t +void HardwareArchEnum(const std::string gpuArchNumber) +{ + auto gpuArchEnum = origami::hardware_t::arch_name_to_enum("gfx" + gpuArchNumber); + EXPECT_EQ(gpuArchEnum, origami::hardware_t::architecture_t::gfx942); +} + +// Utils +void BestGridSize(const origami::hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t MT_M, + size_t MT_N, + size_t MT_K, + size_t MI_M, + size_t MI_N, + size_t MI_K, + size_t element_size_A, + size_t element_size_B, + size_t element_size_out, + size_t mx_block_size, + double H_L2, + int WGM, + size_t biggest_allowable_split, + size_t max_cus, + const std::optional expected_gt) +{ + size_t grid_size = origami::select_best_grid_size(M, + N, + K, + batch, + transA, + transB, + hardware, + MT_M, + MT_N, + MT_K, + MI_M, + MI_N, + MI_K, + element_size_A, + element_size_B, + element_size_out, + origami::data_type_t::BFloat16, + mx_block_size, + H_L2, + WGM, + biggest_allowable_split, + max_cus); + EXPECT_GT(grid_size, expected_gt); +} + +void BestMacroTileSize(const origami::hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + bool transA, + bool transB, + size_t element_size_A, //In bits + size_t element_size_B, //In bits + size_t element_size_out, //In bits + size_t mx_block_size, + double H_L2, + size_t WGM, + size_t max_cus) +{ + const std::vector> + MT_list = {{256, 256, 32, 32, 32, 8, 1, 6, 0, 0}, + {128, 128, 64, 32, 32, 8, 1, 6, 0, 0}, + {64, 64, 64, 32, 32, 8, 1, 6, 0, 0}}; + auto results = select_best_macro_tile_size(M, + N, + K, + batch, + transA, + transB, + hardware, + MT_list, + element_size_A, + element_size_B, + element_size_out, + origami::data_type_t::BFloat16, + mx_block_size, + H_L2, + false, + WGM, + max_cus); + + EXPECT_EQ(results.size(), MT_list.size()); + for(int i = 0; i < results.size() - 1; i++) + EXPECT_LT(std::get<0>(results[i]), std::get<0>(results[i + 1])); +} + +void BestWGM(const origami::hardware_t& hardware, + size_t M, + size_t N, + size_t K, + size_t batch, + size_t MT_M, + size_t MT_N, + size_t MT_K) +{ + // Assume no nt + auto nta = 0; + auto ntb = 0; + // Assume DP + auto skGrid = (M + MT_M - 1) / MT_M * (N + MT_N - 1) / MT_N; + + auto [best_wgmxcc_large_tile, best_wgm_large_tile] = + select_best_wgm(hardware, + M, + N, + K, + batch, + MT_M, + MT_N, + MT_K, + nta, + ntb, + skGrid, + false); + + auto [best_wgmxcc_small_tile, best_wgm_small_tile] = + select_best_wgm(hardware, + M / 4, + N / 4, + K, + batch, + MT_M, + MT_N, + MT_K * 2, + nta, + ntb, + skGrid, + false); + + auto [best_wgmxcc_nonsquare_tile, best_wgm_nonsquare_tile] = + select_best_wgm(hardware, + 1024, + 5120, + K, + batch, + MT_M, + MT_N, + MT_K, + nta, + ntb, + skGrid, + false); + + EXPECT_EQ(best_wgmxcc_large_tile, best_wgmxcc_small_tile); + EXPECT_GT(best_wgm_large_tile, best_wgm_small_tile); + EXPECT_NE(best_wgm_large_tile, best_wgm_nonsquare_tile); +} + +void UtilsTFlopsFromLatency(size_t M, size_t N, size_t K, double latency_cycles, double clock_GHz) +{ + auto tflops = origami::compute_tflops_from_latency(latency_cycles, M, N, K, clock_GHz); + double Expected = 1.99; + EXPECT_LT(std::abs(tflops - Expected) / std::abs(Expected), 0.01); +} diff --git a/shared/origami/tests/origami_gtest.cpp b/shared/origami/tests/origami_gtest.cpp new file mode 100644 index 00000000000..30e4146bbb0 --- /dev/null +++ b/shared/origami/tests/origami_gtest.cpp @@ -0,0 +1,353 @@ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "testing_origami.hpp" +#include +#include +#include +#include + +#ifndef _WIN32 +#include +#include +#else +#include +#endif + +// Returns the directory path where the current executable resides. +// Adds a trailing slash ('/' on Linux, '\' on Windows) for easy file concatenation. +std::string getExecutableDir() { +#ifndef _WIN32 + // Linux branch + + char result[PATH_MAX]; // Buffer to store the path + ssize_t count = readlink("/proc/self/exe", result, PATH_MAX); + // readlink reads the symbolic link /proc/self/exe, which points to the current executable + + if (count == -1) { + // If readlink fails, return empty string + return ""; + } + + result[count] = '\0'; // Null-terminate the buffer + std::string fullPath(result); // Convert to std::string + + // Find the position of the last slash ('/') in the path + // This separates the directory from the binary name + size_t pos = fullPath.find_last_of('/'); + + // Extract the directory portion + std::string dir = (pos != std::string::npos) ? fullPath.substr(0, pos) : fullPath; + + // Ensure the directory string ends with a slash + if (!dir.empty() && dir.back() != '/') + dir += '/'; + + return dir; + +#else + // Windows branch + + char path[MAX_PATH]; // Buffer to store the path + DWORD length = GetModuleFileNameA(NULL, path, MAX_PATH); + // GetModuleFileNameA returns the full path of the current executable + + if (length == 0) { + // Failed to get the executable path + return ""; + } + + std::string fullPath(path, length); // Convert to std::string + + // Find the position of the last backslash ('\') or forward slash ('/') + size_t pos = fullPath.find_last_of("\\/"); + + // Extract the directory portion + std::string dir = (pos != std::string::npos) ? fullPath.substr(0, pos) : fullPath; + + // Ensure the directory string ends with a backslash + if (!dir.empty() && dir.back() != '\\' && dir.back() != '/') + dir += '\\'; + + return dir; +#endif +} + +// Parse origami_gtest.yaml to get the test data +std::vector parseYamlManually(const std::string& filename) +{ + std::string YamlfullPath = getExecutableDir() + filename; + std::ifstream file(YamlfullPath); + if(!file) + { + std::cerr << "Failed to open file: " << YamlfullPath << std::endl; + return {}; + } + + std::string line; + std::vector tests; + MyTestData current; + enum class State + { + None, + Inputs + } state + = State::None; + int line_number = 0; + + while(std::getline(file, line)) + { + line_number++; + line.erase(0, line.find_first_not_of(" \t\r\n")); + line.erase(line.find_last_not_of(" \t\r\n") + 1); + + if(line.empty() || line[0] == '#') + continue; + + if(line.rfind("- name:", 0) == 0) + { + if(!current.name.empty()) + tests.push_back(current); + current = MyTestData{}; + current.name = line.substr(7); + current.name.erase(0, current.name.find_first_not_of(" \t")); + state = State::None; + } + else if(line.rfind("inputs:", 0) == 0) + { + state = State::Inputs; + } + else if(state == State::Inputs && line.rfind("- {", 0) == 0) + { + std::string inner = line.substr(3); + if(!inner.empty() && inner.back() == '}') + inner.pop_back(); + + std::map values; + std::optional expected; + std::optional expected_gt; + std::optional expected_lt; + + std::stringstream ss(inner); + std::string pair; + while(std::getline(ss, pair, ',')) + { + auto colon = pair.find(':'); + if(colon == std::string::npos) + continue; + + std::string key = pair.substr(0, colon); + std::string val = pair.substr(colon + 1); + key.erase(0, key.find_first_not_of(" \t")); + key.erase(key.find_last_not_of(" \t") + 1); + val.erase(0, val.find_first_not_of(" \t")); + val.erase(val.find_last_not_of(" \t") + 1); + + try + { + int num = std::stoi(val); + if(key == "expected") + expected = num; + else if(key == "expected_gt") + expected_gt = num; + else if(key == "expected_lt") + expected_lt = num; + else + values[key] = num; + } + catch(...) + { + std::cerr << "Invalid number in line " << line_number << ": " << val + << std::endl; + } + } + + current.inputs.push_back(InputWithExpected{values, expected, expected_gt, expected_lt}); + } + } + + if(!current.name.empty()) + tests.push_back(current); + + return tests; +} + +TEST_P(AnalyticalGtest, DynamicDispatch) +{ + const MyTestData& test = GetParam(); + + const std::string gpuArchNumber = std::to_string(test.inputs[0].values.at("gpu_arch")); + auto gpuArchEnum = origami::hardware_t::arch_name_to_enum("gfx" + gpuArchNumber); + + //TODO: Hardcoding numbers for gfx942. Future archs could be added here with if else loop. + auto gpuInfo = origami::hardware_t( + gpuArchEnum, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.0, 1, std::make_tuple(0, 0.015, 0)); + + if(test.name == "ComputeLoads") + { + for(const auto& input_case : test.inputs) + { + ComputeLoads(input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.expected); + } + } + else if(test.name == "EstimateL2Hit") + { + for(const auto& input_case : test.inputs) + { + EstimateL2Hit(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), + input_case.values.at("MT_M"), input_case.values.at("MT_N"), input_case.values.at("MT_K"), input_case.values.at("element_size"), + input_case.values.at("splittingFactor"),input_case.expected_gt, input_case.expected_lt); + } + } + else if(test.name == "ComputeNumMatrixInstructions") + { + for(const auto& input_case : test.inputs) + { + ComputeNumMatrixInstructions(gpuInfo, input_case.values.at("MT_M"), input_case.values.at("MT_N"), input_case.values.at("MT_K"), + input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), input_case.expected); + } + } + else if(test.name == "ComputeMTComputeLatency") + { + for(const auto& input_case : test.inputs) + { + ComputeMTComputeLatency(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), + input_case.values.at("transA"), input_case.values.at("transB"),input_case.values.at("MT_M"), input_case.values.at("MT_N"), + input_case.values.at("MT_K"), input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), + input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), input_case.expected, input_case.expected_gt); + } + } + else if(test.name == "ComputeMemoryLatency") + { + for(const auto& input_case : test.inputs) + { + ComputeMemoryLatency(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), + input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("MT_M"), input_case.values.at("MT_N"), + input_case.values.at("MT_K"), input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), + input_case.values.at("mx_block_size"), input_case.values.at("wgm"), input_case.values.at("numActiveCUs"), input_case.values.at("splittingFactor")); + } + } + else if(test.name == "ComputeTileLatency") + { + for(const auto& input_case : test.inputs) + { + ComputeTileLatency(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), + input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("MT_M"), input_case.values.at("MT_N"), + input_case.values.at("MT_K"), input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), + input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), input_case.values.at("element_size_out"), + input_case.values.at("mx_block_size"), input_case.values.at("wgm"), input_case.values.at("numActiveCUs"), input_case.values.at("splittingFactor")); + } + } + else if(test.name == "ComputeWaveLatency") + { + for(const auto& input_case : test.inputs) + { + ComputeWaveLatency(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), + input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("MT_M"), input_case.values.at("MT_N"), + input_case.values.at("MT_K"), input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), + input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), input_case.values.at("element_size_out"), + input_case.values.at("mx_block_size"), input_case.values.at("wgm"), input_case.values.at("numActiveCUs"), input_case.values.at("splittingFactor")); + } + } + else if(test.name == "ComputeTotalLatency") + { + for(const auto& input_case : test.inputs) + { + ComputeTotalLatency(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), + input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("MT_M"), input_case.values.at("MT_N"), + input_case.values.at("MT_K"), input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), + input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), input_case.values.at("element_size_out"), + input_case.values.at("mx_block_size"), input_case.values.at("wgm"), input_case.values.at("splittingFactor"), + input_case.values.at("max_cus")); + } + } + else if(test.name == "ComputePerfGflops") + { + for(const auto& input_case : test.inputs) + { + ComputePerfGflops(input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), + input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("MT_M"), input_case.values.at("MT_N"), + input_case.values.at("MT_K"), input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), + input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), input_case.values.at("element_size_out"), + input_case.values.at("WGM"), input_case.values.at("max_cus")); + } + } + else if(test.name == "EstimateMallHit") + { + for(const auto& input_case : test.inputs) + { + EstimateMallHit(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), + input_case.values.at("MT_M"), input_case.values.at("MT_N"), input_case.values.at("MT_K"),input_case.values.at("element_size_A"), + input_case.values.at("numActiveCUs"), input_case.values.at("splittingFactor"), input_case.expected_gt); + } + } + else if(test.name == "CheckLDSCapacity") + { + for(const auto& input_case : test.inputs) + { + CheckLDSCapacity(gpuInfo, input_case.values.at("MT_M"), input_case.values.at("MT_N"), input_case.values.at("MT_K"), input_case.values.at("element_size")); + } + } + else if(test.name == "HardwareArchEnum") + { + for(const auto& input_case : test.inputs) + { + HardwareArchEnum(std::to_string(test.inputs[0].values.at("gpu_arch"))); + } + } + else if(test.name == "BestGridSize") + { + for(const auto& input_case : test.inputs) + { + BestGridSize(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), + input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("MT_M"), input_case.values.at("MT_N"), + input_case.values.at("MT_K"), input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), + input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), input_case.values.at("element_size_out"), + input_case.values.at("mx_block_size"), input_case.values.at("H_L2"), input_case.values.at("WGM"), + input_case.values.at("biggest_allowable_split"), input_case.values.at("max_cus"), input_case.expected_gt); + } + } + else if(test.name == "BestMacroTileSize") + { + for(const auto& input_case : test.inputs) + { + BestMacroTileSize(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), + input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("element_size_A"), + input_case.values.at("element_size_B"), input_case.values.at("element_size_out"), input_case.values.at("mx_block_size"), + input_case.values.at("H_L2"), input_case.values.at("WGM"), input_case.values.at("max_cus")); + } + } + else if(test.name == "BestWGM") + { + for(const auto& input_case : test.inputs) + { + BestWGM(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), + input_case.values.at("MT_M"), input_case.values.at("MT_N"), input_case.values.at("MT_K")); + } + } + else if(test.name == "UtilsTFlopsFromLatency") + { + for(const auto& input_case : test.inputs) + { + UtilsTFlopsFromLatency(input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), + input_case.values.at("latency_cycles"), input_case.values.at("clock_GHz")); + } + } + else + { + FAIL() << "Unknown test name: " << test.name; + } +} + +// Instantiate tests using manual parser +INSTANTIATE_TEST_SUITE_P(AnalyticalYamlTests, + AnalyticalGtest, + ::testing::ValuesIn(parseYamlManually("origami_gtest.yaml")), + [](const ::testing::TestParamInfo& info) { + std::string name = info.param.name; + for(auto& c : name) + if(!std::isalnum(c)) + c = '_'; + return name; + }); + diff --git a/shared/origami/tests/origami_gtest.yaml b/shared/origami/tests/origami_gtest.yaml new file mode 100644 index 00000000000..6e6975667ac --- /dev/null +++ b/shared/origami/tests/origami_gtest.yaml @@ -0,0 +1,70 @@ + +tests: + - name: ComputeLoads + inputs: + - { M: 128, N: 128, K: 64, gpu_arch: 942, expected: 8192 } + + - name: EstimateL2Hit + inputs: + - { M: 4096, N: 4096, K: 1024, batch: 1, MT_M: 256, MT_N: 256, MT_K: 64, element_size: 16, splittingFactor: 3, gpu_arch: 942, expected_gt: 0.0, expected_lt: 1.0} + + - name: ComputeNumMatrixInstructions + inputs: + - { MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 16, MI_N: 16, MI_K: 16, gpu_arch: 942, expected: 256} # 8 * 8 * 4 + - { MT_M: 16, MT_N: 16, MT_K: 64, MI_M: 16, MI_N: 16, MI_K: 32, gpu_arch: 942, expected: 2} # 1 * 1 * 2 + + - name: ComputeMTComputeLatency + inputs: + - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 0, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 4096} # 8 * 8 * 4 + - { M: 4096, N: 4096, K: 1024, transA: 0, transB: 1, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 4096} # 8 * 8 * 4 + - { M: 4096, N: 4096, K: 1024, transA: 0, transB: 0, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 4096} # 8 * 8 * 4 + - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 1, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 4096} # 8 * 8 * 4 + - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 1, MT_M: 128, MT_N: 128, MT_K: 32, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 2048} # 8 * 8 * 4 + - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 1, MT_M: 224, MT_N: 224, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected_gt: 12543} # 8 * 8 * 4 + - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 1, MT_M: 128, MT_N: 32, MT_K: 32, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 512} # 8 * 8 * 4 + - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 1, MT_M: 32, MT_N: 128, MT_K: 32, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 512} # 8 * 8 * 4 + + - name: ComputeMemoryLatency + inputs: + - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 0, batch: 1, MT_M: 128, MT_N: 128, MT_K: 64, element_size_A: 16, element_size_B: 16, mx_block_size: 0, wgm: 8, numActiveCUs: 304, splittingFactor: 2, gpu_arch: 942} + + - name: ComputeTileLatency + inputs: + - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 0, batch: 2, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, element_size_out: 32, mx_block_size: 0, wgm: 6, numActiveCUs: 304, splittingFactor: 3, gpu_arch: 942} + + - name: ComputeWaveLatency + inputs: + - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 0, batch: 2, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, element_size_out: 32, mx_block_size: 0, wgm: 8, numActiveCUs: 304, splittingFactor: 4, gpu_arch: 942} + + - name: ComputeTotalLatency + inputs: + - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 0, batch: 2, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, element_size_out: 32, mx_block_size: 0, wgm: 1, splittingFactor: 6, max_cus: 0, gpu_arch: 942} + + - name: ComputePerfGflops + inputs: + - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 0, batch: 2, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, element_size_out: 32, WGM: 1, max_cus: 0, gpu_arch: 942} + + - name: EstimateMallHit + inputs: + - { M: 4096, N: 4096, K: 1024, batch: 1, MT_M: 256, MT_N: 256, MT_K: 64, numActiveCUs: 304, splittingFactor: 8, gpu_arch: 942, expected_gt: 0, element_size_A: 16} + + - name: HardwareArchEnum + inputs: + - { gpu_arch: 942} + + - name: BestGridSize + inputs: + - { M: 1024, N: 1024, K: 4096, batch: 1, transA: 1, transB: 0, MT_M: 256, MT_N: 256, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, element_size_out: 32, mx_block_size: 0, H_L2: 0.0, WGM: 1, biggest_allowable_split: 20, gpu_arch: 942, max_cus: 0, expected_gt: 15} + + - name: BestMacroTileSize + inputs: + - { M: 1024, N: 1024, K: 4096, batch: 1, transA: 1, transB: 0, element_size_A: 16, element_size_B: 16, element_size_out: 32, mx_block_size: 0, H_L2: 0.0, WGM: 1, max_cus: 0, gpu_arch: 942} + - { M: 4, N: 4, K: 0, batch: 1, transA: 1, transB: 0, element_size_A: 16, element_size_B: 16, element_size_out: 32, mx_block_size: 0, H_L2: 0.0, WGM: 1, max_cus: 0, gpu_arch: 942} + + - name: BestWGM + inputs: + - { M: 4096, N: 4096, K: 8192, batch: 1, MT_M: 256, MT_N: 256, MT_K: 32, MI_M: 32, MI_N: 32, MI_K: 8, element_size: 16, H_L2: 0.0, gpu_arch: 942} + + - name: UtilsTFlopsFromLatency + inputs: + - { M: 10, N: 10, K: 10, batch: 1, latency_cycles: 256, clock_GHz: 256, gpu_arch: 942} diff --git a/shared/origami/tests/test_gemm.cpp b/shared/origami/tests/test_gemm.cpp deleted file mode 100644 index 2d926aaba99..00000000000 --- a/shared/origami/tests/test_gemm.cpp +++ /dev/null @@ -1,236 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright 2025 AMD ROCm(TM) Software - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ - -#include -#include -#include "common.hpp" - -using Catch::Approx; - -// Test functions for gemm.hpp/cpp - -TEST_CASE("GEMM: compute_num_matrix_instructions", "[gemm]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - 128x128x64 with 16x16x16") { - auto hardware = make_hardware(gpu_arch); - origami::dim3_t mt{128, 128, 64}; - origami::dim3_t mi{16, 16, 16}; - auto num_instructions = origami::compute_number_matrix_instructions(mt, mi); - REQUIRE(num_instructions == 256); // 8 * 8 * 4 - } - - DYNAMIC_SECTION("gfx" << gpu_arch << " - 16x16x64 with 16x16x32") { - auto hardware = make_hardware(gpu_arch); - origami::dim3_t mt{16, 16, 64}; - origami::dim3_t mi{16, 16, 32}; - auto num_instructions = origami::compute_number_matrix_instructions(mt, mi); - REQUIRE(num_instructions == 2); // 1 * 1 * 2 - } - } -} - -TEST_CASE("GEMM: compute_mt_compute_latency", "[gemm]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - transA=T transB=N") { - auto hardware = make_hardware(gpu_arch); - auto problem = - make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::N); - auto config = make_config(128, 128, 64, 32, 32, 8, 1); - - auto latency = origami::compute_mt_compute_latency(problem, hardware, config); - REQUIRE(latency == 4096); - } - - DYNAMIC_SECTION("gfx" << gpu_arch << " - transA=N transB=T") { - auto hardware = make_hardware(gpu_arch); - auto problem = - make_problem(4096, 4096, 1024, origami::transpose_t::N, origami::transpose_t::T); - auto config = make_config(128, 128, 64, 32, 32, 8, 1); - - auto latency = origami::compute_mt_compute_latency(problem, hardware, config); - REQUIRE(latency == 4096); - } - - DYNAMIC_SECTION("gfx" << gpu_arch << " - transA=N transB=N") { - auto hardware = make_hardware(gpu_arch); - auto problem = - make_problem(4096, 4096, 1024, origami::transpose_t::N, origami::transpose_t::N); - auto config = make_config(128, 128, 64, 32, 32, 8, 1); - - auto latency = origami::compute_mt_compute_latency(problem, hardware, config); - REQUIRE(latency == 4096); - } - - DYNAMIC_SECTION("gfx" << gpu_arch << " - transA=T transB=T") { - auto hardware = make_hardware(gpu_arch); - auto problem = - make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::T); - auto config = make_config(128, 128, 64, 32, 32, 8, 1); - - auto latency = origami::compute_mt_compute_latency(problem, hardware, config); - REQUIRE(latency == 4096); - } - - DYNAMIC_SECTION("gfx" << gpu_arch << " - different MT_K") { - auto hardware = make_hardware(gpu_arch); - auto problem = make_problem(4096, 4096, 1024); - auto config = make_config(128, 128, 32, 32, 32, 8, 1); - - auto latency = origami::compute_mt_compute_latency(problem, hardware, config); - REQUIRE(latency == 2048); - } - - DYNAMIC_SECTION("gfx" << gpu_arch << " - larger tile") { - auto hardware = make_hardware(gpu_arch); - auto problem = make_problem(4096, 4096, 1024); - auto config = make_config(224, 224, 64, 32, 32, 8, 1); - - auto latency = origami::compute_mt_compute_latency(problem, hardware, config); - REQUIRE(latency > 12543); - } - } -} - -TEST_CASE("GEMM: compute_memory_latency", "[gemm]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - verify smaller tiles have lower latency") { - auto hardware = make_hardware(gpu_arch); - auto problem = - make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::N, 1); - auto config_small = make_config(128, 128, 64, 32, 32, 8, 8); - auto config_large = make_config(256, 256, 128, 32, 32, 8, 8); - - auto mem_latency_small = - origami::compute_memory_latency(problem, hardware, config_small, 304, 2); - auto mem_latency_large = - origami::compute_memory_latency(problem, hardware, config_large, 304, 2); - - REQUIRE(mem_latency_small < mem_latency_large); - } - } -} - -TEST_CASE("GEMM: compute_tile_latency", "[gemm]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - verify larger tiles have higher latency") { - auto hardware = make_hardware(gpu_arch); - auto problem = - make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::N, 2); - auto config_small = make_config(128, 128, 64, 32, 32, 8, 6); - auto config_large = make_config(256, 256, 128, 32, 32, 8, 6); - - auto tile_latency_small = - origami::compute_tile_latency(problem, hardware, config_small, 304, 3); - auto tile_latency_large = - origami::compute_tile_latency(problem, hardware, config_large, 304, 3); - - REQUIRE(tile_latency_large > tile_latency_small); - } - } -} - -TEST_CASE("GEMM: compute_timestep_latency", "[gemm]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - wave latency equals tile latency") { - auto hardware = make_hardware(gpu_arch); - auto problem = - make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::N, 2); - auto config = make_config(128, 128, 64, 32, 32, 8, 8); - - auto tile_latency = origami::compute_tile_latency(problem, hardware, config, 304, 4); - auto wave_latency = origami::compute_timestep_latency(problem, hardware, config, 304, 4); - - REQUIRE(wave_latency == Approx(tile_latency)); - } - } -} - -TEST_CASE("GEMM: compute_total_latency", "[gemm]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - smaller tiles have lower total latency") { - auto hardware = make_hardware(gpu_arch); - auto problem_small = - make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::N, 2); - auto problem_large = - make_problem(8192, 8192, 2048, origami::transpose_t::T, origami::transpose_t::N, 2); - auto config = make_config(128, 128, 64, 32, 32, 8, 1); - - auto latency_small = - origami::compute_total_latency(problem_small, hardware, config, hardware.N_CU); - auto latency_large = - origami::compute_total_latency(problem_large, hardware, config, hardware.N_CU); - - REQUIRE(latency_small < latency_large); - } - } -} - -TEST_CASE("GEMM: check_lds_capacity", "[gemm]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - 256x256x64 tile fits in LDS") { - auto hardware = make_hardware(gpu_arch); - origami::dim3_t mt{256, 256, 64}; - - auto fits = origami::check_lds_capacity( - hardware, mt, origami::data_type_t::BFloat16, origami::data_type_t::BFloat16); - - REQUIRE(fits == true); - } - } -} - -TEST_CASE("GEMM: estimate_l2_hit", "[gemm]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - L2 hit rate in valid range") { - auto hardware = make_hardware(gpu_arch); - auto problem = make_problem(4096, 4096, 1024); - auto config = make_config(256, 256, 64, 32, 32, 8, 1); - - for (int wgm = 1; wgm < 1025; wgm++) { - config.workgroup_mapping = wgm; - auto l2_hit = origami::estimate_l2_hit(problem, hardware, config, 3); - REQUIRE(l2_hit > 0.0); - REQUIRE(l2_hit < 1.0); - } - } - } -} - -TEST_CASE("GEMM: estimate_mall_hit", "[gemm]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - Mall hit rate is positive") { - auto hardware = make_hardware(gpu_arch); - auto problem = make_problem(4096, 4096, 1024); - auto config = make_config(256, 256, 64, 32, 32, 8, 1); - - for (int wgm = 1; wgm < 1025; wgm++) { - config.workgroup_mapping = wgm; - auto mall_hit = origami::estimate_mall_hit(problem, hardware, config, 304, 8); - REQUIRE(mall_hit > 0.0); - } - } - } -} diff --git a/shared/origami/tests/test_negative_occupancy.cpp b/shared/origami/tests/test_negative_occupancy.cpp new file mode 100644 index 00000000000..406e06aaebd --- /dev/null +++ b/shared/origami/tests/test_negative_occupancy.cpp @@ -0,0 +1,41 @@ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "origami/gemm.hpp" +#include "origami/hardware.hpp" +#include "origami/utils.hpp" + +TEST(OrigamiTileSelection, NegativeOccupancy) { + // Setup hardware (gfx942) + auto gfx942arch = origami::hardware_t::arch_name_to_enum("gfx942"); + auto hardware = origami::hardware_t(gfx942arch, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.7, 1, + std::make_tuple(0, 0.015, 0)); + + // Square problem size + size_t M = 32; + size_t N = 800000; + size_t K = 16; + size_t batch = 1; + bool transA = false; + bool transB = true; + + // List 1: Tile A first, then Tile B + const std::vector MT_list = { + {256, 256, 32, 16, 16, 32, -1, 6, 0, 0}, // Tile A + {32, 256, 16, 32, 32, 8, 2, 6, 0, 0} // Tile B + }; + + auto results = origami::select_best_macro_tile_size( + M, N, K, batch, transA, transB, hardware, MT_list, 32, 32, 32, + origami::data_type_t::XFloat32, 0, 0.8, false, 6); + + auto best_tile = results[0]; + size_t MT_M = std::get<1>(best_tile); + size_t MT_N = std::get<2>(best_tile); + size_t MT_K = std::get<3>(best_tile); + EXPECT_EQ(MT_M, 32) << "MT_M should be 32"; + EXPECT_EQ(MT_N, 256) << "MT_N should be 256"; + EXPECT_EQ(MT_K, 16) << "MT_K should be 16"; +} \ No newline at end of file diff --git a/shared/origami/tests/test_origami.cpp b/shared/origami/tests/test_origami.cpp deleted file mode 100644 index a3445dfd40b..00000000000 --- a/shared/origami/tests/test_origami.cpp +++ /dev/null @@ -1,287 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright 2025 AMD ROCm(TM) Software - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ - -#include -#include -#include "common.hpp" - -using Catch::Approx; - -// Test functions for origami.hpp/cpp - -TEST_CASE("Origami: compute_perf_gflops", "[origami]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - faster clock yields higher GFLOPS") { - // TODO: Add support for make_hardware using hipDeviceProperties - auto hardware_slow = make_hardware(gpu_arch, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.4); - auto hardware_fast = make_hardware(gpu_arch, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.8); - auto problem = - make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::N, 2); - auto config = make_config(128, 128, 64, 32, 32, 8, 1); - - auto config_slow = config; - auto config_fast = config; - - auto latency_config_slow = - origami::compute_total_latency(problem, hardware_slow, config_slow, hardware_slow.N_CU); - auto flops_slow = origami::compute_perf_gflops(hardware_slow, problem, latency_config_slow); - - auto latency_config_fast = - origami::compute_total_latency(problem, hardware_fast, config_fast, hardware_fast.N_CU); - auto flops_fast = origami::compute_perf_gflops(hardware_fast, problem, latency_config_fast); - - REQUIRE(flops_fast > flops_slow); - } - } -} - -TEST_CASE("Origami: hardware_arch_enum", "[origami]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " architecture enum") { - std::string arch_str = "gfx" + std::to_string(gpu_arch); - auto arch_enum = origami::hardware_t::arch_name_to_enum(arch_str); - - if (gpu_arch == 942) { - REQUIRE(arch_enum == origami::hardware_t::architecture_t::gfx942); - } else if (gpu_arch == 950) { - REQUIRE(arch_enum == origami::hardware_t::architecture_t::gfx950); - } - } - } -} - -TEST_CASE("Origami: best_grid_size", "[origami]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - grid size selection") { - auto hardware = make_hardware(gpu_arch); - auto problem = make_problem(1024, 1024, 4096); - auto config = make_config(256, 256, 64, 32, 32, 8, 1); - - auto grid_size = origami::streamk::select_grid_size( - problem, hardware, config, origami::grid_selection_t::k_split_aware, hardware.N_CU); - - REQUIRE(grid_size >= 16); - } - } -} - -TEST_CASE("Origami: best_macro_tile_size", "[origami]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - rank configs by latency") { - auto hardware = make_hardware(gpu_arch); - auto problem = make_problem(1024, 1024, 4096); - - // List 1: config A first, then config B - std::vector configs; - - // config A[0] - configs.push_back(make_config(256, 256, 32, 32, 32, 8, 1, 6, 0, 0)); - // config A[1] - configs.push_back(make_config(128, 128, 64, 32, 32, 8, 1, 6, 0, 0)); - // config A[2] - configs.push_back(make_config(64, 64, 64, 32, 32, 8, 1, 6, 0, 0)); - - auto results = origami::rank_configs(problem, hardware, configs); - - REQUIRE(results.size() == configs.size()); - // Results should be ranked, so latencies should be in ascending order (best first) - for (size_t i = 0; i < results.size() - 1; i++) { - REQUIRE(results[i].latency < results[i + 1].latency); - } - } - } -} - -TEST_CASE("Origami: select_workgroup_mapping", "[origami]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - workgroup mapping selection") { - auto hardware = make_hardware(gpu_arch); - auto problem = make_problem(4096, 4096, 8192); - - auto config_large = make_config(256, 256, 32, 32, 32, 8, 1); - auto skGrid_large = (4096 + 256 - 1) / 256 * (4096 + 256 - 1) / 256; - auto [best_wgmxcc_large_tile, best_wgm_large_tile] = - origami::select_workgroup_mapping(problem, hardware, config_large, skGrid_large); - - auto config_small = make_config(128, 128, 64, 32, 32, 8, 1); - auto skGrid_small = (4096 + 128 - 1) / 128 * (4096 + 128 - 1) / 128; - auto [best_wgmxcc_small_tile, best_wgm_small_tile] = - origami::select_workgroup_mapping(problem, hardware, config_small, skGrid_small); - - // Different problem size for nonsquare test - origami::problem_t problem_nonsquare = problem; - problem_nonsquare.size.m = 2048; - problem_nonsquare.size.n = 5120; - auto skGrid_nonsquare = (2048 + 128 - 1) / 128 * (5120 + 128 - 1) / 128; - - auto [best_wgmxcc_nonsquare_tile, best_wgm_nonsquare] = origami::select_workgroup_mapping( - problem_nonsquare, hardware, config_large, skGrid_nonsquare); - - REQUIRE(best_wgmxcc_large_tile == best_wgmxcc_small_tile); - REQUIRE(best_wgm_large_tile > best_wgm_small_tile); - REQUIRE(best_wgm_large_tile != best_wgm_nonsquare); - } - } -} - -TEST_CASE("GEMM: negative_occupancy", "[gemm]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - test negative occupancy") { - auto hardware = make_hardware(gpu_arch); - origami::problem_t problem = { - .size = {32, 800000, 16}, - .batch = 1, - .a_transpose = origami::transpose_t::N, - .b_transpose = origami::transpose_t::T, - .a_dtype = origami::data_type_t::XFloat32, // element_size_A = 16 - .b_dtype = origami::data_type_t::XFloat32, - .mi_dtype = origami::data_type_t::XFloat32, - .a_mx_block_size = 0, - .b_mx_block_size = 0, - }; - // List 1: config A first, then config B - std::vector config; - - // config[0] - config.push_back(make_config(256, 256, 32, 16, 16, 32, -1, 6, 0, 0)); - // config[1] - config.push_back(make_config(32, 256, 16, 32, 32, 8, 2, 6, 0, 0)); - - // Call select_config - auto best_tile = origami::select_config(problem, hardware, config); - - size_t MT_M = best_tile.config.mt.m; - size_t MT_N = best_tile.config.mt.n; - size_t MT_K = best_tile.config.mt.k; - REQUIRE(MT_M == 32); //"MT_M should be 32" - REQUIRE(MT_N == 256); //"MT_N should be 256" - REQUIRE(MT_K == 16); //"MT_K should be 16" - } - } -} - -TEST_CASE("GEMM: deterministic_tie_breaking", "[gemm]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - Verify deterministic selection") { - auto hardware = make_hardware(gpu_arch); - // Square problem size - auto problem = - make_problem(1024, 1024, 1024, origami::transpose_t::N, origami::transpose_t::N, 1); - - // Two config with same arithmetic intensity: 256x64x32 and 64x256x32 - // AI = (2 * MT_M * MT_N * MT_K) / (MT_M*MT_K + MT_N*MT_K + MT_M*MT_N) - // Both have AI = 1048576 / 26624 = 39.38 - - // List 1: config A first, then config B - std::vector config_A; - std::vector config_B; - - // config A[0] - config_A.push_back(make_config(256, 64, 32, 32, 32, 8, 1, 6, 0, 0)); - // config A[1] - config_A.push_back(make_config(64, 256, 32, 32, 32, 8, 1, 6, 0, 0)); - - // config B[0] (reversed order) - config_B.push_back(make_config(64, 256, 32, 32, 32, 8, 1, 6, 0, 0)); - // config B[1] (reversed order) - config_B.push_back(make_config(256, 64, 32, 32, 32, 8, 1, 6, 0, 0)); - - // Call select_config with both orderings - auto best_tile_A = origami::select_config(problem, hardware, config_A); - - auto best_tile_B = origami::select_config(problem, hardware, config_B); - - size_t MT_M_A_first = best_tile_A.config.mt.m; - size_t MT_N_A_first = best_tile_A.config.mt.n; - size_t MT_K_A_first = best_tile_A.config.mt.k; - - size_t MT_M_B_first = best_tile_B.config.mt.m; - size_t MT_N_B_first = best_tile_B.config.mt.n; - size_t MT_K_B_first = best_tile_B.config.mt.k; - - // Verify deterministic selection: both should select the same tile (256x64x32) - // regardless of input order, using the final tie-breaker (prefer larger MT_M) - REQUIRE(MT_M_A_first == MT_M_B_first); //"Selected tile MT_M should be consistent" - REQUIRE(MT_N_A_first == MT_N_B_first); //"Selected tile MT_N should be consistent" - REQUIRE(MT_K_A_first == MT_K_B_first); //"Selected tile MT_K should be consistent" - - // Verify it selected the tile with larger MT_M (256 > 64) - REQUIRE(MT_M_A_first == 256); //"Should prefer tile with larger MT_M" - REQUIRE(MT_N_A_first == 64); //"Should prefer tile with larger MT_M" - } - } -} - -TEST_CASE("GEMM: Verify deterministic tile selection", "[gemm]") { - for (int gpu_arch : test_architectures) { - DYNAMIC_SECTION("gfx" << gpu_arch << " - Verify deterministic selection") { - auto hardware = make_hardware(gpu_arch); - - // problem size - auto problem = - make_problem(42598, 153, 128, origami::transpose_t::N, origami::transpose_t::T); - - // List 1: config A first, then config B - std::vector config_A; - std::vector config_B; - - // config A[0] - config_A.push_back(make_config(256, 160, 32, 16, 16, 32, 1, 6, 0, 0)); // Tile A - // config A[1] - config_A.push_back(make_config(192, 160, 64, 16, 16, 32, 1, 6, 0, 0)); // Tile B - - // config B[0] Previous two tiles + a new one - config_B.push_back(make_config(256, 160, 32, 16, 16, 32, 1, 6, 0, 0)); // Tile A - // config B[1] - config_B.push_back(make_config(192, 160, 64, 16, 16, 32, 1, 6, 0, 0)); // Tile B - // config B[2] - config_B.push_back(make_config(192, 160, 32, 16, 16, 32, 1, 6, 0, 0)); // Tile C - - // Call select_config with both tile configs - auto best_tile_A = origami::select_config(problem, hardware, config_A); - - auto best_tile_B = origami::select_config(problem, hardware, config_B); - - size_t MT_M1 = best_tile_A.config.mt.m; - size_t MT_N1 = best_tile_A.config.mt.n; - size_t MT_K1 = best_tile_A.config.mt.k; - - size_t MT_M2 = best_tile_B.config.mt.m; - size_t MT_N2 = best_tile_B.config.mt.n; - size_t MT_K2 = best_tile_B.config.mt.k; - - auto winner_is_acceptable = - [](auto const& actual, auto const& candidate1, auto const& candidate2) -> bool { - return (actual == candidate1) || (actual == candidate2); - }; - - INFO("Winner is not acceptable"); - REQUIRE(winner_is_acceptable(MT_M2, MT_M1, config_B[2].mt.m)); - REQUIRE(winner_is_acceptable(MT_N2, MT_N1, config_B[2].mt.n)); - REQUIRE(winner_is_acceptable(MT_K2, MT_K1, config_B[2].mt.k)); - } - } -} diff --git a/shared/origami/tests/test_tile_ordering_issue.cpp b/shared/origami/tests/test_tile_ordering_issue.cpp new file mode 100644 index 00000000000..1dd0eb22041 --- /dev/null +++ b/shared/origami/tests/test_tile_ordering_issue.cpp @@ -0,0 +1,72 @@ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "origami/gemm.hpp" +#include "origami/hardware.hpp" +#include "origami/utils.hpp" + +// Test to verify deterministic tile selection when tiles have same latency and arithmetic intensity +// This test ensures that the tie-breaking logic works correctly regardless of input order +TEST(OrigamiTileSelection, DeterministicTieBreaking) { + // Setup hardware (gfx942) + auto gfx942arch = origami::hardware_t::arch_name_to_enum("gfx942"); + auto hardware = origami::hardware_t(gfx942arch, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.7, 1, + std::make_tuple(0, 0.015, 0)); + + // Square problem size + size_t M = 1024; + size_t N = 1024; + size_t K = 1024; + size_t batch = 1; + bool transA = false; + bool transB = false; + + // Two tiles with same arithmetic intensity: 256x64x32 and 64x256x32 + // AI = (2 * MT_M * MT_N * MT_K) / (MT_M*MT_K + MT_N*MT_K + MT_M*MT_N) + // Both have AI = 1048576 / 26624 = 39.38 + + // List 1: Tile A first, then Tile B + const std::vector MT_list_A_first = { + {256, 64, 32, 32, 32, 8, 1, 6, 0, 0}, // Tile A + {64, 256, 32, 32, 32, 8, 1, 6, 0, 0} // Tile B + }; + + // List 2: Tile B first, then Tile A (reversed order) + const std::vector MT_list_B_first = { + {64, 256, 32, 32, 32, 8, 1, 6, 0, 0}, // Tile B + {256, 64, 32, 32, 32, 8, 1, 6, 0, 0} // Tile A + }; + + // Call select_best_macro_tile_size with both orderings + auto results_A_first = origami::select_best_macro_tile_size( + M, N, K, batch, transA, transB, hardware, MT_list_A_first, 16, 16, 16, + origami::data_type_t::BFloat16, 0, 0.8, false, 6); + + auto results_B_first = origami::select_best_macro_tile_size( + M, N, K, batch, transA, transB, hardware, MT_list_B_first, 16, 16, 16, + origami::data_type_t::BFloat16, 0, 0.8, false, 6); + + // Extract the best tile from each result + auto best_tile_A_first = results_A_first[0]; + auto best_tile_B_first = results_B_first[0]; + + size_t MT_M_A_first = std::get<1>(best_tile_A_first); + size_t MT_N_A_first = std::get<2>(best_tile_A_first); + size_t MT_K_A_first = std::get<3>(best_tile_A_first); + + size_t MT_M_B_first = std::get<1>(best_tile_B_first); + size_t MT_N_B_first = std::get<2>(best_tile_B_first); + size_t MT_K_B_first = std::get<3>(best_tile_B_first); + + // Verify deterministic selection: both should select the same tile (256x64x32) + // regardless of input order, using the final tie-breaker (prefer larger MT_M) + EXPECT_EQ(MT_M_A_first, MT_M_B_first) << "Selected tile MT_M should be consistent"; + EXPECT_EQ(MT_N_A_first, MT_N_B_first) << "Selected tile MT_N should be consistent"; + EXPECT_EQ(MT_K_A_first, MT_K_B_first) << "Selected tile MT_K should be consistent"; + + // Verify it selected the tile with larger MT_M (256 > 64) + EXPECT_EQ(MT_M_A_first, 256) << "Should prefer tile with larger MT_M"; + EXPECT_EQ(MT_N_A_first, 64) << "Should prefer tile with larger MT_M"; +} diff --git a/shared/origami/tests/test_variance_issue.cpp b/shared/origami/tests/test_variance_issue.cpp new file mode 100644 index 00000000000..451988cd13e --- /dev/null +++ b/shared/origami/tests/test_variance_issue.cpp @@ -0,0 +1,81 @@ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "origami/gemm.hpp" +#include "origami/hardware.hpp" +#include "origami/utils.hpp" + +// Test to verify deterministic tile selection +TEST(OrigamiTileSelection, VarianceEffect) { + // Setup hardware + auto gfx950arch = origami::hardware_t::arch_name_to_enum("gfx950"); + // auto hardware = origami::hardware_t(gfx950arch, 256, 163840, 8, 7727.272727, 2993.42, 3157.89, 4194304, 2.2, 4, + // std::make_tuple(0, 0.008, 0)); + auto hardware = origami::hardware_t(gfx950arch, 256, 163840, 8, 7727.272727, 2000.42, 2000.89, 4194304, 2.2, 4, + std::make_tuple(0, 0.008, 0)); + + // Square problem size + size_t M = 42598; + size_t N = 153; + size_t K = 128; + size_t batch = 1; + bool transA = false; + bool transB = true; + + // List 1: Only two tiles + const std::vector MT_list1 = { + {256, 160, 32, 16, 16, 32, 1, 6, 0, 0}, // Tile A + {192, 160, 64, 16, 16, 32, 1, 6, 0, 0} // Tile B + }; + + // List 2: Previous two tiles + a new one + const std::vector MT_list2 = { + {256, 160, 32, 16, 16, 32, 1, 6, 0, 0}, // Tile A + {192, 160, 64, 16, 16, 32, 1, 6, 0, 0}, // Tile B + {192, 160, 32, 16, 16, 32, 1, 6, 0, 0} // Tile C + }; + + // Call select_best_macro_tile_size with both tile lists + auto results1 = origami::select_best_macro_tile_size( + M, N, K, batch, transA, transB, hardware, MT_list1, 16, 16, 32, + origami::data_type_t::BFloat16, 0, 0.8, false, 6); + + auto results2 = origami::select_best_macro_tile_size( + M, N, K, batch, transA, transB, hardware, MT_list2, 16, 16, 32, + origami::data_type_t::BFloat16, 0, 0.8, false, 6); + + // Extract the best tile from each result + auto best_tile1 = results1[0]; + auto best_tile2 = results2[0]; + + size_t MT_M1 = std::get<1>(best_tile1); + size_t MT_N1 = std::get<2>(best_tile1); + size_t MT_K1 = std::get<3>(best_tile1); + + size_t MT_M2 = std::get<1>(best_tile2); + size_t MT_N2 = std::get<2>(best_tile2); + size_t MT_K2 = std::get<3>(best_tile2); + + // At this time, the model predicts following latencies: + // TileA: 35803.3 + // TileB: 36088 + // TileC: 35452.4 + // Hence, for list1 TileB is the winner as it has the largest AI. + // For list2, either TileB (previous winner) or TileC (new tile) should be the winner. + // Note that adding TileC (with a reasonable variance) eliminates TileB from the + // pool of selected kernels, and TileC is the winner. + std::cout << std::get<0>(results2[0]) << " " << std::get<1>(results2[0]) << " " << std::get<3>(results2[0]) << std::endl; + std::cout << std::get<0>(results2[1]) << " " << std::get<1>(results2[1]) << " " << std::get<3>(results2[1]) << std::endl; + std::cout << std::get<0>(results2[2]) << " " << std::get<1>(results2[2]) << " " << std::get<3>(results2[2]) << std::endl; + + // Verify deterministic selection + // After adding a new tile, we don't want the selection to change from + // tile A to tile B (or the other way around). + // We accept if the new tile is the winner, though! + EXPECT_THAT(MT_M2, ::testing::AnyOf(::testing::Eq(MT_M1), ::testing::Eq(std::get<0>(MT_list2[2])))) << "Winner is not acceptable"; + EXPECT_THAT(MT_N2, ::testing::AnyOf(::testing::Eq(MT_N1), ::testing::Eq(std::get<1>(MT_list2[2])))) << "Winner is not acceptable"; + EXPECT_THAT(MT_K2, ::testing::AnyOf(::testing::Eq(MT_K1), ::testing::Eq(std::get<2>(MT_list2[2])))) << "Winner is not acceptable"; +}