diff --git a/projects/hipblaslt/library/src/CMakeLists.txt b/projects/hipblaslt/library/src/CMakeLists.txt index b89dc92a5ce..f1b32faa600 100644 --- a/projects/hipblaslt/library/src/CMakeLists.txt +++ b/projects/hipblaslt/library/src/CMakeLists.txt @@ -34,5 +34,6 @@ target_sources(hipblaslt "${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/parameter_selection.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/rocroller_host.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/runtime_args_selection.hpp" + "${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/solution_cache.hpp" "${CMAKE_CURRENT_SOURCE_DIR}/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp" ) diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/CMakeLists.txt b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/CMakeLists.txt index 750d509b9d3..172d5d32d8f 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/CMakeLists.txt +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/CMakeLists.txt @@ -8,6 +8,7 @@ target_sources(hipblaslt-rocroller "${CMAKE_CURRENT_SOURCE_DIR}/parameter_selection.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/runtime_args_selection.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/rocroller_host.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/solution_cache.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/solution_selection.cpp" ) target_compile_features(hipblaslt-rocroller PRIVATE cxx_std_20) diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/README.md b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/README.md index 9c1bf5216fe..91243541096 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/README.md +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/README.md @@ -10,14 +10,14 @@ can contain other parameters besides the tile size. A list of `SolutionIndexPara in sorted order, based on what Origami predicts is the best performing kernel. If a kernel has already been generated for the specific `SolutionIndexParameters` instance that was selected, -the kernel can be found in the cache and returned. +the kernel can be found in the `SolutionCache` and returned. Otherwise, the rest of the `SolutionParameters` need to be selected. `SolutionParameters` contain all of the parameters required to generate a kernel. These parameters are selected based on the `KernelType` and the `SolutionIndexParameters`. Once all of the `SolutionParameters` have been selected, the kernel is generated using rocRoller. The kernel -is then saved for reuse in the cache and returned. +is then saved for reuse in the `SolutionCache` and returned. # Calling a rocRoller Kernel diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp index 324b3c7e8e1..6d68edacb4e 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/gemm.cpp @@ -113,10 +113,10 @@ std::string genKernelName(std::shared_ptr gemm) gemm->kernelType.typeAcc}) rv << toString(t) << "_"; - if(gemm->kernelType.scaleAMode != Operations::ScaleMode::None) - rv << "SA_" << genScaleModeString(gemm->kernelType.scaleAMode) << "_"; - if(gemm->kernelType.scaleBMode != Operations::ScaleMode::None) - rv << "SB_" << genScaleModeString(gemm->kernelType.scaleBMode) << "_"; + if(gemm->kernelType.scaleTypeA.mode != Operations::ScaleMode::None) + rv << "SA_" << genScaleModeString(gemm->kernelType.scaleTypeA.mode) << "_"; + if(gemm->kernelType.scaleTypeB.mode != Operations::ScaleMode::None) + rv << "SB_" << genScaleModeString(gemm->kernelType.scaleTypeB.mode) << "_"; rv << "WGT_"; rocRoller::streamJoin( @@ -158,25 +158,25 @@ std::shared_ptr genGemmKernel(std::shared_ptr ge auto mulInputA = tagLoadA; auto mulInputB = tagLoadB; - AssertFatal(gemm->kernelType.scaleAMode == Operations::ScaleMode::None - || gemm->kernelType.scaleAMode == Operations::ScaleMode::SingleScale - || gemm->kernelType.scaleAMode == Operations::ScaleMode::Separate, + AssertFatal(gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::None + || gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::SingleScale + || gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate, "Scale mode not supported!", - ShowValue(gemm->kernelType.scaleAMode)); - AssertFatal(gemm->kernelType.scaleBMode == Operations::ScaleMode::None - || gemm->kernelType.scaleBMode == Operations::ScaleMode::SingleScale - || gemm->kernelType.scaleBMode == Operations::ScaleMode::Separate, + ShowValue(gemm->kernelType.scaleTypeA.mode)); + AssertFatal(gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::None + || gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::SingleScale + || gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate, "Scale mode not supported!", - ShowValue(gemm->kernelType.scaleBMode)); + ShowValue(gemm->kernelType.scaleTypeB.mode)); std::optional tagTensorScaleA, tagLoadScaleA, tagBlockScaleA, tagTensorScaleB, tagLoadScaleB, tagBlockScaleB, tagScratch, tagSKGrid, tagWGM; - if(gemm->kernelType.scaleAMode == Operations::ScaleMode::Separate) + if(gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate) { tagTensorScaleA = command->addOperation(rocRoller::Operations::Tensor( 2, - gemm->kernelType.scaleTypeA, + gemm->kernelType.scaleTypeA.type, gemm->kernelType.transA ? oneStridesT : oneStridesN)); tagLoadScaleA = command->addOperation(rocRoller::Operations::T_Load_Tiled(*tagTensorScaleA)); @@ -185,14 +185,14 @@ std::shared_ptr genGemmKernel(std::shared_ptr ge tagLoadA, 2, tagLoadScaleA, - {gemm->kernelType.scaleABlockColSize, gemm->kernelType.scaleABlockRowSize})); + {gemm->kernelType.scaleTypeA.blockColSize, gemm->kernelType.scaleTypeA.blockRowSize})); } - if(gemm->kernelType.scaleBMode == Operations::ScaleMode::Separate) + if(gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate) { tagTensorScaleB = command->addOperation(rocRoller::Operations::Tensor( 2, - gemm->kernelType.scaleTypeA, + gemm->kernelType.scaleTypeA.type, gemm->kernelType.transB ? oneStridesT : oneStridesN)); tagLoadScaleB = command->addOperation(rocRoller::Operations::T_Load_Tiled(*tagTensorScaleB)); @@ -201,7 +201,7 @@ std::shared_ptr genGemmKernel(std::shared_ptr ge tagLoadB, 2, tagLoadScaleB, - {gemm->kernelType.scaleBBlockColSize, gemm->kernelType.scaleBBlockRowSize})); + {gemm->kernelType.scaleTypeB.blockColSize, gemm->kernelType.scaleTypeB.blockRowSize})); } auto tagTensorC @@ -330,11 +330,11 @@ std::shared_ptr genGemmKernel(std::shared_ptr ge params->setDimensionInfo(tagLoadA, macTileA); } - if(gemm->kernelType.scaleAMode == Operations::ScaleMode::Separate) + if(gemm->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate) { // TODO: verify the division of scale block size is correct auto const scaleBlockSize - = gemm->kernelType.scaleABlockRowSize * gemm->kernelType.scaleABlockColSize; + = gemm->kernelType.scaleTypeA.blockRowSize * gemm->kernelType.scaleTypeA.blockColSize; auto macTileAScale = KernelGraph::CoordinateGraph::MacroTile( {gemm->workgroupTile.m, gemm->workgroupTile.k / (int)scaleBlockSize}, LayoutType::MATRIX_A, @@ -358,11 +358,11 @@ std::shared_ptr genGemmKernel(std::shared_ptr ge params->setDimensionInfo(tagLoadB, macTileB); } - if(gemm->kernelType.scaleBMode == Operations::ScaleMode::Separate) + if(gemm->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate) { // TODO: verify the division of scale block size is correct auto const scaleBlockSize - = gemm->kernelType.scaleBBlockRowSize * gemm->kernelType.scaleBBlockColSize; + = gemm->kernelType.scaleTypeB.blockRowSize * gemm->kernelType.scaleTypeB.blockColSize; auto macTileBScale = KernelGraph::CoordinateGraph::MacroTile( {gemm->workgroupTile.k / (int)scaleBlockSize, gemm->workgroupTile.n}, LayoutType::MATRIX_B, @@ -547,7 +547,7 @@ CommandArguments createCommandArguments(std::shared_ptr gemm, setCommandTensorArg(commandArgs, gemm->tagTensorA, descA, (float*)nullptr); setCommandTensorArg(commandArgs, gemm->tagTensorB, descB, (float*)nullptr); - if(gemm->params->kernelType.scaleAMode == Operations::ScaleMode::Separate) + if(gemm->params->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate) { auto const scaleBlockSize = prob.scaleABlockRowSize * prob.scaleABlockColSize; TensorDescriptor descAScale(gemm->params->kernelType.typeA, @@ -555,7 +555,7 @@ CommandArguments createCommandArguments(std::shared_ptr gemm, gemm->params->kernelType.transA ? "T" : "N"); setCommandTensorArg(commandArgs, gemm->tagTensorScaleA, descAScale, (float*)nullptr); } - if(gemm->params->kernelType.scaleBMode == Operations::ScaleMode::Separate) + if(gemm->params->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate) { auto const scaleBlockSize = prob.scaleBBlockRowSize * prob.scaleBBlockColSize; TensorDescriptor descBScale(gemm->params->kernelType.typeB, @@ -578,12 +578,12 @@ CommandArguments createCommandArguments(std::shared_ptr gemm, commandArgs.setArgument(gemm->tagTensorC, ArgumentType::Value, (float*)prob.C); commandArgs.setArgument(gemm->tagTensorD, ArgumentType::Value, (float*)prob.D); - if(gemm->params->kernelType.scaleAMode == Operations::ScaleMode::Separate) + if(gemm->params->kernelType.scaleTypeA.mode == Operations::ScaleMode::Separate) { commandArgs.setArgument(gemm->tagTensorScaleA, ArgumentType::Value, (uint8_t*)prob.scaleA); } - if(gemm->params->kernelType.scaleBMode == Operations::ScaleMode::Separate) + if(gemm->params->kernelType.scaleTypeB.mode == Operations::ScaleMode::Separate) { commandArgs.setArgument(gemm->tagTensorScaleB, ArgumentType::Value, (uint8_t*)prob.scaleB); } diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/kernel_type.hpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/kernel_type.hpp index 9cbc0a961d8..da47ad8fd4d 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/kernel_type.hpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/kernel_type.hpp @@ -30,6 +30,30 @@ #include #include +struct ScaleType +{ + rocRoller::Operations::ScaleMode mode; + size_t blockRowSize = 32u; + size_t blockColSize = 1u; + rocRoller::DataType type = rocRoller::DataType::E8M0; + + auto operator<=>(const ScaleType& other) const = default; +}; + +template<> +struct std::hash +{ + size_t operator()(const ScaleType& s) const noexcept + { + size_t modeHash = std::hash{}(s.mode); + size_t blockRowSizeHash = std::hash{}(s.blockRowSize); + size_t blockColSizeHash = std::hash{}(s.blockColSize); + size_t typeHash = std::hash{}(s.type); + + return modeHash ^ (blockRowSizeHash << 1) ^ (blockColSizeHash << 2) ^ (typeHash << 3); + } +}; + /** * @brief KernelType * @@ -48,16 +72,29 @@ struct KernelType bool transA; bool transB; - rocRoller::Operations::ScaleMode scaleAMode; - rocRoller::Operations::ScaleMode scaleBMode; + ScaleType scaleTypeA; + ScaleType scaleTypeB; - size_t scaleABlockRowSize = 32u; - size_t scaleABlockColSize = 1u; - size_t scaleBBlockRowSize = 1u; - size_t scaleBBlockColSize = 32u; + auto operator<=>(const KernelType& other) const = default; +}; - rocRoller::DataType scaleTypeA = rocRoller::DataType::E8M0; - rocRoller::DataType scaleTypeB = rocRoller::DataType::E8M0; +template<> +struct std::hash +{ + size_t operator()(const KernelType& k) const noexcept + { + size_t typeAHash = std::hash{}(k.typeA); + size_t typeBHash = std::hash{}(k.typeB); + size_t typeCHash = std::hash{}(k.typeC); + size_t typeDHash = std::hash{}(k.typeD); + size_t typeAccHash = std::hash{}(k.typeAcc); + size_t scaleTypeAHash = std::hash{}(k.scaleTypeA); + size_t scaleTypeBHash = std::hash{}(k.scaleTypeB); + size_t transAHash = std::hash{}(k.transA); + size_t transBHash = std::hash{}(k.transB); - auto operator<=>(const KernelType& other) const = default; + return typeAHash ^ (typeBHash << 1) ^ (typeCHash << 2) ^ + (typeDHash << 3) ^ (typeAccHash << 4) ^ (scaleTypeAHash << 5) ^ + (scaleTypeBHash << 6) ^ (transAHash << 7) ^ (transBHash << 8); + } }; diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_cache.hpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_cache.hpp new file mode 100644 index 00000000000..8401f177804 --- /dev/null +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_cache.hpp @@ -0,0 +1,47 @@ +/*! \file */ +/* ************************************************************************ + * + * MIT License + * + * Copyright (C) 2025 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#pragma once + +#include "gemm.hpp" +#include "kernel_type.hpp" +#include "solution_selection.hpp" + +class SolutionCache +{ + public: + void addKernel(const KernelType& kernelType, const SolutionIndexParameters& params, std::shared_ptr kernel); + std::optional> getKernel(const KernelType& kernelType, const SolutionIndexParameters& params); + + private: + // Map of kernels that have already been generated. + // The first level of the map is indexed with a KernelType. + // The second level of the map is indexed with a hash value of a + // SolutionIndexParameters type. + // The value is a GemmKernel. + std::unordered_map>> m_generatedKernels; +}; diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp index 3fd0babe898..e4055d2f7ca 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/include/solution_selection.hpp @@ -41,6 +41,8 @@ struct WorkGroupTileSize int m; int n; int k; + + auto operator<=>(const WorkGroupTileSize& other) const = default; }; /** @@ -69,6 +71,8 @@ struct SolutionIndexParameters { WorkGroupTileSize workgroupTile; bool workgroupMapping; + + auto operator<=>(const SolutionIndexParameters& other) const = default; }; int parametersToIndex(const SolutionIndexParameters& params); diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp index 284060c2f6a..f5cbedb25a6 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/parameter_selection.cpp @@ -111,8 +111,8 @@ std::shared_ptr // Swizzle Scale only support in certain situations // Swizzle Scale also runs out of registers with FP8 - if (kernelType.scaleAMode != rocRoller::Operations::ScaleMode::Separate || - kernelType.scaleBMode != rocRoller::Operations::ScaleMode::Separate) + if (kernelType.scaleTypeA.mode != rocRoller::Operations::ScaleMode::Separate || + kernelType.scaleTypeB.mode != rocRoller::Operations::ScaleMode::Separate) { gemm->swizzleScale = false; gemm->prefetchScale = false; @@ -161,18 +161,18 @@ std::shared_ptr // LDS can only be used for scaling data with certain workgroup tile sizes auto workgroupSizeTotal = gemm->workgroupSizeX * gemm->workgroupSizeY; auto numScaleElementsA = 0; - if(gemm->kernelType.scaleABlockRowSize * gemm->kernelType.scaleABlockColSize != 0) + if(gemm->kernelType.scaleTypeA.blockRowSize * gemm->kernelType.scaleTypeA.blockColSize != 0) { numScaleElementsA = gemm->workgroupTile.m * (gemm->workgroupTile.k - / (gemm->kernelType.scaleABlockRowSize * gemm->kernelType.scaleABlockColSize)); + / (gemm->kernelType.scaleTypeA.blockRowSize * gemm->kernelType.scaleTypeA.blockColSize)); } auto numScaleElementsB = 0; - if(gemm->kernelType.scaleBBlockRowSize * gemm->kernelType.scaleBBlockColSize != 0) + if(gemm->kernelType.scaleTypeB.blockRowSize * gemm->kernelType.scaleTypeB.blockColSize != 0) { numScaleElementsB = gemm->workgroupTile.n * (gemm->workgroupTile.k - / (gemm->kernelType.scaleBBlockRowSize * gemm->kernelType.scaleBBlockColSize)); + / (gemm->kernelType.scaleTypeB.blockRowSize * gemm->kernelType.scaleTypeB.blockColSize)); } if(numScaleElementsA % workgroupSizeTotal != 0) { diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/rocroller_host.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/rocroller_host.cpp index f6a1192c100..f41a2973497 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/rocroller_host.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/rocroller_host.cpp @@ -33,6 +33,7 @@ #include "rocroller_host.hpp" #include "runtime_args_selection.hpp" #include "parameter_selection.hpp" +#include "solution_cache.hpp" #include "solution_selection.hpp" #include "Debug.hpp" @@ -49,12 +50,7 @@ using namespace rocRoller; */ struct RocRollerHandle { - // Map of kernels that have already been generated. - // The first level of the map is indexed with a KernelType. - // The second level of the map is indexed with a hash value of a - // SolutionIndexParameters type. - // The value is a GemmKernel. - std::map>> generatedKernels; + SolutionCache cache; }; /** @@ -401,16 +397,16 @@ KernelType genKernelType(const RocblasltContractionProblem& prob) kernelType.typeAcc = rocblaslt_compute_type_to_rocRoller_type(prob.compute_type); kernelType.transA = prob.trans_a == HIPBLAS_OP_T; kernelType.transB = prob.trans_b == HIPBLAS_OP_T; - kernelType.scaleAMode = prob.scaleAType == RocblasltContractionProblem::ScalingFormat::Block + kernelType.scaleTypeA.mode = prob.scaleAType == RocblasltContractionProblem::ScalingFormat::Block ? rocRoller::Operations::ScaleMode::Separate : rocRoller::Operations::ScaleMode::None; - kernelType.scaleBMode = prob.scaleBType == RocblasltContractionProblem::ScalingFormat::Block + kernelType.scaleTypeB.mode = prob.scaleBType == RocblasltContractionProblem::ScalingFormat::Block ? rocRoller::Operations::ScaleMode::Separate : rocRoller::Operations::ScaleMode::None; - kernelType.scaleABlockRowSize = prob.scaleABlockRowSize; - kernelType.scaleABlockColSize = prob.scaleABlockColSize; - kernelType.scaleBBlockRowSize = prob.scaleBBlockRowSize; - kernelType.scaleBBlockColSize = prob.scaleBBlockColSize; + kernelType.scaleTypeA.blockRowSize = prob.scaleABlockRowSize; + kernelType.scaleTypeA.blockColSize = prob.scaleABlockColSize; + kernelType.scaleTypeB.blockRowSize = prob.scaleBBlockRowSize; + kernelType.scaleTypeB.blockColSize = prob.scaleBBlockColSize; return kernelType; } @@ -429,8 +425,8 @@ rocblaslt_status auto params = genSolutionParameters(kernelType, solutionIndexParameter); try { - kernel = genGemmKernel(params); - rocroller_handle->generatedKernels[kernelType][solutionIndex] = kernel; + kernel = genGemmKernel(params); + rocroller_handle->cache.addKernel(kernelType, solutionIndexParameter, kernel); } catch(const std::exception& e) { @@ -503,12 +499,6 @@ rocblaslt_status return rocblaslt_status_invalid_value; } - auto existingKernelType = rocroller_handle->generatedKernels.find(kernelType); - if(existingKernelType == rocroller_handle->generatedKernels.end()) - { - rocroller_handle->generatedKernels[kernelType] = {}; - } - auto solutionIndexParameters = chooseSolutionIndexParameters(kernelType, prob, requestedAlgoCount); @@ -519,10 +509,10 @@ rocblaslt_status break; index = parametersToIndex(solutionIndexParameter); - auto existingSolutionIndex = rocroller_handle->generatedKernels[kernelType].find(index); + auto existingSolution = rocroller_handle->cache.getKernel(kernelType, solutionIndexParameter); std::shared_ptr kernel; // If kernel doesn't already exist, generate it - if(existingSolutionIndex == rocroller_handle->generatedKernels[kernelType].end()) + if(!existingSolution) { auto status = genKernelFromSolutionIndexParameters( rocroller_handle, kernelType, solutionIndexParameter, index, kernel); @@ -531,7 +521,7 @@ rocblaslt_status } else { - kernel = existingSolutionIndex->second; + kernel = *existingSolution; } // Fill out heuristicResultsArray @@ -616,26 +606,17 @@ rocblaslt_status getKernelFromAlgo(rocblaslt_handle handle, RocRollerHandle* rocroller_handle = static_cast(handle->rocroller_handle); auto kernelType = genKernelType(prob); - auto existingKernelType = rocroller_handle->generatedKernels.find(kernelType); - // If KernelType doesn't exist yet, add an empty container for it to map. - if(existingKernelType == rocroller_handle->generatedKernels.end()) + auto solutionIndexParameters = indexToParameters(*solutionIndex); + auto existingKernel = rocroller_handle->cache.getKernel(kernelType, solutionIndexParameters); + if(existingKernel) { - rocroller_handle->generatedKernels[kernelType] = {}; - existingKernelType = rocroller_handle->generatedKernels.find(kernelType); - } - - auto existingKernel = existingKernelType->second.find(*solutionIndex); - if(existingKernel != existingKernelType->second.end()) - { - kernel = existingKernel->second; + kernel = *existingKernel; return rocblaslt_status_success; } else { - auto solutionIndexParameter = indexToParameters(*solutionIndex); - auto status = genKernelFromSolutionIndexParameters( - rocroller_handle, kernelType, solutionIndexParameter, *solutionIndex, kernel); + rocroller_handle, kernelType, solutionIndexParameters, *solutionIndex, kernel); return status; } } diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_cache.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_cache.cpp new file mode 100644 index 00000000000..b03d7ad6f28 --- /dev/null +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_cache.cpp @@ -0,0 +1,59 @@ +/*! \file */ +/* ************************************************************************ + * + * MIT License + * + * Copyright (C) 2025 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ************************************************************************ */ + +#include "solution_cache.hpp" + +void SolutionCache::addKernel(const KernelType& kernelType, const SolutionIndexParameters& params, std::shared_ptr kernel) +{ + auto existingKernelType = m_generatedKernels.find(kernelType); + if(existingKernelType == m_generatedKernels.end()) + { + m_generatedKernels[kernelType] = {}; + } + + auto index = parametersToIndex(params); + + m_generatedKernels[kernelType][index] = kernel; +} + +std::optional> SolutionCache::getKernel(const KernelType& kernelType, const SolutionIndexParameters& params) +{ + auto existingKernelType = m_generatedKernels.find(kernelType); + if(existingKernelType == m_generatedKernels.end()) + { + return std::nullopt; + } + + auto index = parametersToIndex(params); + + auto kernel = existingKernelType->second.find(index); + + if (kernel == existingKernelType->second.end()) + return std::nullopt; + else + return kernel->second; +} diff --git a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp index 2c9549e7bd0..ae1dca5e046 100644 --- a/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp +++ b/projects/hipblaslt/library/src/amd_detail/rocblaslt/src/rocroller/solution_selection.cpp @@ -193,8 +193,8 @@ std::vector chooseSolutionIndexParameters( .a_dtype = rocroller_type_to_analytical_type(kernelType.typeA), .b_dtype = rocroller_type_to_analytical_type(kernelType.typeB), .mi_dtype = rocroller_type_to_analytical_type(elementSizeA_bits < elementSizeB_bits ? kernelType.typeB : kernelType.typeA), - .a_mx_block_size = kernelType.scaleABlockRowSize * kernelType.scaleABlockColSize, - .b_mx_block_size = kernelType.scaleBBlockRowSize * kernelType.scaleBBlockColSize, + .a_mx_block_size = kernelType.scaleTypeA.blockRowSize * kernelType.scaleTypeA.blockColSize, + .b_mx_block_size = kernelType.scaleTypeB.blockRowSize * kernelType.scaleTypeB.blockColSize, }; int defaultWGM = std::ceil(std::sqrt(analytical_hardware.N_CU / analytical_hardware.NUM_XCD));