diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/analytical_utils.hpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/analytical_utils.hpp index 17f221050338..81e08a2c9047 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/analytical_utils.hpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/analytical_utils.hpp @@ -33,7 +33,7 @@ #include -#include +#include "origami/types.hpp" /** * @brief Convert rocRoller::Datatype to analytical::DataType diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/runtime_args_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/runtime_args_selection.cpp index ca4da79952d2..372c7f6b970e 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/runtime_args_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/runtime_args_selection.cpp @@ -28,55 +28,52 @@ #include "gemm.hpp" #include "runtime_args_selection.hpp" -#include - -const int DEFAULT_DYNAMIC_MODE = 6; +#include "origami/streamk.hpp" int chooseStreamKGridSize(std::shared_ptr gemm, const RocblasltContractionProblem& prob) { - const origami::hardware_t analaytical_hardware = origami::hardware_t::get_hardware_for_device(0); + const origami::hardware_t analytical_hardware = origami::hardware_t::get_hardware_for_device(0); + + const origami::grid_selection_t DEFAULT_DYNAMIC_MODE = origami::grid_selection_t::k_split_aware; + + //setting max_cu's + size_t max_cus = analytical_hardware.N_CU; size_t elementSizeA_bits = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeA).elementBits; size_t elementSizeB_bits = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeB).elementBits; - size_t elementSizeD_bits = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeD).elementBits; size_t elementSizeAcc = rocRoller::DataTypeInfo::Get(gemm->params->kernelType.typeAcc).elementBytes; - origami::data_type_t dataType; - if (elementSizeA_bits < elementSizeB_bits) - dataType = rocroller_type_to_analytical_type(gemm->params->kernelType.typeB); - else - dataType = rocroller_type_to_analytical_type(gemm->params->kernelType.typeA); + origami::problem_t origami_problem = { + .size = {prob.m, prob.n, prob.k}, + .batch = prob.batch_count, + .a_dtype = rocroller_type_to_analytical_type(gemm->params->kernelType.typeA), + .b_dtype = rocroller_type_to_analytical_type(gemm->params->kernelType.typeB), + .mi_dtype = rocroller_type_to_analytical_type(elementSizeA_bits < elementSizeB_bits ? gemm->params->kernelType.typeB : gemm->params->kernelType.typeA), + }; + origami::config_t origami_config = { + .mt = { + static_cast(gemm->params->workgroupTile.m), + static_cast(gemm->params->workgroupTile.n), + static_cast(gemm->params->workgroupTile.k) + }, + .occupancy = gemm->occupancy, + .workspace_size = prob.workspaceSize, + .workspace_size_per_elem_c = elementSizeAcc, + }; + + auto reduction_type = origami::streamk::select_reduction(origami_problem, + analytical_hardware, + origami_config, + DEFAULT_DYNAMIC_MODE); - auto reduction_type = origami::streamk::select_reduction(prob.m, prob.n, prob.k, prob.batch_count, - gemm->params->workgroupTile.m, gemm->params->workgroupTile.n, gemm->params->workgroupTile.k, analaytical_hardware, DEFAULT_DYNAMIC_MODE); - // Override reduction type to tree reduction for now. - // When Parallel reduction is available, this line can be removed - reduction_type = origami::streamk::reduction_type::Tree; + origami_config.reduction_strategy = reduction_type; - auto result = origami::streamk::select_grid(prob.m, - prob.n, - prob.k, - prob.batch_count, - prob.trans_a == HIPBLAS_OP_T, - prob.trans_b == HIPBLAS_OP_T, - elementSizeA_bits, - elementSizeB_bits, - elementSizeD_bits, - dataType, - prob.workspaceSize, - gemm->params->workgroupTile.m, - gemm->params->workgroupTile.n, - gemm->params->workgroupTile.k, - gemm->params->machineInstruction.m, - gemm->params->machineInstruction.n, - gemm->params->machineInstruction.k, - DEFAULT_WGM, - elementSizeAcc, - gemm->occupancy, - analaytical_hardware, - DEFAULT_DYNAMIC_MODE, - reduction_type); + auto result = origami::streamk::select_grid_size(origami_problem, + analytical_hardware, + origami_config, + DEFAULT_DYNAMIC_MODE, + max_cus); return result; } diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp index 4919875c7ee2..2c9549e7bd02 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp @@ -29,7 +29,7 @@ #include "runtime_args_selection.hpp" #include "solution_selection.hpp" -#include +#include "origami/origami.hpp" const int MAX_BITS_WORKGROUPTILE_M = 8; const int MAX_BITS_WORKGROUPTILE_N = 8; @@ -44,7 +44,9 @@ const int USE_WORKGROUP_MAPPING_K_SIZE = 4096; * compile-time known. */ - constexpr std::array possibleTileSizes = {{ + constexpr size_t possibleTileSizesCount = 34; + + constexpr std::array possibleTileSizes = {{ {256, 256, 128}, {256, 192, 128}, {256, 128, 128}, @@ -82,10 +84,10 @@ const int USE_WORKGROUP_MAPPING_K_SIZE = 4096; }}; template -constexpr auto generateTileList() { - std::array tileList{}; +auto generateTileList() { + std::array tileList{}; - for (size_t i = 0; i < possibleTileSizes.size(); ++i) { + for (size_t i = 0; i < possibleTileSizesCount; ++i) { const auto& wgt = possibleTileSizes[i]; auto MI = pickMI(typeA, typeB, wgt); @@ -96,27 +98,33 @@ constexpr auto generateTileList() { int unroll = preferredUnrolling(typeA, typeB, wgt); - int non_temporal_a = 0; - int non_temporal_b = 0; - - tileList[i] = std::make_tuple( - wgt.m, wgt.n, wgtk * unroll, - MI.m, MI.n, MI.k, - 1, // occupancy - DEFAULT_WGM, - non_temporal_a, - non_temporal_b - ); + origami::config_t origami_config = { + .mt = { + static_cast(wgt.m), + static_cast(wgt.n), + static_cast(wgtk * unroll) + }, + .mi = { + static_cast(MI.m), + static_cast(MI.n), + static_cast(MI.k) + }, + .occupancy = 1, + .cache_hints_a = 0, + .cache_hints_b = 0, + }; + + tileList[i] = origami_config; } return tileList; } -using TileListGeneratorFn = std::vector(*)(); +using TileListGeneratorFn = std::vector(*)(); template -std::vector generateTileListWrapper() { - constexpr auto arr = generateTileList(); +std::vector generateTileListWrapper() { + auto arr = generateTileList(); return {arr.begin(), arr.end()}; } @@ -144,7 +152,7 @@ const std::map, TileListGene INSTANTIATE_TILE_LIST_FOR(FP6) }; -std::vector getTileListForKernelType(KernelType kernelType) +std::vector getTileListForKernelType(KernelType kernelType) { auto key = std::make_pair(kernelType.typeA, kernelType.typeB); auto it = tileListGenerators.find(key); @@ -170,43 +178,42 @@ std::vector chooseSolutionIndexParameters( { std::vector params; - std::vector tile_list = getTileListForKernelType(kernelType); + std::vector origami_config_list = getTileListForKernelType(kernelType); size_t elementSizeA_bits = rocRoller::DataTypeInfo::Get(kernelType.typeA).elementBits; size_t elementSizeB_bits = rocRoller::DataTypeInfo::Get(kernelType.typeB).elementBits; - size_t elementSizeC_bits = rocRoller::DataTypeInfo::Get(kernelType.typeC).elementBits; - - origami::data_type_t dataType; - if (elementSizeA_bits < elementSizeB_bits) - dataType = rocroller_type_to_analytical_type(kernelType.typeB); - else - dataType = rocroller_type_to_analytical_type(kernelType.typeA); - - const origami::hardware_t analaytical_hardware = origami::hardware_t::get_hardware_for_device(0); - - int WGM = std::sqrt(std::floor(analaytical_hardware.N_CU / analaytical_hardware.NUM_XCD)); - - auto selected_tiles = origami::select_best_macro_tile_size( - prob.m, - prob.n, - prob.k, - prob.batch_count, - prob.trans_a == hipblasOperation_t::HIPBLAS_OP_T, - prob.trans_b == hipblasOperation_t::HIPBLAS_OP_T, - analaytical_hardware, - tile_list, - elementSizeA_bits, - elementSizeB_bits, - elementSizeC_bits, - dataType, - kernelType.scaleABlockRowSize * kernelType.scaleABlockColSize, //Handle A vs B block size. - 0.8, - false, - WGM); - - for(auto const& selected_tile : selected_tiles) + + const origami::hardware_t analytical_hardware = origami::hardware_t::get_hardware_for_device(0); + + origami::problem_t origami_problem = { + .size = {prob.m, prob.n, prob.k}, + .batch = prob.batch_count, + .a_transpose = (prob.trans_a == hipblasOperation_t::HIPBLAS_OP_T) ? origami::transpose_t::T : origami::transpose_t::N, + .b_transpose = (prob.trans_b == hipblasOperation_t::HIPBLAS_OP_T) ? origami::transpose_t::T : origami::transpose_t::N, + .a_dtype = rocroller_type_to_analytical_type(kernelType.typeA), + .b_dtype = rocroller_type_to_analytical_type(kernelType.typeB), + .mi_dtype = rocroller_type_to_analytical_type(elementSizeA_bits < elementSizeB_bits ? kernelType.typeB : kernelType.typeA), + .a_mx_block_size = kernelType.scaleABlockRowSize * kernelType.scaleABlockColSize, + .b_mx_block_size = kernelType.scaleBBlockRowSize * kernelType.scaleBBlockColSize, + }; + + int defaultWGM = std::ceil(std::sqrt(analytical_hardware.N_CU / analytical_hardware.NUM_XCD)); + for (auto& config : origami_config_list) { + config.workgroup_mapping = defaultWGM; + } + + auto prediction_result = origami::rank_configs( + origami_problem, + analytical_hardware, + origami_config_list + ); + + for(auto const& result : prediction_result) { - WorkGroupTileSize wgt{(int)std::get<1>(selected_tile), (int)std::get<2>(selected_tile), (int)std::get<3>(selected_tile)}; + auto mt_m = static_cast(result.config.mt.m); + auto mt_n = static_cast(result.config.mt.n); + auto mt_k = static_cast(result.config.mt.k); + WorkGroupTileSize wgt{mt_m, mt_n, mt_k}; int unrollAmount = preferredUnrolling(kernelType.typeA, kernelType.typeB, wgt); wgt.k /= unrollAmount; diff --git a/projects/hipblaslt/tensilelite/include/Tensile/ContractionSolution.hpp b/projects/hipblaslt/tensilelite/include/Tensile/ContractionSolution.hpp index 57b47fa6f40b..6d3a494301b2 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/ContractionSolution.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/ContractionSolution.hpp @@ -42,7 +42,8 @@ #include #include -#include +#include "origami/origami.hpp" +#include "origami/streamk.hpp" #define TENSILE_COMMON_KERNEL_ARGS_SIZE 16 @@ -166,7 +167,7 @@ namespace TensileLite struct StreamKSettings { - origami::streamk::reduction_type reduction = origami::streamk::reduction_type::Tree; + origami::reduction_t reduction = origami::reduction_t::tree; size_t grid = 0; }; @@ -183,7 +184,7 @@ namespace TensileLite using Problem = ContractionProblemGemm; using Inputs = ContractionInputs; using GroupedInputs = ContractionGroupedInputs; - using ParamsCache = CacheMap, Problem>; + using ParamsCache = CacheMap, Problem>; /** * Indicate a solution is equally or estimatedly matched. @@ -218,6 +219,11 @@ namespace TensileLite } virtual bool isFallbackForHW(Hardware const&) const; + bool isStreamK() const + { + return sizeMapping.streamK > 0; + } + //! Estimates based on problem size, solution tile, and machine hardware //! charz: struct StaticPerformanceModel @@ -290,8 +296,8 @@ namespace TensileLite void calculateGrid(dim3& workGroupSize, dim3& numWorkGroups, ContractionSolution::Problem const& problem) const; - origami::streamk::reduction_type getSKReduction(Problem const& problem, Hardware const& hardware) const; - size_t getSKGrid(Problem const& problem, Hardware const& hardware, size_t tiles, origami::streamk::reduction_type& reductionStrat) const; + origami::reduction_t getSKReduction(Problem const& problem, Hardware const& hardware) const; + size_t getSKGrid(Problem const& problem, Hardware const& hardware, size_t tiles, origami::reduction_t reductionStrat) const; size_t partialTileSize(size_t skGrid) const; static float computeGranularity(float x); @@ -566,9 +572,9 @@ namespace TensileLite uint32_t magicNumber(int magicDivAlg, uint32_t x, uint32_t* magicShift) const; uint32_t smallMagicNumber(uint32_t x) const; - std::pair calculateAutoWGM(Problem const& problem, - Hardware const* hardware, - uint32_t skgrid) const; + std::pair calculateAutoWGM(Problem const& problem, + Hardware const* hardware, + uint32_t skgrid) const; uint32_t calculateAutoGSU(Problem const& problem, Hardware const* hardware) const; }; diff --git a/projects/hipblaslt/tensilelite/include/Tensile/PredictionLibrary.hpp b/projects/hipblaslt/tensilelite/include/Tensile/PredictionLibrary.hpp index 7b2d6b186921..df1920a8f23e 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/PredictionLibrary.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/PredictionLibrary.hpp @@ -45,9 +45,9 @@ namespace TensileLite template struct ProblemPredictionLibrary : public SolutionLibrary { - std::unordered_map> solutionmap; - std::vector tile_list; - std::unordered_map tile_map; + std::unordered_map> solutionmap; + std::vector origami_config_list; + std::unordered_map origami_config_map; static std::string Type() { @@ -156,56 +156,35 @@ namespace TensileLite batch *= problem.batchSize(i); } - bool debug = Debug::Instance().printPropertyEvaluation(); hip::HipAMDGPU const* pAMDGPU = dynamic_cast(&hardware); - size_t elementSizeA_bits - = problem.a().elementBytes() * 8; - size_t elementSizeB_bits - = problem.b().elementBytes() * 8; - size_t elementSizeC_bits - = problem.c().elementBytes() * 8; + const origami::hardware_t& analytical_hardware = *(pAMDGPU->analyticalHardware); - if(origami::hardware_t::is_debug_enabled()) - { - analytical_hardware.print(); - } - int defaultWGM = std::ceil(std::sqrt(analytical_hardware.N_CU / analytical_hardware.NUM_XCD)); - origami::data_type_t miDataType = datatypeToAnalyticalDatatype(problem.computeInputType()); + auto miDataType = datatypeToAnalyticalDatatype(problem.computeInputType()); + if(problem.f32XdlMathOp() == rocisa::DataType::XFloat32) // Check F32 compute type miDataType = origami::data_type_t::XFloat32; - auto selected_tiles = origami::select_best_macro_tile_size( - m, - n, - k, - batch, - problem.transA(), - problem.transB(), - *(pAMDGPU->analyticalHardware), - tile_list, - elementSizeA_bits, - elementSizeB_bits, - elementSizeC_bits, - miDataType, - 0, // mx_block_size -> MX Data types come from rocroller. - 0.8, // L2 hit-rate (not used anymore -- should be removed) - false, - defaultWGM, - pAMDGPU->skMaxCUs); - for(const auto& tile : selected_tiles) + origami::problem_t origami_problem = { + .size = {m, n, k}, + .batch = batch, + .a_transpose = problem.transA() ? origami::transpose_t::T : origami::transpose_t::N, + .b_transpose = problem.transB() ? origami::transpose_t::T : origami::transpose_t::N, + .a_dtype = datatypeToAnalyticalDatatype(problem.a().dataType()), + .b_dtype = datatypeToAnalyticalDatatype(problem.b().dataType()), + .c_dtype = datatypeToAnalyticalDatatype(problem.c().dataType()), + .d_dtype = datatypeToAnalyticalDatatype(problem.d().dataType()), + .mi_dtype = miDataType, + .a_mx_block_size = 0, // MX Data types come from rocroller + .b_mx_block_size = 0, // MX Data types come from rocroller + }; + + auto prediction_result = origami::rank_configs( + origami_problem, *(pAMDGPU->analyticalHardware), origami_config_list); + + for(const auto& r : prediction_result) { - auto mapiter = tile_map.find(std::make_tuple(std::get<1>(tile), - std::get<2>(tile), - std::get<3>(tile), - std::get<4>(tile), - std::get<5>(tile), - std::get<6>(tile), - std::get<7>(tile), - std::get<8>(tile), - std::get<9>(tile), - std::get<10>(tile) - )); + auto mapiter = origami_config_map.find(r.config); auto smapiter = solutionmap.find(mapiter->second); - if(mapiter != tile_map.end() && smapiter != solutionmap.end()) + if(mapiter != origami_config_map.end() && smapiter != solutionmap.end()) { auto solution = smapiter->second; if((*solution->hardwarePredicate)(hardware) diff --git a/projects/hipblaslt/tensilelite/include/Tensile/Serialization/PredictionLibrary.hpp b/projects/hipblaslt/tensilelite/include/Tensile/Serialization/PredictionLibrary.hpp index 77da8d810c6f..db2a5d05dafb 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/Serialization/PredictionLibrary.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/Serialization/PredictionLibrary.hpp @@ -83,21 +83,39 @@ namespace TensileLite auto solution = slnIter->second; lib.solutionmap.insert(std::make_pair(index, solution)); - auto solution_tuple = std::make_tuple( - solution->sizeMapping.macroTile.x, // MT_M - solution->sizeMapping.macroTile.y, // MT_N - solution->sizeMapping.depthU, // MT_K - solution->sizeMapping.matrixInstruction[0], // MI_M - solution->sizeMapping.matrixInstruction[1], // MI_N - solution->sizeMapping.matrixInstruction[2], // MI_K - solution->sizeMapping.CUOccupancy, // Occupancy - solution->sizeMapping.workGroupMapping, // WGM - solution->sizeMapping.nonTemporalA, // Cache flag: A - solution->sizeMapping.nonTemporalB // Cache flag: B - ); + origami::dim3_t origami_mi; + if(solution->sizeMapping.matrixInstruction[0] == 0 + && solution->sizeMapping.matrixInstruction[1] == 0 + && solution->sizeMapping.matrixInstruction[2] == 0) + { + // Override dot2 instruction with vector lane widths + origami_mi = {1, 1, 64}; + } + else + { + origami_mi = { + static_cast(solution->sizeMapping.matrixInstruction[0]), + static_cast(solution->sizeMapping.matrixInstruction[1]), + static_cast( + solution->sizeMapping.matrixInstruction[2])}; + } - lib.tile_list.emplace_back(solution_tuple); - lib.tile_map.insert(std::make_pair(solution_tuple, index)); + origami::config_t origami_config = { + .mt = {solution->sizeMapping.macroTile.x, + solution->sizeMapping.macroTile.y, + solution->sizeMapping.depthU}, + .mi = origami_mi, + .occupancy + = std::max(solution->sizeMapping.CUOccupancy, static_cast(1)), + .workgroup_mapping = solution->sizeMapping.workGroupMapping, + .cache_hints_a = solution->sizeMapping.nonTemporalA, + .cache_hints_b = solution->sizeMapping.nonTemporalB, + .workspace_size = std::numeric_limits::max(), + .workspace_size_per_elem_c = std::numeric_limits::max(), + }; + + lib.origami_config_list.emplace_back(origami_config); + lib.origami_config_map.insert(std::make_pair(origami_config, index)); } } } diff --git a/projects/hipblaslt/tensilelite/include/Tensile/UtilsOrigami.hpp b/projects/hipblaslt/tensilelite/include/Tensile/UtilsOrigami.hpp index 4d2519947beb..21dd009eada2 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/UtilsOrigami.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/UtilsOrigami.hpp @@ -26,58 +26,57 @@ #pragma once -#include +#include #include namespace TensileLite { - + inline origami::data_type_t datatypeToAnalyticalDatatype(rocisa::DataType type) { switch(type) { - case rocisa::DataType::Float: - return origami::data_type_t::Float; - case rocisa::DataType::Double: - return origami::data_type_t::Double; - case rocisa::DataType::ComplexFloat: - return origami::data_type_t::ComplexFloat; - case rocisa::DataType::ComplexDouble: - return origami::data_type_t::ComplexDouble; - case rocisa::DataType::Half: - return origami::data_type_t::Half; - case rocisa::DataType::Int8x4: - return origami::data_type_t::Int8x4; - case rocisa::DataType::Int32: - return origami::data_type_t::Int32; - case rocisa::DataType::BFloat16: - return origami::data_type_t::BFloat16; - case rocisa::DataType::Int8: - return origami::data_type_t::Int8; - case rocisa::DataType::Int64: - return origami::data_type_t::Int64; - case rocisa::DataType::XFloat32: - return origami::data_type_t::XFloat32; - case rocisa::DataType::Float8_fnuz: - return origami::data_type_t::Float8_fnuz; - case rocisa::DataType::BFloat8_fnuz: - return origami::data_type_t::BFloat8_fnuz; - case rocisa::DataType::Float8BFloat8_fnuz: - return origami::data_type_t::Float8BFloat8_fnuz; - case rocisa::DataType::BFloat8Float8_fnuz: - return origami::data_type_t::BFloat8Float8_fnuz; - case rocisa::DataType::Float8: - return origami::data_type_t::Float8; - case rocisa::DataType::BFloat8: - return origami::data_type_t::BFloat8; - case rocisa::DataType::Float8BFloat8: - return origami::data_type_t::Float8BFloat8; - case rocisa::DataType::BFloat8Float8: - return origami::data_type_t::BFloat8Float8; + case rocisa::DataType::Float: + return origami::data_type_t::Float; + case rocisa::DataType::Double: + return origami::data_type_t::Double; + case rocisa::DataType::ComplexFloat: + return origami::data_type_t::ComplexFloat; + case rocisa::DataType::ComplexDouble: + return origami::data_type_t::ComplexDouble; + case rocisa::DataType::Half: + return origami::data_type_t::Half; + case rocisa::DataType::Int8x4: + return origami::data_type_t::Int8x4; + case rocisa::DataType::Int32: + return origami::data_type_t::Int32; + case rocisa::DataType::BFloat16: + return origami::data_type_t::BFloat16; + case rocisa::DataType::Int8: + return origami::data_type_t::Int8; + case rocisa::DataType::Int64: + return origami::data_type_t::Int64; + case rocisa::DataType::XFloat32: + return origami::data_type_t::XFloat32; + case rocisa::DataType::Float8_fnuz: + return origami::data_type_t::Float8_fnuz; + case rocisa::DataType::BFloat8_fnuz: + return origami::data_type_t::BFloat8_fnuz; + case rocisa::DataType::Float8BFloat8_fnuz: + return origami::data_type_t::Float8BFloat8_fnuz; + case rocisa::DataType::BFloat8Float8_fnuz: + return origami::data_type_t::BFloat8Float8_fnuz; + case rocisa::DataType::Float8: + return origami::data_type_t::Float8; + case rocisa::DataType::BFloat8: + return origami::data_type_t::BFloat8; + case rocisa::DataType::Float8BFloat8: + return origami::data_type_t::Float8BFloat8; + case rocisa::DataType::BFloat8Float8: + return origami::data_type_t::BFloat8Float8; - default: - return origami::data_type_t::None; + default: + return origami::data_type_t::None; } } } // namespace TensileLite - \ No newline at end of file diff --git a/projects/hipblaslt/tensilelite/include/Tensile/hip/HipHardware.hpp b/projects/hipblaslt/tensilelite/include/Tensile/hip/HipHardware.hpp index a195ccff1052..fdcf98520516 100644 --- a/projects/hipblaslt/tensilelite/include/Tensile/hip/HipHardware.hpp +++ b/projects/hipblaslt/tensilelite/include/Tensile/hip/HipHardware.hpp @@ -28,7 +28,7 @@ #include #include -#include +#include #include diff --git a/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp b/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp index b7e86bb966fb..96a51e83e894 100644 --- a/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp +++ b/projects/hipblaslt/tensilelite/src/ContractionSolution.cpp @@ -32,10 +32,11 @@ #include #include #include +#include #include -#include #include +#include #include #include @@ -49,8 +50,6 @@ namespace TensileLite { - using ReductionType = origami::streamk::reduction_type; - enum class KERNELARGTYPE { NORMAL = 0, @@ -275,58 +274,32 @@ namespace TensileLite PrintBufferValueClass betaPrint( (void*)args[i].beta, sizeof(args[i].beta), problems[i].betaType()); std::cout << "Gemm " << i << ":" << std::endl; - std::cout << " " - << "m: " << args[i].m << std::endl; - std::cout << " " - << "n: " << args[i].n << std::endl; - std::cout << " " - << "batch: " << args[i].batch << std::endl; - std::cout << " " - << "k: " << args[i].k << std::endl; - std::cout << " " - << "D: " << args[i].d << std::endl; - std::cout << " " - << "C: " << args[i].c << std::endl; - std::cout << " " - << "A: " << args[i].a << std::endl; - std::cout << " " - << "B: " << args[i].b << std::endl; - std::cout << " " - << "strideD1: " << args[i].strideD1 << std::endl; - std::cout << " " - << "strideD2: " << args[i].strideD2 << std::endl; - std::cout << " " - << "strideC1: " << args[i].strideC1 << std::endl; - std::cout << " " - << "strideC2: " << args[i].strideC2 << std::endl; - std::cout << " " - << "strideA1: " << args[i].strideA1 << std::endl; - std::cout << " " - << "strideA2: " << args[i].strideA2 << std::endl; - std::cout << " " - << "strideB1: " << args[i].strideB1 << std::endl; - std::cout << " " - << "strideB2: " << args[i].strideB2 << std::endl; - std::cout << " " - << "Alpha: " << alphaPrint << std::endl; - std::cout << " " - << "Beta: " << betaPrint << std::endl; - std::cout << " " - << "scaleAlphaVec: " << args[i].scaleAlphaVec << std::endl; - std::cout << " " - << "bias: " << args[i].bias << std::endl; - std::cout << " " - << "e: " << args[i].e << std::endl; - std::cout << " " - << "strideE1: " << args[i].strideE1 << std::endl; - std::cout << " " - << "strideE2: " << args[i].strideE2 << std::endl; - std::cout << " " - << "act0: " << args[i].act0 << std::endl; - std::cout << " " - << "act1: " << args[i].act1 << std::endl; - std::cout << " " - << "activationType: " << args[i].activationType << std::endl; + std::cout << " " << "m: " << args[i].m << std::endl; + std::cout << " " << "n: " << args[i].n << std::endl; + std::cout << " " << "batch: " << args[i].batch << std::endl; + std::cout << " " << "k: " << args[i].k << std::endl; + std::cout << " " << "D: " << args[i].d << std::endl; + std::cout << " " << "C: " << args[i].c << std::endl; + std::cout << " " << "A: " << args[i].a << std::endl; + std::cout << " " << "B: " << args[i].b << std::endl; + std::cout << " " << "strideD1: " << args[i].strideD1 << std::endl; + std::cout << " " << "strideD2: " << args[i].strideD2 << std::endl; + std::cout << " " << "strideC1: " << args[i].strideC1 << std::endl; + std::cout << " " << "strideC2: " << args[i].strideC2 << std::endl; + std::cout << " " << "strideA1: " << args[i].strideA1 << std::endl; + std::cout << " " << "strideA2: " << args[i].strideA2 << std::endl; + std::cout << " " << "strideB1: " << args[i].strideB1 << std::endl; + std::cout << " " << "strideB2: " << args[i].strideB2 << std::endl; + std::cout << " " << "Alpha: " << alphaPrint << std::endl; + std::cout << " " << "Beta: " << betaPrint << std::endl; + std::cout << " " << "scaleAlphaVec: " << args[i].scaleAlphaVec << std::endl; + std::cout << " " << "bias: " << args[i].bias << std::endl; + std::cout << " " << "e: " << args[i].e << std::endl; + std::cout << " " << "strideE1: " << args[i].strideE1 << std::endl; + std::cout << " " << "strideE2: " << args[i].strideE2 << std::endl; + std::cout << " " << "act0: " << args[i].act0 << std::endl; + std::cout << " " << "act1: " << args[i].act1 << std::endl; + std::cout << " " << "activationType: " << args[i].activationType << std::endl; } } } @@ -544,11 +517,11 @@ namespace TensileLite template void ContractionSolution::singleCallArgs(ContractionSolution::Problem const& problem, ContractionInputs const& inputs, - uint32_t const& workspaceOffsetInByte, - Hardware const* hardware, - dim3 const& problemNumGroupTiles, - dim3 const& numWorkGroups, - KA& args, + uint32_t const& workspaceOffsetInByte, + Hardware const* hardware, + dim3 const& problemNumGroupTiles, + dim3 const& numWorkGroups, + KA& args, StreamKSettings const& sk) const { if(debugKernel) @@ -566,7 +539,7 @@ namespace TensileLite TensorDescriptor const& metadata = problem.metadata(); auto [autoWGM, autoWGMXCC] = calculateAutoWGM(problem, hardware, sk.grid); - uint32_t autoGsuVal = calculateAutoGSU(problem, hardware); + uint32_t autoGsuVal = calculateAutoGSU(problem, hardware); uint32_t gsu = problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal; { @@ -599,10 +572,12 @@ namespace TensileLite } else if(problemType.stridedBatched) { - if(sizeMapping.streamK > 0 && sk.reduction == ReductionType::Parallel) + if(sizeMapping.streamK > 0 && sk.reduction == origami::reduction_t::parallel) { - args.template append("ws_d", (uint8_t*)inputs.ws + workspaceOffsetInByte); - args.template append("ws_c", (uint8_t*)inputs.ws + workspaceOffsetInByte); + args.template append("ws_d", + (uint8_t*)inputs.ws + workspaceOffsetInByte); + args.template append("ws_c", + (uint8_t*)inputs.ws + workspaceOffsetInByte); } else { @@ -640,7 +615,7 @@ namespace TensileLite // StreamK workspace + flags args.template append("ws", inputs.ws); - if(sk.reduction == ReductionType::Parallel) + if(sk.reduction == origami::reduction_t::parallel) args.template append("Flags", nullptr); else args.template append("Flags", inputs.Synchronizer); @@ -650,8 +625,9 @@ namespace TensileLite size_t startStrideAB = problemType.useInitialStridesAB ? 0 : 1; // Pass wsStride if it's not in MBSK mode - bool gsuWSStride = gsu > 1 && sizeMapping.globalAccumulation != 3 && sizeMapping.streamK == 0; - bool skWSStride = sizeMapping.streamK > 0 && sk.reduction == ReductionType::Parallel; + bool gsuWSStride + = gsu > 1 && sizeMapping.globalAccumulation != 3 && sizeMapping.streamK == 0; + bool skWSStride = sizeMapping.streamK > 0 && sk.reduction == origami::reduction_t::parallel; if(gsuWSStride || skWSStride) { size_t wsStride = startStrideCD ? d.sizes()[0] : 1; @@ -735,23 +711,24 @@ namespace TensileLite args.template append("totalIters", totalIters); if(sizeMapping.streamK == 1) // Basic SK - { + { uint32_t itersPerWave = CeilDivide(totalIters, numWorkGroups.x); args.template append("SKItersPerWG", itersPerWave); - } + } else if(sizeMapping.streamK >= 2) // Two-tile SK - { - if(sk.reduction == ReductionType::Parallel) - { - uint32_t skSplit = sk.grid / tiles; // skTiles is skSplit in parallel reduction path + { + if(sk.reduction == origami::reduction_t::parallel) + { + uint32_t skSplit + = sk.grid / tiles; // skTiles is skSplit in parallel reduction path uint32_t skItersPerWG = itersPerTile / skSplit; args.template append("SKItersPerWG", skItersPerWG); - args.template append("skGrid", sk.grid); - args.template append("skTiles", skSplit); - } + args.template append("skGrid", sk.grid); + args.template append("skTiles", skSplit); + } else - { + { AMDGPU const* pAMDGPU = dynamic_cast(hardware); assert(pAMDGPU != nullptr && pAMDGPU->computeUnitCount != 0); int fullTiles = pAMDGPU->skFullTiles; @@ -765,21 +742,21 @@ namespace TensileLite uint32_t skTiles = sk.grid; // If not evenly divisible, determine number of Stream-K tiles if(tiles % sk.grid != 0) - { + { // Number of data-parallel tiles on each workgroup would be: // dpTilesPerWG = bigEnough ? (tiles - skTiles) / skGrid : 0; skTiles = bigEnough ? sk.grid * fullTiles + tiles % sk.grid : tiles; // Cap Stream-K tiles at total number of tiles in case of large multiplier skTiles = min(skTiles, tiles); - } + } uint32_t skItersPerWG = skTiles * itersPerTile / sk.grid; args.template append("SKItersPerWG", skItersPerWG); - args.template append("skGrid", sk.grid); - args.template append("skTiles", skTiles); - } - } + args.template append("skGrid", sk.grid); + args.template append("skTiles", skTiles); + } + } } if constexpr(insertKernelArgs) @@ -958,29 +935,26 @@ namespace TensileLite / std::ceil(std::ceil(m / mt0) * std::ceil(n / mt1) * gsu / cuCount); } - std::pair ContractionSolution::calculateAutoWGM( - Problem const& problem, - Hardware const* hardware, - uint32_t const skgrid) const + std::pair ContractionSolution::calculateAutoWGM(Problem const& problem, + Hardware const* hardware, + uint32_t const skgrid) const { // Hardware - AMDGPU const* pAMDGPU = dynamic_cast(hardware); + AMDGPU const* pAMDGPU = dynamic_cast(hardware); hip::HipAMDGPU const* hipAMDGPU = dynamic_cast(hardware); // Default WGM int32_t defaultWGM; - uint32_t defaultWGMXCC; + int32_t defaultWGMXCC; // Dynamically pick the values - if(sizeMapping.streamK != 0 - && skgrid != 0 - && sizeMapping.workGroupMapping == 0 - && sizeMapping.workGroupMappingXCC == -1 - && sizeMapping.nonTemporalA < 4 /* Exclude NTs for now till we fix libs */ - && sizeMapping.nonTemporalB < 4 /* Exclude NTs for now till we fix libs */) - { - int32_t c_wgm = 0; - uint32_t c_wgmxcc = 0; + if(sizeMapping.streamK != 0 && skgrid != 0 && sizeMapping.workGroupMapping == 0 + && sizeMapping.workGroupMappingXCC == -1 + && sizeMapping.nonTemporalA < 4 /* Exclude NTs for now till we fix libs */ + && sizeMapping.nonTemporalB < 4 /* Exclude NTs for now till we fix libs */) + { + int32_t c_wgm = 0; + int32_t c_wgmxcc = 0; // Try to find cached WGM and WGMXCC std::tie(c_wgm, c_wgmxcc) = paramsCache.find(problem); @@ -989,30 +963,30 @@ namespace TensileLite auto sizes = problem.problemSizes(); if(sizes.size() >= 4) { - auto wgm_pred = origami::select_best_wgm(*(hipAMDGPU->analyticalHardware), - sizes[0], - sizes[1], - sizes[3], - sizes[2], - sizeMapping.macroTile.x, - sizeMapping.macroTile.y, - sizeMapping.depthU, - sizeMapping.nonTemporalA, - sizeMapping.nonTemporalB, - skgrid, - false); - defaultWGMXCC = std::get<0>(wgm_pred); - defaultWGM = std::get<1>(wgm_pred); + origami::problem_t origami_problem = { + .size = {sizes[0], sizes[1], sizes[3]}, + .batch = sizes[2], + }; + origami::config_t origami_config = { + .mt = {static_cast(sizeMapping.macroTile.x), + static_cast(sizeMapping.macroTile.y), + static_cast(sizeMapping.depthU)}, + .cache_hints_a = sizeMapping.nonTemporalA, + .cache_hints_b = sizeMapping.nonTemporalB, + }; + std::tie(defaultWGMXCC, defaultWGM) = origami::select_workgroup_mapping( + origami_problem, *(hipAMDGPU->analyticalHardware), origami_config, skgrid); // Add to cache only if dynamically calculated. paramsCache.add(std::make_pair(defaultWGM, defaultWGMXCC), problem); if(Debug::Instance().printPropertyEvaluation()) - std::cout << "Dynamic WGM "<< defaultWGM << ", WGMXCC " << defaultWGMXCC << std::endl; + std::cout << "Dynamic WGM " << defaultWGM << ", WGMXCC " << defaultWGMXCC + << std::endl; } } else { - defaultWGM = c_wgm; + defaultWGM = c_wgm; defaultWGMXCC = c_wgmxcc; } } @@ -1038,7 +1012,6 @@ namespace TensileLite defaultWGMXCC = sizeMapping.workGroupMappingXCC; } - // If WGM and WGMXCC are explicitly specified at runtime, they override default and predictions if(pAMDGPU->fixedWGM != std::numeric_limits::max()) { @@ -1068,30 +1041,30 @@ namespace TensileLite AMDGPU const* pAMDGPU = dynamic_cast(hardware); assert(pAMDGPU); - uint32_t numCUs = pAMDGPU->computeUnitCount; - uint32_t numWGs = getNumWorkGroups(problem, sizeMapping); + uint32_t numCUs = pAMDGPU->computeUnitCount; + uint32_t numWGs = getNumWorkGroups(problem, sizeMapping); // avoid zero division - if (numWGs == 0) + if(numWGs == 0) { return 1; } - uint32_t MT0 = sizeMapping.macroTile.x; - uint32_t MT1 = sizeMapping.macroTile.y; - uint32_t MT2 = sizeMapping.depthU; - uint32_t M = problem.freeSizeA(0); - uint32_t N = problem.freeSizeB(0); - uint32_t B = problem.batchSize(0); - uint32_t K = problem.boundSize(0); - uint32_t GSULimit1 = max(1, (uint32_t)std::floor(numCUs / numWGs)); - uint32_t GSULimit2 = max(1, (uint32_t)std::floor((float)K / (float)MT2 / 3.0)); - uint32_t gsuVal = min(GSULimit2, max(1, GSULimit1)); + uint32_t MT0 = sizeMapping.macroTile.x; + uint32_t MT1 = sizeMapping.macroTile.y; + uint32_t MT2 = sizeMapping.depthU; + uint32_t M = problem.freeSizeA(0); + uint32_t N = problem.freeSizeB(0); + uint32_t B = problem.batchSize(0); + uint32_t K = problem.boundSize(0); + uint32_t GSULimit1 = max(1, (uint32_t)std::floor(numCUs / numWGs)); + uint32_t GSULimit2 = max(1, (uint32_t)std::floor((float)K / (float)MT2 / 3.0)); + uint32_t gsuVal = min(GSULimit2, max(1, GSULimit1)); // WorkgroupNumberCheck #define MAX_WORKGROUP_NUMBER 16777216 if(gsuVal > 1) gsuVal = min(gsuVal, - MAX_WORKGROUP_NUMBER / std::ceil(static_cast(M) / MT0) - / std::ceil(static_cast(N) / MT1) / B); + MAX_WORKGROUP_NUMBER / std::ceil(static_cast(M) / MT0) + / std::ceil(static_cast(N) / MT1) / B); // GlobalSplitUCheckMinK if(gsuVal > 1) @@ -1100,17 +1073,19 @@ namespace TensileLite // SynchronizerSizeCheck if(gsuVal > 1 && sizeMapping.globalAccumulation == 3) // MBSK { - uint32_t synchronizerUsage = sizeMapping.synchronizerSizePerWG * problem.getNumTiles(sizeMapping, 1) * B; + uint32_t synchronizerUsage + = sizeMapping.synchronizerSizePerWG * problem.getNumTiles(sizeMapping, 1) * B; gsuVal = synchronizerUsage > 409600 ? 1 : gsuVal; } // Avoid selecting a gsu value that would make launch grid over the limit - uint32_t tiles0 = CeilDivide(M, MT0); - uint32_t tiles1 = CeilDivide(N, MT1); - uint32_t tiles = tiles0 * tiles1 * B; - uint32_t workGroupSize = sizeMapping.workGroupSize.x * sizeMapping.workGroupSize.y * sizeMapping.workGroupSize.z; + uint32_t tiles0 = CeilDivide(M, MT0); + uint32_t tiles1 = CeilDivide(N, MT1); + uint32_t tiles = tiles0 * tiles1 * B; + uint32_t workGroupSize = sizeMapping.workGroupSize.x * sizeMapping.workGroupSize.y + * sizeMapping.workGroupSize.z; uint32_t maxGsuValue = (std::numeric_limits::max() / workGroupSize) / tiles; - gsuVal = min(gsuVal, maxGsuValue); + gsuVal = min(gsuVal, maxGsuValue); // avoid gsu < 1 gsuVal = max(gsuVal, 1); @@ -1174,7 +1149,8 @@ namespace TensileLite // NB: get value from param= set in runtime / vs value from sizeMapping: from logic yaml. // param: default values: [xcc = 0, xccg = 0]. So when we never set xcc/xccg in runtime: we always get from sizeMapping. // From sizeMapping = from logic yaml. If not set in Config-Yaml, use default value [1, -1] - wgmxccg = param.wgmxccg() != 0 ? param.wgmxccg() : sizeMapping.workGroupMappingXCCGroup; + wgmxccg + = param.wgmxccg() != 0 ? param.wgmxccg() : sizeMapping.workGroupMappingXCCGroup; if(wgmxcc >= 1 && wgmxccg == -1) { AMDGPU const* pAMDGPU = dynamic_cast(hardware); @@ -1185,7 +1161,7 @@ namespace TensileLite } else if(internalArgsSupport.version == 2 && internalArgsSupport.useSFC) { - internalArg1 = wgm; + internalArg1 = wgm; } } @@ -1209,9 +1185,9 @@ namespace TensileLite uint32_t staggerU = mask8 & sizeMapping.staggerU; if(Debug::Instance().disableStaggerU()) staggerU = 0; - staggerU = staggerU | staggerUShift; - staggerU = staggerU | staggerUMapping; - internalArg0 = internalArg0 | (staggerU << 16); + staggerU = staggerU | staggerUShift; + staggerU = staggerU | staggerUMapping; + internalArg0 = internalArg0 | (staggerU << 16); } else if(T_Debug && Debug::Instance().disableStaggerU()) std::cout << "solution doesn't support configurable staggerU" << std::endl; @@ -1225,12 +1201,12 @@ namespace TensileLite } } - void ContractionSolution::calculateGrid(dim3& workGroupSize, - dim3& numWorkGroups, + void ContractionSolution::calculateGrid(dim3& workGroupSize, + dim3& numWorkGroups, ContractionSolution::Problem const& problem) const { workGroupSize.x = sizeMapping.workGroupSize.x * sizeMapping.workGroupSize.y - * sizeMapping.workGroupSize.z; + * sizeMapping.workGroupSize.z; workGroupSize.y = 1; workGroupSize.z = 1; @@ -1337,8 +1313,15 @@ namespace TensileLite std::cout << "AutoWGM: " << autoWGM << std::endl; std::cout << "AutoWGMXCC: " << autoWGMXCC << std::endl; } - kernelArgs( - 1, 0, rv.args, getNumWorkGroups(rv), &hardware, problem.getParams(), autoWGM, autoWGMXCC, autoGsuVal); + kernelArgs(1, + 0, + rv.args, + getNumWorkGroups(rv), + &hardware, + problem.getParams(), + autoWGM, + autoWGMXCC, + autoGsuVal); } singleCallArgs( problem, inputs, 0, &hardware, problemNumGroupTiles, rv.numWorkGroups, rv.args, sk); @@ -1412,7 +1395,8 @@ namespace TensileLite numWorkGroups.x = CeilDivide(numWorkGroups.x, sizeMapping.macroTile.x); numWorkGroups.y = CeilDivide(numWorkGroups.y, sizeMapping.macroTile.y); - uint32_t gsu = problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal; + uint32_t gsu + = problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal; if(gsu > 0) numWorkGroups.y *= gsu; @@ -1472,7 +1456,7 @@ namespace TensileLite { for(int idx = 0; idx < problems.size(); idx++) { - auto problem = problems[idx]; + auto problem = problems[idx]; StreamKSettings sk; // Grouped gemm currently not supported in SK // But this code path is run to calculate to determine if solution is supported @@ -1750,10 +1734,10 @@ namespace TensileLite template void ContractionSolution::outputConversionCallArgs(ContractionSolution::Problem const& problem, ContractionInputs const& inputs, - uint32_t const& workspaceOffsetInByte, - KA& args, + uint32_t const& workspaceOffsetInByte, + KA& args, StreamKSettings const& sk, - uint32_t autoGsuVal) const + uint32_t autoGsuVal) const { TensorDescriptor const& c = problem.c(); TensorDescriptor const& d = problem.d(); @@ -1906,14 +1890,15 @@ namespace TensileLite args.template append(concatenate_if("size_", i), size); i++; } - uint32_t gsu = sizeMapping.globalAccumulation == 1 - ? 1 - : (problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal); + uint32_t gsu + = sizeMapping.globalAccumulation == 1 + ? 1 + : (problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal); if(sizeMapping.streamK > 0) { auto tiles = problem.getNumTiles(sizeMapping, 1); - gsu = sk.grid / tiles; + gsu = sk.grid / tiles; } args.template append(concatenate_if("gsu"), gsu); @@ -1962,16 +1947,17 @@ namespace TensileLite vw = 2; } - uint32_t gsu = sizeMapping.globalAccumulation == 1 - ? 1 - : (problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal); + uint32_t gsu + = sizeMapping.globalAccumulation == 1 + ? 1 + : (problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal); if(sizeMapping.streamK > 0) { // If using post kernel with stream-k then it is doing parallel reduciton // Calculate the splitting factor auto tiles = problem.getNumTiles(sizeMapping, 1); - gsu = sk.grid / tiles; + gsu = sk.grid / tiles; } rv.kernelName = outputConversionKernelName(problem, inputs, vw, gsu); @@ -2126,10 +2112,10 @@ namespace TensileLite problems, vw, rv.workGroupSize, rv.numWorkGroups, rv.numWorkItems, h_args); uint32_t autoGsuVal = calculateAutoGSU(problems[0], &hardware); - uint32_t gsu - = sizeMapping.globalAccumulation == 1 - ? 1 - : (problems[0].getParams().gsu() > 0 ? problems[0].getParams().gsu() : autoGsuVal); + uint32_t gsu = sizeMapping.globalAccumulation == 1 + ? 1 + : (problems[0].getParams().gsu() > 0 ? problems[0].getParams().gsu() + : autoGsuVal); if constexpr(std::is_same::value) { @@ -2140,7 +2126,7 @@ namespace TensileLite = this->requiredHostWorkspaceSizePerProblem * problems.size(); for(int idx = 0; idx < problems.size(); idx++) { - auto problem = problems[idx]; + auto problem = problems[idx]; StreamKSettings sk; outputConversionCallArgs( problem, inputs.grouped[idx], workspaceOffsetInByte, h_args, sk, autoGsuVal); @@ -2580,7 +2566,7 @@ namespace TensileLite std::vector rv; auto autoGsuVal = calculateAutoGSU(problem, &hardware); - auto gsu = problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal; + auto gsu = problem.getParams().gsu() > 0 ? problem.getParams().gsu() : autoGsuVal; if(gsu > 1 && sizeMapping.globalAccumulation != 2 && sizeMapping.globalAccumulation != 3) { if(debug) @@ -2592,11 +2578,13 @@ namespace TensileLite StreamKSettings sk; if(sizeMapping.streamK > 0) { - sk.reduction = getSKReduction(problem, hardware); - auto tiles = problem.getNumTiles(sizeMapping, 1); - sk.grid = getSKGrid(problem, hardware, tiles, sk.reduction); + sk.reduction = getSKReduction(problem, hardware); + auto tiles = problem.getNumTiles(sizeMapping, 1); + sk.grid = getSKGrid(problem, hardware, tiles, sk.reduction); const bool streamKDP = Debug::Instance().useStreamKDataParrallel(); - if(sk.grid > 0 && (sk.reduction == ReductionType::Parallel || (tiles % sk.grid != 0 && !streamKDP))) + if(sk.grid > 0 + && (sk.reduction == origami::reduction_t::parallel + || (tiles % sk.grid != 0 && !streamKDP))) { // Check ideal amount of workspace for optimal performance size_t idealWorkspace = partialTileSize(sk.grid); @@ -2604,8 +2592,8 @@ namespace TensileLite // Performance will likely be lower, but the kernel can run if workspace is unavailable if(idealWorkspace > problem.workspaceSize()) { - sk.reduction = ReductionType::Tree; - sk.grid = tiles; + sk.reduction = origami::reduction_t::tree; + sk.grid = tiles; } } } @@ -2615,7 +2603,8 @@ namespace TensileLite else rv.push_back(generateSingleCall(problem, inputs, hardware, sk)); - if(((sizeMapping.globalAccumulation != 3) && gsu > 1 && sizeMapping.globalAccumulation) || sk.reduction == ReductionType::Parallel) + if(((sizeMapping.globalAccumulation != 3) && gsu > 1 && sizeMapping.globalAccumulation) + || sk.reduction == origami::reduction_t::parallel) { if(debug) rv.push_back(generateOutputConversionCall(problem, inputs, sk, autoGsuVal)); @@ -2804,7 +2793,8 @@ namespace TensileLite rv.push_back( generateSingleCallGroupedGemm(problems, inputs, hardware, h_args, dUA)); - auto gsu = problems[0].getParams().gsu() > 0 ? problems[0].getParams().gsu() : calculateAutoGSU(problems[0], &hardware); + auto gsu = problems[0].getParams().gsu() > 0 ? problems[0].getParams().gsu() + : calculateAutoGSU(problems[0], &hardware); if((sizeMapping.globalAccumulation && gsu > 1) && (sizeMapping.globalAccumulation != 3)) { @@ -2992,12 +2982,12 @@ namespace TensileLite auto tiles = problem.getNumTiles(sizeMapping, 1); if(tiles > 0) // Grouped GEMM reports 0 tiles { - ReductionType reductionStrat = getSKReduction(problem, hardware); - size_t skGrid = getSKGrid(problem, hardware, tiles, reductionStrat); + auto reductionStrat = getSKReduction(problem, hardware); + size_t skGrid = getSKGrid(problem, hardware, tiles, reductionStrat); // Get space required for partial tiles= - if(reductionStrat == ReductionType::Parallel) + if(reductionStrat == origami::reduction_t::parallel) { - size_t splitk = skGrid / tiles; + size_t splitk = skGrid / tiles; size_t idealWorkspace = requiredWorkspaceSizeGsu(problem, hardware, splitk); if(idealWorkspace <= problem.workspaceSize()) size += idealWorkspace; @@ -3016,22 +3006,24 @@ namespace TensileLite else { // TODO: Pass GSU from problem and change value[2] to gsu if gsu != default value - size_t gsu = problem.getParams().gsu() > 0 ? problem.getParams().gsu() : calculateAutoGSU(problem, &hardware); + size_t gsu = problem.getParams().gsu() > 0 ? problem.getParams().gsu() + : calculateAutoGSU(problem, &hardware); size += requiredWorkspaceSizeGsu(problem, hardware, gsu); } - return size; } - size_t ContractionSolution::requiredWorkspaceSizeGsu(Problem const& problem, Hardware const& hardware, size_t gsu) const + size_t ContractionSolution::requiredWorkspaceSizeGsu(Problem const& problem, + Hardware const& hardware, + size_t gsu) const { size_t size = 0; size_t gsuMultiplier = gsu > 1 ? gsu : 0; size_t batch = problem.d().sizes()[2]; size_t tiles = problem.getNumTiles(sizeMapping, gsu) * batch; - size_t tileSize = sizeMapping.macroTile.x * sizeMapping.macroTile.y - * sizeMapping.workspaceSizePerElemC; + size_t tileSize + = sizeMapping.macroTile.x * sizeMapping.macroTile.y * sizeMapping.workspaceSizePerElemC; size_t bufSize = gsu > 1 ? tiles * tileSize : 0; size += bufSize; @@ -3040,19 +3032,15 @@ namespace TensileLite { if(problem.biasSrc() == ContractionProblemGemm::TENSOR::A) { - size += problem.freeSizeA(0) * sizeMapping.workspaceSizePerElemBias - * gsuMultiplier; + size += problem.freeSizeA(0) * sizeMapping.workspaceSizePerElemBias * gsuMultiplier; } else if(problem.biasSrc() == ContractionProblemGemm::TENSOR::B) { - size += problem.freeSizeB(0) * sizeMapping.workspaceSizePerElemBias - * gsuMultiplier; + size += problem.freeSizeB(0) * sizeMapping.workspaceSizePerElemBias * gsuMultiplier; } - else if(problem.biasSrc() == ContractionProblemGemm::TENSOR::D - && (gsuMultiplier == 0)) + else if(problem.biasSrc() == ContractionProblemGemm::TENSOR::D && (gsuMultiplier == 0)) { - size += problem.d().totalLogicalElements() * problem.computeTypeElementSize() - * gsu; + size += problem.d().totalLogicalElements() * problem.computeTypeElementSize() * gsu; } } @@ -3112,7 +3100,8 @@ namespace TensileLite return h_args.size(); } - size_t ContractionSolution::requiredSynchronizerSize(Problem const& problem, Hardware const& hardware) const + size_t ContractionSolution::requiredSynchronizerSize(Problem const& problem, + Hardware const& hardware) const { if(sizeMapping.globalAccumulation == 3) { @@ -3123,9 +3112,10 @@ namespace TensileLite return 0; } - ReductionType ContractionSolution::getSKReduction(Problem const& problem, Hardware const& hardware) const + origami::reduction_t ContractionSolution::getSKReduction(Problem const& problem, + Hardware const& hardware) const { - ReductionType reductionStrat = ReductionType::Tree; + auto reductionStrat = origami::reduction_t::tree; AMDGPU const* pAMDGPU = dynamic_cast(&hardware); assert(pAMDGPU != nullptr && pAMDGPU->computeUnitCount != 0); @@ -3133,13 +3123,13 @@ namespace TensileLite if(!sizeMapping.customKernelName.empty()) { // Custom kernel currently only supports single-kernel reduction - reductionStrat = ReductionType::Tree; + reductionStrat = origami::reduction_t::tree; } else if(pAMDGPU->skDynamicGrid > 0) { size_t x = 1; size_t y = 1; - size_t z = 1; + size_t z = 1; size_t batch = 1; for(size_t i = 0; i < problem.freeIndicesA().size(); i++) { @@ -3157,30 +3147,34 @@ namespace TensileLite { batch *= problem.batchSize(i); } - origami::data_type_t miDataType = datatypeToAnalyticalDatatype(problem.computeInputType()); hip::HipAMDGPU const* hipAMDGPU = dynamic_cast(&hardware); + origami::problem_t origami_problem = { + .size = {x, y, z}, + .batch = batch, + }; + origami::config_t origami_config = { + .mt = {static_cast(sizeMapping.macroTile.x), + static_cast(sizeMapping.macroTile.y), + static_cast(sizeMapping.depthU)}, + }; + reductionStrat = origami::streamk::select_reduction( - x, - y, - z, - batch, - sizeMapping.macroTile.x, - sizeMapping.macroTile.y, - sizeMapping.depthU, + origami_problem, *(hipAMDGPU->analyticalHardware), - pAMDGPU->skDynamicGrid); + origami_config, + static_cast(pAMDGPU->skDynamicGrid)); } return reductionStrat; } - size_t ContractionSolution::getSKGrid(Problem const& problem, - Hardware const& hardware, - size_t tiles, - ReductionType& reductionStrat) const + size_t ContractionSolution::getSKGrid(Problem const& problem, + Hardware const& hardware, + size_t tiles, + origami::reduction_t reductionStrat) const { - size_t skGrid = tiles; // Fallback + size_t skGrid = tiles; // Fallback const bool streamKDP = Debug::Instance().useStreamKDataParrallel(); if(streamKDP) skGrid = tiles; @@ -3204,7 +3198,7 @@ namespace TensileLite { skGrid = pAMDGPU->skFixedGrid; } - else if (pAMDGPU->skDynamicGrid > 0) + else if(pAMDGPU->skDynamicGrid > 0) { size_t x = 1; size_t y = 1; @@ -3221,33 +3215,34 @@ namespace TensileLite { batch *= problem.batchSize(i); } - origami::data_type_t miDataType = datatypeToAnalyticalDatatype(problem.computeInputType()); hip::HipAMDGPU const* hipAMDGPU = dynamic_cast(&hardware); - skGrid = origami::streamk::select_grid(x, - y, - z, - batch, - problem.transA(), - problem.transB(), - problem.a().elementBytes() * 8, - problem.b().elementBytes() * 8, - problem.c().elementBytes() * 8, - miDataType, - problem.workspaceSize(), - sizeMapping.macroTile.x, - sizeMapping.macroTile.y, - sizeMapping.depthU, - sizeMapping.matrixInstruction[0], - sizeMapping.matrixInstruction[1], - sizeMapping.matrixInstruction[2], - sizeMapping.workGroupMapping, - sizeMapping.workspaceSizePerElemC, - sizeMapping.CUOccupancy, - *(hipAMDGPU->analyticalHardware), - pAMDGPU->skDynamicGrid, - reductionStrat, - pAMDGPU->skMaxCUs); + origami::problem_t origami_problem = { + .size = {x, y, z}, + .batch = batch, + .a_dtype = datatypeToAnalyticalDatatype(problem.alphaType()), + .b_dtype = datatypeToAnalyticalDatatype(problem.betaType()), + .mi_dtype = datatypeToAnalyticalDatatype(problem.computeInputType()), + }; + origami::config_t origami_config = { + .mt = {static_cast(sizeMapping.macroTile.x), + static_cast(sizeMapping.macroTile.y), + static_cast(sizeMapping.depthU)}, + .mi = {static_cast(sizeMapping.matrixInstruction[0]), + static_cast(sizeMapping.matrixInstruction[1]), + static_cast(sizeMapping.matrixInstruction[2])}, + .occupancy = std::max(sizeMapping.CUOccupancy, static_cast(1)), + .workgroup_mapping = sizeMapping.workGroupMapping, + .workspace_size = problem.workspaceSize(), + .workspace_size_per_elem_c = sizeMapping.workspaceSizePerElemC, + .reduction_strategy = reductionStrat, + }; + skGrid = origami::streamk::select_grid_size( + origami_problem, + *(hipAMDGPU->analyticalHardware), + origami_config, + static_cast(pAMDGPU->skDynamicGrid), + pAMDGPU->skMaxCUs); } // Limit the CUs Stream-K is launched on either max or the specified, // whichever is minimum. @@ -3272,11 +3267,11 @@ namespace TensileLite // For tree-reduction there are some limits for divisions to avoid overflow // If we hit one of the limits, fallback to DP size_t itersPerTile = problem.getItersPerTile(sizeMapping); - size_t itersPerWG = tiles * itersPerTile / skGrid; - if(itersPerTile >=65536 || itersPerWG >= 65536 || (tiles * itersPerTile) >= 16777216) + size_t itersPerWG = tiles * itersPerTile / skGrid; + if(itersPerTile >= 65536 || itersPerWG >= 65536 || (tiles * itersPerTile) >= 16777216) { - reductionStrat = ReductionType::Tree; - skGrid = tiles; + reductionStrat = origami::reduction_t::tree; + skGrid = tiles; } return skGrid; @@ -3286,7 +3281,8 @@ namespace TensileLite { size_t size = 0; - size_t tileSize = sizeMapping.macroTile.x * sizeMapping.macroTile.y * sizeMapping.workspaceSizePerElemC; + size_t tileSize + = sizeMapping.macroTile.x * sizeMapping.macroTile.y * sizeMapping.workspaceSizePerElemC; size += tileSize * skGrid; // Partials tile per WG // TODO batches // TODO round up for alignment? @@ -3299,8 +3295,13 @@ namespace TensileLite return x / ceil(x); } - ContractionSolution::Granularities ContractionSolution::computeGranularities( - Hardware const& hardware, double M, double N, double K, double NumBatches, uint32_t autoGsuVal) const + ContractionSolution::Granularities + ContractionSolution::computeGranularities(Hardware const& hardware, + double M, + double N, + double K, + double NumBatches, + uint32_t autoGsuVal) const { ContractionSolution::Granularities granularities; @@ -3404,7 +3405,8 @@ namespace TensileLite } double K = problem.boundSize(0); // TODO - fix for multiple summations - pp.granularities = ContractionSolution::computeGranularities(hardware, M, N, K, NumBatches, calculateAutoGSU(problem, &hardware)); + pp.granularities = ContractionSolution::computeGranularities( + hardware, M, N, K, NumBatches, calculateAutoGSU(problem, &hardware)); auto it = ideals.begin(); diff --git a/shared/origami/.clang-format b/shared/origami/.clang-format new file mode 100644 index 000000000000..1158e834f77f --- /dev/null +++ b/shared/origami/.clang-format @@ -0,0 +1,34 @@ +BasedOnStyle: Google +BinPackArguments: false +BinPackParameters: false +ColumnLimit: 100 +IndentWidth: 2 +BreakConstructorInitializers: BeforeComma +IncludeBlocks: Preserve + + +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: true +AlignOperands: true +AlignTrailingComments: true + +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine : false +AllowShortBlocksOnASingleLine : true +AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine : All +AllowShortIfStatementsOnASingleLine: true +AllowShortLambdasOnASingleLine : All +AllowShortLoopsOnASingleLine: true +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakBeforeMultilineStrings : false +AlwaysBreakTemplateDeclarations: Yes + +BreakBeforeBinaryOperators: None +Cpp11BracedListStyle: true +IndentWrappedFunctionNames : false +KeepEmptyLinesAtTheStartOfBlocks : false +PointerAlignment: Left +ReflowComments : true +ExperimentalAutoDetectBinPacking: false +BreakBeforeBraces: Attach diff --git a/shared/origami/.gitignore b/shared/origami/.gitignore new file mode 100644 index 000000000000..050a6c80a574 --- /dev/null +++ b/shared/origami/.gitignore @@ -0,0 +1,20 @@ +# Build directory +build/ + +# CMake cache +CMakeCache.txt + +# IDE files +.vscode/ +.vs/ + +# Python bindings build +__pycache__/ +*.pyc +*.pyo +*.so + +# Temporary files +*.tmp +*.swp +*~ diff --git a/shared/origami/CMakeLists.txt b/shared/origami/CMakeLists.txt index 6779f95a6214..9a420a4ea866 100644 --- a/shared/origami/CMakeLists.txt +++ b/shared/origami/CMakeLists.txt @@ -1,5 +1,27 @@ -# Copyright Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT +################################################################################ +# +# MIT License +# +# Copyright 2025 AMD ROCm(TM) Software +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- +# ies of the Software, and to permit persons to whom the Software is furnished +# to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- +# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- +# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ cmake_minimum_required(VERSION 3.24.4) @@ -24,11 +46,12 @@ option(ORIGAMI_BUILD_SHARED_LIBS "Build shared libraries." ${ORIGAMI_STANDALONE} option(ORIGAMI_ENABLE_PYTHON "Enable Python bindings." OFF) option(ORIGAMI_BUILD_TESTING "Enable Python binding tests." OFF) option(ORIGAMI_ENABLE_INSTALL "Configure origami installation" ON) +option(ORIGAMI_ENABLE_FETCH "Auto-fetch dependencies with FetchContent" ON) find_package(hip REQUIRED) if(ORIGAMI_BUILD_SHARED_LIBS OR (BUILD_SHARED_LIBS AND ORIGAMI_STANDALONE)) - add_library(origami SHARED) + add_library(origami SHARED) else() add_library(origami STATIC) endif() @@ -37,34 +60,39 @@ rocm_set_soversion(origami "${PROJECT_VERSION_MAJOR}.${PROJECT_VERSION_MINOR}") add_library(roc::origami ALIAS origami) -set_target_properties(origami PROPERTIES - POSITION_INDEPENDENT_CODE ON -) +set_target_properties(origami PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_features(origami PUBLIC cxx_std_17) add_library(origami-headers INTERFACE) add_library(roc::origami-headers ALIAS origami-headers) -target_include_directories(origami-headers - INTERFACE - $ - $ +target_compile_features(origami-headers INTERFACE cxx_std_17) + +target_include_directories( + origami-headers INTERFACE $ + $ ) -target_sources(origami-headers - INTERFACE - $ - $ - $ - $ +target_sources( + origami-headers + INTERFACE $ + $ + $ + $ + $ + $ + $ ) target_link_libraries(origami PUBLIC roc::origami-headers) -target_sources(origami - PRIVATE - "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/gemm.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/utils.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/streamk.cpp" +target_sources( + origami + PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/gemm.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/hardware.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/log.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/origami.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/streamk.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/src/origami/types.cpp" ) target_link_libraries(origami PUBLIC hip::host) @@ -75,32 +103,30 @@ endif() if(ORIGAMI_BUILD_TESTING OR BUILD_TESTING) enable_testing() - + add_subdirectory(tests) if(ORIGAMI_ENABLE_PYTHON) find_package(Python3 REQUIRED COMPONENTS Interpreter Development) - add_test( - NAME origami_python_test - COMMAND "${Python3_EXECUTABLE}" "origami_test.py" - WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/python" + add_test(NAME origami_python_test COMMAND "${Python3_EXECUTABLE}" "origami_test.py" + WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/python" ) - - add_test( - NAME origami_python_grid_test - COMMAND "${Python3_EXECUTABLE}" "origami_grid_test.py" - WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/python" + + add_test(NAME origami_python_grid_test COMMAND "${Python3_EXECUTABLE}" + "origami_grid_test.py" + WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/python" ) - set_tests_properties(origami_python_test origami_python_grid_test + set_tests_properties( + origami_python_test origami_python_grid_test PROPERTIES - ENVIRONMENT "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/python:${CMAKE_CURRENT_SOURCE_DIR}/python:$ENV{PYTHONPATH}" + ENVIRONMENT + "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/python:${CMAKE_CURRENT_SOURCE_DIR}/python:$ENV{PYTHONPATH}" ) - set_tests_properties(origami_python_test origami_python_grid_test - PROPERTIES - DEPENDS origami_python + set_tests_properties( + origami_python_test origami_python_grid_test PROPERTIES DEPENDS origami_python ) endif() endif() @@ -109,43 +135,49 @@ if(ORIGAMI_ENABLE_INSTALL OR ORIGAMI_STANDALONE) rocm_install(TARGETS origami origami-headers) rocm_export_targets( - TARGETS roc::origami roc::origami-headers - DEPENDS PACKAGE hip - NAMESPACE roc:: + TARGETS + roc::origami + roc::origami-headers + DEPENDS + PACKAGE + hip + NAMESPACE + roc:: ) if(ORIGAMI_BUILD_TESTING OR BUILD_TESTING) - rocm_install(TARGETS origami-tests - COMPONENT tests - ) - - rocm_install(FILES "${CMAKE_CURRENT_SOURCE_DIR}/tests/origami_gtest.yaml" - DESTINATION "${CMAKE_INSTALL_BINDIR}" - COMPONENT tests - ) + rocm_install(TARGETS origami-tests COMPONENT tests) + endif() rocm_install( - DIRECTORY include/ - DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}" - COMPONENT devel - FILES_MATCHING PATTERN "*.hpp" PATTERN "*.h" + DIRECTORY + include/ + DESTINATION + "${CMAKE_INSTALL_INCLUDEDIR}" + COMPONENT + devel + FILES_MATCHING + PATTERN + "*.hpp" + PATTERN + "*.h" ) configure_file( "${CMAKE_CURRENT_SOURCE_DIR}/cmake/origami-config.cmake.in" - "${CMAKE_CURRENT_BINARY_DIR}/origami-config.cmake" - @ONLY + "${CMAKE_CURRENT_BINARY_DIR}/origami-config.cmake" @ONLY ) rocm_install( - FILES "${CMAKE_CURRENT_BINARY_DIR}/origami-config.cmake" - DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/origami" - COMPONENT devel + FILES "${CMAKE_CURRENT_BINARY_DIR}/origami-config.cmake" DESTINATION + "${CMAKE_INSTALL_LIBDIR}/cmake/origami" COMPONENT devel ) set(BUILD_SHARED_LIBS ${ORIGAMI_BUILD_SHARED_LIBS}) - set(ORIGAMI_CONFIG_DIR "\${CPACK_PACKAGING_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}" CACHE PATH "Path placed into ldconfig file") + set(ORIGAMI_CONFIG_DIR "\${CPACK_PACKAGING_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}" + CACHE PATH "Path placed into ldconfig file" + ) rocm_create_package( NAME origami diff --git a/shared/origami/include/origami/gemm.hpp b/shared/origami/include/origami/gemm.hpp index 18d427ea69dd..607b34fb3f45 100644 --- a/shared/origami/include/origami/gemm.hpp +++ b/shared/origami/include/origami/gemm.hpp @@ -1,245 +1,197 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +/******************************************************************************* + * + * MIT License + * + * Copyright 2025 AMD ROCm(TM) Software + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ #pragma once -#include "origami/hardware.hpp" #include - -namespace origami -{ - // Placeholder for compute_reuse_in_block_gemm function. - // TODO move over L2 hit rate simulation for tie-breaking. - double compute_reuse_in_block_gemm(size_t grid_m, - size_t grid_n, - size_t grid_k, - size_t A_size, - size_t B_size, - size_t C_size, - size_t nproc, - size_t capacity, - const std::vector& radix, - bool print_radix, - bool print_output, - size_t max_timesteps, - size_t max_iters); - - // Compute - std::tuple compute_CU_occupancy(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, - size_t element_size_B, - size_t element_size_out, - data_type_t mi_datatype, - int WGM, - size_t workspace_size, - size_t workspace_size_per_elem_c, - int occupancy, - int dynamic_grid_version, - size_t split, - size_t max_cus = 0); - - /* ---------------------------------------------------------------------------------------- */ - /* Compute-related functions */ - /* ---------------------------------------------------------------------------------------- */ - // Compute the number of matrix instructions required to compute a single MT_MXMT_NXMT_K tile. - size_t compute_number_matrix_instructions(const hardware_t& hardware, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K); - - // Determine the compute latency per MT_MxMT_NxMT_K Macro Tile (L_MT). - size_t compute_mt_compute_latency(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, //In bits - size_t element_size_B, //In bits, - data_type_t mi_datatype); - - /* ---------------------------------------------------------------------------------------- */ - /* Memory-related functions */ - /* ---------------------------------------------------------------------------------------- */ - // Check if MT fits in LDS - bool check_lds_capacity( - const hardware_t& hardware, size_t MT_M, size_t MT_N, size_t MT_K, size_t element_size_out); - - // Compute the amount of data loaded from A to produce a MT_MxMT_NxMT_K tile. - size_t compute_A_loads(size_t MT_M, size_t MT_K); - - // Compute the amount of data loaded from B to produce a MT_MxMT_NxMT_K tile. - size_t compute_B_loads(size_t MT_N, size_t MT_K); - - // Computes total data loads per CU per MT from A and B - // Reads happen every MT, Writes happen every K-complete tile. - size_t compute_cu_loads(size_t MT_M, size_t MT_N, size_t MT_K); - - // Estimates the l2 hit-rate - double estimate_l2_hit(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t element_size, - int WGM, - size_t splittingFactor); - - // Estimates the mall hit-rate - double estimate_mall_hit(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t element_size, - int WGM, - size_t numActiveCUs, - size_t splittingFactor); - - // Determine the memory latency per MT_MxMT_NxMT_K Macro Tile (L_MT). - double compute_memory_latency(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - bool transA, - bool transB, - size_t batch, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t element_size_A, //In bits - size_t element_size_B, //In bits, - size_t mx_block_size, - int WGM, - size_t numActiveCUs, - size_t splittingFactor); - - /* ---------------------------------------------------------------------------------------- */ - /* Tile-related functions */ - /* ---------------------------------------------------------------------------------------- */ - // Computes the latency to compute a K-COMPLETE tile. - double compute_tile_latency(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, //In bits - size_t element_size_B, //In bits, - size_t element_size_out, //In bits - data_type_t mi_datatype, - size_t mx_block_size, - int WGM, - int occupancy, - size_t numActiveCUs, - size_t splittingFactor); - - // Computes the latency per K-complete MT wave. - // A wave is defined as : The time it takes for one CU to complete one K-complete output tile - double compute_wave_latency(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, //In bits - size_t element_size_B, //In bits, - size_t element_size_out, //In bits - data_type_t mi_datatype, - size_t mx_block_size, - int WGM, - int occupancy, - size_t numActiveCUs, - size_t splittingFactor); - - // Compute the total latency of a gemm based on the latency of one wave multiplied by the number of waves - // A wave is defined as : The time it takes for one CU to complete one K-complete output tile - double compute_total_latency(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, //In bits - size_t element_size_B, //In bits, - size_t element_size_out, //In bits - data_type_t mi_datatype, - size_t mx_block_size, - int WGM, - int non_temporal_a = 0, - int non_temporal_b = 0, - int occupancy = 1, - size_t split = 0, - size_t max_cus = 0); - - // Compute the performance from the latency. - // IMPORTANT : This program is NOT meant to be an analytical model for performance, but rather a way to rank different macro tile sizes. - // These performance values could be wildly inaccurate in absolute terms, but will often result in the correct ranking of MTin relative terms. - double compute_perf_gflops(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, - size_t element_size_B, - size_t element_size_out, - data_type_t mi_datatype, - int WGM, - size_t max_cus = 0); -} // namespace origami +#include "origami/hardware.hpp" +#include "origami/types.hpp" + +namespace origami { + +/** + * @brief Compute the number of matrix instructions required to compute a single MT_MXMT_NXMT_K + * tile. + * + * @param mt Macro tile dimensions + * @param mi Micro tile dimensions + * @return size_t Number of matrix instructions + */ +size_t compute_number_matrix_instructions(dim3_t mt, dim3_t mi); + +/** + * @brief Compute TF32 conversion overhead. + * + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param config Kernel configuration. + * @return double Latency in cycles. + */ +static inline double compute_cvt_overhead(const problem_t& problem, + const hardware_t& hardware, + const config_t& config); +/** + * @brief Compute the latency to process a single macro-tile for the given problem and hardware. + * + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param config Kernel configuration. + * @return size_t Latency in cycles. + */ +size_t compute_mt_compute_latency(const problem_t& problem, + const hardware_t& hardware, + const config_t& config); + +/** + * @brief Check if MT fits in LDS + * + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param mt Macro tile dimensions + * @param a_dtype Data type of operand A + * @param b_dtype Data type of operand B + * @return bool True if MT fits in LDS, false otherwise + */ +bool check_lds_capacity(const hardware_t& hardware, + dim3_t mt, + data_type_t a_dtype, + data_type_t b_dtype); + +/** + * @brief Compute the amount of data loaded from A to produce a MT_MxMT_NxMT_K tile. + * + * @param MT_M Macro tile dimension M + * @param MT_K Macro tile dimension K + * @return size_t Amount of data loaded from A + */ +size_t compute_A_loads(size_t MT_M, size_t MT_K); + +/** + * @brief Compute the amount of data loaded from B to produce a MT_MxMT_NxMT_K tile. + * + * @param MT_N Macro tile dimension N + * @param MT_K Macro tile dimension K + * @return size_t Amount of data loaded from B + */ +size_t compute_B_loads(size_t MT_N, size_t MT_K); + +/** + * @brief A linear-estimation method for estimating L2-hitrate. + * + * @todo Parameterize this based on the space-filling curve algos. + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param config Kernel configuration. + * @param splitting_factor + * @return double Predicted L2-hitrate. + */ +double estimate_l2_hit(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + std::size_t splitting_factor); + +/** + * @brief Estimate the MALL-hitrate (last-level cache.) + * + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param config Kernel configuration. + * @param num_active_cus + * @param splitting_factor + * @return double Predicted MALL-hitrate. + */ +double estimate_mall_hit(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + std::size_t num_active_cus, + std::size_t splitting_factor); + +/** + * @brief Determine the memory latency per MT_M x MT_N x MT_K Macro Tile (L_MT). + * + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param config Kernel configuration. + * @param num_active_cus + * @param splitting_factor + * @return double Latency in cycles. + */ +double compute_memory_latency(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + std::size_t num_active_cus, + std::size_t splitting_factor); + +/** + * @brief Computes the latency to compute a K-COMPLETE tile. + * + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param config Kernel configuration. + * @param num_active_cus + * @param splitting_factor + * @return double Latency in cycles. + */ +double compute_tile_latency(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + std::size_t num_active_cus, + std::size_t splitting_factor); + +/** + * @brief Computes the latency per K-complete MT wave. + * A wave is defined as the time it takes for one CU to complete one + * K-complete output tile + * + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param config Kernel configuration. + * @param num_active_cus + * @param splitting_factor + * @return double Latency in cycles. + */ +double compute_timestep_latency(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + std::size_t num_active_cus, + std::size_t splitting_factor); + +/** + * @brief Compute the total latency of a gemm based on the latency of one wave multiplied by the + * number of waves A wave is defined as the time it takes for one CU to complete one K-complete + * output tile. + * + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param config Kernel configuration. + * @param max_cus + * @return double Latency in cycles. + */ +double compute_total_latency(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + size_t max_cus); + +} // namespace origami diff --git a/shared/origami/include/origami/hardware.hpp b/shared/origami/include/origami/hardware.hpp index f5dd7edb4ca2..0fe73d06cfeb 100644 --- a/shared/origami/include/origami/hardware.hpp +++ b/shared/origami/include/origami/hardware.hpp @@ -1,763 +1,252 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +/******************************************************************************* + * + * MIT License + * + * Copyright 2025 AMD ROCm(TM) Software + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ #pragma once -#include +#include #include +#include #include #include #include -namespace origami -{ - enum class data_type_t : int - { - Float, - Double, - ComplexFloat, - ComplexDouble, - Half, - Int8x4, - Int32, - BFloat16, - Int8, - Int4, - Int64, - XFloat32, - Float8_fnuz, - BFloat8_fnuz, - Float8BFloat8_fnuz, - BFloat8Float8_fnuz, - Float8, - BFloat8, - Float8BFloat8, - BFloat8Float8, - Float6, - BFloat6, - Float4, - Count, - None = Count - }; - - inline data_type_t int_to_data_type(int dt) - { - return (data_type_t)dt; - } - - inline int data_type_to_bits(data_type_t type) - { - switch(type) - { - case data_type_t::Float: - return 32; - case data_type_t::Double: - return 64; - case data_type_t::ComplexFloat: - return 64; - case data_type_t::ComplexDouble: - return 128; - case data_type_t::Half: - return 16; - case data_type_t::Int8x4: - return 32; - case data_type_t::Int32: - return 32; - case data_type_t::BFloat16: - return 16; - case data_type_t::Int8: - return 8; - case data_type_t::Int4: - return 4; - case data_type_t::Int64: - return 64; - case data_type_t::XFloat32: - return 32; - case data_type_t::Float8_fnuz: - return 8; - case data_type_t::BFloat8_fnuz: - return 8; - case data_type_t::Float8BFloat8_fnuz: - return 8; - case data_type_t::BFloat8Float8_fnuz: - return 8; - case data_type_t::Float8: - return 8; - case data_type_t::BFloat8: - return 8; - case data_type_t::Float8BFloat8: - return 8; - case data_type_t::BFloat8Float8: - return 8; - case data_type_t::Float6: - return 6; - case data_type_t::BFloat6: - return 6; - case data_type_t::Float4: - return 4; - default: - return -1; // Invalid type - } +#include "origami/types.hpp" + +namespace origami { +/** + * @brief Represents hardware characteristics and capabilities of GPU architectures. + * + */ +class hardware_t { + public: + /** + * @brief Enumeration of supported GPU architectures. + * + */ + enum class architecture_t { gfx90a, gfx942, gfx950, gfx1201, gfx1100, gfx1151, Count }; + + /** + * @brief Convert architecture name string to architecture_t enum. + * + * @param str Architecture name as string (e.g., "gfx90a", "gfx942") + * @return architecture_t Corresponding enum value, or Count if not recognized + */ + static constexpr architecture_t arch_name_to_enum(std::string_view str) noexcept { + if (str == "gfx90a") return architecture_t::gfx90a; + if (str == "gfx942") return architecture_t::gfx942; + if (str == "gfx950") return architecture_t::gfx950; + if (str == "gfx1201") return architecture_t::gfx1201; + if (str == "gfx1100") return architecture_t::gfx1100; + if (str == "gfx1151") return architecture_t::gfx1151; + return architecture_t::Count; + } + + /** + * @brief Architecture-specific constants for memory and compute characteristics. + * + */ + struct architecture_constants { + size_t num_xcds; ///< Number of XCDs (XCD = XGMI Complex Die) + double mem1_perf_ratio; + double mem2_perf_ratio; + double mem3_perf_ratio; + size_t parallel_mi_cu; ///< Number of parallel matrix instructions per compute unit + std::tuple + mem_bw_per_wg_coefficients; ///< Memory bandwidth coefficients per workgroup + double mem_clock_ratio; ///< Memory clock ratio relative to compute clock + + constexpr architecture_constants(size_t num_xcds, + double mem1_perf_ratio, + double mem2_perf_ratio, + double mem3_perf_ratio, + size_t parallel_mi_cu, + std::tuple mem_bw_per_wg_coefficients, + double mem_clock_ratio) // Obtained through microbenchmarking + : num_xcds(num_xcds) + , mem1_perf_ratio(mem1_perf_ratio) + , mem2_perf_ratio(mem2_perf_ratio) + , mem3_perf_ratio(mem3_perf_ratio) + , parallel_mi_cu(parallel_mi_cu) + , mem_bw_per_wg_coefficients(mem_bw_per_wg_coefficients) + , mem_clock_ratio(mem_clock_ratio) {} + }; + + /** + * @brief Get architecture-specific constants for a given architecture. + * + * Returns the pre-configured constants (memory performance ratios, bandwidth + * coefficients, etc.) for the specified architecture. These values are + * determined through microbenchmarking. + * + * @param arch Architecture enum value + * @return architecture_constants Constants for the specified architecture + */ + static constexpr architecture_constants get_arch_constants(architecture_t arch) { + switch (arch) { + case architecture_t::gfx90a: + return {1, 5.5, 1.21875121875121875122 * 1.2, 1.2, 4, std::make_tuple(0, 0.03, 0), 1.5}; + case architecture_t::gfx942: + return {8, 17, 1.21875121875121875122 * 6, 4, 4, std::make_tuple(0, 0.015, 0), 1.5}; + case architecture_t::gfx950: + return {8, 17, 1.21875121875121875122 * 7, 6, 4, std::make_tuple(0, 0.008, 0), 1.5}; + case architecture_t::gfx1201: + return {1, 5.74, 1.21875121875121875122 * 2.41, 0.464, 2, std::make_tuple(0, 0.17, 0), 1.5}; + case architecture_t::gfx1100: + return {1, 7.12, 1.21875121875121875122 * 3.48, 0.732, 2, std::make_tuple(0, 0.11, 0), 1.5}; + case architecture_t::gfx1151: + return {1, 2.47, 1.21875121875121875122 * 0.93, 0.215, 2, std::make_tuple(0, 0.22, 0), 1.5}; + default: return {0, 0, 0, 0, 0, std::make_tuple(0, 0, 0), 0}; } - - inline std::string to_string(data_type_t type) - { - switch(type) - { - case data_type_t::Float: - return "Float"; - case data_type_t::Double: - return "Double"; - case data_type_t::ComplexFloat: - return "ComplexFloat"; - case data_type_t::ComplexDouble: - return "ComplexDouble"; - case data_type_t::Half: - return "Half"; - case data_type_t::Int8x4: - return "Int8x4"; - case data_type_t::Int32: - return "Int32"; - case data_type_t::BFloat16: - return "BFloat16"; - case data_type_t::Int8: - return "Int8"; - case data_type_t::Int4: - return "Int4"; - case data_type_t::Int64: - return "Int64"; - case data_type_t::XFloat32: - return "XFloat32"; - case data_type_t::Float8_fnuz: - return "Float8_fnuz"; - case data_type_t::BFloat8_fnuz: - return "BFloat8_fnuz"; - case data_type_t::Float8BFloat8_fnuz: - return "Float8BFloat8_fnuz"; - case data_type_t::BFloat8Float8_fnuz: - return "BFloat8Float8_fnuz"; - case data_type_t::Float8: - return "Float8"; - case data_type_t::BFloat8: - return "BFloat8"; - case data_type_t::Float8BFloat8: - return "Float8BFloat8"; - case data_type_t::BFloat8Float8: - return "BFloat8Float8"; - case data_type_t::Float6: - return "Float6"; - case data_type_t::BFloat6: - return "BFloat6"; - case data_type_t::Float4: - return "Float4"; - default: - return "Invalid"; - } - return "Invalid"; - } - - inline data_type_t string_to_data_type(std::string s) - { - if (s == "f32") - return data_type_t::Float; - if (s == "c32") - return data_type_t::ComplexFloat; - if (s == "c64") - return data_type_t::ComplexDouble; - if (s == "f64") - return data_type_t::Double; - if (s == "f16") - return data_type_t::Half; - if (s == "i32") - return data_type_t::Int32; - if (s == "bf16") - return data_type_t::BFloat16; - if (s == "i8") - return data_type_t::Int8; - if (s == "i4") - return data_type_t::Int4; - if (s == "xf32") - return data_type_t::XFloat32; - if (s == "f8") - return data_type_t::Float8; - if (s == "bf8") - return data_type_t::BFloat8; - if (s == "f6") - return data_type_t::Float6; - if (s == "bf6") - return data_type_t::BFloat6; - if (s == "f4") - return data_type_t::Float4; - return data_type_t::None; - } - - struct matrix_instruction - { - size_t MI_M; - size_t MI_N; - size_t MI_K; - data_type_t mi_input_type; - - matrix_instruction() - : MI_M(0) - , MI_N(0) - , MI_K(0) - , mi_input_type(data_type_t::Float) - { - } - - matrix_instruction(size_t m, size_t n, size_t k, data_type_t mi_input_type) - : MI_M(m) - , MI_N(n) - , MI_K(k) - , mi_input_type(mi_input_type) - { - } - - matrix_instruction(const matrix_instruction& other) - : MI_M(other.MI_M) - , MI_N(other.MI_N) - , MI_K(other.MI_K) - , mi_input_type(other.mi_input_type) - { - } - - bool operator<(const matrix_instruction& other) const - { - return std::tie(MI_M, MI_N, MI_K, mi_input_type) - < std::tie(other.MI_M, other.MI_N, other.MI_K, other.mi_input_type); - } - - bool operator==(const matrix_instruction& other) const - { - return MI_M == other.MI_M && MI_N == other.MI_N && MI_K == other.MI_K - && mi_input_type == other.mi_input_type; - } - - std::size_t hash() const - { - return std::hash()(MI_M) ^ std::hash()(MI_N) - ^ std::hash()(MI_K) ^ std::hash()(mi_input_type); - } - }; -} - -// Specialize std::hash for the matrix_instruction struct to use it as an unordered_map key. -namespace std -{ - template <> - struct hash - { - std::size_t operator()(const origami::matrix_instruction& k) const - { - return k.hash(); - } - }; -} - -namespace origami -{ - class hardware_t - { - public: - enum class architecture_t - { - gfx90a, - gfx942, - gfx950, - gfx1201, - gfx1100, - gfx1151, - Count - }; - - static architecture_t arch_name_to_enum(const std::string& str) - { - static const std::unordered_map str_to_enum_map - = {{"gfx90a", architecture_t::gfx90a}, - {"gfx942", architecture_t::gfx942}, - {"gfx950", architecture_t::gfx950}, - {"gfx1201", architecture_t::gfx1201}, - {"gfx1100", architecture_t::gfx1100}, - {"gfx1151", architecture_t::gfx1151}}; - - auto it = str_to_enum_map.find(str); - if(it != str_to_enum_map.end()) - { - return it->second; - } - else - { - return architecture_t::Count; - } - } - - struct architecture_constants - { - size_t num_xcds; - double mem1_perf_ratio; - double mem2_perf_ratio; - double mem3_perf_ratio; - size_t parallel_mi_cu; - std::tuple mem_bw_per_wg_coefficients; - double mem_clock_ratio; - - architecture_constants(size_t num_xcds, - double mem1_perf_ratio, - double mem2_perf_ratio, - double mem3_perf_ratio, - size_t parallel_mi_cu, - std::tuple mem_bw_per_wg_coefficients, - double mem_clock_ratio) //Obtained through microbenchmarking - : num_xcds(num_xcds) - , mem1_perf_ratio(mem1_perf_ratio) - , mem2_perf_ratio(mem2_perf_ratio) - , mem3_perf_ratio(mem3_perf_ratio) - , parallel_mi_cu(parallel_mi_cu) - , mem_bw_per_wg_coefficients(mem_bw_per_wg_coefficients) - , mem_clock_ratio(mem_clock_ratio) - { - } - }; - - inline static const std::unordered_map ARCH_CONSTANT_MAP - = {{hardware_t::architecture_t::gfx90a, - hardware_t::architecture_constants( - 1, 5.5, 1.21875121875121875122 * 1.2, 1.2, 4, std::make_tuple(0, 0.03, 0), 1.5)}, - {hardware_t::architecture_t::gfx942, - hardware_t::architecture_constants( - 8, 17, 1.21875121875121875122 * 6, 4, 4, std::make_tuple(0, 0.015, 0), 1.5)}, - {hardware_t::architecture_t::gfx950, - // hardware_t::architecture_constants( - // 8, 17, 1.21875121875121875122 * 7, 6, 4, std::make_tuple(-0.000013, 0.007070, 0.027355), 1.5)}}; - hardware_t::architecture_constants( - 8, 17, 1.21875121875121875122 * 7, 6, 4, std::make_tuple(0, 0.008, 0), 1.5)}, - {hardware_t::architecture_t::gfx1201, - hardware_t::architecture_constants( - 1, 5.74, 1.21875121875121875122 * 2.41, 0.464, 2, std::make_tuple(0, 0.17, 0), 1.5)}, - {hardware_t::architecture_t::gfx1100, - hardware_t::architecture_constants( - 1, 7.12, 1.21875121875121875122 * 3.48, 0.732, 2, std::make_tuple(0, 0.11, 0), 1.5)}, - {hardware_t::architecture_t::gfx1151, - hardware_t::architecture_constants( - 1, 2.47, 1.21875121875121875122 * 0.93, 0.215, 2, std::make_tuple(0, 0.22, 0), 1.5)}}; - - inline static const std::unordered_map> INSTRUCTION_MAP - = {{hardware_t::architecture_t::gfx90a, - { - // F32 - {matrix_instruction(32, 32, 2, data_type_t::Float), 64}, // v_mfma_f32_32x32x2_f32 - {matrix_instruction(32, 32, 1, data_type_t::Float), 64}, // v_mfma_f32_32x32x1_2b_f32 - {matrix_instruction(16, 16, 4, data_type_t::Float), 32}, // v_mfma_f32_16x16x4_f32 - {matrix_instruction(16, 16, 1, data_type_t::Float), 32}, // v_mfma_f32_16x16x1_4b_f32 - {matrix_instruction(4, 4, 1, data_type_t::Float), 8}, // v_mfma_f32_4x4x1_16b_f32 - // F64 - {matrix_instruction(16, 16, 4, data_type_t::Double), 32}, // v_mfma_f64_16x16x4_f64 - {matrix_instruction(4, 4, 4, data_type_t::Double), 16}, // v_mfma_f64_4x4x4_4b_f64 - // TODO ComplexFloat - // TODO ComplexDouble - // F16 - {matrix_instruction(32, 32, 4, data_type_t::Half), 64}, // v_mfma_f32_32x32x4_2b_f16 - {matrix_instruction(32, 32, 8, data_type_t::Half), 64}, // v_mfma_f32_32x32x8_f16 - {matrix_instruction(16, 16, 4, data_type_t::Half), 32}, // v_mfma_f32_16x16x4_4b_f16 - {matrix_instruction(16, 16, 16, data_type_t::Half), 32}, // v_mfma_f32_16x16x16_f16 - {matrix_instruction(4, 4, 4, data_type_t::Half), 8}, // v_mfma_f32_4x4x4_16b_f16 - // BF16 - {matrix_instruction(32, 32, 4, data_type_t::BFloat16), 64}, // v_mfma_f32_32x32x4_2b_bf16 - {matrix_instruction(32, 32, 8, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x8_bf16 - {matrix_instruction(16, 16, 4, data_type_t::BFloat16), 32}, // v_mfma_f32_16x16x4_4b_bf16 - {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 - {matrix_instruction(4, 4, 4, data_type_t::BFloat16), 8}, // v_mfma_f32_4x4x4_16b_bf16 - // I8 - {matrix_instruction(32, 32, 8, data_type_t::Int8), 64}, // v_mfma_f32_32x32x16_f8 - {matrix_instruction(32, 32, 4, data_type_t::Int8), 64}, // v_mfma_i32_32x32x4_2b_i8 - {matrix_instruction(16, 16, 16, data_type_t::Int8), 32}, // v_mfma_f32_16x16x32_i8 - {matrix_instruction(16, 16, 4, data_type_t::Int8), 32}, // v_mfma_i32_16x16x4_4b_i8 - {matrix_instruction(4, 4, 4, data_type_t::Int8), 8}, // v_mfma_i32_4x4x4_16b_i8 - // XF32 - {matrix_instruction(32, 32, 8, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x8_bf16 * 3 - {matrix_instruction(32, 32, 16, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x16_bf16 * 3 - {matrix_instruction(16, 16, 16, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 - {matrix_instruction(16, 16, 32, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 - }}, - {hardware_t::architecture_t::gfx942, - { - // F32 - {matrix_instruction(32, 32, 2, data_type_t::Float), 64}, // v_mfma_f32_32x32x2_f32 - {matrix_instruction(32, 32, 1, data_type_t::Float), 64}, // v_mfma_f32_32x32x1_2b_f32 - {matrix_instruction(16, 16, 4, data_type_t::Float), 32}, // v_mfma_f32_16x16x4_f32 - {matrix_instruction(16, 16, 1, data_type_t::Float), 32}, // v_mfma_f32_16x16x1_4b_f32 - {matrix_instruction(4, 4, 1, data_type_t::Float), 8}, // v_mfma_f32_4x4x1_16b_f32 - // F64 - {matrix_instruction(16, 16, 4, data_type_t::Double), 32}, // v_mfma_f64_16x16x4_f64 - {matrix_instruction(4, 4, 4, data_type_t::Double), 16}, // v_mfma_f64_4x4x4_4b_f64 - // TODO ComplexFloat - // TODO ComplexDouble - // F16 - {matrix_instruction(32, 32, 4, data_type_t::Half), 64}, // v_mfma_f32_32x32x4_2b_f16 - {matrix_instruction(32, 32, 8, data_type_t::Half), 32}, // v_mfma_f32_32x32x8_f16 - {matrix_instruction(16, 16, 4, data_type_t::Half), 32}, // v_mfma_f32_16x16x4_4b_f16 - {matrix_instruction(16, 16, 16, data_type_t::Half), 16}, // v_mfma_f32_16x16x16_f16 - {matrix_instruction(4, 4, 4, data_type_t::Half), 8}, // v_mfma_f32_4x4x4_16b_f16 - // BF16 - {matrix_instruction(32, 32, 4, data_type_t::BFloat16), 64}, // v_mfma_f32_32x32x4_2b_bf16 - {matrix_instruction(32, 32, 8, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x8_bf16 - {matrix_instruction(16, 16, 4, data_type_t::BFloat16), 32}, // v_mfma_f32_16x16x4_4b_bf16 - {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 - {matrix_instruction(4, 4, 4, data_type_t::BFloat16), 8}, // v_mfma_f32_4x4x4_16b_bf16 - // F8 - {matrix_instruction(32, 32, 16, data_type_t::Float8_fnuz), 32}, // v_mfma_f32_32x32x16_f8 - {matrix_instruction(16, 16, 32, data_type_t::Float8_fnuz), 16}, // v_mfma_f32_16x16x32_f8 - // BF8 - {matrix_instruction(32, 32, 16, data_type_t::BFloat8_fnuz), 32}, // v_mfma_f32_32x32x16_bf8 - {matrix_instruction(16, 16, 32, data_type_t::BFloat8_fnuz), 16}, // v_mfma_f32_16x16x32_bf8 - // F8B8 - {matrix_instruction(32, 32, 16, data_type_t::Float8BFloat8_fnuz), 32}, // v_mfma_f32_32x32x16_f8_bf8 - {matrix_instruction(16, 16, 32, data_type_t::Float8BFloat8_fnuz), 16}, // v_mfma_f32_16x16x32_f8_bf8 - // B8F8 - {matrix_instruction(32, 32, 16, data_type_t::BFloat8Float8_fnuz), 32}, // v_mfma_f32_32x32x16_bf8_f8 - {matrix_instruction(16, 16, 32, data_type_t::BFloat8Float8_fnuz), 16}, // v_mfma_f32_16x16x32_bf8_f8 - // I8 - {matrix_instruction(32, 32, 16, data_type_t::Int8), 32}, // v_mfma_f32_32x32x16_f8 - {matrix_instruction(32, 32, 4, data_type_t::Int8), 64}, // v_mfma_i32_32x32x4_2b_i8 - {matrix_instruction(16, 16, 32, data_type_t::Int8), 16}, // v_mfma_f32_16x16x32_i8 - {matrix_instruction(16, 16, 4, data_type_t::Int8), 32}, // v_mfma_i32_16x16x4_4b_i8 - {matrix_instruction(4, 4, 4, data_type_t::Int8), 8}, // v_mfma_i32_4x4x4_16b_i8 - // XF32 - {matrix_instruction(32, 32, 4, data_type_t::XFloat32), 32}, // v_mfma_f32_32x32x4_xf32 - {matrix_instruction(16, 16, 32, data_type_t::XFloat32), 16}, // v_mfma_f32_16x16x8_xf32 - }}, - {hardware_t::architecture_t::gfx950, - { - // F32 - {matrix_instruction(32, 32, 2, data_type_t::Float), 64}, // v_mfma_f32_32x32x2_f32 - {matrix_instruction(32, 32, 1, data_type_t::Float), 64}, // v_mfma_f32_32x32x1_2b_f32 - {matrix_instruction(16, 16, 4, data_type_t::Float), 32}, // v_mfma_f32_16x16x4_f32 - {matrix_instruction(16, 16, 1, data_type_t::Float), 32}, // v_mfma_f32_16x16x1_4b_f32 - {matrix_instruction(4, 4, 1, data_type_t::Float), 8}, // v_mfma_f32_4x4x1_16b_f32 - // F64 - {matrix_instruction(16, 16, 4, data_type_t::Double), 64}, // v_mfma_f64_16x16x4_f64 - {matrix_instruction(4, 4, 4, data_type_t::Double), 16}, // v_mfma_f64_4x4x4_4b_f64 - // TODO ComplexFloat - // TODO ComplexDouble - // F16 - {matrix_instruction(32, 32, 8, data_type_t::Half), 32}, // v_mfma_f32_32x32x8_f16 - {matrix_instruction(32, 32, 16, data_type_t::Half), 32}, // v_mfma_f32_32x32x16_f16 - {matrix_instruction(16, 16, 16, data_type_t::Half), 16}, // v_mfma_f32_16x16x16_f16 - {matrix_instruction(16, 16, 32, data_type_t::Half), 16}, // v_mfma_f32_16x16x32_f16 - // BF16 - {matrix_instruction(32, 32, 8, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x8_bf16 - {matrix_instruction(32, 32, 16, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x16_bf16 - {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 - {matrix_instruction(16, 16, 32, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 - // F8 - {matrix_instruction(32, 32, 64, data_type_t::Float8), 64}, // v_mfma_f32_32x32x64_f8 - {matrix_instruction(32, 32, 16, data_type_t::Float8), 32}, // v_mfma_f32_32x32x16_f8 - {matrix_instruction(16, 16, 128, data_type_t::Float8), 32}, // v_mfma_f32_16x16x128_f8 - {matrix_instruction(16, 16, 32, data_type_t::Float8), 16}, // v_mfma_f32_16x16x32_f8 - // BF8 - {matrix_instruction(32, 32, 64, data_type_t::BFloat8), 64}, // v_mfma_f32_32x32x64_bf8 - {matrix_instruction(32, 32, 16, data_type_t::BFloat8), 32}, // v_mfma_f32_32x32x16_bf8 - {matrix_instruction(16, 16, 128, data_type_t::BFloat8), 32}, // v_mfma_f32_16x16x128_bf8 - {matrix_instruction(16, 16, 32, data_type_t::BFloat8), 16}, // v_mfma_f32_16x16x32_bf8 - // F8B8 - {matrix_instruction(32, 32, 64, data_type_t::Float8BFloat8), 64}, // v_mfma_f32_32x32x64_f8_bf8 - {matrix_instruction(32, 32, 16, data_type_t::Float8BFloat8), 32}, // v_mfma_f32_32x32x16_f8_bf8 - {matrix_instruction(16, 16, 128, data_type_t::Float8BFloat8), 32}, // v_mfma_f32_16x16x128_f8_bf8 - {matrix_instruction(16, 16, 32, data_type_t::Float8BFloat8), 16}, // v_mfma_f32_16x16x32_f8_bf8 - // B8F8 - {matrix_instruction(32, 32, 64, data_type_t::BFloat8Float8), 64}, // v_mfma_f32_32x32x64_bf8_f8 - {matrix_instruction(32, 32, 16, data_type_t::BFloat8Float8), 32}, // v_mfma_f32_32x32x16_bf8_f8 - {matrix_instruction(16, 16, 128, data_type_t::BFloat8Float8), 32}, // v_mfma_f32_16x16x128_bf8_f8 - {matrix_instruction(16, 16, 32, data_type_t::BFloat8Float8), 16}, // v_mfma_f32_16x16x32_bf8_f8 - // I8 - {matrix_instruction(32, 32, 16, data_type_t::Int8), 32}, // v_mfma_f32_32x32x16_f8 - {matrix_instruction(32, 32, 4, data_type_t::Int8), 64}, // v_mfma_i32_32x32x4_2b_i8 - {matrix_instruction(16, 16, 32, data_type_t::Int8), 16}, // v_mfma_f32_16x16x32_i8 - {matrix_instruction(16, 16, 4, data_type_t::Int8), 32}, // v_mfma_i32_16x16x4_4b_i8 - {matrix_instruction(4, 4, 4, data_type_t::Int8), 8}, // v_mfma_i32_4x4x4_16b_i8 - // XF32 - {matrix_instruction(32, 32, 8, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x8_bf16 * 3 - {matrix_instruction(32, 32, 16, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x16_bf16 * 3 - {matrix_instruction(16, 16, 16, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 - {matrix_instruction(16, 16, 32, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 - // F6 - {matrix_instruction(32, 32, 64, data_type_t::Float6), 32}, // v_mfma_f32_32x32x64_f6 - {matrix_instruction(16, 16, 128, data_type_t::Float6), 16}, // v_mfma_f32_16x16x128_f6 - // BF6 - {matrix_instruction(32, 32, 64, data_type_t::BFloat6), 32}, // v_mfma_f32_32x32x64_bf6 - {matrix_instruction(16, 16, 128, data_type_t::BFloat6), 16}, // v_mfma_f32_16x16x128_bf6 - // F4 - {matrix_instruction(32, 32, 64, data_type_t::Float4), 32}, // v_mfma_f32_32x32x64_f4 - {matrix_instruction(16, 16, 128, data_type_t::Float4), 16}, // v_mfma_f32_16x16x128_f4 - // DOT2 - {matrix_instruction( 1, 1, 64, data_type_t::Half), 16}, // V_DOT2_F32_F16 - {matrix_instruction( 1, 1, 64, data_type_t::BFloat16), 16}, // V_DOT2_F32_BF16 - }}, - {hardware_t::architecture_t::gfx1201, - { - // F16 - {matrix_instruction(16, 16, 16, data_type_t::Half), 16}, // v_wmma_f16_16x16x16_f16/v_wmma_f32_16x16x16_f16 - // BF16 - {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_wmma_bf16_16x16x16_bf16/v_wmma_f32_16x16x16_bf16 - // F8 - {matrix_instruction(16, 16, 16, data_type_t::Float8), 8}, // v_wmma_f32_16x16x16_fp8_fp8 - // F8B8 - {matrix_instruction(16, 16, 16, data_type_t::Float8BFloat8), 8}, // v_wmma_f32_16x16x16_fp8_bf8 - // B8F8 - {matrix_instruction(16, 16, 16, data_type_t::BFloat8Float8), 8}, // v_wmma_f32_16x16x16_bf8_fp8 - // B8 - {matrix_instruction(16, 16, 16, data_type_t::BFloat8), 8}, // v_wmma_f32_16x16x16_bf8_bf8 - // I8 - {matrix_instruction(16, 16, 16, data_type_t::Int8), 8}, // v_wmma_i32_16x16x16_iu8 - // I4 - {matrix_instruction(16, 16, 16, data_type_t::Int4), 8}, // v_wmma_i32_16x16x16_iu4 - {matrix_instruction(16, 16, 32, data_type_t::Int4), 8}, // v_wmma_i32_16x16x32_iu4 - }}, - {hardware_t::architecture_t::gfx1100, - { - // F16 - {matrix_instruction(16, 16, 16, data_type_t::Half), 32}, // v_wmma_f32_16x16x16_f16/v_wmma_f16_16x16x16_f16 - // BF16 - {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 32}, // v_wmma_f32_16x16x16_bf16/v_wmma_bf16_16x16x16_bf16 - // I8 - {matrix_instruction(16, 16, 16, data_type_t::Int8), 32}, // v_wmma_i32_16x16x16_iu8 - // I4 - {matrix_instruction(16, 16, 16, data_type_t::Int4), 16}, // v_wmma_i32_16x16x16_iu4 - }}, - {hardware_t::architecture_t::gfx1151, - { - // F16 - {matrix_instruction(16, 16, 16, data_type_t::Half), 32}, // v_wmma_f32_16x16x16_f16/v_wmma_f16_16x16x16_f16 - // BF16 - {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 32}, // v_wmma_f32_16x16x16_bf16/v_wmma_bf16_16x16x16_bf16 - // I8 - {matrix_instruction(16, 16, 16, data_type_t::Int8), 32}, // v_wmma_i32_16x16x16_iu8 - // I4 - {matrix_instruction(16, 16, 16, data_type_t::Int4), 16}, // v_wmma_i32_16x16x16_iu4 - }}}; - - architecture_t arch; - size_t N_CU; // Number of Compute Units - size_t lds_capacity; // Capacity of LDS - double mem1_perf_ratio; - double mem2_perf_ratio; - double mem3_perf_ratio; - size_t L2_capacity; // Capacity of L2 in bytes - size_t CU_per_L2; // Number of compute units per L2 domain - double compute_clock_ghz; - size_t parallel_mi_cu; // The number of parallel MI in a CU - std::tuple mem_bw_per_wg_coefficients; - size_t NUM_XCD; - - hardware_t(architecture_t arch, - size_t N_CU, - size_t lds_capacity, - size_t NUM_XCD, - double mem1_perf_ratio, - double mem2_perf_ratio, - double mem3_perf_ratio, - size_t L2_capacity, - double compute_clock_ghz, - size_t parallel_mi_cu, - std::tuple mem_bw_per_wg_coefficients) - : arch(arch) - , N_CU(N_CU) - , lds_capacity(lds_capacity) - , mem1_perf_ratio(mem1_perf_ratio) - , mem2_perf_ratio(mem2_perf_ratio) - , mem3_perf_ratio(mem3_perf_ratio) - , L2_capacity(L2_capacity) - , CU_per_L2(N_CU / NUM_XCD) - , compute_clock_ghz(compute_clock_ghz) - , parallel_mi_cu(parallel_mi_cu) - , mem_bw_per_wg_coefficients(mem_bw_per_wg_coefficients) - , NUM_XCD(NUM_XCD) - { - } - - hardware_t(hipDeviceProp_t properties) - : hardware_t(get_hardware_for_properties(properties)) - { - } - - hardware_t(const hardware_t& other) - : arch(other.arch) - , N_CU(other.N_CU) - , lds_capacity(other.lds_capacity) - , mem1_perf_ratio(other.mem1_perf_ratio) - , mem2_perf_ratio(other.mem2_perf_ratio) - , mem3_perf_ratio(other.mem3_perf_ratio) - , L2_capacity(other.L2_capacity) - , CU_per_L2(other.CU_per_L2) - , compute_clock_ghz(other.compute_clock_ghz) - , parallel_mi_cu(other.parallel_mi_cu) - , mem_bw_per_wg_coefficients(other.mem_bw_per_wg_coefficients) - , NUM_XCD(other.NUM_XCD) - { - } - - static hardware_t get_hardware_for_properties(hipDeviceProp_t properties) - { - auto arch_name = get_before_first_colon(properties.gcnArchName); - auto arch_enum = arch_name_to_enum(arch_name); - auto it = ARCH_CONSTANT_MAP.find(arch_enum); - if(it == ARCH_CONSTANT_MAP.end()) - { - throw std::runtime_error( - "Attempting to retrieve hardware constants for unsupported architecture: " - + arch_name); // Could also return default values here. - } - auto constants = it->second; - return hardware_t(arch_enum, - properties.multiProcessorCount, - properties.sharedMemPerBlock, - constants.num_xcds, - 1e9 * constants.mem1_perf_ratio / properties.clockRate, - 1e9 * constants.mem2_perf_ratio - / (properties.memoryClockRate * constants.mem_clock_ratio), - 1e9 * constants.mem3_perf_ratio / properties.memoryClockRate, - properties.l2CacheSize, - properties.clockRate / 1e6, - constants.parallel_mi_cu, - constants.mem_bw_per_wg_coefficients); - } - - static hardware_t get_hardware_for_device(int deviceId) - { - hipDeviceProp_t prop; - hipError_t e = hipGetDeviceProperties(&prop, deviceId); - if(e) - { - throw std::runtime_error(hipGetErrorString(e)); - } - return get_hardware_for_properties(prop); - } - - static bool is_hardware_supported(hipDeviceProp_t properties) - { - auto arch_name = get_before_first_colon(properties.gcnArchName); - auto arch_enum = arch_name_to_enum(arch_name); - auto it = ARCH_CONSTANT_MAP.find(arch_enum); - return it != ARCH_CONSTANT_MAP.end(); - } - - // Function to print hardware details - void print() const - { - std::cout << "================== Hardware Configuration ==================\n"; - std::cout << "Number of CUs (N_CU) : " << N_CU << "\n"; - std::cout << "LDS capacity : " << lds_capacity << " bytes\n"; - std::cout << "mem1_perf_ratio : " << mem1_perf_ratio << "\n"; - std::cout << "mem2_perf_ratio : " << mem2_perf_ratio << "\n"; - std::cout << "mem3_perf_ratio : " << mem3_perf_ratio << "\n"; - std::cout << "L2 Cache capacity : " << L2_capacity << " bytes\n"; - std::cout << "CUs per L2 domain : " << CU_per_L2 << "\n"; - std::cout << "Compute clock (GHz) : " << compute_clock_ghz << "\n"; - std::cout << "Parallel MI/CU : " << parallel_mi_cu << "\n"; - std::cout << "Number of XCDs (NUM_XCD) : " << NUM_XCD << "\n"; - std::cout << "mem_bw_per_wg_coefficients: " << std::get<0>(mem_bw_per_wg_coefficients) << ", " - << std::get<1>(mem_bw_per_wg_coefficients) << ", " - << std::get<2>(mem_bw_per_wg_coefficients) << "\n\n"; - - std::cout << "------------------ Instruction Map -------------------------\n"; - // Loop over the instruction_map and print each entry - for(const auto& kv : INSTRUCTION_MAP.at(arch)) - { - const auto& key = kv.first; - const auto& L_MI = kv.second; - - std::cout << "Instruction: MI_M=" << key.MI_M << ", MI_N=" << key.MI_N - << ", MI_K=" << key.MI_K << ", mi_input_type=" << to_string(key.mi_input_type) - << " bytes\n" - << " -> Latency (L_MI): " << L_MI << "\n"; - } - std::cout << "===========================================================\n"; - } - // Debug tracking info - mutable std::unordered_map debug_info; - - static bool is_debug_enabled() - { - static bool debugEnvVar = read_debug_env_var(); //Used to cache the read. - return debugEnvVar; - } - - static bool is_heuristics_enabled() - { - static bool heuristicsEnvVar = read_heuristics_env_var(); //Used to cache the read. - return heuristicsEnvVar; - } - - void log_debug(const std::string& key, const std::string& value) const - { - debug_info[key] = value; - } - - void log_debug(const std::string& key, double value) const - { - debug_info[key] = std::to_string(value); - } - - void clear_debug() const - { - debug_info.clear(); - } - - void print_debug_info() const - { - std::cout << "=== Hardware Debug Info ===\n"; - for(const auto& [key, val] : debug_info) - { - std::cout << key << ": " << val << "\n"; - } - std::cout << "===========================\n"; - } - - size_t get_mi_latency(size_t MI_M, size_t MI_N, size_t MI_K, data_type_t mi_input_type) const - { - const auto& instruction_map = INSTRUCTION_MAP.at(arch); - auto key = matrix_instruction(MI_M, MI_N, MI_K, mi_input_type); - - auto it = instruction_map.find(key); - if(it != instruction_map.end()) - { - return it->second / parallel_mi_cu; - } - else - { - std::cerr << "Warning: Latency not found for MI_M=" << MI_M << ", MI_N=" << MI_N - << ", MI_K=" << MI_K << ", mi_input_type=" << to_string(mi_input_type) - << ". Returning latency value of 32 (really slow).\n"; - return 32 / parallel_mi_cu; // Default latency if instruction is not found - } - } - - private: - static std::string get_before_first_colon(const std::string& input) - { - size_t pos = input.find(':'); - if(pos != std::string::npos) - { - return input.substr(0, pos); - } - return input; // Return the whole string if ':' is not found - } - - // Helper function to read the debug environment variable - static bool read_debug_env_var() - { - const char* env = std::getenv("ANALYTICAL_GEMM_DEBUG"); - return env && std::string(env) == "1"; - } - - // Helper function to read the heuristics environment variable - static bool read_heuristics_env_var() - { - const char* env = std::getenv("ANALYTICAL_GEMM_HEURISTICS"); - return !(env && std::string(env) == "0"); - } - }; -} // namespace origami + } + + /** + * @brief Map of matrix instruction latencies by architecture. + * + */ + static const std::unordered_map> + INSTRUCTION_MAP; + + architecture_t arch; ///< GPU architecture type + size_t N_CU; ///< Number of Compute Units + size_t lds_capacity; ///< Capacity of Local Data Share (LDS) in bytes + double mem1_perf_ratio; + double mem2_perf_ratio; + double mem3_perf_ratio; + size_t L2_capacity; ///< Capacity of L2 cache in bytes + size_t CU_per_L2; ///< Number of compute units per L2 cache domain + double compute_clock_ghz; ///< Compute clock frequency in GHz + size_t parallel_mi_cu; ///< Number of parallel matrix instructions per compute unit + std::tuple + mem_bw_per_wg_coefficients; ///< Memory bandwidth coefficients per workgroup + size_t NUM_XCD; ///< Number of XCDs (XGMI Complex Die) + + /** + * @brief Construct hardware_t with explicit parameters. + * + * @param arch GPU architecture type + * @param N_CU Number of compute units + * @param lds_capacity LDS capacity in bytes + * @param NUM_XCD Number of XCDs + * @param mem1_perf_ratio Memory level 1 performance ratio + * @param mem2_perf_ratio Memory level 2 performance ratio + * @param mem3_perf_ratio Memory level 3 performance ratio + * @param L2_capacity L2 cache capacity in bytes + * @param compute_clock_ghz Compute clock frequency in GHz + * @param parallel_mi_cu Number of parallel matrix instructions per CU + * @param mem_bw_per_wg_coefficients Memory bandwidth coefficients per workgroup + */ + hardware_t(architecture_t arch, + size_t N_CU, + size_t lds_capacity, + size_t NUM_XCD, + double mem1_perf_ratio, + double mem2_perf_ratio, + double mem3_perf_ratio, + size_t L2_capacity, + double compute_clock_ghz, + size_t parallel_mi_cu, + std::tuple mem_bw_per_wg_coefficients); + + /** + * @brief Construct hardware_t from HIP device properties. + * + * Automatically determines architecture and extracts hardware parameters + * from the provided HIP device properties structure. + * + * @param properties HIP device properties structure + */ + hardware_t(hipDeviceProp_t properties); + + /** + * @brief Copy constructor. + * + * @param other Another hardware_t instance to copy from + */ + hardware_t(const hardware_t& other); + + /** + * @brief Create hardware_t instance from HIP device properties. + * + * + * @param properties HIP device properties structure + * @return hardware_t Configured hardware instance + */ + static hardware_t get_hardware_for_properties(hipDeviceProp_t properties); + + /** + * @brief Create hardware_t instance for a specific HIP device. + * + * Queries the specified HIP device and creates a hardware_t instance + * with the appropriate architecture and parameters. + * + * @param deviceId HIP device ID + * @return hardware_t Configured hardware instance for the device + */ + static hardware_t get_hardware_for_device(int deviceId); + + /** + * @brief Check if the hardware described by properties is supported. + * + * Determines whether the GPU architecture represented by the device + * properties is supported by the analytical model. + * + * @param properties HIP device properties structure + * @return true if the architecture is supported, false otherwise + */ + static bool is_hardware_supported(hipDeviceProp_t properties); + + /** + * @brief Print hardware details to stdout. + * + */ + void print() const; + + /** + * @brief Get matrix instruction latency for given instruction parameters. + * + * + * @param MI_M Matrix instruction M dimension + * @param MI_N Matrix instruction N dimension + * @param MI_K Matrix instruction K dimension + * @param mi_input_type Input data type for the matrix instruction + * @return size_t Instruction latency in cycles, or 0 if not found + */ + size_t get_mi_latency(size_t MI_M, size_t MI_N, size_t MI_K, data_type_t mi_input_type) const; + + private: + /** + * @brief Extract substring before the first colon character. + * + * Helper function used for parsing architecture names from device + * property strings (e.g., extracting "gfx90a" from "gfx90a:..."). + * + * @param input Input string to parse + * @return std::string Substring before the first colon, or entire string if no colon found + */ + static std::string get_before_first_colon(const std::string& input); +}; +} // namespace origami diff --git a/shared/origami/include/origami/log.hpp b/shared/origami/include/origami/log.hpp new file mode 100644 index 000000000000..62d54f2392a1 --- /dev/null +++ b/shared/origami/include/origami/log.hpp @@ -0,0 +1,170 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright 2025 AMD ROCm(TM) Software + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +namespace origami { + +/** + * @brief Logger for collecting and exporting analytical metrics in JSON format. + * + * Provides a single templated logging function that stores key-value pairs + * and can export them as JSON. The caller is responsible for checking if + * debug is enabled before calling log(). + */ +class logger_t { + public: + /** + * @brief Default constructor (constexpr-compatible). + * The map is lazily allocated on first use. + */ + constexpr logger_t() : metrics_(nullptr) {} + + /** + * @brief Destructor + */ + ~logger_t() = default; + + /** + * @brief Copy constructor. + */ + logger_t(const logger_t& other) { + if (other.metrics_) { + metrics_ = std::make_unique>(*other.metrics_); + } else { + metrics_ = nullptr; + } + } + + /** + * @brief Move constructor. + */ + logger_t(logger_t&& other) noexcept = default; + + /** + * @brief Copy assignment operator. + */ + logger_t& operator=(const logger_t& other) { + if (this != &other) { + if (other.metrics_) { + metrics_ = std::make_unique>(*other.metrics_); + } else { + metrics_ = nullptr; + } + } + return *this; + } + + /** + * @brief Move assignment operator. + */ + logger_t& operator=(logger_t&& other) noexcept = default; + + /** + * @brief Ensure metrics map is allocated. + */ + void ensure_metrics() const { + if (!metrics_) { metrics_ = std::make_unique>(); } + } + + /** + * @brief Log a key-value pair. + * + * @tparam T Type of the value (must be convertible to JSON-compatible string) + * @param key The metric key + * @param value The metric value + */ + template + void log(const std::string& key, const T& value) { + ensure_metrics(); + (*metrics_)[key] = to_json_string(value); + } + + /** + * @brief Clear all logged metrics. + */ + void clear() { + if (metrics_) { metrics_->clear(); } + } + + /** + * @brief Print all metrics as JSON to stdout. + */ + void print() const; + + /** + * @brief Export metrics to a JSON file. + * + * @param filename Output filename + */ + void export_json(const std::string& filename) const; + + /** + * @brief Get all metrics as a map. + * + * @return Map of metric key-value pairs + */ + std::unordered_map get_metrics() const { + if (!metrics_) { return std::unordered_map(); } + return *metrics_; + } + + /** + * @brief Check if logger has any metrics. + * + * @return true if metrics map is not empty + */ + bool empty() const { return !metrics_ || metrics_->empty(); } + + private: + mutable std::unique_ptr> metrics_; + + // Convert value to JSON-compatible string + template + std::string to_json_string(const T& value) { + if constexpr (std::is_same_v>, std::string>) { + return "\"" + value + "\""; + } else if constexpr (std::is_convertible_v || + std::is_array_v>) { + return "\"" + std::string(value) + "\""; + } else if constexpr (std::is_same_v>, bool>) { + return value ? "true" : "false"; + } else if constexpr (std::is_arithmetic_v && !std::is_same_v) { + return std::to_string(value); + } else { + // For other types, try to_string (this may fail for some types) + static_assert(std::is_arithmetic_v, "Type must be convertible to string or arithmetic"); + return std::to_string(value); + } + } +}; + +} // namespace origami diff --git a/shared/origami/include/origami/math.hpp b/shared/origami/include/origami/math.hpp new file mode 100644 index 000000000000..8498f1a5bacb --- /dev/null +++ b/shared/origami/include/origami/math.hpp @@ -0,0 +1,44 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright 2025 AMD ROCm(TM) Software + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +namespace origami { +namespace math { + +/** + * @brief Performs `(n + d - 1) / d`, but is robust against the case where + * `(n + d - 1)` would overflow. + * + */ +template +inline constexpr N safe_ceil_div(N n, D d) { + // Static cast to undo integral promotion. + return static_cast(d == 0 ? 0 : (n / d + (n % d != 0 ? 1 : 0))); +} + +} // namespace math +} // namespace origami diff --git a/shared/origami/include/origami/origami.hpp b/shared/origami/include/origami/origami.hpp new file mode 100644 index 000000000000..8b78bdba5b1e --- /dev/null +++ b/shared/origami/include/origami/origami.hpp @@ -0,0 +1,120 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright 2025 AMD ROCm(TM) Software + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +#include "origami/hardware.hpp" +#include "origami/types.hpp" + +namespace origami { + +/** + * @brief Based on the provided problem and configs; selects the best config. + * + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param configs Vector of all possible valid configurations. + * @return prediction_result_t Configurations with best latency. + */ +prediction_result_t select_config(const problem_t& problem, + const hardware_t& hardware, + const std::vector& configs); + +/** + * @brief Select best workgroup-mapping for the given tile size. + * + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param config Kernel configuration. + * @param skGrid StreamK grid size. + * @return std::tuple + */ +std::tuple select_workgroup_mapping(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + size_t skGrid); + +/** + * @brief Rank configurations based on predicted performance. + * + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param configs List of candidate configurations to rank + * @return std::vector Configurations with latencies ranked by performance + * (best first) + */ +std::vector rank_configs(const problem_t& problem, + const hardware_t& hardware, + const std::vector& configs); + +/** + * @brief Select best configuration based only on M, N, K dimensions with default settings. + * + * @param M Problem dimension M + * @param N Problem dimension N + * @param K Problem dimension K + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param configs List of candidate configurations + * @return prediction_result_t Configurations with best latency. + */ +prediction_result_t select_config_mnk(std::size_t M, + std::size_t N, + std::size_t K, + const hardware_t& hardware, + const std::vector& configs); + +/** + * @brief Select top K configurations based on performance ranking. + * + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param configs List of candidate configurations to rank + * @param topk Number of top configurations to return + * @return std::vector Top K configurations ranked by performance (best first) + */ +std::vector select_topk_configs(const problem_t& problem, + const hardware_t& hardware, + const std::vector& configs, + std::size_t topk); + +/** + * @brief Given a latency, compute the achieved throughput in gflops. + * + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param problem Problem description (M, N, K, etc.) + * @param latency Kernel latency. + * @return double Throughput in gflops/s. + */ +double compute_perf_gflops(const hardware_t& hardware, + const problem_t& problem, + const double latency); + +} // namespace origami diff --git a/shared/origami/include/origami/streamk.hpp b/shared/origami/include/origami/streamk.hpp index 467bb94de9ff..10cc6077dbbd 100644 --- a/shared/origami/include/origami/streamk.hpp +++ b/shared/origami/include/origami/streamk.hpp @@ -1,78 +1,79 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT +/******************************************************************************* + * + * MIT License + * + * Copyright 2025 AMD ROCm(TM) Software + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ #pragma once #include "origami/hardware.hpp" -#include - -namespace origami -{ - namespace streamk - { - enum class reduction_type - { - // BasicReduction, - Tree, - Parallel, - // AtomicReduction, - Count, - None = Count - }; +#include "origami/types.hpp" - inline reduction_type int_to_reduction_type(int rt) - { - return (reduction_type)rt; - } - - size_t get_workspace( - size_t x, - size_t y, - size_t mt_m, - size_t mt_n, - size_t bpe_c, - size_t grid, - size_t tiles, - reduction_type reduction); +#include - reduction_type select_reduction( - size_t x, - size_t y, - size_t z, - size_t batch, - size_t mt_m, - size_t mt_n, - size_t mt_k, - const hardware_t& analytical_hardware, - int dynamic_grid_version); +namespace origami { +namespace streamk { +/** + * @brief Number of output tiles. + * + * @param mt_m Tile size in M-dimension. + * @param mt_n Tile size in N-dimension. + * @param m Matrix's m-dimension. + * @param n Matrix's n-dimension. + * @param batch Number of batches. + * @return size_t Total number of output tiles. + */ +size_t compute_number_of_output_tiles(size_t mt_m, size_t mt_n, size_t m, size_t n, size_t batch); - const char* rtype_to_string(streamk::reduction_type r); +/** + * @brief Select the best reduction strategy for StreamK. + * + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param config Kernel configuration. + * @param algorithm Grid selection algorithm + * @return reduction_t Selected reduction strategy + */ +reduction_t select_reduction(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + grid_selection_t algorithm); - size_t select_grid(size_t x, - size_t y, - size_t z, - size_t batch, - bool trans_a, - bool trans_b, - size_t element_size_A, - size_t element_size_B, - size_t element_size_out, - data_type_t mi_datatype, - size_t workspace_size, - size_t mt_m, - size_t mt_n, - size_t mt_k, - size_t mi_m, - size_t mi_n, - size_t mi_k, - int workgroup_mapping, - size_t workspace_size_per_elem_c, - int occupancy, - const hardware_t& analytical_hardware, - int dynamic_grid_version, - reduction_type reduction_strategy, - size_t max_cus = 0); - // max workspace +/** + * @brief Based on the provided kernel config, select the best grid dimension. + * + * @param problem Problem description (M, N, K, etc.) + * @param hardware Hardware characteristics (@see origami::hardware_t) + * @param config Kernel configuration. + * @param grid_selection_t grid selection algorithm (@see origami::grid_selection_t) + * @param max_cus Maximum number of CUs to use. + * @return size_t Dimensions of the grid launched. + */ +size_t select_grid_size(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + grid_selection_t algorithm, + size_t max_cus = 0); - } // namespace streamk -} +} // namespace streamk +} // namespace origami diff --git a/shared/origami/include/origami/types.hpp b/shared/origami/include/origami/types.hpp new file mode 100644 index 000000000000..69675590943d --- /dev/null +++ b/shared/origami/include/origami/types.hpp @@ -0,0 +1,414 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright 2025 AMD ROCm(TM) Software + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "origami/log.hpp" +#include "origami/math.hpp" + +namespace origami { + +/** + * @brief Enumeration of supported data types. + * + */ +enum class data_type_t : int { + Float, + Double, + ComplexFloat, + ComplexDouble, + Half, + Int8x4, + Int32, + BFloat16, + Int8, + Int4, + Int64, + XFloat32, + Float8_fnuz, + BFloat8_fnuz, + Float8BFloat8_fnuz, + BFloat8Float8_fnuz, + Float8, + BFloat8, + Float8BFloat8, + BFloat8Float8, + Float6, + BFloat6, + Float4, + Count, + None = Count +}; + +/** + * @brief Convert integer to data_type_t enum. + * + * @param dt Integer value to convert + * @return data_type_t Corresponding data type + */ +inline data_type_t int_to_data_type(int dt) { return static_cast(dt); } + +/** + * @brief Convert data_type_t to number of bits. + * + * @param type Data type + * @return int Number of bits + */ +int datatype_to_bits(data_type_t type); + +/** + * @brief Convert data_type_t to number of bytes. + * + * @param type Data type + * @return int Number of bytes + */ +inline int data_type_to_bytes(data_type_t type) { + return math::safe_ceil_div(datatype_to_bits(type), 8); +} + +/** + * @brief Convert data_type_t to string. + * + * @param type Data type + * @return std::string String representation of data type + */ +std::string datatype_to_string(data_type_t type); + +/** + * @brief Convert string to data_type_t enum. + * + * @param s String value to convert + * @return data_type_t Corresponding data type + */ +data_type_t string_to_datatype(std::string s); + +/** + * @brief Struct to define a matrix instruction. + * + * Contains the dimensions and data type of a matrix instruction. + */ +struct matrix_instruction { + size_t MI_M; + size_t MI_N; + size_t MI_K; + data_type_t mi_input_type; + + matrix_instruction() : MI_M(0), MI_N(0), MI_K(0), mi_input_type(data_type_t::Float) {} + + matrix_instruction(size_t m, size_t n, size_t k, data_type_t mi_input_type) + : MI_M(m), MI_N(n), MI_K(k), mi_input_type(mi_input_type) {} + + matrix_instruction(const matrix_instruction& other) + : MI_M(other.MI_M), MI_N(other.MI_N), MI_K(other.MI_K), mi_input_type(other.mi_input_type) {} + + bool operator<(const matrix_instruction& other) const { + return std::tie(MI_M, MI_N, MI_K, mi_input_type) < + std::tie(other.MI_M, other.MI_N, other.MI_K, other.mi_input_type); + } + + bool operator==(const matrix_instruction& other) const { + return MI_M == other.MI_M && MI_N == other.MI_N && MI_K == other.MI_K && + mi_input_type == other.mi_input_type; + } + + std::size_t hash() const { + return std::hash()(MI_M) ^ std::hash()(MI_N) ^ std::hash()(MI_K) ^ + std::hash()(mi_input_type); + } +}; + +/** + * @brief Grid selection algorithms for StreamK. + * + * Different algorithms to select the grid size for kernel execution. + */ +enum class grid_selection_t : std::uint32_t { + number_of_cus = 0, ///< Use number of compute units + min_resources = 1, ///< Use minimum required resources + energy_aware = 2, ///< Energy-aware selection + reduction_cost_aware = 3, ///< Reduction cost-aware selection + data_parallel = 4, ///< Data parallel approach + analytical = 5, ///< Analytical model-based selection + k_split_aware = 6, ///< K-split aware selection + count, ///< Count of Grid selection algos + none = 0xFFFFFFFFu ///< Explicitly invalid +}; + +/** + * @brief Reduction strategy types for StreamK. + * + * Different algorithms for reduction operations in StreamK. + */ +enum class reduction_t : std::uint32_t { + spinlock = 0, ///< Spinlock-based reduction + tree = 1, ///< Tree-based reduction + parallel = 2, ///< Parallel reduction + atomic = 3, ///< Atomic Add-based reduction + count, ///< Count of reduction types + none = 0xFFFFFFFFu ///< Explicitly invalid / no reduction +}; + +/** + * @brief Convert integer to reduction_t enum. + * + * @param rt Integer value to convert + * @return reduction_t Corresponding reduction type + */ +inline constexpr reduction_t int_to_reduction_t(int rt) { return static_cast(rt); } + +/** + * @brief Indicates whether a matrix is supplied in transposed or not. + */ +enum class transpose_t { + T, + N, + + Count +}; + +/** + * @brief A compact 3-D dimension triple (M, N, K). + * + * Provides convenient accessors for common GEMM tiling parameters + * and helpers like mnk() for volume. + */ +struct dim3_t { + /// M dimension (rows). + std::size_t m; + + /// N dimension (columns). + std::size_t n; + + /// K dimension (reduction). + std::size_t k; + + constexpr bool operator==(const dim3_t& o) const noexcept { + return m == o.m && n == o.n && k == o.k; + } + + constexpr bool operator!=(const dim3_t& o) const noexcept { return !(*this == o); } + + /// @return Product m*n. + constexpr std::size_t mn() const noexcept { return m * n; } + + /// @return Product m*k. + constexpr std::size_t mk() const noexcept { return m * k; } + + /// @return Product n*k. + constexpr std::size_t nk() const noexcept { return n * k; } + + /// @return Product m*n*k. + constexpr std::size_t mnk() const noexcept { return m * n * k; } +}; + +/** + * @brief Runtime options for controlling debug, heuristics, and other behaviors. + * + * Provides programmatic access to runtime configuration options that can be + * set either programmatically or via environment variables. + */ +struct runtime_options { + /// Enable debug logging (reads from ANALYTICAL_GEMM_DEBUG env var) + bool debug_enabled; + + /// Enable heuristics (reads from ANALYTICAL_GEMM_HEURISTICS env var) + bool heuristics_enabled; + + /// Heuristics variance threshold (reads from ANALYTICAL_GEMM_HEURISTICS_VARIANCE env var) + double heuristics_variance; + + /** + * @brief Default constructor that reads from environment variables. + */ + runtime_options(); + + /** + * @brief Constructor with explicit values (does not read from environment). + */ + runtime_options(bool debug, bool heuristics, double variance); + + /** + * @brief Get the global runtime options instance. + */ + static runtime_options& get(); + + /** + * @brief Read debug setting from environment variable. + * @return true if ANALYTICAL_GEMM_DEBUG is set to "1", false otherwise + */ + static bool read_debug_from_env(); + + /** + * @brief Read heuristics setting from environment variable. + * @return true if ANALYTICAL_GEMM_HEURISTICS is set to "1", false otherwise + */ + static bool read_heuristics_from_env(); + + /** + * @brief Read heuristics variance from environment variable. + * @return double Variance value from ANALYTICAL_GEMM_HEURISTICS_VARIANCE, or 0.0 if not set + */ + static double read_heuristics_variance_from_env(); + + /** + * @brief Update runtime options from environment variables. + */ + void update_from_env(); +}; + +/** + * @brief Full kernel configuration (tile shape + execution parameters). + * + * Holds the geometric tile sizes along with occupancy, + * work-group mapping (WGM), and cache-control hints. + */ +struct config_t { + /// Macro tile and matrix-instruction shape. + dim3_t mt{0, 0, 0}; + dim3_t mi{0, 0, 0}; + + /// Occupancy (number of waves resident per CU). + int occupancy = -1; + + /// Reorder workgroup id for L2 reuse. + int workgroup_mapping = 0; + + /// Whether operand A is accessed with cache-flags. + int cache_hints_a = 0; + + /// Whether operand B is accessed with cache-flags. + int cache_hints_b = 0; + + /// Workspace size parameters. + std::size_t workspace_size = 0; + std::size_t workspace_size_per_elem_c = 0; + + /// Reduction strategy. + reduction_t reduction_strategy = reduction_t::none; + + /// Runtime options (if null, uses global singleton) + const runtime_options* runtime_opts{nullptr}; + + /// Logger for analytical metrics + mutable logger_t logger; + + constexpr bool operator==(const config_t& o) const noexcept { + return mt == o.mt && mi == o.mi && cache_hints_a == o.cache_hints_a && + cache_hints_b == o.cache_hints_b && workgroup_mapping == o.workgroup_mapping; + } + + std::size_t hash() const { + return std::hash()(mt.m) ^ std::hash()(mt.n) ^ std::hash()(mt.k) ^ + std::hash()(mi.m) ^ std::hash()(mi.n) ^ std::hash()(mi.k) ^ + std::hash()(cache_hints_a) ^ std::hash()(cache_hints_b) ^ + std::hash()(workgroup_mapping); + } + + void validate() const { + if (!is_valid()) { throw std::runtime_error("Invalid config_t"); } + } + + bool is_valid() const { + return mt.m > 0 && mt.n > 0 && mt.k > 0 && mi.m > 0 && mi.n > 0 && mi.k > 0 && occupancy > 0; + } +}; + +/** + * @brief Latency prediction result given kernel configuration. + * + * Combines a configuration with its estimated latency. + */ +struct prediction_result_t { + double latency; + config_t config; +}; + +/** + * @brief Struct to define the GEMM problem characteristics. + * + * Contains all the parameters needed to describe a GEMM operation, + * including matrix dimensions, data types, and operation flags. + */ +struct problem_t { + /// Size of the problem: M, N, K. + dim3_t size{0, 0, 0}; + + /// Batch size. + std::size_t batch = 1; + + /// Transpose types (TT, TN, NT, TT.) + transpose_t a_transpose = transpose_t::N; + transpose_t b_transpose = transpose_t::N; + + /// Data types: A, B, C, D. + data_type_t a_dtype = data_type_t::None; + data_type_t b_dtype = data_type_t::None; + data_type_t c_dtype = data_type_t::None; + data_type_t d_dtype = data_type_t::None; + + /// Compute type. + data_type_t mi_dtype = data_type_t::None; + + /// MX block size. + std::size_t a_mx_block_size = 0; + std::size_t b_mx_block_size = 0; +}; + +/** + * @brief Get runtime options from config, or global singleton if config doesn't specify. + * + * @param config Configuration struct (may contain runtime_opts pointer) + * @return const runtime_options& Reference to runtime options + */ +inline const runtime_options& get_runtime_options(const config_t& config) { + return config.runtime_opts ? *config.runtime_opts : runtime_options::get(); +} + +} // namespace origami + +// Specialization of std::hash in the std namespace for use of std::unordered_map with +// matrix_instruction and config_t as keys. +namespace std { +template <> +struct hash { + std::size_t operator()(const origami::matrix_instruction& k) const { return k.hash(); } +}; + +template <> +struct hash { + std::size_t operator()(const origami::config_t& config) const noexcept { return config.hash(); } +}; +} // namespace std diff --git a/shared/origami/include/origami/utils.hpp b/shared/origami/include/origami/utils.hpp deleted file mode 100644 index 92e07c6fe8f3..000000000000 --- a/shared/origami/include/origami/utils.hpp +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "origami/gemm.hpp" -#include "origami/hardware.hpp" -#include -#include -#include -#include // For std::function - -namespace origami -{ - using result_tuple = std::tuple; // non_temporal_b - - using tile_tuple = std::tuple; // non_temporal_b - - size_t select_best_grid_size(size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - const hardware_t& hardware, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, - size_t element_size_B, - size_t element_size_out, - data_type_t mi_datatype, - size_t mx_block_size, - double H_L2, - int WGM, - size_t biggest_allowable_split = 8, - size_t max_cus = 0); - - std::vector select_best_macro_tile_size(size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - const hardware_t& hardware, - const std::vector& MT_list, - size_t element_size_A, - size_t element_size_B, - size_t element_size_out, - data_type_t mi_datatype, - size_t mx_block_size, - double H_L2, - bool print, - int WGM, - size_t max_cus = 0); - - std::vector sweep_macro_tile_sizes(size_t M, - size_t N, - size_t K, - bool transA, - bool transB, - hardware_t& hardware, - size_t element_size = 2, - size_t max_MT_M = 256, - size_t max_MT_N = 256, - size_t max_MT_K = 128, - size_t step_MT_M = 32, - size_t step_MT_N = 32, - size_t step_MT_K = 32, - double H_L2 = 0.8, - const std::vector& tiles_to_add - = {}, - bool print = false); - - std::tuple select_best_wgm(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - size_t MT_M, - size_t MT_N, - size_t MT_K, - int nta, - int ntb, - size_t skGrid, - bool print); - - double compute_tflops_from_latency(double latency_cycles, - size_t M, - size_t N, - size_t K, - double clock_GHz); - -} // namespace origami diff --git a/shared/origami/python/CMakeLists.txt b/shared/origami/python/CMakeLists.txt index d0ae6eac5482..f969e328585f 100644 --- a/shared/origami/python/CMakeLists.txt +++ b/shared/origami/python/CMakeLists.txt @@ -1,5 +1,27 @@ -# Copyright Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT +################################################################################ +# +# MIT License +# +# Copyright 2025 AMD ROCm(TM) Software +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- +# ies of the Software, and to permit persons to whom the Software is furnished +# to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- +# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- +# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) diff --git a/shared/origami/python/README.md b/shared/origami/python/README.md index ff7f2adfe83b..9cd51b02596d 100644 --- a/shared/origami/python/README.md +++ b/shared/origami/python/README.md @@ -35,9 +35,7 @@ import origami hardware = origami.getHardwareForDevice(args.device) -result = origami.select_best_macro_tile_size( - args.m, args.n, args.k, args.transA, args.transB, hardware, tile_list, args.element_size, args.miDataType, 0.8, args.debug, args.print, args.wgm - ) +result = origami.rank_configs(problem, hardware, configs) ``` ## Modifying `origami_module.cpp` diff --git a/shared/origami/python/origami_grid_test.py b/shared/origami/python/origami_grid_test.py index 8a6c6822c33f..20b3dac0afee 100755 --- a/shared/origami/python/origami_grid_test.py +++ b/shared/origami/python/origami_grid_test.py @@ -1,15 +1,35 @@ -# Copyright Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT +################################################################################ +# +# MIT License +# +# Copyright 2025 AMD ROCm(TM) Software +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- +# ies of the Software, and to permit persons to whom the Software is furnished +# to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- +# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- +# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ import argparse import origami import math + def parseArguments(): parser = argparse.ArgumentParser(description="Test StreamK Grid Selection.") parser.add_argument("-m", type=int, default=8192, help="Problem M dimension") @@ -30,13 +50,26 @@ def parseArguments(): "--type_b", type=str, default="f16", help="Size of each element in B in bits" ) parser.add_argument( - "--type_acc", type=str, default="f32", help="Size of each element in partial tile in bits" + "--type_acc", + type=str, + default="f32", + help="Size of each element in partial tile in bits", + ) + parser.add_argument( + "--type_d", + type=str, + default="f16", + help="Size of each element in the output in bits", ) parser.add_argument( - "--type_d", type=str, default="f16", help="Size of each element in the output in bits" + "--type_compute", type=str, default=None, help="Instruction input type" + ) + parser.add_argument( + "--workspace_size", + type=int, + default=0, + help="Amount of workspace available in bytes", ) - parser.add_argument("--type_compute", type=str, default=None, help="Instruction input type") - parser.add_argument("--workspace_size", type=int, default=0, help="Amount of workspace available in bytes") parser.add_argument("--debug", action="store_true", help="Enable debug mode") parser.add_argument("--print", action="store_true", help="Print hardware info") parser.add_argument( @@ -46,17 +79,30 @@ def parseArguments(): parser.add_argument("--mt_m", type=int, default=32, help="Macro-tile dimension M") parser.add_argument("--mt_n", type=int, default=32, help="Macro-tile dimension N") parser.add_argument("--mt_k", type=int, default=256, help="Macro-tile dimension K") - parser.add_argument("--mi_m", type=int, default=16, help="Machine Instruction dimension M") - parser.add_argument("--mi_n", type=int, default=16, help="Machine Instruction dimension N") - parser.add_argument("--mi_k", type=int, default=16, help="Machine Instruction dimension K") + parser.add_argument( + "--mi_m", type=int, default=16, help="Machine Instruction dimension M" + ) + parser.add_argument( + "--mi_n", type=int, default=16, help="Machine Instruction dimension N" + ) + parser.add_argument( + "--mi_k", type=int, default=16, help="Machine Instruction dimension K" + ) parser.add_argument("--occupancy", type=int, default=1, help="Occupancy of kernel") - parser.add_argument("--dynamic_grid_version", type=int, default=5, help="Version of Dynamic Grid Selection to use") + parser.add_argument( + "--dynamic_grid_version", + type=int, + default=5, + help="Version of Dynamic Grid Selection to use", + ) args = parser.parse_args() if args.type_compute is None: - if origami.datatype_to_bits(origami.string_to_datatype(args.type_a)) > origami.datatype_to_bits(origami.string_to_datatype(args.type_b)): + if origami.datatype_to_bits( + origami.string_to_datatype(args.type_a) + ) > origami.datatype_to_bits(origami.string_to_datatype(args.type_b)): args.type_compute = args.type_a else: args.type_compute = args.type_b @@ -72,42 +118,52 @@ def main(): if args.print: hardware.print() - reduction = origami.select_reduction( - args.m, - args.n, - args.k, - args.batch, - args.mt_m, - args.mt_n, - args.mt_k, - hardware, - args.dynamic_grid_version - ) - - winner_grid = origami.select_grid( - args.m, - args.n, - args.k, - args.batch, - args.trans_a, - args.trans_b, - origami.datatype_to_bits(origami.string_to_datatype(args.type_a)), - origami.datatype_to_bits(origami.string_to_datatype(args.type_b)), - origami.datatype_to_bits(origami.string_to_datatype(args.type_d)), - origami.string_to_datatype(args.type_compute), - args.workspace_size, - args.mt_m, - args.mt_n, - args.mt_k, - args.mi_m, # MI_M - args.mi_n, # MI_N - args.mi_k, # MI_K - args.wgm, - origami.datatype_to_bits(origami.string_to_datatype(args.type_acc)) // 8, - args.occupancy, - hardware, - args.dynamic_grid_version, - reduction + # Create problem description + problem = origami.problem_t() + problem.size = origami.dim3_t(args.m, args.n, args.k) + problem.batch = args.batch + problem.a_transpose = ( + origami.transpose_t.T if args.trans_a else origami.transpose_t.N + ) + problem.b_transpose = ( + origami.transpose_t.T if args.trans_b else origami.transpose_t.N + ) + problem.a_dtype = origami.string_to_datatype(args.type_a) + problem.b_dtype = origami.string_to_datatype(args.type_b) + problem.d_dtype = origami.string_to_datatype(args.type_d) + problem.c_dtype = problem.d_dtype + problem.mi_dtype = origami.string_to_datatype(args.type_compute) + problem.a_mx_block_size = 0 + problem.b_mx_block_size = 0 + + # Create config + config = origami.config_t() + config.mt = origami.dim3_t(args.mt_m, args.mt_n, args.mt_k) + config.mi = origami.dim3_t(args.mi_m, args.mi_n, args.mi_k) + config.occupancy = args.occupancy + config.workgroup_mapping = args.wgm + + # Select reduction strategy + grid_algorithm = origami.grid_selection_t.analytical # default to analytical + if args.dynamic_grid_version == 0: + grid_algorithm = origami.grid_selection_t.number_of_cus + elif args.dynamic_grid_version == 1: + grid_algorithm = origami.grid_selection_t.min_resources + elif args.dynamic_grid_version == 2: + grid_algorithm = origami.grid_selection_t.energy_aware + elif args.dynamic_grid_version == 3: + grid_algorithm = origami.grid_selection_t.reduction_cost_aware + elif args.dynamic_grid_version == 4: + grid_algorithm = origami.grid_selection_t.data_parallel + elif args.dynamic_grid_version == 5: + grid_algorithm = origami.grid_selection_t.analytical + elif args.dynamic_grid_version == 6: + grid_algorithm = origami.grid_selection_t.k_split_aware + + reduction = origami.select_reduction(problem, hardware, config, grid_algorithm) + + winner_grid = origami.select_grid_size( + problem, hardware, config, grid_algorithm, hardware.N_CU ) print(f"Best reduction algo : {reduction}") diff --git a/shared/origami/python/origami_module.cpp b/shared/origami/python/origami_module.cpp index 26c8409bd28c..0bb0f8a9a8bf 100644 --- a/shared/origami/python/origami_module.cpp +++ b/shared/origami/python/origami_module.cpp @@ -1,172 +1,290 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "origami/hardware.hpp" -#include "origami/streamk.hpp" -#include "origami/utils.hpp" +/******************************************************************************* + * + * MIT License + * + * Copyright 2025 AMD ROCm(TM) Software + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ #include -#include -#include #include -#include #include +#include #include +#include +#include +#include "origami/gemm.hpp" +#include "origami/hardware.hpp" +// logger_t is defined in types.hpp +#include "origami/origami.hpp" +#include "origami/streamk.hpp" +#include "origami/types.hpp" using hardware_t = origami::hardware_t; -namespace nb = nanobind; +namespace nb = nanobind; using namespace nb::literals; -NB_MODULE(origami, m) -{ - nanobind::enum_(m, "architecture_t") - .value("gfx942", hardware_t::architecture_t::gfx942) - .value("gfx950", hardware_t::architecture_t::gfx950) - .export_values(); - - nanobind::enum_(m, "data_type_t") - .value("Float", origami::data_type_t::Float) - .value("ComplexFloat", origami::data_type_t::ComplexFloat) - .value("ComplexDouble", origami::data_type_t::ComplexDouble) - .value("Double", origami::data_type_t::Double) - .value("Half", origami::data_type_t::Half) - .value("Int8x4", origami::data_type_t::Int8x4) - .value("Int32", origami::data_type_t::Int32) - .value("BFloat16", origami::data_type_t::BFloat16) - .value("Int8", origami::data_type_t::Int8) - .value("Int64", origami::data_type_t::Int64) - .value("XFloat32", origami::data_type_t::XFloat32) - .value("Float8_fnuz", origami::data_type_t::Float8_fnuz) - .value("BFloat8_fnuz", origami::data_type_t::BFloat8_fnuz) - .value("Float8BFloat8_fnuz", origami::data_type_t::Float8BFloat8_fnuz) - .value("BFloat8Float8_fnuz", origami::data_type_t::BFloat8Float8_fnuz) - .value("Float8", origami::data_type_t::Float8) - .value("BFloat8", origami::data_type_t::BFloat8) - .value("Float8BFloat8", origami::data_type_t::Float8BFloat8) - .value("BFloat8Float8", origami::data_type_t::BFloat8Float8) - .value("Float6", origami::data_type_t::Float6) - .value("BFloat6", origami::data_type_t::BFloat6) - .value("Float4", origami::data_type_t::Float4) - .export_values(); - - m.def("int_to_data_type", - &origami::int_to_data_type, - "Convert int to data_type_t."); - - nanobind::enum_(m, "reduction_type") - .value("Tree", origami::streamk::reduction_type::Tree) - .value("Parallel", origami::streamk::reduction_type::Parallel) - .export_values(); - - m.def("int_to_reduction_type", - &origami::streamk::int_to_reduction_type, - "Convert int to reduction_type."); - - nanobind::class_(m, "hardware_t") - .def(nanobind::init>()) - .def("print", &hardware_t::print) - .def("print_debug_info", &hardware_t::print_debug_info) - .def_rw("N_CU", &hardware_t::N_CU) - .def_rw("lds_capacity", &hardware_t::lds_capacity) - .def_rw("mem1_perf_ratio", &hardware_t::mem1_perf_ratio) - .def_rw("mem2_perf_ratio", &hardware_t::mem2_perf_ratio) - .def_rw("mem3_perf_ratio", &hardware_t::mem3_perf_ratio) - .def_rw("L2_capacity", &hardware_t::L2_capacity) - .def_rw("CU_per_L2", &hardware_t::CU_per_L2) - .def_rw("compute_clock_ghz", &hardware_t::compute_clock_ghz) - .def_rw("parallel_mi_cu", &hardware_t::parallel_mi_cu) - .def_rw("mem_bw_per_wg_coefficients", &hardware_t::mem_bw_per_wg_coefficients) - .def_rw("NUM_XCD", &hardware_t::NUM_XCD) - .def_rw("debug_info", &hardware_t::debug_info); - - m.def("get_hardware_for_device", - &hardware_t::get_hardware_for_device, - "This gets a hardware object for a device."); - - m.def("datatype_to_bits", &origami::data_type_to_bits, "Return the number of bits in a datatype"); - m.def("string_to_datatype", &origami::string_to_data_type, "Convert a string representation of a datatype into data_type_t enum"); - m.def("select_best_macro_tile_size", - &origami::select_best_macro_tile_size, - "M"_a, - "N"_a, - "K"_a, - "batch"_a, - "transA"_a, - "transB"_a, - "hardware"_a, - "MT_list"_a, - "element_size_A"_a, - "element_size_B"_a, - "element_size_out"_a, - "mi_datatype"_a, - "mx_block_size"_a, - "H_L2"_a, - "print"_a, - "WGM"_a, - "max_cus"_a = 0, - "Get best macro tile sizes."); - m.def("select_reduction", &origami::streamk::select_reduction, "Select best StreamK reduction strategy"); - m.def("select_grid", &origami::streamk::select_grid, - "x"_a, - "y"_a, - "z"_a, - "batch"_a, - "trans_a"_a, - "trans_b"_a, - "element_size_A"_a, - "element_size_B"_a, - "element_size_out"_a, - "mi_datatype"_a, - "workspace_size"_a, - "mt_m"_a, - "mt_n"_a, - "mt_k"_a, - "mi_m"_a, - "mi_n"_a, - "mi_k"_a, - "workgroup_mapping"_a, - "workspace_size_per_elem_c"_a, - "occupancy"_a, - "analytical_hardware"_a, - "dynamic_grid_version"_a, - "reduction_strategy"_a, - "max_cus"_a = 0, - "Select Best StreamK Grid Size"); - m.def("compute_total_latency", &origami::compute_total_latency, - "hardware"_a, - "M"_a, - "N"_a, - "K"_a, - "batch"_a, - "transA"_a, - "transB"_a, - "MT_M"_a, - "MT_N"_a, - "MT_K"_a, - "MI_M"_a, - "MI_N"_a, - "MI_K"_a, - "element_size_A"_a, - "element_size_B"_a, - "element_size_out"_a, - "mi_datatype"_a, - "mx_block_size"_a, - "WGM"_a, - "non_temporal_a"_a = 0, - "non_temporal_b"_a = 0, - "occupancy"_a = 1, - "split"_a = 0, - "max_cus"_a = 0, - "Compute the total latency of a gemm"); - m.def("select_best_wgm", &origami::select_best_wgm, "Get best workgroup mapping."); +NB_MODULE(origami, m) { + nanobind::enum_(m, "architecture_t") + .value("gfx942", hardware_t::architecture_t::gfx942) + .value("gfx950", hardware_t::architecture_t::gfx950) + .export_values(); + + nanobind::enum_(m, "data_type_t") + .value("Float", origami::data_type_t::Float) + .value("ComplexFloat", origami::data_type_t::ComplexFloat) + .value("ComplexDouble", origami::data_type_t::ComplexDouble) + .value("Double", origami::data_type_t::Double) + .value("Half", origami::data_type_t::Half) + .value("Int8x4", origami::data_type_t::Int8x4) + .value("Int32", origami::data_type_t::Int32) + .value("BFloat16", origami::data_type_t::BFloat16) + .value("Int8", origami::data_type_t::Int8) + .value("Int64", origami::data_type_t::Int64) + .value("XFloat32", origami::data_type_t::XFloat32) + .value("Float8_fnuz", origami::data_type_t::Float8_fnuz) + .value("BFloat8_fnuz", origami::data_type_t::BFloat8_fnuz) + .value("Float8BFloat8_fnuz", origami::data_type_t::Float8BFloat8_fnuz) + .value("BFloat8Float8_fnuz", origami::data_type_t::BFloat8Float8_fnuz) + .value("Float8", origami::data_type_t::Float8) + .value("BFloat8", origami::data_type_t::BFloat8) + .value("Float8BFloat8", origami::data_type_t::Float8BFloat8) + .value("BFloat8Float8", origami::data_type_t::BFloat8Float8) + .value("Float6", origami::data_type_t::Float6) + .value("BFloat6", origami::data_type_t::BFloat6) + .value("Float4", origami::data_type_t::Float4) + .export_values(); + + // After your other nanobind::enum_ blocks + nanobind::enum_(m, "transpose_t") + .value("T", origami::transpose_t::T) + .value("N", origami::transpose_t::N) + // Optional: usually you don't expose Count, but you can if you want + // .value("Count", origami::transpose_t::Count) + .export_values(); + + m.def("int_to_data_type", &origami::int_to_data_type, "Convert int to data_type_t."); + + nanobind::enum_(m, "grid_selection_t") + .value("number_of_cus", origami::grid_selection_t::number_of_cus) + .value("min_resources", origami::grid_selection_t::min_resources) + .value("energy_aware", origami::grid_selection_t::energy_aware) + .value("reduction_cost_aware", origami::grid_selection_t::reduction_cost_aware) + .value("data_parallel", origami::grid_selection_t::data_parallel) + .value("analytical", origami::grid_selection_t::analytical) + .value("k_split_aware", origami::grid_selection_t::k_split_aware) + .export_values(); + + nanobind::enum_(m, "reduction_t") + .value("tree", origami::reduction_t::tree) + .value("parallel", origami::reduction_t::parallel) + .export_values(); + + m.def("int_to_reduction_t", &origami::int_to_reduction_t, "Convert int to reduction_t."); + + // Add new struct bindings + nanobind::class_(m, "dim3_t") + .def(nanobind::init()) + .def_rw("m", &origami::dim3_t::m) + .def_rw("n", &origami::dim3_t::n) + .def_rw("k", &origami::dim3_t::k) + .def("mn", &origami::dim3_t::mn) + .def("mk", &origami::dim3_t::mk) + .def("nk", &origami::dim3_t::nk) + .def("mnk", &origami::dim3_t::mnk); + + nanobind::class_(m, "logger_t") + .def(nanobind::init<>()) + .def("clear", &origami::logger_t::clear, "Clear all logged metrics") + .def("print", &origami::logger_t::print, "Print all metrics as JSON to stdout") + .def("export_json", &origami::logger_t::export_json, "Export metrics to a JSON file") + .def("get_metrics", &origami::logger_t::get_metrics, "Get all metrics as a map") + .def("empty", &origami::logger_t::empty, "Check if logger has any metrics") + // Overloads for templated log() method + .def( + "log", + [](origami::logger_t& self, const std::string& key, int value) { self.log(key, value); }, + "Log an integer value") + .def( + "log", + [](origami::logger_t& self, const std::string& key, double value) { + self.log(key, value); + }, + "Log a double value") + .def( + "log", + [](origami::logger_t& self, const std::string& key, const std::string& value) { + self.log(key, value); + }, + "Log a string value") + .def( + "log", + [](origami::logger_t& self, const std::string& key, bool value) { self.log(key, value); }, + "Log a boolean value") + .def( + "log", + [](origami::logger_t& self, const std::string& key, size_t value) { + self.log(key, value); + }, + "Log a size_t value"); + + nanobind::class_(m, "config_t") + .def(nanobind::init<>()) + .def_rw("mt", &origami::config_t::mt) + .def_rw("mi", &origami::config_t::mi) + .def_rw("occupancy", &origami::config_t::occupancy) + .def_rw("workgroup_mapping", &origami::config_t::workgroup_mapping) + .def_rw("cache_hints_a", &origami::config_t::cache_hints_a) + .def_rw("cache_hints_b", &origami::config_t::cache_hints_b) + .def_rw("workspace_size", &origami::config_t::workspace_size) + .def_rw("workspace_size_per_elem_c", &origami::config_t::workspace_size_per_elem_c) + .def_rw("logger", &origami::config_t::logger); + + nanobind::class_(m, "prediction_result_t") + .def(nanobind::init<>()) + .def_rw("latency", &origami::prediction_result_t::latency) + .def_rw("config", &origami::prediction_result_t::config); + + nanobind::class_(m, "problem_t") + .def(nanobind::init<>()) + .def_rw("size", &origami::problem_t::size) + .def_rw("batch", &origami::problem_t::batch) + .def_rw("a_transpose", &origami::problem_t::a_transpose) + .def_rw("b_transpose", &origami::problem_t::b_transpose) + .def_rw("a_dtype", &origami::problem_t::a_dtype) + .def_rw("b_dtype", &origami::problem_t::b_dtype) + .def_rw("c_dtype", &origami::problem_t::c_dtype) + .def_rw("d_dtype", &origami::problem_t::d_dtype) + .def_rw("mi_dtype", &origami::problem_t::mi_dtype) + .def_rw("a_mx_block_size", &origami::problem_t::a_mx_block_size) + .def_rw("b_mx_block_size", &origami::problem_t::b_mx_block_size); + + nanobind::class_(m, "hardware_t") + .def(nanobind::init>()) + .def("print", &hardware_t::print) + .def_rw("N_CU", &hardware_t::N_CU) + .def_rw("lds_capacity", &hardware_t::lds_capacity) + .def_rw("mem1_perf_ratio", &hardware_t::mem1_perf_ratio) + .def_rw("mem2_perf_ratio", &hardware_t::mem2_perf_ratio) + .def_rw("mem3_perf_ratio", &hardware_t::mem3_perf_ratio) + .def_rw("L2_capacity", &hardware_t::L2_capacity) + .def_rw("CU_per_L2", &hardware_t::CU_per_L2) + .def_rw("compute_clock_ghz", &hardware_t::compute_clock_ghz) + .def_rw("parallel_mi_cu", &hardware_t::parallel_mi_cu) + .def_rw("mem_bw_per_wg_coefficients", &hardware_t::mem_bw_per_wg_coefficients) + .def_rw("NUM_XCD", &hardware_t::NUM_XCD); + + m.def("get_hardware_for_device", + &hardware_t::get_hardware_for_device, + "This gets a hardware object for a device."); + + m.def("datatype_to_bits", &origami::datatype_to_bits, "Return the number of bits in a datatype"); + m.def("string_to_datatype", + &origami::string_to_datatype, + "Convert a string representation of a datatype into data_type_t enum"); + m.def("datatype_to_string", + &origami::datatype_to_string, + "Convert data_type_t enum to string representation"); + + m.def("select_config", + &origami::select_config, + "problem"_a, + "hardware"_a, + "configs"_a, + "Select best configuration based on problem and hardware"); + m.def("select_grid_size", + &origami::streamk::select_grid_size, + "problem"_a, + "hardware"_a, + "config"_a, + "algorithm"_a, + "max_cus"_a = 0, + "Select best grid size for the given configuration"); + m.def("select_workgroup_mapping", + &origami::select_workgroup_mapping, + "problem"_a, + "hardware"_a, + "config"_a, + "skGrid"_a, + + "Select best workgroup mapping"); + m.def("rank_configs", + &origami::rank_configs, + "problem"_a, + "hardware"_a, + "configs"_a, + "Rank configurations by performance"); + m.def("select_config_mnk", + &origami::select_config_mnk, + "M"_a, + "N"_a, + "K"_a, + "hardware"_a, + "configs"_a, + + "Select best configuration for M,N,K dimensions"); + m.def("select_topk_configs", + &origami::select_topk_configs, + "problem"_a, + "hardware"_a, + "configs"_a, + "topk"_a, + + "Select topk configurations"); + m.def("compute_perf_gflops", + &origami::compute_perf_gflops, + "hardware"_a, + "problem"_a, + "latency"_a, + + "Compute performance in GFLOPS"); + + // StreamK functions + m.def("select_reduction", + &origami::streamk::select_reduction, + "problem"_a, + "hardware"_a, + "config"_a, + "algorithm"_a, + "Select best StreamK reduction strategy"); + + // GEMM functions + m.def("compute_total_latency", + static_cast(&origami::compute_total_latency), + "problem"_a, + "hardware"_a, + "config"_a, + "max_cus"_a, + "Compute total latency"); } diff --git a/shared/origami/python/origami_test.py b/shared/origami/python/origami_test.py index fdf4d3931503..4946a378af33 100755 --- a/shared/origami/python/origami_test.py +++ b/shared/origami/python/origami_test.py @@ -1,7 +1,29 @@ #!/usr/bin/env python3 -# Copyright © Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT +################################################################################ +# +# MIT License +# +# Copyright 2025 AMD ROCm(TM) Software +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- +# ies of the Software, and to permit persons to whom the Software is furnished +# to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- +# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- +# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ import argparse import origami @@ -28,89 +50,92 @@ # BFloat6 == "bf6" # Float4 == "f4" -MatInst = {'gfx950': {}, 'gfx942': {}} -MatInst['gfx950']["f32"] = [ - (16,16,4,1), - (32,32,2,1) - ] +MatInst = {"gfx950": {}, "gfx942": {}} +MatInst["gfx950"]["f32"] = [(16, 16, 4, 1), (32, 32, 2, 1)] # MatInst['gfx950']["c32"] = [] # MatInst['gfx950']["c64"] = [] -MatInst['gfx950']["f64"] = [ - (16,16,4,1) - ] -MatInst['gfx950']["f16"] = [ +MatInst["gfx950"]["f64"] = [(16, 16, 4, 1)] +MatInst["gfx950"]["f16"] = [ # (4,4,4,16), #gfx942 - #[16,16,4,4] # never use 16x16x4x4 - #[16,16,16,1] #gfx942 - #[32,32,4,2] # never use 32x32x4x2 - #[32,32,8,1] #gfx942 - (16,16,32,1), #gfx950 - (32,32,16,1) #gfx950 - ] + # [16,16,4,4] # never use 16x16x4x4 + # [16,16,16,1] #gfx942 + # [32,32,4,2] # never use 32x32x4x2 + # [32,32,8,1] #gfx942 + (16, 16, 32, 1), # gfx950 + (32, 32, 16, 1), # gfx950 +] # MatInst['gfx950']["i32"] =[] -MatInst['gfx950']["bf16"] = MatInst['gfx950']["f16"] +MatInst["gfx950"]["bf16"] = MatInst["gfx950"]["f16"] # MatInst['gfx950']["i8"] =[ # (32,32,16,1), # (16,16,32,1), # (4,4,4,16) # ] -MatInst['gfx950']["xf32"] = [ +MatInst["gfx950"]["xf32"] = [ # (4,4,4,16), #gfx942 - #[16,16,4,4] # never use 16x16x4x4 - #[16,16,16,1] #gfx942 - #[32,32,4,2] # never use 32x32x4x2 - #[32,32,8,1] #gfx942 - (16,16,32,1), #gfx950 - (32,32,16,1) #gfx950 - ] -MatInst['gfx950']["f8"] = [ - (4,4,4,16), #gfx950, gfx942 - (16,16,128,1), #gfx950 - (32,32,64,1) #gfx950 - ] -MatInst['gfx950']["bf8"] = MatInst['gfx950']["f8"] + # [16,16,4,4] # never use 16x16x4x4 + # [16,16,16,1] #gfx942 + # [32,32,4,2] # never use 32x32x4x2 + # [32,32,8,1] #gfx942 + (16, 16, 32, 1), # gfx950 + (32, 32, 16, 1), # gfx950 +] +MatInst["gfx950"]["f8"] = [ + (4, 4, 4, 16), # gfx950, gfx942 + (16, 16, 128, 1), # gfx950 + (32, 32, 64, 1), # gfx950 +] +MatInst["gfx950"]["bf8"] = MatInst["gfx950"]["f8"] + def parseArguments(): - parser = argparse.ArgumentParser(description="""Get hypothetical Origami MTxDU selection for a size""") + parser = argparse.ArgumentParser( + description="""Get hypothetical Origami MTxDU selection for a size""" + ) parser.add_argument("-m", type=int, default=8192) parser.add_argument("-n", type=int, default=8192) parser.add_argument("-b", type=int, default=1) parser.add_argument("-k", type=int, default=8192) parser.add_argument("--trans_a", type=bool, default=True) parser.add_argument("--trans_b", type=bool, default=False) - parser.add_argument("--device", type=int, default=0) # to get hardware specs + parser.add_argument("--device", type=int, default=0) # to get hardware specs parser.add_argument("--type_a", type=str, default="f16") parser.add_argument("--type_b", type=str, default="f16") parser.add_argument("--type_d", type=str, default="f16") parser.add_argument("--scale_block_size", type=int, default=0) parser.add_argument("--wgm", type=int, default=6) - parser.add_argument("--sizes", type=bool, default=False) # to load the sizes from a csv file. -m/-n/-b/-k will be ignored if True. - parser.add_argument("--path", type=str, default="./sizes.csv") # path to the csv file. Fails if sizes is True, and path or file does not exist. + parser.add_argument( + "--sizes", type=bool, default=False + ) # to load the sizes from a csv file. -m/-n/-b/-k will be ignored if True. + parser.add_argument( + "--path", type=str, default="./sizes.csv" + ) # path to the csv file. Fails if sizes is True, and path or file does not exist. parser.add_argument("--arch", type=str, default="gfx950") parser.add_argument("--print", action="store_true") return parser.parse_args() -def createTileList(arch, gemmType): + +def createConfigList(arch, gemmType): LIST_OF_WAVEs_TO_INCLUDE = [[4, 1], [2, 2], [1, 4], [1, 2], [2, 1], [1, 1]] MIN_MT0 = MIN_MT1 = 16 MAX_MT0 = MAX_MT1 = 512 - # generate all MTs for each datatype: + # generate all configs for each datatype: bm_max = 0 - tile_list = set() + configs = [] for MI in MatInst[arch][gemmType]: for bm in range(bm_max + 1): - MIBlockM = 2 ** bm + MIBlockM = 2**bm for wave in LIST_OF_WAVEs_TO_INCLUDE: waveTileM = 0 waveTileN = 0 while True: - waveTileM+=1 - waveTileN=0 + waveTileM += 1 + waveTileN = 0 MatrixInstM = MI[0] * MIBlockM MT0 = MatrixInstM * waveTileM * wave[0] if MT0 < MIN_MT0: @@ -119,7 +144,7 @@ def createTileList(arch, gemmType): break while True: - waveTileN+=1 + waveTileN += 1 MatrixInstN = MI[1] / MIBlockM * MI[3] MT1 = int(MatrixInstN * waveTileN * wave[1]) @@ -129,89 +154,118 @@ def createTileList(arch, gemmType): break # LDS size check for LSU - LSU = max(1, 4//wave[0]//wave[1]) - if LSU > 1 and MT0*MT1*4*LSU > 256*256: + LSU = max(1, 4 // wave[0] // wave[1]) + if LSU > 1 and MT0 * MT1 * 4 * LSU > 256 * 256: continue - if MT0*MT1 > 256*256: + if MT0 * MT1 > 256 * 256: continue for DU in [16, 32, 64, 128, 256, 512, 1024]: - tile_list.add((MT0, MT1, DU, MI[0], MI[1], MI[2], 1, 6, 0, 0)) + # Create config_t object + config = origami.config_t() + config.mt = origami.dim3_t(MT0, MT1, DU) + config.mi = origami.dim3_t(MI[0], MI[1], MI[2]) + config.occupancy = 1 + config.workgroup_mapping = 6 + configs.append(config) + + return configs - return [tile for tile in tile_list] def main(): args = parseArguments() hardware = origami.get_hardware_for_device(args.device) - tile_list = createTileList(args.arch, args.type_a) - - print(" Number of unique MTxDU: ", len(tile_list)) - - if args.sizes: # sizes from a file - try: - with open(args.path, 'r') as csvfile: - csv_reader = csv.reader(csvfile) - print(f"M,N,Batch,K,MT0,MT1,DU,MI0,MI1,MI2,MI3,latency") - for row in csv_reader: - M = int(row[0]) - N = int(row[1]) - B = int(row[2]) - K = int(row[3]) - ret = origami.select_best_macro_tile_size( - M, - N, - K, - B, - args.trans_a, - args.trans_b, - hardware, - tile_list, - origami.datatype_to_bits(origami.string_to_datatype(args.type_a)), - origami.datatype_to_bits(origami.string_to_datatype(args.type_b)), - origami.datatype_to_bits(origami.string_to_datatype(args.type_d)), - origami.string_to_datatype(args.type_a), - args.scale_block_size, - 0.8, - args.print, - args.wgm, - ) - #MxNxBxK, MT0xMT1xDU, MI0xMI1xMI2xMI3, latency/cycles - print(f"{M},{N},{B},{K},{ret[0][1]},{ret[0][2]},{ret[0][3]},{ret[0][4]},{ret[0][5]},{ret[0][6]},{ret[0][7]},{ret[0][0]:0.3f}") - except FileNotFoundError: - raise FileNotFoundError(f"Error: The size file: '{args.path}' does not exist.") - else: # one size from the command line. - ret = origami.select_best_macro_tile_size( - args.m, - args.n, - args.k, - args.b, - args.trans_a, - args.trans_b, - hardware, - tile_list, - origami.datatype_to_bits(origami.string_to_datatype(args.type_a)), - origami.datatype_to_bits(origami.string_to_datatype(args.type_b)), - origami.datatype_to_bits(origami.string_to_datatype(args.type_d)), - origami.string_to_datatype(args.type_a), - args.scale_block_size, - 0.8, - args.print, - args.wgm, + configs = createConfigList(args.arch, args.type_a) + + print(" Number of unique configs: ", len(configs)) + + if args.sizes: # sizes from a file + try: + with open(args.path, "r") as csvfile: + csv_reader = csv.reader(csvfile) + print(f"M,N,Batch,K,MT0,MT1,DU,MI0,MI1,MI2,latency") + for row in csv_reader: + M = int(row[0]) + N = int(row[1]) + B = int(row[2]) + K = int(row[3]) + # Create problem description + problem = origami.problem_t() + problem.size = origami.dim3_t(M, N, K) + problem.batch = B + problem.a_transpose = ( + origami.transpose_t.T if args.trans_a else origami.transpose_t.N + ) + problem.b_transpose = ( + origami.transpose_t.T if args.trans_b else origami.transpose_t.N + ) + problem.a_dtype = origami.string_to_datatype(args.type_a) + problem.b_dtype = origami.string_to_datatype(args.type_b) + problem.d_dtype = origami.string_to_datatype(args.type_d) + problem.c_dtype = problem.d_dtype + problem.mi_dtype = problem.a_dtype + problem.a_mx_block_size = args.scale_block_size + problem.b_mx_block_size = args.scale_block_size + + # Select best config + best_config = origami.select_config(problem, hardware, configs) + latency = best_config.latency + + # MxNxBxK, MT0xMT1xDU, MI0xMI1xMI2, latency/cycles + print( + f"{M},{N},{B},{K},{best_config.mt.m},{best_config.mt.n},{best_config.mt.k},{best_config.mi.m},{best_config.mi.n},{best_config.mi.k},{latency:0.3f}" + ) + except FileNotFoundError: + raise FileNotFoundError( + f"Error: The size file: '{args.path}' does not exist." + ) + else: # one size from the command line. + # Create problem description + problem = origami.problem_t() + problem.size = origami.dim3_t(args.m, args.n, args.k) + problem.batch = args.b + problem.a_transpose = ( + origami.transpose_t.T if args.trans_a else origami.transpose_t.N + ) + problem.b_transpose = ( + origami.transpose_t.T if args.trans_b else origami.transpose_t.N + ) + problem.a_dtype = origami.string_to_datatype(args.type_a) + problem.b_dtype = origami.string_to_datatype(args.type_b) + problem.d_dtype = origami.string_to_datatype(args.type_d) + problem.c_dtype = problem.d_dtype + problem.mi_dtype = problem.a_dtype + problem.a_mx_block_size = args.scale_block_size + problem.b_mx_block_size = args.scale_block_size + + # Select best config + best_config = origami.select_config(problem, hardware, configs) + latency = best_config.latency + + print( + f"The best config for [{args.m}, {args.n}, {args.b}, {args.k}] is: MT=({best_config.config.mt.m},{best_config.config.mt.n},{best_config.config.mt.k}), MI=({best_config.config.mi.m},{best_config.config.mi.n},{best_config.config.mi.k}), latency={latency:0.3f}" ) - print(f"The best MTxDU for [{args.m}, {args.n}, {args.b}, {args.k}] is: {ret[0]}") # Match this with the top condition - # add an option to list top 10~15 - print(" full list of MTs: \n", ret) + + # Get top configs + ranked_configs = origami.rank_configs(problem, hardware, configs) + print(" Top 5 configs: ") + for i, config in enumerate(ranked_configs[:5]): + print( + f" {i+1}. MT=({config.config.mt.m},{config.config.mt.n},{config.config.mt.k}), MI=({config.config.mi.m},{config.config.mi.n},{config.config.mi.k}), latency={config.latency:0.3f}" + ) if args.print: hardware.print() - hardware.print_debug_info() - with open("MTxDU.log",'w') as file: - for tile in tile_list: - file.write(f'{tile}\n') + with open("configs.log", "w") as file: + for config in configs: + file.write( + f"MT=({config.mt.m},{config.mt.n},{config.mt.k}), MI=({config.mi.m},{config.mi.n},{config.mi.k})\n" + ) return 0 + if __name__ == "__main__": exit(main()) diff --git a/shared/origami/python/setup.py b/shared/origami/python/setup.py index ce2859d583da..e3493bb61fe5 100644 --- a/shared/origami/python/setup.py +++ b/shared/origami/python/setup.py @@ -1,5 +1,27 @@ -# Copyright © Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT +################################################################################ +# +# MIT License +# +# Copyright 2025 AMD ROCm(TM) Software +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- +# ies of the Software, and to permit persons to whom the Software is furnished +# to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- +# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- +# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ # setup.py from setuptools import setup, Extension diff --git a/shared/origami/src/origami/gemm.cpp b/shared/origami/src/origami/gemm.cpp index e7ee12689577..40facff606cd 100644 --- a/shared/origami/src/origami/gemm.cpp +++ b/shared/origami/src/origami/gemm.cpp @@ -1,12 +1,11 @@ // Copyright Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "origami/gemm.hpp" - -#include "origami/streamk.hpp" #include -#include // For timing +#include +#include #include +#include #include #include #include @@ -14,1484 +13,1064 @@ #include #include -namespace origami -{ - /* ---------------------------------------------------------------------------------------- */ - /* Misc. functions */ - /* ---------------------------------------------------------------------------------------- */ - // Performs `(n + d - 1) / d`, but is robust against the case where `(n + d - 1)` would - // overflow. - template - constexpr N safe_ceil_div(N n, D d) - { - // Static cast to undo integral promotion. - return static_cast(d == 0 ? 0 : (n / d + (n % d != 0 ? 1 : 0))); - } - - auto lds128_penalty(size_t dim_elems, size_t element_bits) - { - const size_t bytes = dim_elems * safe_ceil_div(element_bits, 8); - const size_t mod = bytes % 128; - if(mod == 0) - return 1.0; - const double frac = double(mod) / 128.0; // 0..1 - const double base = (element_bits <= 16) ? 1.1 : 1.35; // BF16/FP16 < FP32 - return 1.0 + base * frac; // up to ~2.35x worst case - }; - - double calculate_work_utilization( - size_t M, size_t N, size_t K, size_t MT_M, size_t MT_N, size_t MT_K) - { - if(M == 0 || N == 0 || K == 0 || MT_M == 0 || MT_N == 0 || MT_K == 0) - return 1.0; - - // Calculate the full dimensions covered by the launched grid of tiles (spatial). - const double launched_M = static_cast(safe_ceil_div(M, MT_M)) * MT_M; - const double launched_N = static_cast(safe_ceil_div(N, MT_N)) * MT_N; - - // Calculate the full depth covered by the k-loop iterations (temporal). - const double launched_K = static_cast(safe_ceil_div(K, MT_K)) * MT_K; - - // The utilization is the ratio of the useful problem volume to the total scheduled volume. - const double useful_volume = static_cast(M * N * K); - const double launched_volume = launched_M * launched_N * launched_K; - - if(launched_volume < 1.0) - return 1.0; // Avoid division by zero for tiny/empty problems - - const double utilization = useful_volume / launched_volume; - - return utilization; - } - - double calculate_output_utilization( - size_t M, size_t N, size_t MT_M, size_t MT_N, size_t vector_elems = 1) - { - if(M == 0 || N == 0 || MT_M == 0 || MT_N == 0) - return 1.0; - - // Tiled coverage in M/N - const double launched_M = static_cast(safe_ceil_div(M, MT_M)) * MT_M; - const double launched_N = static_cast(safe_ceil_div(N, MT_N)) * MT_N; +#include "origami/hardware.hpp" +#include "origami/math.hpp" +#include "origami/types.hpp" - // Optional: model vectorization/alignment remainders (e.g., ld/st width) - // This assumes vectors must be fully inside bounds; tail elements are scalarized. - const size_t M_vec = (vector_elems > 1) ? (M / vector_elems) * vector_elems : M; - const size_t N_vec = (vector_elems > 1) ? (N / vector_elems) * vector_elems : N; - - const double useful = static_cast(M_vec) * static_cast(N_vec); - const double launched = launched_M * launched_N; - - if(launched < 1.0) - return 1.0; - return useful / launched; - } - - // Computes the number of active compute units if there is only one wave and it is partial - // Otherwise, returns hardware.N_CU - std::tuple compute_CU_occupancy(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, - size_t element_size_B, - size_t element_size_out, - data_type_t mi_datatype, - int WGM, - size_t workspace_size, - size_t workspace_size_per_elem_c, - int occupancy, - int dynamic_grid_version, - size_t split, - size_t max_cus) - { - // Number of output MTs - size_t numMT_M = safe_ceil_div(M, MT_M); - size_t numMT_N = safe_ceil_div(N, MT_N); - size_t numMTs = numMT_M * numMT_N * batch; - - size_t numWGs, numActiveCUs, numWaves, splitFactor; - - if(split) // if it is given - { - split = split > 1 ? split : 1; - numWGs = numMTs * split; - numActiveCUs = numWGs < hardware.N_CU ? numWGs : hardware.N_CU; - numWaves = safe_ceil_div(numWGs, hardware.N_CU); - splitFactor = split; - - if(hardware_t::is_debug_enabled()) - { - hardware.log_debug("reduction type", "Origami"); - } - } - else // as what StreamK predicts - { - streamk::reduction_type rt = streamk::select_reduction(M, - N, - K, - batch, - MT_M, - MT_N, - MT_K, - hardware, - dynamic_grid_version); - numWGs = streamk::select_grid(M, - N, - K, - batch, - transA, - transB, - element_size_A, - element_size_B, - element_size_out, - mi_datatype, - workspace_size, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - WGM, - workspace_size_per_elem_c, - occupancy, - hardware, - dynamic_grid_version, - rt, - max_cus); - - // output variables - numActiveCUs = numWGs < hardware.N_CU ? numWGs : hardware.N_CU; - // There are cases in which StreamK combines multiple output MTs and assigns to 1 WG. - // That means, we artifically observe one full wave, but that is not what actually happens - // under the hood. From a theoretical point of view, these distributions change all of the - // computations in Origami. With current implementation, it is hard to capture that - // behaviour analytically. So for now, if the numWGs is less than the numMTs, we calculate - // numWaves based on the numMTs. Otherwise, we use numWGs to compute numWaves. - numWaves = numWGs > numMTs ? safe_ceil_div(numWGs, hardware.N_CU) - : safe_ceil_div(numMTs, hardware.N_CU); - splitFactor = safe_ceil_div(numWGs, numMTs); - - if(hardware_t::is_debug_enabled()) - { - hardware.log_debug("reduction type", streamk::rtype_to_string(rt)); - } - } - - if(hardware_t::is_debug_enabled()) - { - hardware.log_debug("numMTs", numMTs); - hardware.log_debug("numWGs", numWGs); - hardware.log_debug("numActiveCUs", numActiveCUs); - hardware.log_debug("numWaves", numWaves); - hardware.log_debug("splitFactor", splitFactor); - hardware.log_debug("max_cus", max_cus); - } - - return std::make_tuple(numWGs, numActiveCUs, numWaves, splitFactor); - } - - /* ---------------------------------------------------------------------------------------- */ - /* Compute-related functions */ - /* ---------------------------------------------------------------------------------------- */ - // Compute the number of matrix instructions required to compute a single MT_MXMT_NXMT_K tile. - size_t compute_number_matrix_instructions(const hardware_t& hardware, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K) - { - // Compute the number of Matrix Instructions required in each dim - size_t N_MI_M = safe_ceil_div(MT_M, MI_M); - size_t N_MI_N = safe_ceil_div(MT_N, MI_N); - size_t N_MI_K = safe_ceil_div(MT_K, MI_K); - // Total number of matrix instructions for MT_MxMT_NxMT_K tile - size_t N_MI = N_MI_M * N_MI_N * N_MI_K; - - return N_MI; - } - - // Compute arithmic intensity - double arithmetic_intensity(double m, double n, double k, double bytes_per_element) - { - // Numerator: 2.0 * m * n * k - // Denominator: (m*n + n*k + m*k) * bytes_per_element - double numerator = 2.0 * m * n * k; - double denominator = (m * n + n * k + m * k) * bytes_per_element; - - return numerator / denominator; - } - - // Computes Emulated arithmetic intensity for TF32 (assumes 3xBF16). - double emulated_tf32_arithmetic_intensity(double m, double n, double k, double bytes_per_element) - { - // Numerator: 3.0 * 2.0 * m * n * k - // Denominator: (m*n + n*k + m*k) * bytes_per_element - double numerator = 3.0 * 2.0 * m * n * k; - double denominator = (m * n + n * k + m * k) * bytes_per_element; - - return numerator / denominator; - } - - // Compute cvt overhead in x1 tf32 emulation - // TODO: We can generalize the same routine to cover more GEMMs that perform conversion - static inline double compute_cvt_overhead_x1(const hardware_t& hardware, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, - size_t element_size_B, - data_type_t mi_datatype) - { - // In X1 TF32 GEMMs, we do: - // v_cvt_pk_bf16_f32 (convert/pack fp32 to bf16) - // v_cvt_pk_bf16_f32 (convert/pack fp32 to bf16) - // ds_write_b64 - // That is, the extra instructions that we need to account for are the two cvt_pk ops - // per wave tile - - // However, these extra ops should not be added up to the overal tile latency becuase - // they can be run in parallel to Matix and Memory operations (given they are not dependent). - // So, We should ideally take L_tile = max{Mem, Comp, Vec (cvt latencies)}. - // Since, Vec latency is not modeled yet, we somehow model that into the current logic - // by scaling according to MFMA latencies and putting some heuristics to model the fact - // that these vector operations can be hidden (read interleaved) with the other memory - // or MFMA instructions. - - // TODO: Use kernel's actual wavetiles. - const double wave_tile_m = MT_M / 2.0; - const double wave_tile_n = MT_N / 2.0; - const double wave_tile_k = MT_K / MI_K; - - // MFMA count - const double N_MI = (wave_tile_m / MI_M) * (wave_tile_n / MI_N) * wave_tile_k; - const double num_mfma = 1.0 * static_cast(N_MI); - // Cycle scale per MI - const double L_MI = hardware.get_mi_latency(MI_M, MI_N, MI_K, mi_datatype); - const double mfma_cycles = num_mfma * L_MI; - - // 2) Bytes (per K-slice), using ceil-div to whole bytes - const double bytesA - = static_cast(wave_tile_m) * MT_K * safe_ceil_div(element_size_A, 8); - const double bytesB - = static_cast(wave_tile_n) * MT_K * safe_ceil_div(element_size_B, 8); - - // 3) Modeled transfer quanta (128B lines) - // dsA = bytesA / (128 * MI_M) - // dsB = bytesB / (128 * MI_N) - // GR = dsA (global->LDS modeled equal to A-side DS) - const double dsA = (bytesA / 128.0) / static_cast(MI_M); // LDS->VGPR for A - const double dsB = (bytesB / 128.0) / static_cast(MI_N); // LDS->VGPR for B - const double GR = dsA; // Global->LDS reads - const double LR = dsA + dsB; // total DS->VGPR - - // 5) Exposed vs hidden CVT - // spare MFMA - const double spare_mfma = std::max(0.0, num_mfma - LR - GR); - // 2 cvt per each ds_write (this for SS_BSS -- should be revised for other datatypes) - // Each cvt has a latency of four. It is scaled by the MI Latency - // Note: change 16.0 based on mi_data_type if we want to generalize this for all - // casting GEMMs. - const double cvt = (2.0 * 4.0 / 16.0 * L_MI) * LR; - // cvt ops are interleaved in main loop and don't stall matrix or memory units. - // Heuristically, we set - const double H = (8.0 / 16.0 * L_MI) * spare_mfma + (4.0 / 16.0) * L_MI * (LR + GR); - const double overhead = std::max(cvt - H, 0.0); - - return overhead; - } - - // Compute cvt overhead in tf32 emulation - static inline double compute_cvt_overhead(const hardware_t& hardware, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, - size_t element_size_B) - { - // Wave tile sizes - // TODO: Use kernel's actual wavetiles. - const double wave_tile_m = MT_M / 2.0; - const double wave_tile_n = MT_N / 2.0; - const double wave_tile_k = MT_K / MI_K; - - // MFMA count and cycles - const double N_MI = (wave_tile_m / MI_M) * (wave_tile_n / MI_N) * wave_tile_k; - - // TF32 emu: 3× BF16 MI issue slots - const double num_mfma = 3.0 * static_cast(N_MI); - - // Cycle scale per MI (use BF16 MI latency as the basic timing quantum) - const double L_MI_bf16 = hardware.get_mi_latency(MI_M, MI_N, MI_K, data_type_t::BFloat16); - //const double mfma_cycles = num_mfma * L_MI_bf16; - - // 2) Bytes (per K-slice), using ceil-div to whole bytes - const double bytesA - = static_cast(wave_tile_m) * MT_K * safe_ceil_div(element_size_A, 8); - const double bytesB - = static_cast(wave_tile_n) * MT_K * safe_ceil_div(element_size_B, 8); - - // const double mt_bytesA - // = static_cast(MT_M) * MT_K * safe_ceil_div(element_size_A, 8); - - // 3) Modeled transfer quanta (128B lines) - // dsA = bytesA / (128 * MI_M) - // dsB = bytesB / (128 * MI_N) - // GR = dsA (global->LDS modeled equal to A-side DS) - const double dsA = (bytesA / 128.0) / static_cast(MI_M); // LDS->VGPR for A - const double dsB = (bytesB / 128.0) / static_cast(MI_N); // LDS->VGPR for B - const double GR = dsA; // Global->LDS reads - const double LR = dsA + dsB; // total DS->VGPR - - // 4) Heuristic cycle weights (scaled to MI latency). - // Preserves your A=104, B=8, C=4 when L_MI_bf16 == 16. - // 24 vector instructions per 2 ds_reads (16x16x32) - // 24 vector instructions per 2 ds_reads for A and for B. - // 3 instructions per fp32 value read; number ds_read * size - const double A = (104.0 / 16.0) * L_MI_bf16; // CVT per LR-sized chunk (DS->VGPR) - const double B = (8.0 / 16.0) * L_MI_bf16; // hidden per spare MFMA slot - // MI16: 16 - 4 (12 cycles), for those 4 cycles, VGPRs are locked. 8 cycles to do anything. - const double C = (4.0 / 16.0) * L_MI_bf16; // hidden per (LR+GR) slot // MI16 - // 32 cycles (mfma), 4 cycles, 28, 4 vgpr lock, 24 cycles left. - // 24: 6 conv instructions, 3 ds_reads, ~6 grs - - // 5) Exposed vs hidden CVT - const double spare_mfma = std::max(0.0, num_mfma - LR - GR); - const double cvt = A * dsA; // only DS->VGPR contributes CVT - const double H = B * spare_mfma + C * (LR + GR); // hidden cycles - const double overhead = std::max(cvt - H, 0.0); - - // 6) Efficiency - //const double denom = mfma_cycles + overhead; - //const double eff = (denom > 0.0) ? (mfma_cycles / denom) : 1; - - return overhead; - } - - // Determine the compute latency per MT_MxMT_NxMT_K Macro Tile (L_MT). - size_t compute_mt_compute_latency(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, - size_t element_size_B, - data_type_t mi_datatype) - { - // Compute the number of matrix instructions - size_t N_MI - = compute_number_matrix_instructions(hardware, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K); - // Latency of a single MT_MxMT_NxMT_k tile is the latency of one MI multiplied by - // number of MI per MT_MxMT_NxMT_k. - size_t L_MI = hardware.get_mi_latency(MI_M, MI_N, MI_K, mi_datatype); - - // size_t mt_arith = arithmetic_intensity(MT_M, MT_N, MT_K, 2); - // printf("MT_M:%d MT_N:%d MT_K:%d arith:%d\n", MT_M, MT_N, MT_K, mt_arith); - // size_t arith = ((M * N * K * 2) / (M * K + N * K + M * N)); - size_t L_MT = L_MI * N_MI; - - return L_MT; - } - - /* ---------------------------------------------------------------------------------------- */ - /* Memory-related functions */ - /* ---------------------------------------------------------------------------------------- */ - // Check if MT fits in LDS - bool check_lds_capacity( - const hardware_t& hardware, size_t MT_M, size_t MT_N, size_t MT_K, size_t element_size) - { - // A and B size - size_t Ld_A_value = compute_A_loads(MT_M, MT_K); - size_t Ld_B_value = compute_B_loads(MT_N, MT_K); - // Size of those in bytes - size_t LDS_usage = (Ld_A_value + Ld_B_value) * (element_size / 8); - - if(LDS_usage > hardware.lds_capacity) - { - return false; // Exceeds LDS capacity - } - else - { - return true; // Within LDS capacity - } - } - - // Compute the amount of data loaded from A to produce a MT_MxMT_NxMT_K tile. - size_t compute_A_loads(size_t MT_M, size_t MT_K) - { - // Compute the size of loads from A for a single MT_MxMT_NxMT_K tiles - size_t Ld_A_value = MT_M * MT_K; - - return Ld_A_value; - } - - // Compute the amount of data loaded from B to produce a MT_MxMT_NxMT_K tile. - size_t compute_B_loads(size_t MT_N, size_t MT_K) - { - // Compute the size of loads from B for a single MT_MxMT_NxMT_K tiles - size_t Ld_B_value = MT_N * MT_K; +#include "origami/gemm.hpp" +#include "origami/streamk.hpp" - return Ld_B_value; +namespace origami { +double calculate_work_utilization(const problem_t& problem, const config_t& config) { + const size_t M = problem.size.m; + const size_t N = problem.size.n; + const size_t K = problem.size.k; + + const size_t MT_M = config.mt.m; + const size_t MT_N = config.mt.n; + const size_t MT_K = config.mt.k; + + if (MT_M <= 0 || MT_N <= 0) return 1.0; + + // Calculate the full dimensions covered by the launched grid of tiles (spatial). + const double launched_M = + static_cast(math::safe_ceil_div(M, MT_M)) * static_cast(MT_M); + const double launched_N = + static_cast(math::safe_ceil_div(N, MT_N)) * static_cast(MT_N); + + // Calculate the full depth covered by the k-loop iterations (temporal). + const double launched_K = + static_cast(math::safe_ceil_div(K, MT_K)) * static_cast(MT_K); + + // The utilization is the ratio of the useful problem volume to the total scheduled volume. + const double useful_volume = static_cast(M * N * K); + const double launched_volume = launched_M * launched_N * launched_K; + + if (launched_volume < 1.0) return 1.0; // Avoid division by zero for tiny/empty problems + + const double utilization = useful_volume / launched_volume; + + return utilization; +} + +double calculate_output_utilization(const problem_t& problem, + const config_t& config, + size_t vector_elems = 1) { + const size_t M = problem.size.m; + const size_t N = problem.size.n; + + const size_t MT_M = config.mt.m; + const size_t MT_N = config.mt.n; + + if (MT_M <= 0 || MT_N <= 0) return 1.0; + + // Tiled coverage in M/N + const double launched_M = + static_cast(math::safe_ceil_div(M, MT_M)) * static_cast(MT_M); + const double launched_N = + static_cast(math::safe_ceil_div(N, MT_N)) * static_cast(MT_N); + + // Optional: model vectorization/alignment remainders (e.g., ld/st width) + // This assumes vectors must be fully inside bounds; tail elements are scalarized. + const size_t M_vec = (vector_elems > 1) ? math::safe_ceil_div(M, vector_elems) * vector_elems : M; + const size_t N_vec = (vector_elems > 1) ? math::safe_ceil_div(N, vector_elems) * vector_elems : N; + + const double useful = static_cast(M_vec) * static_cast(N_vec); + const double launched = launched_M * launched_N; + + if (launched < 1.0) return 1.0; + return useful / launched; +} + +// Computes the number of active compute units if there is only one wave and it is partial +// Otherwise, returns hardware.N_CU +std::tuple compute_cu_occupancy(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + grid_selection_t grid_selection, + size_t max_cus, + size_t split = 0) { + // Number of output MTs + size_t num_mts = streamk::compute_number_of_output_tiles( + config.mt.m, config.mt.n, problem.size.m, problem.size.n, problem.batch); + + size_t num_wgs, num_active_cus, numWaves, splitFactor; + + if (split) // if it is given + { + split = split > 1 ? split : 1; + num_wgs = num_mts * split; + num_active_cus = num_wgs < hardware.N_CU ? num_wgs : hardware.N_CU; + numWaves = math::safe_ceil_div(num_wgs, hardware.N_CU); + splitFactor = split; + + if (get_runtime_options(config).debug_enabled) { + config.logger.log("reduction type", "Origami"); } - - // Compute limited achievable memory bandwidth based on active CUs - double compute_mem_bw_from_occupancy(const hardware_t& hardware, size_t numActiveCUs) - { - const double CUs = static_cast(numActiveCUs); - - if(numActiveCUs > hardware.N_CU) - return 1.0; - - const double bw_limited = std::get<0>(hardware.mem_bw_per_wg_coefficients) * CUs * CUs - + std::get<1>(hardware.mem_bw_per_wg_coefficients) * CUs - + std::get<2>(hardware.mem_bw_per_wg_coefficients); - - return std::min(bw_limited, 1.0); + } else // as what StreamK predicts + { + auto config_with_reduction = config; + config_with_reduction.reduction_strategy = + streamk::select_reduction(problem, hardware, config, grid_selection); + + num_wgs = streamk::select_grid_size( + problem, hardware, config_with_reduction, grid_selection, max_cus); + + // output variables + num_active_cus = num_wgs < hardware.N_CU ? num_wgs : hardware.N_CU; + // There are cases in which StreamK combines multiple output MTs and assigns to 1 WG. + // That means, we artifically observe one full wave, but that is not what actually happens + // under the hood. From a theoretical point of view, these distributions change all of the + // computations in Origami. With current implementation, it is hard to capture that + // behaviour analytically. So for now, if the num_wgs is less than the num_mts, we calculate + // numWaves based on the num_mts. Otherwise, we use num_wgs to compute numWaves. + numWaves = num_wgs > num_mts ? math::safe_ceil_div(num_wgs, hardware.N_CU) + : math::safe_ceil_div(num_mts, hardware.N_CU); + splitFactor = math::safe_ceil_div(num_wgs, num_mts); + } + + if (get_runtime_options(config).debug_enabled) { + config.logger.log("num_mts", num_mts); + config.logger.log("num_wgs", num_wgs); + config.logger.log("num_active_cus", num_active_cus); + config.logger.log("numWaves", numWaves); + config.logger.log("splitFactor", splitFactor); + config.logger.log("max_cus", max_cus); + } + + return std::make_tuple(num_wgs, num_active_cus, numWaves, splitFactor); +} + +/* ---------------------------------------------------------------------------------------- */ +/* Compute-related functions */ +/* ---------------------------------------------------------------------------------------- */ +// Compute the number of matrix instructions required to compute a single MT_MXMT_NXMT_K tile. +size_t compute_number_matrix_instructions(dim3_t mt, dim3_t mi) { + // Compute the number of Matrix Instructions required in each dim. + size_t num_m_instrs = math::safe_ceil_div(mt.m, mi.m); + size_t num_n_instrs = math::safe_ceil_div(mt.n, mi.n); + size_t num_k_instrs = math::safe_ceil_div(mt.k, mi.k); + + // Total number of matrix instructions. + size_t num_matrix_instrs = num_m_instrs * num_n_instrs * num_k_instrs; + + return num_matrix_instrs; +} + +// Compute arithmic intensity +double arithmetic_intensity(double m, double n, double k, double bytes_per_element) { + // Numerator: 2.0 * m * n * k + // Denominator: (m*n + n*k + m*k) * bytes_per_element + double numerator = 2.0 * m * n * k; + double denominator = (m * n + n * k + m * k) * bytes_per_element; + + return numerator / denominator; +} + +// Computes Emulated arithmetic intensity for TF32 (assumes 3xBF16). +double emulated_tf32_arithmetic_intensity(double m, double n, double k, double bytes_per_element) { + // Numerator: 3.0 * 2.0 * m * n * k + // Denominator: (m*n + n*k + m*k) * bytes_per_element + double numerator = 3.0 * 2.0 * m * n * k; + double denominator = (m * n + n * k + m * k) * bytes_per_element; + + return numerator / denominator; +} + +// Compute cvt overhead in x1 tf32 emulation +// TODO: We can generalize the same routine to cover more GEMMs that perform conversion +static inline double compute_cvt_overhead_x1(const problem_t& problem, + const hardware_t& hardware, + const config_t& config) { + // In X1 TF32 GEMMs, we do: + // v_cvt_pk_bf16_f32 (convert/pack fp32 to bf16) + // v_cvt_pk_bf16_f32 (convert/pack fp32 to bf16) + // ds_write_b64 + // That is, the extra instructions that we need to account for are the two cvt_pk ops + // per wave tile + + // However, these extra ops should not be added up to the overal tile latency becuase + // they can be run in parallel to Matix and Memory operations (given they are not dependent). + // So, We should ideally take L_tile = max{Mem, Comp, Vec (cvt latencies)}. + // Since, Vec latency is not modeled yet, we somehow model that into the current logic + // by scaling according to MFMA latencies and putting some heuristics to model the fact + // that these vector operations can be hidden (read interleaved) with the other memory + // or MFMA instructions. + + // --- Shorthands ----------------------------------------------------------- + const double MT_M = static_cast(config.mt.m); + const double MT_N = static_cast(config.mt.n); + const double MT_K = static_cast(config.mt.k); + + const double MI_M = static_cast(config.mi.m); + const double MI_N = static_cast(config.mi.n); + const double MI_K = static_cast(config.mi.k); + + const auto a_bytes = data_type_to_bytes(problem.a_dtype); + const auto b_bytes = data_type_to_bytes(problem.b_dtype); + + // TODO: Use kernel's actual wavetiles. + const double wave_tile_m = MT_M / 2.0; + const double wave_tile_n = MT_N / 2.0; + const double wave_tile_k = MT_K / MI_K; + + // MFMA count + const double N_MI = (wave_tile_m / MI_M) * (wave_tile_n / MI_N) * wave_tile_k; + const double num_mfma = 1.0 * N_MI; + // Cycle scale per MI + const double L_MI = hardware.get_mi_latency(MI_M, MI_N, MI_K, problem.mi_dtype); + const double mfma_cycles = num_mfma * L_MI; + + // 2) Bytes (per K-slice), using ceil-div to whole bytes + const double bytesA = wave_tile_m * MT_K * static_cast(a_bytes); + const double bytesB = wave_tile_n * MT_K * static_cast(b_bytes); + + // 3) Modeled transfer quanta (128B lines) + // dsA = bytesA / (128 * MI_M) + // dsB = bytesB / (128 * MI_N) + // GR = dsA (global->LDS modeled equal to A-side DS) + const double dsA = (bytesA / 128.0) / MI_M; // LDS->VGPR for A + const double dsB = (bytesB / 128.0) / MI_N; // LDS->VGPR for B + const double GR = dsA; // Global->LDS reads + const double LR = dsA + dsB; // total DS->VGPR + + // 5) Exposed vs hidden CVT + // spare MFMA + const double spare_mfma = std::max(0.0, num_mfma - LR - GR); + // 2 cvt per each ds_write (this for SS_BSS -- should be revised for other datatypes) + // Each cvt has a latency of four. It is scaled by the MI Latency + // Note: change 16.0 based on mi_data_type if we want to generalize this for all + // casting GEMMs. + const double cvt = (2.0 * 4.0 / 16.0 * L_MI) * LR; + // cvt ops are interleaved in main loop and don't stall matrix or memory units. + // Heuristically, we set + const double H = (8.0 / 16.0 * L_MI) * spare_mfma + (4.0 / 16.0) * L_MI * (LR + GR); + const double overhead = std::max(cvt - H, 0.0); + + return overhead; +} + +// Compute cvt overhead in tf32 emulation +static inline double compute_cvt_overhead(const problem_t& problem, + const hardware_t& hardware, + const config_t& config) { + // Wave tile sizes + // TODO: Use kernel's actual wavetiles. + const double wave_tile_m = config.mt.m / 2.0; + const double wave_tile_n = config.mt.n / 2.0; + const double wave_tile_k = config.mt.k / config.mi.k; + + // MFMA count and cycles + const double N_MI = (wave_tile_m / config.mi.m) * (wave_tile_n / config.mi.n) * wave_tile_k; + + // TF32 emu: 3× BF16 MI issue slots + const double num_mfma = 3.0 * static_cast(N_MI); + + // Cycle scale per MI (use BF16 MI latency as the basic timing quantum) + const double L_MI_bf16 = + hardware.get_mi_latency(config.mi.m, config.mi.n, config.mi.k, data_type_t::BFloat16); + // const double mfma_cycles = num_mfma * L_MI_bf16; + + // 2) Bytes (per K-slice), using ceil-div to whole bytes + int a_bytes = data_type_to_bytes(problem.a_dtype); + int b_bytes = data_type_to_bytes(problem.b_dtype); + + const double bytesA = static_cast(wave_tile_m) * config.mt.k * a_bytes; + const double bytesB = static_cast(wave_tile_n) * config.mt.k * b_bytes; + + // const double mt_bytesA + // = static_cast(MT_M) * MT_K * safe_ceil_div(element_size_A, 8); + + // 3) Modeled transfer quanta (128B lines) + // dsA = bytesA / (128 * MI_M) + // dsB = bytesB / (128 * MI_N) + // GR = dsA (global->LDS modeled equal to A-side DS) + const double dsA = (bytesA / 128.0) / static_cast(config.mi.m); // LDS->VGPR for A + const double dsB = (bytesB / 128.0) / static_cast(config.mi.n); // LDS->VGPR for B + const double GR = dsA; // Global->LDS reads + const double LR = dsA + dsB; // total DS->VGPR + + // 4) Heuristic cycle weights (scaled to MI latency). + // Preserves your A=104, B=8, C=4 when L_MI_bf16 == 16. + // 24 vector instructions per 2 ds_reads (16x16x32) + // 24 vector instructions per 2 ds_reads for A and for B. + // 3 instructions per fp32 value read; number ds_read * size + const double A = (104.0 / 16.0) * L_MI_bf16; // CVT per LR-sized chunk (DS->VGPR) + const double B = (8.0 / 16.0) * L_MI_bf16; // hidden per spare MFMA slot + // MI16: 16 - 4 (12 cycles), for those 4 cycles, VGPRs are locked. 8 cycles to do anything. + const double C = (4.0 / 16.0) * L_MI_bf16; // hidden per (LR+GR) slot // MI16 + // 32 cycles (mfma), 4 cycles, 28, 4 vgpr lock, 24 cycles left. + // 24: 6 conv instructions, 3 ds_reads, ~6 grs + + // 5) Exposed vs hidden CVT + const double spare_mfma = std::max(0.0, num_mfma - LR - GR); + const double cvt = A * dsA; // only DS->VGPR contributes CVT + const double H = B * spare_mfma + C * (LR + GR); // hidden cycles + const double overhead = std::max(cvt - H, 0.0); + + // 6) Efficiency + // const double denom = mfma_cycles + overhead; + // const double eff = (denom > 0.0) ? (mfma_cycles / denom) : 1; + + return overhead; +} + +// Determine the compute latency per MT_MxMT_NxMT_K Macro Tile (L_MT). +size_t compute_mt_compute_latency(const problem_t& problem, + const hardware_t& hardware, + const config_t& config) { + // Compute the number of matrix instructions + size_t N_MI = compute_number_matrix_instructions(config.mt, config.mi); + // Latency of a single MT_MxMT_NxMT_k tile is the latency of one MI multiplied by + // number of MI per MT_MxMT_NxMT_k. + size_t L_MI = hardware.get_mi_latency(config.mi.m, config.mi.n, config.mi.k, problem.mi_dtype); + + // size_t mt_arith = arithmetic_intensity(MT_M, MT_N, MT_K, 2); + // printf("MT_M:%d MT_N:%d MT_K:%d arith:%d\n", MT_M, MT_N, MT_K, mt_arith); + // size_t arith = ((M * N * K * 2) / (M * K + N * K + M * N)); + size_t L_MT = L_MI * N_MI; + + return L_MT; +} + +/* ---------------------------------------------------------------------------------------- */ +/* Memory-related functions */ +/* ---------------------------------------------------------------------------------------- */ +// Check if MT fits in LDS +bool check_lds_capacity(const hardware_t& hardware, + dim3_t mt, + data_type_t a_dtype, + data_type_t b_dtype) { + // A and B size + size_t a_loads_in_bytes = mt.mk() * data_type_to_bytes(a_dtype); + size_t b_loads_in_bytes = mt.nk() * data_type_to_bytes(b_dtype); + // Size of those in bytes + size_t LDS_usage = a_loads_in_bytes + b_loads_in_bytes; + + if (LDS_usage > hardware.lds_capacity) { + return false; // Exceeds LDS capacity + } else { + return true; // Within LDS capacity + } +} + +// Compute limited achievable memory bandwidth based on active CUs +double compute_mem_bw_from_occupancy(const hardware_t& hardware, size_t num_active_cus) { + const double CUs = static_cast(num_active_cus); + + if (num_active_cus > hardware.N_CU) return 1.0; + + const double bw_limited = std::get<0>(hardware.mem_bw_per_wg_coefficients) * CUs * CUs + + std::get<1>(hardware.mem_bw_per_wg_coefficients) * CUs + + std::get<2>(hardware.mem_bw_per_wg_coefficients); + + return std::min(bw_limited, 1.0); +} + +double estimate_l2_hit(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + size_t splitting_factor) { + // Use size_t for dimensions and counts to ensure type safety. + const size_t workgroups_m = math::safe_ceil_div(problem.size.m, config.mt.m); + const size_t workgroups_n = math::safe_ceil_div(problem.size.n, config.mt.n); + const size_t total_workgroups = workgroups_m * workgroups_n; + + // Concurrently executing workgroups are limited by the number of CUs.a + const size_t concurrent_workgroups = std::min(total_workgroups, hardware.N_CU); + if (concurrent_workgroups == 0) + throw std::runtime_error("#Workgroups is zero in estimate l2 hit"); + + // Number of CUs that might share the same K-tiles, adjusted for K-splitting. + // This affects contention on the L2 cache partitions (XCDs). + const size_t effective_cus = math::safe_ceil_div(concurrent_workgroups, splitting_factor); + const size_t cu_per_xcd = + std::max(math::safe_ceil_div(effective_cus, hardware.NUM_XCD), static_cast(1)); + + // Initial guess for the L2 tile dimensions (a tile of workgroups). + size_t l2_tile_n = std::min(static_cast(config.workgroup_mapping), workgroups_n); + size_t l2_tile_m = math::safe_ceil_div(cu_per_xcd, l2_tile_n); + + // Handle wrap-around case: if the tile is taller than the grid, wrap it to be wider. + if (l2_tile_m > workgroups_m) { + size_t num_wraps = (l2_tile_m / workgroups_m); + l2_tile_n += (num_wraps * config.workgroup_mapping); + l2_tile_m = workgroups_m; + } + + // Clamp initial tile dimensions to the actual grid size. + l2_tile_m = std::max(std::min(workgroups_m, l2_tile_m), static_cast(1)); + l2_tile_n = std::max(std::min(workgroups_n, l2_tile_n), static_cast(1)); + + // Calculate memory footprint in bytes. + const size_t a_bytes = static_cast(data_type_to_bytes(problem.a_dtype)); + const size_t b_bytes = static_cast(data_type_to_bytes(problem.b_dtype)); + auto calculate_footprint = [&](size_t tile_m, size_t tile_n) { + size_t a_footprint = tile_m * config.mt.mk() * a_bytes; + size_t b_footprint = tile_n * config.mt.nk() * b_bytes; + return a_footprint + b_footprint; + }; + + // Symmetrically shrink the L2 tile until it fits in the L2 cache capacity. + // This is more robust than shrinking only one dimension. + while (calculate_footprint(l2_tile_m, l2_tile_n) > hardware.L2_capacity) { + if (l2_tile_m > 1 && l2_tile_m >= l2_tile_n) { + l2_tile_m--; + } else if (l2_tile_n > 1) { + l2_tile_n--; + } else { + // Cannot shrink further. + break; } - - /* - * This heuristic models data reuse by defining a "tile of workgroups" that can fit - * its working set (portions of matrices A and B) into the L2 cache. The hit rate - * is the ratio of reused data reads to total data reads within this tile. - * - * @param M, N, K, batch Problem dimensions. - * @param MT_M, MT_N, MT_K Macro-tile dimensions. - * @param element_size Size of a single data element in bits. - * @param WGM Workgroup mapping size (typically 64). - * @param splittingFactor K-splitting factor, reduces L2 contention. - * @return Estimated L2 hit rate (0.0 to 1.0). + } + + // Uncached reads are the first read of each unique element within the L2 tile. + const long long uncached_A_reads = static_cast(l2_tile_m) * config.mt.mk(); + const long long uncached_B_reads = static_cast(l2_tile_n) * config.mt.nk(); + const long long total_uncached_reads = uncached_A_reads + uncached_B_reads; + + // Total reads are the sum of all reads performed by all workgroups in the L2 tile. + // Matrix A is reused l2_tile_n times, Matrix B is reused l2_tile_m times. + const long long total_A_reads = uncached_A_reads * l2_tile_n; + const long long total_B_reads = uncached_B_reads * l2_tile_m; + const long long total_reads = std::max(total_A_reads + total_B_reads, 1LL); + + const long long cached_reads = total_reads - total_uncached_reads; + + double l2_hit_rate = static_cast(cached_reads) / static_cast(total_reads); + + // Final clamping and logging. + if (get_runtime_options(config).debug_enabled) { + config.logger.log("L2Tile_M", l2_tile_m); + config.logger.log("L2Tile_N", l2_tile_n); + config.logger.log("TotalWorkgroups", total_workgroups); + config.logger.log("ConcurrentWorkgroups", concurrent_workgroups); + } + + // Clamp the hit rate to be within a realistic [0, 1] range. + return std::max(0.0, std::min(l2_hit_rate, 1.0)); +} + +// Estimate MALL hit-rate +double estimate_mall_hit(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + size_t num_active_cus, + size_t splitting_factor) { + const size_t workgroups_m = math::safe_ceil_div(problem.size.m, config.mt.m); + const size_t workgroups_n = math::safe_ceil_div(problem.size.n, config.mt.n); + + if (num_active_cus == 0) throw std::runtime_error("Number of Active CUs was 0"); + + // --- Initial Tile Sizing based on Concurrency --- + // Use ceiling division for a more accurate initial guess. + size_t mall_tile_m = + math::safe_ceil_div(num_active_cus, static_cast(config.workgroup_mapping)); + size_t mall_tile_n = std::min(static_cast(config.workgroup_mapping), workgroups_n); + + // Handle wrap-around case if the tile is taller than the grid. + if (mall_tile_m > workgroups_m) { + size_t num_wraps = mall_tile_m / workgroups_m; + mall_tile_n += (num_wraps * config.workgroup_mapping); + mall_tile_m = workgroups_m; + } + + // Clamp initial tile dimensions to the actual grid size. + mall_tile_m = std::max(std::min(workgroups_m, mall_tile_m), static_cast(1)); + mall_tile_n = std::max(std::min(workgroups_n, mall_tile_n), static_cast(1)); + + // --- CRITICAL: Shrink tile to fit into MALL Capacity --- + const size_t a_bytes = static_cast(data_type_to_bytes(problem.a_dtype)); + const size_t b_bytes = static_cast(data_type_to_bytes(problem.b_dtype)); + + auto calculate_footprint = [&](size_t tile_m, size_t tile_n) { + size_t a_footprint = tile_m * config.mt.mk() * a_bytes; + size_t b_footprint = tile_n * config.mt.nk() * b_bytes; + return a_footprint + b_footprint; + }; + + // --- Calculate Hit Rate based on the final, capacity-aware tile size --- + const long long uncached_A_reads = static_cast(mall_tile_m) * config.mt.mk(); + const long long uncached_B_reads = static_cast(mall_tile_n) * config.mt.nk(); + const long long total_uncached_reads = uncached_A_reads + uncached_B_reads; + + const long long total_A_reads = uncached_A_reads * mall_tile_n; + const long long total_B_reads = uncached_B_reads * mall_tile_m; + const long long total_reads = std::max(total_A_reads + total_B_reads, 1LL); + + const long long cached_reads = total_reads - total_uncached_reads; + + double mall_hit_rate = static_cast(cached_reads) / static_cast(total_reads); + + if (get_runtime_options(config).debug_enabled) { + config.logger.log("MallTile_M", mall_tile_m); + config.logger.log("MallTile_N", mall_tile_n); + config.logger.log("MallFootprint_Bytes", calculate_footprint(mall_tile_m, mall_tile_n)); + } + + // Clamp the final result to the valid [0, 1] range. + return std::max(0.0, std::min(mall_hit_rate, 1.0)); +} + +/** + * @brief L2 hit rate from a global (problem-wide) perspective using the refactored API. + * Computes in BYTES to correctly handle differing A/B dtypes. */ - double estimate_l2_hit(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t element_size, - int WGM, - size_t splittingFactor) - { - // Use size_t for dimensions and counts to ensure type safety. - const size_t workgroups_m = safe_ceil_div(M, MT_M); - const size_t workgroups_n = safe_ceil_div(N, MT_N); - const size_t total_workgroups = workgroups_m * workgroups_n; - - // Concurrently executing workgroups are limited by the number of CUs.a - const size_t concurrent_workgroups = std::min(total_workgroups, hardware.N_CU); - if(concurrent_workgroups == 0) - throw std::runtime_error("#Workgroups is zero in estimate l2 hit"); - - // Number of CUs that might share the same K-tiles, adjusted for K-splitting. - // This affects contention on the L2 cache partitions (XCDs). - const size_t effective_cus = safe_ceil_div(concurrent_workgroups, splittingFactor); - const size_t cu_per_xcd = std::max(safe_ceil_div(effective_cus, hardware.NUM_XCD), static_cast(1)); - - // Initial guess for the L2 tile dimensions (a tile of workgroups). - size_t l2_tile_n = std::min(static_cast(WGM), workgroups_n); - size_t l2_tile_m = safe_ceil_div(cu_per_xcd, l2_tile_n); - - // Handle wrap-around case: if the tile is taller than the grid, wrap it to be wider. - if(l2_tile_m > workgroups_m) - { - size_t num_wraps = (l2_tile_m / workgroups_m); - l2_tile_n += (num_wraps * WGM); - l2_tile_m = workgroups_m; - } - - // Clamp initial tile dimensions to the actual grid size. - l2_tile_m = std::max(std::min(workgroups_m, l2_tile_m), static_cast(1)); - l2_tile_n = std::max(std::min(workgroups_n, l2_tile_n), static_cast(1)); - - // Calculate memory footprint in bytes. - const size_t element_bytes = safe_ceil_div(element_size, 8); - auto calculate_footprint = [&](size_t tile_m, size_t tile_n) { - size_t a_footprint = tile_m * MT_M * MT_K * element_bytes; - size_t b_footprint = tile_n * MT_N * MT_K * element_bytes; - return a_footprint + b_footprint; - }; - - // Symmetrically shrink the L2 tile until it fits in the L2 cache capacity. - // This is more robust than shrinking only one dimension. - while(calculate_footprint(l2_tile_m, l2_tile_n) > hardware.L2_capacity) - { - if(l2_tile_m > 1 && l2_tile_m >= l2_tile_n) - { - l2_tile_m--; - } - else if(l2_tile_n > 1) - { - l2_tile_n--; - } - else - { - // Cannot shrink further. - break; - } - } - - // Uncached reads are the first read of each unique element within the L2 tile. - const long long uncached_A_reads = static_cast(l2_tile_m) * MT_M * MT_K; - const long long uncached_B_reads = static_cast(l2_tile_n) * MT_N * MT_K; - const long long total_uncached_reads = uncached_A_reads + uncached_B_reads; - - // Total reads are the sum of all reads performed by all workgroups in the L2 tile. - // Matrix A is reused l2_tile_n times, Matrix B is reused l2_tile_m times. - const long long total_A_reads = uncached_A_reads * l2_tile_n; - const long long total_B_reads = uncached_B_reads * l2_tile_m; - const long long total_reads = std::max(total_A_reads + total_B_reads, 1LL); - - const long long cached_reads = total_reads - total_uncached_reads; - - double l2_hit_rate = static_cast(cached_reads) / static_cast(total_reads); - - // Final clamping and logging. - if(hardware_t::is_debug_enabled()) - { - hardware.log_debug("L2Tile_M", l2_tile_m); - hardware.log_debug("L2Tile_N", l2_tile_n); - hardware.log_debug("TotalWorkgroups", total_workgroups); - hardware.log_debug("ConcurrentWorkgroups", concurrent_workgroups); - } - - // Clamp the hit rate to be within a realistic [0, 1] range. - return std::max(0.0, std::min(l2_hit_rate, 1.0)); +double compute_l2_hit_rate_global(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + size_t l2_capacity_bytes) { + // --- Hardware Parameters (as requested, defined locally) --- + // You would normally get l2_capacity_bytes from your hardware_t struct. + if (l2_capacity_bytes == 0) throw std::runtime_error("L2 Capacity is zero"); + + // 1. Calculate the grid dimensions in terms of macro-tiles + const size_t grid_m = math::safe_ceil_div(problem.size.m, config.mt.m); + const size_t grid_n = math::safe_ceil_div(problem.size.n, config.mt.n); + + if (grid_m == 0 || grid_n == 0) + throw std::runtime_error("estimate_l2_hit grid dimensions can not be zero"); + + // 2. Calculate the working set size for one full pass of global reuse + // This is the data needed by one full column of CUs (for A) and one full row (for B). + const double a_bytes = static_cast(data_type_to_bytes(problem.a_dtype)); + const double b_bytes = static_cast(data_type_to_bytes(problem.b_dtype)); + + const double a_working_set = static_cast(grid_m * config.mt.mk()) * a_bytes; + const double b_working_set = static_cast(grid_n * config.mt.nk()) * b_bytes; + const double total_working_set_bytes = a_working_set + b_working_set; + + // 3. CRUCIAL: Check if the working set fits in the L2 cache. + // If it doesn't, the global reuse pattern is broken by capacity misses, + // and the hit rate will be very low. + if (total_working_set_bytes > l2_capacity_bytes) { + // Return a floor value for the hit rate. The exact value can be tuned, + // but it should be low to indicate that the ideal reuse is not possible. + return 0.1; // 10% hit rate + } + + // 4. If it fits, calculate the idealized global hit rate + // Total reads if nothing was cached + const double total_A_reads = static_cast(grid_m * grid_n * config.mt.mk()); + const double total_B_reads = static_cast(grid_m * grid_n * config.mt.nk()); + + // Uncached reads are the first-time fetches for each row/column + const double uncached_A_reads = + static_cast(grid_m * config.mt.mk()); // One full column fetches A + const double uncached_B_reads = + static_cast(grid_n * config.mt.nk()); // One full row fetches B + + const double total_reads = total_A_reads + total_B_reads; + if (total_reads == 0) return 1.0; // No reads, perfect hit rate. + + const double cached_reads = + (total_A_reads - uncached_A_reads) + (total_B_reads - uncached_B_reads); + + return cached_reads / total_reads; +} + +inline size_t round_up_mul(size_t x, size_t m) { return (x + m - 1) / m * m; } + +size_t round_elements_to_128B(size_t elements, size_t element_size_bits) { + const size_t transaction_bits = 128u * 8u; // 1024 + const size_t g = std::gcd(element_size_bits, transaction_bits); + const size_t E_block = transaction_bits / g; // elements per 128B-aligned chunk + return round_up_mul(elements, E_block); +} + +// Determine the memory latency +double compute_memory_latency(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + size_t num_active_cus, + size_t splitting_factor) { + // Extract parameters from structured types + const auto a_bytes = data_type_to_bytes(problem.a_dtype); + const auto b_bytes = data_type_to_bytes(problem.b_dtype); + const auto a_bits = datatype_to_bits(problem.a_dtype); + const auto b_bits = datatype_to_bits(problem.b_dtype); + size_t batch = problem.batch; + + const bool a_trans = (problem.a_transpose == transpose_t::T); + const bool b_trans = (problem.b_transpose == transpose_t::T); + + const size_t MT_M = config.mt.m; + const size_t MT_N = config.mt.n; + const size_t MT_K = config.mt.k; + + // 1) Estimate L2 hit-rate + double H_mem1 = estimate_l2_hit(problem, hardware, config, splitting_factor); + + // Global cap on L2 hit-rate (prevents impossible cache residency claims) + // (Assumes capacity is given in KiB, convert to bytes) + double H_mem1_global = + compute_l2_hit_rate_global(problem, hardware, config, hardware.L2_capacity * 1024); + + H_mem1 = std::min(H_mem1, H_mem1_global); + + if (H_mem1 == 0) { H_mem1 = 0.5; } + + // 2) Estimate mall hit-rate + double H_mem2 = estimate_mall_hit(problem, hardware, config, num_active_cus, splitting_factor); + + // 3) Total loads are loads from A and loads from B + size_t MT_M_rounded_128bytes = round_elements_to_128B(MT_M, a_bits); + size_t MT_N_rounded_128bytes = round_elements_to_128B(MT_N, a_bits); + size_t MT_K_rounded_128bytes = round_elements_to_128B(MT_K, a_bits); + + if (!a_trans && !b_trans) { + MT_N_rounded_128bytes = MT_N; + MT_K_rounded_128bytes = MT_K; + } else if (a_trans && !b_trans) { + MT_M_rounded_128bytes = MT_M; + MT_N_rounded_128bytes = MT_N; + } else if (!a_trans && b_trans) { + MT_K_rounded_128bytes = MT_K; + } + + size_t Ld_A_value = MT_M_rounded_128bytes * MT_K_rounded_128bytes; + size_t Ld_B_value = MT_N_rounded_128bytes * MT_K_rounded_128bytes; + size_t Ld_CU_bytes = (Ld_A_value * static_cast(a_bytes)) // A Bytes + + (Ld_B_value * static_cast(b_bytes)); // B Bytes + + // Logic for block scaled datatypes (Assuming BS=32 and 8-bit scales) + // TODO This is technically wrong, need separate flag to enable MX so we can differentiate FP8 + // and MX8 + if (a_bits < 8 && problem.a_mx_block_size != 0) { + // Number of scales per tile + size_t num_scales_A = math::safe_ceil_div(config.mt.mk(), problem.a_mx_block_size); + Ld_CU_bytes += num_scales_A; // One Byte per scale + } + if (b_bits < 8 && problem.b_mx_block_size != 0) { + // Number of scales per tile + size_t num_scales_B = math::safe_ceil_div(config.mt.nk(), problem.b_mx_block_size); + Ld_CU_bytes += num_scales_B; // One Byte per scale + } + + // 4) total loads by all CUs + double total_Ld = Ld_CU_bytes * static_cast(num_active_cus); + + // 5) mem1‐limited factor (simple linear model) + double mem1_bw_limited = static_cast(num_active_cus) / static_cast(hardware.N_CU); + double limited_mem1_bw = (hardware.mem1_perf_ratio * mem1_bw_limited); + + // 6) mem1 latency + double L_mem_mem1 = (limited_mem1_bw > 0) ? (total_Ld / (limited_mem1_bw)) : 0.0; + + // 7) mem2‐limited from occupancy (Can't Issue enough load/stores) + double bw_limited = compute_mem_bw_from_occupancy(hardware, num_active_cus); + + // 8) loads that reach each level + double Ld_mem2 = (1.0 - H_mem1) * total_Ld; + double Ld_MEM = (1.0 - H_mem2) * Ld_mem2; + + // 9) enforce whole‐problem minimum loads when we can fit M/N in the CUs. + // Calculate the tile of workgroups that can run concurrently (logic from estimate_mall_hit). + size_t grid_m = math::safe_ceil_div(problem.size.m, MT_M); + size_t grid_n = math::safe_ceil_div(problem.size.n, MT_N); + size_t mall_m = + math::safe_ceil_div(num_active_cus, static_cast(config.workgroup_mapping)); + size_t mall_n = std::min(static_cast(config.workgroup_mapping), grid_n); + // Handle wrap-around case + if (mall_m > grid_m) { + size_t num_wraps = (mall_m / grid_m); + mall_n += (num_wraps * config.workgroup_mapping); + mall_m = grid_m; + } + // Clamp tile dimensions + mall_m = std::max(std::min(grid_m, mall_m), static_cast(1)); + mall_n = std::max(std::min(grid_n, mall_n), static_cast(1)); + // This is the minimum unique bytes needed from HBM to feed the concurrent workgroups. + double min_load = static_cast((mall_m * config.mt.mk() * static_cast(a_bytes)) + + (mall_n * config.mt.nk() * static_cast(b_bytes))) * + batch; // Apply batching to the minimum load itself. + // The actual loads cannot be less than this physical minimum. + Ld_MEM = std::max(Ld_MEM, min_load); + Ld_mem2 = std::max(Ld_mem2, min_load); + + // 10) mem2 latency + double limited_mem2_bw = (hardware.mem2_perf_ratio * bw_limited); + double L_mem_mem2 = (limited_mem2_bw > 0) ? (Ld_mem2 / limited_mem2_bw) : 0.0; + + // 11) MEM latency + double limited_mem_bw = (hardware.mem3_perf_ratio * bw_limited); + double L_mem_MEM = (limited_mem_bw > 0) ? (Ld_MEM / limited_mem_bw) : 0.0; + L_mem_MEM += 200; // Load Latency + + // 12) pick the worst‐case bound + double L_mem = std::max({L_mem_mem1, L_mem_mem2, L_mem_MEM}); + + if (get_runtime_options(config).debug_enabled) { + config.logger.log("mem1_perf_ratio", hardware.mem1_perf_ratio); + config.logger.log("mem2_perf_ratio", hardware.mem2_perf_ratio); + config.logger.log("mem3_perf_ratio", hardware.mem3_perf_ratio); + config.logger.log("mem_bw_per_wg_coefficients(0)", + std::get<0>(hardware.mem_bw_per_wg_coefficients)); + config.logger.log("mem_bw_per_wg_coefficients(1)", + std::get<1>(hardware.mem_bw_per_wg_coefficients)); + config.logger.log("mem_bw_per_wg_coefficients(2)", + std::get<2>(hardware.mem_bw_per_wg_coefficients)); + config.logger.log("H_mem1 (mem1 hit ratio)", H_mem1); + config.logger.log("H_mem2 (mem2 hit ratio)", H_mem2); + config.logger.log("Total Load (bytes)", total_Ld); + config.logger.log("Ld_mem2 (bytes)", Ld_mem2); + config.logger.log("Ld_MEM (bytes)", Ld_MEM); + config.logger.log("L_mem_mem1 (cycles)", L_mem_mem1); + config.logger.log("L_mem_mem2 (cycles)", L_mem_mem2); + config.logger.log("L_mem_MEM (cycles)", L_mem_MEM); + config.logger.log("MT_K % 128 bytes", MT_K * static_cast(b_bytes) % 128); + config.logger.log("MT_M % 128 bytes", MT_M * static_cast(a_bytes) % 128); + config.logger.log("MT_N % 128 bytes", MT_N * static_cast(b_bytes) % 128); + config.logger.log( + "MT_N % 128 + MT_M % 128 bytes", + (MT_M * static_cast(a_bytes) % 128) + MT_N * static_cast(b_bytes) % 128); + config.logger.log( + "MT_N % 64 + MT_M % 64 bytes", + (MT_M * static_cast(a_bytes) % 64) + MT_N * static_cast(b_bytes) % 64); + config.logger.log("MT_K % 64 bytes", MT_K * static_cast(b_bytes) % 64); + config.logger.log("MT_M % 64 bytes", MT_M * static_cast(a_bytes) % 64); + config.logger.log("MT_N % 64 bytes", MT_N * static_cast(b_bytes) % 64); + config.logger.log("Tile Arithmetic Intensity", + MT_M * MT_N * MT_K / (MT_M * MT_K + MT_N * MT_K)); + } + + return L_mem; +} + +/* ---------------------------------------------------------------------------------------- */ +/* Tile-related functions */ +/* ---------------------------------------------------------------------------------------- */ +double compute_tile_latency(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + size_t num_active_cus, + size_t splitting_factor) { + // Extract parameters from structured types + const size_t K = problem.size.k; + size_t batch = problem.batch; + + const size_t MT_M = config.mt.m; + const size_t MT_N = config.mt.n; + const size_t MT_K = config.mt.k; + + const auto a_bits = datatype_to_bits(problem.a_dtype); + const auto b_bits = datatype_to_bits(problem.b_dtype); + const size_t a_bytes = static_cast(data_type_to_bytes(problem.a_dtype)); + const size_t d_bytes = static_cast(data_type_to_bytes(problem.d_dtype)); + + // 1) Compute per-tile latencies + double L_compute = compute_mt_compute_latency(problem, hardware, config); + + double L_mem = + compute_memory_latency(problem, hardware, config, num_active_cus, splitting_factor); + + // TODO Does work utilization need to be 128-byte rounded for a cache line? + double utilization = calculate_work_utilization(problem, config); + double output_utilization = calculate_output_utilization(problem, config, 1UL); + // The effective latency per useful operation increases as utilization drops. + // This penalty affects BOTH compute and memory bounds for the tile's core work. + double effective_tile_penalty = (utilization > 1e-9) ? (1.0 / (utilization)) : 1.0; + double output_utilization_penalty = + (output_utilization > 1e-9) ? (1.0 / (output_utilization)) : 1.0; + // 2) Work-group setup & iteration latencies + double L_WG_setup = 1; // WG_setup_Latency + + // 3) Prologue: 2.2× memory latency + double L_prologue = 1.5 * L_mem; // 1.5 chosen emprically + + // L_compute *= std::max(L_compute, L_LDS); + + // 4) Epilogue: writes from all active CUs with limited bandwidth + double mem_bw_occ = compute_mem_bw_from_occupancy(hardware, num_active_cus); + double mem_bw_occ_limited = hardware.mem3_perf_ratio * mem_bw_occ; + size_t MT_M_rounded_128bytes = round_elements_to_128B(MT_M, datatype_to_bits(problem.a_dtype)); + + double L_epilogue = (static_cast(num_active_cus / splitting_factor) * + MT_M_rounded_128bytes * MT_N * static_cast(d_bytes)) / + mem_bw_occ_limited; + // One compute iteration happens in the prologue + L_epilogue += L_compute * effective_tile_penalty; + // Epilogue and Prologue overhead are reduced with higher occupancy kernels. + int grid_m = static_cast(math::safe_ceil_div(problem.size.m, MT_M)); + int grid_n = static_cast(math::safe_ceil_div(problem.size.n, MT_N)); + + size_t real_occupancy = + std::min(std::max(config.occupancy, static_cast(1)), + static_cast(math::safe_ceil_div(grid_m * grid_n * batch * splitting_factor, + hardware.N_CU))); // Number of WGs per CU. + + L_prologue = L_prologue * pow(0.95, real_occupancy); // Factor chosen empirically + L_epilogue = L_epilogue * pow(0.95, real_occupancy); // Factor chosen empirically + // 4') K-split reductions are globally coherent, we need to write and read split-1 MT_M*MT_N + // tiles to coherent memory + if (splitting_factor > 1) { + size_t n_partials = splitting_factor - 1; + + // Only the reduction CU reads from all splits. + double partial_read_bytes = + grid_m * grid_n * n_partials * MT_M_rounded_128bytes * MT_N * static_cast(d_bytes); + + // All CUs write (once for each partial, and once by the reduction CU for the output.) + double partial_write_bytes = + grid_m * grid_n * MT_M_rounded_128bytes * MT_N * static_cast(d_bytes); + + double partial_readwrite_bytes = partial_read_bytes + partial_write_bytes; + + // 64 Threads active in a SIMD. Exposed to at least latency of reducing splitting_factor + // tiles. + double partial_adds = + (static_cast(config.mt.mn()) * static_cast(splitting_factor)) / (64); + + double L_reduce = partial_readwrite_bytes / (mem_bw_occ_limited); + L_epilogue += L_reduce + partial_adds + 10000; + } + // 4'') tf32 emu has some more overhead + double L_cvt = 0; + if ((problem.mi_dtype == data_type_t::XFloat32) && + (hardware.arch == hardware_t::architecture_t::gfx950)) { + L_cvt = compute_cvt_overhead(problem, hardware, config); + } else if ((a_bits == 32) && (b_bits == 32) && (problem.mi_dtype == data_type_t::BFloat16) && + (hardware.arch == hardware_t::architecture_t::gfx950)) // SS_BSS on GFX950 + { + L_cvt = compute_cvt_overhead_x1(problem, hardware, config); + } + + // 5) Single-tile latency (always additive) + // Calculate the fraction of the work that is useful (not padding). + + // 5) Single-tile latency (apply penalty after finding the bottleneck) + double L_tile_single = (std::max(L_compute, L_mem) * effective_tile_penalty) + L_cvt; + L_prologue *= effective_tile_penalty; + // 6) Number of K-iterations (excluding epilogue), at least 1 + // long num_iter = static_cast(((K + MT_K - 1) / MT_K)) - 1; + // num_iter = std::ceil(num_iter / splitting_factor); + // num_iter = std::max(num_iter, 1L); + const long k_per_split = static_cast(math::safe_ceil_div(K, splitting_factor)); + long num_iter = + std::max(static_cast(math::safe_ceil_div(static_cast(k_per_split), MT_K) - 1), + static_cast(1)); + // Zero Padding in the K dimension on last iteration + if (K % MT_K != 0) { + const double problem_k_quant = static_cast(K % MT_K) / static_cast(K); + L_epilogue += problem_k_quant * 50000; // Scale by remainder proportion of problem. 50k cycle + // penalty if have to zero pad all except 1. + //(Scale Determined Empirically) + } + // L_epilogue *= output_utilization_penalty; + + // 7) Total tile latency + double L_tile_total = + (L_tile_single * static_cast(num_iter)) + L_prologue + L_epilogue * 2 + L_WG_setup + + (500 * static_cast( + num_iter)); // 7 instructions (each with 4 cycles) at the end of the loop + + if (get_runtime_options(config).debug_enabled) { + double problem_k_quant = ((K % MT_K) / (double)K); + config.logger.log("Iteration Compute Latency", L_compute); + config.logger.log("L_mem", L_mem); + config.logger.log("L_cvt", L_cvt); + config.logger.log("L_tile_single", L_tile_single); + config.logger.log("num_iter", num_iter); + config.logger.log("L_prologue", L_prologue); + config.logger.log("L_epilogue", L_epilogue); + config.logger.log("L_tile_total", L_tile_total); + config.logger.log("Effective Tile Penalty", effective_tile_penalty); + config.logger.log("Problem K quant", problem_k_quant); + config.logger.log("K quant overhead", (problem_k_quant * 50000)); + config.logger.log("Problem Tile Quant", utilization); + config.logger.log("Real Occupancy", utilization); + config.logger.log("Output Utilization Penalty", output_utilization_penalty); + config.logger.log("Output Utilization", output_utilization); + std::string bound_source; + if (L_compute >= L_mem) { + L_tile_single = L_compute + L_cvt; + bound_source = "Compute"; + } else { + L_tile_single = L_mem + L_cvt; + bound_source = "Memory"; } - - // Estimate MALL hit-rate - double estimate_mall_hit(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t element_size, - int WGM, - size_t numActiveCUs, - size_t splittingFactor) - { - const size_t workgroups_m = safe_ceil_div(M, MT_M); - const size_t workgroups_n = safe_ceil_div(N, MT_N); - - if(numActiveCUs == 0) - throw std::runtime_error("Number of Active CUs was 0"); - - // --- Initial Tile Sizing based on Concurrency --- - // Use ceiling division for a more accurate initial guess. - size_t mall_tile_m = safe_ceil_div(numActiveCUs, static_cast(WGM)); - size_t mall_tile_n = std::min(static_cast(WGM), workgroups_n); - - // Handle wrap-around case if the tile is taller than the grid. - if(mall_tile_m > workgroups_m) - { - size_t num_wraps = mall_tile_m / workgroups_m; - mall_tile_n += (num_wraps * WGM); - mall_tile_m = workgroups_m; - } - - // Clamp initial tile dimensions to the actual grid size. - mall_tile_m = std::max(std::min(workgroups_m, mall_tile_m), static_cast(1)); - mall_tile_n = std::max(std::min(workgroups_n, mall_tile_n), static_cast(1)); - - // --- CRITICAL: Shrink tile to fit into MALL Capacity --- - const size_t element_bytes = safe_ceil_div(element_size, 8); - auto calculate_footprint = [&](size_t tile_m, size_t tile_n) { - size_t a_footprint = tile_m * MT_M * MT_K * element_bytes; - size_t b_footprint = tile_n * MT_N * MT_K * element_bytes; - return a_footprint + b_footprint; - }; - - // --- Calculate Hit Rate based on the final, capacity-aware tile size --- - const long long uncached_A_reads = static_cast(mall_tile_m) * MT_M * MT_K; - const long long uncached_B_reads = static_cast(mall_tile_n) * MT_N * MT_K; - const long long total_uncached_reads = uncached_A_reads + uncached_B_reads; - - const long long total_A_reads = uncached_A_reads * mall_tile_n; - const long long total_B_reads = uncached_B_reads * mall_tile_m; - const long long total_reads = std::max(total_A_reads + total_B_reads, 1LL); - - const long long cached_reads = total_reads - total_uncached_reads; - - double mall_hit_rate = static_cast(cached_reads) / static_cast(total_reads); - - if(hardware_t::is_debug_enabled()) - { - hardware.log_debug("MallTile_M", mall_tile_m); - hardware.log_debug("MallTile_N", mall_tile_n); - hardware.log_debug("MallFootprint_Bytes", - calculate_footprint(mall_tile_m, mall_tile_n)); - } - - // Clamp the final result to the valid [0, 1] range. - return std::max(0.0, std::min(mall_hit_rate, 1.0)); - } - - /** - @brief Computes the L2 hit rate from a global, - problem - wide perspective. - **/ - double compute_l2_hit_rate_global(size_t M, - size_t N, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t element_size, - size_t l2_capacity_bytes) - { - // --- Hardware Parameters (as requested, defined locally) --- - // You would normally get l2_capacity_bytes from your hardware_t struct. - if(l2_capacity_bytes == 0) - throw std::runtime_error("L2 Capacity is zero"); - ; - - // 1. Calculate the grid dimensions in terms of macro-tiles - const size_t grid_m = safe_ceil_div(M, MT_M); - const size_t grid_n = safe_ceil_div(N, MT_N); - - if(grid_m == 0 || grid_n == 0) - throw std::runtime_error("estimate_l2_hit grid dimensions can not be zero"); - ; - - // 2. Calculate the working set size for one full pass of global reuse - // This is the data needed by one full column of CUs (for A) and one full row (for B). - const double bytes_per_element = static_cast(element_size) / 8.0; - const double a_working_set = static_cast(grid_m * MT_M * MT_K) * bytes_per_element; - const double b_working_set = static_cast(grid_n * MT_N * MT_K) * bytes_per_element; - const double total_working_set_bytes = a_working_set + b_working_set; - - // 3. CRUCIAL: Check if the working set fits in the L2 cache. - // If it doesn't, the global reuse pattern is broken by capacity misses, - // and the hit rate will be very low. - if(total_working_set_bytes > l2_capacity_bytes) - { - // Return a floor value for the hit rate. The exact value can be tuned, - // but it should be low to indicate that the ideal reuse is not possible. - return 0.1; // 10% hit rate - } - - // 4. If it fits, calculate the idealized global hit rate - // Total reads if nothing was cached - const double total_A_reads = static_cast(grid_m * grid_n * MT_M * MT_K); - const double total_B_reads = static_cast(grid_m * grid_n * MT_N * MT_K); - - // Uncached reads are the first-time fetches for each row/column - const double uncached_A_reads - = static_cast(grid_m * MT_M * MT_K); // One full column fetches A - const double uncached_B_reads - = static_cast(grid_n * MT_N * MT_K); // One full row fetches B - - const double total_reads = total_A_reads + total_B_reads; - if(total_reads == 0) - return 1.0; // No reads, perfect hit rate. - - const double cached_reads - = (total_A_reads - uncached_A_reads) + (total_B_reads - uncached_B_reads); - - return cached_reads / total_reads; - } - - inline size_t round_up_mul(size_t x, size_t m) - { - return (x + m - 1) / m * m; - } - - size_t round_elements_to_128B(size_t elements, size_t element_size_bits) - { - const size_t transaction_bits = 128u * 8u; // 1024 - const size_t g = std::gcd(element_size_bits, transaction_bits); - const size_t E_block = transaction_bits / g; // elements per 128B-aligned chunk - return round_up_mul(elements, E_block); - } - - // Determine the memory latency - double compute_memory_latency(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - bool transA, - bool transB, - size_t batch, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t element_size_A, - size_t element_size_B, - size_t mx_block_size, - int WGM, - size_t numActiveCUs, - size_t splittingFactor) - { - // 1) Estimate L2 hit-rate - double H_mem1 = estimate_l2_hit( - hardware, M, N, K, batch, MT_M, MT_N, MT_K, element_size_A, WGM, splittingFactor); - - double H_mem1_global = compute_l2_hit_rate_global( - M, N, MT_M, MT_N, MT_K, element_size_A, hardware.L2_capacity * 1024); - - H_mem1 = std::min(H_mem1, H_mem1_global); - - if(H_mem1 == 0) - { - H_mem1 = 0.5; - } - - // 2) Estimate mall hit-rate - double H_mem2 = estimate_mall_hit(hardware, - M, - N, - K, - batch, - MT_M, - MT_N, - MT_K, - element_size_A, - WGM, - numActiveCUs, - splittingFactor); - - // 3) Total loads are loads from A and loads from B - size_t MT_M_rounded_128bytes = round_elements_to_128B(MT_M, element_size_A); - size_t MT_N_rounded_128bytes = round_elements_to_128B(MT_N, element_size_A); - size_t MT_K_rounded_128bytes = round_elements_to_128B(MT_K, element_size_A); - if(!transA && !transB) - { - MT_N_rounded_128bytes = MT_N; - MT_K_rounded_128bytes = MT_K; - } - else if(transA && !transB) - { - MT_M_rounded_128bytes = MT_M; - MT_N_rounded_128bytes = MT_N; - } - else if(!transA && transB) - { - MT_K_rounded_128bytes = MT_K; + config.logger.log("Iteration Bound", bound_source + " (" + std::to_string(L_tile_single) + ")"); + config.logger.log("K % MT_K", K % MT_K); + } + + return L_tile_total; +} + +// Computes the latency per K-complete MT wave +// A wave is defined as : The time it takes for one CU to complete one K-complete output tile +double compute_timestep_latency(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + size_t num_active_cus, + size_t splitting_factor) { + // Assume latency of a wave is latency of a single k-complete output tile. + double L_wave = compute_tile_latency(problem, hardware, config, num_active_cus, splitting_factor); + + return L_wave; +} + +// Compute the total latency of a gemm based on the latency of one wave multiplied by the number of +// waves A wave is defined as : The time it takes for one CU to complete one K-complete output tile +double compute_total_latency(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + size_t max_cus) { + assert(config.is_valid()); + + // Extract parameters from structured types + size_t M = problem.size.m; + size_t N = problem.size.n; + size_t K = problem.size.k; + size_t batch = problem.batch; + + bool a_trans = problem.a_transpose == transpose_t::T; + bool b_trans = problem.b_transpose == transpose_t::T; + + size_t MT_M = config.mt.m; + size_t MT_N = config.mt.n; + size_t MT_K = config.mt.k; + size_t MI_M = config.mi.m; + size_t MI_N = config.mi.n; + size_t MI_K = config.mi.k; + + const int a_bits = datatype_to_bits(problem.a_dtype); + const int b_bits = datatype_to_bits(problem.b_dtype); + const int a_bytes = data_type_to_bytes(problem.a_dtype); + const int d_bytes = data_type_to_bytes(problem.d_dtype); + + if (get_runtime_options(config).debug_enabled) { + config.logger.log( + "Problem_Size", + std::to_string(int(M)) + "x" + std::to_string(int(N)) + "x" + std::to_string(int(K))); + config.logger.log("Batch", std::to_string(int(batch))); + config.logger.log("Macro_Tile", + std::to_string(int(MT_M)) + "x" + std::to_string(int(MT_N)) + "x" + + std::to_string(int(MT_K))); + config.logger.log("Element Size A (bits)", a_bits); + config.logger.log("Element Size B (bits)", b_bits); + } + + // 0) Short-circuit + // We don't need to compute latency for all MTs. With this, we can shortcut. + bool shortCircuit = true; + if (shortCircuit) { + // When problem dimensions are small enough that we can fit them in one tile, we should do + // so. This short circuit condition also decreases selection latency when problems are very + // small :) + // TODO 256 and 256 here should be largest M and N tile dimensions in library + if (M <= 256 && N <= 256 && K < 1024 && batch != 1 && (MT_M < M || MT_N < N)) + return std::numeric_limits::max(); + + // Use Dot2 only for M < 3 + if (MI_M == 1 && MI_N == 1 && MI_K == 64 && M > 2) return std::numeric_limits::max(); + + size_t K_mod_128bytes = K * a_bytes % 128; + size_t MT_K_mod_128bytes = MT_K * a_bytes % 128; + if (K_mod_128bytes == 0 && MT_K_mod_128bytes == 0) { + // avoid division by 0 if K == 0 + if (M <= MT_M * 2 && !b_trans && ((N * b_bits) / (M * a_bits) > 5)) { + // Use nontemporal B + if (!(config.cache_hints_b == 4)) { return std::numeric_limits::max(); } + } else if (N <= MT_N * 2 && a_trans && ((M * a_bits) / (N * b_bits) > 5)) { + // Use Non Temporal A + if (!(config.cache_hints_a == 4)) { return std::numeric_limits::max(); } + } else { + // Never use Non Temporal + if (config.cache_hints_a || config.cache_hints_b) { + return std::numeric_limits::max(); } - - size_t Ld_A_value = compute_A_loads(MT_M_rounded_128bytes, MT_K_rounded_128bytes); - size_t Ld_B_value = compute_B_loads(MT_N_rounded_128bytes, MT_K_rounded_128bytes); - size_t Ld_CU_bytes = (Ld_A_value * safe_ceil_div(element_size_A, 8)) // A Bytes - + (Ld_B_value * safe_ceil_div(element_size_B, 8)); // B Bytes - - // Logic for block scaled datatypes (Assuming BS=32 and 8-bit scales) - // TODO This is technically wrong, need separate flag to enable MX so we can differentiate FP8 - // and MX8 - if(element_size_A < 8 && mx_block_size != 0) - { - // Number of scales per tile - size_t num_scales_A = safe_ceil_div(MT_M * MT_K, mx_block_size); - Ld_CU_bytes += num_scales_A; // One Byte per scale - } - if(element_size_B < 8 && mx_block_size != 0) - { - // Number of scales per tile - size_t num_scales_B = safe_ceil_div(MT_N * MT_K, mx_block_size); - Ld_CU_bytes += num_scales_B; // One Byte per scale - } - - // 4) total loads by all CUs - double total_Ld = Ld_CU_bytes * static_cast(numActiveCUs); - - // 5) mem1‐limited factor (simple linear model) - double mem1_bw_limited - = static_cast(numActiveCUs) / static_cast(hardware.N_CU); - double limited_mem1_bw = (hardware.mem1_perf_ratio * mem1_bw_limited); - - // 6) mem1 latency - double L_mem_mem1 = (limited_mem1_bw > 0) ? (total_Ld / (limited_mem1_bw)) : 0.0; - - // 7) mem2‐limited from occupancy (Can't Issue enough load/stores) - double bw_limited = compute_mem_bw_from_occupancy(hardware, numActiveCUs); - - // 8) loads that reach each level - double Ld_mem2 = (1.0 - H_mem1) * total_Ld; - double Ld_MEM = (1.0 - H_mem2) * Ld_mem2; - - // 9) enforce whole‐problem minimum loads when we can fit M/N in the CUs. - // Calculate the tile of workgroups that can run concurrently (logic from estimate_mall_hit). - size_t grid_m = safe_ceil_div(M, MT_M); - size_t grid_n = safe_ceil_div(N, MT_N); - size_t mall_m = safe_ceil_div(numActiveCUs, static_cast(WGM)); - size_t mall_n = std::min(static_cast(WGM), grid_n); - // Handle wrap-around case - if(mall_m > grid_m) - { - size_t num_wraps = (mall_m / grid_m); - mall_n += (num_wraps * WGM); - mall_m = grid_m; - } - // Clamp tile dimensions - mall_m = std::max(std::min(grid_m, mall_m), static_cast(1)); - mall_n = std::max(std::min(grid_n, mall_n), static_cast(1)); - // This is the minimum unique bytes needed from HBM to feed the concurrent workgroups. - double min_load - = static_cast((mall_m * MT_M * MT_K * safe_ceil_div(element_size_A, 8)) - + (mall_n * MT_N * MT_K * safe_ceil_div(element_size_B, 8))) - * batch; // Apply batching to the minimum load itself. - // The actual loads cannot be less than this physical minimum. - Ld_MEM = std::max(Ld_MEM, min_load); - Ld_mem2 = std::max(Ld_mem2, min_load); - - // 10) mem2 latency - double limited_mem2_bw = (hardware.mem2_perf_ratio * bw_limited); - double L_mem_mem2 = (limited_mem2_bw > 0) ? (Ld_mem2 / limited_mem2_bw) : 0.0; - - // 11) MEM latency - double limited_mem_bw = (hardware.mem3_perf_ratio * bw_limited); - double L_mem_MEM = (limited_mem_bw > 0) ? (Ld_MEM / limited_mem_bw) : 0.0; - L_mem_MEM += 200; // Load Latency - - // 12) pick the worst‐case bound - double L_mem = std::max({L_mem_mem1, L_mem_mem2, L_mem_MEM}); - - if(hardware_t::is_debug_enabled()) - { - hardware.log_debug("mem1_perf_ratio", hardware.mem1_perf_ratio); - hardware.log_debug("mem2_perf_ratio", hardware.mem2_perf_ratio); - hardware.log_debug("mem3_perf_ratio", hardware.mem3_perf_ratio); - hardware.log_debug("mem_bw_per_wg_coefficients(0)", - std::get<0>(hardware.mem_bw_per_wg_coefficients)); - hardware.log_debug("mem_bw_per_wg_coefficients(1)", - std::get<1>(hardware.mem_bw_per_wg_coefficients)); - hardware.log_debug("mem_bw_per_wg_coefficients(2)", - std::get<2>(hardware.mem_bw_per_wg_coefficients)); - hardware.log_debug("H_mem1 (mem1 hit ratio)", H_mem1); - hardware.log_debug("H_mem2 (mem2 hit ratio)", H_mem2); - hardware.log_debug("Total Load (bytes)", total_Ld); - hardware.log_debug("Ld_mem2 (bytes)", Ld_mem2); - hardware.log_debug("Ld_MEM (bytes)", Ld_MEM); - hardware.log_debug("L_mem_mem1 (cycles)", L_mem_mem1); - hardware.log_debug("L_mem_mem2 (cycles)", L_mem_mem2); - hardware.log_debug("L_mem_MEM (cycles)", L_mem_MEM); - hardware.log_debug("MT_K % 128 bytes", MT_K * safe_ceil_div(element_size_B, 8) % 128); - hardware.log_debug("MT_M % 128 bytes", MT_M * safe_ceil_div(element_size_A, 8) % 128); - hardware.log_debug("MT_N % 128 bytes", MT_N * safe_ceil_div(element_size_B, 8) % 128); - hardware.log_debug("MT_N % 128 + MT_M % 128 bytes", - (MT_M * safe_ceil_div(element_size_A, 8) % 128) - + MT_N * safe_ceil_div(element_size_B, 8) % 128); - hardware.log_debug("MT_N % 64 + MT_M % 64 bytes", - (MT_M * safe_ceil_div(element_size_A, 8) % 64) - + MT_N * safe_ceil_div(element_size_B, 8) % 64); - hardware.log_debug("MT_K % 64 bytes", MT_K * safe_ceil_div(element_size_B, 8) % 64); - hardware.log_debug("MT_M % 64 bytes", MT_M * safe_ceil_div(element_size_A, 8) % 64); - hardware.log_debug("MT_N % 64 bytes", MT_N * safe_ceil_div(element_size_B, 8) % 64); - hardware.log_debug("Tile Arithmetic Intensity", - MT_M * MT_N * MT_K / (MT_M * MT_K + MT_N * MT_K)); - } - - return L_mem; + } + } else if (config.cache_hints_a || config.cache_hints_b) { + return std::numeric_limits::max(); } + } - /* ---------------------------------------------------------------------------------------- */ - /* Tile-related functions */ - /* ---------------------------------------------------------------------------------------- */ - double compute_tile_latency(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, - size_t element_size_B, - size_t element_size_out, - data_type_t mi_datatype, - size_t mx_block_size, - int WGM, - int occupancy, - size_t numActiveCUs, - size_t splittingFactor) - { - // 1) Compute per-tile latencies - double L_compute = compute_mt_compute_latency(hardware, - M, - N, - K, - transA, - transB, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - mi_datatype); - - double L_mem = compute_memory_latency(hardware, - M, - N, - K, - transA, - transB, - batch, - MT_M, - MT_N, - MT_K, - element_size_A, - element_size_B, - mx_block_size, - WGM, - numActiveCUs, - splittingFactor); - - // TODO Does work utilization need to be 128-byte rounded for a cache line? - double utilization = calculate_work_utilization(M, N, K, MT_M, MT_N, MT_K); - double output_utilization = calculate_output_utilization(M, N, MT_M, MT_N, 1); - // The effective latency per useful operation increases as utilization drops. - // This penalty affects BOTH compute and memory bounds for the tile's core work. - double effective_tile_penalty = (utilization > 1e-9) ? (1.0 / (utilization)) : 1.0; - double output_utilization_penalty - = (output_utilization > 1e-9) ? (1.0 / (output_utilization)) : 1.0; - // 2) Work-group setup & iteration latencies - double L_WG_setup = 1; // WG_setup_Latency - - // 3) Prologue: 2.2× memory latency - double L_prologue = 1.5 * L_mem; // 1.5 chosen emprically - - // L_compute *= std::max(L_compute, L_LDS); - - // 4) Epilogue: writes from all active CUs with limited bandwidth - double mem_bw_occ = compute_mem_bw_from_occupancy(hardware, numActiveCUs); - double mem_bw_occ_limited = hardware.mem3_perf_ratio * mem_bw_occ; - size_t MT_M_rounded_128bytes = round_elements_to_128B(MT_M, element_size_A); - - double L_epilogue = (static_cast(numActiveCUs / splittingFactor) - * MT_M_rounded_128bytes * MT_N * safe_ceil_div(element_size_out, 8)) - / mem_bw_occ_limited; - // One compute iteration happens in the prologue - L_epilogue += L_compute * effective_tile_penalty; - // Epilogue and Prologue overhead are reduced with higher occupancy kernels. - int grid_m = static_cast(safe_ceil_div(M, MT_M)); - int grid_n = static_cast(safe_ceil_div(N, MT_N)); - int real_occupancy = std::min(occupancy, - static_cast(safe_ceil_div(grid_m * grid_n * batch * splittingFactor, - hardware.N_CU))); // Number of WGs per CU. - L_prologue = L_prologue * pow(0.95, real_occupancy); // Factor chosen empirically - L_epilogue = L_epilogue * pow(0.95, real_occupancy); // Factor chosen empirically - // 4') K-split reductions are globally coherent, we need to write and read split-1 MT_M*MT_N - // tiles to coherent memory - if(splittingFactor > 1) - { - size_t n_partials = splittingFactor - 1; - - // Only the reduction CU reads from all splits. - double partial_read_bytes = grid_m * grid_n * n_partials * MT_M_rounded_128bytes * MT_N - * safe_ceil_div(element_size_out, 8); - - // All CUs write (once for each partial, and once by the reduction CU for the output.) - double partial_write_bytes = grid_m * grid_n * MT_M_rounded_128bytes * MT_N - * safe_ceil_div(element_size_out, 8); - - double partial_readwrite_bytes = partial_read_bytes + partial_write_bytes; - - // 64 Threads active in a SIMD. Exposed to at least latency of reducing splittingfactor - // tiles. - double partial_adds = ((MT_M * MT_N) * splittingFactor) / (64); - // Things have to be written to memory - double mem_bw_occ = compute_mem_bw_from_occupancy(hardware, numActiveCUs); - double mem_bw_occ_limited = hardware.mem3_perf_ratio * mem_bw_occ; - - double L_reduce = partial_readwrite_bytes / (mem_bw_occ_limited); - L_epilogue += L_reduce + partial_adds + 10000; - } - // 4'') tf32 emu has some more overhead - double L_cvt = 0; - if((mi_datatype == data_type_t::XFloat32) - && (hardware.arch == hardware_t::architecture_t::gfx950)) - { - L_cvt = compute_cvt_overhead( - hardware, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K, element_size_A, element_size_B); - } - else if((element_size_A == 32) && (element_size_B == 32) - && (mi_datatype == data_type_t::BFloat16) - && (hardware.arch == hardware_t::architecture_t::gfx950)) // SS_BSS on GFX950 - { - L_cvt = compute_cvt_overhead_x1(hardware, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - mi_datatype); - } + // 1-1) To compute the latency, use default WGM. And WGM can't be greater than one + int defaultWGM = static_cast(ceil(std::sqrt(hardware.N_CU / hardware.NUM_XCD))); + auto config_with_default_wgm = config; + config_with_default_wgm.workgroup_mapping = std::max(defaultWGM, 1); - // 5) Single-tile latency (always additive) - // Calculate the fraction of the work that is useful (not padding). - - // 5) Single-tile latency (apply penalty after finding the bottleneck) - double L_tile_single = (std::max(L_compute, L_mem) * effective_tile_penalty) + L_cvt; - L_prologue *= effective_tile_penalty; - // 6) Number of K-iterations (excluding epilogue), at least 1 - // long num_iter = static_cast(((K + MT_K - 1) / MT_K)) - 1; - // num_iter = std::ceil(num_iter / splittingFactor); - // num_iter = std::max(num_iter, 1L); - long splittedK = static_cast(safe_ceil_div(K, splittingFactor)); - long num_iter - = std::max(static_cast(safe_ceil_div(splittedK, MT_K) - 1), static_cast(1)); - // Zero Padding in the K dimension on last iteration - if(K % MT_K != 0) - { - double problem_k_quant = ((K % MT_K) / (double)K); - L_epilogue - += problem_k_quant - * 50000; // Scale by remainder proportion of problem. 50k cycle penalty if have to zero pad all except 1. - //(Scale Determined Empirically) - } - //L_epilogue *= output_utilization_penalty; + // 1-2) Find CU occupancy + auto [num_wgs, num_active_cus, numWaves, splitting_factor] = compute_cu_occupancy( + problem, hardware, config_with_default_wgm, grid_selection_t::k_split_aware, max_cus); - // 7) Total tile latency - double L_tile_total - = (L_tile_single * num_iter) + L_prologue + L_epilogue * 2 + L_WG_setup - + (500 * num_iter); // 7 instructions (each with 4 cycles) at the end of the loop + // 2) Compute latency of a wave + // Compute latency of a wave + double L_wave = compute_timestep_latency( + problem, hardware, config_with_default_wgm, num_active_cus, splitting_factor); - if(MT_K == 1024) - { - L_prologue = L_prologue * 100; - } + // Compute latency for all waves and return it as the latency for the MT/problem + double total_latency = L_wave * numWaves; - if(hardware_t::is_debug_enabled()) - { - double problem_k_quant = ((K % MT_K) / (double)K); - hardware.log_debug("Iteration Compute Latency", L_compute); - hardware.log_debug("L_mem", L_mem); - hardware.log_debug("L_cvt", L_cvt); - hardware.log_debug("L_tile_single", L_tile_single); - hardware.log_debug("num_iter", num_iter); - hardware.log_debug("L_prologue", L_prologue); - hardware.log_debug("L_epilogue", L_epilogue); - hardware.log_debug("L_tile_total", L_tile_total); - hardware.log_debug("Effective Tile Peanlty", effective_tile_penalty); - hardware.log_debug("Problem K quant", problem_k_quant); - hardware.log_debug("K quant overhead", (problem_k_quant * 50000)); - hardware.log_debug("Problem Tiile Quant", utilization); - hardware.log_debug("Real Occupancy", utilization); - hardware.log_debug("Output Utilization Penalty", output_utilization_penalty); - hardware.log_debug("Output Utilization", output_utilization); - std::string bound_source; - if(L_compute >= L_mem) - { - L_tile_single = L_compute + L_cvt; - bound_source = "Compute"; - } - else - { - L_tile_single = L_mem + L_cvt; - bound_source = "Memory"; - } - hardware.log_debug("Iteration Bound", - bound_source + " (" + std::to_string(L_tile_single) + ")"); - - hardware.log_debug("K % MT_K", K % MT_K); - } + // 3) Customized heuristics + // TODO These are quantifying effects that don't work in the current math. + // TODO THESE SHOULD BE TEMPORARY FIXES AND BE MORE SOLIDLY INTEGRATED LATER + bool heuristics = get_runtime_options(config).heuristics_enabled; - return L_tile_total; + if (heuristics) { + if (MT_M == 64 && MT_N == 32 && MT_K == 32 && !b_trans && a_bits == 16) { + total_latency = total_latency * 10; } - // Computes the latency per K-complete MT wave - // A wave is defined as : The time it takes for one CU to complete one K-complete output tile - double compute_wave_latency(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, - size_t element_size_B, - size_t element_size_out, - data_type_t mi_datatype, - size_t mx_block_size, - int WGM, - int occupancy, - size_t numActiveCUs, - size_t splittingFactor) - { - // Assume latency of a wave is latency of a single k-complete output tile. - double L_wave = compute_tile_latency(hardware, - M, - N, - K, - batch, - transA, - transB, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - element_size_out, - mi_datatype, - mx_block_size, - WGM, - occupancy, - numActiveCUs, - splittingFactor); - - return L_wave; - } + bool tf32_emu = ((problem.mi_dtype == data_type_t::XFloat32) && + (hardware.arch == hardware_t::architecture_t::gfx950)); - // Compute the total latency of a gemm based on the latency of one wave multiplied by the number of - // waves A wave is defined as : The time it takes for one CU to complete one K-complete output tile - double compute_total_latency(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, - size_t element_size_B, - size_t element_size_out, - data_type_t mi_datatype, - size_t mx_block_size, - int WGM, - int non_temporal_a, - int non_temporal_b, - int occupancy, - size_t split, - size_t max_cus) - { - if(hardware_t::is_debug_enabled()) - { - hardware.log_debug("Problem_Size", - std::to_string(int(M)) + "x" + std::to_string(int(N)) + "x" - + std::to_string(int(K))); - hardware.log_debug("Macro_Tile", - std::to_string(int(MT_M)) + "x" + std::to_string(int(MT_N)) + "x" - + std::to_string(int(MT_K))); - hardware.log_debug("Element Size A (bits)", element_size_A); - hardware.log_debug("Element Size B (bits)", element_size_B); - } + // Heuristics for TF32 + if (tf32_emu) { + double bytes_per_element = static_cast(a_bytes); + double arith = emulated_tf32_arithmetic_intensity(M, N, K, bytes_per_element); + double compute_threshold = 1000; // threshold empirically determined. - // 0) Short-circuit - // We don't need to compute latency for all MTs. With this, we can shortcut. - bool shortCircuit = true; - if(shortCircuit) - { - // When problem dimensions are small enough that we can fit them in one tile, we should do - // so. This short circuit condition also decreases selection latency when problems are very - // small :) - // TODO 256 and 256 here should be largest M and N tile dimensions in library - if(M <= 256 && N <= 256 && K < 1024 && batch != 1 && (MT_M < M || MT_N < N)) - return std::numeric_limits::max(); - - // Override dot2 instruction with vector lane widths - if(MI_N == 0 && MI_M == 0 && MI_K == 0) - { - // We only use Dot2 for NN layout where M < 3 - if(M > 2 || transA || transB) - return std::numeric_limits::max(); - MI_M = 1; - MI_N = 1; - MI_K = 64; - } - - - - - if(batch == 1) - { - - - - size_t K_mod_128bytes = K * safe_ceil_div(element_size_A, 8) % 128; - size_t MT_K_mod_128bytes = MT_K * safe_ceil_div(element_size_A, 8) % 128; - if(K_mod_128bytes == 0 && MT_K_mod_128bytes == 0 && batch == 1) - { - // avoid division by 0 if K == 0 - if(M <= MT_M *2 && !transB && ((N * element_size_B)/(M * element_size_A) > 5)) - { - //Use nontemporal B - if(!(non_temporal_b == 4)) - { - return std::numeric_limits::max(); - } - } - else if(N <= MT_N *2 && transA && ((M * element_size_A)/(N * element_size_B) > 5)) - { - //Use Non Temporal A - if(!(non_temporal_a == 4)) - { - return std::numeric_limits::max(); - } - } - else - { - //Never use Non Temporal - if(non_temporal_a || non_temporal_b) - { - return std::numeric_limits::max(); - } - } - } - else if(non_temporal_a || non_temporal_b) - { - return std::numeric_limits::max(); - } - - } - - - - } - - // 1-1) WGM - WGM = std::max(WGM, 1); // WGM can't be less than one. - occupancy = std::max(occupancy, 1); // occupancy can't be less than one. - - // 1-2) Find CU occupancy - auto [numWGs, numActiveCUs, numWaves, splittingFactor] - = compute_CU_occupancy(hardware, - M, - N, - K, - batch, - transA, - transB, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - element_size_out, - mi_datatype, - WGM, - std::numeric_limits::max(), // workspace - std::numeric_limits::max(), // workspace per c - 0, // occupancy - 6, // dynamic_grid - 0, - max_cus); - - // 2) Compute latency of a wave - // Compute latency of a wave - double L_wave = compute_wave_latency(hardware, - M, - N, - K, - batch, - transA, - transB, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - element_size_out, - mi_datatype, - mx_block_size, - WGM, - occupancy, - numActiveCUs, - splittingFactor); - - // Compute latency for all waves and return it as the latency for the MT/problem - double total_latency = L_wave * numWaves; - - if(MT_M == 64 && MT_N == 32 && MT_K == 32 && !transB && element_size_A == 16) - { - total_latency = total_latency * 10; - } - - // 3) Customized heuristics - // TODO These are quantifying effects that don't work in the current math. - // TODO THESE SHOULD BE TEMPORARY FIXES AND BE MORE SOLIDLY INTEGRATED LATER - bool heuristics = hardware_t::is_heuristics_enabled(); - - const char* env = std::getenv("ANALYTICAL_GEMM_HEURISTICS"); - heuristics = !(env && std::string(env) == "0"); - - - - // heuristics = 0; - // Heuristics for TF32 - bool tf32_emu = ((mi_datatype == data_type_t::XFloat32) - && (hardware.arch == hardware_t::architecture_t::gfx950)); - if(tf32_emu && heuristics) - { - double bytes_per_element = static_cast(element_size_A) / 8.0; - double arith = emulated_tf32_arithmetic_intensity(M, N, K, bytes_per_element); - double compute_threshold = 1000; // threshold empirically determined. - - // The kernel for this is more optimized (Custom kernel NT) - if((!transA && transB) && MT_M == 256 && MT_N == 256 && MT_K == 32) - { - if(arith < compute_threshold) - total_latency = total_latency * 0.6; - else - total_latency = total_latency * 0.4; - } - - // The kernel for this is more optimized (Custom kernel NN) - if((!transA && !transB) && MT_M == 256 && MT_N == 256 && MT_K == 32) - { - if(arith < compute_threshold) - total_latency = total_latency * 0.8; - else - total_latency = total_latency * 0.4; - } - - // The kernel for this is more optimized (Custom kernel TN) - if((transA && !transB) && MT_M == 256 && MT_N == 256 && MT_K == 32) - { - if(arith < compute_threshold) - total_latency = total_latency * 0.8; - else - total_latency = total_latency * 0.4; - } - - // Bias large DU where K-dimension is large and M and N are small. - if((K >= (M * 16) && K >= (N * 16)) && (MT_K >= 128)) - { - total_latency = total_latency * 0.5; - } - } + // The kernel for this is more optimized (Custom kernel NT) + if ((!a_trans && b_trans) && MT_M == 256 && MT_N == 256 && MT_K == 32) { + if (arith < compute_threshold) + total_latency = total_latency * 0.6; + else + total_latency = total_latency * 0.4; + } - if(hardware_t::is_debug_enabled()) - { - hardware.log_debug("Total_latency (with heuristics)", total_latency); - hardware.log_debug("non_temporal_a", non_temporal_a); - hardware.log_debug("non_temporal_b", non_temporal_b); - hardware.log_debug("kernel_occupancy", occupancy); - hardware.log_debug("splitting_factor", splittingFactor); - hardware.log_debug("Input Tile Size A", MT_M * MT_K); - hardware.log_debug("Input Tile Size B", MT_N * MT_K); - hardware.log_debug("Output Tile Size", MT_M * MT_N); - hardware.log_debug("Tile M/N", MT_M / MT_N); - hardware.log_debug("Tile N/M", MT_N / MT_M); - hardware.log_debug("Problem M/N", MT_M / MT_N); - hardware.log_debug("Problem N/M", MT_N / MT_M); - size_t occupancy_percent = numActiveCUs / hardware.N_CU; - hardware.log_debug("Peak theoretical GFLOPs based on occupancy", - 1300 * occupancy_percent); - if(hardware_t::is_debug_enabled()) - { - hardware.print_debug_info(); - } - } + // The kernel for this is more optimized (Custom kernel NN) + if ((!a_trans && !b_trans) && MT_M == 256 && MT_N == 256 && MT_K == 32) { + if (arith < compute_threshold) + total_latency = total_latency * 0.8; + else + total_latency = total_latency * 0.4; + } - return total_latency; - } + // The kernel for this is more optimized (Custom kernel TN) + if ((a_trans && !b_trans) && MT_M == 256 && MT_N == 256 && MT_K == 32) { + if (arith < compute_threshold) + total_latency = total_latency * 0.8; + else + total_latency = total_latency * 0.4; + } - // Compute the performance from the latency. - // IMPORTANT : This program is NOT meant to be an analytical model for performance, but rather a way - // to rank different macro tile sizes. These performance values could be wildly inaccurate in - // absolute terms, but will often result in the correct ranking of MTin relative terms. - double compute_perf_gflops(const hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, - size_t element_size_B, - size_t element_size_out, - data_type_t mi_datatype, - int WGM, - size_t max_cus) - { - // Compute total FLOPs - double total_FLOPs = 2.0 * M * N * K; // For GEMM, each multiply-add is 2 FLOPs - // Compute total time in seconds - double cycles_per_second - = hardware.compute_clock_ghz * 1e9; // 1 GHz = 1e9 cycles per second - size_t mx_block_size = 0; - double latency_cycles = compute_total_latency(hardware, - M, - N, - K, - batch, - transA, - transB, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - element_size_out, - mi_datatype, - mx_block_size, - WGM, - 0, - 0, - 1, - 0, - max_cus); - double total_time_seconds = latency_cycles / cycles_per_second; - // Compute performance in FLOPS - double FLOPS = total_FLOPs / total_time_seconds; - // Convert to TFLOPS - double GFLOPS = FLOPS / 1e9; // 1 TFLOP = 1e9 FLOPs - return GFLOPS; + // Bias large DU where K-dimension is large and M and N are small. + if ((K >= (M * 16) && K >= (N * 16)) && (MT_K >= 128)) { + total_latency = total_latency * 0.5; + } } -} // namespace origami + } + + if (get_runtime_options(config).debug_enabled) { + config.logger.log("Total_latency (with heuristics)", total_latency); + config.logger.log("non_temporal_a", config.cache_hints_a); + config.logger.log("non_temporal_b", config.cache_hints_b); + config.logger.log("kernel_occupancy", config.occupancy); + config.logger.log("splitting_factor", splitting_factor); + config.logger.log("Input Tile Size A", MT_M * MT_K); + config.logger.log("Input Tile Size B", MT_N * MT_K); + config.logger.log("Output Tile Size", MT_M * MT_N); + config.logger.log("Tile M/N", MT_M / MT_N); + config.logger.log("Tile N/M", MT_N / MT_M); + config.logger.log("Problem M/N", M / N); + config.logger.log("Problem N/M", N / M); + size_t occupancy_percent = num_active_cus / hardware.N_CU; + config.logger.log("Peak theoretical GFLOPs based on occupancy", 1300 * occupancy_percent); + if (get_runtime_options(config).debug_enabled) { config.logger.print(); } + } + + return total_latency; +} + +} // namespace origami diff --git a/shared/origami/src/origami/hardware.cpp b/shared/origami/src/origami/hardware.cpp new file mode 100644 index 000000000000..51485fb2bf1d --- /dev/null +++ b/shared/origami/src/origami/hardware.cpp @@ -0,0 +1,388 @@ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "origami/hardware.hpp" +#include "origami/types.hpp" + +#include +#include +#include +#include + +namespace origami { + +// Static member definition +// clang-format off +const std::unordered_map> + hardware_t::INSTRUCTION_MAP = { + {hardware_t::architecture_t::gfx90a, + { + // F32 + {matrix_instruction(32, 32, 2, data_type_t::Float), 64}, // v_mfma_f32_32x32x2_f32 + {matrix_instruction(32, 32, 1, data_type_t::Float), 64}, // v_mfma_f32_32x32x1_2b_f32 + {matrix_instruction(16, 16, 4, data_type_t::Float), 32}, // v_mfma_f32_16x16x4_f32 + {matrix_instruction(16, 16, 1, data_type_t::Float), 32}, // v_mfma_f32_16x16x1_4b_f32 + {matrix_instruction(4, 4, 1, data_type_t::Float), 8}, // v_mfma_f32_4x4x1_16b_f32 + + // F64 + {matrix_instruction(16, 16, 4, data_type_t::Double), 32}, // v_mfma_f64_16x16x4_f64 + {matrix_instruction(4, 4, 4, data_type_t::Double), 16}, // v_mfma_f64_4x4x4_4b_f64 + + // TODO ComplexFloat + // TODO ComplexDouble + + // F16 + {matrix_instruction(32, 32, 4, data_type_t::Half), 64}, // v_mfma_f32_32x32x4_2b_f16 + {matrix_instruction(32, 32, 8, data_type_t::Half), 64}, // v_mfma_f32_32x32x8_f16 + {matrix_instruction(16, 16, 4, data_type_t::Half), 32}, // v_mfma_f32_16x16x4_4b_f16 + {matrix_instruction(16, 16, 16, data_type_t::Half), 32}, // v_mfma_f32_16x16x16_f16 + {matrix_instruction(4, 4, 4, data_type_t::Half), 8}, // v_mfma_f32_4x4x4_16b_f16 + + // BF16 + {matrix_instruction(32, 32, 4, data_type_t::BFloat16), 64}, // v_mfma_f32_32x32x4_2b_bf16 + {matrix_instruction(32, 32, 8, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x8_bf16 + {matrix_instruction(16, 16, 4, data_type_t::BFloat16), 32}, // v_mfma_f32_16x16x4_4b_bf16 + {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 + {matrix_instruction(4, 4, 4, data_type_t::BFloat16), 8}, // v_mfma_f32_4x4x4_16b_bf16 + + // I8 + {matrix_instruction(32, 32, 8, data_type_t::Int8), 64}, // v_mfma_f32_32x32x16_f8 + {matrix_instruction(32, 32, 4, data_type_t::Int8), 64}, // v_mfma_i32_32x32x4_2b_i8 + {matrix_instruction(16, 16, 16, data_type_t::Int8), 32}, // v_mfma_f32_16x16x32_i8 + {matrix_instruction(16, 16, 4, data_type_t::Int8), 32}, // v_mfma_i32_16x16x4_4b_i8 + {matrix_instruction(4, 4, 4, data_type_t::Int8), 8}, // v_mfma_i32_4x4x4_16b_i8 + + // XF32 + {matrix_instruction(32, 32, 8, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x8_bf16 * 3 + {matrix_instruction(32, 32, 16, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x16_bf16 * 3 + {matrix_instruction(16, 16, 16, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 + {matrix_instruction(16, 16, 32, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 + }}, + {hardware_t::architecture_t::gfx942, + { + // F32 + {matrix_instruction(32, 32, 2, data_type_t::Float), 64}, // v_mfma_f32_32x32x2_f32 + {matrix_instruction(32, 32, 1, data_type_t::Float), 64}, // v_mfma_f32_32x32x1_2b_f32 + {matrix_instruction(16, 16, 4, data_type_t::Float), 32}, // v_mfma_f32_16x16x4_f32 + {matrix_instruction(16, 16, 1, data_type_t::Float), 32}, // v_mfma_f32_16x16x1_4b_f32 + {matrix_instruction(4, 4, 1, data_type_t::Float), 8}, // v_mfma_f32_4x4x1_16b_f32 + + // F64 + {matrix_instruction(16, 16, 4, data_type_t::Double), 32}, // v_mfma_f64_16x16x4_f64 + {matrix_instruction(4, 4, 4, data_type_t::Double), 16}, // v_mfma_f64_4x4x4_4b_f64 + + // TODO ComplexFloat + // TODO ComplexDouble + + // F16 + {matrix_instruction(32, 32, 4, data_type_t::Half), 64}, // v_mfma_f32_32x32x4_2b_f16 + {matrix_instruction(32, 32, 8, data_type_t::Half), 32}, // v_mfma_f32_32x32x8_f16 + {matrix_instruction(16, 16, 4, data_type_t::Half), 32}, // v_mfma_f32_16x16x4_4b_f16 + {matrix_instruction(16, 16, 16, data_type_t::Half), 16}, // v_mfma_f32_16x16x16_f16 + {matrix_instruction(4, 4, 4, data_type_t::Half), 8}, // v_mfma_f32_4x4x4_16b_f16 + + // BF16 + {matrix_instruction(32, 32, 4, data_type_t::BFloat16), 64}, // v_mfma_f32_32x32x4_2b_bf16 + {matrix_instruction(32, 32, 8, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x8_bf16 + {matrix_instruction(16, 16, 4, data_type_t::BFloat16), 32}, // v_mfma_f32_16x16x4_4b_bf16 + {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 + {matrix_instruction(4, 4, 4, data_type_t::BFloat16), 8}, // v_mfma_f32_4x4x4_16b_bf16 + + // F8 + {matrix_instruction(32, 32, 16, data_type_t::Float8_fnuz), 32}, // v_mfma_f32_32x32x16_f8 + {matrix_instruction(16, 16, 32, data_type_t::Float8_fnuz), 16}, // v_mfma_f32_16x16x32_f8 + + // BF8 + {matrix_instruction(32, 32, 16, data_type_t::BFloat8_fnuz), 32}, // v_mfma_f32_32x32x16_bf8 + {matrix_instruction(16, 16, 32, data_type_t::BFloat8_fnuz), 16}, // v_mfma_f32_16x16x32_bf8 + + // F8B8 + {matrix_instruction(32, 32, 16, data_type_t::Float8BFloat8_fnuz), 32}, // v_mfma_f32_32x32x16_f8_bf8 + {matrix_instruction(16, 16, 32, data_type_t::Float8BFloat8_fnuz), 16}, // v_mfma_f32_16x16x32_f8_bf8 + + // B8F8 + {matrix_instruction(32, 32, 16, data_type_t::BFloat8Float8_fnuz), 32}, // v_mfma_f32_32x32x16_bf8_f8 + {matrix_instruction(16, 16, 32, data_type_t::BFloat8Float8_fnuz), 16}, // v_mfma_f32_16x16x32_bf8_f8 + + // I8 + {matrix_instruction(32, 32, 16, data_type_t::Int8), 32}, // v_mfma_f32_32x32x16_f8 + {matrix_instruction(32, 32, 4, data_type_t::Int8), 64}, // v_mfma_i32_32x32x4_2b_i8 + {matrix_instruction(16, 16, 32, data_type_t::Int8), 16}, // v_mfma_f32_16x16x32_i8 + {matrix_instruction(16, 16, 4, data_type_t::Int8), 32}, // v_mfma_i32_16x16x4_4b_i8 + {matrix_instruction(4, 4, 4, data_type_t::Int8), 8}, // v_mfma_i32_4x4x4_16b_i8 + + // XF32 + {matrix_instruction(32, 32, 4, data_type_t::XFloat32), 32}, // v_mfma_f32_32x32x4_xf32 + {matrix_instruction(16, 16, 32, data_type_t::XFloat32), 16}, // v_mfma_f32_16x16x8_xf32 + }}, + {hardware_t::architecture_t::gfx950, + { + // F32 + {matrix_instruction(32, 32, 2, data_type_t::Float), 64}, // v_mfma_f32_32x32x2_f32 + {matrix_instruction(32, 32, 1, data_type_t::Float), 64}, // v_mfma_f32_32x32x1_2b_f32 + {matrix_instruction(16, 16, 4, data_type_t::Float), 32}, // v_mfma_f32_16x16x4_f32 + {matrix_instruction(16, 16, 1, data_type_t::Float), 32}, // v_mfma_f32_16x16x1_4b_f32 + {matrix_instruction(4, 4, 1, data_type_t::Float), 8}, // v_mfma_f32_4x4x1_16b_f32 + + // F64 + {matrix_instruction(16, 16, 4, data_type_t::Double), 64}, // v_mfma_f64_16x16x4_f64 + {matrix_instruction(4, 4, 4, data_type_t::Double), 16}, // v_mfma_f64_4x4x4_4b_f64 + + // TODO ComplexFloat + // TODO ComplexDouble + + // F16 + {matrix_instruction(32, 32, 8, data_type_t::Half), 32}, // v_mfma_f32_32x32x8_f16 + {matrix_instruction(32, 32, 16, data_type_t::Half), 32}, // v_mfma_f32_32x32x16_f16 + {matrix_instruction(16, 16, 16, data_type_t::Half), 16}, // v_mfma_f32_16x16x16_f16 + {matrix_instruction(16, 16, 32, data_type_t::Half), 16}, // v_mfma_f32_16x16x32_f16 + + // BF16 + {matrix_instruction(32, 32, 8, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x8_bf16 + {matrix_instruction(32, 32, 16, data_type_t::BFloat16), 32}, // v_mfma_f32_32x32x16_bf16 + {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 + {matrix_instruction(16, 16, 32, data_type_t::BFloat16), 16}, // v_mfma_f32_16x16x16_bf16 + + // F8 + {matrix_instruction(32, 32, 64, data_type_t::Float8), 64}, // v_mfma_f32_32x32x64_f8 + {matrix_instruction(32, 32, 16, data_type_t::Float8), 32}, // v_mfma_f32_32x32x16_f8 + {matrix_instruction(16, 16, 128, data_type_t::Float8), 32}, // v_mfma_f32_16x16x128_f8 + {matrix_instruction(16, 16, 32, data_type_t::Float8), 16}, // v_mfma_f32_16x16x32_f8 + + // BF8 + {matrix_instruction(32, 32, 64, data_type_t::BFloat8), 64}, // v_mfma_f32_32x32x64_bf8 + {matrix_instruction(32, 32, 16, data_type_t::BFloat8), 32}, // v_mfma_f32_32x32x16_bf8 + {matrix_instruction(16, 16, 128, data_type_t::BFloat8), 32}, // v_mfma_f32_16x16x128_bf8 + {matrix_instruction(16, 16, 32, data_type_t::BFloat8), 16}, // v_mfma_f32_16x16x32_bf8 + + // F8B8 + {matrix_instruction(32, 32, 64, data_type_t::Float8BFloat8), 64}, // v_mfma_f32_32x32x64_f8_bf8 + {matrix_instruction(32, 32, 16, data_type_t::Float8BFloat8), 32}, // v_mfma_f32_32x32x16_f8_bf8 + {matrix_instruction(16, 16, 128, data_type_t::Float8BFloat8), 32}, // v_mfma_f32_16x16x128_f8_bf8 + {matrix_instruction(16, 16, 32, data_type_t::Float8BFloat8), 16}, // v_mfma_f32_16x16x32_f8_bf8 + + // B8F8 + {matrix_instruction(32, 32, 64, data_type_t::BFloat8Float8), 64}, // v_mfma_f32_32x32x64_bf8_f8 + {matrix_instruction(32, 32, 16, data_type_t::BFloat8Float8), 32}, // v_mfma_f32_32x32x16_bf8_f8 + {matrix_instruction(16, 16, 128, data_type_t::BFloat8Float8), 32}, // v_mfma_f32_16x16x128_bf8_f8 + {matrix_instruction(16, 16, 32, data_type_t::BFloat8Float8), 16}, // v_mfma_f32_16x16x32_bf8_f8 + + // I8 + {matrix_instruction(32, 32, 16, data_type_t::Int8), 32}, // v_mfma_f32_32x32x16_f8 + {matrix_instruction(32, 32, 4, data_type_t::Int8), 64}, // v_mfma_i32_32x32x4_2b_i8 + {matrix_instruction(16, 16, 32, data_type_t::Int8), 16}, // v_mfma_f32_16x16x32_i8 + {matrix_instruction(16, 16, 4, data_type_t::Int8), 32}, // v_mfma_i32_16x16x4_4b_i8 + {matrix_instruction(4, 4, 4, data_type_t::Int8), 8}, // v_mfma_i32_4x4x4_16b_i8 + + // XF32 + {matrix_instruction(32, 32, 8, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x8_bf16 * 3 + {matrix_instruction(32, 32, 16, data_type_t::XFloat32), 96}, // v_mfma_f32_32x32x16_bf16 * 3 + {matrix_instruction(16, 16, 16, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 + {matrix_instruction(16, 16, 32, data_type_t::XFloat32), 48}, // v_mfma_f32_16x16x16_bf16 * 3 + + // F6 + {matrix_instruction(32, 32, 64, data_type_t::Float6), 32}, // v_mfma_f32_32x32x64_f6 + {matrix_instruction(16, 16, 128, data_type_t::Float6), 16}, // v_mfma_f32_16x16x128_f6 + + // BF6 + {matrix_instruction(32, 32, 64, data_type_t::BFloat6), 32}, // v_mfma_f32_32x32x64_bf6 + {matrix_instruction(16, 16, 128, data_type_t::BFloat6), 16}, // v_mfma_f32_16x16x128_bf6 + + // F4 + {matrix_instruction(32, 32, 64, data_type_t::Float4), 32}, // v_mfma_f32_32x32x64_f4 + {matrix_instruction(16, 16, 128, data_type_t::Float4), 16}, // v_mfma_f32_16x16x128_f4 + + // DOT2 + {matrix_instruction(1, 1, 64, data_type_t::Half), 16}, // V_DOT2_F32_F16 + {matrix_instruction(1, 1, 64, data_type_t::BFloat16), 16}, // V_DOT2_F32_BF16 + }}, + {hardware_t::architecture_t::gfx1201, + { + // F16 + {matrix_instruction(16, 16, 16, data_type_t::Half), 16}, // v_wmma_f16_16x16x16_f16/v_wmma_f32_16x16x16_f16 + + // BF16 + {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 16}, // v_wmma_bf16_16x16x16_bf16/v_wmma_f32_16x16x16_bf16 + + // F8 + {matrix_instruction(16, 16, 16, data_type_t::Float8), 8}, // v_wmma_f32_16x16x16_fp8_fp8 + + // F8B8 + {matrix_instruction(16, 16, 16, data_type_t::Float8BFloat8), 8}, // v_wmma_f32_16x16x16_fp8_bf8 + + // B8F8 + {matrix_instruction(16, 16, 16, data_type_t::BFloat8Float8), 8}, // v_wmma_f32_16x16x16_bf8_fp8 + + // B8 + {matrix_instruction(16, 16, 16, data_type_t::BFloat8), 8}, // v_wmma_f32_16x16x16_bf8_bf8 + + // I8 + {matrix_instruction(16, 16, 16, data_type_t::Int8), 8}, // v_wmma_i32_16x16x16_iu8 + + // I4 + {matrix_instruction(16, 16, 16, data_type_t::Int4), 8}, // v_wmma_i32_16x16x16_iu4 + {matrix_instruction(16, 16, 32, data_type_t::Int4), 8}, // v_wmma_i32_16x16x32_iu4 + }}, + {hardware_t::architecture_t::gfx1100, + { + // F16 + {matrix_instruction(16, 16, 16, data_type_t::Half), 32}, // v_wmma_f32_16x16x16_f16/v_wmma_f16_16x16x16_f16 + + // BF16 + {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 32}, // v_wmma_f32_16x16x16_bf16/v_wmma_bf16_16x16x16_bf16 + + // I8 + {matrix_instruction(16, 16, 16, data_type_t::Int8), 32}, // v_wmma_i32_16x16x16_iu8 + + // I4 + {matrix_instruction(16, 16, 16, data_type_t::Int4), 16}, // v_wmma_i32_16x16x16_iu4 + }}, + {hardware_t::architecture_t::gfx1151, + { + // F16 + {matrix_instruction(16, 16, 16, data_type_t::Half), 32}, // v_wmma_f32_16x16x16_f16/v_wmma_f16_16x16x16_f16 + + // BF16 + {matrix_instruction(16, 16, 16, data_type_t::BFloat16), 32}, // v_wmma_f32_16x16x16_bf16/v_wmma_bf16_16x16x16_bf16 + + // I8 + {matrix_instruction(16, 16, 16, data_type_t::Int8), 32}, // v_wmma_i32_16x16x16_iu8 + + // I4 + {matrix_instruction(16, 16, 16, data_type_t::Int4), 16}, // v_wmma_i32_16x16x16_iu4 + }}}; +// clang-format on + +hardware_t::hardware_t(architecture_t arch, + size_t N_CU, + size_t lds_capacity, + size_t NUM_XCD, + double mem1_perf_ratio, + double mem2_perf_ratio, + double mem3_perf_ratio, + size_t L2_capacity, + double compute_clock_ghz, + size_t parallel_mi_cu, + std::tuple mem_bw_per_wg_coefficients) + : arch(arch) + , N_CU(N_CU) + , lds_capacity(lds_capacity) + , mem1_perf_ratio(mem1_perf_ratio) + , mem2_perf_ratio(mem2_perf_ratio) + , mem3_perf_ratio(mem3_perf_ratio) + , L2_capacity(L2_capacity) + , CU_per_L2(N_CU / NUM_XCD) + , compute_clock_ghz(compute_clock_ghz) + , parallel_mi_cu(parallel_mi_cu) + , mem_bw_per_wg_coefficients(mem_bw_per_wg_coefficients) + , NUM_XCD(NUM_XCD) {} + +hardware_t::hardware_t(hipDeviceProp_t properties) + : hardware_t(get_hardware_for_properties(properties)) {} + +hardware_t::hardware_t(const hardware_t& other) + : arch(other.arch) + , N_CU(other.N_CU) + , lds_capacity(other.lds_capacity) + , mem1_perf_ratio(other.mem1_perf_ratio) + , mem2_perf_ratio(other.mem2_perf_ratio) + , mem3_perf_ratio(other.mem3_perf_ratio) + , L2_capacity(other.L2_capacity) + , CU_per_L2(other.CU_per_L2) + , compute_clock_ghz(other.compute_clock_ghz) + , parallel_mi_cu(other.parallel_mi_cu) + , mem_bw_per_wg_coefficients(other.mem_bw_per_wg_coefficients) + , NUM_XCD(other.NUM_XCD) {} + +hardware_t hardware_t::get_hardware_for_properties(hipDeviceProp_t properties) { + auto arch_name = get_before_first_colon(properties.gcnArchName); + auto arch_enum = arch_name_to_enum(arch_name); + if (arch_enum == architecture_t::Count) { + throw std::runtime_error( + std::string("Attempting to retrieve hardware constants for unsupported architecture: ") + + std::string(arch_name)); + } + auto constants = get_arch_constants(arch_enum); + return hardware_t( + arch_enum, + properties.multiProcessorCount, + properties.sharedMemPerBlock, + constants.num_xcds, + 1e9 * constants.mem1_perf_ratio / properties.clockRate, + 1e9 * constants.mem2_perf_ratio / (properties.memoryClockRate * constants.mem_clock_ratio), + 1e9 * constants.mem3_perf_ratio / properties.memoryClockRate, + properties.l2CacheSize, + properties.clockRate / 1e6, + constants.parallel_mi_cu, + constants.mem_bw_per_wg_coefficients); +} + +hardware_t hardware_t::get_hardware_for_device(int deviceId) { + hipDeviceProp_t prop; + hipError_t e = hipGetDeviceProperties(&prop, deviceId); + if (e) { throw std::runtime_error(hipGetErrorString(e)); } + return get_hardware_for_properties(prop); +} + +bool hardware_t::is_hardware_supported(hipDeviceProp_t properties) { + auto arch_name = get_before_first_colon(properties.gcnArchName); + auto arch_enum = arch_name_to_enum(arch_name); + return arch_enum != architecture_t::Count; +} + +void hardware_t::print() const { + std::cout << "================== Hardware Configuration ==================\n"; + std::cout << "Number of CUs (N_CU) : " << N_CU << "\n"; + std::cout << "LDS capacity : " << lds_capacity << " bytes\n"; + std::cout << "mem1_perf_ratio : " << mem1_perf_ratio << "\n"; + std::cout << "mem2_perf_ratio : " << mem2_perf_ratio << "\n"; + std::cout << "mem3_perf_ratio : " << mem3_perf_ratio << "\n"; + std::cout << "L2 Cache capacity : " << L2_capacity << " bytes\n"; + std::cout << "CUs per L2 domain : " << CU_per_L2 << "\n"; + std::cout << "Compute clock (GHz) : " << compute_clock_ghz << "\n"; + std::cout << "Parallel MI/CU : " << parallel_mi_cu << "\n"; + std::cout << "Number of XCDs (NUM_XCD) : " << NUM_XCD << "\n"; + std::cout << "mem_bw_per_wg_coefficients: " << std::get<0>(mem_bw_per_wg_coefficients) << ", " + << std::get<1>(mem_bw_per_wg_coefficients) << ", " + << std::get<2>(mem_bw_per_wg_coefficients) << "\n\n"; + + std::cout << "------------------ Instruction Map -------------------------\n"; + // Loop over the instruction_map and print each entry + for (const auto& kv : INSTRUCTION_MAP.at(arch)) { + const auto& key = kv.first; + const auto& L_MI = kv.second; + + std::cout << "Instruction: MI_M=" << key.MI_M << ", MI_N=" << key.MI_N << ", MI_K=" << key.MI_K + << ", mi_input_type=" << datatype_to_string(key.mi_input_type) << " bytes\n" + << " -> Latency (L_MI): " << L_MI << "\n"; + } + std::cout << "===========================================================\n"; +} + +size_t hardware_t::get_mi_latency(size_t MI_M, + size_t MI_N, + size_t MI_K, + data_type_t mi_input_type) const { + const auto& instruction_map = INSTRUCTION_MAP.at(arch); + auto key = matrix_instruction(MI_M, MI_N, MI_K, mi_input_type); + + auto it = instruction_map.find(key); + if (it != instruction_map.end()) { + return it->second / parallel_mi_cu; + } else { + if (origami::runtime_options().get().debug_enabled) + std::cerr << "Warning: Latency not found for MI_M=" << MI_M << ", MI_N=" << MI_N + << ", MI_K=" << MI_K << ", mi_input_type=" << datatype_to_string(mi_input_type) + << ". Returning latency value of 32 (really slow).\n"; + return 32 / parallel_mi_cu; // Default latency if instruction is not found + } +} + +std::string hardware_t::get_before_first_colon(const std::string& input) { + size_t pos = input.find(':'); + if (pos != std::string::npos) { return input.substr(0, pos); } + return input; // Return the whole string if ':' is not found +} + +} // namespace origami diff --git a/shared/origami/src/origami/log.cpp b/shared/origami/src/origami/log.cpp new file mode 100644 index 000000000000..b6e595f766e2 --- /dev/null +++ b/shared/origami/src/origami/log.cpp @@ -0,0 +1,62 @@ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "origami/log.hpp" + +#include +#include +#include +#include + +namespace origami { + +void logger_t::print() const { + if (!metrics_ || metrics_->empty()) { + std::cout << "{}\n"; + return; + } + std::cout << "{\n"; + bool first = true; + for (const auto& [key, val] : *metrics_) { + if (!first) std::cout << ",\n"; + std::cout << " \"" << key << "\": " << val; + first = false; + } + std::cout << "\n}\n"; +} + +void logger_t::export_json(const std::string& filename) const { + if (!metrics_ || metrics_->empty()) { + std::ofstream file(filename); + if (!file.is_open()) { + std::cerr << "Error: Could not open file " << filename << " for writing\n"; + return; + } + file << "{}\n"; + file.close(); + std::cout << "Analytical metrics exported to JSON: " << filename << "\n"; + return; + } + + std::ofstream file(filename); + if (!file.is_open()) { + std::cerr << "Error: Could not open file " << filename << " for writing\n"; + return; + } + + file << std::fixed << std::setprecision(6); + file << "{\n"; + bool first = true; + for (const auto& [key, val] : *metrics_) { + if (!first) file << ",\n"; + file << " \"" << key << "\": " << val; + first = false; + } + file << "\n}\n"; + + file.close(); + std::cout << "Analytical metrics exported to JSON: " << filename << "\n"; +} + +} // namespace origami + diff --git a/shared/origami/src/origami/origami.cpp b/shared/origami/src/origami/origami.cpp new file mode 100644 index 000000000000..ada2b9bb3d30 --- /dev/null +++ b/shared/origami/src/origami/origami.cpp @@ -0,0 +1,446 @@ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include +#include + +#include "origami/gemm.hpp" +#include "origami/math.hpp" +#include "origami/origami.hpp" +#include "origami/streamk.hpp" +#include "origami/types.hpp" + +namespace origami { + +std::vector select_topk_configs(const problem_t& problem, + const hardware_t& hardware, + const std::vector& configs, + std::size_t topk) { + // Use rank_configs to get configurations with latencies ranked by performance + auto ranked_configs = rank_configs(problem, hardware, configs); + + // Return only the top K configurations + std::vector topk_configs; + size_t count = std::min(topk, ranked_configs.size()); + topk_configs.reserve(count); + for (size_t i = 0; i < count; ++i) { topk_configs.push_back(ranked_configs[i]); } + return topk_configs; +} + +/** + * @brief Selects the best WGM (maximizing L2 hit rate) given fixed macro tile sizes. + * + * @param[in] problem Problem description (M, N, K, etc.) + * @param[in] hardware Hardware characteristics + * @param config Kernel configuration. + * + * @return A tuple: best predicted (wgmxcc, wgm). + */ + +std::tuple select_workgroup_mapping(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + size_t skGrid) { + // Extract parameters from structured types + size_t M = problem.size.m; + size_t N = problem.size.n; + size_t K = problem.size.k; + size_t batch = problem.batch; + + size_t MT_M = config.mt.m; + size_t MT_N = config.mt.n; + size_t MT_K = config.mt.k; + + int nta = config.cache_hints_a; + int ntb = config.cache_hints_b; + + // Default is the closest we can get to a square + size_t max_CU_XCD = hardware.N_CU / hardware.NUM_XCD; + int defaultWGM = static_cast(ceil(std::sqrt(max_CU_XCD))); + + // Number of output MTs per split and batch + size_t numMT_M = math::safe_ceil_div(M, MT_M); + size_t numMT_N = math::safe_ceil_div(N, MT_N); + size_t numMTs = numMT_M * numMT_N; + + // What SK does -- we already have skGrid so just compute numWaves and splitFactor + auto numWaves = skGrid > numMTs ? math::safe_ceil_div(skGrid, hardware.N_CU) + : math::safe_ceil_div(numMTs, hardware.N_CU); + auto splitFactor = math::safe_ceil_div(skGrid, numMTs); + + // ------------------- + // NonTemporal Cases + // ------------------- + if (nta > 3 && ntb < 4) + return std::make_tuple(hardware.NUM_XCD, 1); + else if (nta < 4 && ntb > 3) + return std::make_tuple(hardware.NUM_XCD, + std::min(max_CU_XCD, math::safe_ceil_div(numMTs, hardware.NUM_XCD))); + else if (nta > 3 && ntb > 3) + return std::make_tuple(hardware.NUM_XCD, 1); + + // ------------------- + // WGMXCC Prediction + // ------------------- + // Default WGMXCC -- always number of XCD + int defaultWGMXCC = static_cast(hardware.NUM_XCD); + bool isWGMXCCset = false; + int out_wgmxcc = static_cast(defaultWGMXCC); + + // Batched GEMMs + if (batch > 1 && !isWGMXCCset) { + // Total tiles including batch count + size_t numTotalTiles = numMTs * batch; + + // if only one MT per each GEMM -> no mapping + // if less than hardware.NUM_XCD total tiles -> no mapping + if (numMTs == 1 || numTotalTiles <= hardware.NUM_XCD) { + out_wgmxcc = 1; + isWGMXCCset = true; + } + // else use the default (num_xcd) + } + + // If we are lucky that the splitFactor is a multiple of NUM_XCD -> no mapping + if ((splitFactor % hardware.NUM_XCD == 0) && !isWGMXCCset) { + out_wgmxcc = 1; + isWGMXCCset = true; + } + + // Small GEMMs + if ((numMTs <= hardware.NUM_XCD) && !isWGMXCCset) { + out_wgmxcc = 1; + isWGMXCCset = true; + } + + // For sizes that we have more than 2 waves of computations, we skip xcc mapping as MALL is + // more important -- matrix should not be skinny + // To avoid regressions, it's set to default, but it should actually be 1! + bool MallIsImportant = + (splitFactor == 1 && batch == 1 && numMTs > 2 * hardware.N_CU && numMT_M > 8 && numMT_N > 8); + if (MallIsImportant && !isWGMXCCset) { + out_wgmxcc = defaultWGMXCC; + isWGMXCCset = true; + } + + // ------------------- + // WGM Prediction + // ------------------- + // Default WGM + bool isWGMset = false; + int out_wgm = defaultWGM; + + // shortcut: + // 1. if we have decided to not remap xcc, there is no reason to use wgm + // 2. GEMMs that only have one tile in one dimension don't need wgm + // 3. Batched GEMMs don't need wgm (emprically -> batch count is often large!) + if (((out_wgmxcc == 1 && !MallIsImportant) || numMT_M == 1 || numMT_N == 1 || batch > 1) && + !isWGMset) { + out_wgm = 1; + isWGMset = true; + } + + // For tall cases (M >> N), if we have enough tiles to schedule, we use the number of tiles + // in the smaller dimension as WGM value + if (numMTs > hardware.N_CU && M > 10 * N && numMT_N <= 8) { + out_wgm = numMT_N; + isWGMset = true; + } + + // Cases where we have multiple rounds of computation per each CU + // To avoid regressions, it's set to defaultWGM. However, I think WGM=1 should be the winner + if (MallIsImportant && !isWGMset) { + out_wgm = defaultWGM; + isWGMset = true; + } + + if (!isWGMset) { + size_t numWGs = numWaves * splitFactor * numMTs; + size_t q = numWGs / hardware.NUM_XCD; + size_t r = numWGs % hardware.NUM_XCD; + + std::vector wgmList = {1, 2, 3, 4, 5, 6, 8, 16}; + int bestWGM = 1; + int bestL2 = std::numeric_limits::max(); + for (auto wgm : wgmList) { + auto slabTiles = numMT_M * std::min(wgm, static_cast(numMT_N)); + auto slabCount = math::safe_ceil_div(numMT_N, wgm); + auto edgeSlabWidth = numMT_N - (slabCount - 1) * wgm; + auto wgmL2Estimate = 0; + auto numXCD = std::min(hardware.NUM_XCD, numWGs); + + // Compute unique loads per L2 tile + for (uint32_t x = 0; x < numWaves * numXCD; ++x) { + // Range of "output tiles" that this xcd takes. + auto xccStart = q * x + (x < r ? x : r); + auto xccEnd = xccStart + q - 1 + (x < r ? 1 : 0); + // xccStart and xccEnd are supposed to be tile IDs + // In case of splitting, they are WG IDs. Modify to get tile IDs + xccStart /= splitFactor; + xccEnd /= splitFactor; + + auto slabStart = xccStart / slabTiles; + auto slabEnd = xccEnd / slabTiles; + + auto firstSlabWidth = (slabStart == slabCount - 1 ? edgeSlabWidth : wgm); + auto firstSlabStartIndex = xccStart % slabTiles; + auto firstSlabStartRow = firstSlabStartIndex / firstSlabWidth; + auto firstSlabEndRow = + std::min((firstSlabStartIndex + (xccEnd - xccStart)) / firstSlabWidth, numMT_M - 1); + auto rowsInFirstSlab = firstSlabEndRow - firstSlabStartRow + 1; + + auto lastSlabWidth = (slabEnd == slabCount - 1 ? edgeSlabWidth : wgm); + auto lastSlabEndIndex = xccEnd % slabTiles; + auto lastSlabEndRow = lastSlabEndIndex / lastSlabWidth; + auto colsInLastRow = (lastSlabEndIndex % lastSlabWidth) + 1; + auto colsInLastSlab = (lastSlabEndRow > 0 ? lastSlabWidth : colsInLastRow); + + size_t uniqueRows = 0; + size_t uniqueCols = 0; + if (slabEnd == slabStart) { + uniqueRows = lastSlabEndRow - firstSlabStartRow + 1; + uniqueCols = firstSlabWidth; + if (rowsInFirstSlab <= 2) uniqueCols = std::min(xccEnd - xccStart + 1, firstSlabWidth); + } else { + auto colsInFirstRow = firstSlabWidth - (xccStart % firstSlabWidth); + auto colsInFirstSlab = (rowsInFirstSlab > 1 ? firstSlabWidth : colsInFirstRow); + auto fullSlabs = slabEnd - slabStart - 1; + uniqueRows = + (fullSlabs > 0 ? numMT_M : std::min(rowsInFirstSlab + lastSlabEndRow + 1, numMT_M)); + uniqueCols = colsInFirstSlab + colsInLastSlab + fullSlabs * wgm; + } + + // Sum up the L2 loads over all XCD + // We should technically multiply by K (or splitted K), but it + // has no effect on sorting + auto xccL2Estimate = uniqueRows * MT_M + uniqueCols * MT_N; + wgmL2Estimate += xccL2Estimate; + } + + // If we have found a better WGM + if (wgmL2Estimate < bestL2) { + bestL2 = wgmL2Estimate; + bestWGM = wgm; + } + + if (get_runtime_options(config).debug_enabled) { + config.logger.log("WGM", wgm); + config.logger.log("L2Estimate", wgmL2Estimate); + } + } + + out_wgm = bestWGM; + } + + return std::make_tuple(out_wgmxcc, out_wgm); +} + +std::vector rank_configs(const problem_t& problem, + const hardware_t& hardware, + const std::vector& configs) { + if (configs.empty()) { throw std::runtime_error("No configurations provided."); } + + std::vector results(configs.size()); + + std::transform(std::execution::seq, + configs.begin(), + configs.end(), + results.begin(), + [&](const config_t& config) -> prediction_result_t { + if (!check_lds_capacity(hardware, config.mt, problem.a_dtype, problem.b_dtype)) { + return {std::numeric_limits::max(), config}; + } + double latency = compute_total_latency(problem, hardware, config, hardware.N_CU); + return {latency, config}; + }); + + results.erase(std::remove_if(results.begin(), + results.end(), + [](const prediction_result_t& p) { + return p.latency == std::numeric_limits::max(); + }), + results.end()); + + std::stable_sort(results.begin(), + results.end(), + [](const prediction_result_t& a, const prediction_result_t& b) { + return a.latency < b.latency; + }); + + if (results.empty()) { throw std::runtime_error("No valid configs found."); } + + // Compute arithmetic intensity for tie-breaking + // Flops = 2 * MT_M * MT_N * MT_K, Memory traffic = MT_M*MT_K + MT_K*MT_N + MT_M*MT_N + auto compute_arithmetic_intensity = [](const config_t& config) -> double { + const auto MT_M = config.mt.m; + const auto MT_N = config.mt.n; + const auto MT_K = config.mt.k; + + const double flops = static_cast(2ull * MT_M * MT_N * MT_K); + const double memory_traffic = static_cast(MT_M * MT_K + MT_N * MT_K + MT_M * MT_N); + + if (memory_traffic == 0.0) return 0.0; + return flops / memory_traffic; + }; + + // Apply tie-breaking logic for configs with similar latency + double best_latency = results.front().latency; + size_t num_the_same = 0; + + // Count the number of similar latencies + constexpr double epsilon = 1e-9; + // variance is set through environment variable ANALYTICAL_GEMM_HEURISTICS_VARIANCE + // Use runtime_options from first config if available, otherwise global singleton + const double top_N_heuristic = get_runtime_options(configs.front()).heuristics_variance; + for (const auto& res : results) { + bool within_top; + const double diff = std::abs(res.latency - best_latency); + + if (top_N_heuristic <= epsilon) { + // Absolute tolerance path + within_top = diff < epsilon; + } else { + // Relative tolerance path (guard denom) + const double denom = std::max(std::abs(best_latency), epsilon); + // If it's within top_N_heuristic%, include it. + within_top = (diff / denom) < top_N_heuristic; + } + + if (within_top) + ++num_the_same; + else + break; + } + + // Sort top candidates by arithmetic intensity (descending - highest first) + if (num_the_same > 1) { + std::stable_sort(results.begin(), + results.begin() + num_the_same, + [&compute_arithmetic_intensity](const prediction_result_t& a, + const prediction_result_t& b) { + return compute_arithmetic_intensity(a.config) > + compute_arithmetic_intensity(b.config); + }); + + // After arithmetic intensity tie-breaking, check if we still have ties + // among the top results (those with same latency and arithmetic intensity) + // Check if the top tiles still have the same arithmetic intensity + double first_ai = compute_arithmetic_intensity(results.front().config); + size_t num_same_ai = 1; + for (size_t i = 1; i < num_the_same; ++i) { + double current_ai = compute_arithmetic_intensity(results[i].config); + if (std::abs(current_ai - first_ai) < 1e-6) { + num_same_ai++; + } else { + break; + } + } + + // If we still have ties after arithmetic intensity, apply problem dimension tie-breaker + if (num_same_ai > 1) { + // Problem dimension-based tie breaker: + // If M > N, prefer tiles with larger MT_M + // If N > M, prefer tiles with larger MT_N + // If M == N, this tie-breaker doesn't apply (will use final tie-breaker) + + if (problem.size.m != problem.size.n) { + std::stable_sort(results.begin(), + results.begin() + num_same_ai, + [problem](const prediction_result_t& a, const prediction_result_t& b) { + if (problem.size.m > problem.size.n) { + // M-dominant: prefer larger MT_M + if (a.config.mt.m != b.config.mt.m) + return a.config.mt.m > b.config.mt.m; + // If MT_M is same, prefer larger MT_N as secondary + return a.config.mt.n > b.config.mt.n; + } else // N > M + { + // N-dominant: prefer larger MT_N + if (a.config.mt.n != b.config.mt.n) + return a.config.mt.n > b.config.mt.n; + // If MT_N is same, prefer larger MT_M as secondary + return a.config.mt.m > b.config.mt.m; + } + }); + } + + // Final tie-breaker: when all else is equal (including square problems), + // consistently prefer tiles with larger MT_M + // This ensures deterministic selection regardless of input order + std::stable_sort(results.begin(), + results.begin() + num_same_ai, + [](const prediction_result_t& a, const prediction_result_t& b) { + // Prefer larger MT_M first + if (a.config.mt.m != b.config.mt.m) return a.config.mt.m > b.config.mt.m; + // If MT_M is same, prefer larger MT_N + if (a.config.mt.n != b.config.mt.n) return a.config.mt.n > b.config.mt.n; + // If both MT_M and MT_N are same, prefer larger MT_K + return a.config.mt.k > b.config.mt.k; + }); + } + } + + return results; +} + +prediction_result_t select_config_mnk(size_t M, + size_t N, + size_t K, + const hardware_t& hardware, + const std::vector& configs) { + // Create a default problem_t with the provided M, N, K and reasonable defaults + problem_t problem; + problem.size.m = M; + problem.size.n = N; + problem.size.k = K; + problem.batch = 1; + problem.a_transpose = transpose_t::T; // Default to T + problem.b_transpose = transpose_t::N; // Default to N + problem.a_dtype = data_type_t::Half; // Default to fp16 + problem.b_dtype = data_type_t::Half; // Default to fp16 + problem.c_dtype = data_type_t::Half; // Default to fp16 + problem.d_dtype = data_type_t::Half; // Default to fp16 + problem.mi_dtype = data_type_t::Half; // Default to fp16 + problem.a_mx_block_size = 0; // Default MX block size + problem.b_mx_block_size = 0; // Default MX block size + + // Use the existing select_config function with the constructed problem + return select_config(problem, hardware, configs); +} + +prediction_result_t select_config(const problem_t& problem, + const hardware_t& hardware, + const std::vector& configs) { + auto ranked_configs = rank_configs(problem, hardware, configs); + + // Return the top configuration + return ranked_configs[0]; +} + +double compute_perf_gflops(const hardware_t& hardware, + const problem_t& problem, + const double latency) { + // Extract parameters from structured types + size_t M = problem.size.m; + size_t N = problem.size.n; + size_t K = problem.size.k; + size_t batch = problem.batch; + + // Compute total FLOPs + double total_FLOPs = 2.0 * M * N * K; // For GEMM, each multiply-add is 2 FLOPs + // Compute total time in seconds + double cycles_per_second = hardware.compute_clock_ghz * 1e9; // 1 GHz = 1e9 cycles per second + + double total_time_seconds = latency / cycles_per_second; + + // Compute performance in FLOPS + double FLOPS = total_FLOPs / total_time_seconds; + // Convert to GFLOPS + double GFLOPS = FLOPS / 1e9; // 1 TFLOP = 1e9 FLOPs + return GFLOPS; +} +} // namespace origami diff --git a/shared/origami/src/origami/streamk.cpp b/shared/origami/src/origami/streamk.cpp index 8c6af0bfb2c8..1973e3c308a1 100644 --- a/shared/origami/src/origami/streamk.cpp +++ b/shared/origami/src/origami/streamk.cpp @@ -1,504 +1,447 @@ // Copyright Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include "origami/hardware.hpp" #include "origami/streamk.hpp" -#include "origami/utils.hpp" -// #include - -namespace origami -{ - namespace streamk - { - namespace math - { - /** - * Performs `(n + d - 1) / d`, but is robust against the case where - * `(n + d - 1)` would overflow. - */ - template - __device__ __host__ inline constexpr N safe_ceil_div(N n, D d) - { - // Static cast to undo integral promotion. - return static_cast(d == 0 ? 0 : (n / d + (n % d != 0 ? 1 : 0))); - } - } // namespace math - - constexpr size_t num_iters_total(size_t output_tiles, size_t iters_per_tile) - { - return output_tiles * iters_per_tile; - } +#include "origami/gemm.hpp" +#include "origami/hardware.hpp" +#include "origami/math.hpp" +#include "origami/types.hpp" + +namespace origami { +namespace streamk { +size_t compute_number_of_output_tiles(size_t mt_m, size_t mt_n, size_t m, size_t n, size_t batch) { + size_t m_tiles = math::safe_ceil_div(m, mt_m); + size_t n_tiles = math::safe_ceil_div(n, mt_n); + return m_tiles * n_tiles * batch; +} - constexpr size_t num_iters_per_tile(size_t BLK_K, size_t k) - { - return math::safe_ceil_div(k, BLK_K); - } +/** + * @brief Returns number of k-iterations. + * + * @param output_tiles Number of output tiles. + * @param iters_per_tile Number of iterations per tile. + * @return constexpr size_t Number of total iterations. + */ +constexpr size_t num_iters_total(size_t output_tiles, size_t iters_per_tile) { + return output_tiles * iters_per_tile; +} - constexpr size_t num_iters_per_cta(size_t iters_total, int g) - { - return math::safe_ceil_div(iters_total, g); - } +/** + * @brief Returns number of k-iterations per tile. + * + * @param mt_k K-dimension tile size. + * @param k Reduction dimension. + * @return constexpr size_t Number of k-iteration per tile. + */ +constexpr size_t num_iters_per_tile(size_t mt_k, size_t k) { return math::safe_ceil_div(k, mt_k); } + +/** + * @brief Number of iterations per cta. + * + * @param iters_total Total number of k-iterations. + * @param g Number of workgroups (grid-size). + * @return constexpr size_t Number of iterations per cta. + */ +constexpr size_t num_iters_per_cta(size_t iters_total, int g) { + return math::safe_ceil_div(iters_total, g); +} - constexpr size_t - number_of_output_tiles(size_t BLK_M, size_t BLK_N, size_t m, size_t n, size_t batch) - { - size_t m_tiles = math::safe_ceil_div(m, BLK_M); - size_t n_tiles = math::safe_ceil_div(n, BLK_N); - return m_tiles * n_tiles * batch; - } +constexpr size_t num_fixup_peers_v2(size_t g, + size_t iters_total, + size_t iters_per_tile, + size_t iters_per_cta) { + // If tiles don't evenly divide there are always at least 2 fixup peers, and more if + // iters_per_tile > iters_per_cta + size_t hasFixup = + (iters_total % g == 0 && // Check if some WGs have more iters than others + iters_per_cta % iters_per_tile == 0) // Check if WGs have an even number of full tiles + ? 0 + : 1; + return math::safe_ceil_div(iters_per_tile, iters_per_cta) + hasFixup; +} - constexpr size_t num_fixup_peers_v2(size_t g, - size_t iters_total, - size_t iters_per_tile, - size_t iters_per_cta) - { - // If tiles don't evenly divide there are always at least 2 fixup peers, and more if iters_per_tile > iters_per_cta - size_t hasFixup - = (iters_total % g == 0 && // Check if some WGs have more iters than others - iters_per_cta % iters_per_tile - == 0) // Check if WGs have an even number of full tiles - ? 0 - : 1; - return math::safe_ceil_div(iters_per_tile, iters_per_cta) + hasFixup; - } +/** + * @brief Number of workgroups involved in the Stream-K's fixup step. + * + * @param g Number of total workgroups (grid-size.) + * @param iters_total Total number of k-iterations. + * @param iters_per_tile K-iterations per tile. + * @param iters_per_cta Number of iterations per workgroup. + * @return constexpr size_t Number of workgroups involved in fixup. + */ + +constexpr size_t num_fixup_peers(size_t iters_per_tile, size_t iters_per_cta) { + return math::safe_ceil_div(iters_per_tile, iters_per_cta); +} - constexpr size_t num_fixup_peers(size_t iters_per_tile, size_t iters_per_cta) - { - return math::safe_ceil_div(iters_per_tile, iters_per_cta); - } +/** + * @brief Returns the predicted latency for a given grid-size. + * + * @param mt BLK_M, BLK_N, BLK_K macro-tile. + * @param size M, N, K size. + * @param batch Number of batches. + * @param g Grid size to test. + * @param a alpha (a), fixed-size cost incurred by each workgroup. + * @param b Beta (b) incorporates conditional costs of outputting temporary partial. + * @param c Represents instruction and stall workload of each MAC-iteration. + * @param d Delta (d) is the cost of reading and accumulating the partial sums. + * @return double Predicted latency. + */ +std::tuple predicted_runtime(dim3_t mt, + dim3_t size, + size_t batch, + size_t g, + double a, + double b, + double c, + double d) { + size_t output_tiles = compute_number_of_output_tiles(mt.m, mt.n, size.m, size.n, batch); + size_t iters_per_tile = num_iters_per_tile(mt.k, size.k); + size_t iters_total = num_iters_total(output_tiles, iters_per_tile); + size_t iters_per_cta = num_iters_per_cta(iters_total, g); + size_t fixup_peers = num_fixup_peers(iters_per_tile, iters_per_cta); + + double runtime = a + (b * (fixup_peers > 1)) + (c * iters_per_cta) + (d * (fixup_peers - 1)); + + return std::make_tuple(runtime, iters_per_cta, fixup_peers); +} - const char* rtype_to_string(streamk::reduction_type r) - { - switch(r) - { - case streamk::reduction_type::Tree: - return "Tree"; - case streamk::reduction_type::Parallel: - return "Parallel"; - case streamk::reduction_type::None: - return "None"; - default: - return "Unknown"; - } +std::tuple predicted_runtime_v2(dim3_t mt, + dim3_t size, + size_t batch, + size_t g, + double a, + double b, + double c, + double d) { + size_t output_tiles = compute_number_of_output_tiles(mt.m, mt.n, size.m, size.n, batch); + size_t iters_per_tile = num_iters_per_tile(mt.k, size.k); + size_t iters_total = num_iters_total(output_tiles, iters_per_tile); + size_t iters_per_cta = num_iters_per_cta(iters_total, g); + size_t fixup_peers = num_fixup_peers_v2(g, iters_total, iters_per_tile, iters_per_cta); + + size_t remainder_tiles = output_tiles % g; + double k_split_ratio = remainder_tiles / static_cast(g); + + double cache_penalty = 0.0; + if (fixup_peers >= 1) { + // Calculate the ideal equal split ratio + double ideal_split_ratio = 1.0 / fixup_peers; + + // Measure deviation from the ideal equal split + double imbalance = 1 / std::abs(k_split_ratio - ideal_split_ratio); + + // Scale the penalty by the imbalance and the per-collaborator cost (d) + cache_penalty = d * imbalance * fixup_peers; + } + + // Include the cache penalty in the runtime prediction + double runtime = + a + (b * (fixup_peers > 1)) + (c * iters_per_cta) + (d * (fixup_peers - 1)) + cache_penalty; + + return std::make_tuple(runtime, iters_per_cta, fixup_peers, cache_penalty); +} + +reduction_t select_reduction(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + grid_selection_t algorithm) { + reduction_t reduction_strategy = reduction_t::tree; + + if (algorithm == grid_selection_t::k_split_aware) { + size_t tiles = compute_number_of_output_tiles( + config.mt.m, config.mt.n, problem.size.m, problem.size.n, problem.batch); + size_t cu_count = hardware.N_CU; + size_t iters_per_tile = std::max(size_t(1), num_iters_per_tile(config.mt.k, problem.size.k)); + + if (tiles < cu_count) { + // For problems with large k and low number of tiles, use parallel reduction + // TODO Benchmark to check if limits are correct + constexpr int MinItersForParallel = 64; + constexpr int MaxTilesForParallel = 64; + if (iters_per_tile >= MinItersForParallel && tiles <= MaxTilesForParallel) + reduction_strategy = reduction_t::parallel; } + } + return reduction_strategy; +} - std::tuple predicted_runtime(size_t BLK_M, - size_t BLK_N, - size_t BLK_K, - size_t m, - size_t n, - size_t k, - size_t batch, - int g, - double a, - double b, - double c, - double d) - { - size_t output_tiles = number_of_output_tiles(BLK_M, BLK_N, m, n, batch); - size_t iters_per_tile = num_iters_per_tile(BLK_K, k); - size_t iters_total = num_iters_total(output_tiles, iters_per_tile); - size_t iters_per_cta = num_iters_per_cta(iters_total, g); - size_t fixup_peers = num_fixup_peers(iters_per_tile, iters_per_cta); - - double runtime - = a + (b * (fixup_peers > 1)) + (c * iters_per_cta) + (d * (fixup_peers - 1)); - - return std::make_tuple(runtime, iters_per_cta, fixup_peers); - } +/** + * @brief Dynamically pick the minimum between the cu_count or number of tiles. + * @param problem Problem description (M, N, K, etc.) + * @param config Kernel configuration. + * @param cu_count cu count + * @return size_t minimum between the cu_count or number of tiles + */ +size_t grid_min_resources(const problem_t& problem, const config_t& config, size_t cu_count) { + size_t tiles = compute_number_of_output_tiles( + config.mt.m, config.mt.n, problem.size.m, problem.size.n, problem.batch); + return std::min(cu_count, tiles); +} - std::tuple predicted_runtime_v2(size_t BLK_M, - size_t BLK_N, - size_t BLK_K, - size_t m, - size_t n, - size_t k, - size_t batch, - int g, - double a, - double b, - double c, - double d) - { - size_t output_tiles = number_of_output_tiles(BLK_M, BLK_N, m, n, batch); - size_t iters_per_tile = num_iters_per_tile(BLK_K, k); - size_t iters_total = num_iters_total(output_tiles, iters_per_tile); - size_t iters_per_cta = num_iters_per_cta(iters_total, g); - size_t fixup_peers - = num_fixup_peers_v2(g, iters_total, iters_per_tile, iters_per_cta); - - size_t remainder_tiles = output_tiles % g; - double k_split_ratio = remainder_tiles / static_cast(g); - - double cache_penalty = 0.0; - if(fixup_peers >= 1) - { - // Calculate the ideal equal split ratio - double ideal_split_ratio = 1.0 / fixup_peers; - - // Measure deviation from the ideal equal split - double imbalance = 1 / std::abs(k_split_ratio - ideal_split_ratio); - - // Scale the penalty by the imbalance and the per-collaborator cost (d) - cache_penalty = d * imbalance * fixup_peers; - } - - // Include the cache penalty in the runtime prediction - double runtime = a + (b * (fixup_peers > 1)) + (c * iters_per_cta) - + (d * (fixup_peers - 1)) + cache_penalty; - - return std::make_tuple(runtime, iters_per_cta, fixup_peers, cache_penalty); - } +/** + * @brief Dynamically pick the minimum between the cu_count or number of tiles, + * and scale down really large sizes to use fewer CUs for power/energy savings + * @param problem Problem description (M, N, K, etc.) + * @param config Kernel configuration. + * @param cu_count cu count + * @return size_t minimum between the cu_count or number of tiles + */ +size_t grid_energy_aware(const problem_t& problem, const config_t& config, size_t cu_count) { + size_t tiles = compute_number_of_output_tiles( + config.mt.m, config.mt.n, problem.size.m, problem.size.n, problem.batch); + size_t sk_grid = cu_count; + + if (tiles > sk_grid) { + for (size_t i = 1; i <= 32; i *= 2) { + size_t tiles_per_cu = math::safe_ceil_div(i * tiles, cu_count); + size_t reduced_grid = math::safe_ceil_div(i * tiles, tiles_per_cu); + float utilization = static_cast(reduced_grid) / static_cast(cu_count); + if (utilization > 0.75f) { + if (utilization < 1.0f) sk_grid = reduced_grid; + break; + } + } + } + return std::min(sk_grid, tiles); +} - int best_predicted_grid_size(size_t BLK_M, - size_t BLK_N, - size_t BLK_K, - size_t m, - size_t n, - size_t k, - size_t batch, - int grid_start, - int grid_end, - bool verbose = false) - { - - // Fixed overhead alpha (a), fixed-size cost incurred by - // each work-group, e.g. the grid launch latency, the initial - // compulsary cache misses, the cost of writing the final output tile - // to C. - // double a = 5544 + 9130; - double a = 2.772 + 4.565; // 5.04 + 8.30; - - // Beta (b) incorporates conditional costs of outputting temporary partial - // sums for scenarios where the number of output tiles does not quantize - // perfectly across the number of processors. - double b = 3.01; // 5.47; 6017; - - // c represents instruction and stall workload of each MAC-iteration. - double c = 2.2935; // 4.17; 4587; - - // Delta (d) is the cost of reading and accumulating the partial sums from - // other work-groups covering the same tile. - double d = 10.22; // 18.59; 20449; - - std::pair min_grid_runtime; - std::pair min_grid_runtime_v2; - min_grid_runtime.second = std::numeric_limits::max(); - min_grid_runtime_v2.second = std::numeric_limits::max(); - - size_t g = grid_start; - - // Predict the number of CTAs to use between 1 and 304 - for(; g <= static_cast(grid_end); ++g) - { - auto [runtime, iters_per_cta, fixup_peers] - = predicted_runtime(BLK_M, BLK_N, BLK_K, m, n, k, batch, g, a, b, c, d); - - auto [runtime_v2, iters_per_cta_v2, fixup_peers_v2, cache_penalty] - = predicted_runtime_v2(BLK_M, BLK_N, BLK_K, m, n, k, batch, g, a, b, c, d); - - if(verbose) - { - std::cout << "[original] " - << "grid size: " << g << ", runtime: " << runtime - << ", iters_per_cta: " << iters_per_cta << ", fixup_peers: " - << fixup_peers - // << ", cache_penalty: " << cache_penalty - << ", m: " << m << ", n: " << n << ", k: " << k << ", a: " << a - << ", b: " << b << ", c: " << c << ", d: " << d << std::endl; - - std::cout << "[cache-offset] " - << "grid size: " << g << ", runtime: " << runtime_v2 - << ", iters_per_cta: " << iters_per_cta_v2 - << ", fixup_peers: " << fixup_peers_v2 - << ", cache_penalty: " << cache_penalty << ", m: " << m - << ", n: " << n << ", k: " << k << ", a: " << a << ", b: " << b - << ", c: " << c << ", d: " << d << std::endl; - } - - if(min_grid_runtime.second > runtime) - { - min_grid_runtime.first = g; - min_grid_runtime.second = runtime; - } - - if(min_grid_runtime_v2.second > runtime_v2) - { - min_grid_runtime_v2.first = g; - min_grid_runtime_v2.second = runtime_v2; - } - } - - if(verbose) - { - std::cout << "[original] Number of Output Tiles: " - << number_of_output_tiles(BLK_M, BLK_N, m, n, batch) << std::endl; - std::cout << "[original] Minimum runtime: " << min_grid_runtime.second - << " @ grid size: " << min_grid_runtime.first << std::endl; - - std::cout << "[cache-offset] Number of Output Tiles: " - << number_of_output_tiles(BLK_M, BLK_N, m, n, batch) << std::endl; - std::cout << "[cache-offset] Minimum runtime: " << min_grid_runtime_v2.second - << " @ grid size: " << min_grid_runtime_v2.first << std::endl; - } - - return min_grid_runtime_v2.first; - } +/** + * @brief Dynamically predict the best grid-size by weighing the cost of the fix-up + * step and the cost of processing MAC-loop instructions. When the cost of fix-up + * is the bottleneck, use smaller grid size. + * Architecture dependent. + * @param problem Problem description (M, N, K, etc.) + * @param config Kernel configuration. + * @param grid_start grid_start is 1 + * @param grid_end grid_end is cu_count + * @return size_t minimum between the cu_count or number of tiles + */ +size_t grid_reduction_cost_aware(const problem_t& problem, + const config_t& config, + size_t grid_start, + size_t grid_end) { + // Fixed overhead alpha (a), fixed-size cost incurred by + // each work-group, e.g. the grid launch latency, the initial + // compulsary cache misses, the cost of writing the final output tile + // to C. + // double a = 5544 + 9130; + double a = 2.772 + 4.565; // 5.04 + 8.30; + + // Beta (b) incorporates conditional costs of outputting temporary partial + // sums for scenarios where the number of output tiles does not quantize + // perfectly across the number of processors. + double b = 3.01; // 5.47; 6017; + + // c represents instruction and stall workload of each MAC-iteration. + double c = 2.2935; // 4.17; 4587; + + // Delta (d) is the cost of reading and accumulating the partial sums from + // other work-groups covering the same tile. + double d = 10.22; // 18.59; 20449; + + std::pair min_grid_runtime; + std::pair min_grid_runtime_v2; + min_grid_runtime.second = std::numeric_limits::max(); + min_grid_runtime_v2.second = std::numeric_limits::max(); + + size_t g = grid_start; + + // Predict the number of CTAs to use between 1 and 304 + for (; g <= static_cast(grid_end); ++g) { + auto [runtime, iters_per_cta, fixup_peers] = + predicted_runtime(config.mt, problem.size, problem.batch, g, a, b, c, d); + + auto [runtime_v2, iters_per_cta_v2, fixup_peers_v2, cache_penalty] = + predicted_runtime_v2(config.mt, problem.size, problem.batch, g, a, b, c, d); + + if (min_grid_runtime.second > runtime) { + min_grid_runtime.first = g; + min_grid_runtime.second = runtime; + } - size_t get_workspace( - size_t x, - size_t y, - size_t mt_m, - size_t mt_n, - size_t bpe_c, - size_t grid, - size_t tiles, - reduction_type reduction) - { - size_t size = 0; - if(reduction == reduction_type::Tree) - { - if(tiles % grid == 0) - { - size_t tileSize = mt_m * mt_n * bpe_c; - size += tileSize * grid; - } - } - else if(reduction == reduction_type::Parallel) - { - size_t splitSize = x * y * bpe_c; - size_t splitCount = grid / tiles; - size += splitSize * splitCount; - } - return size; - } + if (min_grid_runtime_v2.second > runtime_v2) { + min_grid_runtime_v2.first = g; + min_grid_runtime_v2.second = runtime_v2; + } + } - reduction_type select_reduction( - size_t x, - size_t y, - size_t z, - size_t batch, - size_t mt_m, - size_t mt_n, - size_t mt_k, - const hardware_t& analytical_hardware, - int dynamic_grid_version) - { - reduction_type reductionStrat = reduction_type::Tree; - - if(dynamic_grid_version == 6) - { - size_t tiles = number_of_output_tiles(mt_m, mt_n, x, y, batch); - size_t cu_count = analytical_hardware.N_CU; - size_t iters_per_tile = std::max(size_t(1), math::safe_ceil_div(z, mt_k)); - - if (tiles < cu_count) - { - // For problems with large k and low number of tiles, use parallel reduction - // TODO Benchmark to check if limits are correct - constexpr int MinItersForParallel = 64; - constexpr int MaxTilesForParallel = 64; - if (iters_per_tile >= MinItersForParallel && tiles <= MaxTilesForParallel) - reductionStrat = reduction_type::Parallel; - } - } - - return reductionStrat; - } + if (get_runtime_options(config).debug_enabled) { + config.logger.log("grid_reduction_cost_aware_best_grid_size_original", min_grid_runtime.first); + config.logger.log("grid_reduction_cost_aware_best_runtime_original", min_grid_runtime.second); + config.logger.log("grid_reduction_cost_aware_best_grid_size_cache_offset", min_grid_runtime_v2.first); + config.logger.log("grid_reduction_cost_aware_best_runtime_cache_offset", min_grid_runtime_v2.second); + } + + return min_grid_runtime_v2.first; +} + +/** + * @brief Fix Stream-K algorithm to function like a Data-parallel schedule + * where grid size is equal to the number of output tiles. + * @param problem Problem description (M, N, K, etc.) + * @param config Kernel configuration. + * @return size_t number of tiles + */ +size_t grid_data_parallel(const problem_t& problem, const config_t& config) { + return compute_number_of_output_tiles( + config.mt.m, config.mt.n, problem.size.m, problem.size.n, problem.batch); +} - size_t select_grid( - size_t x, - size_t y, - size_t z, - size_t batch, - bool trans_a, - bool trans_b, - size_t element_size_A, - size_t element_size_B, - size_t element_size_out, - data_type_t mi_datatype, - size_t workspace_size, - size_t mt_m, - size_t mt_n, - size_t mt_k, - size_t mi_m, - size_t mi_n, - size_t mi_k, - int workgroup_mapping, - size_t workspace_size_per_elem_c, - int occupancy, - const hardware_t& analytical_hardware, - int dynamic_grid_version, - reduction_type reduction_strategy, - size_t max_cus) - { - size_t cu_count = analytical_hardware.N_CU; - if(max_cus > 0) - cu_count = std::min(cu_count, max_cus); - size_t tiles = number_of_output_tiles(mt_m, mt_n, x, y, batch); - - // Dynamically pick the minimum between the cu_count or number of tiles. - if(dynamic_grid_version == 1) - { - return std::min(cu_count, tiles); - } - - // Dynamically pick the minimum between the cu_count or number of tiles, - // and scale down really large sizes to use fewer CUs for power/energy savings. - else if(dynamic_grid_version == 2) - { - size_t sk_grid = cu_count; - if(tiles > sk_grid) - { - for(size_t i = 1; i <= 32; i *= 2) - { - size_t tilesPerCU = math::safe_ceil_div(i * tiles, cu_count); - size_t reducedGrid = math::safe_ceil_div(i * tiles, tilesPerCU); - float utilization = ((float)reducedGrid) / ((float)cu_count); - if(utilization > 0.75f) - { - if(utilization < 1.0f) - sk_grid = reducedGrid; - break; - } - } - } - - return std::min(sk_grid, tiles); - } - // Dynamically predict the best grid-size by weighing the cost of the fix-up - // step and the cost of processing MAC-loop instructions. When the cost of fix-up - // is the bottleneck, use smaller grid size. - // Architecture dependent. - else if(dynamic_grid_version == 3) - { - return origami::streamk::best_predicted_grid_size(mt_m, - mt_n, - mt_k, - x, - y, - z, - batch, - 1, - cu_count); - } - // Fix Stream-K algorithm to function like a Data-parallel schedule - // where grid size is equal to the number of output tiles. - else if(dynamic_grid_version == 4) - { - return origami::streamk::number_of_output_tiles( - mt_m, mt_n, x, y, batch); - } - else if(dynamic_grid_version == 5) - { - return origami::select_best_grid_size(x, - y, - z, - batch, - trans_a, - trans_b, - analytical_hardware, - mt_m, - mt_n, - mt_k, - mi_m, - mi_n, - mi_k, - element_size_A, - element_size_B, - element_size_out, - mi_datatype, - 0, - 0.0, - workgroup_mapping, - 10, - max_cus); - } - else if(dynamic_grid_version == 6) - { - size_t iters_per_tile = std::max(size_t(1), math::safe_ceil_div(z, mt_k)); - size_t sk_grid = tiles; // Fallback if no good fractional tile is found - if(max_cus > 0) - sk_grid = std::min(sk_grid, max_cus); - size_t tile_size = mt_m * mt_n * workspace_size_per_elem_c; - // More tiles than CUs - // Distribute tiles evenly across maximum number of CUs - // Split remaining tiles as evenly as possible for better caching - if(tiles > cu_count) - { - size_t virt_cu_count = cu_count; - if (occupancy > 1 && max_cus == 0) - virt_cu_count *= occupancy; - - const std::vector tile_fractions = {0.0, 1.0/2.0, 1.0/8.0, 1.0/5.0, 1.0/4.0, 1.0/3.0}; - size_t min_even_tiles = tiles / virt_cu_count; - for(double frac: tile_fractions) - { - size_t frac_grid = (size_t)((tiles / (min_even_tiles + frac)) + 0.5); - // Check if higher occupancy would cause excessive workspace requirements (set current limit to 128MB) - if((tiles % frac_grid != 0) && (tile_size * frac_grid > 128*1024*1024)) - continue; - if(frac_grid <= virt_cu_count) - { - sk_grid = frac_grid; - break; - } - } - } - // Fewer tiles than CUs - // Split tiles evenly in k-dimension - // Attempt to maximize CU utilization, up to a peak number of splits - // Max splitting is currently constant, but should be dependant on K dimension - else if (tiles < cu_count) - { - // For problems with large k and low number of tiles, use parallel reduction - // TODO Benchmark to check if limits are correct - // constexpr int MinItersForParallel = 64; - // constexpr int MaxTilesForParallel = 16; - constexpr int MinItersPerCU = 8; - // if (iters_per_tile >= MinItersForParallel && tiles <= MaxTilesForParallel) - if(reduction_strategy == reduction_type::Parallel) - { - size_t virt_cu_count = cu_count; - // TODO check if using occupancy info makes workspace too large - // if (occupancy > 1 && max_cus == 0) - // virt_cu_count *= occupancy; - - // Find max splitting factor to use as much of GPU as possible - size_t maxSplitsForTiles = virt_cu_count / tiles; - - // Find max splitting factor to ensure each CU has a minimum number of iterations to do - size_t maxSplitsForIters = iters_per_tile / MinItersPerCU; - - size_t maxSplits = std::min(maxSplitsForTiles, maxSplitsForIters); - sk_grid = tiles * maxSplits; - } - else - { - const std::vector tile_fractions = {16, 12, 8, 6, 4, 3, 2, 1}; - for(size_t frac: tile_fractions) - { - size_t splitGrid = tiles * frac; - size_t itersPerCU = iters_per_tile / frac; - if(splitGrid <= cu_count && itersPerCU >= 8) - { - sk_grid = splitGrid; - break; - } - } - } - } - - if (tiles % sk_grid != 0 && tile_size * sk_grid > workspace_size) - sk_grid = tiles; - return sk_grid; - } - // If no option is specified, launch exactly cu_count worth of workgroups. - else - { - return cu_count; - } +size_t grid_analytical(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + size_t biggest_allowable_split, + size_t max_cus) { + // Extract parameters from structured types + size_t M = problem.size.m; + size_t N = problem.size.n; + size_t K = problem.size.k; + size_t batch = problem.batch; + + size_t MT_M = config.mt.m; + size_t MT_N = config.mt.n; + + // compute how many 32×32 tiles are needed in each dim, + // then multiply to get total grid size: + size_t grid = ((M + MT_M - 1) / MT_M) * ((N + MT_N - 1) / MT_N) * batch; + + size_t max_hw_split = std::floor(hardware.N_CU / grid); + size_t MAX_SPLIT = std::min(biggest_allowable_split, max_hw_split); + + size_t best_split = 1; + double best_latency = std::numeric_limits::infinity(); + + for (size_t split = 1; split <= MAX_SPLIT; ++split) { + double latency = compute_total_latency(problem, hardware, config, max_cus); + + if (latency < best_latency) { + best_latency = latency; + best_split = split; + } + best_latency = latency; + best_split = split; + } + + size_t best_grid = best_split * grid; + + // you now have both `grid` and `best_split`— + // return whichever is appropriate (here we stick with split): + return best_grid; +} + +size_t grid_k_split_aware(const problem_t& problem, + const config_t& config, + size_t cu_count, + size_t max_cus) { + size_t tiles = compute_number_of_output_tiles( + config.mt.m, config.mt.n, problem.size.m, problem.size.n, problem.batch); + + size_t sk_grid = tiles; // Fallback if no good fractional tile is found + if (max_cus > 0) sk_grid = std::min(sk_grid, max_cus); + + const size_t iters_per_tile = num_iters_per_tile(config.mt.k, problem.size.k); + + const size_t tile_size = config.mt.m * config.mt.n * config.workspace_size_per_elem_c; + + // More tiles than CUs + // Distribute tiles evenly across maximum number of CUs + // Split remaining tiles as evenly as possible for better caching + if (tiles > cu_count) { + size_t virt_cu_count = cu_count; + if (config.occupancy > 1 && max_cus == 0) virt_cu_count *= config.occupancy; + + const std::vector tile_fractions = { + 0.0, 1.0 / 2.0, 1.0 / 8.0, 1.0 / 5.0, 1.0 / 4.0, 1.0 / 3.0}; + const size_t min_even_tiles = tiles / virt_cu_count; + + for (double frac : tile_fractions) { + const size_t frac_grid = static_cast((tiles / (min_even_tiles + frac)) + 0.5); + + // Check if higher occupancy would cause excessive workspace requirements (set current limit + // to 128MB) + if ((tiles % frac_grid != 0) && (tile_size * frac_grid > 128ull * 1024ull * 1024ull)) + continue; + + if (frac_grid <= virt_cu_count) { + sk_grid = frac_grid; + break; + } + } + } + // Fewer tiles than CUs + // Split tiles evenly in k-dimension + // Attempt to maximize CU utilization, up to a peak number of splits + // Max splitting is currently constant, but should be dependant on K dimension + else if (tiles < cu_count) { + // For problems with large k and low number of tiles, use parallel reduction + // TODO Benchmark to check if limits are correct + // constexpr int MinItersForParallel = 64; + // constexpr int MaxTilesForParallel = 16; + constexpr int MinItersPerCU = 8; + + if (config.reduction_strategy == reduction_t::parallel) { + size_t virt_cu_count = cu_count; + // TODO check if using occupancy info makes workspace too large + // if (occupancy > 1) + // virt_cu_count *= occupancy; + + // Find max splitting factor to use as much of GPU as possible + const size_t maxSplitsForTiles = virt_cu_count / tiles; + + // Find max splitting factor to ensure each CU has a minimum number of iterations to do + const size_t maxSplitsForIters = iters_per_tile / MinItersPerCU; + + const size_t maxSplits = std::min(maxSplitsForTiles, maxSplitsForIters); + sk_grid = tiles * maxSplits; + } else { + const std::vector tile_fractions = {16, 12, 8, 6, 4, 3, 2, 1}; + for (size_t frac : tile_fractions) { + const size_t splitGrid = tiles * frac; + const size_t itersPerCU = iters_per_tile / frac; + if (splitGrid <= cu_count && itersPerCU >= MinItersPerCU) { + sk_grid = splitGrid; + break; } - } // namespace streamk + } + } + } + + if (tiles % sk_grid != 0 && tile_size * sk_grid > config.workspace_size) sk_grid = tiles; + + return sk_grid; +} + +size_t select_grid_size(const problem_t& problem, + const hardware_t& hardware, + const config_t& config, + grid_selection_t algorithm, + size_t max_cus) { + size_t cu_count = hardware.N_CU; + if (max_cus > 0) cu_count = std::min(cu_count, max_cus); + + switch (algorithm) { + case grid_selection_t::min_resources: + return streamk::grid_min_resources(problem, config, cu_count); + + case grid_selection_t::energy_aware: + return streamk::grid_energy_aware(problem, config, cu_count); + + case grid_selection_t::reduction_cost_aware: + return streamk::grid_reduction_cost_aware(problem, config, 1, cu_count); + + case grid_selection_t::data_parallel: return streamk::grid_data_parallel(problem, config); + + case grid_selection_t::analytical: + return streamk::grid_analytical(problem, hardware, config, 10, max_cus); + + case grid_selection_t::k_split_aware: + return streamk::grid_k_split_aware(problem, config, cu_count, max_cus); + + case grid_selection_t::number_of_cus: + default: return hardware.N_CU; + } } +} // namespace streamk +} // namespace origami diff --git a/shared/origami/src/origami/types.cpp b/shared/origami/src/origami/types.cpp new file mode 100644 index 000000000000..add4dcbde81c --- /dev/null +++ b/shared/origami/src/origami/types.cpp @@ -0,0 +1,129 @@ +// Copyright Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "origami/types.hpp" + +#include +#include +#include + +namespace origami { + +runtime_options::runtime_options() { update_from_env(); } + +runtime_options::runtime_options(bool debug, bool heuristics, double variance) + : debug_enabled(debug), heuristics_enabled(heuristics), heuristics_variance(variance) {} + +runtime_options& runtime_options::get() { + static runtime_options instance; + return instance; +} + +bool runtime_options::read_debug_from_env() { + const char* env = std::getenv("ANALYTICAL_GEMM_DEBUG"); + return env && std::string(env) == "1"; +} + +bool runtime_options::read_heuristics_from_env() { + const char* env = std::getenv("ANALYTICAL_GEMM_HEURISTICS"); + return !(env && std::string(env) == "0"); +} + +double runtime_options::read_heuristics_variance_from_env() { + constexpr double default_variance = 0.01; // 1% + + if (const char* env = std::getenv("ANALYTICAL_GEMM_HEURISTICS_VARIANCE")) { + try { + double val = std::stod(env); + if (std::isfinite(val) && val > 0.0) { return val; } + } catch (...) { + // fall through to default + } + } + return default_variance; +} + +void runtime_options::update_from_env() { + debug_enabled = read_debug_from_env(); + heuristics_enabled = read_heuristics_from_env(); + heuristics_variance = read_heuristics_variance_from_env(); +} + +int datatype_to_bits(data_type_t type) { + switch (type) { + case data_type_t::Float: return 32; + case data_type_t::Double: return 64; + case data_type_t::ComplexFloat: return 64; + case data_type_t::ComplexDouble: return 128; + case data_type_t::Half: return 16; + case data_type_t::Int8x4: return 32; + case data_type_t::Int32: return 32; + case data_type_t::BFloat16: return 16; + case data_type_t::Int8: return 8; + case data_type_t::Int4: return 4; + case data_type_t::Int64: return 64; + case data_type_t::XFloat32: return 32; + case data_type_t::Float8_fnuz: return 8; + case data_type_t::BFloat8_fnuz: return 8; + case data_type_t::Float8BFloat8_fnuz: return 8; + case data_type_t::BFloat8Float8_fnuz: return 8; + case data_type_t::Float8: return 8; + case data_type_t::BFloat8: return 8; + case data_type_t::Float8BFloat8: return 8; + case data_type_t::BFloat8Float8: return 8; + case data_type_t::Float6: return 6; + case data_type_t::BFloat6: return 6; + case data_type_t::Float4: return 4; + default: return -1; // Invalid type + } +} + +std::string datatype_to_string(data_type_t type) { + switch (type) { + case data_type_t::Float: return "Float"; + case data_type_t::Double: return "Double"; + case data_type_t::ComplexFloat: return "ComplexFloat"; + case data_type_t::ComplexDouble: return "ComplexDouble"; + case data_type_t::Half: return "Half"; + case data_type_t::Int8x4: return "Int8x4"; + case data_type_t::Int32: return "Int32"; + case data_type_t::BFloat16: return "BFloat16"; + case data_type_t::Int8: return "Int8"; + case data_type_t::Int4: return "Int4"; + case data_type_t::Int64: return "Int64"; + case data_type_t::XFloat32: return "XFloat32"; + case data_type_t::Float8_fnuz: return "Float8_fnuz"; + case data_type_t::BFloat8_fnuz: return "BFloat8_fnuz"; + case data_type_t::Float8BFloat8_fnuz: return "Float8BFloat8_fnuz"; + case data_type_t::BFloat8Float8_fnuz: return "BFloat8Float8_fnuz"; + case data_type_t::Float8: return "Float8"; + case data_type_t::BFloat8: return "BFloat8"; + case data_type_t::Float8BFloat8: return "Float8BFloat8"; + case data_type_t::BFloat8Float8: return "BFloat8Float8"; + case data_type_t::Float6: return "Float6"; + case data_type_t::BFloat6: return "BFloat6"; + case data_type_t::Float4: return "Float4"; + default: return "Invalid"; + } +} + +data_type_t string_to_datatype(std::string s) { + if (s == "f32") return data_type_t::Float; + if (s == "c32") return data_type_t::ComplexFloat; + if (s == "c64") return data_type_t::ComplexDouble; + if (s == "f64") return data_type_t::Double; + if (s == "f16") return data_type_t::Half; + if (s == "i32") return data_type_t::Int32; + if (s == "bf16") return data_type_t::BFloat16; + if (s == "i8") return data_type_t::Int8; + if (s == "i4") return data_type_t::Int4; + if (s == "xf32") return data_type_t::XFloat32; + if (s == "f8") return data_type_t::Float8; + if (s == "bf8") return data_type_t::BFloat8; + if (s == "f6") return data_type_t::Float6; + if (s == "bf6") return data_type_t::BFloat6; + if (s == "f4") return data_type_t::Float4; + return data_type_t::None; +} + +} // namespace origami diff --git a/shared/origami/src/origami/utils.cpp b/shared/origami/src/origami/utils.cpp deleted file mode 100644 index 39f9909e1566..000000000000 --- a/shared/origami/src/origami/utils.cpp +++ /dev/null @@ -1,733 +0,0 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "origami/utils.hpp" - -#include -#include // For timing -#include -#include -#include // For output formatting -#include - -namespace origami { - -static double read_heuristics_variance_env_var() { - constexpr double default_variance = 0.01; // 1% - - if (const char* env = std::getenv("ANALYTICAL_GEMM_HEURISTICS_VARIANCE")) { - try { - double val = std::stod(env); - if (std::isfinite(val) && val > 0.0) { - return val; - } - } catch (...) { - // fall through to default - } - } - return default_variance; -} - -// -// Tiebreaker function. -// -void pick_best_tile_by_arithmetic_intensity(std::vector& top_results, - size_t num_to_sort) { - if (top_results.empty()) { - throw std::runtime_error("pick_best_tile_by_arithmetic_intensity received empty list."); - } - - // 1) Define a helper function to compute the arithmetic intensity of a tile. - // Here we assume: - // - Flops for tile (MT_M, MT_N, MT_K) is: 2 * MT_M * MT_N * MT_K - // - Memory traffic approximated as: MT_M*MT_K + MT_K*MT_N + MT_M*MT_N - // - Arithmetic intensity = flops / memory_traffic - auto compute_arithmetic_intensity = [](const result_tuple& t) -> double { - // The tuple is: (latency, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K) - auto MT_M = std::get<1>(t); - auto MT_N = std::get<2>(t); - auto MT_K = std::get<3>(t); - - double flops = static_cast(2ull * MT_M * MT_N * MT_K); - double memory_traffic = static_cast(MT_M * MT_K + MT_N * MT_K + MT_M * MT_N); - - // Avoid division by zero. - if (memory_traffic == 0.0) return 0.0; - - return flops / memory_traffic; - }; - // 2) Sort the results in descending order of arithmetic intensity - // (highest arithmetic intensity first). - std::stable_sort(top_results.begin(), top_results.begin() + num_to_sort, - [&](const result_tuple& a, const result_tuple& b) { - double ai_a = compute_arithmetic_intensity(a); - double ai_b = compute_arithmetic_intensity(b); - return ai_a > ai_b; // descending - }); - // 3) Return the tile with the highest arithmetic intensity -} - -result_tuple pick_best_tile_with_dimension_priority(const std::vector& top_results, - size_t M, size_t N, size_t K) { - if (top_results.empty()) { - throw std::runtime_error("pick_best_tile_with_dimension_priority received empty list."); - } - - // 1) Determine whether M or N is more important - // (based on which is larger), and always place K last. - // This yields a priority order of either { 'M', 'N', 'K' } - // or { 'N', 'M', 'K' }. - std::vector dimPriority; - if (M >= N) - dimPriority = {'M', 'N', 'K'}; - else - dimPriority = {'N', 'M', 'K'}; - - // 2) Helper function to extract the tile dimension: - // (latency, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K) - auto getTileSize = [](const result_tuple& t, char dimChar) -> size_t { - switch (dimChar) { - case 'M': - return std::get<1>(t); // MT_M - case 'N': - return std::get<2>(t); // MT_N - case 'K': - return std::get<3>(t); // MT_K - default: - return 0; - } - }; - - // 3) Sort in descending order according to the dimension priority. - // - Compare dimensionPriority[0] first - // - If there's a tie, compare dimensionPriority[1] - // - If still a tie, compare dimensionPriority[2] - // - If they're all equal, consider them tied - std::vector sorted = top_results; // copy - std::stable_sort(sorted.begin(), sorted.end(), - [&](const result_tuple& a, const result_tuple& b) { - for (char d : dimPriority) { - size_t ta = getTileSize(a, d); - size_t tb = getTileSize(b, d); - if (ta > tb) return true; - if (ta < tb) return false; - } - // If all relevant dimensions are the same, treat as a tie - return false; - }); - - // 4) Return the best tile (the first after sorting). - return sorted.front(); -} - -size_t select_best_grid_size(size_t M, size_t N, size_t K, size_t batch, bool transA, bool transB, - const hardware_t& hardware, size_t MT_M, size_t MT_N, size_t MT_K, - size_t MI_M, size_t MI_N, size_t MI_K, size_t element_size_A, - size_t element_size_B, size_t element_size_out, - data_type_t mi_datatype, size_t mx_block_size, double H_L2, int WGM, - size_t biggest_allowable_split, size_t max_cus) { - // compute how many 32×32 tiles are needed in each dim, - // then multiply to get total grid size: - size_t grid = ((M + MT_M - 1) / MT_M) * ((N + MT_N - 1) / MT_N) * batch; - - size_t max_hw_split = std::floor(hardware.N_CU / grid); - size_t MAX_SPLIT = std::min(biggest_allowable_split, max_hw_split); - - size_t best_split = 1; - double best_latency = std::numeric_limits::infinity(); - - for (size_t split = 1; split <= MAX_SPLIT; ++split) { - double latency = - compute_total_latency(hardware, M, N, - K, // problem dims - batch, transA, transB, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K, - element_size_A, // ElementSizeA - element_size_B, // ElementSizeB - element_size_out, // ElementSizeout - mi_datatype, mx_block_size, WGM, 0, 0, 0, split, max_cus); - - if (latency < best_latency) { - best_latency = latency; - best_split = split; - } - best_latency = latency; - best_split = split; - } - - size_t best_grid = best_split * grid; - - // you now have both `grid` and `best_split`— - // return whichever is appropriate (here we stick with split): - return best_grid; -} - -std::vector select_best_macro_tile_size(size_t M, size_t N, size_t K, size_t batch, - bool transA, bool transB, - const hardware_t& hardware, - const std::vector& MT_list, - size_t element_size_A, // In bits - size_t element_size_B, // In bits - size_t element_size_out, // In bits - data_type_t mi_datatype, size_t mx_block_size, - double H_L2, bool print, int defaultWGM, - size_t max_cus) { - std::vector valid_results; - valid_results.reserve(MT_list.size()); - - // bool tf32_emu = ((mi_datatype == data_type_t::XFloat32) - // && (hardware.arch == hardware_t::architecture_t::gfx950)); - - for (const auto& mt : MT_list) { - size_t MT_M = std::get<0>(mt); - size_t MT_N = std::get<1>(mt); - size_t MT_K = std::get<2>(mt); - size_t MI_M = std::get<3>(mt); - size_t MI_N = std::get<4>(mt); - size_t MI_K = std::get<5>(mt); - int occupancy = std::get<6>(mt); - int WGM = std::get<7>(mt); - int non_temporal_a = std::get<8>(mt); - int non_temporal_b = std::get<9>(mt); - - if (hardware_t::is_debug_enabled()) { - std::cout << "Evaluating MT_M=" << MT_M << ", MT_N=" << MT_N << ", MT_K=" << MT_K - << ", MI_M=" << MI_M << ", MI_N=" << MI_N << ", MI_K=" << MI_K << "\n"; - } - - if (check_lds_capacity(hardware, MT_M, MT_N, MT_K, element_size_A)) { - double Total_latency = compute_total_latency( - hardware, M, N, K, batch, transA, transB, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K, - element_size_A, element_size_B, element_size_out, mi_datatype, mx_block_size, - defaultWGM, non_temporal_a, non_temporal_b, occupancy, 0, max_cus); - - valid_results.emplace_back(Total_latency, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K, occupancy, - WGM, non_temporal_a, non_temporal_b); - } else if (hardware_t::is_debug_enabled()) { - std::cout << "Skipping MT_M=" << MT_M << ", MT_N=" << MT_N << ", MT_K=" << MT_K - << " due to LDS capacity\n"; - } - } - - if (valid_results.empty()) { - throw std::runtime_error("No valid macro-tile sizes found."); - } - - // 1) Sort results by ascending latency. - std::stable_sort(valid_results.begin(), valid_results.end(), - [](auto const& a, auto const& b) { return std::get<0>(a) < std::get<0>(b); }); - - // 2) Collect results that tie for the absolute best latency. - double best_latency = std::get<0>(valid_results.front()); - size_t num_the_same = 0; - - // Count the number of similar latencies - constexpr double epsilon = 1e-9; - // variance is set through environment variable ANALYTICAL_GEMM_HEURISTICS_VARIANCE - static const double top_N_heuristic = read_heuristics_variance_env_var(); - for (const auto& res : valid_results) { - bool within_top; - const double diff = std::abs(std::get<0>(res) - best_latency); - - if (top_N_heuristic <= epsilon) { - // Absolute tolerance path - within_top = diff < epsilon; - } else { - // Relative tolerance path (guard denom) - const double denom = std::max(std::abs(best_latency), epsilon); - // If it's within top_N_heuristic%, include it. - within_top = (diff / denom) < top_N_heuristic; - } - - if (within_top) - ++num_the_same; - else - break; - } - - // 3) If that tie group has at least 10 entries, we only use those. - // 4) Otherwise, keep adding the next best latencies until we have 10 total or run out. - // std::vector top_candidates = tie_results; - // if(tie_results.size() < 10) - // { - // size_t i = tie_results.size(); - // while(top_candidates.size() < 10 && i < valid_results.size()) - // { - // top_candidates.push_back(valid_results[i]); - // i++; - // } - // } - // Now ‘top_candidates’ is either all the tied best-latency results (if >=10), - // or the top 10 latencies overall (including however many best-latency entries there were). - - // Finally, use your existing tie-breaker on top_candidates - pick_best_tile_by_arithmetic_intensity(valid_results, num_the_same); - - // After arithmetic intensity tie-breaking, check if we still have ties - // among the top results (those with same latency and arithmetic intensity) - if (num_the_same > 1) { - // Helper to compute arithmetic intensity - auto compute_arithmetic_intensity = [](const result_tuple& t) -> double { - auto MT_M = std::get<1>(t); - auto MT_N = std::get<2>(t); - auto MT_K = std::get<3>(t); - double flops = static_cast(2ull * MT_M * MT_N * MT_K); - double memory_traffic = static_cast(MT_M * MT_K + MT_N * MT_K + MT_M * MT_N); - return (memory_traffic == 0.0) ? 0.0 : (flops / memory_traffic); - }; - - // Check if the top tiles still have the same arithmetic intensity - double first_ai = compute_arithmetic_intensity(valid_results[0]); - size_t num_same_ai = 1; - for (size_t i = 1; i < num_the_same; ++i) { - double current_ai = compute_arithmetic_intensity(valid_results[i]); - if (std::abs(current_ai - first_ai) < 1e-6) { - num_same_ai++; - } else { - break; - } - } - - // If we still have ties after arithmetic intensity, apply problem dimension tie-breaker - if (num_same_ai > 1) { - // Problem dimension-based tie breaker: - // If M > N, prefer tiles with larger MT_M - // If N > M, prefer tiles with larger MT_N - // If M == N, this tie-breaker doesn't apply (will use final tie-breaker) - - if (M != N) { - std::stable_sort(valid_results.begin(), valid_results.begin() + num_same_ai, - [M, N](const result_tuple& a, const result_tuple& b) { - size_t MT_M_a = std::get<1>(a); - size_t MT_N_a = std::get<2>(a); - size_t MT_M_b = std::get<1>(b); - size_t MT_N_b = std::get<2>(b); - - if (M > N) { - // M-dominant: prefer larger MT_M - if (MT_M_a != MT_M_b) return MT_M_a > MT_M_b; - // If MT_M is same, prefer larger MT_N as secondary - return MT_N_a > MT_N_b; - } else // N > M - { - // N-dominant: prefer larger MT_N - if (MT_N_a != MT_N_b) return MT_N_a > MT_N_b; - // If MT_N is same, prefer larger MT_M as secondary - return MT_M_a > MT_M_b; - } - }); - } - - // Final tie-breaker: when all else is equal (including square problems), - // consistently prefer tiles with larger MT_M - // This ensures deterministic selection regardless of input order - std::stable_sort(valid_results.begin(), valid_results.begin() + num_same_ai, - [](const result_tuple& a, const result_tuple& b) { - size_t MT_M_a = std::get<1>(a); - size_t MT_N_a = std::get<2>(a); - size_t MT_K_a = std::get<3>(a); - size_t MT_M_b = std::get<1>(b); - size_t MT_N_b = std::get<2>(b); - size_t MT_K_b = std::get<3>(b); - - // Prefer larger MT_M first - if (MT_M_a != MT_M_b) return MT_M_a > MT_M_b; - // If MT_M is same, prefer larger MT_N - if (MT_N_a != MT_N_b) return MT_N_a > MT_N_b; - // If both MT_M and MT_N are same, prefer larger MT_K - return MT_K_a > MT_K_b; - }); - } - } - - if (print) { - for (const auto& tile : valid_results) { - std::cout << M << "x" << N << "x" << K - << "Selected Macro-Tile: Latency=" << std::get<0>(tile) - << ", MT_M=" << std::get<0>(tile) << ", MT_N=" << std::get<1>(tile) - << ", MT_K=" << std::get<2>(tile) << ", MI_M=" << std::get<3>(tile) - << ", MI_N=" << std::get<4>(tile) << ", MI_K=" << std::get<5>(tile) - << ", Occupancy=" << std::get<6>(tile) << ", WGM=" << std::get<7>(tile) - << ", NonTemporalA=" << std::get<8>(tile) - << ", NonTemporalB=" << std::get<9>(tile) << "\n"; - } - } - - return valid_results; -} - -template -constexpr N safe_ceil_div(N n, D d) { - // Static cast to undo integral promotion. - return static_cast(d == 0 ? 0 : (n / d + (n % d != 0 ? 1 : 0))); -} - -/*! - * \brief Selects the best WGM (maximizing L2 hit rate) given fixed macro tile sizes. - * - * \param[in] hardware - Hardware - * \param[in] M, N, K, batch - Problem - * \param[in] MT_M, MT_N, MT_K - Solution - * \param[in] print - whether to print the final best result - * - * \return best WGMXCC, WGM. - */ -std::tuple select_best_wgm(const hardware_t& hardware, size_t M, size_t N, - size_t K, size_t batch, size_t MT_M, size_t MT_N, - size_t MT_K, int nta, int ntb, size_t skGrid, - bool print) { - // Default is the closest we can get to a square - size_t max_CU_XCD = hardware.N_CU / hardware.NUM_XCD; - size_t defaultWGM = ceil(std::sqrt(max_CU_XCD)); - - // Number of output MTs per split and batch - size_t numMT_M = safe_ceil_div(M, MT_M); - size_t numMT_N = safe_ceil_div(N, MT_N); - size_t numMTs = numMT_M * numMT_N; - - // What SK does -- we already have skGrid so just compute numWaves and splitFactor - auto numWaves = skGrid > numMTs ? safe_ceil_div(skGrid, hardware.N_CU) - : safe_ceil_div(numMTs, hardware.N_CU); - auto splitFactor = safe_ceil_div(skGrid, numMTs); - - // ------------------- - // NonTemporal Cases - // ------------------- - if (nta > 3 && ntb < 4) - return std::make_tuple(hardware.NUM_XCD, 1); - else if (nta < 4 && ntb > 3) - return std::make_tuple(hardware.NUM_XCD, - std::min(max_CU_XCD, safe_ceil_div(numMTs, hardware.NUM_XCD))); - else if (nta > 3 && ntb > 3) - return std::make_tuple(hardware.NUM_XCD, 1); - - // ------------------- - // WGMXCC Prediction - // ------------------- - // Default WGMXCC -- always number of XCD - auto defaultWGMXCC = hardware.NUM_XCD; - bool isWGMXCCset = false; - size_t out_wgmxcc = defaultWGMXCC; - - // Batched GEMMs - if (batch > 1 && !isWGMXCCset) { - // Total tiles including batch count - size_t numTotalTiles = numMTs * batch; - - // if only one MT per each GEMM -> no mapping - // if less than hardware.NUM_XCD total tiles -> no mapping - if (numMTs == 1 || numTotalTiles <= hardware.NUM_XCD) { - out_wgmxcc = 1; - isWGMXCCset = true; - } - // else use the default (num_xcd) - } - - // If we are lucky that the splitFactor is a multiple of NUM_XCD -> no mapping - if ((splitFactor % hardware.NUM_XCD == 0) && !isWGMXCCset) { - out_wgmxcc = 1; - isWGMXCCset = true; - } - - // Small GEMMs - if ((numMTs <= hardware.NUM_XCD) && !isWGMXCCset) { - out_wgmxcc = 1; - isWGMXCCset = true; - } - - // For sizes that we have more than 2 waves of computations, we skip xcc mapping as MALL is - // more important -- matrix should not be skinny - // To avoid regressions, it's set to default, but it should actually be 1! - bool MallIsImportant = (splitFactor == 1 && batch == 1 && numMTs > 2 * hardware.N_CU && - numMT_M > 8 && numMT_N > 8); - if (MallIsImportant && !isWGMXCCset) { - out_wgmxcc = defaultWGMXCC; - isWGMXCCset = true; - } - - // ------------------- - // WGM Prediction - // ------------------- - // Default WGM - bool isWGMset = false; - size_t out_wgm = defaultWGM; - - // shortcut: - // 1. if we have decided to not remap xcc, there is no reason to use wgm - // 2. GEMMs that only have one tile in one dimension don't need wgm - // 3. Batched GEMMs don't need wgm (emprically -> batch count is often large!) - if (((out_wgmxcc == 1 && !MallIsImportant) || numMT_M == 1 || numMT_N == 1 || batch > 1) && - !isWGMset) { - out_wgm = 1; - isWGMset = true; - } - - // For tall cases (M >> N), if we have enough tiles to schedule, we use the number of tiles - // in the smaller dimension as WGM value - if (numMTs > hardware.N_CU && M > 10 * N && numMT_N <= 8) { - out_wgm = numMT_N; - isWGMset = true; - } - - // Cases where we have multiple rounds of computation per each CU - // To avoid regressions, it's set to defaultWGM. However, I think WGM=1 should be the winner - if (MallIsImportant && !isWGMset) { - out_wgm = defaultWGM; - isWGMset = true; - } - - if (!isWGMset) { - size_t numWGs = numWaves * splitFactor * numMTs; - size_t q = numWGs / hardware.NUM_XCD; - size_t r = numWGs % hardware.NUM_XCD; - - std::vector wgmList = {1, 2, 3, 4, 5, 6, 8, 16}; - int bestWGM = 1; - int bestL2 = std::numeric_limits::max(); - for (auto wgm : wgmList) { - auto slabTiles = numMT_M * std::min(wgm, static_cast(numMT_N)); - auto slabCount = safe_ceil_div(numMT_N, wgm); - auto edgeSlabWidth = numMT_N - (slabCount - 1) * wgm; - auto wgmL2Estimate = 0; - auto numXCD = std::min(hardware.NUM_XCD, numWGs); - - // Compute unique loads per L2 tile - for (uint32_t x = 0; x < numWaves * numXCD; ++x) { - // Range of "output tiles" that this xcd takes. - auto xccStart = q * x + (x < r ? x : r); - auto xccEnd = xccStart + q - 1 + (x < r ? 1 : 0); - // xccStart and xccEnd are supposed to be tile IDs - // In case of splitting, they are WG IDs. Modify to get tile IDs - xccStart /= splitFactor; - xccEnd /= splitFactor; - - auto slabStart = xccStart / slabTiles; - auto slabEnd = xccEnd / slabTiles; - - auto firstSlabWidth = (slabStart == slabCount - 1 ? edgeSlabWidth : wgm); - auto firstSlabStartIndex = xccStart % slabTiles; - auto firstSlabStartRow = firstSlabStartIndex / firstSlabWidth; - auto firstSlabEndRow = std::min( - (firstSlabStartIndex + (xccEnd - xccStart)) / firstSlabWidth, numMT_M - 1); - auto rowsInFirstSlab = firstSlabEndRow - firstSlabStartRow + 1; - - auto lastSlabWidth = (slabEnd == slabCount - 1 ? edgeSlabWidth : wgm); - auto lastSlabEndIndex = xccEnd % slabTiles; - auto lastSlabEndRow = lastSlabEndIndex / lastSlabWidth; - auto colsInLastRow = (lastSlabEndIndex % lastSlabWidth) + 1; - auto colsInLastSlab = (lastSlabEndRow > 0 ? lastSlabWidth : colsInLastRow); - - size_t uniqueRows = 0; - size_t uniqueCols = 0; - if (slabEnd == slabStart) { - uniqueRows = lastSlabEndRow - firstSlabStartRow + 1; - uniqueCols = firstSlabWidth; - if (rowsInFirstSlab <= 2) - uniqueCols = std::min(xccEnd - xccStart + 1, firstSlabWidth); - } else { - auto colsInFirstRow = firstSlabWidth - (xccStart % firstSlabWidth); - auto colsInFirstSlab = (rowsInFirstSlab > 1 ? firstSlabWidth : colsInFirstRow); - auto fullSlabs = slabEnd - slabStart - 1; - uniqueRows = - (fullSlabs > 0 ? numMT_M - : std::min(rowsInFirstSlab + lastSlabEndRow + 1, numMT_M)); - uniqueCols = colsInFirstSlab + colsInLastSlab + fullSlabs * wgm; - } - - // Sum up the L2 loads over all XCD - // We should technically multiply by K (or splitted K), but it - // has no effect on sorting - auto xccL2Estimate = uniqueRows * MT_M + uniqueCols * MT_N; - wgmL2Estimate += xccL2Estimate; - } - - // If we have found a better WGM - if (wgmL2Estimate < bestL2) { - bestL2 = wgmL2Estimate; - bestWGM = wgm; - } - - if (print || hardware_t::is_debug_enabled()) - std::cout << "WGM (" << wgm << "), L2Estimate (" << wgmL2Estimate << ")" - << std::endl; - } - - out_wgm = bestWGM; - } - - return std::make_tuple(out_wgmxcc, out_wgm); -} - -// Logic to decide between two MT that are "tied" -std::vector> tie_breaker_macro_tile_sizes( - const std::vector>& top_results, size_t M, size_t N, - size_t K, hardware_t& hardware, - std::function - tie_breaker_fn) { - std::vector> tie_breaker_results; - - for (const auto& res : top_results) { - size_t MT_M = std::get<1>(res); - size_t MT_N = std::get<2>(res); - size_t MT_K = std::get<3>(res); - - // Call user-provided tie-breaking function - double precise_latency = tie_breaker_fn(M, N, K, MT_M, MT_N, MT_K, hardware); - - tie_breaker_results.emplace_back(precise_latency, MT_M, MT_N, MT_K); - } - - // Sort results by precise_latency (ascending order) - std::stable_sort(tie_breaker_results.begin(), tie_breaker_results.end()); - - return tie_breaker_results; -} - -std::vector> -rank_macro_tile_sizes( - size_t M, size_t N, size_t K, bool transA, bool transB, hardware_t& hardware, - const std::vector& MT_list, size_t element_size, data_type_t mi_datatype, - double H_L2, bool print, size_t WGM, - std::function - tie_breaker_fn, - size_t max_cus) { - std::vector> results; - - typedef std::tuple result_tuple; - - for (size_t i = 0; i < MT_list.size(); ++i) { - size_t MT_M = std::get<0>(MT_list[i]); - size_t MT_N = std::get<1>(MT_list[i]); - size_t MT_K = std::get<2>(MT_list[i]); - size_t MI_M = std::get<3>(MT_list[i]); - size_t MI_N = std::get<4>(MT_list[i]); - size_t MI_K = std::get<5>(MT_list[i]); - - if (hardware_t::is_debug_enabled()) { - std::cout << "Evaluating MT_M=" << MT_M << ", MT_N=" << MT_N << ", MT_K=" << MT_K - << ", MI_M=" << MI_M << ", MI_N=" << MI_N << ", MI_K=" << MI_K << "\n"; - } - - if (check_lds_capacity(hardware, MT_M, MT_N, MT_K, element_size)) { - size_t split = 1; - size_t mx_block_size = 0; - double Total_latency = - compute_total_latency(hardware, M, N, K, - 1, // Batch - transA, transB, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K, - element_size * 8, // Element Size A - element_size * 8, // Element Size B - element_size * 8, // Element Size out - mi_datatype, mx_block_size, WGM, 0, 0, 0, split, max_cus); - - results.push_back(std::make_tuple(Total_latency, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K)); - } else if (hardware_t::is_debug_enabled()) { - std::cout << "Skipping MT_M=" << MT_M << ", MT_N=" << MT_N << ", MT_K=" << MT_K - << " due to LDS capacity\n"; - } - } - - // Sort results by Total_latency, from worst (largest latency) to best (smallest latency) - std::stable_sort(results.begin(), results.end(), - [](const result_tuple& a, const result_tuple& b) { - return std::get<0>(a) > std::get<0>(b); - }); - - if (!results.empty()) { - double best_latency = std::get<0>(results.back()); - - std::vector top_results; - for (size_t i = 0; i < results.size(); ++i) { - if (std::abs(std::get<0>(results[i]) - best_latency) < 1e-6) { - top_results.push_back(results[i]); - } - } - - if (top_results.size() > 1) { - if (hardware_t::is_debug_enabled()) { - std::cout << "Tie detected among top-ranked tile sizes. Applying " - "tie-breaker...\n"; - } - - // Compute tie-breaker scores and store them along with the result indices - std::vector> - tie_breaker_scores; // (score, index in top_results) - - for (size_t i = 0; i < top_results.size(); ++i) { - const result_tuple& res = top_results[i]; - size_t MT_M = std::get<1>(res); - size_t MT_N = std::get<2>(res); - size_t MT_K = std::get<3>(res); - size_t MI_M = std::get<4>(res); - size_t MI_N = std::get<5>(res); - size_t MI_K = std::get<6>(res); - double score = tie_breaker_fn(MT_M, MT_N, MT_K, MI_M, MI_N, MI_K, hardware); - - tie_breaker_scores.push_back(std::make_pair(score, i)); - } - - // Now sort the tie_breaker_scores based on score - std::stable_sort(tie_breaker_scores.begin(), tie_breaker_scores.end(), - [](const std::pair& a, - const std::pair& b) { return a.first > b.first; }); - - // Now re-order 'top_results' based on sorted indices - std::vector sorted_top_results; - for (size_t i = 0; i < tie_breaker_scores.size(); ++i) { - size_t idx = tie_breaker_scores[i].second; - sorted_top_results.push_back(top_results[idx]); - } - - // Remove the tied results from 'results' and insert the sorted 'sorted_top_results' - results.erase(std::remove_if(results.begin(), results.end(), - [best_latency](const result_tuple& res) { - return std::abs(std::get<0>(res) - best_latency) < - 1e-6; - }), - results.end()); - - results.insert(results.end(), sorted_top_results.begin(), sorted_top_results.end()); - // No need to re-sort results as total_latency remains same for tied results - } - } - - if (print) { - std::cout << "Total Latency\tMT_M\tMT_N\tMT_K\tMI_M\tMI_N\tMI_K\n"; - for (size_t i = 0; i < results.size(); ++i) { - double latency = std::get<0>(results[i]); - size_t MT_M = std::get<1>(results[i]); - size_t MT_N = std::get<2>(results[i]); - size_t MT_K = std::get<3>(results[i]); - size_t MI_M = std::get<4>(results[i]); - size_t MI_N = std::get<5>(results[i]); - size_t MI_K = std::get<6>(results[i]); - std::cout << std::fixed << std::setprecision(2) << latency << "\t" << MT_M << "\t" - << MT_N << "\t" << MT_K << "\t" << MI_M << "\t" << MI_N << "\t" << MI_K - << "\n"; - } - } - - return results; -} - -double compute_tflops_from_latency(double latency_cycles, size_t M, size_t N, size_t K, - double clock_GHz) { - // Compute total FLOPs - double total_FLOPs = 2.0 * M * N * K; // For GEMM, each multiply-add is 2 FLOPs - // Compute total time in seconds - double cycles_per_second = clock_GHz * 1e9; // 1 GHz = 1e9 cycles per second - double total_time_seconds = latency_cycles / cycles_per_second; - // Compute performance in FLOPS - double FLOPS = total_FLOPs / total_time_seconds; - // Convert to TFLOPS - double TFLOPS = FLOPS / 1e12; // 1 TFLOP = 1e12 FLOPs - - if (hardware_t::is_debug_enabled()) { - std::cout << "Total FLOPs: " << total_FLOPs << "\n"; - std::cout << "Total Time: " << total_time_seconds << " seconds\n"; - std::cout << "Performance: " << FLOPS << " FLOPS\n"; - std::cout << "Achieved Performance: " << TFLOPS << " TFLOPS\n"; - } - - return TFLOPS; -} -} // namespace origami diff --git a/shared/origami/tests/CMakeLists.txt b/shared/origami/tests/CMakeLists.txt index 822f4853d20c..863c4d125473 100644 --- a/shared/origami/tests/CMakeLists.txt +++ b/shared/origami/tests/CMakeLists.txt @@ -1,39 +1,52 @@ -# Copyright Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT +################################################################################ +# +# MIT License +# +# Copyright 2025 AMD ROCm(TM) Software +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- +# ies of the Software, and to permit persons to whom the Software is furnished +# to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- +# PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- +# CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +################################################################################ -find_package(GTest REQUIRED) -find_package(Boost REQUIRED COMPONENTS filesystem) +find_package(Catch2 3 QUIET CONFIG) +if(NOT Catch2_FOUND) + if(ORIGAMI_ENABLE_FETCH) + include(FetchContent) + fetchcontent_declare( + Catch2 GIT_REPOSITORY https://github.com/catchorg/Catch2.git GIT_TAG devel + ) + fetchcontent_makeavailable(Catch2) + else() + message(FATAL_ERROR "Failed to find Catch2") + endif() +endif() add_executable(origami-tests) -target_sources(origami-tests - PRIVATE - "${CMAKE_CURRENT_SOURCE_DIR}/origami_gtest.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/test_tile_ordering_issue.cpp" - # "${CMAKE_CURRENT_SOURCE_DIR}/test_variance_issue.cpp" - "${CMAKE_CURRENT_SOURCE_DIR}/test_negative_occupancy.cpp" +target_sources( + origami-tests PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/test_gemm.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/test_origami.cpp" ) -target_include_directories(origami-tests - PRIVATE - "${CMAKE_CURRENT_SOURCE_DIR}/include" -) +target_include_directories(origami-tests PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/include") -configure_file( - "${CMAKE_CURRENT_SOURCE_DIR}/origami_gtest.yaml" - "${CMAKE_CURRENT_BINARY_DIR}/origami_gtest.yaml" - COPYONLY -) +target_link_libraries(origami-tests PRIVATE roc::origami Catch2::Catch2WithMain) -target_link_libraries(origami-tests - PRIVATE - roc::origami - Boost::filesystem - GTest::gtest - GTest::gtest_main -) - -gtest_discover_tests(origami-tests - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} - TIMEOUT 60 -) +include(CTest) +include(Catch) +catch_discover_tests(origami-tests) diff --git a/shared/origami/tests/include/common.hpp b/shared/origami/tests/include/common.hpp new file mode 100644 index 000000000000..3c18f1a3b48d --- /dev/null +++ b/shared/origami/tests/include/common.hpp @@ -0,0 +1,114 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright 2025 AMD ROCm(TM) Software + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once +#include +#include "origami/gemm.hpp" +#include "origami/hardware.hpp" +#include "origami/origami.hpp" +#include "origami/streamk.hpp" + +// List of GPU architectures to test +inline const std::vector test_architectures = {942, 950}; + +// Helper function to construct problem_t +inline origami::problem_t make_problem(size_t m, + size_t n, + size_t k, + origami::transpose_t a_trans = origami::transpose_t::T, + origami::transpose_t b_trans = origami::transpose_t::N, + size_t batch = 1, + int mx_block_size = 0) { + origami::problem_t problem; + problem.size.m = m; + problem.size.n = n; + problem.size.k = k; + problem.batch = batch; + problem.a_transpose = a_trans; + problem.b_transpose = b_trans; + problem.a_dtype = origami::data_type_t::BFloat16; + problem.b_dtype = origami::data_type_t::BFloat16; + problem.c_dtype = origami::data_type_t::BFloat16; + problem.d_dtype = origami::data_type_t::BFloat16; + problem.mi_dtype = origami::data_type_t::BFloat16; + problem.a_mx_block_size = mx_block_size; + problem.b_mx_block_size = mx_block_size; + return problem; +} + +// Helper function to construct config_t +inline origami::config_t make_config(size_t mt_m, + size_t mt_n, + size_t mt_k, + size_t mi_m = 16, + size_t mi_n = 16, + size_t mi_k = 16, + int wgm = 1, + int occupancy = 1, + int non_temporal_a = 0, + int non_temporal_b = 0) { + origami::config_t config; + config.mt.m = mt_m; + config.mt.n = mt_n; + config.mt.k = mt_k; + config.mi.m = mi_m; + config.mi.n = mi_n; + config.mi.k = mi_k; + config.occupancy = occupancy; + config.workgroup_mapping = wgm; + config.cache_hints_a = non_temporal_a; + config.cache_hints_b = non_temporal_b; + return config; +} + +// Helper function to construct hardware_t with all parameters +inline origami::hardware_t make_hardware( + int gpu_arch, + size_t n_cu = 304, + size_t lds_capacity = 65536, + size_t num_xcd = 8, + double mem1_perf_ratio = 1.0, + double mem2_perf_ratio = 1.0, + double mem3_perf_ratio = 1.0, + size_t l2_capacity = 4000000, + double compute_clock_ghz = 1.0, + size_t parallel_mi_cu = 1, + std::tuple mem_bw_per_wg_coefficients = std::make_tuple(0, 0.015, 0)) { + const std::string gpu_arch_str = "gfx" + std::to_string(gpu_arch); + auto gpu_arch_enum = origami::hardware_t::arch_name_to_enum(gpu_arch_str); + + return origami::hardware_t(gpu_arch_enum, + n_cu, + lds_capacity, + num_xcd, + mem1_perf_ratio, + mem2_perf_ratio, + mem3_perf_ratio, + l2_capacity, + compute_clock_ghz, + parallel_mi_cu, + mem_bw_per_wg_coefficients); +} diff --git a/shared/origami/tests/include/testing_origami.hpp b/shared/origami/tests/include/testing_origami.hpp deleted file mode 100644 index f89603b1c53a..000000000000 --- a/shared/origami/tests/include/testing_origami.hpp +++ /dev/null @@ -1,657 +0,0 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once -#include "origami/gemm.hpp" -#include "origami/hardware.hpp" -#include "origami/utils.hpp" -#include -#include -#include -#include - -struct InputWithExpected -{ - std::map values; - std::optional expected; - std::optional expected_gt; - std::optional expected_lt; -}; - -struct MyTestData -{ - std::string name; - std::vector inputs; -}; - -// Parameterized test class declaration -class AnalyticalGtest : public ::testing::TestWithParam -{ -}; - -void ComputeLoads(int MT_M, int MT_N, int MT_K, const std::optional expected) -{ - auto a_loads = origami::compute_A_loads(MT_M, MT_K); - auto b_loads = origami::compute_B_loads(MT_N, MT_K); - EXPECT_EQ(a_loads, expected); - EXPECT_EQ(b_loads, expected); -} - -void EstimateL2Hit(const origami::hardware_t& hardware, - int M, - int N, - int K, - int batch, - int MT_M, - int MT_N, - int MT_K, - size_t element_size, - int splittingFactor, - const std::optional expected_gt, - const std::optional expected_lt) -{ - double l2_hit; - for(int i = 1; i < 1025; i++) - { - l2_hit = origami::estimate_l2_hit( - hardware, M, N, K, batch, MT_M, MT_N, MT_K, element_size, i, splittingFactor); - EXPECT_GT(l2_hit, expected_gt); - EXPECT_LT(l2_hit, expected_lt); - } -} - -void ComputeNumMatrixInstructions(const origami::hardware_t& hardware, - int MT_M, - int MT_N, - int MT_K, - int MI_M, - int MI_N, - int MI_K, - const std::optional expected) -{ - auto NumberMatrixInstructions - = origami::compute_number_matrix_instructions(hardware, MT_M, MT_N, MT_K, MI_M, MI_N, MI_K); - EXPECT_EQ(NumberMatrixInstructions, expected); -} - -void ComputeMTComputeLatency(const origami::hardware_t& hardware, - size_t M, - size_t N, - size_t K, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, - size_t element_size_B, - const std::optional expected, - const std::optional expected_gt) -{ - auto latency = origami::compute_mt_compute_latency(hardware, - M, - N, - K, - transA, - transB, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - origami::data_type_t::BFloat16); - - if(expected.has_value()) - EXPECT_EQ(latency, expected); - else if(expected_gt.has_value()) - EXPECT_GT(latency, expected_gt); -} - -void ComputeMemoryLatency(const origami::hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t element_size_A, - size_t element_size_B, - size_t mx_block_size, - int wgm, - int numActiveCUs, - int splittingFactor) -{ - auto mem_latency_small = origami::compute_memory_latency(hardware, - M, - N, - K, - transA, - transB, - batch, - MT_M, - MT_N, - MT_K, - element_size_A, - element_size_B, - mx_block_size, - wgm, - numActiveCUs, - splittingFactor); - - auto mem_latency_large = origami::compute_memory_latency(hardware, - M, - N, - K, - transA, - transB, - batch, - MT_M * 2, - MT_N * 2, - MT_K * 2, - element_size_A, - element_size_B, - mx_block_size, - wgm, - numActiveCUs, - splittingFactor); - - EXPECT_LT(mem_latency_small, mem_latency_large); -} - -void ComputeTileLatency(const origami::hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, //In bits - size_t element_size_B, //In bits, - size_t element_size_out, //In bits - size_t mx_block_size, - int WGM, - size_t numActiveCUs, - size_t splittingFactor) -{ - auto tile_latency_small = origami::compute_tile_latency(hardware, - M, - N, - K, - batch, - transA, - transB, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - element_size_out, - origami::data_type_t::BFloat16, - mx_block_size, - WGM, - 1, - numActiveCUs, - splittingFactor); - - auto tile_latency_large = origami::compute_tile_latency(hardware, - M, - N, - K, - batch, - transA, - transB, - MT_M * 2, - MT_N * 2, - MT_K * 2, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - element_size_out, - origami::data_type_t::BFloat16, - mx_block_size, - WGM, - 1, - numActiveCUs, - splittingFactor); - - EXPECT_GT(tile_latency_large, tile_latency_small); -} - -void ComputeWaveLatency(const origami::hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, //In bits - size_t element_size_B, //In bits, - size_t element_size_out, //In bits - size_t mx_block_size, - int WGM, - size_t numActiveCUs, - size_t splittingFactor) -{ - auto tile_latency = origami::compute_tile_latency(hardware, - M, - N, - K, - batch, - transA, - transB, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - element_size_out, - origami::data_type_t::BFloat16, - mx_block_size, - WGM, - 1, - numActiveCUs, - splittingFactor); - auto wave_latency = origami::compute_wave_latency(hardware, - M, - N, - K, - batch, - transA, - transB, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - element_size_out, - origami::data_type_t::BFloat16, - mx_block_size, - WGM, - 1, - numActiveCUs, - splittingFactor); - EXPECT_DOUBLE_EQ(wave_latency, tile_latency); -} - -void ComputeTotalLatency(const origami::hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, //In bits - size_t element_size_B, //In bits, - size_t element_size_out, //In bits - size_t mx_block_size, - int WGM, - size_t splittingFactor, - size_t max_cus) -{ - double latency_cycles_small = origami::compute_total_latency(hardware, - M, - N, - K, - batch, - transA, - transB, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - element_size_out, - origami::data_type_t::BFloat16, - mx_block_size, - WGM, - 0, - 0, - splittingFactor, - max_cus); - - double latency_cycles_large = origami::compute_total_latency(hardware, - M * 2, - N * 2, - K * 2, - batch, - transA, - transB, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - element_size_out, - origami::data_type_t::BFloat16, - mx_block_size, - WGM, - 0, - 0, - splittingFactor, - max_cus); - EXPECT_LT(latency_cycles_small, latency_cycles_large); -} - -void ComputePerfGflops(size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, //In bits - size_t element_size_B, //In bits, - size_t element_size_out, //In bits - int WGM, - size_t max_cus) -{ - auto gfx942arch = origami::hardware_t::arch_name_to_enum("gfx942"); - auto gfx942_slow = origami::hardware_t( - gfx942arch, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.4, 1, std::make_tuple(0, 0.015, 0)); - auto gfx942_fast = origami::hardware_t( - gfx942arch, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.8, 1, std::make_tuple(0, 0.015, 0)); - double flops_slow = origami::compute_perf_gflops(gfx942_slow, - M, - N, - K, - batch, - transA, - transB, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - element_size_out, - origami::data_type_t::BFloat16, - WGM, - max_cus); - double flops_fast = origami::compute_perf_gflops(gfx942_fast, - M, - N, - K, - batch, - transA, - transB, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - element_size_out, - origami::data_type_t::BFloat16, - WGM, - max_cus); - EXPECT_GT(flops_fast, flops_slow); // faster clock = higher flops -} - -void EstimateMallHit(const origami::hardware_t& hardware, - int M, - int N, - int K, - int batch, - int MT_M, - int MT_N, - int MT_K, - size_t element_size, - size_t numActiveCUs, - size_t splittingFactor, - const std::optional expected_gt) -{ - double mall_hit; - for(int i = 1; i < 1025; i++) - { - mall_hit = origami::estimate_mall_hit(hardware, - M, - N, - K, - batch, - MT_M, - MT_N, - MT_K, - element_size, - i, - numActiveCUs, - splittingFactor); - EXPECT_GT(mall_hit, expected_gt); - } -} - -void CheckLDSCapacity( - const origami::hardware_t& hardware, int MT_M, int MT_N, int MT_K, size_t element_size) -{ - auto fit_lds_memory = origami::check_lds_capacity(hardware, MT_M, MT_N, MT_K, element_size); - EXPECT_TRUE(fit_lds_memory); -} - -// hardware_t -void HardwareArchEnum(const std::string gpuArchNumber) -{ - auto gpuArchEnum = origami::hardware_t::arch_name_to_enum("gfx" + gpuArchNumber); - EXPECT_EQ(gpuArchEnum, origami::hardware_t::architecture_t::gfx942); -} - -// Utils -void BestGridSize(const origami::hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t MT_M, - size_t MT_N, - size_t MT_K, - size_t MI_M, - size_t MI_N, - size_t MI_K, - size_t element_size_A, - size_t element_size_B, - size_t element_size_out, - size_t mx_block_size, - double H_L2, - int WGM, - size_t biggest_allowable_split, - size_t max_cus, - const std::optional expected_gt) -{ - size_t grid_size = origami::select_best_grid_size(M, - N, - K, - batch, - transA, - transB, - hardware, - MT_M, - MT_N, - MT_K, - MI_M, - MI_N, - MI_K, - element_size_A, - element_size_B, - element_size_out, - origami::data_type_t::BFloat16, - mx_block_size, - H_L2, - WGM, - biggest_allowable_split, - max_cus); - EXPECT_GT(grid_size, expected_gt); -} - -void BestMacroTileSize(const origami::hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - bool transA, - bool transB, - size_t element_size_A, //In bits - size_t element_size_B, //In bits - size_t element_size_out, //In bits - size_t mx_block_size, - double H_L2, - size_t WGM, - size_t max_cus) -{ - const std::vector> - MT_list = {{256, 256, 32, 32, 32, 8, 1, 6, 0, 0}, - {128, 128, 64, 32, 32, 8, 1, 6, 0, 0}, - {64, 64, 64, 32, 32, 8, 1, 6, 0, 0}}; - auto results = select_best_macro_tile_size(M, - N, - K, - batch, - transA, - transB, - hardware, - MT_list, - element_size_A, - element_size_B, - element_size_out, - origami::data_type_t::BFloat16, - mx_block_size, - H_L2, - false, - WGM, - max_cus); - - EXPECT_EQ(results.size(), MT_list.size()); - for(int i = 0; i < results.size() - 1; i++) - EXPECT_LT(std::get<0>(results[i]), std::get<0>(results[i + 1])); -} - -void BestWGM(const origami::hardware_t& hardware, - size_t M, - size_t N, - size_t K, - size_t batch, - size_t MT_M, - size_t MT_N, - size_t MT_K) -{ - // Assume no nt - auto nta = 0; - auto ntb = 0; - // Assume DP - auto skGrid = (M + MT_M - 1) / MT_M * (N + MT_N - 1) / MT_N; - - auto [best_wgmxcc_large_tile, best_wgm_large_tile] = - select_best_wgm(hardware, - M, - N, - K, - batch, - MT_M, - MT_N, - MT_K, - nta, - ntb, - skGrid, - false); - - auto [best_wgmxcc_small_tile, best_wgm_small_tile] = - select_best_wgm(hardware, - M / 4, - N / 4, - K, - batch, - MT_M, - MT_N, - MT_K * 2, - nta, - ntb, - skGrid, - false); - - auto [best_wgmxcc_nonsquare_tile, best_wgm_nonsquare_tile] = - select_best_wgm(hardware, - 1024, - 5120, - K, - batch, - MT_M, - MT_N, - MT_K, - nta, - ntb, - skGrid, - false); - - EXPECT_EQ(best_wgmxcc_large_tile, best_wgmxcc_small_tile); - EXPECT_GT(best_wgm_large_tile, best_wgm_small_tile); - EXPECT_NE(best_wgm_large_tile, best_wgm_nonsquare_tile); -} - -void UtilsTFlopsFromLatency(size_t M, size_t N, size_t K, double latency_cycles, double clock_GHz) -{ - auto tflops = origami::compute_tflops_from_latency(latency_cycles, M, N, K, clock_GHz); - double Expected = 1.99; - EXPECT_LT(std::abs(tflops - Expected) / std::abs(Expected), 0.01); -} diff --git a/shared/origami/tests/origami_gtest.cpp b/shared/origami/tests/origami_gtest.cpp deleted file mode 100644 index 30e4146bbb09..000000000000 --- a/shared/origami/tests/origami_gtest.cpp +++ /dev/null @@ -1,353 +0,0 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "testing_origami.hpp" -#include -#include -#include -#include - -#ifndef _WIN32 -#include -#include -#else -#include -#endif - -// Returns the directory path where the current executable resides. -// Adds a trailing slash ('/' on Linux, '\' on Windows) for easy file concatenation. -std::string getExecutableDir() { -#ifndef _WIN32 - // Linux branch - - char result[PATH_MAX]; // Buffer to store the path - ssize_t count = readlink("/proc/self/exe", result, PATH_MAX); - // readlink reads the symbolic link /proc/self/exe, which points to the current executable - - if (count == -1) { - // If readlink fails, return empty string - return ""; - } - - result[count] = '\0'; // Null-terminate the buffer - std::string fullPath(result); // Convert to std::string - - // Find the position of the last slash ('/') in the path - // This separates the directory from the binary name - size_t pos = fullPath.find_last_of('/'); - - // Extract the directory portion - std::string dir = (pos != std::string::npos) ? fullPath.substr(0, pos) : fullPath; - - // Ensure the directory string ends with a slash - if (!dir.empty() && dir.back() != '/') - dir += '/'; - - return dir; - -#else - // Windows branch - - char path[MAX_PATH]; // Buffer to store the path - DWORD length = GetModuleFileNameA(NULL, path, MAX_PATH); - // GetModuleFileNameA returns the full path of the current executable - - if (length == 0) { - // Failed to get the executable path - return ""; - } - - std::string fullPath(path, length); // Convert to std::string - - // Find the position of the last backslash ('\') or forward slash ('/') - size_t pos = fullPath.find_last_of("\\/"); - - // Extract the directory portion - std::string dir = (pos != std::string::npos) ? fullPath.substr(0, pos) : fullPath; - - // Ensure the directory string ends with a backslash - if (!dir.empty() && dir.back() != '\\' && dir.back() != '/') - dir += '\\'; - - return dir; -#endif -} - -// Parse origami_gtest.yaml to get the test data -std::vector parseYamlManually(const std::string& filename) -{ - std::string YamlfullPath = getExecutableDir() + filename; - std::ifstream file(YamlfullPath); - if(!file) - { - std::cerr << "Failed to open file: " << YamlfullPath << std::endl; - return {}; - } - - std::string line; - std::vector tests; - MyTestData current; - enum class State - { - None, - Inputs - } state - = State::None; - int line_number = 0; - - while(std::getline(file, line)) - { - line_number++; - line.erase(0, line.find_first_not_of(" \t\r\n")); - line.erase(line.find_last_not_of(" \t\r\n") + 1); - - if(line.empty() || line[0] == '#') - continue; - - if(line.rfind("- name:", 0) == 0) - { - if(!current.name.empty()) - tests.push_back(current); - current = MyTestData{}; - current.name = line.substr(7); - current.name.erase(0, current.name.find_first_not_of(" \t")); - state = State::None; - } - else if(line.rfind("inputs:", 0) == 0) - { - state = State::Inputs; - } - else if(state == State::Inputs && line.rfind("- {", 0) == 0) - { - std::string inner = line.substr(3); - if(!inner.empty() && inner.back() == '}') - inner.pop_back(); - - std::map values; - std::optional expected; - std::optional expected_gt; - std::optional expected_lt; - - std::stringstream ss(inner); - std::string pair; - while(std::getline(ss, pair, ',')) - { - auto colon = pair.find(':'); - if(colon == std::string::npos) - continue; - - std::string key = pair.substr(0, colon); - std::string val = pair.substr(colon + 1); - key.erase(0, key.find_first_not_of(" \t")); - key.erase(key.find_last_not_of(" \t") + 1); - val.erase(0, val.find_first_not_of(" \t")); - val.erase(val.find_last_not_of(" \t") + 1); - - try - { - int num = std::stoi(val); - if(key == "expected") - expected = num; - else if(key == "expected_gt") - expected_gt = num; - else if(key == "expected_lt") - expected_lt = num; - else - values[key] = num; - } - catch(...) - { - std::cerr << "Invalid number in line " << line_number << ": " << val - << std::endl; - } - } - - current.inputs.push_back(InputWithExpected{values, expected, expected_gt, expected_lt}); - } - } - - if(!current.name.empty()) - tests.push_back(current); - - return tests; -} - -TEST_P(AnalyticalGtest, DynamicDispatch) -{ - const MyTestData& test = GetParam(); - - const std::string gpuArchNumber = std::to_string(test.inputs[0].values.at("gpu_arch")); - auto gpuArchEnum = origami::hardware_t::arch_name_to_enum("gfx" + gpuArchNumber); - - //TODO: Hardcoding numbers for gfx942. Future archs could be added here with if else loop. - auto gpuInfo = origami::hardware_t( - gpuArchEnum, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.0, 1, std::make_tuple(0, 0.015, 0)); - - if(test.name == "ComputeLoads") - { - for(const auto& input_case : test.inputs) - { - ComputeLoads(input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.expected); - } - } - else if(test.name == "EstimateL2Hit") - { - for(const auto& input_case : test.inputs) - { - EstimateL2Hit(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), - input_case.values.at("MT_M"), input_case.values.at("MT_N"), input_case.values.at("MT_K"), input_case.values.at("element_size"), - input_case.values.at("splittingFactor"),input_case.expected_gt, input_case.expected_lt); - } - } - else if(test.name == "ComputeNumMatrixInstructions") - { - for(const auto& input_case : test.inputs) - { - ComputeNumMatrixInstructions(gpuInfo, input_case.values.at("MT_M"), input_case.values.at("MT_N"), input_case.values.at("MT_K"), - input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), input_case.expected); - } - } - else if(test.name == "ComputeMTComputeLatency") - { - for(const auto& input_case : test.inputs) - { - ComputeMTComputeLatency(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), - input_case.values.at("transA"), input_case.values.at("transB"),input_case.values.at("MT_M"), input_case.values.at("MT_N"), - input_case.values.at("MT_K"), input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), - input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), input_case.expected, input_case.expected_gt); - } - } - else if(test.name == "ComputeMemoryLatency") - { - for(const auto& input_case : test.inputs) - { - ComputeMemoryLatency(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), - input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("MT_M"), input_case.values.at("MT_N"), - input_case.values.at("MT_K"), input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), - input_case.values.at("mx_block_size"), input_case.values.at("wgm"), input_case.values.at("numActiveCUs"), input_case.values.at("splittingFactor")); - } - } - else if(test.name == "ComputeTileLatency") - { - for(const auto& input_case : test.inputs) - { - ComputeTileLatency(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), - input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("MT_M"), input_case.values.at("MT_N"), - input_case.values.at("MT_K"), input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), - input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), input_case.values.at("element_size_out"), - input_case.values.at("mx_block_size"), input_case.values.at("wgm"), input_case.values.at("numActiveCUs"), input_case.values.at("splittingFactor")); - } - } - else if(test.name == "ComputeWaveLatency") - { - for(const auto& input_case : test.inputs) - { - ComputeWaveLatency(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), - input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("MT_M"), input_case.values.at("MT_N"), - input_case.values.at("MT_K"), input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), - input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), input_case.values.at("element_size_out"), - input_case.values.at("mx_block_size"), input_case.values.at("wgm"), input_case.values.at("numActiveCUs"), input_case.values.at("splittingFactor")); - } - } - else if(test.name == "ComputeTotalLatency") - { - for(const auto& input_case : test.inputs) - { - ComputeTotalLatency(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), - input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("MT_M"), input_case.values.at("MT_N"), - input_case.values.at("MT_K"), input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), - input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), input_case.values.at("element_size_out"), - input_case.values.at("mx_block_size"), input_case.values.at("wgm"), input_case.values.at("splittingFactor"), - input_case.values.at("max_cus")); - } - } - else if(test.name == "ComputePerfGflops") - { - for(const auto& input_case : test.inputs) - { - ComputePerfGflops(input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), - input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("MT_M"), input_case.values.at("MT_N"), - input_case.values.at("MT_K"), input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), - input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), input_case.values.at("element_size_out"), - input_case.values.at("WGM"), input_case.values.at("max_cus")); - } - } - else if(test.name == "EstimateMallHit") - { - for(const auto& input_case : test.inputs) - { - EstimateMallHit(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), - input_case.values.at("MT_M"), input_case.values.at("MT_N"), input_case.values.at("MT_K"),input_case.values.at("element_size_A"), - input_case.values.at("numActiveCUs"), input_case.values.at("splittingFactor"), input_case.expected_gt); - } - } - else if(test.name == "CheckLDSCapacity") - { - for(const auto& input_case : test.inputs) - { - CheckLDSCapacity(gpuInfo, input_case.values.at("MT_M"), input_case.values.at("MT_N"), input_case.values.at("MT_K"), input_case.values.at("element_size")); - } - } - else if(test.name == "HardwareArchEnum") - { - for(const auto& input_case : test.inputs) - { - HardwareArchEnum(std::to_string(test.inputs[0].values.at("gpu_arch"))); - } - } - else if(test.name == "BestGridSize") - { - for(const auto& input_case : test.inputs) - { - BestGridSize(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), - input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("MT_M"), input_case.values.at("MT_N"), - input_case.values.at("MT_K"), input_case.values.at("MI_M"), input_case.values.at("MI_N"), input_case.values.at("MI_K"), - input_case.values.at("element_size_A"), input_case.values.at("element_size_B"), input_case.values.at("element_size_out"), - input_case.values.at("mx_block_size"), input_case.values.at("H_L2"), input_case.values.at("WGM"), - input_case.values.at("biggest_allowable_split"), input_case.values.at("max_cus"), input_case.expected_gt); - } - } - else if(test.name == "BestMacroTileSize") - { - for(const auto& input_case : test.inputs) - { - BestMacroTileSize(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), - input_case.values.at("transA"), input_case.values.at("transB"), input_case.values.at("element_size_A"), - input_case.values.at("element_size_B"), input_case.values.at("element_size_out"), input_case.values.at("mx_block_size"), - input_case.values.at("H_L2"), input_case.values.at("WGM"), input_case.values.at("max_cus")); - } - } - else if(test.name == "BestWGM") - { - for(const auto& input_case : test.inputs) - { - BestWGM(gpuInfo, input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), input_case.values.at("batch"), - input_case.values.at("MT_M"), input_case.values.at("MT_N"), input_case.values.at("MT_K")); - } - } - else if(test.name == "UtilsTFlopsFromLatency") - { - for(const auto& input_case : test.inputs) - { - UtilsTFlopsFromLatency(input_case.values.at("M"), input_case.values.at("N"), input_case.values.at("K"), - input_case.values.at("latency_cycles"), input_case.values.at("clock_GHz")); - } - } - else - { - FAIL() << "Unknown test name: " << test.name; - } -} - -// Instantiate tests using manual parser -INSTANTIATE_TEST_SUITE_P(AnalyticalYamlTests, - AnalyticalGtest, - ::testing::ValuesIn(parseYamlManually("origami_gtest.yaml")), - [](const ::testing::TestParamInfo& info) { - std::string name = info.param.name; - for(auto& c : name) - if(!std::isalnum(c)) - c = '_'; - return name; - }); - diff --git a/shared/origami/tests/origami_gtest.yaml b/shared/origami/tests/origami_gtest.yaml deleted file mode 100644 index 6e6975667ac4..000000000000 --- a/shared/origami/tests/origami_gtest.yaml +++ /dev/null @@ -1,70 +0,0 @@ - -tests: - - name: ComputeLoads - inputs: - - { M: 128, N: 128, K: 64, gpu_arch: 942, expected: 8192 } - - - name: EstimateL2Hit - inputs: - - { M: 4096, N: 4096, K: 1024, batch: 1, MT_M: 256, MT_N: 256, MT_K: 64, element_size: 16, splittingFactor: 3, gpu_arch: 942, expected_gt: 0.0, expected_lt: 1.0} - - - name: ComputeNumMatrixInstructions - inputs: - - { MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 16, MI_N: 16, MI_K: 16, gpu_arch: 942, expected: 256} # 8 * 8 * 4 - - { MT_M: 16, MT_N: 16, MT_K: 64, MI_M: 16, MI_N: 16, MI_K: 32, gpu_arch: 942, expected: 2} # 1 * 1 * 2 - - - name: ComputeMTComputeLatency - inputs: - - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 0, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 4096} # 8 * 8 * 4 - - { M: 4096, N: 4096, K: 1024, transA: 0, transB: 1, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 4096} # 8 * 8 * 4 - - { M: 4096, N: 4096, K: 1024, transA: 0, transB: 0, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 4096} # 8 * 8 * 4 - - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 1, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 4096} # 8 * 8 * 4 - - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 1, MT_M: 128, MT_N: 128, MT_K: 32, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 2048} # 8 * 8 * 4 - - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 1, MT_M: 224, MT_N: 224, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected_gt: 12543} # 8 * 8 * 4 - - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 1, MT_M: 128, MT_N: 32, MT_K: 32, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 512} # 8 * 8 * 4 - - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 1, MT_M: 32, MT_N: 128, MT_K: 32, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, gpu_arch: 942, expected: 512} # 8 * 8 * 4 - - - name: ComputeMemoryLatency - inputs: - - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 0, batch: 1, MT_M: 128, MT_N: 128, MT_K: 64, element_size_A: 16, element_size_B: 16, mx_block_size: 0, wgm: 8, numActiveCUs: 304, splittingFactor: 2, gpu_arch: 942} - - - name: ComputeTileLatency - inputs: - - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 0, batch: 2, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, element_size_out: 32, mx_block_size: 0, wgm: 6, numActiveCUs: 304, splittingFactor: 3, gpu_arch: 942} - - - name: ComputeWaveLatency - inputs: - - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 0, batch: 2, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, element_size_out: 32, mx_block_size: 0, wgm: 8, numActiveCUs: 304, splittingFactor: 4, gpu_arch: 942} - - - name: ComputeTotalLatency - inputs: - - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 0, batch: 2, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, element_size_out: 32, mx_block_size: 0, wgm: 1, splittingFactor: 6, max_cus: 0, gpu_arch: 942} - - - name: ComputePerfGflops - inputs: - - { M: 4096, N: 4096, K: 1024, transA: 1, transB: 0, batch: 2, MT_M: 128, MT_N: 128, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, element_size_out: 32, WGM: 1, max_cus: 0, gpu_arch: 942} - - - name: EstimateMallHit - inputs: - - { M: 4096, N: 4096, K: 1024, batch: 1, MT_M: 256, MT_N: 256, MT_K: 64, numActiveCUs: 304, splittingFactor: 8, gpu_arch: 942, expected_gt: 0, element_size_A: 16} - - - name: HardwareArchEnum - inputs: - - { gpu_arch: 942} - - - name: BestGridSize - inputs: - - { M: 1024, N: 1024, K: 4096, batch: 1, transA: 1, transB: 0, MT_M: 256, MT_N: 256, MT_K: 64, MI_M: 32, MI_N: 32, MI_K: 8, element_size_A: 16, element_size_B: 16, element_size_out: 32, mx_block_size: 0, H_L2: 0.0, WGM: 1, biggest_allowable_split: 20, gpu_arch: 942, max_cus: 0, expected_gt: 15} - - - name: BestMacroTileSize - inputs: - - { M: 1024, N: 1024, K: 4096, batch: 1, transA: 1, transB: 0, element_size_A: 16, element_size_B: 16, element_size_out: 32, mx_block_size: 0, H_L2: 0.0, WGM: 1, max_cus: 0, gpu_arch: 942} - - { M: 4, N: 4, K: 0, batch: 1, transA: 1, transB: 0, element_size_A: 16, element_size_B: 16, element_size_out: 32, mx_block_size: 0, H_L2: 0.0, WGM: 1, max_cus: 0, gpu_arch: 942} - - - name: BestWGM - inputs: - - { M: 4096, N: 4096, K: 8192, batch: 1, MT_M: 256, MT_N: 256, MT_K: 32, MI_M: 32, MI_N: 32, MI_K: 8, element_size: 16, H_L2: 0.0, gpu_arch: 942} - - - name: UtilsTFlopsFromLatency - inputs: - - { M: 10, N: 10, K: 10, batch: 1, latency_cycles: 256, clock_GHz: 256, gpu_arch: 942} diff --git a/shared/origami/tests/test_gemm.cpp b/shared/origami/tests/test_gemm.cpp new file mode 100644 index 000000000000..2d926aaba994 --- /dev/null +++ b/shared/origami/tests/test_gemm.cpp @@ -0,0 +1,236 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright 2025 AMD ROCm(TM) Software + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include "common.hpp" + +using Catch::Approx; + +// Test functions for gemm.hpp/cpp + +TEST_CASE("GEMM: compute_num_matrix_instructions", "[gemm]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - 128x128x64 with 16x16x16") { + auto hardware = make_hardware(gpu_arch); + origami::dim3_t mt{128, 128, 64}; + origami::dim3_t mi{16, 16, 16}; + auto num_instructions = origami::compute_number_matrix_instructions(mt, mi); + REQUIRE(num_instructions == 256); // 8 * 8 * 4 + } + + DYNAMIC_SECTION("gfx" << gpu_arch << " - 16x16x64 with 16x16x32") { + auto hardware = make_hardware(gpu_arch); + origami::dim3_t mt{16, 16, 64}; + origami::dim3_t mi{16, 16, 32}; + auto num_instructions = origami::compute_number_matrix_instructions(mt, mi); + REQUIRE(num_instructions == 2); // 1 * 1 * 2 + } + } +} + +TEST_CASE("GEMM: compute_mt_compute_latency", "[gemm]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - transA=T transB=N") { + auto hardware = make_hardware(gpu_arch); + auto problem = + make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::N); + auto config = make_config(128, 128, 64, 32, 32, 8, 1); + + auto latency = origami::compute_mt_compute_latency(problem, hardware, config); + REQUIRE(latency == 4096); + } + + DYNAMIC_SECTION("gfx" << gpu_arch << " - transA=N transB=T") { + auto hardware = make_hardware(gpu_arch); + auto problem = + make_problem(4096, 4096, 1024, origami::transpose_t::N, origami::transpose_t::T); + auto config = make_config(128, 128, 64, 32, 32, 8, 1); + + auto latency = origami::compute_mt_compute_latency(problem, hardware, config); + REQUIRE(latency == 4096); + } + + DYNAMIC_SECTION("gfx" << gpu_arch << " - transA=N transB=N") { + auto hardware = make_hardware(gpu_arch); + auto problem = + make_problem(4096, 4096, 1024, origami::transpose_t::N, origami::transpose_t::N); + auto config = make_config(128, 128, 64, 32, 32, 8, 1); + + auto latency = origami::compute_mt_compute_latency(problem, hardware, config); + REQUIRE(latency == 4096); + } + + DYNAMIC_SECTION("gfx" << gpu_arch << " - transA=T transB=T") { + auto hardware = make_hardware(gpu_arch); + auto problem = + make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::T); + auto config = make_config(128, 128, 64, 32, 32, 8, 1); + + auto latency = origami::compute_mt_compute_latency(problem, hardware, config); + REQUIRE(latency == 4096); + } + + DYNAMIC_SECTION("gfx" << gpu_arch << " - different MT_K") { + auto hardware = make_hardware(gpu_arch); + auto problem = make_problem(4096, 4096, 1024); + auto config = make_config(128, 128, 32, 32, 32, 8, 1); + + auto latency = origami::compute_mt_compute_latency(problem, hardware, config); + REQUIRE(latency == 2048); + } + + DYNAMIC_SECTION("gfx" << gpu_arch << " - larger tile") { + auto hardware = make_hardware(gpu_arch); + auto problem = make_problem(4096, 4096, 1024); + auto config = make_config(224, 224, 64, 32, 32, 8, 1); + + auto latency = origami::compute_mt_compute_latency(problem, hardware, config); + REQUIRE(latency > 12543); + } + } +} + +TEST_CASE("GEMM: compute_memory_latency", "[gemm]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - verify smaller tiles have lower latency") { + auto hardware = make_hardware(gpu_arch); + auto problem = + make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::N, 1); + auto config_small = make_config(128, 128, 64, 32, 32, 8, 8); + auto config_large = make_config(256, 256, 128, 32, 32, 8, 8); + + auto mem_latency_small = + origami::compute_memory_latency(problem, hardware, config_small, 304, 2); + auto mem_latency_large = + origami::compute_memory_latency(problem, hardware, config_large, 304, 2); + + REQUIRE(mem_latency_small < mem_latency_large); + } + } +} + +TEST_CASE("GEMM: compute_tile_latency", "[gemm]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - verify larger tiles have higher latency") { + auto hardware = make_hardware(gpu_arch); + auto problem = + make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::N, 2); + auto config_small = make_config(128, 128, 64, 32, 32, 8, 6); + auto config_large = make_config(256, 256, 128, 32, 32, 8, 6); + + auto tile_latency_small = + origami::compute_tile_latency(problem, hardware, config_small, 304, 3); + auto tile_latency_large = + origami::compute_tile_latency(problem, hardware, config_large, 304, 3); + + REQUIRE(tile_latency_large > tile_latency_small); + } + } +} + +TEST_CASE("GEMM: compute_timestep_latency", "[gemm]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - wave latency equals tile latency") { + auto hardware = make_hardware(gpu_arch); + auto problem = + make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::N, 2); + auto config = make_config(128, 128, 64, 32, 32, 8, 8); + + auto tile_latency = origami::compute_tile_latency(problem, hardware, config, 304, 4); + auto wave_latency = origami::compute_timestep_latency(problem, hardware, config, 304, 4); + + REQUIRE(wave_latency == Approx(tile_latency)); + } + } +} + +TEST_CASE("GEMM: compute_total_latency", "[gemm]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - smaller tiles have lower total latency") { + auto hardware = make_hardware(gpu_arch); + auto problem_small = + make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::N, 2); + auto problem_large = + make_problem(8192, 8192, 2048, origami::transpose_t::T, origami::transpose_t::N, 2); + auto config = make_config(128, 128, 64, 32, 32, 8, 1); + + auto latency_small = + origami::compute_total_latency(problem_small, hardware, config, hardware.N_CU); + auto latency_large = + origami::compute_total_latency(problem_large, hardware, config, hardware.N_CU); + + REQUIRE(latency_small < latency_large); + } + } +} + +TEST_CASE("GEMM: check_lds_capacity", "[gemm]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - 256x256x64 tile fits in LDS") { + auto hardware = make_hardware(gpu_arch); + origami::dim3_t mt{256, 256, 64}; + + auto fits = origami::check_lds_capacity( + hardware, mt, origami::data_type_t::BFloat16, origami::data_type_t::BFloat16); + + REQUIRE(fits == true); + } + } +} + +TEST_CASE("GEMM: estimate_l2_hit", "[gemm]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - L2 hit rate in valid range") { + auto hardware = make_hardware(gpu_arch); + auto problem = make_problem(4096, 4096, 1024); + auto config = make_config(256, 256, 64, 32, 32, 8, 1); + + for (int wgm = 1; wgm < 1025; wgm++) { + config.workgroup_mapping = wgm; + auto l2_hit = origami::estimate_l2_hit(problem, hardware, config, 3); + REQUIRE(l2_hit > 0.0); + REQUIRE(l2_hit < 1.0); + } + } + } +} + +TEST_CASE("GEMM: estimate_mall_hit", "[gemm]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - Mall hit rate is positive") { + auto hardware = make_hardware(gpu_arch); + auto problem = make_problem(4096, 4096, 1024); + auto config = make_config(256, 256, 64, 32, 32, 8, 1); + + for (int wgm = 1; wgm < 1025; wgm++) { + config.workgroup_mapping = wgm; + auto mall_hit = origami::estimate_mall_hit(problem, hardware, config, 304, 8); + REQUIRE(mall_hit > 0.0); + } + } + } +} diff --git a/shared/origami/tests/test_negative_occupancy.cpp b/shared/origami/tests/test_negative_occupancy.cpp deleted file mode 100644 index 406e06aaebde..000000000000 --- a/shared/origami/tests/test_negative_occupancy.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "origami/gemm.hpp" -#include "origami/hardware.hpp" -#include "origami/utils.hpp" - -TEST(OrigamiTileSelection, NegativeOccupancy) { - // Setup hardware (gfx942) - auto gfx942arch = origami::hardware_t::arch_name_to_enum("gfx942"); - auto hardware = origami::hardware_t(gfx942arch, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.7, 1, - std::make_tuple(0, 0.015, 0)); - - // Square problem size - size_t M = 32; - size_t N = 800000; - size_t K = 16; - size_t batch = 1; - bool transA = false; - bool transB = true; - - // List 1: Tile A first, then Tile B - const std::vector MT_list = { - {256, 256, 32, 16, 16, 32, -1, 6, 0, 0}, // Tile A - {32, 256, 16, 32, 32, 8, 2, 6, 0, 0} // Tile B - }; - - auto results = origami::select_best_macro_tile_size( - M, N, K, batch, transA, transB, hardware, MT_list, 32, 32, 32, - origami::data_type_t::XFloat32, 0, 0.8, false, 6); - - auto best_tile = results[0]; - size_t MT_M = std::get<1>(best_tile); - size_t MT_N = std::get<2>(best_tile); - size_t MT_K = std::get<3>(best_tile); - EXPECT_EQ(MT_M, 32) << "MT_M should be 32"; - EXPECT_EQ(MT_N, 256) << "MT_N should be 256"; - EXPECT_EQ(MT_K, 16) << "MT_K should be 16"; -} \ No newline at end of file diff --git a/shared/origami/tests/test_origami.cpp b/shared/origami/tests/test_origami.cpp new file mode 100644 index 000000000000..a3445dfd40b0 --- /dev/null +++ b/shared/origami/tests/test_origami.cpp @@ -0,0 +1,287 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright 2025 AMD ROCm(TM) Software + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include "common.hpp" + +using Catch::Approx; + +// Test functions for origami.hpp/cpp + +TEST_CASE("Origami: compute_perf_gflops", "[origami]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - faster clock yields higher GFLOPS") { + // TODO: Add support for make_hardware using hipDeviceProperties + auto hardware_slow = make_hardware(gpu_arch, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.4); + auto hardware_fast = make_hardware(gpu_arch, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.8); + auto problem = + make_problem(4096, 4096, 1024, origami::transpose_t::T, origami::transpose_t::N, 2); + auto config = make_config(128, 128, 64, 32, 32, 8, 1); + + auto config_slow = config; + auto config_fast = config; + + auto latency_config_slow = + origami::compute_total_latency(problem, hardware_slow, config_slow, hardware_slow.N_CU); + auto flops_slow = origami::compute_perf_gflops(hardware_slow, problem, latency_config_slow); + + auto latency_config_fast = + origami::compute_total_latency(problem, hardware_fast, config_fast, hardware_fast.N_CU); + auto flops_fast = origami::compute_perf_gflops(hardware_fast, problem, latency_config_fast); + + REQUIRE(flops_fast > flops_slow); + } + } +} + +TEST_CASE("Origami: hardware_arch_enum", "[origami]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " architecture enum") { + std::string arch_str = "gfx" + std::to_string(gpu_arch); + auto arch_enum = origami::hardware_t::arch_name_to_enum(arch_str); + + if (gpu_arch == 942) { + REQUIRE(arch_enum == origami::hardware_t::architecture_t::gfx942); + } else if (gpu_arch == 950) { + REQUIRE(arch_enum == origami::hardware_t::architecture_t::gfx950); + } + } + } +} + +TEST_CASE("Origami: best_grid_size", "[origami]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - grid size selection") { + auto hardware = make_hardware(gpu_arch); + auto problem = make_problem(1024, 1024, 4096); + auto config = make_config(256, 256, 64, 32, 32, 8, 1); + + auto grid_size = origami::streamk::select_grid_size( + problem, hardware, config, origami::grid_selection_t::k_split_aware, hardware.N_CU); + + REQUIRE(grid_size >= 16); + } + } +} + +TEST_CASE("Origami: best_macro_tile_size", "[origami]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - rank configs by latency") { + auto hardware = make_hardware(gpu_arch); + auto problem = make_problem(1024, 1024, 4096); + + // List 1: config A first, then config B + std::vector configs; + + // config A[0] + configs.push_back(make_config(256, 256, 32, 32, 32, 8, 1, 6, 0, 0)); + // config A[1] + configs.push_back(make_config(128, 128, 64, 32, 32, 8, 1, 6, 0, 0)); + // config A[2] + configs.push_back(make_config(64, 64, 64, 32, 32, 8, 1, 6, 0, 0)); + + auto results = origami::rank_configs(problem, hardware, configs); + + REQUIRE(results.size() == configs.size()); + // Results should be ranked, so latencies should be in ascending order (best first) + for (size_t i = 0; i < results.size() - 1; i++) { + REQUIRE(results[i].latency < results[i + 1].latency); + } + } + } +} + +TEST_CASE("Origami: select_workgroup_mapping", "[origami]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - workgroup mapping selection") { + auto hardware = make_hardware(gpu_arch); + auto problem = make_problem(4096, 4096, 8192); + + auto config_large = make_config(256, 256, 32, 32, 32, 8, 1); + auto skGrid_large = (4096 + 256 - 1) / 256 * (4096 + 256 - 1) / 256; + auto [best_wgmxcc_large_tile, best_wgm_large_tile] = + origami::select_workgroup_mapping(problem, hardware, config_large, skGrid_large); + + auto config_small = make_config(128, 128, 64, 32, 32, 8, 1); + auto skGrid_small = (4096 + 128 - 1) / 128 * (4096 + 128 - 1) / 128; + auto [best_wgmxcc_small_tile, best_wgm_small_tile] = + origami::select_workgroup_mapping(problem, hardware, config_small, skGrid_small); + + // Different problem size for nonsquare test + origami::problem_t problem_nonsquare = problem; + problem_nonsquare.size.m = 2048; + problem_nonsquare.size.n = 5120; + auto skGrid_nonsquare = (2048 + 128 - 1) / 128 * (5120 + 128 - 1) / 128; + + auto [best_wgmxcc_nonsquare_tile, best_wgm_nonsquare] = origami::select_workgroup_mapping( + problem_nonsquare, hardware, config_large, skGrid_nonsquare); + + REQUIRE(best_wgmxcc_large_tile == best_wgmxcc_small_tile); + REQUIRE(best_wgm_large_tile > best_wgm_small_tile); + REQUIRE(best_wgm_large_tile != best_wgm_nonsquare); + } + } +} + +TEST_CASE("GEMM: negative_occupancy", "[gemm]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - test negative occupancy") { + auto hardware = make_hardware(gpu_arch); + origami::problem_t problem = { + .size = {32, 800000, 16}, + .batch = 1, + .a_transpose = origami::transpose_t::N, + .b_transpose = origami::transpose_t::T, + .a_dtype = origami::data_type_t::XFloat32, // element_size_A = 16 + .b_dtype = origami::data_type_t::XFloat32, + .mi_dtype = origami::data_type_t::XFloat32, + .a_mx_block_size = 0, + .b_mx_block_size = 0, + }; + // List 1: config A first, then config B + std::vector config; + + // config[0] + config.push_back(make_config(256, 256, 32, 16, 16, 32, -1, 6, 0, 0)); + // config[1] + config.push_back(make_config(32, 256, 16, 32, 32, 8, 2, 6, 0, 0)); + + // Call select_config + auto best_tile = origami::select_config(problem, hardware, config); + + size_t MT_M = best_tile.config.mt.m; + size_t MT_N = best_tile.config.mt.n; + size_t MT_K = best_tile.config.mt.k; + REQUIRE(MT_M == 32); //"MT_M should be 32" + REQUIRE(MT_N == 256); //"MT_N should be 256" + REQUIRE(MT_K == 16); //"MT_K should be 16" + } + } +} + +TEST_CASE("GEMM: deterministic_tie_breaking", "[gemm]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - Verify deterministic selection") { + auto hardware = make_hardware(gpu_arch); + // Square problem size + auto problem = + make_problem(1024, 1024, 1024, origami::transpose_t::N, origami::transpose_t::N, 1); + + // Two config with same arithmetic intensity: 256x64x32 and 64x256x32 + // AI = (2 * MT_M * MT_N * MT_K) / (MT_M*MT_K + MT_N*MT_K + MT_M*MT_N) + // Both have AI = 1048576 / 26624 = 39.38 + + // List 1: config A first, then config B + std::vector config_A; + std::vector config_B; + + // config A[0] + config_A.push_back(make_config(256, 64, 32, 32, 32, 8, 1, 6, 0, 0)); + // config A[1] + config_A.push_back(make_config(64, 256, 32, 32, 32, 8, 1, 6, 0, 0)); + + // config B[0] (reversed order) + config_B.push_back(make_config(64, 256, 32, 32, 32, 8, 1, 6, 0, 0)); + // config B[1] (reversed order) + config_B.push_back(make_config(256, 64, 32, 32, 32, 8, 1, 6, 0, 0)); + + // Call select_config with both orderings + auto best_tile_A = origami::select_config(problem, hardware, config_A); + + auto best_tile_B = origami::select_config(problem, hardware, config_B); + + size_t MT_M_A_first = best_tile_A.config.mt.m; + size_t MT_N_A_first = best_tile_A.config.mt.n; + size_t MT_K_A_first = best_tile_A.config.mt.k; + + size_t MT_M_B_first = best_tile_B.config.mt.m; + size_t MT_N_B_first = best_tile_B.config.mt.n; + size_t MT_K_B_first = best_tile_B.config.mt.k; + + // Verify deterministic selection: both should select the same tile (256x64x32) + // regardless of input order, using the final tie-breaker (prefer larger MT_M) + REQUIRE(MT_M_A_first == MT_M_B_first); //"Selected tile MT_M should be consistent" + REQUIRE(MT_N_A_first == MT_N_B_first); //"Selected tile MT_N should be consistent" + REQUIRE(MT_K_A_first == MT_K_B_first); //"Selected tile MT_K should be consistent" + + // Verify it selected the tile with larger MT_M (256 > 64) + REQUIRE(MT_M_A_first == 256); //"Should prefer tile with larger MT_M" + REQUIRE(MT_N_A_first == 64); //"Should prefer tile with larger MT_M" + } + } +} + +TEST_CASE("GEMM: Verify deterministic tile selection", "[gemm]") { + for (int gpu_arch : test_architectures) { + DYNAMIC_SECTION("gfx" << gpu_arch << " - Verify deterministic selection") { + auto hardware = make_hardware(gpu_arch); + + // problem size + auto problem = + make_problem(42598, 153, 128, origami::transpose_t::N, origami::transpose_t::T); + + // List 1: config A first, then config B + std::vector config_A; + std::vector config_B; + + // config A[0] + config_A.push_back(make_config(256, 160, 32, 16, 16, 32, 1, 6, 0, 0)); // Tile A + // config A[1] + config_A.push_back(make_config(192, 160, 64, 16, 16, 32, 1, 6, 0, 0)); // Tile B + + // config B[0] Previous two tiles + a new one + config_B.push_back(make_config(256, 160, 32, 16, 16, 32, 1, 6, 0, 0)); // Tile A + // config B[1] + config_B.push_back(make_config(192, 160, 64, 16, 16, 32, 1, 6, 0, 0)); // Tile B + // config B[2] + config_B.push_back(make_config(192, 160, 32, 16, 16, 32, 1, 6, 0, 0)); // Tile C + + // Call select_config with both tile configs + auto best_tile_A = origami::select_config(problem, hardware, config_A); + + auto best_tile_B = origami::select_config(problem, hardware, config_B); + + size_t MT_M1 = best_tile_A.config.mt.m; + size_t MT_N1 = best_tile_A.config.mt.n; + size_t MT_K1 = best_tile_A.config.mt.k; + + size_t MT_M2 = best_tile_B.config.mt.m; + size_t MT_N2 = best_tile_B.config.mt.n; + size_t MT_K2 = best_tile_B.config.mt.k; + + auto winner_is_acceptable = + [](auto const& actual, auto const& candidate1, auto const& candidate2) -> bool { + return (actual == candidate1) || (actual == candidate2); + }; + + INFO("Winner is not acceptable"); + REQUIRE(winner_is_acceptable(MT_M2, MT_M1, config_B[2].mt.m)); + REQUIRE(winner_is_acceptable(MT_N2, MT_N1, config_B[2].mt.n)); + REQUIRE(winner_is_acceptable(MT_K2, MT_K1, config_B[2].mt.k)); + } + } +} diff --git a/shared/origami/tests/test_tile_ordering_issue.cpp b/shared/origami/tests/test_tile_ordering_issue.cpp deleted file mode 100644 index 1dd0eb220411..000000000000 --- a/shared/origami/tests/test_tile_ordering_issue.cpp +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "origami/gemm.hpp" -#include "origami/hardware.hpp" -#include "origami/utils.hpp" - -// Test to verify deterministic tile selection when tiles have same latency and arithmetic intensity -// This test ensures that the tie-breaking logic works correctly regardless of input order -TEST(OrigamiTileSelection, DeterministicTieBreaking) { - // Setup hardware (gfx942) - auto gfx942arch = origami::hardware_t::arch_name_to_enum("gfx942"); - auto hardware = origami::hardware_t(gfx942arch, 304, 65536, 8, 1.0, 1.0, 1.0, 4000000, 1.7, 1, - std::make_tuple(0, 0.015, 0)); - - // Square problem size - size_t M = 1024; - size_t N = 1024; - size_t K = 1024; - size_t batch = 1; - bool transA = false; - bool transB = false; - - // Two tiles with same arithmetic intensity: 256x64x32 and 64x256x32 - // AI = (2 * MT_M * MT_N * MT_K) / (MT_M*MT_K + MT_N*MT_K + MT_M*MT_N) - // Both have AI = 1048576 / 26624 = 39.38 - - // List 1: Tile A first, then Tile B - const std::vector MT_list_A_first = { - {256, 64, 32, 32, 32, 8, 1, 6, 0, 0}, // Tile A - {64, 256, 32, 32, 32, 8, 1, 6, 0, 0} // Tile B - }; - - // List 2: Tile B first, then Tile A (reversed order) - const std::vector MT_list_B_first = { - {64, 256, 32, 32, 32, 8, 1, 6, 0, 0}, // Tile B - {256, 64, 32, 32, 32, 8, 1, 6, 0, 0} // Tile A - }; - - // Call select_best_macro_tile_size with both orderings - auto results_A_first = origami::select_best_macro_tile_size( - M, N, K, batch, transA, transB, hardware, MT_list_A_first, 16, 16, 16, - origami::data_type_t::BFloat16, 0, 0.8, false, 6); - - auto results_B_first = origami::select_best_macro_tile_size( - M, N, K, batch, transA, transB, hardware, MT_list_B_first, 16, 16, 16, - origami::data_type_t::BFloat16, 0, 0.8, false, 6); - - // Extract the best tile from each result - auto best_tile_A_first = results_A_first[0]; - auto best_tile_B_first = results_B_first[0]; - - size_t MT_M_A_first = std::get<1>(best_tile_A_first); - size_t MT_N_A_first = std::get<2>(best_tile_A_first); - size_t MT_K_A_first = std::get<3>(best_tile_A_first); - - size_t MT_M_B_first = std::get<1>(best_tile_B_first); - size_t MT_N_B_first = std::get<2>(best_tile_B_first); - size_t MT_K_B_first = std::get<3>(best_tile_B_first); - - // Verify deterministic selection: both should select the same tile (256x64x32) - // regardless of input order, using the final tie-breaker (prefer larger MT_M) - EXPECT_EQ(MT_M_A_first, MT_M_B_first) << "Selected tile MT_M should be consistent"; - EXPECT_EQ(MT_N_A_first, MT_N_B_first) << "Selected tile MT_N should be consistent"; - EXPECT_EQ(MT_K_A_first, MT_K_B_first) << "Selected tile MT_K should be consistent"; - - // Verify it selected the tile with larger MT_M (256 > 64) - EXPECT_EQ(MT_M_A_first, 256) << "Should prefer tile with larger MT_M"; - EXPECT_EQ(MT_N_A_first, 64) << "Should prefer tile with larger MT_M"; -} diff --git a/shared/origami/tests/test_variance_issue.cpp b/shared/origami/tests/test_variance_issue.cpp deleted file mode 100644 index 451988cd13ef..000000000000 --- a/shared/origami/tests/test_variance_issue.cpp +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include - -#include "origami/gemm.hpp" -#include "origami/hardware.hpp" -#include "origami/utils.hpp" - -// Test to verify deterministic tile selection -TEST(OrigamiTileSelection, VarianceEffect) { - // Setup hardware - auto gfx950arch = origami::hardware_t::arch_name_to_enum("gfx950"); - // auto hardware = origami::hardware_t(gfx950arch, 256, 163840, 8, 7727.272727, 2993.42, 3157.89, 4194304, 2.2, 4, - // std::make_tuple(0, 0.008, 0)); - auto hardware = origami::hardware_t(gfx950arch, 256, 163840, 8, 7727.272727, 2000.42, 2000.89, 4194304, 2.2, 4, - std::make_tuple(0, 0.008, 0)); - - // Square problem size - size_t M = 42598; - size_t N = 153; - size_t K = 128; - size_t batch = 1; - bool transA = false; - bool transB = true; - - // List 1: Only two tiles - const std::vector MT_list1 = { - {256, 160, 32, 16, 16, 32, 1, 6, 0, 0}, // Tile A - {192, 160, 64, 16, 16, 32, 1, 6, 0, 0} // Tile B - }; - - // List 2: Previous two tiles + a new one - const std::vector MT_list2 = { - {256, 160, 32, 16, 16, 32, 1, 6, 0, 0}, // Tile A - {192, 160, 64, 16, 16, 32, 1, 6, 0, 0}, // Tile B - {192, 160, 32, 16, 16, 32, 1, 6, 0, 0} // Tile C - }; - - // Call select_best_macro_tile_size with both tile lists - auto results1 = origami::select_best_macro_tile_size( - M, N, K, batch, transA, transB, hardware, MT_list1, 16, 16, 32, - origami::data_type_t::BFloat16, 0, 0.8, false, 6); - - auto results2 = origami::select_best_macro_tile_size( - M, N, K, batch, transA, transB, hardware, MT_list2, 16, 16, 32, - origami::data_type_t::BFloat16, 0, 0.8, false, 6); - - // Extract the best tile from each result - auto best_tile1 = results1[0]; - auto best_tile2 = results2[0]; - - size_t MT_M1 = std::get<1>(best_tile1); - size_t MT_N1 = std::get<2>(best_tile1); - size_t MT_K1 = std::get<3>(best_tile1); - - size_t MT_M2 = std::get<1>(best_tile2); - size_t MT_N2 = std::get<2>(best_tile2); - size_t MT_K2 = std::get<3>(best_tile2); - - // At this time, the model predicts following latencies: - // TileA: 35803.3 - // TileB: 36088 - // TileC: 35452.4 - // Hence, for list1 TileB is the winner as it has the largest AI. - // For list2, either TileB (previous winner) or TileC (new tile) should be the winner. - // Note that adding TileC (with a reasonable variance) eliminates TileB from the - // pool of selected kernels, and TileC is the winner. - std::cout << std::get<0>(results2[0]) << " " << std::get<1>(results2[0]) << " " << std::get<3>(results2[0]) << std::endl; - std::cout << std::get<0>(results2[1]) << " " << std::get<1>(results2[1]) << " " << std::get<3>(results2[1]) << std::endl; - std::cout << std::get<0>(results2[2]) << " " << std::get<1>(results2[2]) << " " << std::get<3>(results2[2]) << std::endl; - - // Verify deterministic selection - // After adding a new tile, we don't want the selection to change from - // tile A to tile B (or the other way around). - // We accept if the new tile is the winner, though! - EXPECT_THAT(MT_M2, ::testing::AnyOf(::testing::Eq(MT_M1), ::testing::Eq(std::get<0>(MT_list2[2])))) << "Winner is not acceptable"; - EXPECT_THAT(MT_N2, ::testing::AnyOf(::testing::Eq(MT_N1), ::testing::Eq(std::get<1>(MT_list2[2])))) << "Winner is not acceptable"; - EXPECT_THAT(MT_K2, ::testing::AnyOf(::testing::Eq(MT_K1), ::testing::Eq(std::get<2>(MT_list2[2])))) << "Winner is not acceptable"; -}