diff --git a/projects/hipdnn/CMakeLists.txt b/projects/hipdnn/CMakeLists.txt index aca3eca02c3..4a6cfb8635f 100644 --- a/projects/hipdnn/CMakeLists.txt +++ b/projects/hipdnn/CMakeLists.txt @@ -101,6 +101,23 @@ set(HIPDNN_RELATIVE_INSTALL_PLUGIN_ENGINE_DIR "${HIPDNN_PLUGIN_ROOTDIR}/${HIPDNN_PLUGIN_ENGINE_SUBDIR}" CACHE STRING "Relative install directory for official hipDNN engine plugins" ) +set(HIPDNN_PLUGIN_HEURISTIC_SUBDIR "hipdnn_plugins/heuristics" + CACHE STRING "Subdirectory for official hipDNN heuristic plugins" +) +set(HIPDNN_BUILD_PLUGIN_HEURISTIC_DIR + "${CMAKE_BINARY_DIR}/${HIPDNN_PLUGIN_ROOTDIR}/${HIPDNN_PLUGIN_HEURISTIC_SUBDIR}" + CACHE PATH "Build directory for official hipDNN heuristic plugins" +) +set(HIPDNN_FULL_INSTALL_PLUGIN_HEURISTIC_DIR + "${HIPDNN_PLUGIN_FULL_ROOTDIR}/${HIPDNN_PLUGIN_HEURISTIC_SUBDIR}" + CACHE + STRING + "Full install directory for official hipDNN heuristic plugins. Uses the CMAKE_INSTALL_PREFIX set at configure time (defaults to /opt/rocm/)." +) +set(HIPDNN_RELATIVE_INSTALL_PLUGIN_HEURISTIC_DIR + "${HIPDNN_PLUGIN_ROOTDIR}/${HIPDNN_PLUGIN_HEURISTIC_SUBDIR}" + CACHE STRING "Relative install directory for official hipDNN heuristic plugins" +) set(HIPDNN_TEST_PLUGIN_DIR "${CMAKE_BINARY_DIR}/${HIPDNN_PLUGIN_ROOTDIR}/test_plugins" CACHE INTERNAL "Build directory for test plugins" ) @@ -182,6 +199,7 @@ include(cmake/ClangCheck.cmake) include(cmake/ClangTidy.cmake) include(cmake/Sanitizers.cmake) include(cmake/Tests.cmake) +include(cmake/TestPluginNames.cmake) include(cmake/Spdlog.cmake) # Add global compile options diff --git a/projects/hipdnn/backend/include/HipdnnBackendAttributeName.h b/projects/hipdnn/backend/include/HipdnnBackendAttributeName.h index 52ed6dba062..e1075f5e31e 100644 --- a/projects/hipdnn/backend/include/HipdnnBackendAttributeName.h +++ b/projects/hipdnn/backend/include/HipdnnBackendAttributeName.h @@ -85,6 +85,26 @@ typedef enum /** @brief Find first mode: stop after finding any applicable engine (bool, extension) */ HIPDNN_ATTR_ENGINEHEUR_FIND_FIRST_EXT = 105, + /** + * @brief Ordered list of heuristic policy IDs for engine selection (array of int64, extension) + * + * Specifies the policy order for the heuristic outer loop. Each element is an int64_t + * policy ID, produced by hashing a policy name (e.g., "SelectionHeuristic::StaticOrdering") + * with hipdnn_data_sdk::utilities::policyNameToId. + * Hashing is performed by the caller before the C ABI; the backend stores and dispatches + * by ID only. + * + * Resolution priority at finalize time (highest first): + * 1. HIPDNN_HEUR_POLICY_ORDER env var (comma-separated tokens; each token is + * either a policy name, which is hashed via policyNameToId, or a raw + * decimal int64 policy ID). + * 2. This descriptor attribute, if set. + * 3. Built-in default: [SelectionHeuristic::Config, SelectionHeuristic::StaticOrdering]. + * + * Type: HIPDNN_TYPE_INT64 + */ + HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT = 106, + /** @} */ /** diff --git a/projects/hipdnn/backend/include/hipdnn_backend.h b/projects/hipdnn/backend/include/hipdnn_backend.h index 4a0fa73cf28..ad3e5205d6f 100644 --- a/projects/hipdnn/backend/include/hipdnn_backend.h +++ b/projects/hipdnn/backend/include/hipdnn_backend.h @@ -502,6 +502,28 @@ HIPDNN_BACKEND_EXPORT void hipdnnLoggingCallback_ext(hipdnnSeverity_t severity, HIPDNN_BACKEND_EXPORT hipdnnStatus_t hipdnnSetEnginePluginPaths_ext( size_t numPaths, const char* const* pluginPaths, hipdnnPluginLoadingMode_ext_t loadingMode); +/** + * @brief Sets the search paths for hipDNN heuristic plugins. + * + * Mirrors @ref hipdnnSetEnginePluginPaths_ext for the heuristic plugin search domain. + * Must be called before creating a hipDNN handle, as heuristic plugins are loaded + * during handle creation. + * + * Paths can be either directories or specific plugin files. Relative paths are resolved + * from the location of the libhipdnn_backend.so file. + * + * @param[in] numPaths The number of paths in the `pluginPaths` array. + * @param[in] pluginPaths An array of relative or absolute path strings. + * @param[in] loadingMode Specifies whether to add paths to or replace the default search paths. + * + * @retval HIPDNN_STATUS_SUCCESS The operation was successful. + * @retval HIPDNN_STATUS_BAD_PARAM_NULL_POINTER `pluginPaths` is nullptr when `numPaths` is greater than 0. + * @retval HIPDNN_STATUS_NOT_SUPPORTED Called with active handle. + * @retval HIPDNN_STATUS_INTERNAL_ERROR An internal error occurred. + */ +HIPDNN_BACKEND_EXPORT hipdnnStatus_t hipdnnSetHeuristicPluginPaths_ext( + size_t numPaths, const char* const* pluginPaths, hipdnnPluginLoadingMode_ext_t loadingMode); + /** * @brief Sets the plugin unloading mode for hipDNN. * @@ -676,6 +698,72 @@ HIPDNN_BACKEND_EXPORT hipdnnStatus_t hipdnnGetEngineInfo_ext(hipdnnHandle_t hand char* type, size_t* typeLen); +/** + * @brief Gets the count of loaded heuristic policies. + * + * Returns the number of heuristic policy plugins that have been successfully loaded + * and validated by the backend. This count includes all policies available for use + * in the outer loop engine selection. + * + * @param[in] handle A valid hipDNN handle. + * @param[out] numPolicies Pointer where the policy count will be stored. + * + * @retval HIPDNN_STATUS_SUCCESS Success. + * @retval HIPDNN_STATUS_BAD_PARAM Invalid handle or null pointer. + * + * @see hipdnnGetHeuristicPolicyInfo_ext for retrieving individual policy metadata + */ +HIPDNN_BACKEND_EXPORT hipdnnStatus_t hipdnnGetHeuristicPolicyCount_ext(hipdnnHandle_t handle, + size_t* numPolicies); + +/** + * @brief Gets information about a loaded heuristic policy by index. + * + * Retrieves metadata for a heuristic policy plugin, including policy ID, policy + * name, plugin name, plugin version, and API version. + * + * @note The enumeration order is unspecified and may change between calls or + * between backend versions. Callers must not assume a stable ordering across + * indices; use the returned `policyId` as the identity, not `policyIndex`. + * + * This function uses a two-call pattern for string fields: + * 1. First call: Pass all string buffers as `nullptr` to query required sizes. + * - Sets `policyNameLen`, `pluginNameLen`, `pluginVersionLen`, and `apiVersionLen` + * to the required buffer sizes (including null terminator). Note: if any + * string buffer is null, all sizes are updated. + * + * 2. Second call: Pass allocated buffers with sizes set from the first call. + * + * @param[in] handle A valid hipDNN handle. + * @param[in] policyIndex Zero-based index of the policy to query. + * @param[out] policyId Pointer where the policy ID will be stored, or `nullptr` to skip. + * @param[out] policyName Buffer for the policy name, or `nullptr` to query size. + * @param[in,out] policyNameLen Pointer to buffer size; updated with required size. + * @param[out] pluginName Buffer for the plugin name, or `nullptr` to query size. + * @param[in,out] pluginNameLen Pointer to buffer size; updated with required size. + * @param[out] pluginVersion Buffer for the plugin version, or `nullptr` to query size. + * @param[in,out] pluginVersionLen Pointer to buffer size; updated with required size. + * @param[out] apiVersion Buffer for the API version, or `nullptr` to query size. + * @param[in,out] apiVersionLen Pointer to buffer size; updated with required size. + * + * @retval HIPDNN_STATUS_SUCCESS Success. + * @retval HIPDNN_STATUS_BAD_PARAM Invalid handle, null pointers, or out-of-range index. + * @retval HIPDNN_STATUS_INTERNAL_ERROR Internal error. + * + * @see hipdnnGetHeuristicPolicyCount_ext for getting the total policy count + */ +HIPDNN_BACKEND_EXPORT hipdnnStatus_t hipdnnGetHeuristicPolicyInfo_ext(hipdnnHandle_t handle, + size_t policyIndex, + int64_t* policyId, + char* policyName, + size_t* policyNameLen, + char* pluginName, + size_t* pluginNameLen, + char* pluginVersion, + size_t* pluginVersionLen, + char* apiVersion, + size_t* apiVersionLen); + /** * @brief Returns hipdnn backend version string. Returns an error if nullptr is passed * diff --git a/projects/hipdnn/backend/src/CMakeLists.txt b/projects/hipdnn/backend/src/CMakeLists.txt index 58f45c3c5f8..4d065cc50a3 100644 --- a/projects/hipdnn/backend/src/CMakeLists.txt +++ b/projects/hipdnn/backend/src/CMakeLists.txt @@ -55,6 +55,10 @@ add_library( logging/GraphLogger.cpp logging/Logging.cpp FlatbufferUtilities.cpp + heuristics/BuiltInHeuristics.cpp + heuristics/SelectionHeuristic.cpp + heuristics/config/ConfigBuiltIn.cpp + heuristics/static_ordering/StaticOrderingBuiltIn.cpp plugin/EnginePlugin.cpp plugin/EnginePluginResourceManager.cpp plugin/HeuristicPlugin.cpp diff --git a/projects/hipdnn/backend/src/HipdnnBackend.cpp b/projects/hipdnn/backend/src/HipdnnBackend.cpp index 46095fd5931..9d9b3dab09e 100644 --- a/projects/hipdnn/backend/src/HipdnnBackend.cpp +++ b/projects/hipdnn/backend/src/HipdnnBackend.cpp @@ -14,6 +14,7 @@ #include "hipdnn_backend.h" #include "logging/Logging.hpp" #include "plugin/EnginePluginResourceManager.hpp" +#include "plugin/HeuristicPluginResourceManager.hpp" #include #include @@ -480,13 +481,47 @@ HIPDNN_BACKEND_EXPORT hipdnnStatus_t hipdnnSetEnginePluginPaths_ext( }); } +HIPDNN_BACKEND_EXPORT hipdnnStatus_t hipdnnSetHeuristicPluginPaths_ext( + size_t numPaths, const char* const* pluginPaths, hipdnnPluginLoadingMode_ext_t loadingMode) +{ + LOG_API_ENTRY("numPaths={}, pluginPaths_ptr={:p}, loadingMode={}", + numPaths, + static_cast(pluginPaths), + loadingMode); + + return hipdnn_backend::tryCatch([&, apiName = __func__] { + if(numPaths > 0) + { + throwIfNull(pluginPaths); + } + + std::vector pathsVec; + pathsVec.reserve(numPaths); + + for(size_t i = 0; i < numPaths; ++i) + { + throwIfNull(pluginPaths[i]); + pathsVec.emplace_back(pluginPaths[i]); + } + + hipdnn_backend::plugin::HeuristicPluginResourceManager::setPluginPaths(pathsVec, + loadingMode); + LOG_API_SUCCESS(apiName, "set_heuristic_plugin_paths={}", loadingMode); + return HIPDNN_STATUS_SUCCESS; + }); +} + HIPDNN_BACKEND_EXPORT hipdnnStatus_t hipdnnSetPluginUnloadMode_ext(hipdnnPluginUnloadingMode_ext_t unloadingMode) { LOG_API_ENTRY("unloadingMode={}", unloadingMode); return hipdnn_backend::tryCatch([&, apiName = __func__] { + // Apply to both plugin resource managers so the public ABI behaves + // uniformly regardless of plugin kind. hipdnn_backend::plugin::EnginePluginResourceManager::setPluginUnloadingMode(unloadingMode); + hipdnn_backend::plugin::HeuristicPluginResourceManager::setPluginUnloadingMode( + unloadingMode); LOG_API_SUCCESS(apiName, "set_plugin_unloading_mode={}", unloadingMode); return HIPDNN_STATUS_SUCCESS; }); @@ -654,6 +689,116 @@ HIPDNN_BACKEND_EXPORT hipdnnStatus_t hipdnnGetEngineInfo_ext(hipdnnHandle_t hand }); } +HIPDNN_BACKEND_EXPORT hipdnnStatus_t hipdnnGetHeuristicPolicyCount_ext(hipdnnHandle_t handle, + size_t* numPolicies) +{ + LOG_API_ENTRY("handle={:p}, numPolicies_ptr={:p}", + static_cast(handle), + static_cast(numPolicies)); + + return hipdnn_backend::tryCatch([&, apiName = __func__] { + throwIfNull(handle); + throwIfNull(numPolicies); + + auto policyInfos = handle->getHeuristicPluginResourceManager()->getHeuristicPolicyInfos(); + *numPolicies = policyInfos.size(); + + LOG_API_SUCCESS(apiName, "retrieved_numPolicies={}", *numPolicies); + }); +} + +HIPDNN_BACKEND_EXPORT hipdnnStatus_t hipdnnGetHeuristicPolicyInfo_ext(hipdnnHandle_t handle, + size_t policyIndex, + int64_t* policyId, + char* policyName, + size_t* policyNameLen, + char* pluginName, + size_t* pluginNameLen, + char* pluginVersion, + size_t* pluginVersionLen, + char* apiVersion, + size_t* apiVersionLen) +{ + LOG_API_ENTRY("handle={:p}, policyIndex={}, policyId_ptr={:p}, policyName_ptr={:p}, " + "pluginName_ptr={:p}, pluginVersion_ptr={:p}, apiVersion_ptr={:p}", + static_cast(handle), + policyIndex, + static_cast(policyId), + static_cast(policyName), + static_cast(pluginName), + static_cast(pluginVersion), + static_cast(apiVersion)); + + return hipdnn_backend::tryCatch([&, apiName = __func__] { + throwIfNull(handle); + throwIfNull(policyNameLen); + throwIfNull(pluginNameLen); + throwIfNull(pluginVersionLen); + throwIfNull(apiVersionLen); + + // Built from an unordered_map; ordering is unspecified and may change + // between calls. If a stable enumeration is ever required, an explicit + // ordering must be applied here rather than relied on from the source map. + auto policyInfos = handle->getHeuristicPluginResourceManager()->getHeuristicPolicyInfos(); + if(policyIndex >= policyInfos.size()) + { + throw HipdnnException(HIPDNN_STATUS_BAD_PARAM, + "Policy index " + std::to_string(policyIndex) + " out of range (" + + std::to_string(policyInfos.size()) + " policies loaded)."); + } + + const auto& info = policyInfos[policyIndex]; + + if(policyId != nullptr) + { + *policyId = info.policyId; + } + + const size_t requiredPolicyNameLen = info.policyName.size() + 1; + const size_t requiredPluginNameLen = info.pluginName.size() + 1; + const size_t requiredPluginVersionLen = info.pluginVersion.size() + 1; + const size_t requiredApiVersionLen = info.apiVersion.size() + 1; + + // Query mode: return required sizes + if(policyName == nullptr || pluginName == nullptr || pluginVersion == nullptr + || apiVersion == nullptr) + { + *policyNameLen = requiredPolicyNameLen; + *pluginNameLen = requiredPluginNameLen; + *pluginVersionLen = requiredPluginVersionLen; + *apiVersionLen = requiredApiVersionLen; + return; + } + + // Retrieve mode: check buffer sizes + if(*policyNameLen < requiredPolicyNameLen || *pluginNameLen < requiredPluginNameLen + || *pluginVersionLen < requiredPluginVersionLen + || *apiVersionLen < requiredApiVersionLen) + { + throw HipdnnException(HIPDNN_STATUS_BAD_PARAM, "Insufficient buffer space provided."); + } + + hipdnn_data_sdk::utilities::copyMaxSizeWithNullTerminator( + policyName, info.policyName.c_str(), *policyNameLen); + hipdnn_data_sdk::utilities::copyMaxSizeWithNullTerminator( + pluginName, info.pluginName.c_str(), *pluginNameLen); + hipdnn_data_sdk::utilities::copyMaxSizeWithNullTerminator( + pluginVersion, info.pluginVersion.c_str(), *pluginVersionLen); + hipdnn_data_sdk::utilities::copyMaxSizeWithNullTerminator( + apiVersion, info.apiVersion.c_str(), *apiVersionLen); + + LOG_API_SUCCESS(apiName, + "policy[{}]: policyId={}, policyName={}, pluginName={}, " + "pluginVersion={}, apiVersion={}", + policyIndex, + info.policyId, + info.policyName, + info.pluginName, + info.pluginVersion, + info.apiVersion); + }); +} + HIPDNN_BACKEND_EXPORT hipdnnStatus_t hipdnnGetVersion_ext(const char** version) { return hipdnn_backend::tryCatch([&]() { diff --git a/projects/hipdnn/backend/src/descriptors/DescriptorFactory.cpp b/projects/hipdnn/backend/src/descriptors/DescriptorFactory.cpp index e0d736e651c..20a1c129b7e 100644 --- a/projects/hipdnn/backend/src/descriptors/DescriptorFactory.cpp +++ b/projects/hipdnn/backend/src/descriptors/DescriptorFactory.cpp @@ -32,6 +32,10 @@ #include "SdpaFwdOperationDescriptor.hpp" #include "TensorDescriptor.hpp" #include "VariantDescriptor.hpp" +// Required: EngineHeuristicDescriptor holds std::unique_ptr +// via forward declaration, so the complete type must be visible where +// make_shared() instantiates the destructor. +#include "heuristics/SelectionHeuristic.hpp" #include "logging/Logging.hpp" namespace hipdnn_backend diff --git a/projects/hipdnn/backend/src/descriptors/EngineHeuristicDescriptor.cpp b/projects/hipdnn/backend/src/descriptors/EngineHeuristicDescriptor.cpp index 1e65be157d4..63bcd78bd77 100644 --- a/projects/hipdnn/backend/src/descriptors/EngineHeuristicDescriptor.cpp +++ b/projects/hipdnn/backend/src/descriptors/EngineHeuristicDescriptor.cpp @@ -13,11 +13,163 @@ #include "handle/Handle.hpp" #include "utilities/EngineOrdering.hpp" +// Heuristics framework +#include "heuristics/SelectionHeuristic.hpp" +#include "logging/Logging.hpp" +#include "plugin/HeuristicPlugin.hpp" +#include "plugin/HeuristicPluginResourceManager.hpp" +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + namespace hipdnn_backend { +std::vector EngineHeuristicDescriptor::resolveHeuristicPolicyOrder() +{ + // Policy order resolution. + // Priority: env > descriptor attr > default + // Storage and ABI are policy IDs (FNV-1a of the policy name); names are + // hashed at the point they enter the system. The Config built-in + // (HIPDNN_HEUR_CONFIG_PATH JSON rules) is a regular policy in this + // list, not a precursor; it declines when no rule matches so subsequent + // policies still run. The explicit Graph.preferred_engine_id setter is + // handled by the frontend as a post-hoc reorder of the heuristic-ranked + // engine configs. + + // 1. Environment variable HIPDNN_HEUR_POLICY_ORDER (highest priority) + // Use the data_sdk getEnv() wrapper rather than std::getenv() so that this + // reads the live process environment block on Windows. + // + // Tokens may be either policy names ("SelectionHeuristic::Config") or raw + // int64 policy IDs (decimal, optionally signed). A token parses as an ID + // only when std::strtoll consumes the *entire* trimmed token; anything + // else — including names that happen to start with digits — is hashed + // through policyNameToId. + const std::string envStr = hipdnn_data_sdk::utilities::getEnv("HIPDNN_HEUR_POLICY_ORDER"); + if(!envStr.empty()) + { + std::vector policyIds; + std::istringstream iss(envStr); + std::string token; + while(std::getline(iss, token, ',')) + { + // Trim whitespace + token.erase(0, token.find_first_not_of(" \t\n\r")); + token.erase(token.find_last_not_of(" \t\n\r") + 1); + if(token.empty()) + { + continue; + } + + char* end = nullptr; + errno = 0; + const int64_t asId = std::strtoll(token.c_str(), &end, 10); + const bool fullyParsed = (end != nullptr) && (*end == '\0') && (errno == 0); + policyIds.push_back(fullyParsed ? asId + : hipdnn_data_sdk::utilities::policyNameToId(token)); + } + HIPDNN_BACKEND_LOG_WARN("Using environment variable policy order: {} policies", + policyIds.size()); + return policyIds; + } + // 2. Descriptor attribute + if(_policyOrderSet) + { + HIPDNN_BACKEND_LOG_DEBUG("Using descriptor-level policy order: {} policies", + _policyOrder.size()); + return _policyOrder; + } + // 3. Default policy list — Config first so HIPDNN_HEUR_CONFIG_PATH + // rules win when set; StaticOrdering is the canonical last-resort fallback + // and always succeeds when there is at least one candidate. Vendor + // heuristic plugins may be inserted via env or descriptor attribute above. + std::vector policyIds = { + hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::Config"), + hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"), + }; + HIPDNN_BACKEND_LOG_WARN( + "No heuristic policy order configured, falling back to built-in defaults " + "[SelectionHeuristic::Config, SelectionHeuristic::StaticOrdering]. " + "Set HIPDNN_HEUR_POLICY_ORDER or the descriptor attribute to silence " + "this warning."); + return policyIds; +} + +void EngineHeuristicDescriptor::syncPolicySlots(const std::vector& orderedPolicyIds) +{ + // Ensure one SelectionHeuristic per policy slot. + // If the policy list changed, recreate the slots. + + if(_orderedPolicyIds == orderedPolicyIds && !_policySlots.empty()) + { + // Policy list unchanged and slots already created + return; + } + + _orderedPolicyIds = orderedPolicyIds; + _policySlots.clear(); + + auto handle = _graph->getHandle(); + auto heurRm = handle->getHeuristicPluginResourceManager(); + + // Create one SelectionHeuristic per policy slot. The slot holds a + // shared_ptr to the resource manager so the underlying plugin and handle + // cannot be destroyed while the slot is alive; lookups happen by policy + // ID inside SelectionHeuristic. + for(const int64_t policyId : orderedPolicyIds) + { + if(heurRm->getHeuristicHandleForPolicyId(policyId) == nullptr) + { + // Policy not loaded - add null placeholder + _policySlots.push_back(nullptr); + continue; + } + + // SelectionHeuristic's constructor calls into the plugin to create the + // policy descriptor, which can throw. Treat a failed slot the same way + // we treat a not-loaded policy: log and insert a null placeholder so + // the policy loop in finalize() simply skips it via its existing + // nullptr branch instead of aborting the whole descriptor. + // HipdnnException derives from std::exception, so one catch covers both. + try + { + _policySlots.push_back( + std::make_unique(heurRm, policyId)); + continue; + } + catch(const std::exception& e) + { + HIPDNN_BACKEND_LOG_WARN("Failed to construct SelectionHeuristic for policy ID {}: {}. " + "Slot will be skipped during finalize().", + policyId, + e.what()); + } + catch(...) + { + HIPDNN_BACKEND_LOG_WARN( + "Failed to construct SelectionHeuristic for policy ID {} (unknown exception). " + "Slot will be skipped during finalize().", + policyId); + } + _policySlots.push_back(nullptr); + } +} + void EngineHeuristicDescriptor::finalize() { + // Outer loop policy selection THROW_IF_TRUE(isFinalized(), HIPDNN_STATUS_BAD_PARAM, "EngineHeuristicDescriptor::finalize() failed: Already finalized."); @@ -31,18 +183,218 @@ void EngineHeuristicDescriptor::finalize() "EngineHeuristicDescriptor::finalize() failed: Heuristic mode is not set."); auto handle = _graph->getHandle(); - auto pluginResourceManager = handle->getPluginResourceManager(); + auto engineRm = handle->getPluginResourceManager(); + auto heurRm = handle->getHeuristicPluginResourceManager(); + + // Get candidate engine IDs from engine plugins + auto candidates = engineRm->getApplicableEngineIds(_graph.get(), _findFirst); + + // If no engines available, finalize with empty result (no need to invoke heuristics). + // This is a valid state - not an error. + if(candidates.empty()) + { + _engineIds.clear(); + HipdnnBackendDescriptorImpl::finalize(); + return; + } - _engineIds = pluginResourceManager->getApplicableEngineIds(_graph.get(), _findFirst); + // findFirst is a fast applicability probe (Graph::is_supported_ext) — the + // caller only needs to know whether *any* engine can run the graph. + if(_findFirst) + { + _engineIds = std::move(candidates); + HipdnnBackendDescriptorImpl::finalize(); + return; + } + + // Query and serialize device properties for the device the handle's stream + // is bound to. + int deviceId = 0; + auto status = hipStreamGetDevice(handle->getStream(), &deviceId); + if(status != hipSuccess) + { + throw HipdnnException(HIPDNN_STATUS_INTERNAL_ERROR, + "Failed to get device from handle's stream"); + } + + hipDeviceProp_t hipProps; + status = hipGetDeviceProperties(&hipProps, deviceId); + if(status != hipSuccess) + { + throw HipdnnException(HIPDNN_STATUS_INTERNAL_ERROR, "Failed to get device properties"); + } - if(!_findFirst) + // Create DevicePropertiesT from HIP device properties + hipdnn_flatbuffers_sdk::data_objects::DevicePropertiesT devProps; + devProps.device_id = deviceId; + devProps.multi_processor_count = hipProps.multiProcessorCount; + devProps.total_global_mem = hipProps.totalGlobalMem; + devProps.architecture_name = hipProps.gcnArchName; + + // Serialize DevicePropertiesT using FlatBuffers + flatbuffers::FlatBufferBuilder builder(256); + auto offset = hipdnn_flatbuffers_sdk::data_objects::DeviceProperties::Pack(builder, &devProps); + builder.Finish(offset, "HDDP"); + + // Copy serialized data to persistent storage + std::vector devicePropsSerialized(builder.GetBufferPointer(), + builder.GetBufferPointer() + builder.GetSize()); + + // Wrap serialized buffer in hipdnnPluginConstData_t + hipdnnPluginConstData_t devicePropsWrapper; + devicePropsWrapper.ptr = devicePropsSerialized.data(); + devicePropsWrapper.size = devicePropsSerialized.size(); + + // Get serialized graph from GraphDescriptor + const hipdnnPluginConstData_t serializedGraph = _graph->getSerializedGraph(); + + // Resolve ordered policy IDs + auto orderedPolicyIds = resolveHeuristicPolicyOrder(); + + // Ensure policy slots match the ordered policy list + syncPolicySlots(orderedPolicyIds); + + // Set device properties on all distinct plugin handles (once per handle, not per slot) + // NOTE: Multiple policies may share the same plugin handle if they come from the same + // .so file (e.g., a single plugin providing multiple ordering strategies like Fast/Balanced/Accurate). + // Use a map to deduplicate and call setDeviceProperties only once per unique handle. + std::unordered_map distinctHandles; + for(const int64_t policyId : orderedPolicyIds) { - // Sort engine IDs to prioritize MIOPEN_ENGINE and deprioritize MIOPEN_ENGINE_DETERMINISTIC - // In the future, we will need to implement a plugin system for engine heuristics that allows - // plugins to determine sort order of the returned engines. - utilities::sortEngineIds(_engineIds); + auto pluginHandle = heurRm->getHeuristicHandleForPolicyId(policyId); + if(pluginHandle != nullptr) + { + auto plugin = heurRm->getPluginForPolicyId(policyId); + if(plugin != nullptr) + { + distinctHandles[pluginHandle] = plugin; + } + } } + // Call SetDeviceProperties on each distinct handle. + // Sort by handle pointer so the call order is stable *within this process + // run* — std::unordered_map iteration order is otherwise unspecified, which + // would scramble the order of any per-handle log lines emitted below and + // make the fail-soft disable order non-reproducible from one finalize() to + // the next on the same descriptor. Pointers vary across runs (ASLR), so + // this is reproducible-per-run, not reproducible-across-runs. + std::vector> sortedHandles( + distinctHandles.begin(), distinctHandles.end()); + std::sort(sortedHandles.begin(), sortedHandles.end(), [](const auto& a, const auto& b) { + return a.first < b.first; + }); + + // Mirror the policy loop's fail-soft contract below: a single plugin's + // setDeviceProperties failure must not break the chain. Disable every slot + // backed by a failed plugin handle so the policy loop skips it via the + // existing nullptr branch. + auto disableSlotsForHandle = [&](hipdnnHeuristicHandle_t failedHandle) { + for(size_t i = 0; i < _policySlots.size(); ++i) + { + if(_policySlots[i] != nullptr + && heurRm->getHeuristicHandleForPolicyId(_orderedPolicyIds[i]) == failedHandle) + { + _policySlots[i].reset(); + } + } + }; + + for(const auto& [pluginHandle, plugin] : sortedHandles) + { + try + { + plugin->setDeviceProperties(pluginHandle, &devicePropsWrapper); + } + catch(const std::exception& e) + { + HIPDNN_BACKEND_LOG_WARN("setDeviceProperties failed for heuristic plugin '{}': {}. " + "Disabling all policies provided by this plugin.", + plugin->name(), + e.what()); + disableSlotsForHandle(pluginHandle); + } + catch(...) + { + HIPDNN_BACKEND_LOG_WARN("setDeviceProperties threw unknown exception for heuristic " + "plugin '{}'. Disabling all policies provided by this plugin.", + plugin->name()); + disableSlotsForHandle(pluginHandle); + } + } + + // Outer policy loop: try each policy in order until one succeeds + bool success = false; + for(size_t i = 0; i < _policySlots.size(); ++i) + { + auto& selection = _policySlots[i]; + if(selection == nullptr) + { + // Policy plugin not loaded - continue to next policy + continue; + } + + try + { + // Set candidate engine IDs and serialized graph on the policy descriptor + selection->setEngineIds(candidates); + selection->setSerializedGraph(&serializedGraph); + + // Call finalize on this policy + if(!selection->finalize()) + { + // Policy declined or not applicable - continue to next policy + continue; + } + + // Policy succeeded! Get the sorted engine IDs + candidates = selection->getSortedEngineIds(); + success = true; + break; + } + catch(const HipdnnException& e) + { + // Policy threw an exception - log and continue to next policy + HIPDNN_BACKEND_LOG_WARN("Heuristic policy at slot {} (ID {}) threw exception: {}. " + "Continuing to next policy.", + i, + _orderedPolicyIds[i], + e.what()); + continue; + } + catch(const std::exception& e) + { + // Plugin code is external and may throw any std-derived exception type. + // Treat the same as HipdnnException: log and continue. + HIPDNN_BACKEND_LOG_WARN("Heuristic policy at slot {} (ID {}) threw exception: {}. " + "Continuing to next policy.", + i, + _orderedPolicyIds[i], + e.what()); + continue; + } + catch(...) + { + // Plugin may throw a non-std-derived exception; never let it cross the C ABI. + HIPDNN_BACKEND_LOG_WARN("Heuristic policy at slot {} (ID {}) threw unknown exception. " + "Continuing to next policy.", + i, + _orderedPolicyIds[i]); + continue; + } + } + + // If no policy succeeded, throw exception. + // No hidden fallback to utilities::sortEngineIds. + if(!success) + { + throw HipdnnException( + HIPDNN_STATUS_INTERNAL_ERROR, + "EngineHeuristicDescriptor::finalize() failed: No heuristic policy succeeded."); + } + + _engineIds = candidates; + HipdnnBackendDescriptorImpl::finalize(); } @@ -70,6 +422,9 @@ void EngineHeuristicDescriptor::getAttribute(hipdnnBackendAttributeName_t attrib case HIPDNN_ATTR_ENGINEHEUR_FIND_FIRST_EXT: getFindFirst(attributeType, requestedElementCount, elementCount, arrayOfElements); break; + case HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT: + getPolicyOrder(attributeType, requestedElementCount, elementCount, arrayOfElements); + break; default: throw HipdnnException( HIPDNN_STATUS_NOT_SUPPORTED, @@ -98,6 +453,9 @@ void EngineHeuristicDescriptor::setAttribute(hipdnnBackendAttributeName_t attrib case HIPDNN_ATTR_ENGINEHEUR_FIND_FIRST_EXT: setFindFirst(attributeType, elementCount, arrayOfElements); break; + case HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT: + setPolicyOrder(attributeType, elementCount, arrayOfElements); + break; default: throw HipdnnException( HIPDNN_STATUS_NOT_SUPPORTED, @@ -327,12 +685,96 @@ hipdnnBackendDescriptorType_t EngineHeuristicDescriptor::getStaticType() return HIPDNN_BACKEND_ENGINEHEUR_DESCRIPTOR; } +void EngineHeuristicDescriptor::setPolicyOrder(hipdnnBackendAttributeType_t attributeType, + int64_t elementCount, + const void* arrayOfElements) +{ + THROW_IF_NE(attributeType, + HIPDNN_TYPE_INT64, + HIPDNN_STATUS_BAD_PARAM, + "EngineHeuristicDescriptor failed to set policy order: Invalid attribute type."); + + THROW_IF_TRUE(elementCount < 0, + HIPDNN_STATUS_BAD_PARAM, + "EngineHeuristicDescriptor failed to set policy order: Negative element count."); + + THROW_IF_TRUE(elementCount > 0 && arrayOfElements == nullptr, + HIPDNN_STATUS_BAD_PARAM_NULL_POINTER, + "EngineHeuristicDescriptor failed to set policy order: Null pointer."); + + if(elementCount == 0) + { + _policyOrder.clear(); + _policyOrderSet = true; + HIPDNN_BACKEND_LOG_DEBUG("Set descriptor-level policy order: 0 policies"); + return; + } + + const auto* data = static_cast(arrayOfElements); + _policyOrder.assign(data, data + elementCount); + _policyOrderSet = true; + HIPDNN_BACKEND_LOG_DEBUG("Set descriptor-level policy order: {} policies", _policyOrder.size()); +} + +void EngineHeuristicDescriptor::getPolicyOrder(hipdnnBackendAttributeType_t attributeType, + int64_t requestedElementCount, + int64_t* elementCount, + void* arrayOfElements) const +{ + THROW_IF_NE(attributeType, + HIPDNN_TYPE_INT64, + HIPDNN_STATUS_BAD_PARAM, + "EngineHeuristicDescriptor failed to get policy order: Invalid attribute type."); + + THROW_IF_NULL(elementCount, + HIPDNN_STATUS_BAD_PARAM_NULL_POINTER, + "EngineHeuristicDescriptor failed to get policy order: Null pointer for " + "element count."); + + THROW_IF_TRUE(requestedElementCount < 0, + HIPDNN_STATUS_BAD_PARAM, + "EngineHeuristicDescriptor failed to get policy order: Negative requested " + "element count."); + + THROW_IF_TRUE(requestedElementCount > 0 && arrayOfElements == nullptr, + HIPDNN_STATUS_BAD_PARAM_NULL_POINTER, + "EngineHeuristicDescriptor failed to get policy order: Null pointer."); + + // The dispatcher requires isFinalized() before reaching here, so + // _orderedPolicyIds reflects the resolved order (env > descriptor > default) + // actually used during finalize(). See resolveHeuristicPolicyOrder(). + if(requestedElementCount == 0) + { + *elementCount = static_cast(_orderedPolicyIds.size()); + return; + } + + auto* output = static_cast(arrayOfElements); + const auto count + = std::min(static_cast(requestedElementCount), _orderedPolicyIds.size()); + std::memcpy(output, _orderedPolicyIds.data(), count * sizeof(int64_t)); + *elementCount = static_cast(count); +} + std::string EngineHeuristicDescriptor::toString() const { std::string str = "EngineHeuristicDescriptor: {heuristicMode="; str += _heuristicModeSet ? std::to_string(_heuristicMode) : "unset"; str += _graph ? ", graph=" + fmt::format("{:p}", static_cast(_graph.get())) : ", graph=null"; + if(_policyOrderSet) + { + str += ", policyOrder=["; + for(size_t i = 0; i < _policyOrder.size(); ++i) + { + if(i > 0) + { + str += ", "; + } + str += hipdnn_data_sdk::utilities::formatEngineIdHex(_policyOrder[i]); + } + str += "]"; + } str += "}"; return str; } diff --git a/projects/hipdnn/backend/src/descriptors/EngineHeuristicDescriptor.hpp b/projects/hipdnn/backend/src/descriptors/EngineHeuristicDescriptor.hpp index 4f5cd9c5198..4a304a7692d 100644 --- a/projects/hipdnn/backend/src/descriptors/EngineHeuristicDescriptor.hpp +++ b/projects/hipdnn/backend/src/descriptors/EngineHeuristicDescriptor.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include "BackendDescriptor.hpp" @@ -12,6 +13,11 @@ namespace hipdnn_backend class GraphDescriptor; +namespace heuristics +{ +class SelectionHeuristic; +} + class EngineHeuristicDescriptor : public HipdnnBackendDescriptorImpl { private: @@ -21,6 +27,18 @@ class EngineHeuristicDescriptor : public HipdnnBackendDescriptorImpl _orderedPolicyIds; + std::vector> _policySlots; + std::vector _policyOrder; // descriptor-level policy IDs + bool _policyOrderSet = false; + + // Resolve policy order from descriptor/handle/env/default + std::vector resolveHeuristicPolicyOrder(); + + // Ensure policy slots match orderedPolicyIds + void syncPolicySlots(const std::vector& orderedPolicyIds); + void setGraph(hipdnnBackendAttributeType_t attributeType, int64_t elementCount, const void* arrayOfElements); @@ -53,6 +71,15 @@ class EngineHeuristicDescriptor : public HipdnnBackendDescriptorImpl + +#include +#include + +namespace hipdnn_backend::heuristics::detail +{ + +constexpr size_t BUILT_IN_LOG_BUFFER_SIZE = 1024; + +} // namespace hipdnn_backend::heuristics::detail + +// Shared logging macro for built-in heuristic policies. Each built-in declares +// its own file-scope g_loggingCallback / g_logLevel globals (set via the C-ABI +// SetLoggingCallback / SetLogLevel pseudo-plugin entrypoints) and uses this +// macro to format a prefixed message and dispatch through whatever sink the +// host wired up. +// +// Usage: +// #define MY_LOG(severity, ...) \ +// HIPDNN_BUILTIN_HEURISTIC_LOG( \ +// g_loggingCallback, g_logLevel, severity, "[MyBuiltIn] ", __VA_ARGS__) +#define HIPDNN_BUILTIN_HEURISTIC_LOG(callback, threshold, severity, prefix, ...) \ + do \ + { \ + if((callback) != nullptr && (severity) >= (threshold)) \ + { \ + std::array \ + _buf{}; \ + std::snprintf(_buf.data(), _buf.size(), prefix __VA_ARGS__); \ + (callback)((severity), _buf.data()); \ + } \ + } while(0) diff --git a/projects/hipdnn/backend/src/heuristics/SelectionHeuristic.cpp b/projects/hipdnn/backend/src/heuristics/SelectionHeuristic.cpp new file mode 100644 index 00000000000..787bf44cdf6 --- /dev/null +++ b/projects/hipdnn/backend/src/heuristics/SelectionHeuristic.cpp @@ -0,0 +1,229 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "SelectionHeuristic.hpp" + +#include +#include + +#include "HipdnnException.hpp" +#include "logging/Logging.hpp" +#include "plugin/HeuristicPlugin.hpp" +#include "plugin/HeuristicPluginResourceManager.hpp" + +namespace hipdnn_backend::heuristics +{ + +SelectionHeuristic::SelectionHeuristic( + std::shared_ptr resourceManager, int64_t policyId) + : _resourceManager(std::move(resourceManager)) + , _policyId(policyId) +{ + THROW_IF_FALSE(_resourceManager != nullptr, + HIPDNN_STATUS_BAD_PARAM, + "HeuristicPluginResourceManager pointer cannot be null"); + + auto pluginHandle = _resourceManager->getHeuristicHandleForPolicyId(_policyId); + THROW_IF_FALSE(pluginHandle != nullptr, + HIPDNN_STATUS_BAD_PARAM, + "No heuristic plugin handle loaded for policy ID " + std::to_string(_policyId)); + + auto plugin = _resourceManager->getPluginForPolicyId(_policyId); + THROW_IF_FALSE(plugin != nullptr, + HIPDNN_STATUS_BAD_PARAM, + "No heuristic plugin loaded for policy ID " + std::to_string(_policyId)); + + _descriptor = plugin->createPolicyDescriptor(pluginHandle, _policyId); +} + +SelectionHeuristic::~SelectionHeuristic() +{ + if(_descriptor != nullptr && _resourceManager != nullptr) + { + try + { + auto plugin = lookupPlugin(); + if(plugin != nullptr) + { + plugin->destroyPolicyDescriptor(_descriptor); + } + } + // Destructors must not propagate. HipdnnException derives from + // std::exception, so one catch covers both; plugin code is untrusted + // and may also throw non-std types — fall through to catch(...). + catch(const std::exception& e) + { + HIPDNN_BACKEND_LOG_WARN( + "Exception while destroying heuristic policy descriptor for policy ID {}: {}", + _policyId, + e.what()); + } + catch(...) + { + HIPDNN_BACKEND_LOG_WARN( + "Unknown exception while destroying heuristic policy descriptor for policy ID {}", + _policyId); + } + _descriptor = nullptr; + } +} + +SelectionHeuristic::SelectionHeuristic(SelectionHeuristic&& other) noexcept + : _resourceManager(std::move(other._resourceManager)) + , _policyId(other._policyId) + , _descriptor(other._descriptor) + , _inputEngineIds(std::move(other._inputEngineIds)) +{ + other._policyId = 0; + other._descriptor = nullptr; + other._inputEngineIds.clear(); +} + +SelectionHeuristic& SelectionHeuristic::operator=(SelectionHeuristic&& other) noexcept +{ + if(this != &other) + { + // Clean up current descriptor + if(_descriptor != nullptr && _resourceManager != nullptr) + { + try + { + auto plugin = lookupPlugin(); + if(plugin != nullptr) + { + plugin->destroyPolicyDescriptor(_descriptor); + } + } + // noexcept move-assign must not propagate. HipdnnException + // derives from std::exception, so one catch covers both; plugin + // code is untrusted and may also throw non-std types. + catch(const std::exception& e) + { + HIPDNN_BACKEND_LOG_WARN( + "Exception while destroying heuristic policy descriptor for policy ID {} " + "during move-assignment: {}", + _policyId, + e.what()); + } + catch(...) + { + HIPDNN_BACKEND_LOG_WARN("Unknown exception while destroying heuristic policy " + "descriptor for policy ID {} during move-assignment", + _policyId); + } + } + + // Move from other + _resourceManager = std::move(other._resourceManager); + _policyId = other._policyId; + _descriptor = other._descriptor; + _inputEngineIds = std::move(other._inputEngineIds); + other._policyId = 0; + other._descriptor = nullptr; + other._inputEngineIds.clear(); + } + return *this; +} + +void SelectionHeuristic::setEngineIds(const std::vector& engineIds) +{ + THROW_IF_FALSE( + _descriptor != nullptr, HIPDNN_STATUS_NOT_INITIALIZED, "Policy descriptor not initialized"); + + auto plugin = lookupPlugin(); + THROW_IF_FALSE(plugin != nullptr, + HIPDNN_STATUS_NOT_INITIALIZED, + "Heuristic plugin no longer registered for policy ID " + + std::to_string(_policyId)); + + plugin->setEngineIds(_descriptor, engineIds.data(), engineIds.size()); + _inputEngineIds = engineIds; +} + +void SelectionHeuristic::setSerializedGraph(const hipdnnPluginConstData_t* serializedGraph) +{ + THROW_IF_FALSE( + _descriptor != nullptr, HIPDNN_STATUS_NOT_INITIALIZED, "Policy descriptor not initialized"); + THROW_IF_FALSE(serializedGraph != nullptr, + HIPDNN_STATUS_BAD_PARAM_NULL_POINTER, + "Serialized graph pointer cannot be null"); + + auto plugin = lookupPlugin(); + THROW_IF_FALSE(plugin != nullptr, + HIPDNN_STATUS_NOT_INITIALIZED, + "Heuristic plugin no longer registered for policy ID " + + std::to_string(_policyId)); + + plugin->setSerializedGraph(_descriptor, serializedGraph); +} + +bool SelectionHeuristic::finalize() +{ + THROW_IF_FALSE( + _descriptor != nullptr, HIPDNN_STATUS_NOT_INITIALIZED, "Policy descriptor not initialized"); + + auto plugin = lookupPlugin(); + THROW_IF_FALSE(plugin != nullptr, + HIPDNN_STATUS_NOT_INITIALIZED, + "Heuristic plugin no longer registered for policy ID " + + std::to_string(_policyId)); + + // Call the plugin's finalize method + // Returns true if policy succeeded (won the outer loop) + // Returns false if not applicable or declined + return plugin->finalize(_descriptor); +} + +std::vector SelectionHeuristic::getSortedEngineIds() +{ + THROW_IF_FALSE( + _descriptor != nullptr, HIPDNN_STATUS_NOT_INITIALIZED, "Policy descriptor not initialized"); + + auto plugin = lookupPlugin(); + THROW_IF_FALSE(plugin != nullptr, + HIPDNN_STATUS_NOT_INITIALIZED, + "Heuristic plugin no longer registered for policy ID " + + std::to_string(_policyId)); + + // Call the plugin's getSortedEngineIds method + // This is valid only after finalize() returned true + auto sortedIds = plugin->getSortedEngineIds(_descriptor); + + // Validate that the plugin returned a permutation or subset of the IDs + // we handed it via setEngineIds: every output ID must be in the input + // and no output ID may appear twice. Plugins are untrusted code. + THROW_IF_FALSE(sortedIds.size() <= _inputEngineIds.size(), + HIPDNN_STATUS_PLUGIN_ERROR, + "Heuristic plugin for policy ID " + std::to_string(_policyId) + + " returned more engine IDs (" + std::to_string(sortedIds.size()) + + ") than were provided (" + std::to_string(_inputEngineIds.size()) + ")"); + + const std::unordered_set inputSet(_inputEngineIds.begin(), _inputEngineIds.end()); + std::unordered_set seen; + seen.reserve(sortedIds.size()); + for(const int64_t id : sortedIds) + { + THROW_IF_FALSE(inputSet.count(id) != 0, + HIPDNN_STATUS_PLUGIN_ERROR, + "Heuristic plugin for policy ID " + std::to_string(_policyId) + + " returned engine ID " + std::to_string(id) + + " that was not in the provided candidate set"); + THROW_IF_FALSE(seen.insert(id).second, + HIPDNN_STATUS_PLUGIN_ERROR, + "Heuristic plugin for policy ID " + std::to_string(_policyId) + + " returned duplicate engine ID " + std::to_string(id)); + } + + return sortedIds; +} + +const plugin::HeuristicPlugin* SelectionHeuristic::lookupPlugin() const +{ + if(_resourceManager == nullptr) + { + return nullptr; + } + return _resourceManager->getPluginForPolicyId(_policyId); +} + +} // namespace hipdnn_backend::heuristics diff --git a/projects/hipdnn/backend/src/heuristics/SelectionHeuristic.hpp b/projects/hipdnn/backend/src/heuristics/SelectionHeuristic.hpp new file mode 100644 index 00000000000..c56a35de0a4 --- /dev/null +++ b/projects/hipdnn/backend/src/heuristics/SelectionHeuristic.hpp @@ -0,0 +1,157 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include +#include + +namespace hipdnn_backend::plugin +{ +class HeuristicPlugin; +class HeuristicPluginResourceManager; +} // namespace hipdnn_backend::plugin + +namespace hipdnn_backend::heuristics +{ + +/** + * @brief C++ facade for one policy slot on EngineHeuristicDescriptor. + * + * This class wraps a hipdnnHeuristicPolicyDescriptor_t created with a + * hipdnnHeuristicHandle_t for that policy's module. It provides a clean + * C++ interface over the heuristic plugin C ABI. + * + * Session state (caches, tuning data, etc.) lives in the plugin behind the + * handle, not in this wrapper. + * + * Device properties are NOT set on this facade. The host calls + * hipdnnHeuristicHandleSetDeviceProperties on the handle BEFORE calling + * finalize() on any descriptor created with that handle. + * + * Lifecycle: Owned by EngineHeuristicDescriptor, one per resolved policy slot. + * Created when the policy list is established, destroyed with the descriptor. + * + * Ownership: Holds a shared_ptr to the HeuristicPluginResourceManager and a + * policy ID rather than a raw HeuristicPlugin pointer. The shared_ptr keeps + * the manager (and transitively the loaded plugin) alive for the descriptor's + * lifetime; plugin and handle lookups happen by policy ID through the + * manager. This protects against the manager being destroyed while a slot + * still holds the descriptor. + */ +class SelectionHeuristic +{ +public: + /** + * @brief Constructs a SelectionHeuristic for a given policy. + * + * Resolves the HeuristicPlugin and hipdnnHeuristicHandle_t through the + * resource manager using the policy ID, and creates the underlying + * hipdnnHeuristicPolicyDescriptor_t. + * + * @param resourceManager Shared pointer to the resource manager that owns + * the plugin handles. Kept alive for the lifetime + * of this object so the resolved plugin/handle + * cannot dangle. + * @param policyId Stable int64_t policy ID identifying the policy this + * slot represents (policyNameToId hash). + */ + SelectionHeuristic(std::shared_ptr resourceManager, + int64_t policyId); + + /** + * @brief Destroys the SelectionHeuristic and releases the policy descriptor. + * + * Calls hipdnnHeuristicPolicyDescriptorDestroy on the underlying descriptor. + */ + ~SelectionHeuristic(); + + // Prevent copying + SelectionHeuristic(const SelectionHeuristic&) = delete; + SelectionHeuristic& operator=(const SelectionHeuristic&) = delete; + + // Allow moving + SelectionHeuristic(SelectionHeuristic&& other) noexcept; + SelectionHeuristic& operator=(SelectionHeuristic&& other) noexcept; + + /** + * @brief Sets the candidate engine IDs for this policy. + * + * Provides the list of candidate engine IDs from + * EnginePluginResourceManager::getApplicableEngineIds. + * The plugin must produce a reordered subset or permutation of these IDs. + * + * Mirrors hipdnnHeuristicPolicySetEngineIds (§8.8). + * + * @param engineIds Vector of candidate engine IDs. + */ + void setEngineIds(const std::vector& engineIds); + + /** + * @brief Sets the serialized operation graph for this policy. + * + * Provides the FlatBuffer-serialized operation graph from + * GraphDescriptor::getSerializedGraph(). + * + * Mirrors hipdnnHeuristicPolicySetSerializedGraph (§8.8). + * + * @param serializedGraph Pointer to hipdnnPluginConstData_t containing + * the serialized graph buffer. + */ + void setSerializedGraph(const hipdnnPluginConstData_t* serializedGraph); + + /** + * @brief Executes the policy selection logic. + * + * Performs applicability checking and engine ordering based on the inputs + * previously set via setEngineIds and setSerializedGraph. + * + * Device properties are queried from the bound hipdnnHeuristicHandle_t + * (which received SetDeviceProperties earlier in the finalize() flow). + * + * Two-phase design: This function performs the selection work; + * getSortedEngineIds retrieves the result. + * + * Mirrors hipdnnHeuristicPolicyFinalize (§8.9). + * + * @return true if policy succeeded (policy won the outer loop), + * false if not applicable or declined (host continues outer loop). + */ + bool finalize(); + + /** + * @brief Retrieves the sorted engine IDs after successful finalize. + * + * Valid only after finalize() returned true. Returns the reordered + * engine IDs produced by the policy. + * + * The output IDs are a permutation or subset of the input IDs from + * setEngineIds. The host validates this constraint. + * + * Mirrors hipdnnHeuristicPolicyGetSortedEngineIds (§8.9). + * + * @return Vector of sorted engine IDs. + */ + std::vector getSortedEngineIds(); + +private: + // Look up the plugin for _policyId via _resourceManager. Returns nullptr + // if the policy is no longer registered (e.g. plugin was removed). All + // operations route through the resource manager so the plugin lookup + // stays consistent with the manager's current state. + const plugin::HeuristicPlugin* lookupPlugin() const; + + std::shared_ptr _resourceManager; + int64_t _policyId = 0; + hipdnnHeuristicPolicyDescriptor_t _descriptor = nullptr; + // Cached candidate engine IDs from the most recent setEngineIds call. + // Used in getSortedEngineIds to validate that the plugin returned a + // permutation or subset of the provided IDs (no duplicates, no fabrications). + std::vector _inputEngineIds; +}; + +} // namespace hipdnn_backend::heuristics diff --git a/projects/hipdnn/backend/src/heuristics/config/ConfigBuiltIn.cpp b/projects/hipdnn/backend/src/heuristics/config/ConfigBuiltIn.cpp new file mode 100644 index 00000000000..7fbcaa16ad9 --- /dev/null +++ b/projects/hipdnn/backend/src/heuristics/config/ConfigBuiltIn.cpp @@ -0,0 +1,597 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file ConfigBuiltIn.cpp + * @brief Backend-internal implementation of the + * SelectionHeuristic::Config policy. + * + * The policy reads HIPDNN_HEUR_CONFIG_PATH (a JSON file mapping + * conv-shape patterns to engine names via EngineOverrideConfig), walks + * conv-like nodes in the serialized graph, and on the first matching rule + * reorders the candidate engine IDs so the chosen engine is first. When the + * env var is unset, the file is missing/invalid, no rule matches, or the + * matched engine is not among the candidates, the policy declines so the + * outer policy loop can try the next plugin. + * + * Mechanics mirror StaticOrderingBuiltIn — a function-pointer table wrapped + * by HeuristicPlugin::createBuiltIn so registration and validation flow + * through the same paths as dlopen-loaded plugins. + */ + +#include "ConfigBuiltIn.hpp" + +#include "EngineOverrideConfig.hpp" +#include "heuristics/BuiltInLogging.hpp" +#include "logging/Logging.hpp" + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace hipdnn_backend::heuristics::config +{ +namespace +{ + +constexpr const char* PLUGIN_NAME = "BuiltInConfigHeuristic"; +constexpr const char* PLUGIN_VERSION = "1.0.0"; +constexpr const char* POLICY_NAME = "SelectionHeuristic::Config"; + +// File-scope logging callback / level, set via the C-ABI-shaped +// SetLoggingCallback / SetLogLevel below. The backend supplies its own +// callback when registering the built-in (see PluginManagerBase::registerPlugin) +// so log lines from this module flow through the backend logger. +// +// Identity contract: the built-in is statically linked into the backend, so +// these globals live in the same process image as the caller. The last writer +// wins — if multiple HeuristicPluginManager instances register the built-in +// they overwrite each other's callback. This is intentional: registerPlugin() +// hands in a callback that forwards to the backend logger, which is itself a +// process-wide sink, so the identity of the "current" callback does not matter +// as long as one is installed. Do not assume per-instance scoping here. +hipdnnCallback_t g_loggingCallback = nullptr; // NOLINT(readability-identifier-naming) +hipdnnSeverity_t g_logLevel = HIPDNN_SEV_INFO; // NOLINT(readability-identifier-naming) + +#define CONFIG_BUILTIN_LOG(severity, ...) \ + HIPDNN_BUILTIN_HEURISTIC_LOG( \ + g_loggingCallback, g_logLevel, severity, "[BuiltInConfig] ", __VA_ARGS__) + +int64_t policyId() +{ + static const int64_t s_id = hipdnn_data_sdk::utilities::policyNameToId(POLICY_NAME); + return s_id; +} + +// ---- Graph-walk helpers (lifted from the former PreferredEngineResolver) --- + +using hipdnn_flatbuffers_sdk::data_objects::Graph; + +std::vector toVector(const flatbuffers::Vector* fb) +{ + if(fb == nullptr) + { + return {}; + } + return {fb->begin(), fb->end()}; +} + +struct TensorDimsStrides +{ + std::vector dims; + std::vector strides; +}; + +std::unordered_map indexTensorsByUid(const Graph* graph) +{ + std::unordered_map out; + const auto* tensors = graph->tensors(); + if(tensors == nullptr) + { + return out; + } + out.reserve(tensors->size()); + for(const auto* t : *tensors) + { + if(t == nullptr) + { + continue; + } + out.emplace(t->uid(), TensorDimsStrides{toVector(t->dims()), toVector(t->strides())}); + } + return out; +} + +std::optional + matchOverrideConfig(const EngineOverrideConfig& config, + const Graph* graph, + const std::unordered_map& tensorIndex) +{ + const auto* nodes = graph->nodes(); + if(nodes == nullptr) + { + return std::nullopt; + } + + auto viewFor = [&](int64_t uid) -> const TensorDimsStrides* { + auto it = tensorIndex.find(uid); + return it == tensorIndex.end() ? nullptr : &it->second; + }; + + auto buildView = [&](const TensorDimsStrides* t) { return TensorView{&t->dims, &t->strides}; }; + + for(const auto* node : *nodes) + { + if(node == nullptr) + { + continue; + } + + const char* op = nullptr; + const TensorDimsStrides* a = nullptr; + const TensorDimsStrides* b = nullptr; + + if(const auto* fwd = node->attributes_as_ConvolutionFwdAttributes()) + { + op = "conv_fprop"; + a = viewFor(fwd->x_tensor_uid()); + b = viewFor(fwd->w_tensor_uid()); + } + else if(const auto* bwd = node->attributes_as_ConvolutionBwdAttributes()) + { + op = "conv_dgrad"; + a = viewFor(bwd->dy_tensor_uid()); + b = viewFor(bwd->w_tensor_uid()); + } + else if(const auto* wrw = node->attributes_as_ConvolutionWrwAttributes()) + { + op = "conv_wgrad"; + a = viewFor(wrw->x_tensor_uid()); + b = viewFor(wrw->dy_tensor_uid()); + } + + if(op == nullptr || a == nullptr || b == nullptr) + { + continue; + } + + const std::vector views{buildView(a), buildView(b)}; + auto match = config.matchOperation(op, views); + if(match.has_value()) + { + return match; + } + } + return std::nullopt; +} + +/// Validate the buffer and return the typed Graph root, or nullptr on failure. +const Graph* parseGraphBuffer(const std::vector& buffer) +{ + if(buffer.empty()) + { + return nullptr; + } + flatbuffers::Verifier verifier(buffer.data(), buffer.size()); + if(!verifier.VerifyBuffer()) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_WARN, "policyFinalize: invalid serialized graph buffer"); + return nullptr; + } + return hipdnn_flatbuffers_sdk::data_objects::GetGraph(buffer.data()); +} + +/// Reorder @p candidates so @p preferredEngineId comes first, preserving the +/// relative order of the rest. Returns nullopt if the preferred id is not in +/// the candidate list. +std::optional> + reorderWithPreferredFirst(const std::vector& candidates, int64_t preferredEngineId) +{ + auto it = std::find(candidates.begin(), candidates.end(), preferredEngineId); + if(it == candidates.end()) + { + return std::nullopt; + } + std::vector reordered; + reordered.reserve(candidates.size()); + reordered.push_back(preferredEngineId); + for(const int64_t engineId : candidates) + { + if(engineId != preferredEngineId) + { + reordered.push_back(engineId); + } + } + return reordered; +} + +// ---- Per-handle / per-descriptor state ------------------------------------- + +struct Handle +{ + std::vector devicePropertiesBuffer; + bool devicePropertiesSet = false; +}; + +struct PolicyDescriptor +{ + Handle* handle = nullptr; + std::vector candidateEngineIds; + std::vector serializedGraph; + std::vector sortedEngineIds; + bool finalized = false; + + explicit PolicyDescriptor(Handle* h) + : handle(h) + { + } +}; + +// ---- Base plugin metadata -------------------------------------------------- + +hipdnnPluginStatus_t getName(const char** name) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL(name, CONFIG_BUILTIN_LOG, "getName: null output pointer"); + *name = PLUGIN_NAME; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t getVersion(const char** version) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL(version, CONFIG_BUILTIN_LOG, "getVersion: null output pointer"); + *version = PLUGIN_VERSION; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t getApiVersion(const char** version) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + version, CONFIG_BUILTIN_LOG, "getApiVersion: null output pointer"); + *version = HIPDNN_HEURISTIC_API_VERSION; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t getType(hipdnnPluginType_t* type) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL(type, CONFIG_BUILTIN_LOG, "getType: null output pointer"); + *type = HIPDNN_PLUGIN_TYPE_HEURISTIC; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t setLoggingCallback(hipdnnCallback_t callback) +{ + g_loggingCallback = callback; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t setLogLevel(hipdnnSeverity_t level) +{ + g_logLevel = level; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +void getLastErrorString(const char** errorStr) +{ + if(errorStr == nullptr) + { + return; + } + *errorStr = "No error information available"; +} + +// ---- Policy enumeration ---------------------------------------------------- + +hipdnnPluginStatus_t + getAllPolicyIds(int64_t* policyIds, uint32_t maxPolicies, uint32_t* numPolicies) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + numPolicies, CONFIG_BUILTIN_LOG, "getAllPolicyIds: null num_policies"); + + constexpr uint32_t TOTAL_POLICIES = 1; + *numPolicies = TOTAL_POLICIES; + if(policyIds == nullptr || maxPolicies == 0) + { + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + if(maxPolicies < TOTAL_POLICIES) + { + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; + } + policyIds[0] = policyId(); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t getPolicyName(int64_t id, const char** name) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL(name, CONFIG_BUILTIN_LOG, "getPolicyName: null output pointer"); + if(id != policyId()) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_ERROR, "getPolicyName: unknown policy ID"); + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; + } + *name = POLICY_NAME; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +// ---- Handle lifecycle ------------------------------------------------------ + +hipdnnPluginStatus_t handleCreate(hipdnnHeuristicHandle_t* outHandle) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + outHandle, CONFIG_BUILTIN_LOG, "handleCreate: null output pointer"); + try + { + auto h = std::make_unique(); + *outHandle = reinterpret_cast(h.release()); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + catch(const std::exception& e) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_ERROR, "handleCreate failed: %s", e.what()); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } +} + +hipdnnPluginStatus_t handleDestroy(hipdnnHeuristicHandle_t handle) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL(handle, CONFIG_BUILTIN_LOG, "handleDestroy: null handle"); + delete reinterpret_cast(handle); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t handleSetDeviceProperties(hipdnnHeuristicHandle_t handle, + const hipdnnPluginConstData_t* devicePropsSerialized) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + handle, CONFIG_BUILTIN_LOG, "handleSetDeviceProperties: null handle"); + HIPDNN_PLUGIN_REQUIRE_CONST_DATA(devicePropsSerialized, + true, + CONFIG_BUILTIN_LOG, + "handleSetDeviceProperties: invalid buffer"); + try + { + auto* h = reinterpret_cast(handle); + const auto* data = reinterpret_cast(devicePropsSerialized->ptr); + h->devicePropertiesBuffer.assign(data, data + devicePropsSerialized->size); + h->devicePropertiesSet = true; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + catch(const std::exception& e) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_ERROR, "handleSetDeviceProperties failed: %s", e.what()); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } +} + +// ---- Policy descriptor lifecycle ------------------------------------------ + +hipdnnPluginStatus_t policyDescriptorCreate(hipdnnHeuristicHandle_t pluginHandle, + int64_t id, + hipdnnHeuristicPolicyDescriptor_t* outDesc) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + pluginHandle, CONFIG_BUILTIN_LOG, "policyDescriptorCreate: null handle"); + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + outDesc, CONFIG_BUILTIN_LOG, "policyDescriptorCreate: null output pointer"); + if(id != policyId()) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_ERROR, "policyDescriptorCreate: unknown policy ID"); + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; + } + try + { + auto desc = std::make_unique(reinterpret_cast(pluginHandle)); + *outDesc = reinterpret_cast(desc.release()); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + catch(const std::exception& e) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_ERROR, "policyDescriptorCreate failed: %s", e.what()); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } +} + +hipdnnPluginStatus_t policyDescriptorDestroy(hipdnnHeuristicPolicyDescriptor_t desc) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + desc, CONFIG_BUILTIN_LOG, "policyDescriptorDestroy: null descriptor"); + delete reinterpret_cast(desc); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +// ---- Policy inputs --------------------------------------------------------- + +hipdnnPluginStatus_t policySetEngineIds(hipdnnHeuristicPolicyDescriptor_t desc, + const int64_t* engineIds, + size_t engineIdCount) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL(desc, CONFIG_BUILTIN_LOG, "policySetEngineIds: null descriptor"); + HIPDNN_PLUGIN_REQUIRE_ARRAY(engineIds, + engineIdCount, + CONFIG_BUILTIN_LOG, + "policySetEngineIds: null engine_ids with count > 0"); + try + { + auto* d = reinterpret_cast(desc); + d->candidateEngineIds.assign(engineIds, engineIds + engineIdCount); + d->finalized = false; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + catch(const std::exception& e) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_ERROR, "policySetEngineIds failed: %s", e.what()); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } +} + +hipdnnPluginStatus_t policySetSerializedGraph(hipdnnHeuristicPolicyDescriptor_t desc, + const hipdnnPluginConstData_t* serializedGraph) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + desc, CONFIG_BUILTIN_LOG, "policySetSerializedGraph: null descriptor"); + HIPDNN_PLUGIN_REQUIRE_CONST_DATA(serializedGraph, + false, + CONFIG_BUILTIN_LOG, + "policySetSerializedGraph: invalid graph buffer"); + try + { + auto* d = reinterpret_cast(desc); + const auto* bytes = reinterpret_cast(serializedGraph->ptr); + if(bytes == nullptr || serializedGraph->size == 0) + { + d->serializedGraph.clear(); + } + else + { + d->serializedGraph.assign(bytes, bytes + serializedGraph->size); + } + d->finalized = false; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + catch(const std::exception& e) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_ERROR, "policySetSerializedGraph failed: %s", e.what()); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } +} + +// ---- Selection ------------------------------------------------------------- + +hipdnnPluginStatus_t policyFinalize(hipdnnHeuristicPolicyDescriptor_t desc, int32_t* outApplied) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL(desc, CONFIG_BUILTIN_LOG, "policyFinalize: null descriptor"); + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + outApplied, CONFIG_BUILTIN_LOG, "policyFinalize: null output pointer"); + try + { + auto* d = reinterpret_cast(desc); + *outApplied = 0; + + if(d->candidateEngineIds.empty()) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_INFO, "policyFinalize: no candidate engines; declining"); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + const auto config = EngineOverrideConfig::loadFromEnv(); + if(!config.has_value()) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_INFO, + "policyFinalize: HIPDNN_HEUR_CONFIG_PATH unset or " + "unreadable; declining"); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + const Graph* graph = parseGraphBuffer(d->serializedGraph); + if(graph == nullptr) + { + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + const auto preferredEngineId + = matchOverrideConfig(*config, graph, indexTensorsByUid(graph)); + if(!preferredEngineId.has_value()) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_INFO, + "policyFinalize: no rule matched any conv node; declining"); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + auto reordered = reorderWithPreferredFirst(d->candidateEngineIds, *preferredEngineId); + if(!reordered.has_value()) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_INFO, + "policyFinalize: matched engine 0x%llx not in candidates; declining", + static_cast(*preferredEngineId)); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + CONFIG_BUILTIN_LOG(HIPDNN_SEV_INFO, + "policyFinalize: reordered %zu engines with preferred 0x%llx first", + reordered->size(), + static_cast(*preferredEngineId)); + d->sortedEngineIds = std::move(*reordered); + d->finalized = true; + *outApplied = 1; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + catch(const std::exception& e) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_ERROR, "policyFinalize failed: %s", e.what()); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } +} + +hipdnnPluginStatus_t policyGetSortedEngineIds(hipdnnHeuristicPolicyDescriptor_t desc, + int64_t* engineIds, + size_t* numEngines) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + desc, CONFIG_BUILTIN_LOG, "policyGetSortedEngineIds: null descriptor"); + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + numEngines, CONFIG_BUILTIN_LOG, "policyGetSortedEngineIds: null num_engines pointer"); + try + { + auto* d = reinterpret_cast(desc); + if(!d->finalized) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_ERROR, + "policyGetSortedEngineIds: descriptor not finalized"); + return HIPDNN_PLUGIN_STATUS_NOT_INITIALIZED; + } + if(engineIds == nullptr) + { + *numEngines = d->sortedEngineIds.size(); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + *numEngines = std::min(*numEngines, d->sortedEngineIds.size()); + std::copy_n(d->sortedEngineIds.begin(), *numEngines, engineIds); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + catch(const std::exception& e) + { + CONFIG_BUILTIN_LOG(HIPDNN_SEV_ERROR, "policyGetSortedEngineIds failed: %s", e.what()); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } +} + +} // namespace + +hipdnn_backend::plugin::HeuristicPluginFunctionTable populateFunctionTable() +{ + hipdnn_backend::plugin::HeuristicPluginFunctionTable funcs{}; + funcs.getName = &getName; + funcs.getVersion = &getVersion; + funcs.getApiVersion = &getApiVersion; + funcs.getType = &getType; + funcs.setLoggingCallback = &setLoggingCallback; + funcs.setLogLevel = &setLogLevel; + funcs.getLastErrorString = &getLastErrorString; + funcs.getAllPolicyIds = &getAllPolicyIds; + funcs.getPolicyName = &getPolicyName; + funcs.handleCreate = &handleCreate; + funcs.handleDestroy = &handleDestroy; + funcs.handleSetDeviceProperties = &handleSetDeviceProperties; + funcs.policyDescriptorCreate = &policyDescriptorCreate; + funcs.policyDescriptorDestroy = &policyDescriptorDestroy; + funcs.policySetEngineIds = &policySetEngineIds; + funcs.policySetSerializedGraph = &policySetSerializedGraph; + funcs.policyFinalize = &policyFinalize; + funcs.policyGetSortedEngineIds = &policyGetSortedEngineIds; + return funcs; +} + +} // namespace hipdnn_backend::heuristics::config diff --git a/projects/hipdnn/backend/src/heuristics/config/ConfigBuiltIn.hpp b/projects/hipdnn/backend/src/heuristics/config/ConfigBuiltIn.hpp new file mode 100644 index 00000000000..8c3d4ea8afd --- /dev/null +++ b/projects/hipdnn/backend/src/heuristics/config/ConfigBuiltIn.hpp @@ -0,0 +1,21 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "plugin/HeuristicPlugin.hpp" + +namespace hipdnn_backend::heuristics::config +{ + +// Build a fully-populated heuristic plugin function table that exposes the +// SelectionHeuristic::Config policy as a backend built-in. The policy reads +// HIPDNN_HEUR_CONFIG_PATH, parses an EngineOverrideConfig JSON file, +// walks conv nodes in the serialized graph and reorders the candidate engine +// IDs so the rule-matched engine sits first. The policy declines (outApplied +// = 0) when the env var is unset, the file fails to load, no rule matches, +// or the matched engine is not in the candidate list — in those cases the +// outer policy loop continues to the next plugin. +hipdnn_backend::plugin::HeuristicPluginFunctionTable populateFunctionTable(); + +} // namespace hipdnn_backend::heuristics::config diff --git a/projects/hipdnn/backend/src/heuristics/config/EngineOverrideConfig.hpp b/projects/hipdnn/backend/src/heuristics/config/EngineOverrideConfig.hpp new file mode 100644 index 00000000000..640ed5b8759 --- /dev/null +++ b/projects/hipdnn/backend/src/heuristics/config/EngineOverrideConfig.hpp @@ -0,0 +1,339 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace hipdnn_backend::heuristics::config +{ + +/// Dimension value meaning "match any value in this slot". +inline constexpr int64_t WILDCARD_DIM = -1; + +/// View into one tensor: pointers to the live dim and stride vectors. +/// The matcher does not own this data; callers must keep the underlying +/// vectors alive for the duration of the match call. +struct TensorView +{ + const std::vector* dim; + const std::vector* stride; +}; + +/// Pattern for a single tensor: a list of expected dimensions and optional strides, +/// with -1 as a per-slot wildcard. When `stride` is empty no stride matching is +/// performed. +struct TensorPattern +{ + std::vector dim; + std::vector stride; + + bool matches(const TensorView& tensor) const + { + const auto& tdim = *tensor.dim; + if(dim.size() != tdim.size()) + { + return false; + } + for(size_t i = 0; i < dim.size(); ++i) + { + if(dim[i] != WILDCARD_DIM && dim[i] != tdim[i]) + { + return false; + } + } + if(!stride.empty()) + { + const auto& tstride = *tensor.stride; + if(stride.size() != tstride.size()) + { + return false; + } + for(size_t i = 0; i < stride.size(); ++i) + { + if(stride[i] != WILDCARD_DIM && stride[i] != tstride[i]) + { + return false; + } + } + } + return true; + } +}; + +/// A single engine-override rule (one operation, one engine, ordered tensor patterns). +struct OperationRule +{ + std::string op; + std::string engineName; + std::vector tensors; + + bool matches(const std::vector& inputs) const + { + if(tensors.size() != inputs.size()) + { + return false; + } + for(size_t i = 0; i < tensors.size(); ++i) + { + if(!tensors[i].matches(inputs[i])) + { + return false; + } + } + return true; + } +}; + +namespace detail +{ + +/// FNV-1a hash over a flat vector key. +struct DimKeyHash +{ + size_t operator()(const std::vector& key) const noexcept + { + size_t h = 14695981039346656037ULL; + for(int64_t v : key) + { + const auto* p = reinterpret_cast(&v); + for(size_t b = 0; b < sizeof(int64_t); ++b) + { + h ^= static_cast(p[b]); + h *= 1099511628211ULL; + } + } + return h; + } +}; + +} // namespace detail + +/// Loaded set of engine-override rules (process-lifetime cache around +/// HIPDNN_HEUR_CONFIG_PATH). Rules are evaluated in declaration order; +/// first match wins. Internally split per-op into an exact hash bucket and +/// an order-preserving wildcard list, reconciled by declaration index. +class EngineOverrideConfig +{ +public: + EngineOverrideConfig() = default; + + explicit EngineOverrideConfig(std::vector rules) + { + for(size_t i = 0; i < rules.size(); ++i) + { + indexRule(std::move(rules[i]), i); + } + } + + static std::optional load(const std::string& filepath) + { + std::ifstream file(filepath); + if(!file.is_open()) + { + return std::nullopt; + } + try + { + return parseJson(nlohmann::json::parse(file)); + } + catch(const nlohmann::json::exception&) + { + return std::nullopt; + } + } + + static std::optional loadFromContent(const std::string& content) + { + try + { + return parseJson(nlohmann::json::parse(content)); + } + catch(const nlohmann::json::exception&) + { + return std::nullopt; + } + } + + /// Read HIPDNN_HEUR_CONFIG_PATH and load the referenced config. + /// Returns nullopt when the variable is unset / empty / the file cannot + /// be opened or parsed. Called once per heuristic finalize so env changes + /// take effect without process restart and the path stays testable. + static std::optional loadFromEnv() + { + static constexpr const char* ENV_VAR = "HIPDNN_HEUR_CONFIG_PATH"; + const std::string path + = hipdnn_data_sdk::utilities::trim(hipdnn_data_sdk::utilities::getEnv(ENV_VAR, "")); + if(path.empty()) + { + return std::nullopt; + } + return load(path); + } + + /// Scan rules in declaration order; return the first matching engine ID or nullopt. + std::optional matchOperation(const std::string& op, + const std::vector& tensors) const + { + const auto opIt = _index.find(op); + if(opIt == _index.end()) + { + return std::nullopt; + } + const OpBucket& bucket = opIt->second; + + std::optional exactHit; + { + const auto key = buildDimKey(tensors); + const auto eit = bucket.exact.find(key); + if(eit != bucket.exact.end()) + { + exactHit = eit->second; + } + } + + for(const auto& entry : bucket.wildcards) + { + if(exactHit && entry.order > exactHit->order) + { + break; + } + if(entry.rule.matches(tensors)) + { + return entry.engineId; + } + } + + if(exactHit) + { + return exactHit->engineId; + } + return std::nullopt; + } + + size_t ruleCount() const + { + size_t n = 0; + for(const auto& [op, bucket] : _index) + { + n += bucket.exact.size() + bucket.wildcards.size(); + } + return n; + } + +private: + struct ExactEntry + { + int64_t engineId; + size_t order; + }; + + struct WildcardEntry + { + OperationRule rule; + int64_t engineId; + size_t order; + }; + + struct OpBucket + { + std::unordered_map, ExactEntry, detail::DimKeyHash> exact; + std::vector wildcards; + }; + + std::unordered_map _index; + + static EngineOverrideConfig parseJson(const nlohmann::json& j) + { + std::vector rules; + for(const auto& entry : j.at("engine_overrides")) + { + OperationRule rule; + rule.op = entry.at("op").get(); + rule.engineName = entry.at("engine_name").get(); + for(const auto& t : entry.at("tensors")) + { + TensorPattern pat; + pat.dim = t.at("dim").get>(); + if(t.contains("stride")) + { + pat.stride = t.at("stride").get>(); + } + rule.tensors.push_back(std::move(pat)); + } + rules.push_back(std::move(rule)); + } + return EngineOverrideConfig(std::move(rules)); + } + + static bool hasWildcard(const std::vector& patterns) + { + for(const auto& p : patterns) + { + for(const int64_t d : p.dim) + { + if(d == WILDCARD_DIM) + { + return true; + } + } + if(!p.stride.empty()) + { + return true; + } + } + return false; + } + + static std::vector buildDimKey(const std::vector& patterns) + { + std::vector key; + for(const auto& p : patterns) + { + key.push_back(static_cast(p.dim.size())); + key.insert(key.end(), p.dim.begin(), p.dim.end()); + } + return key; + } + + static std::vector buildDimKey(const std::vector& tensors) + { + std::vector key; + for(const auto& t : tensors) + { + const auto& d = *t.dim; + key.push_back(static_cast(d.size())); + key.insert(key.end(), d.begin(), d.end()); + } + return key; + } + + void indexRule(OperationRule rule, size_t order) + { + const int64_t resolvedId = hipdnn_data_sdk::utilities::engineNameToId(rule.engineName); + OpBucket& bucket = _index[rule.op]; + if(hasWildcard(rule.tensors)) + { + bucket.wildcards.push_back(WildcardEntry{std::move(rule), resolvedId, order}); + } + else + { + const auto key = buildDimKey(rule.tensors); + bucket.exact.try_emplace(key, ExactEntry{resolvedId, order}); + } + } +}; + +} // namespace hipdnn_backend::heuristics::config diff --git a/projects/hipdnn/backend/src/heuristics/static_ordering/StaticOrderingBuiltIn.cpp b/projects/hipdnn/backend/src/heuristics/static_ordering/StaticOrderingBuiltIn.cpp new file mode 100644 index 00000000000..29f209f9dfb --- /dev/null +++ b/projects/hipdnn/backend/src/heuristics/static_ordering/StaticOrderingBuiltIn.cpp @@ -0,0 +1,475 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file StaticOrderingBuiltIn.cpp + * @brief Backend-internal implementation of the + * SelectionHeuristic::StaticOrdering policy. + * + * The policy wraps utilities::sortEngineIds — the legacy MIOPEN_ENGINE-first / + * MIOPEN_ENGINE_DETERMINISTIC-last fallback ordering — and exposes it through + * the heuristic plugin C ABI shape. Functions live in an unnamed namespace + * inside a backend translation unit: they are *not* exported as + * hipdnnHeuristic* symbols, but their signatures match the C ABI exactly so a + * HeuristicPluginFunctionTable can dispatch through them just like a loaded + * plugin's dlsym'd entry points. The wrapper layer (HeuristicPlugin) does not + * distinguish dlopen plugins from built-ins. + */ + +#include "StaticOrderingBuiltIn.hpp" + +#include "heuristics/BuiltInLogging.hpp" +#include "logging/Logging.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace hipdnn_backend::heuristics::static_ordering +{ +namespace +{ + +constexpr const char* PLUGIN_NAME = "BuiltInStaticOrderingHeuristic"; +constexpr const char* PLUGIN_VERSION = "1.0.0"; +constexpr const char* POLICY_NAME = "SelectionHeuristic::StaticOrdering"; + +// File-scope logging callback / level, set via the C-ABI-shaped +// SetLoggingCallback / SetLogLevel below. The backend supplies its own +// callback when registering the built-in (see PluginManagerBase::registerPlugin) +// so log lines from this module flow through the backend logger. +// +// Identity contract: the built-in is statically linked into the backend, so +// these globals live in the same process image as the caller. The last writer +// wins — if multiple HeuristicPluginManager instances register the built-in +// they overwrite each other's callback. This is intentional: registerPlugin() +// hands in a callback that forwards to the backend logger, which is itself a +// process-wide sink, so the identity of the "current" callback does not matter +// as long as one is installed. Do not assume per-instance scoping here. +hipdnnCallback_t g_loggingCallback = nullptr; // NOLINT(readability-identifier-naming) +hipdnnSeverity_t g_logLevel = HIPDNN_SEV_INFO; // NOLINT(readability-identifier-naming) + +#define STATIC_ORDERING_LOG(severity, ...) \ + HIPDNN_BUILTIN_HEURISTIC_LOG( \ + g_loggingCallback, g_logLevel, severity, "[BuiltInStaticOrdering] ", __VA_ARGS__) + +int64_t policyId() +{ + static const int64_t s_id = hipdnn_data_sdk::utilities::policyNameToId(POLICY_NAME); + return s_id; +} + +constexpr const char* FALLBACK_ORDERING_ENV = "HIPDNN_HEUR_FALLBACK_ENGINE_ORDER"; + +/// Parse HIPDNN_HEUR_FALLBACK_ENGINE_ORDER (comma-separated engine names) into a +/// list of engine IDs in the order the user wrote them. Empty / unset env → +/// empty vector (caller falls back to the legacy sortEngineIds ordering). +/// Blank tokens are skipped; unknown engine names hash to a deterministic ID +/// via engineNameToId — the caller filters against the candidate list, so a +/// typo'd name simply won't match anything. +std::vector parseFallbackOrderingEnv() +{ + const std::string raw = hipdnn_data_sdk::utilities::getEnv(FALLBACK_ORDERING_ENV, ""); + if(hipdnn_data_sdk::utilities::trim(raw).empty()) + { + return {}; + } + + std::vector ids; + std::stringstream stream(raw); + std::string token; + while(std::getline(stream, token, ',')) + { + const std::string name = hipdnn_data_sdk::utilities::trim(token); + if(name.empty()) + { + continue; + } + ids.push_back(hipdnn_data_sdk::utilities::engineNameToId(name)); + } + return ids; +} + +/// Restrict @p candidates to engines named in @p envOrder, preserving the env +/// order. Engines not listed in the env are dropped — when the operator sets +/// HIPDNN_HEUR_FALLBACK_ENGINE_ORDER they are explicitly opting out of every +/// other engine. Names in the env that are not in @p candidates are silently +/// skipped (the policy loop only sees engines the rest of the stack already +/// filtered down to). +std::vector applyFallbackOrdering(const std::vector& candidates, + const std::vector& envOrder) +{ + const std::unordered_set candidateSet(candidates.begin(), candidates.end()); + std::vector out; + out.reserve(envOrder.size()); + for(const int64_t id : envOrder) + { + if(candidateSet.count(id) != 0U) + { + out.push_back(id); + } + } + return out; +} + +// Per-handle state. StaticOrdering does not consume device properties, but the +// C ABI requires a working SetDeviceProperties entry point. +struct Handle +{ + std::vector devicePropertiesBuffer; + bool devicePropertiesSet = false; +}; + +// Per-policy-descriptor state. +struct PolicyDescriptor +{ + Handle* handle = nullptr; + std::vector candidateEngineIds; + std::vector sortedEngineIds; + bool finalized = false; + + explicit PolicyDescriptor(Handle* h) + : handle(h) + { + } +}; + +// ---- Base plugin metadata -------------------------------------------------- + +hipdnnPluginStatus_t getName(const char** name) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL(name, STATIC_ORDERING_LOG, "getName: null output pointer"); + *name = PLUGIN_NAME; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t getVersion(const char** version) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL(version, STATIC_ORDERING_LOG, "getVersion: null output pointer"); + *version = PLUGIN_VERSION; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t getApiVersion(const char** version) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + version, STATIC_ORDERING_LOG, "getApiVersion: null output pointer"); + *version = HIPDNN_HEURISTIC_API_VERSION; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t getType(hipdnnPluginType_t* type) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL(type, STATIC_ORDERING_LOG, "getType: null output pointer"); + *type = HIPDNN_PLUGIN_TYPE_HEURISTIC; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t setLoggingCallback(hipdnnCallback_t callback) +{ + g_loggingCallback = callback; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t setLogLevel(hipdnnSeverity_t level) +{ + g_logLevel = level; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +void getLastErrorString(const char** errorStr) +{ + if(errorStr == nullptr) + { + return; + } + *errorStr = "No error information available"; +} + +// ---- Policy enumeration ---------------------------------------------------- + +hipdnnPluginStatus_t + getAllPolicyIds(int64_t* policyIds, uint32_t maxPolicies, uint32_t* numPolicies) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + numPolicies, STATIC_ORDERING_LOG, "getAllPolicyIds: null num_policies"); + + constexpr uint32_t TOTAL_POLICIES = 1; + *numPolicies = TOTAL_POLICIES; + if(policyIds == nullptr || maxPolicies == 0) + { + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + if(maxPolicies < TOTAL_POLICIES) + { + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; + } + policyIds[0] = policyId(); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t getPolicyName(int64_t id, const char** name) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL(name, STATIC_ORDERING_LOG, "getPolicyName: null output pointer"); + if(id != policyId()) + { + STATIC_ORDERING_LOG(HIPDNN_SEV_ERROR, "getPolicyName: unknown policy ID"); + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; + } + *name = POLICY_NAME; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +// ---- Handle lifecycle ------------------------------------------------------ + +hipdnnPluginStatus_t handleCreate(hipdnnHeuristicHandle_t* outHandle) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + outHandle, STATIC_ORDERING_LOG, "handleCreate: null output pointer"); + try + { + auto h = std::make_unique(); + *outHandle = reinterpret_cast(h.release()); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + catch(const std::exception& e) + { + STATIC_ORDERING_LOG(HIPDNN_SEV_ERROR, "handleCreate failed: %s", e.what()); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } +} + +hipdnnPluginStatus_t handleDestroy(hipdnnHeuristicHandle_t handle) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL(handle, STATIC_ORDERING_LOG, "handleDestroy: null handle"); + delete reinterpret_cast(handle); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t handleSetDeviceProperties(hipdnnHeuristicHandle_t handle, + const hipdnnPluginConstData_t* devicePropsSerialized) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + handle, STATIC_ORDERING_LOG, "handleSetDeviceProperties: null handle"); + HIPDNN_PLUGIN_REQUIRE_CONST_DATA(devicePropsSerialized, + true, + STATIC_ORDERING_LOG, + "handleSetDeviceProperties: invalid buffer"); + try + { + auto* h = reinterpret_cast(handle); + const auto* data = reinterpret_cast(devicePropsSerialized->ptr); + h->devicePropertiesBuffer.assign(data, data + devicePropsSerialized->size); + h->devicePropertiesSet = true; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + catch(const std::exception& e) + { + STATIC_ORDERING_LOG(HIPDNN_SEV_ERROR, "handleSetDeviceProperties failed: %s", e.what()); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } +} + +// ---- Policy descriptor lifecycle ------------------------------------------ + +hipdnnPluginStatus_t policyDescriptorCreate(hipdnnHeuristicHandle_t pluginHandle, + int64_t id, + hipdnnHeuristicPolicyDescriptor_t* outDesc) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + pluginHandle, STATIC_ORDERING_LOG, "policyDescriptorCreate: null handle"); + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + outDesc, STATIC_ORDERING_LOG, "policyDescriptorCreate: null output pointer"); + if(id != policyId()) + { + STATIC_ORDERING_LOG(HIPDNN_SEV_ERROR, "policyDescriptorCreate: unknown policy ID"); + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; + } + try + { + auto desc = std::make_unique(reinterpret_cast(pluginHandle)); + *outDesc = reinterpret_cast(desc.release()); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + catch(const std::exception& e) + { + STATIC_ORDERING_LOG(HIPDNN_SEV_ERROR, "policyDescriptorCreate failed: %s", e.what()); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } +} + +hipdnnPluginStatus_t policyDescriptorDestroy(hipdnnHeuristicPolicyDescriptor_t desc) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + desc, STATIC_ORDERING_LOG, "policyDescriptorDestroy: null descriptor"); + delete reinterpret_cast(desc); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +// ---- Policy inputs --------------------------------------------------------- + +hipdnnPluginStatus_t policySetEngineIds(hipdnnHeuristicPolicyDescriptor_t desc, + const int64_t* engineIds, + size_t engineIdCount) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + desc, STATIC_ORDERING_LOG, "policySetEngineIds: null descriptor"); + HIPDNN_PLUGIN_REQUIRE_ARRAY(engineIds, + engineIdCount, + STATIC_ORDERING_LOG, + "policySetEngineIds: null engine_ids with count > 0"); + try + { + auto* d = reinterpret_cast(desc); + d->candidateEngineIds.assign(engineIds, engineIds + engineIdCount); + d->finalized = false; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + catch(const std::exception& e) + { + STATIC_ORDERING_LOG(HIPDNN_SEV_ERROR, "policySetEngineIds failed: %s", e.what()); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } +} + +hipdnnPluginStatus_t policySetSerializedGraph(hipdnnHeuristicPolicyDescriptor_t desc, + const hipdnnPluginConstData_t* serializedGraph) +{ + // StaticOrdering ignores the serialized graph; only validate args. + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + desc, STATIC_ORDERING_LOG, "policySetSerializedGraph: null descriptor"); + HIPDNN_PLUGIN_REQUIRE_CONST_DATA(serializedGraph, + false, + STATIC_ORDERING_LOG, + "policySetSerializedGraph: invalid graph buffer"); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +// ---- Selection ------------------------------------------------------------- + +hipdnnPluginStatus_t policyFinalize(hipdnnHeuristicPolicyDescriptor_t desc, int32_t* outApplied) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL(desc, STATIC_ORDERING_LOG, "policyFinalize: null descriptor"); + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + outApplied, STATIC_ORDERING_LOG, "policyFinalize: null output pointer"); + try + { + auto* d = reinterpret_cast(desc); + if(d->candidateEngineIds.empty()) + { + STATIC_ORDERING_LOG(HIPDNN_SEV_WARN, "policyFinalize: no candidate engines"); + *outApplied = 0; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + // HIPDNN_HEUR_FALLBACK_ENGINE_ORDER, when set, replaces the legacy + // MIOPEN-first / DETERMINISTIC-last ordering. Only engines named in + // the env are eligible — operators use this to constrain selection + // to a known-good shortlist. If the env is set but no listed engine + // is among the candidates the policy declines (outApplied = 0) so + // the policy loop can try the next plugin. + const auto envOrder = parseFallbackOrderingEnv(); + if(!envOrder.empty()) + { + d->sortedEngineIds = applyFallbackOrdering(d->candidateEngineIds, envOrder); + if(d->sortedEngineIds.empty()) + { + STATIC_ORDERING_LOG(HIPDNN_SEV_WARN, + "policyFinalize: HIPDNN_HEUR_FALLBACK_ENGINE_ORDER listed no " + "engines that are candidates; declining."); + *outApplied = 0; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + d->finalized = true; + *outApplied = 1; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + d->sortedEngineIds = d->candidateEngineIds; + hipdnn_data_sdk::utilities::sortEngineIds(d->sortedEngineIds); + d->finalized = true; + *outApplied = 1; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + catch(const std::exception& e) + { + STATIC_ORDERING_LOG(HIPDNN_SEV_ERROR, "policyFinalize failed: %s", e.what()); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } +} + +hipdnnPluginStatus_t policyGetSortedEngineIds(hipdnnHeuristicPolicyDescriptor_t desc, + int64_t* engineIds, + size_t* numEngines) +{ + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + desc, STATIC_ORDERING_LOG, "policyGetSortedEngineIds: null descriptor"); + HIPDNN_PLUGIN_REQUIRE_NOT_NULL( + numEngines, STATIC_ORDERING_LOG, "policyGetSortedEngineIds: null num_engines pointer"); + try + { + auto* d = reinterpret_cast(desc); + if(!d->finalized) + { + STATIC_ORDERING_LOG(HIPDNN_SEV_ERROR, + "policyGetSortedEngineIds: descriptor not finalized"); + return HIPDNN_PLUGIN_STATUS_NOT_INITIALIZED; + } + if(engineIds == nullptr) + { + *numEngines = d->sortedEngineIds.size(); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + *numEngines = std::min(*numEngines, d->sortedEngineIds.size()); + std::copy_n(d->sortedEngineIds.begin(), *numEngines, engineIds); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + catch(const std::exception& e) + { + STATIC_ORDERING_LOG(HIPDNN_SEV_ERROR, "policyGetSortedEngineIds failed: %s", e.what()); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } +} + +} // namespace + +hipdnn_backend::plugin::HeuristicPluginFunctionTable populateFunctionTable() +{ + hipdnn_backend::plugin::HeuristicPluginFunctionTable funcs{}; + funcs.getName = &getName; + funcs.getVersion = &getVersion; + funcs.getApiVersion = &getApiVersion; + funcs.getType = &getType; + funcs.setLoggingCallback = &setLoggingCallback; + funcs.setLogLevel = &setLogLevel; + funcs.getLastErrorString = &getLastErrorString; + funcs.getAllPolicyIds = &getAllPolicyIds; + funcs.getPolicyName = &getPolicyName; + funcs.handleCreate = &handleCreate; + funcs.handleDestroy = &handleDestroy; + funcs.handleSetDeviceProperties = &handleSetDeviceProperties; + funcs.policyDescriptorCreate = &policyDescriptorCreate; + funcs.policyDescriptorDestroy = &policyDescriptorDestroy; + funcs.policySetEngineIds = &policySetEngineIds; + funcs.policySetSerializedGraph = &policySetSerializedGraph; + funcs.policyFinalize = &policyFinalize; + funcs.policyGetSortedEngineIds = &policyGetSortedEngineIds; + return funcs; +} + +} // namespace hipdnn_backend::heuristics::static_ordering diff --git a/projects/hipdnn/backend/src/heuristics/static_ordering/StaticOrderingBuiltIn.hpp b/projects/hipdnn/backend/src/heuristics/static_ordering/StaticOrderingBuiltIn.hpp new file mode 100644 index 00000000000..f3b9957a48a --- /dev/null +++ b/projects/hipdnn/backend/src/heuristics/static_ordering/StaticOrderingBuiltIn.hpp @@ -0,0 +1,19 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "plugin/HeuristicPlugin.hpp" + +namespace hipdnn_backend::heuristics::static_ordering +{ + +// Build a fully-populated heuristic plugin function table that exposes the +// SelectionHeuristic::StaticOrdering policy as a backend built-in. The table +// references file-static functions in the matching .cpp; no symbols are +// exported and no shared library is involved. The built-in is registered with +// HeuristicPluginManager via HeuristicPlugin::createBuiltIn at construction +// time. +hipdnn_backend::plugin::HeuristicPluginFunctionTable populateFunctionTable(); + +} // namespace hipdnn_backend::heuristics::static_ordering diff --git a/projects/hipdnn/backend/src/plugin/EnginePluginManager.hpp b/projects/hipdnn/backend/src/plugin/EnginePluginManager.hpp index 67c04ba6ab6..dfb2e164dcc 100644 --- a/projects/hipdnn/backend/src/plugin/EnginePluginManager.hpp +++ b/projects/hipdnn/backend/src/plugin/EnginePluginManager.hpp @@ -75,6 +75,11 @@ class EnginePluginManager : public PluginManagerBase _engineIds.insert(engineIds.begin(), engineIds.end()); } + void actionAfterClearing() override + { + _engineIds.clear(); + } + std::set _engineIds; }; diff --git a/projects/hipdnn/backend/src/plugin/HeuristicPlugin.cpp b/projects/hipdnn/backend/src/plugin/HeuristicPlugin.cpp index 8969461e74a..e1e32cf4686 100644 --- a/projects/hipdnn/backend/src/plugin/HeuristicPlugin.cpp +++ b/projects/hipdnn/backend/src/plugin/HeuristicPlugin.cpp @@ -5,15 +5,45 @@ #include "HipdnnException.hpp" #include "logging/Logging.hpp" -#include +#include + +#include +#include namespace hipdnn_backend::plugin { +namespace +{ + +// std::string_view{nullptr} is UB. Plugin code is untrusted: an out-param +// const char** may be left null on a "successful" status return. Funnel every +// plugin-supplied C string through this helper before constructing a view. +std::string_view safeStringView(const char* str) noexcept +{ + return (str != nullptr) ? std::string_view{str} : std::string_view{}; +} + +} // anonymous namespace + HeuristicPlugin::HeuristicPlugin(SharedLibrary&& lib) : _lib(std::move(lib)) + , _sourceLabel(_lib.libraryPath().string()) { resolveSymbols(); + validateFunctionTable(); + validatePluginMetadata(*this); +#ifndef NDEBUG + _initialized = true; +#endif +} + +HeuristicPlugin::HeuristicPlugin(HeuristicPluginFunctionTable funcs, std::string sourceLabel) + : _sourceLabel(std::move(sourceLabel)) + , _funcs(funcs) +{ + validateFunctionTable(); + validatePluginMetadata(*this); #ifndef NDEBUG _initialized = true; #endif @@ -21,6 +51,17 @@ HeuristicPlugin::HeuristicPlugin(SharedLibrary&& lib) HeuristicPlugin::HeuristicPlugin() = default; +std::shared_ptr HeuristicPlugin::createBuiltIn(HeuristicPluginFunctionTable funcs, + std::string sourceLabel) +{ + return std::shared_ptr(new HeuristicPlugin(funcs, std::move(sourceLabel))); +} + +std::string_view HeuristicPlugin::sourceLabel() const noexcept +{ + return _sourceLabel; +} + void HeuristicPlugin::resolveSymbols() { // NOLINTBEGIN(bugprone-macro-parentheses, cppcoreguidelines-macro-usage) @@ -36,9 +77,8 @@ void HeuristicPlugin::resolveSymbols() { \ throw HipdnnException( \ HIPDNN_STATUS_PLUGIN_ERROR, \ - std::string("ERROR: HEURISTIC PLUGIN ABI INCOMPLETE\n") \ - + "Plugin: " + _lib.libraryPath().string() + "\n" \ - + "Missing required symbol: " symbolName "\n" \ + std::string("ERROR: HEURISTIC PLUGIN ABI INCOMPLETE\n") + "Plugin: " \ + + _sourceLabel + "\n" + "Missing required symbol: " symbolName "\n" \ + "This plugin does not implement the complete heuristic plugin C ABI.\n" \ + "See plugin_sdk/include/hipdnn_plugin_sdk/HeuristicsPluginApi.h for the " \ "full API.\n" \ @@ -48,37 +88,87 @@ void HeuristicPlugin::resolveSymbols() // NOLINTEND(bugprone-macro-parentheses, cppcoreguidelines-macro-usage) // Required base plugin symbols (from PluginApi.h) - GET_REQUIRED_SYMBOL(_funcGetName, "hipdnnPluginGetName"); - GET_REQUIRED_SYMBOL(_funcGetVersion, "hipdnnPluginGetVersion"); - GET_REQUIRED_SYMBOL(_funcGetApiVersion, "hipdnnPluginGetApiVersion"); - GET_REQUIRED_SYMBOL(_funcGetType, "hipdnnPluginGetType"); - GET_REQUIRED_SYMBOL(_funcSetLoggingCallback, "hipdnnPluginSetLoggingCallback"); - GET_REQUIRED_SYMBOL(_funcGetLastErrorString, "hipdnnPluginGetLastErrorString"); + GET_REQUIRED_SYMBOL(_funcs.getName, "hipdnnPluginGetName"); + GET_REQUIRED_SYMBOL(_funcs.getVersion, "hipdnnPluginGetVersion"); + GET_REQUIRED_SYMBOL(_funcs.getApiVersion, "hipdnnPluginGetApiVersion"); + GET_REQUIRED_SYMBOL(_funcs.getType, "hipdnnPluginGetType"); + GET_REQUIRED_SYMBOL(_funcs.setLoggingCallback, "hipdnnPluginSetLoggingCallback"); + GET_REQUIRED_SYMBOL(_funcs.getLastErrorString, "hipdnnPluginGetLastErrorString"); // Optional base plugin symbols - tryAssignSymbol(_funcSetLogLevel, "hipdnnPluginSetLogLevel"); + tryAssignSymbol(_funcs.setLogLevel, "hipdnnPluginSetLogLevel"); + + // Required policy enumeration symbols + GET_REQUIRED_SYMBOL(_funcs.getAllPolicyIds, "hipdnnHeuristicPluginGetAllPolicyIds"); + GET_REQUIRED_SYMBOL(_funcs.getPolicyName, "hipdnnHeuristicPluginGetPolicyName"); // Required handle lifecycle symbols - GET_REQUIRED_SYMBOL(_funcHandleCreate, "hipdnnHeuristicHandleCreate"); - GET_REQUIRED_SYMBOL(_funcHandleDestroy, "hipdnnHeuristicHandleDestroy"); - GET_REQUIRED_SYMBOL(_funcHandleSetDeviceProperties, "hipdnnHeuristicHandleSetDeviceProperties"); + GET_REQUIRED_SYMBOL(_funcs.handleCreate, "hipdnnHeuristicHandleCreate"); + GET_REQUIRED_SYMBOL(_funcs.handleDestroy, "hipdnnHeuristicHandleDestroy"); + GET_REQUIRED_SYMBOL(_funcs.handleSetDeviceProperties, + "hipdnnHeuristicHandleSetDeviceProperties"); // Required policy descriptor lifecycle symbols - GET_REQUIRED_SYMBOL(_funcPolicyDescriptorCreate, "hipdnnHeuristicPolicyDescriptorCreate"); - GET_REQUIRED_SYMBOL(_funcPolicyDescriptorDestroy, "hipdnnHeuristicPolicyDescriptorDestroy"); + GET_REQUIRED_SYMBOL(_funcs.policyDescriptorCreate, "hipdnnHeuristicPolicyDescriptorCreate"); + GET_REQUIRED_SYMBOL(_funcs.policyDescriptorDestroy, "hipdnnHeuristicPolicyDescriptorDestroy"); // Required policy input symbols - GET_REQUIRED_SYMBOL(_funcPolicySetEngineIds, "hipdnnHeuristicPolicySetEngineIds"); - GET_REQUIRED_SYMBOL(_funcPolicySetSerializedGraph, "hipdnnHeuristicPolicySetSerializedGraph"); + GET_REQUIRED_SYMBOL(_funcs.policySetEngineIds, "hipdnnHeuristicPolicySetEngineIds"); + GET_REQUIRED_SYMBOL(_funcs.policySetSerializedGraph, "hipdnnHeuristicPolicySetSerializedGraph"); // Required selection symbols - GET_REQUIRED_SYMBOL(_funcPolicyFinalize, "hipdnnHeuristicPolicyFinalize"); - GET_REQUIRED_SYMBOL(_funcPolicyGetSortedEngineIds, "hipdnnHeuristicPolicyGetSortedEngineIds"); + GET_REQUIRED_SYMBOL(_funcs.policyFinalize, "hipdnnHeuristicPolicyFinalize"); + GET_REQUIRED_SYMBOL(_funcs.policyGetSortedEngineIds, "hipdnnHeuristicPolicyGetSortedEngineIds"); #undef GET_REQUIRED_SYMBOL +} + +void HeuristicPlugin::validateFunctionTable() const +{ + // Every required entry point must be populated. This mirrors the dlsym + // checks in resolveSymbols and catches built-ins that forget to wire + // something up. setLogLevel is optional and may remain null. + auto require = [&](const void* ptr, const char* name) { + if(ptr == nullptr) + { + throw HipdnnException(HIPDNN_STATUS_PLUGIN_ERROR, + std::string("ERROR: HEURISTIC PLUGIN ABI INCOMPLETE\n") + + "Plugin: " + _sourceLabel + "\n" + + "Missing required entry point: " + name); + } + }; + require(reinterpret_cast(_funcs.getName), "hipdnnPluginGetName"); + require(reinterpret_cast(_funcs.getVersion), "hipdnnPluginGetVersion"); + require(reinterpret_cast(_funcs.getApiVersion), "hipdnnPluginGetApiVersion"); + require(reinterpret_cast(_funcs.getType), "hipdnnPluginGetType"); + require(reinterpret_cast(_funcs.setLoggingCallback), + "hipdnnPluginSetLoggingCallback"); + require(reinterpret_cast(_funcs.getLastErrorString), + "hipdnnPluginGetLastErrorString"); + require(reinterpret_cast(_funcs.getAllPolicyIds), + "hipdnnHeuristicPluginGetAllPolicyIds"); + require(reinterpret_cast(_funcs.getPolicyName), + "hipdnnHeuristicPluginGetPolicyName"); + require(reinterpret_cast(_funcs.handleCreate), "hipdnnHeuristicHandleCreate"); + require(reinterpret_cast(_funcs.handleDestroy), "hipdnnHeuristicHandleDestroy"); + require(reinterpret_cast(_funcs.handleSetDeviceProperties), + "hipdnnHeuristicHandleSetDeviceProperties"); + require(reinterpret_cast(_funcs.policyDescriptorCreate), + "hipdnnHeuristicPolicyDescriptorCreate"); + require(reinterpret_cast(_funcs.policyDescriptorDestroy), + "hipdnnHeuristicPolicyDescriptorDestroy"); + require(reinterpret_cast(_funcs.policySetEngineIds), + "hipdnnHeuristicPolicySetEngineIds"); + require(reinterpret_cast(_funcs.policySetSerializedGraph), + "hipdnnHeuristicPolicySetSerializedGraph"); + require(reinterpret_cast(_funcs.policyFinalize), "hipdnnHeuristicPolicyFinalize"); + require(reinterpret_cast(_funcs.policyGetSortedEngineIds), + "hipdnnHeuristicPolicyGetSortedEngineIds"); +} - // Verify plugin type - auto pluginType = type(); +void HeuristicPlugin::validatePluginMetadata(const HeuristicPlugin& plugin) +{ + auto pluginType = plugin.type(); if(pluginType != HIPDNN_PLUGIN_TYPE_HEURISTIC) { throw HipdnnException(HIPDNN_STATUS_PLUGIN_ERROR, @@ -86,96 +176,171 @@ void HeuristicPlugin::resolveSymbols() + std::to_string(pluginType)); } - // Eagerly cache policy ID - it's always needed for plugin matching - // Compute it once during initialization rather than lazily - auto pluginName = name(); - if(pluginName.empty()) + // Verify the plugin reports a non-empty library name (used purely for diagnostics now; + // policy identity flows through the policy IDs enumerated below). + if(plugin.name().empty()) { throw HipdnnException(HIPDNN_STATUS_PLUGIN_ERROR, - "Cannot load heuristic plugin: policy name is empty"); + "Cannot load heuristic plugin: plugin name is empty"); + } + + // Eagerly enumerate policies and validate that each policy ID matches the FNV-1a hash of + // its canonical name. Mismatches indicate a malformed plugin and cause rejection at load. + const auto policyIds = plugin.getAllPolicyIds(); + for(const int64_t policyId : policyIds) + { + const auto policyName = plugin.getPolicyName(policyId); + if(policyName.empty()) + { + throw HipdnnException(HIPDNN_STATUS_PLUGIN_ERROR, + "Heuristic plugin returned empty name for policy ID " + + std::to_string(policyId)); + } + const int64_t expectedId + = hipdnn_data_sdk::utilities::policyNameToId(std::string(policyName)); + if(expectedId != policyId) + { + throw HipdnnException(HIPDNN_STATUS_PLUGIN_ERROR, + "Policy ID/name mismatch: plugin reported policy '" + + std::string(policyName) + "' with ID " + + std::to_string(policyId) + " but policyNameToId yields " + + std::to_string(expectedId)); + } } - _policyId = hipdnn_data_sdk::utilities::engineNameToId(pluginName); } std::string_view HeuristicPlugin::apiVersion() const { const char* version = nullptr; - invokeHeuristicFunction("get API version", _funcGetApiVersion, &version); - return version; + invokeHeuristicFunction("get API version", _funcs.getApiVersion, &version); + return safeStringView(version); } std::string_view HeuristicPlugin::name() const { const char* name = nullptr; - invokeHeuristicFunction("get plugin name", _funcGetName, &name); - return (name != nullptr) ? name : ""; + invokeHeuristicFunction("get plugin name", _funcs.getName, &name); + return safeStringView(name); } std::string_view HeuristicPlugin::version() const { const char* version = nullptr; - invokeHeuristicFunction("get plugin version", _funcGetVersion, &version); - return version; + invokeHeuristicFunction("get plugin version", _funcs.getVersion, &version); + return safeStringView(version); } hipdnnPluginType_t HeuristicPlugin::type() const { hipdnnPluginType_t pluginType = HIPDNN_PLUGIN_TYPE_UNSPECIFIED; - invokeHeuristicFunction("get plugin type", _funcGetType, &pluginType); + invokeHeuristicFunction("get plugin type", _funcs.getType, &pluginType); return pluginType; } -int64_t HeuristicPlugin::policyId() const +std::vector HeuristicPlugin::getAllPolicyIds() const +{ + if(!_allPolicyIds.empty()) + { + return _allPolicyIds; + } + + uint32_t expectedCount = 0; + invokeHeuristicFunction( + "get number of policies", _funcs.getAllPolicyIds, nullptr, 0u, &expectedCount); + + std::vector policyIds(expectedCount); + uint32_t actualCount = expectedCount; + if(expectedCount > 0) + { + invokeHeuristicFunction("get all policy IDs", + _funcs.getAllPolicyIds, + policyIds.data(), + expectedCount, + &actualCount); + } + + validatePolicyIdsBuffer(expectedCount, actualCount, policyIds); + + _allPolicyIds = policyIds; + return policyIds; +} + +void HeuristicPlugin::validatePolicyIdsBuffer(uint32_t expectedCount, + uint32_t actualCount, + std::vector& policyIds) { - // Policy ID is eagerly cached during construction in resolveSymbols() - return _policyId; + if(expectedCount == 0) + { + throw HipdnnException(HIPDNN_STATUS_PLUGIN_ERROR, "No policies found in the plugin"); + } + + if(actualCount != expectedCount) + { + throw HipdnnException( + HIPDNN_STATUS_PLUGIN_ERROR, + "Number of policies returned does not match the number reported by the plugin"); + } + + std::sort(policyIds.begin(), policyIds.end()); + if(std::adjacent_find(policyIds.begin(), policyIds.end()) != policyIds.end()) + { + throw HipdnnException(HIPDNN_STATUS_PLUGIN_ERROR, "Duplicate policy IDs found"); + } +} + +std::string_view HeuristicPlugin::getPolicyName(int64_t policyId) const +{ + const char* name = nullptr; + invokeHeuristicFunction("get policy name", _funcs.getPolicyName, policyId, &name); + return safeStringView(name); } hipdnnPluginStatus_t HeuristicPlugin::setLoggingCallback(hipdnnCallback_t callback) const { - return _funcSetLoggingCallback(callback); + return _funcs.setLoggingCallback(callback); } hipdnnPluginStatus_t HeuristicPlugin::setLogLevel(hipdnnSeverity_t level) const { - if(_funcSetLogLevel == nullptr) + if(_funcs.setLogLevel == nullptr) { return HIPDNN_PLUGIN_STATUS_SUCCESS; // Optional function not implemented } - return _funcSetLogLevel(level); + return _funcs.setLogLevel(level); } hipdnnHeuristicHandle_t HeuristicPlugin::createHandle() const { hipdnnHeuristicHandle_t handle = nullptr; - invokeHeuristicFunction("create handle", _funcHandleCreate, &handle); + invokeHeuristicFunction("create handle", _funcs.handleCreate, &handle); return handle; } void HeuristicPlugin::destroyHandle(hipdnnHeuristicHandle_t handle) const { - invokeHeuristicFunction("destroy handle", _funcHandleDestroy, handle); + invokeHeuristicFunction("destroy handle", _funcs.handleDestroy, handle); } void HeuristicPlugin::setDeviceProperties( hipdnnHeuristicHandle_t handle, const hipdnnPluginConstData_t* devicePropsSerialized) const { invokeHeuristicFunction( - "set device properties", _funcHandleSetDeviceProperties, handle, devicePropsSerialized); + "set device properties", _funcs.handleSetDeviceProperties, handle, devicePropsSerialized); } hipdnnHeuristicPolicyDescriptor_t - HeuristicPlugin::createPolicyDescriptor(hipdnnHeuristicHandle_t pluginHandle) const + HeuristicPlugin::createPolicyDescriptor(hipdnnHeuristicHandle_t pluginHandle, + int64_t policyId) const { hipdnnHeuristicPolicyDescriptor_t desc = nullptr; invokeHeuristicFunction( - "create policy descriptor", _funcPolicyDescriptorCreate, pluginHandle, &desc); + "create policy descriptor", _funcs.policyDescriptorCreate, pluginHandle, policyId, &desc); return desc; } void HeuristicPlugin::destroyPolicyDescriptor(hipdnnHeuristicPolicyDescriptor_t desc) const { - invokeHeuristicFunction("destroy policy descriptor", _funcPolicyDescriptorDestroy, desc); + invokeHeuristicFunction("destroy policy descriptor", _funcs.policyDescriptorDestroy, desc); } void HeuristicPlugin::setEngineIds(hipdnnHeuristicPolicyDescriptor_t desc, @@ -183,20 +348,20 @@ void HeuristicPlugin::setEngineIds(hipdnnHeuristicPolicyDescriptor_t desc, size_t engineIdCount) const { invokeHeuristicFunction( - "set engine IDs", _funcPolicySetEngineIds, desc, engineIds, engineIdCount); + "set engine IDs", _funcs.policySetEngineIds, desc, engineIds, engineIdCount); } void HeuristicPlugin::setSerializedGraph(hipdnnHeuristicPolicyDescriptor_t desc, const hipdnnPluginConstData_t* serializedGraph) const { invokeHeuristicFunction( - "set serialized graph", _funcPolicySetSerializedGraph, desc, serializedGraph); + "set serialized graph", _funcs.policySetSerializedGraph, desc, serializedGraph); } bool HeuristicPlugin::finalize(hipdnnHeuristicPolicyDescriptor_t desc) const { int32_t applied = 0; - invokeHeuristicFunction("finalize policy", _funcPolicyFinalize, desc, &applied); + invokeHeuristicFunction("finalize policy", _funcs.policyFinalize, desc, &applied); return applied != 0; } @@ -206,7 +371,7 @@ std::vector // Query the count first (pass nullptr for engine_ids) size_t count = 0; invokeHeuristicFunction( - "get sorted engine IDs count", _funcPolicyGetSortedEngineIds, desc, nullptr, &count); + "get sorted engine IDs count", _funcs.policyGetSortedEngineIds, desc, nullptr, &count); if(count == 0) { @@ -217,7 +382,7 @@ std::vector std::vector ids(count); size_t actualCount = count; invokeHeuristicFunction( - "get sorted engine IDs", _funcPolicyGetSortedEngineIds, desc, ids.data(), &actualCount); + "get sorted engine IDs", _funcs.policyGetSortedEngineIds, desc, ids.data(), &actualCount); ids.resize(actualCount); return ids; @@ -226,8 +391,8 @@ std::vector std::string_view HeuristicPlugin::getLastErrorString() const noexcept { const char* error = nullptr; - _funcGetLastErrorString(&error); - return (error != nullptr) ? error : ""; + _funcs.getLastErrorString(&error); + return safeStringView(error); } } // namespace hipdnn_backend::plugin diff --git a/projects/hipdnn/backend/src/plugin/HeuristicPlugin.hpp b/projects/hipdnn/backend/src/plugin/HeuristicPlugin.hpp index 7bc466f702d..30d6e00a8bd 100644 --- a/projects/hipdnn/backend/src/plugin/HeuristicPlugin.hpp +++ b/projects/hipdnn/backend/src/plugin/HeuristicPlugin.hpp @@ -13,20 +13,87 @@ namespace hipdnn_backend::plugin { /** - * @brief Wrapper for a heuristic plugin shared library. + * @brief Function-pointer table for the heuristic plugin C ABI. * - * This class provides a C++ interface to the heuristic plugin C ABI defined in - * HeuristicsPluginApi.h. It manages symbol resolution and provides type-safe - * wrappers around the C function pointers. + * Holds every entry point HeuristicPlugin needs to drive a plugin. Populated + * either by dlsym from a loaded shared library (the `SharedLibrary` ctor below) + * or by a backend-internal "built-in" that supplies the table directly without + * a `.so` (see `HeuristicPluginManager::registerBuiltIn`). The downstream + * wrapper code does not distinguish the two cases. + */ +struct HeuristicPluginFunctionTable +{ + // Base plugin metadata (PluginApi.h) + hipdnnPluginStatus_t (*getName)(const char**) = nullptr; + hipdnnPluginStatus_t (*getVersion)(const char**) = nullptr; + hipdnnPluginStatus_t (*getApiVersion)(const char**) = nullptr; + hipdnnPluginStatus_t (*getType)(hipdnnPluginType_t*) = nullptr; + hipdnnPluginStatus_t (*setLoggingCallback)(hipdnnCallback_t) = nullptr; + hipdnnPluginStatus_t (*setLogLevel)(hipdnnSeverity_t) = nullptr; // optional + void (*getLastErrorString)(const char**) = nullptr; + + // Policy enumeration + hipdnnPluginStatus_t (*getAllPolicyIds)(int64_t*, uint32_t, uint32_t*) = nullptr; + hipdnnPluginStatus_t (*getPolicyName)(int64_t, const char**) = nullptr; + + // Handle lifecycle + hipdnnPluginStatus_t (*handleCreate)(hipdnnHeuristicHandle_t*) = nullptr; + hipdnnPluginStatus_t (*handleDestroy)(hipdnnHeuristicHandle_t) = nullptr; + hipdnnPluginStatus_t (*handleSetDeviceProperties)(hipdnnHeuristicHandle_t, + const hipdnnPluginConstData_t*) + = nullptr; + + // Policy descriptor lifecycle + hipdnnPluginStatus_t (*policyDescriptorCreate)(hipdnnHeuristicHandle_t, + int64_t, + hipdnnHeuristicPolicyDescriptor_t*) + = nullptr; + hipdnnPluginStatus_t (*policyDescriptorDestroy)(hipdnnHeuristicPolicyDescriptor_t) = nullptr; + + // Policy inputs + hipdnnPluginStatus_t (*policySetEngineIds)(hipdnnHeuristicPolicyDescriptor_t, + const int64_t*, + size_t) + = nullptr; + hipdnnPluginStatus_t (*policySetSerializedGraph)(hipdnnHeuristicPolicyDescriptor_t, + const hipdnnPluginConstData_t*) + = nullptr; + + // Selection + hipdnnPluginStatus_t (*policyFinalize)(hipdnnHeuristicPolicyDescriptor_t, int32_t*) = nullptr; + hipdnnPluginStatus_t (*policyGetSortedEngineIds)(hipdnnHeuristicPolicyDescriptor_t, + int64_t*, + size_t*) + = nullptr; +}; + +/** + * @brief Wrapper for a heuristic plugin (shared library or backend built-in). + * + * Provides a C++ interface over the heuristic plugin C ABI defined in + * HeuristicsPluginApi.h. Two construction paths populate the same function + * table: + * - `HeuristicPlugin(SharedLibrary&&)` resolves every symbol via dlsym from + * the loaded `.so` — used by `HeuristicPluginManager::loadPlugins`. + * - `HeuristicPlugin(HeuristicPluginFunctionTable, std::string sourceLabel)` + * accepts a pre-populated table from a backend built-in module — used by + * `HeuristicPluginManager::registerBuiltIn`. `sourceLabel` is shown in + * diagnostics in place of a library path. * - * Heuristic plugins implement base PluginApi.h functions PLUS HeuristicsPluginApi.h extensions. + * Validation (`validatePluginMetadata`, `validatePolicyIdsBuffer`) runs + * identically for both paths. */ class HeuristicPlugin : public PluginBase { protected: - // Protected constructor to prevent direct instantiation + // Shared-library ctor: populates the function table via dlsym. explicit HeuristicPlugin(SharedLibrary&& lib); + // Built-in ctor: caller hands over a fully populated function table. + // `sourceLabel` is a human-readable identifier used in error/diagnostic + // messages (e.g. "built-in:SelectionHeuristic::StaticOrdering"). + HeuristicPlugin(HeuristicPluginFunctionTable funcs, std::string sourceLabel); + // For mocking in tests HeuristicPlugin(); @@ -36,14 +103,17 @@ class HeuristicPlugin : public PluginBase // Base plugin metadata (from PluginApi.h) std::string_view apiVersion() const override; - std::string_view name() const override; // Returns policy name (via hipdnnPluginGetName) + std::string_view name() const override; // Plugin (library) name (via hipdnnPluginGetName) std::string_view version() const override; // Returns plugin version (via hipdnnPluginGetVersion) hipdnnPluginType_t type() const override; // Returns HIPDNN_PLUGIN_TYPE_HEURISTIC (via hipdnnPluginGetType) - // Heuristic-specific metadata - virtual int64_t policyId() const; // Computed from name via engineNameToId + // Heuristic-specific metadata: a single plugin may expose multiple policies. + // getAllPolicyIds() is cached after first invocation; getPolicyName() is + // queried on demand and returns the canonical name reported by the plugin. + virtual std::vector getAllPolicyIds() const; + virtual std::string_view getPolicyName(int64_t policyId) const; // Plugin type - heuristic plugins return HEURISTIC static hipdnnPluginType_t getPluginType() @@ -56,15 +126,15 @@ class HeuristicPlugin : public PluginBase hipdnnPluginStatus_t setLogLevel(hipdnnSeverity_t level) const; - // Plugin handle lifecycle + // Plugin handle lifecycle (one handle per loaded plugin, shared across policies) virtual hipdnnHeuristicHandle_t createHandle() const; virtual void destroyHandle(hipdnnHeuristicHandle_t handle) const; virtual void setDeviceProperties(hipdnnHeuristicHandle_t handle, const hipdnnPluginConstData_t* devicePropsSerialized) const; - // Policy descriptor lifecycle + // Policy descriptor lifecycle (one descriptor per policy slot) virtual hipdnnHeuristicPolicyDescriptor_t - createPolicyDescriptor(hipdnnHeuristicHandle_t pluginHandle) const; + createPolicyDescriptor(hipdnnHeuristicHandle_t pluginHandle, int64_t policyId) const; virtual void destroyPolicyDescriptor(hipdnnHeuristicPolicyDescriptor_t desc) const; // Policy inputs @@ -78,6 +148,34 @@ class HeuristicPlugin : public PluginBase virtual bool finalize(hipdnnHeuristicPolicyDescriptor_t desc) const; virtual std::vector getSortedEngineIds(hipdnnHeuristicPolicyDescriptor_t desc) const; + // Validation helpers shared between resolveSymbols() (run at load time) and + // unit tests. Each helper throws HipdnnException on failure. + // + // validatePluginMetadata: checks plugin type is HEURISTIC, plugin library + // name is non-empty, and every reported policy has a non-empty name whose + // FNV-1a hash matches its policy ID. Operates entirely through virtual + // accessors so a NiceMock can drive each rejection path. + static void validatePluginMetadata(const HeuristicPlugin& plugin); + + // validatePolicyIdsBuffer: checks the raw policy ID buffer returned by a + // plugin: actual count matches the expected count from the prior count + // query, and the buffer contains no intra-plugin duplicates. Sorts + // policyIds in place. + static void validatePolicyIdsBuffer(uint32_t expectedCount, + uint32_t actualCount, + std::vector& policyIds); + + // Factory for backend built-in heuristics. Wraps a fully populated function + // table in a HeuristicPlugin without going through dlopen. Validates the + // table is complete and the metadata matches the same rules as a loaded + // plugin (HeuristicPlugin::validatePluginMetadata). + static std::shared_ptr createBuiltIn(HeuristicPluginFunctionTable funcs, + std::string sourceLabel); + + // Source identifier used in diagnostics (library path for dlopen plugins, + // "built-in:" for built-ins). + std::string_view sourceLabel() const noexcept; + protected: // Error handling helper (must not throw, used during error handling) std::string_view getLastErrorString() const noexcept; @@ -112,48 +210,24 @@ class HeuristicPlugin : public PluginBase SharedLibrary _lib; + // For diagnostics — either the library path string or a built-in label. + std::string _sourceLabel; + + // Function-pointer table for the heuristic C ABI. Populated by + // resolveSymbols() in the dlopen ctor or by the caller in the built-in ctor. + HeuristicPluginFunctionTable _funcs; + private: void resolveSymbols(); + void validateFunctionTable() const; #ifndef NDEBUG bool _initialized = false; #endif - // Cached metadata (eagerly initialized in resolveSymbols) - int64_t _policyId = -1; - - // Base plugin function pointers (from PluginApi.h) - hipdnnPluginStatus_t (*_funcGetName)(const char**); - hipdnnPluginStatus_t (*_funcGetVersion)(const char**); - hipdnnPluginStatus_t (*_funcGetApiVersion)(const char**); - hipdnnPluginStatus_t (*_funcGetType)(hipdnnPluginType_t*); - hipdnnPluginStatus_t (*_funcSetLoggingCallback)(hipdnnCallback_t); - hipdnnPluginStatus_t (*_funcSetLogLevel)(hipdnnSeverity_t); - void (*_funcGetLastErrorString)(const char**); - - // Handle lifecycle function pointers - hipdnnPluginStatus_t (*_funcHandleCreate)(hipdnnHeuristicHandle_t*); - hipdnnPluginStatus_t (*_funcHandleDestroy)(hipdnnHeuristicHandle_t); - hipdnnPluginStatus_t (*_funcHandleSetDeviceProperties)(hipdnnHeuristicHandle_t, - const hipdnnPluginConstData_t*); - - // Policy descriptor lifecycle function pointers - hipdnnPluginStatus_t (*_funcPolicyDescriptorCreate)(hipdnnHeuristicHandle_t, - hipdnnHeuristicPolicyDescriptor_t*); - hipdnnPluginStatus_t (*_funcPolicyDescriptorDestroy)(hipdnnHeuristicPolicyDescriptor_t); - - // Policy input function pointers - hipdnnPluginStatus_t (*_funcPolicySetEngineIds)(hipdnnHeuristicPolicyDescriptor_t, - const int64_t*, - size_t); - hipdnnPluginStatus_t (*_funcPolicySetSerializedGraph)(hipdnnHeuristicPolicyDescriptor_t, - const hipdnnPluginConstData_t*); - - // Selection function pointers - hipdnnPluginStatus_t (*_funcPolicyFinalize)(hipdnnHeuristicPolicyDescriptor_t, int32_t*); - hipdnnPluginStatus_t (*_funcPolicyGetSortedEngineIds)(hipdnnHeuristicPolicyDescriptor_t, - int64_t*, - size_t*); + // Cached policy IDs (lazily populated by getAllPolicyIds and validated in + // resolveSymbols). Mutable so the const accessor can fill the cache. + mutable std::vector _allPolicyIds; friend class PluginManagerBase; }; diff --git a/projects/hipdnn/backend/src/plugin/HeuristicPluginManager.hpp b/projects/hipdnn/backend/src/plugin/HeuristicPluginManager.hpp index 99da053402e..1e552089bb1 100644 --- a/projects/hipdnn/backend/src/plugin/HeuristicPluginManager.hpp +++ b/projects/hipdnn/backend/src/plugin/HeuristicPluginManager.hpp @@ -12,6 +12,7 @@ #include "PluginCore.hpp" #include #include +#include #include #include @@ -19,16 +20,23 @@ namespace hipdnn_backend::plugin { /** - * @brief Manager for loading and validating heuristic plugin shared libraries. + * @brief Manager for loading and validating heuristic plugins. * - * This class extends PluginManagerBase to provide heuristic-specific plugin - * discovery, loading, and validation. It uses a separate search path from - * engine plugins and validates heuristic-specific constraints. + * Loads heuristic plugins from disk and registers backend-internal "built-in" + * heuristics (e.g. SelectionHeuristic::StaticOrdering) at construction time. + * Both kinds flow through the same `validateBeforeAdding` checks so downstream + * code cannot tell them apart. * * Validation includes: * - Heuristic C ABI major version compatibility - * - Unique policy IDs across all loaded heuristic plugins - * - Policy name is provided via hipdnnPluginGetName() + * - Unique policy IDs across all loaded heuristic plugins (one plugin may expose many) + * - Plugin (library) name is provided via hipdnnPluginGetName() + * + * A single heuristic plugin may expose one or more selection policies. Each + * policy is identified by a stable int64 policy ID derived from the canonical + * policy name (FNV-1a hash via policyNameToId). The plugin layer validates the + * policy ID/name pairing eagerly at load; this manager enforces uniqueness + * across all loaded plugins (built-ins included). */ class HeuristicPluginManager : public PluginManagerBase { @@ -37,6 +45,7 @@ class HeuristicPluginManager : public PluginManagerBase : PluginManagerBase(getPluginSearchPaths( "HIPDNN_HEURISTIC_PLUGIN_DIR", {std::filesystem::path("hipdnn_plugins/heuristics/")})) { + registerBuiltIns(); } protected: @@ -45,7 +54,7 @@ class HeuristicPluginManager : public PluginManagerBase using hipdnn_data_sdk::utilities::Version; // Validate heuristic C ABI major version against the heuristic API version - // (RFC 0007: heuristic plugin API has independent versioning from backend) + // (the heuristic plugin API has independent versioning from the backend) if(Version{plugin.apiVersion()}.major != HIPDNN_HEURISTIC_API_VERSION_MAJOR) { throw HipdnnException(HIPDNN_STATUS_PLUGIN_ERROR, @@ -59,36 +68,58 @@ class HeuristicPluginManager : public PluginManagerBase + "Expected API version: " HIPDNN_HEURISTIC_API_VERSION); } - // Validate unique policy ID - const int64_t policyId = plugin.policyId(); - if(_policyIds.find(policyId) != _policyIds.end()) - { - throw HipdnnException(HIPDNN_STATUS_PLUGIN_ERROR, - "ERROR: HEURISTIC PLUGIN VALIDATION FAILED\n" - "Policy ID " - + std::to_string(policyId) - + " already exists in the list of loaded heuristic plugins.\n" - + "Each heuristic plugin must have a unique policy ID."); - } - - // Validate policy name is provided (required for all heuristic plugins) - auto policyNameView = plugin.name(); - if(policyNameView.empty()) + // Validate plugin (library) name is provided + if(plugin.name().empty()) { throw HipdnnException( HIPDNN_STATUS_PLUGIN_ERROR, "ERROR: HEURISTIC PLUGIN VALIDATION FAILED\n" - "Policy name is required but was not provided.\n" - "Plugin must implement hipdnnPluginGetName() and return a non-empty policy name."); + "Plugin name is required but was not provided.\n" + "Plugin must implement hipdnnPluginGetName() and return a non-empty name."); + } + + // Validate every policy ID is globally unique across loaded plugins. + // The plugin layer (HeuristicPlugin::resolveSymbols) already checks intra-plugin + // uniqueness and policyNameToId(name) == policyId; here we extend the check across + // the full set of loaded plugins (including built-ins). + const auto policyIds = plugin.getAllPolicyIds(); + for(const int64_t policyId : policyIds) + { + if(_policyIds.find(policyId) != _policyIds.end()) + { + throw HipdnnException( + HIPDNN_STATUS_PLUGIN_ERROR, + "ERROR: HEURISTIC PLUGIN VALIDATION FAILED\n" + "Policy ID " + + std::to_string(policyId) + + " already exists in the list of loaded heuristic plugins.\n" + + "Each policy must have a unique ID."); + } } } void actionAfterAdding(const HeuristicPlugin& plugin) override { - _policyIds.insert(plugin.policyId()); + const auto policyIds = plugin.getAllPolicyIds(); + _policyIds.insert(policyIds.begin(), policyIds.end()); + } + + void actionAfterClearing() override + { + _policyIds.clear(); + // Built-ins survive a clear: they are not loaded from a path and the + // ABSOLUTE plugin-loading mode is intended to replace external plugins, + // not first-party backend modules. Re-register them so policy IDs and + // _plugins stay consistent. + registerBuiltIns(); } private: + // Registers all backend-internal heuristic policies. Implemented in + // backend/src/heuristics/BuiltInHeuristics.cpp so this header doesn't need + // to know about each built-in module's internals. + void registerBuiltIns(); + std::set _policyIds; }; diff --git a/projects/hipdnn/backend/src/plugin/HeuristicPluginResourceManager.cpp b/projects/hipdnn/backend/src/plugin/HeuristicPluginResourceManager.cpp index a6ef8042503..a81be3dae1d 100644 --- a/projects/hipdnn/backend/src/plugin/HeuristicPluginResourceManager.cpp +++ b/projects/hipdnn/backend/src/plugin/HeuristicPluginResourceManager.cpp @@ -35,6 +35,29 @@ struct HeuristicPluginShutdownRegistrar HeuristicPluginShutdownRegistrar gHeuristicShutdownRegistrar; +// Best-effort handle destruction used by both the constructor (rollback) and +// the destructor. Plugin code is untrusted and may throw arbitrary types; we +// must swallow everything because we are either mid-construction-failure or +// mid-destruction. +void safeDestroyHandle(const HeuristicPlugin* plugin, hipdnnHeuristicHandle_t handle) noexcept +{ + try + { + plugin->destroyHandle(handle); + } + catch(const std::exception& e) + { + HIPDNN_BACKEND_LOG_WARN( + "Failed to destroy handle for heuristic plugin '{}': {}", plugin->name(), e.what()); + } + catch(...) + { + HIPDNN_BACKEND_LOG_WARN( + "Failed to destroy handle for heuristic plugin '{}' (unknown exception)", + plugin->name()); + } +} + } // anonymous namespace // Static accessor implementations for CRTP base class @@ -80,86 +103,101 @@ HeuristicPluginResourceManager::HeuristicPluginResourceManager( return; // No plugins to initialize } - // Create plugin handles for all loaded heuristic plugins + // Create plugin handles for all loaded heuristic plugins. Each plugin has a single + // handle, but a plugin may expose multiple policies; every policy ID maps to the + // same handle (N:1). + // + // Per-plugin setup is wrapped in try/catch so that a single failing plugin does + // not leave the resource manager partially constructed. If any step after + // createHandle() throws, the constructor would otherwise exit with an active + // plugin handle recorded only in _handleToPlugin — and since the destructor + // does NOT run for an object whose constructor threw, that handle would leak. + // The rollback path below mirrors what the destructor would do for the + // half-initialized plugin. const auto& plugins = _pm->getPlugins(); for(const auto& plugin : plugins) { - // Set logging callback before creating handle + // Set logging callback before creating handle. This is a status-returning + // call (not throwing) so it stays outside the try block. auto logStatus = plugin->setLoggingCallback(hipdnn_backend::logging::backendLoggingCallback); if(logStatus != HIPDNN_PLUGIN_STATUS_SUCCESS) { - HIPDNN_BACKEND_LOG_WARN( - "Failed to set logging callback on heuristic plugin with policy ID {}", - plugin->policyId()); + HIPDNN_BACKEND_LOG_WARN("Failed to set logging callback on heuristic plugin '{}'", + plugin->name()); continue; } - // Set log level (optional) - hipdnnSeverity_t level = HIPDNN_SEV_INFO; - hipdnn_backend::logging::getGlobalLogLevel(level); - plugin->setLogLevel(level); - - // Create plugin handle hipdnnHeuristicHandle_t handle = nullptr; + std::vector registeredPolicyIds; + try { + // Set log level (optional). May throw from untrusted plugin code. + hipdnnSeverity_t level = HIPDNN_SEV_INFO; + hipdnn_backend::logging::getGlobalLogLevel(level); + plugin->setLogLevel(level); + + // Create plugin handle (one per plugin, shared across all its policies) handle = plugin->createHandle(); if(handle == nullptr) { - HIPDNN_BACKEND_LOG_ERROR("Plugin with policy ID {} ({}) returned null handle " - "despite reporting success. Plugin will be unavailable.", - plugin->policyId(), + HIPDNN_BACKEND_LOG_ERROR("Heuristic plugin '{}' returned null handle despite " + "reporting success. Plugin will be unavailable.", plugin->name()); continue; } + + _handleToPlugin[handle] = plugin.get(); + + const auto policyIds = plugin->getAllPolicyIds(); + for(const int64_t policyId : policyIds) + { + _policyIdToHandle[policyId] = handle; + registeredPolicyIds.push_back(policyId); + HIPDNN_BACKEND_LOG_INFO("Registered heuristic policy ID {} ({}) from plugin '{}'", + policyId, + std::string(plugin->getPolicyName(policyId)), + plugin->name()); + } + + continue; // success — skip the rollback block below } - catch(const HipdnnException& e) + catch(const std::exception& e) { - HIPDNN_BACKEND_LOG_ERROR("Failed to create handle for heuristic plugin with policy ID " - "{} ({}): {}. Plugin will be unavailable.", - plugin->policyId(), + HIPDNN_BACKEND_LOG_ERROR("Failed to initialize heuristic plugin '{}': {}. " + "Plugin will be unavailable.", plugin->name(), e.what()); - continue; } - - _handleToPlugin[handle] = plugin.get(); - _policyIdToHandle[plugin->policyId()] = handle; - - HIPDNN_BACKEND_LOG_INFO("Created heuristic plugin handle for policy ID {} ({})", - plugin->policyId(), - plugin->name()); - } -} - -HeuristicPluginResourceManager::~HeuristicPluginResourceManager() -{ - // Lambda to safely destroy a handle, catching all errors - auto safeDestroyHandle = [](const HeuristicPlugin* plugin, hipdnnHeuristicHandle_t handle) { - try + catch(...) { - plugin->destroyHandle(handle); + HIPDNN_BACKEND_LOG_ERROR( + "Failed to initialize heuristic plugin '{}' (unknown exception). " + "Plugin will be unavailable.", + plugin->name()); } - catch(const std::exception& e) + + // Reached only via catch: undo any partial state for this plugin so the + // resource manager remains consistent (and free the handle if we got + // far enough to create it). + for(const int64_t policyId : registeredPolicyIds) { - HIPDNN_BACKEND_LOG_WARN("Failed to destroy handle for heuristic plugin '{}' (policy ID " - "{}) during cleanup: {}", - plugin->name(), - plugin->policyId(), - e.what()); + _policyIdToHandle.erase(policyId); } - catch(...) + if(handle != nullptr) { - HIPDNN_BACKEND_LOG_WARN( - "Failed to destroy handle for heuristic plugin '{}' (policy ID {}) during cleanup: " - "unknown error", - plugin->name(), - plugin->policyId()); + _handleToPlugin.erase(handle); + safeDestroyHandle(plugin.get(), handle); } - }; + } +} - // Destroy all plugin handles +HeuristicPluginResourceManager::~HeuristicPluginResourceManager() +{ + // Destroy all plugin handles. safeDestroyHandle (anon namespace above) + // swallows exceptions; the constructor uses the same helper for its + // partial-failure rollback. for(const auto& [handle, plugin] : _handleToPlugin) { safeDestroyHandle(plugin, handle); @@ -184,6 +222,14 @@ HeuristicPluginResourceManager& { if(this != &other) { + // Destroy any handles we currently own before overwriting the map — + // move-assigning into _handleToPlugin would otherwise silently drop + // them and leak. Mirrors the destructor's loop. + for(const auto& [handle, plugin] : _handleToPlugin) + { + safeDestroyHandle(plugin, handle); + } + _handleToPlugin = std::move(other._handleToPlugin); _policyIdToHandle = std::move(other._policyIdToHandle); _cachedPolicyInfos = std::move(other._cachedPolicyInfos); @@ -236,10 +282,9 @@ void HeuristicPluginResourceManager::setDevicePropertiesOnAllHandles( } catch(const HipdnnException& e) { - HIPDNN_BACKEND_LOG_WARN( - "Failed to set device properties on heuristic plugin with policy ID {}: {}", - plugin->policyId(), - e.what()); + HIPDNN_BACKEND_LOG_WARN("Failed to set device properties on heuristic plugin '{}': {}", + plugin->name(), + e.what()); // Continue with other plugins } } @@ -253,16 +298,21 @@ std::vector HeuristicPluginResourceManager::getHeuristicPol } std::vector infos; - infos.reserve(_handleToPlugin.size()); + infos.reserve(_policyIdToHandle.size()); for(const auto& [handle, plugin] : _handleToPlugin) { - HeuristicPolicyInfo info; - info.policyId = plugin->policyId(); - info.policyName = std::string(plugin->name()); - info.pluginVersion = std::string(plugin->version()); - info.apiVersion = std::string(plugin->apiVersion()); - infos.push_back(info); + const auto policyIds = plugin->getAllPolicyIds(); + for(const int64_t policyId : policyIds) + { + HeuristicPolicyInfo info; + info.policyId = policyId; + info.policyName = std::string(plugin->getPolicyName(policyId)); + info.pluginName = std::string(plugin->name()); + info.pluginVersion = std::string(plugin->version()); + info.apiVersion = std::string(plugin->apiVersion()); + infos.push_back(info); + } } _cachedPolicyInfos = infos; @@ -301,6 +351,7 @@ std::string HeuristicPluginResourceManager::toString() const { oss << " (" << info.policyName << ")"; } + oss << ", Plugin: " << info.pluginName; oss << ", Plugin Version: " << info.pluginVersion; oss << ", API Version: " << info.apiVersion << "\n"; } diff --git a/projects/hipdnn/backend/src/plugin/HeuristicPluginResourceManager.hpp b/projects/hipdnn/backend/src/plugin/HeuristicPluginResourceManager.hpp index 8e652e89c62..7e5f7740f59 100644 --- a/projects/hipdnn/backend/src/plugin/HeuristicPluginResourceManager.hpp +++ b/projects/hipdnn/backend/src/plugin/HeuristicPluginResourceManager.hpp @@ -39,7 +39,8 @@ class HeuristicPluginManager; struct HeuristicPolicyInfo { std::string policyName; ///< Canonical policy name (UTF-8) - int64_t policyId; ///< Stable policy ID (engineNameToId hash) + int64_t policyId; ///< Stable policy ID (policyNameToId hash) + std::string pluginName; ///< Plugin (library) name; "hipdnn-builtin" for built-in policies std::string pluginVersion; ///< Plugin implementation version std::string apiVersion; ///< Heuristic C ABI version }; @@ -121,7 +122,7 @@ class HeuristicPluginResourceManager * Returns the hipdnnHeuristicHandle_t created for the plugin that implements * the given policy ID. Returns nullptr if no plugin with that policy ID is loaded. * - * @param policyId The policy ID (int64_t from engineNameToId) + * @param policyId The policy ID (int64_t from policyNameToId) * @return The plugin handle, or nullptr if not found */ virtual hipdnnHeuristicHandle_t getHeuristicHandleForPolicyId(int64_t policyId) const; @@ -132,7 +133,7 @@ class HeuristicPluginResourceManager * Returns the HeuristicPlugin* that implements the given policy ID. * Returns nullptr if no plugin with that policy ID is loaded. * - * @param policyId The policy ID (int64_t from engineNameToId) + * @param policyId The policy ID (int64_t from policyNameToId) * @return The plugin pointer, or nullptr if not found */ virtual const HeuristicPlugin* getPluginForPolicyId(int64_t policyId) const; diff --git a/projects/hipdnn/backend/src/plugin/PluginCore.hpp b/projects/hipdnn/backend/src/plugin/PluginCore.hpp index 2cfe7f202ad..894210b7d4a 100644 --- a/projects/hipdnn/backend/src/plugin/PluginCore.hpp +++ b/projects/hipdnn/backend/src/plugin/PluginCore.hpp @@ -134,6 +134,11 @@ class PluginManagerBase // This function is called after the plugin is added to the plugin list. virtual void actionAfterAdding([[maybe_unused]] const Plugin& plugin) {} + // This function is called after the plugin list is cleared. Derived classes + // must override this to reset any auxiliary state kept in sync with _plugins + // (e.g. derived-class indexes populated from actionAfterAdding). + virtual void actionAfterClearing() {} + // For cases where tests need to override the default plugin search paths static std::set getPluginSearchPaths(const char* envVarName, @@ -208,9 +213,30 @@ class PluginManagerBase } } + // Track failed plugin loads for summary reporting + size_t failedCount = 0; + for(const auto& filePath : filesToLoad) { - loadPluginFromFile(filePath); + if(!loadPluginFromFile(filePath)) + { + failedCount++; + // Error already logged in loadPluginFromFile + } + } + + // Emit summary if any plugins failed to load + if(failedCount > 0) + { + HIPDNN_BACKEND_LOG_WARN( + "⚠️ Plugin loading summary: {} plugin(s) failed to load out of {} attempted. " + "Check error messages above for details.", + failedCount, + filesToLoad.size()); + } + else if(!filesToLoad.empty()) + { + HIPDNN_BACKEND_LOG_INFO("✓ Successfully loaded all {} plugin(s)", filesToLoad.size()); } } @@ -224,11 +250,31 @@ class PluginManagerBase return _loadedPluginFiles; } +protected: + // Register a backend-internal plugin (e.g. a built-in heuristic) without going + // through dlopen. Runs the same setLoggingCallback/setLogLevel/validateBeforeAdding + // path as loadPluginFromFile so built-ins and dlopen-loaded plugins are + // indistinguishable downstream. Throws on validation failure — built-in + // failures are build bugs, not silent skips. + void registerPlugin(std::shared_ptr plugin) + { + plugin->setLoggingCallback(logging::backendLoggingCallback); + hipdnnSeverity_t currentLogLevel{}; + logging::getGlobalLogLevel(currentLogLevel); + plugin->setLogLevel(currentLogLevel); + + validateBeforeAdding(*plugin); + + _plugins.emplace_back(std::move(plugin)); + actionAfterAdding(*_plugins.back()); + } + private: void clearPlugins() { _plugins.clear(); _loadedPluginFiles.clear(); + actionAfterClearing(); } void scanDirectoryForPlugins(const std::filesystem::path& dirPath, @@ -253,19 +299,23 @@ class PluginManagerBase } } - void loadPluginFromFile(const std::filesystem::path& filePath) +protected: + bool loadPluginFromFile(const std::filesystem::path& filePath) { - HIPDNN_BACKEND_LOG_INFO("Attempting to load plugin from [{}]", filePath.string()); + bool success = false; hipdnn_backend::tryCatch( [&]() { SharedLibrary lib(filePath); const auto libraryPath = lib.libraryPath(); - // Shared library ensures an injective, weakly canonical mapping to a path + // Shared library ensures an injective, weakly canonical mapping to a path. + // Treat an already-loaded library as a successful no-op so the caller's + // failedCount reflects real load failures only. if(_loadedPluginFiles.find(libraryPath) != _loadedPluginFiles.end()) { + success = true; return; } @@ -305,10 +355,15 @@ class PluginManagerBase static_cast(type)); actionAfterAdding(*_plugins.back()); + + success = true; }, - fmt::format("Error loading plugin from [{}]: ", filePath.string())); + fmt::format("❌ Error loading plugin from [{}]: ", filePath.string())); + + return success; } +private: std::vector> _plugins; std::set _loadedPluginFiles; std::set _defaultPluginPaths; diff --git a/projects/hipdnn/backend/src/utilities/EngineOrdering.cpp b/projects/hipdnn/backend/src/utilities/EngineOrdering.cpp index ab80ec88768..f5b0c82ae6b 100644 --- a/projects/hipdnn/backend/src/utilities/EngineOrdering.cpp +++ b/projects/hipdnn/backend/src/utilities/EngineOrdering.cpp @@ -1,12 +1,10 @@ // Copyright © Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -#include -#include - -#include "hipdnn_data_sdk/utilities/EngineNames.hpp" #include "utilities/EngineOrdering.hpp" +#include + namespace hipdnn_backend { namespace utilities @@ -14,38 +12,8 @@ namespace utilities void sortEngineIds(std::vector& engineIds) { - // Sort engine IDs: MIOPEN_ENGINE first, MIOPEN_ENGINE_DETERMINISTIC last, others in middle - // Using index-based sorting with std::sort to achieve stable sort behavior - - std::vector indices(engineIds.size()); - std::iota(indices.begin(), indices.end(), 0); - - auto getPriority = [](int64_t engineId) -> int { - if(engineId == hipdnn_data_sdk::utilities::MIOPEN_ENGINE_ID) - { - return 0; - } - if(engineId == hipdnn_data_sdk::utilities::MIOPEN_ENGINE_DETERMINISTIC_ID) - { - return 2; - } - return 1; // Other engines - }; - - std::sort(indices.begin(), indices.end(), [&](size_t i, size_t j) { - const int priA = getPriority(engineIds[i]); - const int priB = getPriority(engineIds[j]); - return (priA != priB) ? (priA < priB) : (i < j); - }); - - // Reorder engineIds based on sorted indices - std::vector sorted; - sorted.reserve(engineIds.size()); - for(const size_t idx : indices) - { - sorted.push_back(engineIds[idx]); - } - engineIds = std::move(sorted); + // Delegate to data_sdk implementation (shared with heuristic plugins) + hipdnn_data_sdk::utilities::sortEngineIds(engineIds); } } // namespace utilities diff --git a/projects/hipdnn/backend/tests/CMakeLists.txt b/projects/hipdnn/backend/tests/CMakeLists.txt index 3263492c65f..4670a29787d 100644 --- a/projects/hipdnn/backend/tests/CMakeLists.txt +++ b/projects/hipdnn/backend/tests/CMakeLists.txt @@ -11,16 +11,10 @@ endif() find_package(hip REQUIRED) find_package(Threads REQUIRED) -set(TEST_PLUGIN1_NAME "hipdnn_test_plugin1") -set(TEST_PLUGIN2_NAME "hipdnn_test_plugin2") -set(TEST_NO_API_VERSION_PLUGIN_NAME "hipdnn_test_no_api_version_plugin_name") -set(TEST_ENGINE_PLUGIN1_NAME "hipdnn_test_engine_plugin1") - -# Heuristic plugin test names (RFC 0007) -# Match the names defined in tests/test_plugins/CMakeLists.txt -set(TEST_GOOD_HEURISTIC_PLUGIN_NAME "test_good_heuristic_plugin") -set(TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME "test_incomplete_heuristic_api_plugin") -set(TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME "test_no_optional_heuristic_plugin") +# Test plugin target names come from cmake/TestPluginNames.cmake (included by +# the root CMakeLists.txt before any subdirectory). The targets themselves are +# created later by tests/test_plugins/, but the names are needed here for +# compile definitions and add_dependencies forward references. add_executable( hipdnn_backend_tests @@ -104,12 +98,19 @@ add_executable( TestFlatbufferUtilities.cpp TestHandle.cpp TestHelpers.cpp + IntegrationHeuristicPlugin.cpp + TestHeuristicPolicyFramework.cpp + IntegrationHeuristicPolicyPlugins.cpp TestBackendLogger.cpp TestGraphLogger.cpp TestHeuristicPlugin.cpp - TestHeuristicPluginIntegration.cpp TestHeuristicPluginManager.cpp + TestHeuristicPluginManagerValidationPaths.cpp + TestSelectionHeuristic.cpp TestHeuristicPluginResourceManager.cpp + heuristics/TestStaticOrderingBuiltIn.cpp + heuristics/TestEngineOverrideConfig.cpp + heuristics/TestConfigBuiltIn.cpp TestUserLoggingApis.cpp TestPlatformUtils.cpp TestPlugin.cpp @@ -137,6 +138,10 @@ target_compile_definitions( TEST_GOOD_HEURISTIC_PLUGIN_NAME="${TEST_GOOD_HEURISTIC_PLUGIN_NAME}" TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME="${TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME}" TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME="${TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME}" + TEST_BAD_API_VERSION_HEURISTIC_PLUGIN_NAME="${TEST_BAD_API_VERSION_HEURISTIC_PLUGIN_NAME}" + TEST_EMPTY_NAME_HEURISTIC_PLUGIN_NAME="${TEST_EMPTY_NAME_HEURISTIC_PLUGIN_NAME}" + TEST_DUPLICATE_POLICY_ID_A_PLUGIN_NAME="${TEST_DUPLICATE_POLICY_ID_A_PLUGIN_NAME}" + TEST_DUPLICATE_POLICY_ID_B_PLUGIN_NAME="${TEST_DUPLICATE_POLICY_ID_B_PLUGIN_NAME}" ) target_compile_options(hipdnn_backend_tests PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS} ${HIPDNN_NO_RTTI_OPTIONS}) clang_tidy_check(hipdnn_backend_tests) @@ -215,6 +220,10 @@ add_dependencies(hipdnn_backend_tests ${TEST_GOOD_HEURISTIC_PLUGIN_NAME} ${TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME} ${TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME} + ${TEST_BAD_API_VERSION_HEURISTIC_PLUGIN_NAME} + ${TEST_EMPTY_NAME_HEURISTIC_PLUGIN_NAME} + ${TEST_DUPLICATE_POLICY_ID_A_PLUGIN_NAME} + ${TEST_DUPLICATE_POLICY_ID_B_PLUGIN_NAME} ) add_unit_test_target(hipdnn_backend_tests ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/projects/hipdnn/backend/tests/IntegrationHeuristicPlugin.cpp b/projects/hipdnn/backend/tests/IntegrationHeuristicPlugin.cpp new file mode 100644 index 00000000000..25a8cf763ba --- /dev/null +++ b/projects/hipdnn/backend/tests/IntegrationHeuristicPlugin.cpp @@ -0,0 +1,753 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file IntegrationHeuristicPlugin.cpp + * @brief Integration tests for HeuristicPlugin workflow coverage + * + * These tests exercise full workflows with real plugins. The file holds four + * sibling fixtures, each focused on a distinct concept: + * - IntegrationHeuristicPlugin: resource-manager-mediated paths against a + * directory of test plugins. + * - IntegrationHeuristicPluginLoadedGood: direct construction of the good + * test plugin to exercise the HeuristicPlugin object surface. + * - IntegrationHeuristicPluginLoadedNoOptional: direct construction of a + * plugin missing optional symbols, verifying graceful degradation. + * - IntegrationHeuristicPluginIncomplete: a plugin missing required symbols + * must be rejected at construction time. + */ + +#include "HipdnnException.hpp" +#include "PlatformUtils.hpp" +#include "TestPluginConstants.hpp" +#include "plugin/HeuristicPlugin.hpp" +#include "plugin/HeuristicPluginManager.hpp" +#include "plugin/HeuristicPluginResourceManager.hpp" +#include "plugin/SharedLibrary.hpp" + +#include +#include +#include +#include +#include + +#include + +#include +#include + +using namespace hipdnn_backend; +using namespace hipdnn_backend::plugin; +using namespace hipdnn_backend::plugin_constants; + +namespace +{ +// Note: TEST_GOOD_HEURISTIC_PLUGIN_NAME, TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME, +// and TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME are defined as macros in CMakeLists.txt + +// Helper to serialize DevicePropertiesT using FlatBuffers Pack +std::vector + serializeDeviceProperties(const hipdnn_flatbuffers_sdk::data_objects::DevicePropertiesT& props) +{ + flatbuffers::FlatBufferBuilder builder(256); + auto offset = hipdnn_flatbuffers_sdk::data_objects::DeviceProperties::Pack(builder, &props); + builder.Finish(offset, "HDDP"); + return {builder.GetBufferPointer(), builder.GetBufferPointer() + builder.GetSize()}; +} + +// Wrapper class to access protected constructor +class TestableHeuristicPlugin : public HeuristicPlugin +{ +public: + explicit TestableHeuristicPlugin(SharedLibrary&& lib) + : HeuristicPlugin(std::move(lib)) + { + } +}; + +// RAII guards for plugin opaque handles. The handle and policy descriptor are +// allocated by the plugin's C ABI (`createHandle` / `createPolicyDescriptor`) +// and must be released via the matching destroy call. Without a guard, any +// ASSERT_* abort or thrown exception between create and destroy leaks under +// ASAN. Built on top of the shared ScopedResource utility. +inline auto makeScopedHandle(const HeuristicPlugin& plugin, hipdnnHeuristicHandle_t handle) +{ + return hipdnn_data_sdk::utilities::ScopedResource( + handle, [p = &plugin](hipdnnHeuristicHandle_t h) { + if(h != nullptr) + { + p->destroyHandle(h); + } + }); +} + +inline auto makeScopedPolicyDescriptor(const HeuristicPlugin& plugin, + hipdnnHeuristicPolicyDescriptor_t desc) +{ + return hipdnn_data_sdk::utilities::ScopedResource( + desc, [p = &plugin](hipdnnHeuristicPolicyDescriptor_t d) { + if(d != nullptr) + { + p->destroyPolicyDescriptor(d); + } + }); +} + +} // namespace + +// ==================================================================================== +// IntegrationHeuristicPlugin: resource-manager-mediated workflows +// ==================================================================================== + +class IntegrationHeuristicPlugin : public ::testing::Test +{ +protected: + void SetUp() override + { + // Set plugin path to test plugins directory + const auto testPluginDir = getHeuristicPluginPath("").parent_path(); + HeuristicPluginResourceManager::setHeuristicPluginPaths({testPluginDir}, + HIPDNN_PLUGIN_LOADING_ABSOLUTE); + } + + void TearDown() override + { + // Reset to default empty paths + HeuristicPluginResourceManager::setHeuristicPluginPaths({}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + } +}; + +// ========== Complete Workflow Tests ========== + +TEST_F(IntegrationHeuristicPlugin, CompleteHandleLifecycleWithGoodPlugin) +{ + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + // Should have loaded test plugins + const auto policyInfos = rm->getHeuristicPolicyInfos(); + ASSERT_FALSE(policyInfos.empty()); + + // Look up the good test plugin by its known policy id rather than scanning + // — getPluginForPolicyId is non-null for any id sourced from + // getHeuristicPolicyInfos(), so a "find the first non-null" loop is dead. + const auto goodPolicyId = hipdnn_data_sdk::utilities::policyNameToId("TestGoodHeuristicPolicy"); + const HeuristicPlugin* plugin = rm->getPluginForPolicyId(goodPolicyId); + hipdnnHeuristicHandle_t handle = rm->getHeuristicHandleForPolicyId(goodPolicyId); + + ASSERT_NE(plugin, nullptr); + ASSERT_NE(handle, nullptr); + + // Verify plugin metadata is available + EXPECT_FALSE(plugin->version().empty()); +} + +// ========== Basic Operation Tests ========== +// Note: Basic individual operations (createHandle, createPolicyDescriptor, setEngineIds, +// setSerializedGraph, finalize, getSortedEngineIds) are tested in IntegrationHeuristicPluginLoadedGood +// fixture with focused assertions. Tests here focus on resource manager integration. + +TEST_F(IntegrationHeuristicPlugin, SetDevicePropertiesOnHandle) +{ + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + const auto policyInfos = rm->getHeuristicPolicyInfos(); + ASSERT_FALSE(policyInfos.empty()); + + // policyInfos ordering comes from an unordered_map iteration; target the + // test plugin by known policy ID so behavior is stable across platforms. + const auto goodPolicyId = hipdnn_data_sdk::utilities::policyNameToId("TestGoodHeuristicPolicy"); + const HeuristicPlugin* plugin = rm->getPluginForPolicyId(goodPolicyId); + ASSERT_NE(plugin, nullptr); + + hipdnnHeuristicHandle_t handle = rm->getHeuristicHandleForPolicyId(goodPolicyId); + ASSERT_NE(handle, nullptr); + + // Create device properties + hipdnn_flatbuffers_sdk::data_objects::DevicePropertiesT props; + props.device_id = 0; + props.multi_processor_count = 120; + props.total_global_mem = 16ULL * 1024 * 1024 * 1024; // 16 GB + props.architecture_name = "gfx90a"; + + // Serialize + auto serialized = serializeDeviceProperties(props); + hipdnnPluginConstData_t devicePropsData; + devicePropsData.ptr = serialized.data(); + devicePropsData.size = serialized.size(); + + // Set on handle (should not throw) + EXPECT_NO_THROW(plugin->setDeviceProperties(handle, &devicePropsData)); +} + +TEST_F(IntegrationHeuristicPlugin, SetDevicePropertiesOnAllHandles) +{ + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + // Create device properties + hipdnn_flatbuffers_sdk::data_objects::DevicePropertiesT props; + props.device_id = 0; + props.multi_processor_count = 120; + props.total_global_mem = 16ULL * 1024 * 1024 * 1024; + props.architecture_name = "gfx90a"; + + auto serialized = serializeDeviceProperties(props); + hipdnnPluginConstData_t devicePropsData; + devicePropsData.ptr = serialized.data(); + devicePropsData.size = serialized.size(); + + // Set on all handles via resource manager + EXPECT_NO_THROW(rm->setDevicePropertiesOnAllHandles(&devicePropsData)); +} + +TEST_F(IntegrationHeuristicPlugin, CompleteWorkflowWithDevicePropertiesAndFinalize) +{ + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + const auto policyInfos = rm->getHeuristicPolicyInfos(); + ASSERT_FALSE(policyInfos.empty()); + + // Target the test plugin by known policy ID; built-ins reject the + // synthetic graph payload below. + const auto goodPolicyId = hipdnn_data_sdk::utilities::policyNameToId("TestGoodHeuristicPolicy"); + const HeuristicPlugin* plugin = rm->getPluginForPolicyId(goodPolicyId); + ASSERT_NE(plugin, nullptr); + + hipdnnHeuristicHandle_t handle = rm->getHeuristicHandleForPolicyId(goodPolicyId); + ASSERT_NE(handle, nullptr); + + // Set device properties on handle + hipdnn_flatbuffers_sdk::data_objects::DevicePropertiesT props; + props.device_id = 0; + props.multi_processor_count = 120; + props.total_global_mem = 16ULL * 1024 * 1024 * 1024; + props.architecture_name = "gfx90a"; + + auto serialized = serializeDeviceProperties(props); + hipdnnPluginConstData_t devicePropsData; + devicePropsData.ptr = serialized.data(); + devicePropsData.size = serialized.size(); + plugin->setDeviceProperties(handle, &devicePropsData); + + // Create policy descriptor (RAII so destroy runs even on ASSERT_* abort) + const auto descGuard + = makeScopedPolicyDescriptor(*plugin, plugin->createPolicyDescriptor(handle, goodPolicyId)); + ASSERT_NE(descGuard.get(), nullptr); + + // Set engine IDs + const std::vector engineIds = {1000, 2000, 3000}; + plugin->setEngineIds(descGuard.get(), engineIds.data(), engineIds.size()); + + // Set serialized graph + const std::vector graphBytes = {10, 20, 30}; + hipdnnPluginConstData_t serializedGraph; + serializedGraph.ptr = graphBytes.data(); + serializedGraph.size = graphBytes.size(); + plugin->setSerializedGraph(descGuard.get(), &serializedGraph); + + // Finalize + plugin->finalize(descGuard.get()); + + // Get results + const auto sortedIds = plugin->getSortedEngineIds(descGuard.get()); +} + +// ========== Plugin Metadata Coverage ========== +// Note: Plugin metadata queries (name, version, API version, policy ID) are tested +// in IntegrationHeuristicPluginLoadedGood fixture with more specific assertions + +TEST_F(IntegrationHeuristicPlugin, GetPluginTypeFromLoadedPlugin) +{ + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + const auto policyInfos = rm->getHeuristicPolicyInfos(); + ASSERT_FALSE(policyInfos.empty()); + + // policyInfos ordering comes from an unordered_map iteration; target the + // test plugin by known policy ID so behavior is stable across platforms. + const auto goodPolicyId = hipdnn_data_sdk::utilities::policyNameToId("TestGoodHeuristicPolicy"); + const HeuristicPlugin* plugin = rm->getPluginForPolicyId(goodPolicyId); + ASSERT_NE(plugin, nullptr); + + // Heuristic plugins report HEURISTIC type + const auto pluginType = plugin->type(); + EXPECT_EQ(pluginType, HIPDNN_PLUGIN_TYPE_HEURISTIC); +} + +// ========== Resource Manager Enumeration Coverage ========== + +TEST_F(IntegrationHeuristicPlugin, GetLoadedPluginFilesReturnsCorrectCount) +{ + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + size_t numPlugins = 0; + size_t maxStringLen = 0; + + rm->getLoadedPluginFiles(&numPlugins, nullptr, &maxStringLen); + + // Should have at least the test plugins + EXPECT_GT(numPlugins, 0u); +} + +TEST_F(IntegrationHeuristicPlugin, ToStringContainsPluginInformation) +{ + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + const auto str = rm->toString(); + + EXPECT_NE(str.find("HeuristicPluginResourceManager"), std::string::npos); + EXPECT_NE(str.find("Loaded plugins:"), std::string::npos); +} + +// ========== Multiple Descriptors Per Handle ========== + +TEST_F(IntegrationHeuristicPlugin, MultipleDescriptorsFromSameHandle) +{ + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + const auto policyInfos = rm->getHeuristicPolicyInfos(); + ASSERT_FALSE(policyInfos.empty()); + + // policyInfos ordering comes from an unordered_map iteration; target the + // test plugin by known policy ID so behavior is stable across platforms. + const auto policyId = hipdnn_data_sdk::utilities::policyNameToId("TestGoodHeuristicPolicy"); + const HeuristicPlugin* plugin = rm->getPluginForPolicyId(policyId); + ASSERT_NE(plugin, nullptr); + + hipdnnHeuristicHandle_t handle = rm->getHeuristicHandleForPolicyId(policyId); + ASSERT_NE(handle, nullptr); + + // Create multiple descriptors from the same handle (RAII-wrapped so any + // assertion abort below still releases them). + const auto desc1 + = makeScopedPolicyDescriptor(*plugin, plugin->createPolicyDescriptor(handle, policyId)); + const auto desc2 + = makeScopedPolicyDescriptor(*plugin, plugin->createPolicyDescriptor(handle, policyId)); + const auto desc3 + = makeScopedPolicyDescriptor(*plugin, plugin->createPolicyDescriptor(handle, policyId)); + + EXPECT_NE(desc1.get(), nullptr); + EXPECT_NE(desc2.get(), nullptr); + EXPECT_NE(desc3.get(), nullptr); + + // Note: Test plugins may return the same hardcoded pointer for simplicity, + // but real plugins should return distinct descriptors. We just verify they're created. +} + +// ========== Error Path Tests ========== +// These tests exercise error handling and edge cases + +// ========== Error Path: Device Properties Exceptions ========== + +TEST_F(IntegrationHeuristicPlugin, SetDevicePropertiesHandlesPluginFailures) +{ + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + // Create invalid device properties (empty buffer) + hipdnnPluginConstData_t invalidProps; + invalidProps.ptr = nullptr; + invalidProps.size = 0; + + // Should not throw even if some plugins fail - logs warning and continues + EXPECT_NO_THROW(rm->setDevicePropertiesOnAllHandles(&invalidProps)); +} + +// ========== Error Path: Missing Optional Functions ========== + +TEST_F(IntegrationHeuristicPlugin, SetPluginLogLevelHandlesMissingOptionalFunction) +{ + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + const auto policyInfos = rm->getHeuristicPolicyInfos(); + ASSERT_FALSE(policyInfos.empty()); + + // setPluginLogLevel should not throw even if optional function is missing + EXPECT_NO_THROW(rm->setPluginLogLevel(HIPDNN_SEV_INFO)); +} + +// ========== Error Path: Empty Engine IDs ========== + +TEST_F(IntegrationHeuristicPlugin, FinalizeWithEmptyEngineIdsSucceeds) +{ + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + const auto policyInfos = rm->getHeuristicPolicyInfos(); + ASSERT_FALSE(policyInfos.empty()); + + // Target the test plugin by known policy ID; built-ins reject finalize() + // without a real graph payload. + const auto goodPolicyId = hipdnn_data_sdk::utilities::policyNameToId("TestGoodHeuristicPolicy"); + const HeuristicPlugin* plugin = rm->getPluginForPolicyId(goodPolicyId); + ASSERT_NE(plugin, nullptr); + + hipdnnHeuristicHandle_t handle = rm->getHeuristicHandleForPolicyId(goodPolicyId); + ASSERT_NE(handle, nullptr); + + const auto descGuard + = makeScopedPolicyDescriptor(*plugin, plugin->createPolicyDescriptor(handle, goodPolicyId)); + ASSERT_NE(descGuard.get(), nullptr); + + // Don't set any engine IDs - just finalize + plugin->finalize(descGuard.get()); + + // Get sorted IDs (should be empty) + const auto sortedIds = plugin->getSortedEngineIds(descGuard.get()); + EXPECT_TRUE(sortedIds.empty()); +} + +// ========== Error Path: Multiple Policy Lookups (Same Handle/Plugin Reuse) ========== + +TEST_F(IntegrationHeuristicPlugin, MultipleGetHandleCallsReturnSameHandle) +{ + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + const auto policyInfos = rm->getHeuristicPolicyInfos(); + ASSERT_FALSE(policyInfos.empty()); + + // Pin to a known policy ID so the test does not depend on the + // unordered_map iteration order behind getHeuristicPolicyInfos(). + const auto policyId = hipdnn_data_sdk::utilities::policyNameToId("TestGoodHeuristicPolicy"); + + // Multiple calls should return the same handle (cached) + auto handle1 = rm->getHeuristicHandleForPolicyId(policyId); + auto handle2 = rm->getHeuristicHandleForPolicyId(policyId); + + EXPECT_EQ(handle1, handle2); + EXPECT_NE(handle1, nullptr); +} + +TEST_F(IntegrationHeuristicPlugin, MultipleGetPluginCallsReturnSamePlugin) +{ + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + const auto policyInfos = rm->getHeuristicPolicyInfos(); + ASSERT_FALSE(policyInfos.empty()); + + // Pin to a known policy ID so the test does not depend on the + // unordered_map iteration order behind getHeuristicPolicyInfos(). + const auto policyId = hipdnn_data_sdk::utilities::policyNameToId("TestGoodHeuristicPolicy"); + + // Multiple calls should return the same plugin pointer + const HeuristicPlugin* plugin1 = rm->getPluginForPolicyId(policyId); + const HeuristicPlugin* plugin2 = rm->getPluginForPolicyId(policyId); + + EXPECT_EQ(plugin1, plugin2); + EXPECT_NE(plugin1, nullptr); +} + +// ========== Error Path: No plugins loaded scenario ========== + +TEST_F(IntegrationHeuristicPlugin, SetDevicePropertiesWithNoPluginsLoaded) +{ + // Create RM with no plugins + HeuristicPluginResourceManager::setHeuristicPluginPaths({}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + auto rm = HeuristicPluginResourceManager::create(); + ASSERT_NE(rm, nullptr); + + hipdnn_flatbuffers_sdk::data_objects::DevicePropertiesT props; + props.device_id = 0; + props.multi_processor_count = 120; + props.total_global_mem = 16ULL * 1024 * 1024 * 1024; + props.architecture_name = "gfx90a"; + + auto serialized = serializeDeviceProperties(props); + hipdnnPluginConstData_t devicePropsData; + devicePropsData.ptr = serialized.data(); + devicePropsData.size = serialized.size(); + + // Should not throw when no plugins loaded + EXPECT_NO_THROW(rm->setDevicePropertiesOnAllHandles(&devicePropsData)); +} + +// ==================================================================================== +// IntegrationHeuristicPluginLoadedGood: direct construction of the good test plugin +// ==================================================================================== + +class IntegrationHeuristicPluginLoadedGood : public ::testing::Test +{ +protected: + void SetUp() override + { + const auto pluginPath = getHeuristicPluginPath(TEST_GOOD_HEURISTIC_PLUGIN_NAME); + ASSERT_TRUE(std::filesystem::exists(pluginPath)) + << "Test plugin not found: " << pluginPath + << "\nMake sure test_plugins are built before running tests"; + + SharedLibrary lib(pluginPath); + _pluginPtr = std::make_unique(std::move(lib)); + } + + void TearDown() override + { + _pluginPtr.reset(); + } + + TestableHeuristicPlugin& plugin() + { + return *_pluginPtr; + } + +private: + std::unique_ptr _pluginPtr; +}; + +TEST_F(IntegrationHeuristicPluginLoadedGood, LoadedPluginCanQueryApiVersion) +{ + const auto version = plugin().apiVersion(); + EXPECT_FALSE(version.empty()); + EXPECT_EQ(version, HIPDNN_HEURISTIC_API_VERSION); +} + +TEST_F(IntegrationHeuristicPluginLoadedGood, LoadedPluginCanQueryPolicyId) +{ + const auto policyIds = plugin().getAllPolicyIds(); + ASSERT_EQ(policyIds.size(), 1u); + const auto expectedId = hipdnn_data_sdk::utilities::policyNameToId("TestGoodHeuristicPolicy"); + EXPECT_EQ(policyIds.front(), expectedId); +} + +TEST_F(IntegrationHeuristicPluginLoadedGood, LoadedPluginCanQueryPluginName) +{ + const auto name = plugin().name(); + EXPECT_EQ(name, "TestGoodHeuristicPlugin"); +} + +TEST_F(IntegrationHeuristicPluginLoadedGood, LoadedPluginCanQueryPolicyName) +{ + const auto policyIds = plugin().getAllPolicyIds(); + ASSERT_EQ(policyIds.size(), 1u); + const auto policyName = plugin().getPolicyName(policyIds.front()); + EXPECT_EQ(policyName, "TestGoodHeuristicPolicy"); +} + +TEST_F(IntegrationHeuristicPluginLoadedGood, LoadedPluginCanQueryPluginVersion) +{ + const auto version = plugin().version(); + EXPECT_EQ(version, "1.0.0"); +} + +TEST_F(IntegrationHeuristicPluginLoadedGood, LoadedPluginCanGetSortedEngineIds) +{ + const auto handleGuard = makeScopedHandle(plugin(), plugin().createHandle()); + const auto policyId = plugin().getAllPolicyIds().front(); + const auto descGuard = makeScopedPolicyDescriptor( + plugin(), plugin().createPolicyDescriptor(handleGuard.get(), policyId)); + + const std::vector inputIds = {1, 2, 3, 4, 5}; + plugin().setEngineIds(descGuard.get(), inputIds.data(), inputIds.size()); + plugin().finalize(descGuard.get()); + + std::vector sortedIds; + ASSERT_NO_THROW({ sortedIds = plugin().getSortedEngineIds(descGuard.get()); }); + + // Good plugin reverses the order + EXPECT_EQ(sortedIds.size(), inputIds.size()); + EXPECT_EQ(sortedIds, std::vector({5, 4, 3, 2, 1})); +} + +TEST_F(IntegrationHeuristicPluginLoadedGood, RealPluginCachesPolicyIds) +{ + // First call - IDs are queried from the plugin + const auto ids1 = plugin().getAllPolicyIds(); + ASSERT_EQ(ids1.size(), 1u); + const auto expectedId = hipdnn_data_sdk::utilities::policyNameToId("TestGoodHeuristicPolicy"); + EXPECT_EQ(ids1.front(), expectedId); + + // Second call should return the cached vector + const auto ids2 = plugin().getAllPolicyIds(); + EXPECT_EQ(ids2, ids1); +} + +// ==================================================================================== +// IntegrationHeuristicPluginLoadedNoOptional: plugin missing optional symbols +// ==================================================================================== + +class IntegrationHeuristicPluginLoadedNoOptional : public ::testing::Test +{ +protected: + void SetUp() override + { + const auto pluginPath = getHeuristicPluginPath(TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME); + ASSERT_TRUE(std::filesystem::exists(pluginPath)) + << "Test plugin not found: " << pluginPath + << "\nMake sure test_plugins are built before running tests"; + + SharedLibrary lib(pluginPath); + _pluginPtr = std::make_unique(std::move(lib)); + } + + void TearDown() override + { + _pluginPtr.reset(); + } + + TestableHeuristicPlugin& plugin() + { + return *_pluginPtr; + } + +private: + std::unique_ptr _pluginPtr; +}; + +TEST_F(IntegrationHeuristicPluginLoadedNoOptional, PluginWithoutOptionalPolicyNameHasName) +{ + // hipdnnPluginGetName is required + const auto name = plugin().name(); + EXPECT_FALSE(name.empty()); + EXPECT_EQ(name, "TestNoOptionalHeuristicPlugin"); + + // Each policy has a non-empty name + for(const int64_t policyId : plugin().getAllPolicyIds()) + { + EXPECT_FALSE(plugin().getPolicyName(policyId).empty()); + } +} + +TEST_F(IntegrationHeuristicPluginLoadedNoOptional, PluginWithoutOptionalSetLogLevelSucceeds) +{ + // Plugin doesn't implement hipdnnHeuristicSetLogLevel + // Should return SUCCESS without calling the function + const auto status = plugin().setLogLevel(HIPDNN_SEV_INFO); + EXPECT_EQ(status, HIPDNN_PLUGIN_STATUS_SUCCESS); +} + +TEST_F(IntegrationHeuristicPluginLoadedNoOptional, PluginWithoutOptionalCanStillExecuteWorkflow) +{ + // Full workflow should work despite missing optional functions + const auto handleGuard = makeScopedHandle(plugin(), plugin().createHandle()); + ASSERT_NE(handleGuard.get(), nullptr); + + const auto policyId = plugin().getAllPolicyIds().front(); + const auto descGuard = makeScopedPolicyDescriptor( + plugin(), plugin().createPolicyDescriptor(handleGuard.get(), policyId)); + ASSERT_NE(descGuard.get(), nullptr); + + const std::vector inputIds = {1, 2, 3}; + plugin().setEngineIds(descGuard.get(), inputIds.data(), inputIds.size()); + + const bool applied = plugin().finalize(descGuard.get()); + EXPECT_FALSE(applied); // This plugin declines to apply + + const auto sortedIds = plugin().getSortedEngineIds(descGuard.get()); + EXPECT_TRUE(sortedIds.empty()); // Returns empty list +} + +// ==================================================================================== +// IntegrationHeuristicPluginIncomplete: plugin missing required symbols is rejected +// ==================================================================================== +// +// This fixture cannot pre-load the plugin in SetUp because construction is expected +// to throw. Each test loads the SharedLibrary and attempts construction in its body. + +class IntegrationHeuristicPluginIncomplete : public ::testing::Test +{ +protected: + static std::filesystem::path incompletePluginPath() + { + return getHeuristicPluginPath(TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME); + } +}; + +TEST_F(IntegrationHeuristicPluginIncomplete, LoadIncompletePluginThrowsException) +{ + const auto pluginPath = incompletePluginPath(); + + ASSERT_TRUE(std::filesystem::exists(pluginPath)) << "Test plugin not found: " << pluginPath; + + SharedLibrary lib(pluginPath); + + // Loading should fail during symbol resolution + EXPECT_THROW( + { + try + { + const TestableHeuristicPlugin plugin(std::move(lib)); + } + catch(const HipdnnException& e) + { + // Verify the exception contains expected error details + const std::string errorMsg(e.what()); + EXPECT_NE(errorMsg.find("HEURISTIC PLUGIN ABI INCOMPLETE"), std::string::npos); + EXPECT_NE(errorMsg.find("Missing required symbol"), std::string::npos); + // Error text uses SharedLibrary's weakly_canonical path; on Windows the string can + // differ in drive-letter case or separators from a fresh weakly_canonical(pluginPath). + const auto canonicalPath = std::filesystem::weakly_canonical(pluginPath); + static constexpr std::string_view K_PLUGIN_PREFIX{"Plugin: "}; + const auto prefixPos = errorMsg.find(K_PLUGIN_PREFIX); + ASSERT_NE(prefixPos, std::string::npos); + const auto pathStart = prefixPos + K_PLUGIN_PREFIX.size(); + const auto pathEnd = errorMsg.find('\n', pathStart); + ASSERT_NE(pathEnd, std::string::npos); + const std::filesystem::path pluginPathInMessage( + errorMsg.substr(pathStart, pathEnd - pathStart)); + EXPECT_TRUE( + hipdnn_data_sdk::utilities::pathCompEq(pluginPathInMessage, canonicalPath)) + << "pluginPathInMessage='" << pluginPathInMessage.string() + << "' canonicalPath='" << canonicalPath.string() << "'"; + throw; + } + }, + HipdnnException); +} + +TEST_F(IntegrationHeuristicPluginIncomplete, IncompletePluginExceptionContainsSymbolName) +{ + const auto pluginPath = incompletePluginPath(); + SharedLibrary lib(pluginPath); + + EXPECT_THROW( + { + try + { + const TestableHeuristicPlugin plugin(std::move(lib)); + } + catch(const HipdnnException& e) + { + const std::string errorMsg(e.what()); + // Should mention one of the missing required symbols + const bool hasPluginNameError + = errorMsg.find("hipdnnPluginGetName") != std::string::npos; + const bool hasFinalizeError + = errorMsg.find("hipdnnHeuristicPolicyFinalize") != std::string::npos; + const bool hasGetSortedError + = errorMsg.find("hipdnnHeuristicPolicyGetSortedEngineIds") != std::string::npos; + EXPECT_TRUE(hasPluginNameError || hasFinalizeError || hasGetSortedError); + throw; + } + }, + HipdnnException); +} + +TEST_F(IntegrationHeuristicPluginIncomplete, IncompletePluginExceptionHasPluginErrorStatus) +{ + const auto pluginPath = incompletePluginPath(); + SharedLibrary lib(pluginPath); + + EXPECT_THROW( + { + try + { + const TestableHeuristicPlugin plugin(std::move(lib)); + } + catch(const HipdnnException& e) + { + EXPECT_EQ(e.getStatus(), HIPDNN_STATUS_PLUGIN_ERROR); + throw; + } + }, + HipdnnException); +} diff --git a/projects/hipdnn/backend/tests/IntegrationHeuristicPolicyPlugins.cpp b/projects/hipdnn/backend/tests/IntegrationHeuristicPolicyPlugins.cpp new file mode 100644 index 00000000000..89da3d15838 --- /dev/null +++ b/projects/hipdnn/backend/tests/IntegrationHeuristicPolicyPlugins.cpp @@ -0,0 +1,453 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file IntegrationHeuristicPolicyPlugins.cpp + * @brief Integration tests for real heuristic policy plugins + * + * These tests verify the heuristic policy chain seen at runtime: + * - The StaticOrdering built-in (registered at HeuristicPluginManager + * construction time as a function-table-shaped pseudo-plugin). + * - Any external heuristic .so found in HIPDNN_HEURISTIC_PLUGIN_DIR. + * - Plugin discovery and loading from installed location + * - Symbol resolution and ABI validation + * - Plugin handle creation and lifecycle + * - Policy descriptor creation and execution + * - API version compatibility + * - Policy ID/name consistency + */ + +#include "PlatformUtils.hpp" +#include "descriptors/EngineHeuristicDescriptor.hpp" +#include "descriptors/GraphDescriptor.hpp" +#include "handle/Handle.hpp" +#include "plugin/HeuristicPlugin.hpp" +#include "plugin/HeuristicPluginManager.hpp" +#include "plugin/HeuristicPluginResourceManager.hpp" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +using namespace hipdnn_backend; +using namespace hipdnn_backend::plugin; + +namespace +{ +// RAII guards so ASSERT_* aborts mid-test do not leak the underlying +// plugin-allocated resources under ASAN. Mirrors makeScopedHandle / +// makeScopedPolicyDescriptor in IntegrationHeuristicPlugin.cpp. +inline auto makeScopedHipdnnHandle(hipdnnHandle_t handle) +{ + return hipdnn_data_sdk::utilities::ScopedResource(handle, [](hipdnnHandle_t h) { + if(h != nullptr) + { + hipdnnDestroy(h); + } + }); +} + +inline auto makeScopedPolicyDescriptor(const HeuristicPlugin& plugin, + hipdnnHeuristicPolicyDescriptor_t desc) +{ + return hipdnn_data_sdk::utilities::ScopedResource( + desc, [p = &plugin](hipdnnHeuristicPolicyDescriptor_t d) { + if(d != nullptr) + { + p->destroyPolicyDescriptor(d); + } + }); +} + +// Helper to get the plugin directory for tests +// Tests binaries are in build/bin/, plugins could be in: +// - build/lib/hipdnn_plugins/heuristics/ (Linux/Unix) +// - build/bin/hipdnn_plugins/heuristics/ (Windows DLLs) +std::filesystem::path getTestPluginDirectory() +{ + // First, check for environment variable override + const auto envPath = hipdnn_data_sdk::utilities::getEnv("HIPDNN_HEURISTIC_PLUGIN_DIR"); + if(!envPath.empty()) + { + return {envPath}; + } + + // Get the directory containing the test binary + const auto testBinDir = hipdnn_backend::platform_utilities::getCurrentModuleDirectory(); + const auto buildRoot = testBinDir.parent_path(); + + // Try multiple possible locations + const std::vector candidatePaths = { + buildRoot / "lib" / "hipdnn_plugins" / "heuristics", // Linux/Unix + buildRoot / "bin" / "hipdnn_plugins" / "heuristics", // Windows DLLs + buildRoot / "lib64" / "hipdnn_plugins" / "heuristics", // lib64 systems + }; + + // Return the first path that exists and contains plugin files + for(const auto& path : candidatePaths) + { + if(std::filesystem::exists(path)) + { + // Check if directory contains any .so or .dll files + for(const auto& entry : std::filesystem::directory_iterator(path)) + { + const auto ext = entry.path().extension(); + if(ext == ".so" || ext == ".dll" || ext == ".dylib") + { + return path; + } + } + } + } + + // Fallback to original behavior (lib) + return buildRoot / "lib" / "hipdnn_plugins" / "heuristics"; +} +} // anonymous namespace + +class IntegrationHeuristicPolicyPlugins : public ::testing::Test +{ +protected: + void SetUp() override + { + // hipdnnCreate loads real heuristic plugins (e.g. hipBLASLt in the + // superbuild) whose initializers probe the device. Skip on no-GPU + // runners to avoid a hard abort from the plugin's HIP error path. + SKIP_IF_NO_DEVICES(); + const hipdnnStatus_t status = hipdnnCreate(&_handle); + ASSERT_EQ(status, HIPDNN_STATUS_SUCCESS); + ASSERT_NE(_handle, nullptr); + } + + void TearDown() override + { + if(_handle != nullptr) + { + hipdnnDestroy(_handle); + _handle = nullptr; + } + } + + hipdnnHandle_t _handle = nullptr; +}; + +// ========== Plugin Discovery Tests ========== + +TEST_F(IntegrationHeuristicPolicyPlugins, PluginManagerLoadsFromDefaultPath) +{ + // Create manager and load plugins + auto manager = std::make_shared(); + // For tests, explicitly pass the plugin directory since getCurrentModuleDirectory + // resolves to the test binary's location (bin/) not the backend library's (lib/) + manager->loadPlugins({getTestPluginDirectory()}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + // Plugin handles are created by HeuristicPluginResourceManager, not by the bare + // manager. The resource manager creates handles and can enumerate loaded plugins. + auto resourceMgr = std::make_shared(manager); + + // Enumerate loaded policies via resource manager + const auto& policyInfos = resourceMgr->getHeuristicPolicyInfos(); + + // The StaticOrdering built-in is always registered; vendor plugins (if any) + // discovered under the search path raise the count further. + EXPECT_GE(policyInfos.size(), 1u) << "Expected at least the StaticOrdering built-in"; +} + +TEST_F(IntegrationHeuristicPolicyPlugins, PluginManagerRejectsInvalidPlugins) +{ + // Plugins with wrong ABI version should be rejected during validation + // This is tested by HeuristicPluginManager::validateBeforeAdding() + + auto manager = std::make_shared(); + manager->loadPlugins({getTestPluginDirectory()}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + // Create resource manager to access loaded plugins + auto resourceMgr = std::make_shared(manager); + const auto& policyInfos = resourceMgr->getHeuristicPolicyInfos(); + + // All loaded plugins should have valid metadata + for(const auto& info : policyInfos) + { + EXPECT_NE(info.policyId, -1) << "Policy ID should be valid"; + EXPECT_FALSE(info.apiVersion.empty()) << "API version should not be empty"; + EXPECT_FALSE(info.pluginVersion.empty()) << "Plugin version should not be empty"; + } +} + +TEST_F(IntegrationHeuristicPolicyPlugins, PluginManagerRejectsDuplicatePolicyIds) +{ + // HeuristicPluginManager should reject plugins with duplicate policy IDs + // This is enforced by validateBeforeAdding() + + auto manager = std::make_shared(); + manager->loadPlugins({getTestPluginDirectory()}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + // Create resource manager to enumerate loaded plugins + auto resourceMgr = std::make_shared(manager); + const auto& policyInfos = resourceMgr->getHeuristicPolicyInfos(); + + // Collect all policy IDs + std::set policyIds; + for(const auto& info : policyInfos) + { + const int64_t id = info.policyId; + EXPECT_EQ(policyIds.count(id), 0u) << "Duplicate policy ID detected: " << id; + policyIds.insert(id); + } +} + +// ========== Symbol Resolution Tests ========== + +TEST_F(IntegrationHeuristicPolicyPlugins, LoadedPluginsHaveRequiredSymbols) +{ + auto heurRm = _handle->getHeuristicPluginResourceManager(); + auto policyInfos = heurRm->getHeuristicPolicyInfos(); + + ASSERT_GT(policyInfos.size(), 0u); + + // Each loaded plugin must have successfully resolved required symbols + // If symbol resolution failed, the plugin wouldn't be in the list + for(const auto& info : policyInfos) + { + EXPECT_NE(info.policyId, -1); + EXPECT_FALSE(info.policyName.empty()); + EXPECT_FALSE(info.apiVersion.empty()); + EXPECT_FALSE(info.pluginVersion.empty()); + } +} + +// ========== Handle Lifecycle Tests ========== + +TEST_F(IntegrationHeuristicPolicyPlugins, ResourceManagerCreatesHandlesForAllPlugins) +{ + auto heurRm = _handle->getHeuristicPluginResourceManager(); + ASSERT_NE(heurRm, nullptr); + + auto policyInfos = heurRm->getHeuristicPolicyInfos(); + + // Should have created a handle for each loaded plugin + for(const auto& info : policyInfos) + { + auto handle = heurRm->getHeuristicHandleForPolicyId(info.policyId); + EXPECT_NE(handle, nullptr) << "Handle should exist for policy ID " << info.policyId; + + auto plugin = heurRm->getPluginForPolicyId(info.policyId); + EXPECT_NE(plugin, nullptr) << "Plugin should exist for policy ID " << info.policyId; + } +} + +TEST_F(IntegrationHeuristicPolicyPlugins, HandleDestructionCleansUpResources) +{ + // Create and destroy a handle + hipdnnHandle_t tempHandle = nullptr; + ASSERT_EQ(hipdnnCreate(&tempHandle), HIPDNN_STATUS_SUCCESS); + auto scopedHandle = makeScopedHipdnnHandle(tempHandle); + + // Get resource manager (creates plugin handles) + auto heurRm = tempHandle->getHeuristicPluginResourceManager(); + ASSERT_NE(heurRm, nullptr); + + const size_t policyCount = heurRm->getHeuristicPolicyInfos().size(); + EXPECT_GT(policyCount, 0u); +} + +// ========== Policy Descriptor Tests ========== + +TEST_F(IntegrationHeuristicPolicyPlugins, PolicyDescriptorCreationSucceeds) +{ + auto heurRm = _handle->getHeuristicPluginResourceManager(); + auto policyInfos = heurRm->getHeuristicPolicyInfos(); + ASSERT_GT(policyInfos.size(), 0u); + + // Pin to the StaticOrdering built-in: it is always registered, so the + // test is deterministic regardless of which vendor plugins are present + // (policyInfos ordering is derived from an unordered_map iteration). + const auto policyId + = hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"); + auto pluginHandle = heurRm->getHeuristicHandleForPolicyId(policyId); + auto plugin = heurRm->getPluginForPolicyId(policyId); + + ASSERT_NE(pluginHandle, nullptr); + ASSERT_NE(plugin, nullptr); + + // Create policy descriptor + auto descriptor = makeScopedPolicyDescriptor( + *plugin, plugin->createPolicyDescriptor(pluginHandle, policyId)); + EXPECT_NE(descriptor.get(), nullptr); +} + +// ========== Logging Tests ========== + +TEST_F(IntegrationHeuristicPolicyPlugins, PluginsReceiveLoggingCallback) +{ + // Verify that setLoggingCallback was wired up during plugin initialization. + // If the callback registration failed, resource manager construction would have + // logged warnings; here we at least confirm the manager came up. + auto heurRm = _handle->getHeuristicPluginResourceManager(); + ASSERT_NE(heurRm, nullptr); +} + +// ========== Device Properties Tests ========== + +TEST_F(IntegrationHeuristicPolicyPlugins, DevicePropertiesAreSetOnAllHandles) +{ + auto heurRm = _handle->getHeuristicPluginResourceManager(); + ASSERT_NE(heurRm, nullptr); + + // Create serialized device properties using FlatBuffers + hipdnn_flatbuffers_sdk::data_objects::DevicePropertiesT devProps; + devProps.device_id = 0; + devProps.multi_processor_count = 64; + devProps.total_global_mem = 8ULL * 1024 * 1024 * 1024; + devProps.architecture_name = "gfx90a"; + + // Serialize using FlatBuffers + flatbuffers::FlatBufferBuilder builder(256); + auto offset = hipdnn_flatbuffers_sdk::data_objects::DeviceProperties::Pack(builder, &devProps); + builder.Finish(offset, "HDDP"); + + // Create wrapper + hipdnnPluginConstData_t wrapper; + wrapper.ptr = builder.GetBufferPointer(); + wrapper.size = builder.GetSize(); + + // Set on all handles (should not throw) + EXPECT_NO_THROW(heurRm->setDevicePropertiesOnAllHandles(&wrapper)); +} + +// ========== Policy ID Consistency Tests ========== + +TEST_F(IntegrationHeuristicPolicyPlugins, PolicyIdMatchesNameHash) +{ + auto heurRm = _handle->getHeuristicPluginResourceManager(); + auto policyInfos = heurRm->getHeuristicPolicyInfos(); + + for(const auto& info : policyInfos) + { + if(!info.policyName.empty()) + { + // Policy ID should match policyNameToId(policyName) + const int64_t expectedId = hipdnn_data_sdk::utilities::policyNameToId(info.policyName); + EXPECT_EQ(info.policyId, expectedId) << "Policy ID mismatch for " << info.policyName; + } + } +} + +// ========== API Version Tests ========== + +TEST_F(IntegrationHeuristicPolicyPlugins, AllPluginsHaveCompatibleApiVersion) +{ + auto heurRm = _handle->getHeuristicPluginResourceManager(); + auto policyInfos = heurRm->getHeuristicPolicyInfos(); + + // All loaded plugins should have compatible API versions + // (major version matches the heuristic API version) + for(const auto& info : policyInfos) + { + EXPECT_FALSE(info.apiVersion.empty()); + + // Parse version + const hipdnn_data_sdk::utilities::Version apiVer{info.apiVersion}; + + // Major version should match heuristic API (independent of backend version) + EXPECT_EQ(apiVer.major, HIPDNN_HEURISTIC_API_VERSION_MAJOR) + << "Plugin " << info.policyName << " has incompatible API major version"; + } +} + +// ========== Enumeration Consistency Tests ========== + +TEST_F(IntegrationHeuristicPolicyPlugins, EnumerationMatchesResourceManager) +{ + auto heurRm = _handle->getHeuristicPluginResourceManager(); + auto rmInfos = heurRm->getHeuristicPolicyInfos(); + + // Get count via C API + size_t apiCount = 0; + ASSERT_EQ(hipdnnGetHeuristicPolicyCount_ext(_handle, &apiCount), HIPDNN_STATUS_SUCCESS); + + // Counts should match + EXPECT_EQ(apiCount, rmInfos.size()); + + // Enumeration order of getHeuristicPolicyInfos() is documented as unspecified + // (built from an unordered_map). Compare by set of policy IDs instead of by + // index so the test does not implicitly depend on the cache happening to + // return the same vector for two successive calls. + std::set rmPolicyIds; + for(const auto& info : rmInfos) + { + rmPolicyIds.insert(info.policyId); + } + + std::set apiPolicyIds; + for(size_t i = 0; i < apiCount; ++i) + { + int64_t apiPolicyId = -1; + size_t nameLen = 0; + size_t pluginNameLen = 0; + size_t pluginVerLen = 0; + size_t apiVerLen = 0; + + ASSERT_EQ(hipdnnGetHeuristicPolicyInfo_ext(_handle, + i, + &apiPolicyId, + nullptr, + &nameLen, + nullptr, + &pluginNameLen, + nullptr, + &pluginVerLen, + nullptr, + &apiVerLen), + HIPDNN_STATUS_SUCCESS); + + apiPolicyIds.insert(apiPolicyId); + } + + EXPECT_EQ(apiPolicyIds, rmPolicyIds); +} + +// ========== Stress Tests ========== + +TEST_F(IntegrationHeuristicPolicyPlugins, MultipleResourceManagersCanCoexist) +{ + // Create multiple handles, each with its own resource manager. + // ScopedResource entries destroy their handle on test exit (including + // ASSERT_* short-circuit), so a mid-loop abort cannot leak earlier ones. + std::vector> handles; + + for(int i = 0; i < 5; ++i) + { + hipdnnHandle_t h = nullptr; + ASSERT_EQ(hipdnnCreate(&h), HIPDNN_STATUS_SUCCESS); + handles.push_back(makeScopedHipdnnHandle(h)); + + // Access resource manager (triggers creation) + auto heurRm = h->getHeuristicPluginResourceManager(); + EXPECT_NE(heurRm, nullptr); + } +} + +// ========== Error Recovery Tests ========== + +TEST_F(IntegrationHeuristicPolicyPlugins, MissingPolicyGracefullyHandled) +{ + auto heurRm = _handle->getHeuristicPluginResourceManager(); + + // Query a non-existent policy ID + const int64_t fakePolicyId = 0x1234567890ABCDEF; + auto handle = heurRm->getHeuristicHandleForPolicyId(fakePolicyId); + auto plugin = heurRm->getPluginForPolicyId(fakePolicyId); + + // Should return nullptr, not crash + EXPECT_EQ(handle, nullptr); + EXPECT_EQ(plugin, nullptr); +} +// ========== Workflow Tests with Test Plugins (from pr1) ========== diff --git a/projects/hipdnn/backend/tests/TestHeuristicPlugin.cpp b/projects/hipdnn/backend/tests/TestHeuristicPlugin.cpp index d6e733e0bdf..e10b60808ae 100644 --- a/projects/hipdnn/backend/tests/TestHeuristicPlugin.cpp +++ b/projects/hipdnn/backend/tests/TestHeuristicPlugin.cpp @@ -3,13 +3,12 @@ /** * @file TestHeuristicPlugin.cpp - * @brief Unit tests for HeuristicPlugin class (RFC 0007 Part 1) + * @brief Unit tests for HeuristicPlugin's load-time validation helpers. * - * These tests verify the plugin wrapper class including: - * - Symbol resolution and error handling - * - Plugin metadata access - * - Handle lifecycle - * - Policy descriptor lifecycle + * Workflow / call-sequence / metadata-passthrough behaviors are covered via + * real plugins in IntegrationHeuristicPlugin.cpp — gmock can only round-trip + * its own configured returns, which would not exercise any HeuristicPlugin + * logic. */ #include "HipdnnException.hpp" @@ -19,215 +18,78 @@ #include #include +#include + using namespace hipdnn_backend; using namespace hipdnn_backend::plugin; using ::testing::NiceMock; using ::testing::Return; -namespace -{ -// Helper to create fake handles for testing -// NOLINTBEGIN(performance-no-int-to-ptr) -hipdnnHeuristicHandle_t makeFakeHandle(int id) -{ - return reinterpret_cast(static_cast(id)); -} - -hipdnnHeuristicPolicyDescriptor_t makeFakePolicyDescriptor(int id) -{ - return reinterpret_cast(static_cast(id)); -} -// NOLINTEND(performance-no-int-to-ptr) -} // anonymous namespace - class TestHeuristicPlugin : public ::testing::Test { -protected: - void SetUp() override - { - // Tests use mocks, not real plugin shared libraries - } - - void TearDown() override {} }; -// ========== Mock Plugin Behavior Tests ========== -// Note: Trivial single-method mocking tests removed - integration tests with real plugins -// provide better coverage of actual behavior - -// ========== Complete Workflow Tests ========== -// Note: Complete workflow is tested with real plugins in TestHeuristicPluginLoading - -// ========== Multiple Handles Tests ========== +// ========== Plugin Metadata Validation ========== +// HeuristicPlugin::validatePluginMetadata is the load-time gate invoked from +// resolveSymbols(). Each test below pins one specific rejection path that is +// otherwise unreachable without a dedicated test plugin .so. -TEST_F(TestHeuristicPlugin, MockPluginCanManageMultipleHandles) +namespace { - const NiceMock plugin; - - const auto handle1 = makeFakeHandle(1); - const auto handle2 = makeFakeHandle(2); - const auto handle3 = makeFakeHandle(3); - - EXPECT_CALL(plugin, createHandle()) - .WillOnce(Return(handle1)) - .WillOnce(Return(handle2)) - .WillOnce(Return(handle3)); +const std::string_view GOOD_POLICY_NAME = "TestPolicy::Good"; +const int64_t GOOD_POLICY_ID = hipdnn_data_sdk::utilities::policyNameToId("TestPolicy::Good"); +} // namespace - // Create multiple handles - const auto h1 = plugin.createHandle(); - const auto h2 = plugin.createHandle(); - const auto h3 = plugin.createHandle(); - - EXPECT_EQ(h1, handle1); - EXPECT_EQ(h2, handle2); - EXPECT_EQ(h3, handle3); - - // All handles should be unique - EXPECT_NE(h1, h2); - EXPECT_NE(h2, h3); - EXPECT_NE(h1, h3); -} - -// ========== Multiple Policy Descriptors Tests ========== - -TEST_F(TestHeuristicPlugin, MockPluginCanManageMultiplePolicyDescriptors) +TEST_F(TestHeuristicPlugin, ValidatePluginMetadataRejectsNonHeuristicPluginType) { const NiceMock plugin; + EXPECT_CALL(plugin, type()).WillRepeatedly(Return(HIPDNN_PLUGIN_TYPE_ENGINE)); - const auto handle = makeFakeHandle(42); - const auto desc1 = makeFakePolicyDescriptor(1); - const auto desc2 = makeFakePolicyDescriptor(2); - const auto desc3 = makeFakePolicyDescriptor(3); - - EXPECT_CALL(plugin, createPolicyDescriptor(handle)) - .WillOnce(Return(desc1)) - .WillOnce(Return(desc2)) - .WillOnce(Return(desc3)); - - // Create multiple descriptors from same handle - const auto d1 = plugin.createPolicyDescriptor(handle); - const auto d2 = plugin.createPolicyDescriptor(handle); - const auto d3 = plugin.createPolicyDescriptor(handle); - - EXPECT_EQ(d1, desc1); - EXPECT_EQ(d2, desc2); - EXPECT_EQ(d3, desc3); - - // All descriptors should be unique - EXPECT_NE(d1, d2); - EXPECT_NE(d2, d3); - EXPECT_NE(d1, d3); + EXPECT_THROW(HeuristicPlugin::validatePluginMetadata(plugin), HipdnnException); } -// ========== Call Count Verification Tests ========== - -TEST_F(TestHeuristicPlugin, MockPluginTracksCallCounts) +TEST_F(TestHeuristicPlugin, ValidatePluginMetadataRejectsPolicyIdNameHashMismatch) { const NiceMock plugin; - - const auto handle = makeFakeHandle(42); - - EXPECT_CALL(plugin, createHandle()).Times(3).WillRepeatedly(Return(handle)); - - // Create handle 3 times - plugin.createHandle(); - plugin.createHandle(); - plugin.createHandle(); - - // Expectations verified by gmock + EXPECT_CALL(plugin, type()).WillRepeatedly(Return(HIPDNN_PLUGIN_TYPE_HEURISTIC)); + EXPECT_CALL(plugin, name()).WillRepeatedly(Return("MockPlugin")); + // Plugin reports policy name "TestPolicy::Good" but tags it with an ID that + // is NOT policyNameToId("TestPolicy::Good"). validatePluginMetadata must + // reject this mismatch. + const int64_t bogusId = GOOD_POLICY_ID ^ int64_t { 0xDEADBEEF }; + EXPECT_CALL(plugin, getAllPolicyIds()).WillRepeatedly(Return(std::vector{bogusId})); + EXPECT_CALL(plugin, getPolicyName(bogusId)).WillRepeatedly(Return(GOOD_POLICY_NAME)); + + EXPECT_THROW(HeuristicPlugin::validatePluginMetadata(plugin), HipdnnException); } -TEST_F(TestHeuristicPlugin, MockPluginVerifiesCallSequence) -{ - const NiceMock plugin; - - const auto handle = makeFakeHandle(42); - const auto descriptor = makeFakePolicyDescriptor(100); - - { - const ::testing::InSequence seq; - - EXPECT_CALL(plugin, createHandle()).WillOnce(Return(handle)); +// ========== Policy ID Buffer Validation ========== +// HeuristicPlugin::validatePolicyIdsBuffer is invoked from getAllPolicyIds() +// after the second (fetch) call into the plugin and gates the lazy enumeration +// cache. Tests below exercise it directly with raw buffers since +// MockHeuristicPlugin overrides getAllPolicyIds() entirely. - EXPECT_CALL(plugin, createPolicyDescriptor(handle)).WillOnce(Return(descriptor)); - - EXPECT_CALL(plugin, destroyPolicyDescriptor(descriptor)); - - EXPECT_CALL(plugin, destroyHandle(handle)); - } - - // Execute in expected order - const auto h = plugin.createHandle(); - const auto d = plugin.createPolicyDescriptor(h); - plugin.destroyPolicyDescriptor(d); - plugin.destroyHandle(h); -} - -// ========== Edge Cases Tests ========== - -TEST_F(TestHeuristicPlugin, MockPluginCanReturnNullHandle) +TEST_F(TestHeuristicPlugin, ValidatePolicyIdsBufferRejectsZeroPolicyCount) { - const NiceMock plugin; - - EXPECT_CALL(plugin, createHandle()).WillOnce(Return(nullptr)); - - const auto handle = plugin.createHandle(); - EXPECT_EQ(handle, nullptr); -} - -TEST_F(TestHeuristicPlugin, MockPluginCanReturnNullDescriptor) -{ - const NiceMock plugin; - - const auto handle = makeFakeHandle(42); - - EXPECT_CALL(plugin, createPolicyDescriptor(handle)).WillOnce(Return(nullptr)); - - const auto descriptor = plugin.createPolicyDescriptor(handle); - EXPECT_EQ(descriptor, nullptr); + std::vector ids; + EXPECT_THROW(HeuristicPlugin::validatePolicyIdsBuffer(0, 0, ids), HipdnnException); } -// ========== Policy ID Caching Tests ========== - -TEST_F(TestHeuristicPlugin, MockPluginPolicyIdCanBeCached) +TEST_F(TestHeuristicPlugin, ValidatePolicyIdsBufferRejectsCountMismatchBetweenQueries) { - const NiceMock plugin; - - const int64_t testPolicyId = 0xABCDEF; - - // First call should query the mock - EXPECT_CALL(plugin, policyId()).Times(2).WillRepeatedly(Return(testPolicyId)); - - // Multiple calls - const int64_t id1 = plugin.policyId(); - const int64_t id2 = plugin.policyId(); - - EXPECT_EQ(id1, testPolicyId); - EXPECT_EQ(id2, testPolicyId); + std::vector ids = {10, 20, 30}; + EXPECT_THROW(HeuristicPlugin::validatePolicyIdsBuffer(3, 2, ids), HipdnnException); } -// ========== Policy Name Edge Cases ========== - -TEST_F(TestHeuristicPlugin, MockPluginEmptyPolicyNameIsValid) +TEST_F(TestHeuristicPlugin, ValidatePolicyIdsBufferRejectsIntraPluginDuplicateIds) { - const NiceMock plugin; - - EXPECT_CALL(plugin, name()).WillOnce(Return("")); - - const auto name = plugin.name(); - EXPECT_TRUE(name.empty()); + std::vector ids = {42, 42}; + EXPECT_THROW(HeuristicPlugin::validatePolicyIdsBuffer(2, 2, ids), HipdnnException); } -TEST_F(TestHeuristicPlugin, MockPluginLongPolicyNameIsValid) +TEST_F(TestHeuristicPlugin, ValidatePolicyIdsBufferAcceptsValidUniqueIdsAndSorts) { - const NiceMock plugin; - - const std::string_view longName = "VeryLongPolicyNameThatExceedsTypicalLengthsButIsStillValid"; - EXPECT_CALL(plugin, name()).WillOnce(Return(longName)); - - const auto name = plugin.name(); - EXPECT_EQ(name, longName); + std::vector ids = {30, 10, 20}; + EXPECT_NO_THROW(HeuristicPlugin::validatePolicyIdsBuffer(3, 3, ids)); + EXPECT_EQ(ids, (std::vector{10, 20, 30})); } - -// Base class interface tests (name, version, type) are covered by loading tests -// since they delegate to virtual methods that can't be effectively tested with mocks diff --git a/projects/hipdnn/backend/tests/TestHeuristicPluginIntegration.cpp b/projects/hipdnn/backend/tests/TestHeuristicPluginIntegration.cpp deleted file mode 100644 index 6cc5beefe63..00000000000 --- a/projects/hipdnn/backend/tests/TestHeuristicPluginIntegration.cpp +++ /dev/null @@ -1,966 +0,0 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -/** - * @file TestHeuristicPluginIntegration.cpp - * @brief Integration tests for HeuristicPlugin workflow coverage - * - * These tests exercise full workflows with real plugins to improve coverage: - * - Complete handle and descriptor lifecycle - * - Device properties serialization and setting - * - Engine ID setting and finalization - * - Error handling paths - */ - -#include "HipdnnException.hpp" -#include "PlatformUtils.hpp" -#include "TestPluginConstants.hpp" -#include "plugin/HeuristicPlugin.hpp" -#include "plugin/HeuristicPluginManager.hpp" -#include "plugin/HeuristicPluginResourceManager.hpp" -#include "plugin/SharedLibrary.hpp" - -#include -#include -#include - -#include - -#include -#include - -using namespace hipdnn_backend; -using namespace hipdnn_backend::plugin; -using namespace hipdnn_backend::plugin_constants; - -namespace -{ -// Note: TEST_GOOD_HEURISTIC_PLUGIN_NAME, TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME, -// and TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME are defined as macros in CMakeLists.txt - -// Helper to serialize DevicePropertiesT using FlatBuffers Pack -std::vector - serializeDeviceProperties(const hipdnn_flatbuffers_sdk::data_objects::DevicePropertiesT& props) -{ - flatbuffers::FlatBufferBuilder builder(256); - auto offset = hipdnn_flatbuffers_sdk::data_objects::DeviceProperties::Pack(builder, &props); - builder.Finish(offset, "HDDP"); - return {builder.GetBufferPointer(), builder.GetBufferPointer() + builder.GetSize()}; -} -} // namespace - -class IntegrationHeuristicPlugin : public ::testing::Test -{ -protected: - void SetUp() override - { - // Set plugin path to test plugins directory - const auto testPluginDir = getHeuristicPluginPath("").parent_path(); - HeuristicPluginResourceManager::setHeuristicPluginPaths({testPluginDir}, - HIPDNN_PLUGIN_LOADING_ABSOLUTE); - } - - void TearDown() override - { - // Reset to default empty paths - HeuristicPluginResourceManager::setHeuristicPluginPaths({}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); - } -}; - -// ========== Complete Workflow Tests ========== - -TEST_F(IntegrationHeuristicPlugin, CompleteHandleLifecycleWithGoodPlugin) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - // Should have loaded test plugins - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - // Find the good test plugin - const HeuristicPlugin* plugin = nullptr; - hipdnnHeuristicHandle_t handle = nullptr; - for(const auto& info : policyInfos) - { - plugin = rm->getPluginForPolicyId(info.policyId); - if(plugin != nullptr) - { - handle = rm->getHeuristicHandleForPolicyId(info.policyId); - break; - } - } - - ASSERT_NE(plugin, nullptr); - ASSERT_NE(handle, nullptr); - - // Verify plugin metadata is available - EXPECT_FALSE(plugin->version().empty()); -} - -TEST_F(IntegrationHeuristicPlugin, CompletePolicyDescriptorLifecycle) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - const HeuristicPlugin* plugin = rm->getPluginForPolicyId(policyInfos[0].policyId); - ASSERT_NE(plugin, nullptr); - - hipdnnHeuristicHandle_t handle = rm->getHeuristicHandleForPolicyId(policyInfos[0].policyId); - ASSERT_NE(handle, nullptr); - - // Create policy descriptor - auto desc = plugin->createPolicyDescriptor(handle); - EXPECT_NE(desc, nullptr); - - // Destroy policy descriptor (should not throw) - EXPECT_NO_THROW(plugin->destroyPolicyDescriptor(desc)); -} - -TEST_F(IntegrationHeuristicPlugin, SetEngineIdsOnPolicyDescriptor) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - const HeuristicPlugin* plugin = rm->getPluginForPolicyId(policyInfos[0].policyId); - ASSERT_NE(plugin, nullptr); - - hipdnnHeuristicHandle_t handle = rm->getHeuristicHandleForPolicyId(policyInfos[0].policyId); - ASSERT_NE(handle, nullptr); - - auto desc = plugin->createPolicyDescriptor(handle); - ASSERT_NE(desc, nullptr); - - // Set engine IDs - const std::vector engineIds = {1, 2, 3}; - EXPECT_NO_THROW(plugin->setEngineIds(desc, engineIds.data(), engineIds.size())); - - plugin->destroyPolicyDescriptor(desc); -} - -TEST_F(IntegrationHeuristicPlugin, SetSerializedGraphOnPolicyDescriptor) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - const HeuristicPlugin* plugin = rm->getPluginForPolicyId(policyInfos[0].policyId); - ASSERT_NE(plugin, nullptr); - - hipdnnHeuristicHandle_t handle = rm->getHeuristicHandleForPolicyId(policyInfos[0].policyId); - ASSERT_NE(handle, nullptr); - - auto desc = plugin->createPolicyDescriptor(handle); - ASSERT_NE(desc, nullptr); - - // Create a simple serialized graph (just some bytes) - const std::vector graphBytes = {1, 2, 3, 4, 5}; - hipdnnPluginConstData_t serializedGraph; - serializedGraph.ptr = graphBytes.data(); - serializedGraph.size = graphBytes.size(); - - EXPECT_NO_THROW(plugin->setSerializedGraph(desc, &serializedGraph)); - - plugin->destroyPolicyDescriptor(desc); -} - -TEST_F(IntegrationHeuristicPlugin, FinalizeAndGetSortedEngineIds) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - const HeuristicPlugin* plugin = rm->getPluginForPolicyId(policyInfos[0].policyId); - ASSERT_NE(plugin, nullptr); - - hipdnnHeuristicHandle_t handle = rm->getHeuristicHandleForPolicyId(policyInfos[0].policyId); - ASSERT_NE(handle, nullptr); - - auto desc = plugin->createPolicyDescriptor(handle); - ASSERT_NE(desc, nullptr); - - // Set engine IDs - const std::vector engineIds = {100, 200, 300}; - plugin->setEngineIds(desc, engineIds.data(), engineIds.size()); - - // Finalize the policy - plugin->finalize(desc); - - // Get sorted IDs - const auto sortedIds = plugin->getSortedEngineIds(desc); - EXPECT_TRUE(sortedIds.empty() || !sortedIds.empty()); // May or may not apply - - plugin->destroyPolicyDescriptor(desc); -} - -TEST_F(IntegrationHeuristicPlugin, SetDevicePropertiesOnHandle) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - const HeuristicPlugin* plugin = rm->getPluginForPolicyId(policyInfos[0].policyId); - ASSERT_NE(plugin, nullptr); - - hipdnnHeuristicHandle_t handle = rm->getHeuristicHandleForPolicyId(policyInfos[0].policyId); - ASSERT_NE(handle, nullptr); - - // Create device properties - hipdnn_flatbuffers_sdk::data_objects::DevicePropertiesT props; - props.device_id = 0; - props.multi_processor_count = 120; - props.total_global_mem = 16ULL * 1024 * 1024 * 1024; // 16 GB - props.architecture_name = "gfx90a"; - - // Serialize - auto serialized = serializeDeviceProperties(props); - hipdnnPluginConstData_t devicePropsData; - devicePropsData.ptr = serialized.data(); - devicePropsData.size = serialized.size(); - - // Set on handle (should not throw) - EXPECT_NO_THROW(plugin->setDeviceProperties(handle, &devicePropsData)); -} - -TEST_F(IntegrationHeuristicPlugin, SetDevicePropertiesOnAllHandles) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - // Create device properties - hipdnn_flatbuffers_sdk::data_objects::DevicePropertiesT props; - props.device_id = 0; - props.multi_processor_count = 120; - props.total_global_mem = 16ULL * 1024 * 1024 * 1024; - props.architecture_name = "gfx90a"; - - auto serialized = serializeDeviceProperties(props); - hipdnnPluginConstData_t devicePropsData; - devicePropsData.ptr = serialized.data(); - devicePropsData.size = serialized.size(); - - // Set on all handles via resource manager - EXPECT_NO_THROW(rm->setDevicePropertiesOnAllHandles(&devicePropsData)); -} - -TEST_F(IntegrationHeuristicPlugin, CompleteWorkflowWithDevicePropertiesAndFinalize) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - const HeuristicPlugin* plugin = rm->getPluginForPolicyId(policyInfos[0].policyId); - ASSERT_NE(plugin, nullptr); - - hipdnnHeuristicHandle_t handle = rm->getHeuristicHandleForPolicyId(policyInfos[0].policyId); - ASSERT_NE(handle, nullptr); - - // Set device properties on handle - hipdnn_flatbuffers_sdk::data_objects::DevicePropertiesT props; - props.device_id = 0; - props.multi_processor_count = 120; - props.total_global_mem = 16ULL * 1024 * 1024 * 1024; - props.architecture_name = "gfx90a"; - - auto serialized = serializeDeviceProperties(props); - hipdnnPluginConstData_t devicePropsData; - devicePropsData.ptr = serialized.data(); - devicePropsData.size = serialized.size(); - plugin->setDeviceProperties(handle, &devicePropsData); - - // Create policy descriptor - auto desc = plugin->createPolicyDescriptor(handle); - ASSERT_NE(desc, nullptr); - - // Set engine IDs - const std::vector engineIds = {1000, 2000, 3000}; - plugin->setEngineIds(desc, engineIds.data(), engineIds.size()); - - // Set serialized graph - const std::vector graphBytes = {10, 20, 30}; - hipdnnPluginConstData_t serializedGraph; - serializedGraph.ptr = graphBytes.data(); - serializedGraph.size = graphBytes.size(); - plugin->setSerializedGraph(desc, &serializedGraph); - - // Finalize - plugin->finalize(desc); - - // Get results - const auto sortedIds = plugin->getSortedEngineIds(desc); - - // Clean up - plugin->destroyPolicyDescriptor(desc); -} - -// ========== Plugin Metadata Coverage ========== - -TEST_F(IntegrationHeuristicPlugin, GetPolicyNameFromLoadedPlugin) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - for(const auto& info : policyInfos) - { - const HeuristicPlugin* plugin = rm->getPluginForPolicyId(info.policyId); - if(plugin != nullptr) - { - // Should return policy name (may be empty for no-optional plugin) - const auto policyName = plugin->name(); - // Just verify it doesn't crash - EXPECT_TRUE(policyName.empty() || !policyName.empty()); - } - } -} - -TEST_F(IntegrationHeuristicPlugin, GetPluginVersionFromLoadedPlugin) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - const HeuristicPlugin* plugin = rm->getPluginForPolicyId(policyInfos[0].policyId); - ASSERT_NE(plugin, nullptr); - - const auto pluginVersion = plugin->version(); - EXPECT_FALSE(pluginVersion.empty()); -} - -TEST_F(IntegrationHeuristicPlugin, GetApiVersionFromLoadedPlugin) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - const HeuristicPlugin* plugin = rm->getPluginForPolicyId(policyInfos[0].policyId); - ASSERT_NE(plugin, nullptr); - - const auto apiVersion = plugin->apiVersion(); - EXPECT_FALSE(apiVersion.empty()); - EXPECT_NE(apiVersion.find("0."), std::string_view::npos); // Should be version 0.x -} - -TEST_F(IntegrationHeuristicPlugin, GetPluginTypeFromLoadedPlugin) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - const HeuristicPlugin* plugin = rm->getPluginForPolicyId(policyInfos[0].policyId); - ASSERT_NE(plugin, nullptr); - - // Heuristic plugins report HEURISTIC type - const auto pluginType = plugin->type(); - EXPECT_EQ(pluginType, HIPDNN_PLUGIN_TYPE_HEURISTIC); -} - -// ========== Resource Manager Enumeration Coverage ========== - -TEST_F(IntegrationHeuristicPlugin, GetLoadedPluginFilesReturnsCorrectCount) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - size_t numPlugins = 0; - size_t maxStringLen = 0; - - rm->getLoadedPluginFiles(&numPlugins, nullptr, &maxStringLen); - - // Should have at least the test plugins - EXPECT_GT(numPlugins, 0u); -} - -TEST_F(IntegrationHeuristicPlugin, ToStringContainsPluginInformation) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto str = rm->toString(); - - EXPECT_NE(str.find("HeuristicPluginResourceManager"), std::string::npos); - EXPECT_NE(str.find("Loaded plugins:"), std::string::npos); -} - -// ========== Multiple Descriptors Per Handle ========== - -TEST_F(IntegrationHeuristicPlugin, MultipleDescriptorsFromSameHandle) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - const HeuristicPlugin* plugin = rm->getPluginForPolicyId(policyInfos[0].policyId); - ASSERT_NE(plugin, nullptr); - - hipdnnHeuristicHandle_t handle = rm->getHeuristicHandleForPolicyId(policyInfos[0].policyId); - ASSERT_NE(handle, nullptr); - - // Create multiple descriptors from the same handle - auto desc1 = plugin->createPolicyDescriptor(handle); - auto desc2 = plugin->createPolicyDescriptor(handle); - auto desc3 = plugin->createPolicyDescriptor(handle); - - EXPECT_NE(desc1, nullptr); - EXPECT_NE(desc2, nullptr); - EXPECT_NE(desc3, nullptr); - - // Note: Test plugins may return the same hardcoded pointer for simplicity, - // but real plugins should return distinct descriptors. We just verify they're created. - - // Clean up all - plugin->destroyPolicyDescriptor(desc1); - plugin->destroyPolicyDescriptor(desc2); - plugin->destroyPolicyDescriptor(desc3); -} - -// ========== Error Path Tests ========== -// These tests exercise error handling and edge cases - -// ========== Error Path: Device Properties Exceptions ========== - -TEST_F(IntegrationHeuristicPlugin, SetDevicePropertiesHandlesPluginFailures) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - // Create invalid device properties (empty buffer) - hipdnnPluginConstData_t invalidProps; - invalidProps.ptr = nullptr; - invalidProps.size = 0; - - // Should not throw even if some plugins fail - logs warning and continues - EXPECT_NO_THROW(rm->setDevicePropertiesOnAllHandles(&invalidProps)); -} - -// ========== Error Path: Missing Optional Functions ========== - -TEST_F(IntegrationHeuristicPlugin, PolicyNameReturnsEmptyWhenOptionalFunctionMissing) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - // Find the no-optional plugin which doesn't implement hipdnnHeuristicGetPolicyName - for(const auto& info : policyInfos) - { - const HeuristicPlugin* plugin = rm->getPluginForPolicyId(info.policyId); - if(plugin != nullptr) - { - // policyName() should return empty string for plugins without the optional function - const auto name = plugin->name(); - // Either has a name or empty string (both valid) - EXPECT_TRUE(name.empty() || !name.empty()); - } - } -} - -TEST_F(IntegrationHeuristicPlugin, SetPluginLogLevelHandlesMissingOptionalFunction) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - // setPluginLogLevel should not throw even if optional function is missing - EXPECT_NO_THROW(rm->setPluginLogLevel(HIPDNN_SEV_INFO)); -} - -// ========== Error Path: Empty Engine IDs ========== - -TEST_F(IntegrationHeuristicPlugin, FinalizeWithEmptyEngineIdsSucceeds) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - const HeuristicPlugin* plugin = rm->getPluginForPolicyId(policyInfos[0].policyId); - ASSERT_NE(plugin, nullptr); - - hipdnnHeuristicHandle_t handle = rm->getHeuristicHandleForPolicyId(policyInfos[0].policyId); - ASSERT_NE(handle, nullptr); - - auto desc = plugin->createPolicyDescriptor(handle); - ASSERT_NE(desc, nullptr); - - // Don't set any engine IDs - just finalize - plugin->finalize(desc); - - // Get sorted IDs (should be empty) - const auto sortedIds = plugin->getSortedEngineIds(desc); - EXPECT_TRUE(sortedIds.empty()); - - plugin->destroyPolicyDescriptor(desc); -} - -// ========== Error Path: Multiple Policy Lookups (Same Handle/Plugin Reuse) ========== - -TEST_F(IntegrationHeuristicPlugin, MultipleGetHandleCallsReturnSameHandle) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - const auto policyId = policyInfos[0].policyId; - - // Multiple calls should return the same handle (cached) - auto handle1 = rm->getHeuristicHandleForPolicyId(policyId); - auto handle2 = rm->getHeuristicHandleForPolicyId(policyId); - - EXPECT_EQ(handle1, handle2); - EXPECT_NE(handle1, nullptr); -} - -TEST_F(IntegrationHeuristicPlugin, MultipleGetPluginCallsReturnSamePlugin) -{ - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - const auto policyInfos = rm->getHeuristicPolicyInfos(); - ASSERT_FALSE(policyInfos.empty()); - - const auto policyId = policyInfos[0].policyId; - - // Multiple calls should return the same plugin pointer - const HeuristicPlugin* plugin1 = rm->getPluginForPolicyId(policyId); - const HeuristicPlugin* plugin2 = rm->getPluginForPolicyId(policyId); - - EXPECT_EQ(plugin1, plugin2); - EXPECT_NE(plugin1, nullptr); -} - -// ========== Error Path: No plugins loaded scenario ========== - -TEST_F(IntegrationHeuristicPlugin, SetDevicePropertiesWithNoPluginsLoaded) -{ - // Create RM with no plugins - HeuristicPluginResourceManager::setHeuristicPluginPaths({}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); - - auto rm = HeuristicPluginResourceManager::create(); - ASSERT_NE(rm, nullptr); - - hipdnn_flatbuffers_sdk::data_objects::DevicePropertiesT props; - props.device_id = 0; - props.multi_processor_count = 120; - props.total_global_mem = 16ULL * 1024 * 1024 * 1024; - props.architecture_name = "gfx90a"; - - auto serialized = serializeDeviceProperties(props); - hipdnnPluginConstData_t devicePropsData; - devicePropsData.ptr = serialized.data(); - devicePropsData.size = serialized.size(); - - // Should not throw when no plugins loaded - EXPECT_NO_THROW(rm->setDevicePropertiesOnAllHandles(&devicePropsData)); -} - -// ========== Plugin Loading Tests ========== -// These tests exercise loading real plugins and their functionality - -namespace -{ - -// Wrapper class to access protected constructor -class TestableHeuristicPlugin : public HeuristicPlugin -{ -public: - explicit TestableHeuristicPlugin(SharedLibrary&& lib) - : HeuristicPlugin(std::move(lib)) - { - } -}; - -} // anonymous namespace - -// Fixture that loads the good plugin for tests that need it -class TestHeuristicPluginLoadedGood : public ::testing::Test -{ -protected: - void SetUp() override - { - const auto pluginPath = getHeuristicPluginPath(TEST_GOOD_HEURISTIC_PLUGIN_NAME); - ASSERT_TRUE(std::filesystem::exists(pluginPath)) - << "Test plugin not found: " << pluginPath - << "\nMake sure test_plugins are built before running tests"; - - SharedLibrary lib(pluginPath); - _pluginPtr = std::make_unique(std::move(lib)); - } - - void TearDown() override - { - _pluginPtr.reset(); - } - - TestableHeuristicPlugin& plugin() - { - return *_pluginPtr; - } - -private: - std::unique_ptr _pluginPtr; -}; -TEST_F(IntegrationHeuristicPlugin, LoadGoodPluginSucceeds) -{ - const auto pluginPath = getHeuristicPluginPath(TEST_GOOD_HEURISTIC_PLUGIN_NAME); - - ASSERT_TRUE(std::filesystem::exists(pluginPath)) - << "Test plugin not found: " << pluginPath - << "\nMake sure test_plugins are built before running tests"; - - // Load the plugin - SharedLibrary lib(pluginPath); - // NOLINTNEXTLINE(misc-const-correctness) - ASSERT_NO_THROW({ TestableHeuristicPlugin plugin(std::move(lib)); }); -} -TEST_F(TestHeuristicPluginLoadedGood, LoadedPluginCanQueryApiVersion) -{ - const auto version = plugin().apiVersion(); - EXPECT_FALSE(version.empty()); - EXPECT_EQ(version, "0.0.1"); // HIPDNN_HEURISTIC_API_VERSION -} -TEST_F(TestHeuristicPluginLoadedGood, LoadedPluginCanQueryPolicyId) -{ - const auto policyId = plugin().policyId(); - const auto expectedId = hipdnn_data_sdk::utilities::engineNameToId("TestGoodHeuristicPolicy"); - EXPECT_EQ(policyId, expectedId); -} -TEST_F(TestHeuristicPluginLoadedGood, LoadedPluginCanQueryPolicyName) -{ - const auto name = plugin().name(); - EXPECT_EQ(name, "TestGoodHeuristicPolicy"); -} -TEST_F(TestHeuristicPluginLoadedGood, LoadedPluginCanQueryPluginVersion) -{ - const auto version = plugin().version(); - EXPECT_EQ(version, "1.0.0"); -} -TEST_F(TestHeuristicPluginLoadedGood, LoadedPluginCanCreateAndDestroyHandle) -{ - hipdnnHeuristicHandle_t handle = nullptr; - ASSERT_NO_THROW({ handle = plugin().createHandle(); }); - EXPECT_NE(handle, nullptr); - - ASSERT_NO_THROW({ plugin().destroyHandle(handle); }); -} -TEST_F(TestHeuristicPluginLoadedGood, LoadedPluginCanSetDeviceProperties) -{ - const auto handle = plugin().createHandle(); - ASSERT_NE(handle, nullptr); - - hipdnnPluginConstData_t deviceProps{}; - deviceProps.ptr = nullptr; - deviceProps.size = 0; - - ASSERT_NO_THROW({ plugin().setDeviceProperties(handle, &deviceProps); }); - - plugin().destroyHandle(handle); -} -TEST_F(TestHeuristicPluginLoadedGood, LoadedPluginCanManagePolicyDescriptor) -{ - const auto handle = plugin().createHandle(); - ASSERT_NE(handle, nullptr); - - hipdnnHeuristicPolicyDescriptor_t desc = nullptr; - ASSERT_NO_THROW({ desc = plugin().createPolicyDescriptor(handle); }); - EXPECT_NE(desc, nullptr); - - ASSERT_NO_THROW({ plugin().destroyPolicyDescriptor(desc); }); - - plugin().destroyHandle(handle); -} -TEST_F(TestHeuristicPluginLoadedGood, LoadedPluginCanSetEngineIds) -{ - const auto handle = plugin().createHandle(); - const auto desc = plugin().createPolicyDescriptor(handle); - - const std::vector engineIds = {1, 2, 3, 4, 5}; - ASSERT_NO_THROW({ plugin().setEngineIds(desc, engineIds.data(), engineIds.size()); }); - - plugin().destroyPolicyDescriptor(desc); - plugin().destroyHandle(handle); -} -TEST_F(TestHeuristicPluginLoadedGood, LoadedPluginCanSetSerializedGraph) -{ - const auto handle = plugin().createHandle(); - const auto desc = plugin().createPolicyDescriptor(handle); - - hipdnnPluginConstData_t graphData{}; - graphData.ptr = nullptr; - graphData.size = 0; - - ASSERT_NO_THROW({ plugin().setSerializedGraph(desc, &graphData); }); - - plugin().destroyPolicyDescriptor(desc); - plugin().destroyHandle(handle); -} -TEST_F(TestHeuristicPluginLoadedGood, LoadedPluginCanFinalizePolicy) -{ - const auto handle = plugin().createHandle(); - const auto desc = plugin().createPolicyDescriptor(handle); - - const std::vector engineIds = {1, 2, 3}; - plugin().setEngineIds(desc, engineIds.data(), engineIds.size()); - - bool applied = false; - ASSERT_NO_THROW({ applied = plugin().finalize(desc); }); - EXPECT_TRUE(applied); // Good plugin always applies - - plugin().destroyPolicyDescriptor(desc); - plugin().destroyHandle(handle); -} -TEST_F(TestHeuristicPluginLoadedGood, LoadedPluginCanGetSortedEngineIds) -{ - const auto handle = plugin().createHandle(); - const auto desc = plugin().createPolicyDescriptor(handle); - - const std::vector inputIds = {1, 2, 3, 4, 5}; - plugin().setEngineIds(desc, inputIds.data(), inputIds.size()); - plugin().finalize(desc); - - std::vector sortedIds; - ASSERT_NO_THROW({ sortedIds = plugin().getSortedEngineIds(desc); }); - - // Good plugin reverses the order - EXPECT_EQ(sortedIds.size(), inputIds.size()); - EXPECT_EQ(sortedIds, std::vector({5, 4, 3, 2, 1})); - - plugin().destroyPolicyDescriptor(desc); - plugin().destroyHandle(handle); -} -TEST_F(TestHeuristicPluginLoadedGood, LoadedPluginCompleteWorkflow) -{ - // Create handle and descriptor - const auto handle = plugin().createHandle(); - ASSERT_NE(handle, nullptr); - - const auto desc = plugin().createPolicyDescriptor(handle); - ASSERT_NE(desc, nullptr); - - // Set inputs - const std::vector inputIds = {10, 20, 30}; - plugin().setEngineIds(desc, inputIds.data(), inputIds.size()); - - hipdnnPluginConstData_t graphData{}; - graphData.ptr = nullptr; - graphData.size = 0; - plugin().setSerializedGraph(desc, &graphData); - - // Finalize and retrieve results - const bool applied = plugin().finalize(desc); - EXPECT_TRUE(applied); - - const auto sortedIds = plugin().getSortedEngineIds(desc); - EXPECT_EQ(sortedIds, std::vector({30, 20, 10})); - - // Cleanup - plugin().destroyPolicyDescriptor(desc); - plugin().destroyHandle(handle); -} -TEST_F(IntegrationHeuristicPlugin, LoadIncompletePluginThrowsException) -{ - const auto pluginPath = getHeuristicPluginPath(TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME); - - ASSERT_TRUE(std::filesystem::exists(pluginPath)) << "Test plugin not found: " << pluginPath; - - SharedLibrary lib(pluginPath); - - // Loading should fail during symbol resolution - EXPECT_THROW( - { - try - { - const TestableHeuristicPlugin plugin(std::move(lib)); - } - catch(const HipdnnException& e) - { - // Verify the exception contains expected error details - const std::string errorMsg(e.what()); - EXPECT_NE(errorMsg.find("HEURISTIC PLUGIN ABI INCOMPLETE"), std::string::npos); - EXPECT_NE(errorMsg.find("Missing required symbol"), std::string::npos); - // Error text uses SharedLibrary's weakly_canonical path; on Windows the string can - // differ in drive-letter case or separators from a fresh weakly_canonical(pluginPath). - const auto canonicalPath = std::filesystem::weakly_canonical(pluginPath); - static constexpr std::string_view K_PLUGIN_PREFIX{"Plugin: "}; - const auto prefixPos = errorMsg.find(K_PLUGIN_PREFIX); - ASSERT_NE(prefixPos, std::string::npos); - const auto pathStart = prefixPos + K_PLUGIN_PREFIX.size(); - const auto pathEnd = errorMsg.find('\n', pathStart); - ASSERT_NE(pathEnd, std::string::npos); - const std::filesystem::path pluginPathInMessage( - errorMsg.substr(pathStart, pathEnd - pathStart)); - EXPECT_TRUE( - hipdnn_data_sdk::utilities::pathCompEq(pluginPathInMessage, canonicalPath)) - << "pluginPathInMessage='" << pluginPathInMessage.string() - << "' canonicalPath='" << canonicalPath.string() << "'"; - throw; - } - }, - HipdnnException); -} -TEST_F(IntegrationHeuristicPlugin, IncompletePluginExceptionContainsSymbolName) -{ - const auto pluginPath = getHeuristicPluginPath(TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME); - SharedLibrary lib(pluginPath); - - EXPECT_THROW( - { - try - { - const TestableHeuristicPlugin plugin(std::move(lib)); - } - catch(const HipdnnException& e) - { - const std::string errorMsg(e.what()); - // Should mention one of the missing required symbols - const bool hasPluginNameError - = errorMsg.find("hipdnnPluginGetName") != std::string::npos; - const bool hasFinalizeError - = errorMsg.find("hipdnnHeuristicPolicyFinalize") != std::string::npos; - const bool hasGetSortedError - = errorMsg.find("hipdnnHeuristicPolicyGetSortedEngineIds") != std::string::npos; - EXPECT_TRUE(hasPluginNameError || hasFinalizeError || hasGetSortedError); - throw; - } - }, - HipdnnException); -} -TEST_F(IntegrationHeuristicPlugin, IncompletePluginExceptionHasPluginErrorStatus) -{ - const auto pluginPath = getHeuristicPluginPath(TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME); - SharedLibrary lib(pluginPath); - - EXPECT_THROW( - { - try - { - const TestableHeuristicPlugin plugin(std::move(lib)); - } - catch(const HipdnnException& e) - { - EXPECT_EQ(e.getStatus(), HIPDNN_STATUS_PLUGIN_ERROR); - throw; - } - }, - HipdnnException); -} -TEST_F(IntegrationHeuristicPlugin, LoadPluginWithoutOptionalSymbolsSucceeds) -{ - const auto pluginPath = getHeuristicPluginPath(TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME); - - // NOLINTNEXTLINE(readability-implicit-bool-conversion) - ASSERT_TRUE(std::filesystem::exists(pluginPath)) << "Test plugin not found: " << pluginPath; - - SharedLibrary lib(pluginPath); - - // Should load successfully despite missing optional symbols - // NOLINTNEXTLINE(misc-const-correctness) - ASSERT_NO_THROW({ TestableHeuristicPlugin plugin(std::move(lib)); }); -} -TEST_F(IntegrationHeuristicPlugin, PluginWithoutOptionalPolicyNameHasName) -{ - const auto pluginPath = getHeuristicPluginPath(TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME); - SharedLibrary lib(pluginPath); - const TestableHeuristicPlugin plugin(std::move(lib)); - - // GetPolicyName is now required - const auto name = plugin.name(); - EXPECT_FALSE(name.empty()); - EXPECT_EQ(name, "TestNoOptionalHeuristicPolicy"); -} -TEST_F(IntegrationHeuristicPlugin, PluginWithoutOptionalSetLogLevelSucceeds) -{ - const auto pluginPath = getHeuristicPluginPath(TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME); - SharedLibrary lib(pluginPath); - const TestableHeuristicPlugin plugin(std::move(lib)); - - // Plugin doesn't implement hipdnnHeuristicSetLogLevel - // Should return SUCCESS without calling the function - const auto status = plugin.setLogLevel(HIPDNN_SEV_INFO); - EXPECT_EQ(status, HIPDNN_PLUGIN_STATUS_SUCCESS); -} -TEST_F(IntegrationHeuristicPlugin, PluginWithoutOptionalCanStillExecuteWorkflow) -{ - const auto pluginPath = getHeuristicPluginPath(TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME); - SharedLibrary lib(pluginPath); - const TestableHeuristicPlugin plugin(std::move(lib)); - - // Full workflow should work despite missing optional functions - const auto handle = plugin.createHandle(); - ASSERT_NE(handle, nullptr); - - const auto desc = plugin.createPolicyDescriptor(handle); - ASSERT_NE(desc, nullptr); - - const std::vector inputIds = {1, 2, 3}; - plugin.setEngineIds(desc, inputIds.data(), inputIds.size()); - - const bool applied = plugin.finalize(desc); - EXPECT_FALSE(applied); // This plugin declines to apply - - const auto sortedIds = plugin.getSortedEngineIds(desc); - EXPECT_TRUE(sortedIds.empty()); // Returns empty list - - plugin.destroyPolicyDescriptor(desc); - plugin.destroyHandle(handle); -} -TEST_F(TestHeuristicPluginLoadedGood, RealPluginCachesPolicyId) -{ - // First call - ID is computed from policy name - const auto id1 = plugin().policyId(); - const auto expectedId = hipdnn_data_sdk::utilities::engineNameToId("TestGoodHeuristicPolicy"); - EXPECT_EQ(id1, expectedId); - - // Second call should return cached value - const auto id2 = plugin().policyId(); - EXPECT_EQ(id2, id1); -} -TEST_F(IntegrationHeuristicPlugin, GetSortedEngineIdsReturnsEmptyWhenNoEngines) -{ - const auto pluginPath = getHeuristicPluginPath(TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME); - SharedLibrary lib(pluginPath); - const TestableHeuristicPlugin plugin(std::move(lib)); - - const auto handle = plugin.createHandle(); - const auto desc = plugin.createPolicyDescriptor(handle); - - // Don't set any engine IDs - plugin.finalize(desc); - - const auto sortedIds = plugin.getSortedEngineIds(desc); - EXPECT_TRUE(sortedIds.empty()); - - plugin.destroyPolicyDescriptor(desc); - plugin.destroyHandle(handle); -} diff --git a/projects/hipdnn/backend/tests/TestHeuristicPluginManager.cpp b/projects/hipdnn/backend/tests/TestHeuristicPluginManager.cpp index 9c5809b4e02..0904e29699a 100644 --- a/projects/hipdnn/backend/tests/TestHeuristicPluginManager.cpp +++ b/projects/hipdnn/backend/tests/TestHeuristicPluginManager.cpp @@ -3,7 +3,7 @@ /** * @file TestHeuristicPluginManager.cpp - * @brief Unit tests for HeuristicPluginManager validation logic (RFC 0007 Part 1) + * @brief Unit tests for HeuristicPluginManager validation logic * * These tests verify the plugin discovery and validation layer including: * - API version compatibility validation @@ -15,6 +15,9 @@ #include "plugin/HeuristicPlugin.hpp" #include "plugin/HeuristicPluginManager.hpp" +#include +#include + #include #include @@ -34,7 +37,7 @@ class TestHeuristicPluginManager : public ::testing::Test // Helper to create a valid policy name/ID pair static std::pair makeValidPolicyPair(const std::string& baseName) { - const int64_t policyId = hipdnn_data_sdk::utilities::engineNameToId(baseName); + const int64_t policyId = hipdnn_data_sdk::utilities::policyNameToId(baseName); return {baseName, policyId}; } }; @@ -46,31 +49,18 @@ TEST_F(TestHeuristicPluginManager, ConstructorSucceeds) EXPECT_NO_THROW(const HeuristicPluginManager manager); } -TEST_F(TestHeuristicPluginManager, ConstructorInitializesSearchPaths) -{ - const HeuristicPluginManager manager; - - // Manager should be created (implementation uses default search paths) - // We can't easily test internal state, but construction should succeed - SUCCEED(); -} - // ========== Plugin Loading Tests ========== TEST_F(TestHeuristicPluginManager, LoadPluginsFromEmptyDirectorySucceeds) { HeuristicPluginManager manager; - // Create a temporary empty directory - const std::filesystem::path emptyDir - = std::filesystem::temp_directory_path() / "hipdnn_test_empty"; - std::filesystem::create_directories(emptyDir); - - // Should not throw when loading from empty directory - EXPECT_NO_THROW(manager.loadPlugins({emptyDir}, HIPDNN_PLUGIN_LOADING_ABSOLUTE)); + const auto uniqueName = std::string("hipdnn_test_empty_") + + std::to_string(::testing::UnitTest::GetInstance()->random_seed()); + const hipdnn_test_sdk::utilities::ScopedDirectory emptyDir( + std::filesystem::temp_directory_path() / uniqueName); - // Cleanup - std::filesystem::remove(emptyDir); + EXPECT_NO_THROW(manager.loadPlugins({emptyDir.path()}, HIPDNN_PLUGIN_LOADING_ABSOLUTE)); } TEST_F(TestHeuristicPluginManager, LoadPluginsFromNonexistentDirectorySucceeds) @@ -97,20 +87,9 @@ TEST_F(TestHeuristicPluginManager, LoadPluginsWithMultiplePathsSucceeds) EXPECT_NO_THROW(manager.loadPlugins(paths, HIPDNN_PLUGIN_LOADING_ABSOLUTE)); } -// ========== Validation Tests - API Version ========== - // Note: Actual validation happens inside validateBeforeAdding() which is protected. -// We test it indirectly through loadPlugins() with real plugin files, but those -// tests are in integration tests. Here we verify the manager's structure supports validation. - -TEST_F(TestHeuristicPluginManager, ManagerSupportsValidation) -{ - const HeuristicPluginManager manager; - - // The manager should be constructed with validation capabilities - // This is a structural test - actual validation tested via integration tests - SUCCEED(); -} +// It is exercised end-to-end in TestHeuristicPluginManagerValidationPaths.cpp using +// real test plugins. // ========== Multiple Load Cycles Tests ========== @@ -154,8 +133,10 @@ TEST_F(TestHeuristicPluginManager, MultipleInstancesAreIndependent) manager2.loadPlugins({std::filesystem::temp_directory_path() / "path2"}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); - // Both should work independently - SUCCEED(); + // Only the always-registered Config + StaticOrdering built-ins remain; no + // external plugin loaded from a non-existent path. + EXPECT_EQ(manager1.getPlugins().size(), 2u); + EXPECT_EQ(manager2.getPlugins().size(), 2u); } // ========== Edge Cases Tests ========== @@ -181,226 +162,99 @@ TEST_F(TestHeuristicPluginManager, LoadPluginsWithSamePathTwiceSucceeds) EXPECT_NO_THROW(manager.loadPlugins(paths, HIPDNN_PLUGIN_LOADING_ABSOLUTE)); } -// ========== Destructor Tests ========== +// ========== Plugin Loading Mode Tests ========== -TEST_F(TestHeuristicPluginManager, DestructorCleansUpResources) +TEST_F(TestHeuristicPluginManager, LoadPluginsWithAbsoluteModeResets) { - { - HeuristicPluginManager manager; - manager.loadPlugins({std::filesystem::temp_directory_path() / "test"}, - HIPDNN_PLUGIN_LOADING_ABSOLUTE); - } // manager destroyed here - - // If we get here without crashes, cleanup succeeded - SUCCEED(); -} - -TEST_F(TestHeuristicPluginManager, MultipleDestructionsSucceed) -{ - for(int i = 0; i < 10; ++i) - { - HeuristicPluginManager manager; - manager.loadPlugins({std::filesystem::temp_directory_path() / "test"}, - HIPDNN_PLUGIN_LOADING_ABSOLUTE); - // Destroyed at end of loop - } - - SUCCEED(); -} - -// ========== Policy ID Tracking Tests ========== - -TEST_F(TestHeuristicPluginManager, ManagerTracksLoadedPolicies) -{ - const HeuristicPluginManager manager; - - // After construction, no policies should be loaded - // This verifies the manager maintains internal state for policy tracking - // Actual policy ID validation tested via integration tests with real plugins - SUCCEED(); -} + HeuristicPluginManager manager; -// ========== Search Path Tests ========== + const std::set paths1 + = {std::filesystem::temp_directory_path() / "path1"}; + const std::set paths2 + = {std::filesystem::temp_directory_path() / "path2"}; -TEST_F(TestHeuristicPluginManager, DefaultSearchPathsAreUsed) -{ - const HeuristicPluginManager manager; + manager.loadPlugins(paths1, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + const size_t count1 = manager.getPlugins().size(); - // Manager should initialize with default search paths - // (implementation detail - verified indirectly) - SUCCEED(); -} + manager.loadPlugins(paths2, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + const size_t count2 = manager.getPlugins().size(); -TEST_F(TestHeuristicPluginManager, EnvironmentVariablePathsAreSupported) -{ - // The manager uses getPluginSearchPaths() which checks HIPDNN_HEURISTIC_PLUGIN_DIR - // This is tested indirectly through construction - const HeuristicPluginManager manager; - SUCCEED(); + // Both should be 0 since paths don't exist, but operation should succeed + EXPECT_EQ(count1, count2); } -using namespace hipdnn_backend::plugin; - -class TestHeuristicPluginManagerValidation : public ::testing::Test -{ -protected: - void SetUp() override {} - void TearDown() override {} - - // Helper to create a valid policy name/ID pair - static std::pair makeValidPolicyPair(const std::string& baseName) - { - const int64_t policyId = hipdnn_data_sdk::utilities::engineNameToId(baseName); - return {baseName, policyId}; - } -}; -// ========== API Version Validation Tests ========== - -TEST_F(TestHeuristicPluginManager, ValidApiVersionAccepted) +TEST_F(TestHeuristicPluginManager, LoadPluginsWithAdditiveModeAdds) { - // This test verifies that plugins with matching API version are accepted - // We can't easily inject a mock plugin into the real manager without actual .so files, - // so this is more of a structural test - const HeuristicPluginManager manager; + HeuristicPluginManager manager; - // If construction succeeds, validation infrastructure is in place - SUCCEED(); -} + const std::set paths1 + = {std::filesystem::temp_directory_path() / "path1"}; + const std::set paths2 + = {std::filesystem::temp_directory_path() / "path2"}; -TEST_F(TestHeuristicPluginManager, ManagerConstructorSucceeds) -{ - // Verify manager can be constructed - auto manager = std::make_shared(); - ASSERT_NE(manager, nullptr); + // Exercise the additive code path. With non-existent paths the plugin count + // stays at 0, but the additive branch in PluginManagerBase::loadPlugins is + // executed and any leak/crash inside it would be caught under ASAN. + EXPECT_NO_THROW(manager.loadPlugins(paths1, HIPDNN_PLUGIN_LOADING_ABSOLUTE)); + EXPECT_NO_THROW(manager.loadPlugins(paths2, HIPDNN_PLUGIN_LOADING_ADDITIVE)); } -TEST_F(TestHeuristicPluginManager, ManagerUsesHeuristicPluginSearchPaths) -{ - // Verify the manager initializes with heuristic-specific search paths - const HeuristicPluginManager manager; +// Destructor coverage with real loaded plugins lives in +// TestHeuristicPluginManagerValidationPaths.cpp::DestructorUnloadsLoadedPluginLibraries, +// where _testPluginPath provides actual shared libraries to exercise the unload path. - // Manager should have been constructed with HIPDNN_HEURISTIC_PLUGIN_DIR paths - SUCCEED(); -} +// ========== Policy ID/Name Validation Tests ========== -// ========== Policy ID Uniqueness Tests ========== - -TEST_F(TestHeuristicPluginManager, ValidPolicyIdNamePairStructure) +TEST_F(TestHeuristicPluginManager, PolicyNameToIdIsConsistent) { - // Test the helper function that creates valid pairs - const auto [name, id] = makeValidPolicyPair("TestPolicy"); - - EXPECT_FALSE(name.empty()); - EXPECT_NE(id, 0); + const std::string name1 = "Vendor::PolicyA"; + const int64_t id1a = hipdnn_data_sdk::utilities::policyNameToId(name1); + const int64_t id1b = hipdnn_data_sdk::utilities::policyNameToId(name1); - // Verify consistency - const int64_t computedId = hipdnn_data_sdk::utilities::engineNameToId(name); - EXPECT_EQ(computedId, id); + EXPECT_EQ(id1a, id1b); } TEST_F(TestHeuristicPluginManager, DifferentNamesProduceDifferentIds) { - const auto [name1, id1] = makeValidPolicyPair("Policy1"); - const auto [name2, id2] = makeValidPolicyPair("Policy2"); + const std::string name1 = "Vendor::PolicyA"; + const std::string name2 = "SelectionHeuristic::StaticOrdering"; + + const int64_t id1 = hipdnn_data_sdk::utilities::policyNameToId(name1); + const int64_t id2 = hipdnn_data_sdk::utilities::policyNameToId(name2); - EXPECT_NE(name1, name2); EXPECT_NE(id1, id2); } -// ========== Policy ID/Name Consistency Tests ========== - -TEST_F(TestHeuristicPluginManager, EngineNameToIdIsConsistent) +TEST_F(TestHeuristicPluginManager, PolicyIdIsNonZero) { - const std::string policyName = "SelectionHeuristic::Config"; - const int64_t id1 = hipdnn_data_sdk::utilities::engineNameToId(policyName); - const int64_t id2 = hipdnn_data_sdk::utilities::engineNameToId(policyName); + const std::string name = "Vendor::PolicyA"; + const int64_t id = hipdnn_data_sdk::utilities::policyNameToId(name); - EXPECT_EQ(id1, id2); // Same name should produce same ID + EXPECT_NE(id, 0); } TEST_F(TestHeuristicPluginManager, EmptyPolicyNameProducesZeroId) { const std::string emptyName; - const int64_t id = hipdnn_data_sdk::utilities::engineNameToId(emptyName); + const int64_t id = hipdnn_data_sdk::utilities::policyNameToId(emptyName); - // Empty string should produce a specific ID (likely 0 or a hash of empty string) + // Empty string should produce a specific ID (FNV-1a hash of empty string) EXPECT_EQ(id, 0); } -// ========== Multiple Manager Instances Tests ========== +// ========== State Verification Tests ========== -TEST_F(TestHeuristicPluginManager, MultipleManagersAreIndependent) -{ - auto manager1 = std::make_shared(); - auto manager2 = std::make_shared(); - - EXPECT_NE(manager1, manager2); - - // Both should be functional - EXPECT_NE(manager1, nullptr); - EXPECT_NE(manager2, nullptr); -} - -TEST_F(TestHeuristicPluginManager, ManagerDestructionSucceeds) -{ - { - const HeuristicPluginManager manager; - // Use the manager - const auto& plugins = manager.getPlugins(); - EXPECT_TRUE(plugins.empty() || !plugins.empty()); // Always true, just use it - } // Manager destroyed here - - SUCCEED(); -} - -// ========== Additional Search Path Tests ========== - -TEST_F(TestHeuristicPluginManager, LoadPluginsFromEmptyPathsSucceeds) +TEST_F(TestHeuristicPluginManager, GetPluginsAfterEmptyLoadReturnsEmpty) { HeuristicPluginManager manager; - const std::set emptyPaths; - EXPECT_NO_THROW(manager.loadPlugins(emptyPaths, HIPDNN_PLUGIN_LOADING_ABSOLUTE)); -} - -TEST_F(TestHeuristicPluginManager, LoadPluginsWithNonexistentPathSucceeds) -{ - HeuristicPluginManager manager; - - const std::set paths - = {std::filesystem::temp_directory_path() / "nonexistent_path_to_plugins"}; - EXPECT_NO_THROW(manager.loadPlugins(paths, HIPDNN_PLUGIN_LOADING_ABSOLUTE)); -} - -TEST_F(TestHeuristicPluginManager, MultipleLoadCallsSucceed) -{ - HeuristicPluginManager manager; - - const std::set paths1 - = {std::filesystem::temp_directory_path() / "plugins1"}; - const std::set paths2 - = {std::filesystem::temp_directory_path() / "plugins2"}; - - EXPECT_NO_THROW(manager.loadPlugins(paths1, HIPDNN_PLUGIN_LOADING_ABSOLUTE)); - EXPECT_NO_THROW(manager.loadPlugins(paths2, HIPDNN_PLUGIN_LOADING_ADDITIVE)); -} - -// ========== Plugin Enumeration Tests ========== - -TEST_F(TestHeuristicPluginManager, GetPluginsWhenNoneLoaded) -{ - const HeuristicPluginManager manager; - - const auto& plugins = manager.getPlugins(); - EXPECT_TRUE(plugins.empty()); -} - -TEST_F(TestHeuristicPluginManager, GetPluginsIsConsistent) -{ - const HeuristicPluginManager manager; - - const auto& plugins1 = manager.getPlugins(); - const auto& plugins2 = manager.getPlugins(); + const auto uniqueName = std::string("hipdnn_empty_load_test_") + + std::to_string(::testing::UnitTest::GetInstance()->random_seed()); + const hipdnn_test_sdk::utilities::ScopedDirectory emptyDir( + std::filesystem::temp_directory_path() / uniqueName); - EXPECT_EQ(plugins1.size(), plugins2.size()); + manager.loadPlugins({emptyDir.path()}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + // Only the always-registered Config + StaticOrdering built-ins remain; the + // empty dir contributed nothing. + EXPECT_EQ(manager.getPlugins().size(), 2u); } diff --git a/projects/hipdnn/backend/tests/TestHeuristicPluginManagerValidationPaths.cpp b/projects/hipdnn/backend/tests/TestHeuristicPluginManagerValidationPaths.cpp new file mode 100644 index 00000000000..683461acbbc --- /dev/null +++ b/projects/hipdnn/backend/tests/TestHeuristicPluginManagerValidationPaths.cpp @@ -0,0 +1,535 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file TestHeuristicPluginManagerValidationPaths.cpp + * @brief Tests for HeuristicPluginManager validation code paths + * + * These tests load actual test plugins to exercise validateBeforeAdding() and + * actionAfterAdding() to improve coverage of HeuristicPluginManager.hpp + */ + +#include "HipdnnException.hpp" +#include "PlatformUtils.hpp" +#include "TestPluginConstants.hpp" +#include "plugin/HeuristicPluginManager.hpp" + +#include +#include +#include +#include +#include + +// Test plugin name constants (defined here because CMake ordering prevents proper macro propagation) +namespace +{ +constexpr const char* BAD_API_VERSION_PLUGIN = "test_bad_api_version_heuristic_plugin"; +constexpr const char* EMPTY_NAME_PLUGIN = "test_empty_name_heuristic_plugin"; +constexpr const char* DUPLICATE_POLICY_ID_A_PLUGIN = "test_duplicate_policy_id_a_plugin"; +constexpr const char* DUPLICATE_POLICY_ID_B_PLUGIN = "test_duplicate_policy_id_b_plugin"; +} // namespace + +using namespace hipdnn_backend; +using namespace hipdnn_backend::plugin; +using namespace hipdnn_backend::plugin_constants; + +class TestHeuristicPluginManagerValidationPaths : public ::testing::Test +{ +protected: + static void SetUpTestSuite() + { + // Check once if test plugins are available + const auto pluginPath = getHeuristicPluginPath("").parent_path(); + if(!std::filesystem::exists(pluginPath)) + { + GTEST_SKIP() << "Test plugins not found at: " << pluginPath + << "\nMake sure test_plugins are built before running tests"; + } + } + + void SetUp() override + { + // Test plugins are in lib/test_plugins/custom relative to backend library location + _testPluginPath = getHeuristicPluginPath("").parent_path(); + + // Create manager for each test + _manager = std::make_unique(); + } + + std::unique_ptr _manager; + std::filesystem::path _testPluginPath; +}; + +// ========== validateBeforeAdding() Tests - API Version Check ========== + +TEST_F(TestHeuristicPluginManagerValidationPaths, GoodPluginPassesApiVersionValidation) +{ + // Load good test plugin - should pass API version validation + EXPECT_NO_THROW(_manager->loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE)); + + const auto& plugins = _manager->getPlugins(); + + // Should have loaded at least the good plugin + EXPECT_GT(plugins.size(), 0); + + // All loaded plugins should have correct API major version + for(const auto& plugin : plugins) + { + const auto version = hipdnn_data_sdk::utilities::Version{plugin->apiVersion()}; + EXPECT_EQ(version.major, HIPDNN_HEURISTIC_API_VERSION_MAJOR) + << "validateBeforeAdding should have checked API version for plugin: " + << plugin->name(); + } +} + +// ========== validateBeforeAdding() Tests - Policy ID Uniqueness ========== + +TEST_F(TestHeuristicPluginManagerValidationPaths, ActionAfterAddingStoresPolicyIds) +{ + // Load plugins - actionAfterAdding should store policy IDs + _manager->loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + const auto& plugins = _manager->getPlugins(); + + // Collect all policy IDs to verify actionAfterAdding was called + std::set policyIds; + size_t totalPolicyCount = 0; + for(const auto& plugin : plugins) + { + for(const int64_t id : plugin->getAllPolicyIds()) + { + ++totalPolicyCount; + + // Each policy ID should be unique (actionAfterAdding should have tracked this) + EXPECT_EQ(policyIds.count(id), 0) << "Policy ID " << id + << " appears multiple times (actionAfterAdding " + "tracking failed)"; + + policyIds.insert(id); + } + } + + // Should have loaded plugins with unique policy IDs + EXPECT_GT(policyIds.size(), 0) << "Should have loaded at least one plugin policy"; + EXPECT_EQ(policyIds.size(), totalPolicyCount) << "All policy IDs should be unique"; +} + +TEST_F(TestHeuristicPluginManagerValidationPaths, PolicyIdTrackingAcrossMultiplePlugins) +{ + // Load all available plugins + _manager->loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + const auto& plugins = _manager->getPlugins(); + + // Collect all policy IDs from every loaded plugin/policy + std::set policyIds; + size_t totalPolicyCount = 0; + for(const auto& plugin : plugins) + { + for(const int64_t id : plugin->getAllPolicyIds()) + { + ++totalPolicyCount; + + // Each policy ID should be unique (actionAfterAdding tracks this) + EXPECT_EQ(policyIds.count(id), 0) + << "Policy ID " << id << " appears multiple times (validateBeforeAdding failed)"; + + policyIds.insert(id); + } + } + + // Should have as many unique IDs as policies across all plugins + EXPECT_EQ(policyIds.size(), totalPolicyCount); +} + +// ========== validateBeforeAdding() Tests - Policy Name Check ========== + +TEST_F(TestHeuristicPluginManagerValidationPaths, AllLoadedPluginsHaveNonEmptyNames) +{ + _manager->loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + const auto& plugins = _manager->getPlugins(); + + // validateBeforeAdding should have rejected any plugins with empty names + for(const auto& plugin : plugins) + { + EXPECT_FALSE(plugin->name().empty()) + << "validateBeforeAdding should reject plugins with empty policy names"; + } +} + +TEST_F(TestHeuristicPluginManagerValidationPaths, PolicyNameIsProvided) +{ + _manager->loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + const auto& plugins = _manager->getPlugins(); + EXPECT_GT(plugins.size(), 0); + + for(const auto& plugin : plugins) + { + // Plugin name must be non-empty (validated in HeuristicPluginManager) + const std::string pluginName(plugin->name()); + EXPECT_FALSE(pluginName.empty()) << "Plugin has empty name"; + + // Each policy must have a non-empty name (validated eagerly in HeuristicPlugin) + for(const int64_t policyId : plugin->getAllPolicyIds()) + { + const std::string policyName(plugin->getPolicyName(policyId)); + EXPECT_FALSE(policyName.empty()) << "Policy ID " << policyId << " has empty name"; + } + } +} + +// ========== Validation Success Path Tests ========== + +TEST_F(TestHeuristicPluginManagerValidationPaths, ValidPluginPassesAllValidation) +{ + // Load should succeed for valid plugins + EXPECT_NO_THROW(_manager->loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE)); + + const auto& plugins = _manager->getPlugins(); + + for(const auto& plugin : plugins) + { + // API version check + const auto version = hipdnn_data_sdk::utilities::Version{plugin->apiVersion()}; + EXPECT_EQ(version.major, HIPDNN_HEURISTIC_API_VERSION_MAJOR); + + // Plugin name check + EXPECT_FALSE(plugin->name().empty()); + + // Each policy ID should be non-zero + for(const int64_t policyId : plugin->getAllPolicyIds()) + { + EXPECT_NE(policyId, 0); + } + } +} + +// ========== actionAfterAdding() Coverage Tests ========== + +TEST_F(TestHeuristicPluginManagerValidationPaths, ActionAfterAddingExecutesForEachPlugin) +{ + _manager->loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + const auto& plugins = _manager->getPlugins(); + + // For each policy across all plugins, actionAfterAdding should have inserted the + // policy ID into _policyIds set. Verified indirectly by ensuring no duplicates exist. + std::set observedIds; + for(const auto& plugin : plugins) + { + for(const int64_t id : plugin->getAllPolicyIds()) + { + EXPECT_EQ(observedIds.count(id), 0) + << "Duplicate policy ID detected - actionAfterAdding may have failed"; + observedIds.insert(id); + } + } +} + +// ========== Multiple Load Cycles with Validation ========== + +TEST_F(TestHeuristicPluginManagerValidationPaths, ValidationRunsOnEachLoadCycle) +{ + // Create new manager for each load to ensure fresh state + HeuristicPluginManager manager1; // NOLINT(misc-const-correctness) + manager1.loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + const size_t count1 = manager1.getPlugins().size(); + + // Second manager should also run validation and load same plugins + HeuristicPluginManager manager2; // NOLINT(misc-const-correctness) + manager2.loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + const size_t count2 = manager2.getPlugins().size(); + + EXPECT_EQ(count1, count2) << "Both managers should validate and load same plugins"; + EXPECT_GT(count1, 0) << "Should have loaded at least one plugin"; +} + +TEST_F(TestHeuristicPluginManagerValidationPaths, AbsoluteReloadResetsPolicyIdTracking) +{ + // Regression: ABSOLUTE-mode reload must clear the derived-class policy-id + // index, not just _plugins. Otherwise reloading a plugin whose policy id + // matches one from the previous load triggers a false "already exists" + // failure in validateBeforeAdding. + _manager->loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + const size_t firstCount = _manager->getPlugins().size(); + ASSERT_GT(firstCount, 0u) << "Test precondition: at least one plugin must load"; + + EXPECT_NO_THROW(_manager->loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE)); + EXPECT_EQ(_manager->getPlugins().size(), firstCount); +} + +// ========== Constructor Path Coverage ========== + +TEST_F(TestHeuristicPluginManagerValidationPaths, ConstructorSetsUpValidationInfrastructure) +{ + // Constructor initializes with search paths and empty _policyIds set. + // Assert against the freshly-constructed local manager (not the fixture's), + // otherwise the test is just re-checking SetUp's invariant. + const HeuristicPluginManager manager; + + // A freshly-constructed manager always contains the Config + StaticOrdering + // built-ins (registered in HeuristicPluginManager's constructor); nothing else yet. + EXPECT_EQ(manager.getPlugins().size(), 2u); +} + +// ========== Destructor Path Coverage ========== + +TEST_F(TestHeuristicPluginManagerValidationPaths, DestructorUnloadsLoadedPluginLibraries) +{ + _manager->loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + ASSERT_FALSE(_manager->getPlugins().empty()) + << "Test precondition: at least one plugin must load to exercise the unload path"; + + // Destruction triggers SharedLibrary teardown for each loaded plugin + // (dlclose / FreeLibrary). ASAN catches any leak in plugin-side static teardown. + EXPECT_NO_THROW(_manager.reset()); + EXPECT_EQ(_manager, nullptr); +} + +// ========== Integration with PluginManagerBase ========== + +TEST_F(TestHeuristicPluginManagerValidationPaths, ValidationIntegratesWithBaseClass) +{ + // PluginManagerBase calls validateBeforeAdding before adding each plugin + // and actionAfterAdding after successful add + _manager->loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + const auto& plugins = _manager->getPlugins(); + + // All plugins should have passed validation + for(const auto& plugin : plugins) + { + // These checks verify that validateBeforeAdding was called and passed + EXPECT_FALSE(plugin->name().empty()); + + const auto version = hipdnn_data_sdk::utilities::Version{plugin->apiVersion()}; + EXPECT_EQ(version.major, HIPDNN_HEURISTIC_API_VERSION_MAJOR); + + for(const int64_t policyId : plugin->getAllPolicyIds()) + { + EXPECT_NE(policyId, 0); + } + } +} + +// ========== Specific Test Plugin Tests ========== + +TEST_F(TestHeuristicPluginManagerValidationPaths, NoOptionalHeuristicPluginPassesValidation) +{ + // test_no_optional_heuristic_plugin doesn't implement optional functions + // but should still pass validation + _manager->loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + const auto& plugins = _manager->getPlugins(); + ASSERT_FALSE(plugins.empty()) << "Expected at least one plugin to load from " + << _testPluginPath; + + bool foundNoOptional = false; + for(const auto& plugin : plugins) + { + const std::string name(plugin->name()); + if(name.find("NoOptional") != std::string::npos) + { + foundNoOptional = true; + // Should pass all validation checks + EXPECT_FALSE(name.empty()); + for(const int64_t policyId : plugin->getAllPolicyIds()) + { + EXPECT_NE(policyId, 0); + } + } + } + EXPECT_TRUE(foundNoOptional) << "test_no_optional_heuristic_plugin should be loaded"; +} + +TEST_F(TestHeuristicPluginManagerValidationPaths, GoodHeuristicPluginPassesValidation) +{ + _manager->loadPlugins({_testPluginPath}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + const auto& plugins = _manager->getPlugins(); + ASSERT_FALSE(plugins.empty()) << "Expected at least one plugin to load from " + << _testPluginPath; + + // Match the exact good-plugin name. A loose substring match would pass even if + // the good plugin failed to load. + constexpr const char* K_GOOD_PLUGIN_NAME = "TestGoodHeuristicPlugin"; + + bool foundGood = false; + for(const auto& plugin : plugins) + { + const std::string name(plugin->name()); + if(name == K_GOOD_PLUGIN_NAME) + { + foundGood = true; + // Verify it passed all validation: + // 1. API version + const auto version = hipdnn_data_sdk::utilities::Version{plugin->apiVersion()}; + EXPECT_EQ(version.major, HIPDNN_HEURISTIC_API_VERSION_MAJOR); + + // 2. Plugin name is non-empty + EXPECT_FALSE(name.empty()); + + // 3. All policy IDs are unique and non-zero + for(const int64_t policyId : plugin->getAllPolicyIds()) + { + EXPECT_NE(policyId, 0); + } + } + } + EXPECT_TRUE(foundGood) << "test_good_heuristic_plugin should be loaded"; +} + +// ========== Validation Failure Tests ========== + +TEST_F(TestHeuristicPluginManagerValidationPaths, BadApiVersionPluginRejected) +{ + // ABSOLUTE mode accepts a single plugin file path, so we can load just the bad + // plugin directly from the build tree instead of copying it to a temp dir. + const auto badPlugin + = _testPluginPath / hipdnn_data_sdk::utilities::getLibraryName(BAD_API_VERSION_PLUGIN); + + // Without this precondition, loadPlugins silently no-ops on a missing file + // and the empty-plugins assertion below would pass vacuously. + ASSERT_TRUE(std::filesystem::exists(badPlugin)) + << "Test precondition: bad-API-version plugin missing at " << badPlugin; + + _manager->loadPlugins({badPlugin}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + // Built-in Config + StaticOrdering are always present; no external plugin should have loaded. + EXPECT_EQ(_manager->getPlugins().size(), 2u) << "Bad API version plugin should be rejected"; +} + +TEST_F(TestHeuristicPluginManagerValidationPaths, EmptyNamePluginRejected) +{ + const auto emptyNamePlugin + = _testPluginPath / hipdnn_data_sdk::utilities::getLibraryName(EMPTY_NAME_PLUGIN); + + ASSERT_TRUE(std::filesystem::exists(emptyNamePlugin)) + << "Test precondition: empty-name plugin missing at " << emptyNamePlugin; + + _manager->loadPlugins({emptyNamePlugin}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + // Built-in Config + StaticOrdering are always present; no external plugin should have loaded. + EXPECT_EQ(_manager->getPlugins().size(), 2u) << "Empty policy name plugin should be rejected"; +} + +TEST_F(TestHeuristicPluginManagerValidationPaths, DuplicatePolicyIdPluginsRejected) +{ + const auto pluginA = _testPluginPath + / hipdnn_data_sdk::utilities::getLibraryName(DUPLICATE_POLICY_ID_A_PLUGIN); + const auto pluginB = _testPluginPath + / hipdnn_data_sdk::utilities::getLibraryName(DUPLICATE_POLICY_ID_B_PLUGIN); + + if(!std::filesystem::exists(pluginA) || !std::filesystem::exists(pluginB)) + { + GTEST_SKIP() << "test_duplicate_policy_id plugins not found"; + } + + _manager->loadPlugins({pluginA, pluginB}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + // Built-in Config + StaticOrdering are always present, plus the first of the + // duplicate pair (pluginA). The second (pluginB) is rejected for duplicate policy ID. + const auto& plugins = _manager->getPlugins(); + ASSERT_EQ(plugins.size(), 3u) << "Built-ins + first duplicate plugin should be present"; + + // The survivor must be pluginA (first offered). Probe pluginA on its own to + // capture its policy IDs, then verify those IDs appear in the loaded set — + // without this, the size check would still pass if pluginA was rejected for + // an unrelated reason and pluginB loaded. + HeuristicPluginManager probeA; + probeA.loadPlugins({pluginA}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + ASSERT_EQ(probeA.getPlugins().size(), 3u) + << "pluginA should load successfully alongside the built-ins to be a valid baseline"; + + // Find the non-built-in plugin in the probe to get pluginA's policy IDs. + std::vector pluginAPolicyIds; + for(const auto& plugin : probeA.getPlugins()) + { + const std::string name(plugin->name()); + if(name != "BuiltInStaticOrderingHeuristic" && name != "BuiltInConfigHeuristic") + { + pluginAPolicyIds = plugin->getAllPolicyIds(); + break; + } + } + ASSERT_FALSE(pluginAPolicyIds.empty()) << "Probe failed to identify pluginA"; + + bool foundPluginA = false; + for(const auto& plugin : plugins) + { + if(plugin->getAllPolicyIds() == pluginAPolicyIds) + { + foundPluginA = true; + break; + } + } + EXPECT_TRUE(foundPluginA) << "Survivor should be pluginA (the first offered), not pluginB"; +} + +// ========== loadPluginFromFile Return-Value Regression ========== + +// Expose loadPluginFromFile() so we can directly observe its bool return. +// The bug it guards against: success was set to true at the top of the +// tryCatch lambda, so a throwing validateBeforeAdding() (e.g. bad API +// version) left success == true and the caller's failedCount silently +// stayed at zero. +class LoadPluginFromFileProbe : public HeuristicPluginManager +{ +public: + using HeuristicPluginManager::loadPluginFromFile; +}; + +TEST_F(TestHeuristicPluginManagerValidationPaths, LoadPluginFromFileReturnsFalseOnValidationFailure) +{ + const auto badPlugin + = _testPluginPath / hipdnn_data_sdk::utilities::getLibraryName(BAD_API_VERSION_PLUGIN); + ASSERT_TRUE(std::filesystem::exists(badPlugin)) + << "Test precondition: bad-API-version plugin missing at " << badPlugin; + + LoadPluginFromFileProbe probe; + const size_t pluginCountBefore = probe.getPlugins().size(); + + EXPECT_FALSE(probe.loadPluginFromFile(badPlugin)) + << "loadPluginFromFile must report failure when validateBeforeAdding throws"; + EXPECT_EQ(probe.getPlugins().size(), pluginCountBefore) + << "Rejected plugin must not be appended to _plugins"; +} + +TEST_F(TestHeuristicPluginManagerValidationPaths, LoadPluginFromFileReturnsTrueOnSuccess) +{ + const auto goodPlugin = getHeuristicPluginPath("test_good_heuristic_plugin"); + ASSERT_TRUE(std::filesystem::exists(goodPlugin)) + << "Test precondition: good plugin missing at " << goodPlugin; + + LoadPluginFromFileProbe probe; + const size_t pluginCountBefore = probe.getPlugins().size(); + + EXPECT_TRUE(probe.loadPluginFromFile(goodPlugin)); + EXPECT_EQ(probe.getPlugins().size(), pluginCountBefore + 1u); + + // A second load of the same file is an idempotent no-op (already in + // _loadedPluginFiles); it must also return true so failedCount is + // not inflated by retries. + EXPECT_TRUE(probe.loadPluginFromFile(goodPlugin)) + << "Idempotent reload of an already-loaded plugin must not count as failure"; + EXPECT_EQ(probe.getPlugins().size(), pluginCountBefore + 1u); +} + +// ========== Edge Case: Empty Plugin Directory ========== + +TEST_F(TestHeuristicPluginManagerValidationPaths, EmptyDirectorySkipsValidation) +{ + // ScopedDirectory creates the dir and remove_all's it on destruction, so + // the temp dir is cleaned up even if an assertion below aborts. + const auto uniqueName = std::string("hipdnn_empty_heur_test_") + + std::to_string(::testing::UnitTest::GetInstance()->random_seed()); + const hipdnn_test_sdk::utilities::ScopedDirectory emptyDir( + std::filesystem::temp_directory_path() / uniqueName); + + // Load from empty directory - no plugins to validate + _manager->loadPlugins({emptyDir.path()}, HIPDNN_PLUGIN_LOADING_ABSOLUTE); + + // Built-in Config + StaticOrdering are always present; the empty directory contributed nothing. + EXPECT_EQ(_manager->getPlugins().size(), 2u); +} diff --git a/projects/hipdnn/backend/tests/TestHeuristicPluginResourceManager.cpp b/projects/hipdnn/backend/tests/TestHeuristicPluginResourceManager.cpp index b9644e57048..a0089f466a5 100644 --- a/projects/hipdnn/backend/tests/TestHeuristicPluginResourceManager.cpp +++ b/projects/hipdnn/backend/tests/TestHeuristicPluginResourceManager.cpp @@ -3,14 +3,13 @@ /** * @file TestHeuristicPluginResourceManager.cpp - * @brief Unit tests for HeuristicPluginResourceManager (RFC 0007 Part 1) + * @brief Unit tests for HeuristicPluginResourceManager * * These tests verify the plugin resource management layer that provides * per-handle plugin lifecycle management and policy lookup. */ #include "HipdnnException.hpp" -#include "descriptors/mocks/MockHeuristicPlugin.hpp" #include "plugin/HeuristicPluginManager.hpp" #include "plugin/HeuristicPluginResourceManager.hpp" @@ -57,9 +56,10 @@ TEST_F(TestHeuristicPluginResourceManager, MoveConstructorTransfersOwnership) // Move construct const HeuristicPluginResourceManager rm2(std::move(*rm1)); - // rm2 should be usable + // rm2 should be usable. The shared plugin manager always contains the + // Config + StaticOrdering built-ins, so the policy-info list is not empty. const auto infos = rm2.getHeuristicPolicyInfos(); - EXPECT_TRUE(infos.empty()); // No plugins loaded + EXPECT_EQ(infos.size(), 2u); } TEST_F(TestHeuristicPluginResourceManager, MoveAssignmentTransfersOwnership) @@ -71,10 +71,11 @@ TEST_F(TestHeuristicPluginResourceManager, MoveAssignmentTransfersOwnership) // Move assign *rm2 = std::move(*rm1); - // rm2 should be usable + // rm2 should be usable. The shared plugin manager always contains the + // StaticOrdering built-in, so the policy-info list is not empty. const HeuristicPluginResourceManager& constRm2 = *rm2; const auto infos = constRm2.getHeuristicPolicyInfos(); - EXPECT_TRUE(infos.empty()); + EXPECT_EQ(infos.size(), 2u); } // ========== Policy Lookup Tests ========== @@ -108,8 +109,10 @@ TEST_F(TestHeuristicPluginResourceManager, GetPolicyInfosWhenNoPluginsLoaded) auto pm = std::make_shared(); auto rm = std::make_shared(pm); + // No external plugin paths configured, but the Config + StaticOrdering built-ins + // are always registered in the plugin manager's constructor. const auto infos = rm->getHeuristicPolicyInfos(); - EXPECT_TRUE(infos.empty()); + EXPECT_EQ(infos.size(), 2u); } TEST_F(TestHeuristicPluginResourceManager, GetPolicyInfosCachesResult) @@ -256,10 +259,11 @@ TEST_F(TestHeuristicPluginResourceManager, MultipleInstancesCanCoexist) EXPECT_NE(rm2, nullptr); EXPECT_NE(rm3, nullptr); - // Each should work independently - EXPECT_TRUE(rm1->getHeuristicPolicyInfos().empty()); - EXPECT_TRUE(rm2->getHeuristicPolicyInfos().empty()); - EXPECT_TRUE(rm3->getHeuristicPolicyInfos().empty()); + // Each should work independently. The shared plugin manager always contains + // the StaticOrdering built-in, so each resource manager observes it. + EXPECT_EQ(rm1->getHeuristicPolicyInfos().size(), 2u); + EXPECT_EQ(rm2->getHeuristicPolicyInfos().size(), 2u); + EXPECT_EQ(rm3->getHeuristicPolicyInfos().size(), 2u); } // ========== Copy Prevention Tests ========== @@ -288,9 +292,6 @@ TEST_F(TestHeuristicPluginResourceManager, DestructorCleansUpResources) // Use rm rm->getHeuristicPolicyInfos(); } // rm destroyed here - - // If we get here without crashes, cleanup succeeded - SUCCEED(); } TEST_F(TestHeuristicPluginResourceManager, MultipleDestructionsSucceed) @@ -303,8 +304,6 @@ TEST_F(TestHeuristicPluginResourceManager, MultipleDestructionsSucceed) rm->getHeuristicPolicyInfos(); // Destroyed at end of loop } - - SUCCEED(); } // ========== Constructor Null Pointer Tests ========== diff --git a/projects/hipdnn/backend/tests/TestHeuristicPolicyFramework.cpp b/projects/hipdnn/backend/tests/TestHeuristicPolicyFramework.cpp new file mode 100644 index 00000000000..48895f651c9 --- /dev/null +++ b/projects/hipdnn/backend/tests/TestHeuristicPolicyFramework.cpp @@ -0,0 +1,216 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file TestHeuristicPolicyFramework.cpp + * @brief Unit tests for the Heuristic Policy Framework + * + * Tests cover: + * - Policy enumeration and metadata via the public hipdnnGetHeuristicPolicy* APIs + * - Default-policy loading through the resource manager + * + * Lower-level coverage (policy order resolution, outer-loop failure handling, + * empty-policy-list throws) lives in TestEngineHeuristicDescriptorAdditional.cpp, + * where the mocked plugin manager makes those paths reachable in a unit test. + */ + +#include "handle/Handle.hpp" +#include "plugin/HeuristicPluginResourceManager.hpp" + +#include +#include + +using namespace hipdnn_backend; + +class TestHeuristicPolicyFramework : public ::testing::Test +{ +protected: + void SetUp() override + { + // hipdnnCreate loads real heuristic plugins (e.g. hipBLASLt in the + // superbuild) whose initializers probe the device. Skip on no-GPU + // runners to avoid a hard abort from the plugin's HIP error path. + SKIP_IF_NO_DEVICES(); + const hipdnnStatus_t status = hipdnnCreate(&_handle); + ASSERT_EQ(status, HIPDNN_STATUS_SUCCESS); + ASSERT_NE(_handle, nullptr); + } + + void TearDown() override + { + if(_handle != nullptr) + { + hipdnnDestroy(_handle); + _handle = nullptr; + } + } + + hipdnnHandle_t _handle = nullptr; +}; + +// ========== Policy Enumeration Tests ========== + +TEST_F(TestHeuristicPolicyFramework, GetHeuristicPolicyCountReturnsNonZero) +{ + size_t numPolicies = 0; + const hipdnnStatus_t status = hipdnnGetHeuristicPolicyCount_ext(_handle, &numPolicies); + + EXPECT_EQ(status, HIPDNN_STATUS_SUCCESS); + // At minimum, the StaticOrdering built-in should be loaded. + EXPECT_GE(numPolicies, 1u); +} + +TEST_F(TestHeuristicPolicyFramework, GetHeuristicPolicyInfoReturnsValidData) +{ + size_t numPolicies = 0; + ASSERT_EQ(hipdnnGetHeuristicPolicyCount_ext(_handle, &numPolicies), HIPDNN_STATUS_SUCCESS); + ASSERT_GT(numPolicies, 0u); + + // Query first policy (two-call pattern) + int64_t policyId = -1; + size_t policyNameLen = 0; + size_t pluginNameLen = 0; + size_t pluginVersionLen = 0; + size_t apiVersionLen = 0; + + // First call: query sizes + hipdnnStatus_t status = hipdnnGetHeuristicPolicyInfo_ext(_handle, + 0, + &policyId, + nullptr, + &policyNameLen, + nullptr, + &pluginNameLen, + nullptr, + &pluginVersionLen, + nullptr, + &apiVersionLen); + + ASSERT_EQ(status, HIPDNN_STATUS_SUCCESS); + EXPECT_NE(policyId, -1); + EXPECT_GT(policyNameLen, 0u); + EXPECT_GT(pluginNameLen, 0u); + EXPECT_GT(pluginVersionLen, 0u); + EXPECT_GT(apiVersionLen, 0u); + + // Second call: retrieve strings + std::vector policyName(policyNameLen); + std::vector pluginName(pluginNameLen); + std::vector pluginVersion(pluginVersionLen); + std::vector apiVersion(apiVersionLen); + + status = hipdnnGetHeuristicPolicyInfo_ext(_handle, + 0, + &policyId, + policyName.data(), + &policyNameLen, + pluginName.data(), + &pluginNameLen, + pluginVersion.data(), + &pluginVersionLen, + apiVersion.data(), + &apiVersionLen); + + EXPECT_EQ(status, HIPDNN_STATUS_SUCCESS); + EXPECT_GT(std::strlen(policyName.data()), 0u); + EXPECT_GT(std::strlen(pluginName.data()), 0u); + EXPECT_GT(std::strlen(pluginVersion.data()), 0u); + EXPECT_GT(std::strlen(apiVersion.data()), 0u); +} + +TEST_F(TestHeuristicPolicyFramework, GetHeuristicPolicyInfoOutOfRangeFails) +{ + size_t numPolicies = 0; + ASSERT_EQ(hipdnnGetHeuristicPolicyCount_ext(_handle, &numPolicies), HIPDNN_STATUS_SUCCESS); + + // Try to query beyond range + int64_t policyId = -1; + size_t policyNameLen = 0; + size_t pluginNameLen = 0; + size_t pluginVersionLen = 0; + size_t apiVersionLen = 0; + + const hipdnnStatus_t status = hipdnnGetHeuristicPolicyInfo_ext(_handle, + numPolicies + 100, + &policyId, + nullptr, + &policyNameLen, + nullptr, + &pluginNameLen, + nullptr, + &pluginVersionLen, + nullptr, + &apiVersionLen); + + EXPECT_EQ(status, HIPDNN_STATUS_BAD_PARAM); +} + +// Policy order resolution (descriptor / env / default), policy decline behavior, +// and "no policy succeeds" failure paths are covered with mocked plugin managers +// in descriptors/TestEngineHeuristicDescriptor.cpp. The StaticOrdering "never +// declines" contract is enforced by the built-in's Finalize implementation and +// exercised in heuristics/TestStaticOrderingBuiltIn.cpp. + +// ========== Integration Tests ========== + +TEST_F(TestHeuristicPolicyFramework, HeuristicResourceManagerLoadsDefaultPolicies) +{ + auto heurRm = _handle->getHeuristicPluginResourceManager(); + ASSERT_NE(heurRm, nullptr); + + auto policyInfos = heurRm->getHeuristicPolicyInfos(); + + // The StaticOrdering built-in is registered at construction time and is the + // canonical fallback policy. + EXPECT_GE(policyInfos.size(), 1u); + + bool hasStaticOrdering = false; + for(const auto& info : policyInfos) + { + if(info.policyName.find("StaticOrdering") != std::string::npos) + { + hasStaticOrdering = true; + } + } + + EXPECT_TRUE(hasStaticOrdering) << "StaticOrdering policy should be loaded"; +} + +// ========== Negative Tests ========== + +TEST_F(TestHeuristicPolicyFramework, GetPolicyCountWithNullHandleFails) +{ + size_t numPolicies = 0; + const hipdnnStatus_t status = hipdnnGetHeuristicPolicyCount_ext(nullptr, &numPolicies); + + EXPECT_NE(status, HIPDNN_STATUS_SUCCESS); +} + +TEST_F(TestHeuristicPolicyFramework, GetPolicyCountWithNullPointerFails) +{ + const hipdnnStatus_t status = hipdnnGetHeuristicPolicyCount_ext(_handle, nullptr); + + EXPECT_NE(status, HIPDNN_STATUS_SUCCESS); +} + +TEST_F(TestHeuristicPolicyFramework, GetPolicyInfoWithNullLengthPointersFails) +{ + int64_t policyId = -1; + + // All length pointers are required (not nullptr) + const hipdnnStatus_t status = hipdnnGetHeuristicPolicyInfo_ext(_handle, + 0, + &policyId, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr); + + EXPECT_NE(status, HIPDNN_STATUS_SUCCESS); +} + +// Note: gtest provides main(), do not define it here diff --git a/projects/hipdnn/backend/tests/TestSelectionHeuristic.cpp b/projects/hipdnn/backend/tests/TestSelectionHeuristic.cpp new file mode 100644 index 00000000000..1d53a5d98e3 --- /dev/null +++ b/projects/hipdnn/backend/tests/TestSelectionHeuristic.cpp @@ -0,0 +1,473 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file TestSelectionHeuristic.cpp + * @brief Unit tests for SelectionHeuristic RAII wrapper + * + * Tests the C++ facade that wraps hipdnnHeuristicPolicyDescriptor_t lifecycle + * and provides clean API over the heuristic plugin C ABI. + */ + +#include "heuristics/SelectionHeuristic.hpp" + +#include "descriptors/mocks/MockHeuristicPlugin.hpp" +#include "descriptors/mocks/MockHeuristicPluginResourceManager.hpp" +#include "plugin/HeuristicPlugin.hpp" + +#include +#include +#include + +using namespace hipdnn_backend::heuristics; +using namespace hipdnn_backend::plugin; +using ::testing::NiceMock; + +class TestSelectionHeuristic : public ::testing::Test +{ +protected: + void SetUp() override + { + // Create a mock plugin handle (just a non-null pointer for testing) + _mockHandle = reinterpret_cast(this); + _mockPlugin = std::make_unique(); + _mockResourceManager = std::make_shared>(); + } + + // Wires the mock resource manager to return _mockHandle / _mockPlugin.get() + // for the test policy ID. Use AnyNumber so cleanup-time lookups in the + // SelectionHeuristic destructor (and move ops) are also satisfied. + void wireResourceManager() + { + EXPECT_CALL(*_mockResourceManager, getHeuristicHandleForPolicyId(_policyId)) + .WillRepeatedly(::testing::Return(_mockHandle)); + EXPECT_CALL(*_mockResourceManager, getPluginForPolicyId(_policyId)) + .WillRepeatedly(::testing::Return(_mockPlugin.get())); + } + + hipdnnHeuristicHandle_t _mockHandle = nullptr; + std::unique_ptr _mockPlugin; + std::shared_ptr> _mockResourceManager; + int64_t _policyId = 12345; +}; + +// ========== Constructor Tests ========== + +TEST_F(TestSelectionHeuristic, ConstructorWithValidInputs) +{ + wireResourceManager(); + + // Expect createPolicyDescriptor to be called + auto mockDescriptor = reinterpret_cast(0x1234); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)).Times(1); + + // Should not throw + const SelectionHeuristic heuristic(_mockResourceManager, _policyId); +} + +TEST_F(TestSelectionHeuristic, ConstructorThrowsOnNullResourceManager) +{ + EXPECT_THROW( + { + SelectionHeuristic heuristic(nullptr, _policyId); // NOLINT(misc-const-correctness) + }, + hipdnn_backend::HipdnnException); +} + +TEST_F(TestSelectionHeuristic, ConstructorThrowsWhenPolicyHasNoHandle) +{ + // Manager reports no handle for this policy ID + EXPECT_CALL(*_mockResourceManager, getHeuristicHandleForPolicyId(_policyId)) + .WillRepeatedly(::testing::Return(nullptr)); + + EXPECT_THROW( + { const SelectionHeuristic heuristic(_mockResourceManager, _policyId); }, + hipdnn_backend::HipdnnException); +} + +TEST_F(TestSelectionHeuristic, ConstructorThrowsWhenPolicyHasNoPlugin) +{ + // Manager reports a handle but no plugin (defensive check) + EXPECT_CALL(*_mockResourceManager, getHeuristicHandleForPolicyId(_policyId)) + .WillRepeatedly(::testing::Return(_mockHandle)); + EXPECT_CALL(*_mockResourceManager, getPluginForPolicyId(_policyId)) + .WillRepeatedly(::testing::Return(nullptr)); + + EXPECT_THROW( + { const SelectionHeuristic heuristic(_mockResourceManager, _policyId); }, + hipdnn_backend::HipdnnException); +} + +// ========== Move Semantics Tests ========== + +TEST_F(TestSelectionHeuristic, MoveConstructor) +{ + wireResourceManager(); + + auto mockDescriptor = reinterpret_cast(0x5678); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + + // Destroy should be called exactly once (when moved-to object is destroyed) + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)).Times(1); + + // Regression: the move must carry _inputEngineIds too, otherwise the + // validator in getSortedEngineIds sees an empty candidate set and rejects + // every legitimate plugin output with HIPDNN_STATUS_PLUGIN_ERROR. + const std::vector inputIds = {1, 2, 3}; + const std::vector sortedIds = {3, 1, 2}; + EXPECT_CALL(*_mockPlugin, setEngineIds(mockDescriptor, ::testing::_, inputIds.size())).Times(1); + EXPECT_CALL(*_mockPlugin, getSortedEngineIds(mockDescriptor)) + .WillOnce(::testing::Return(sortedIds)); + + { + SelectionHeuristic heuristic1(_mockResourceManager, _policyId); + heuristic1.setEngineIds(inputIds); + + // Move construct + SelectionHeuristic heuristic2(std::move(heuristic1)); // NOLINT(misc-const-correctness) + + // heuristic2 should now own the descriptor and the input-ID candidate set + EXPECT_EQ(heuristic2.getSortedEngineIds(), sortedIds); + // heuristic1 should be empty (moved-from state) + } // Both destructors called, but only heuristic2 has valid descriptor +} + +TEST_F(TestSelectionHeuristic, MoveAssignment) +{ + wireResourceManager(); + + auto mockDescriptor1 = reinterpret_cast(0x1111); + auto mockDescriptor2 = reinterpret_cast(0x2222); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor1)) + .WillOnce(::testing::Return(mockDescriptor2)); + + // First descriptor destroyed during move assignment + // Second descriptor destroyed when moved-to object is destroyed + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor1)).Times(1); + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor2)).Times(1); + + // Regression: the move must carry _inputEngineIds too, so the validator + // in getSortedEngineIds still has the candidate set after the assignment. + const std::vector inputIds = {10, 20, 30}; + const std::vector sortedIds = {30, 20, 10}; + EXPECT_CALL(*_mockPlugin, setEngineIds(mockDescriptor2, ::testing::_, inputIds.size())) + .Times(1); + EXPECT_CALL(*_mockPlugin, getSortedEngineIds(mockDescriptor2)) + .WillOnce(::testing::Return(sortedIds)); + + { + SelectionHeuristic heuristic1(_mockResourceManager, _policyId); + SelectionHeuristic heuristic2(_mockResourceManager, _policyId); + heuristic2.setEngineIds(inputIds); + + // Move assign + heuristic1 = std::move(heuristic2); + + // heuristic1 should now own mockDescriptor2 and heuristic2's input IDs + // heuristic2 should be empty + // mockDescriptor1 should have been destroyed + EXPECT_EQ(heuristic1.getSortedEngineIds(), sortedIds); + } +} + +TEST_F(TestSelectionHeuristic, MoveAssignmentSelfAssignment) +{ + wireResourceManager(); + + auto mockDescriptor = reinterpret_cast(0x9999); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)).Times(1); + + { + SelectionHeuristic heuristic(_mockResourceManager, _policyId); + + // Self-assignment should be safe (use reference to avoid warning) + SelectionHeuristic& heuristicRef = heuristic; + heuristic = std::move(heuristicRef); + + // Descriptor should still be valid + } +} + +// ========== API Tests ========== + +TEST_F(TestSelectionHeuristic, SetEngineIds) +{ + wireResourceManager(); + + auto mockDescriptor = reinterpret_cast(0xAAAA); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + + std::vector testEngineIds = {1, 2, 3, 4, 5}; + + EXPECT_CALL( + *_mockPlugin, + setEngineIds(mockDescriptor, ::testing::Pointee(testEngineIds[0]), testEngineIds.size())) + .Times(1); + + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)).Times(1); + + SelectionHeuristic heuristic(_mockResourceManager, _policyId); + heuristic.setEngineIds(testEngineIds); +} + +TEST_F(TestSelectionHeuristic, SetSerializedGraph) +{ + wireResourceManager(); + + auto mockDescriptor = reinterpret_cast(0xBBBB); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + + std::vector graphData = {0x01, 0x02, 0x03}; + const hipdnnPluginConstData_t serializedGraph{graphData.data(), graphData.size()}; + + EXPECT_CALL(*_mockPlugin, setSerializedGraph(mockDescriptor, &serializedGraph)).Times(1); + + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)).Times(1); + + SelectionHeuristic heuristic(_mockResourceManager, _policyId); + heuristic.setSerializedGraph(&serializedGraph); +} + +TEST_F(TestSelectionHeuristic, FinalizeReturnsTrue) +{ + wireResourceManager(); + + auto mockDescriptor = reinterpret_cast(0xCCCC); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + + EXPECT_CALL(*_mockPlugin, finalize(mockDescriptor)).WillOnce(::testing::Return(true)); + + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)).Times(1); + + SelectionHeuristic heuristic(_mockResourceManager, _policyId); + EXPECT_TRUE(heuristic.finalize()); +} + +TEST_F(TestSelectionHeuristic, FinalizeReturnsFalse) +{ + wireResourceManager(); + + auto mockDescriptor = reinterpret_cast(0xDDDD); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + + EXPECT_CALL(*_mockPlugin, finalize(mockDescriptor)).WillOnce(::testing::Return(false)); + + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)).Times(1); + + SelectionHeuristic heuristic(_mockResourceManager, _policyId); + EXPECT_FALSE(heuristic.finalize()); +} + +TEST_F(TestSelectionHeuristic, GetSortedEngineIds) +{ + wireResourceManager(); + + auto mockDescriptor = reinterpret_cast(0xEEEE); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + + // The plugin output must be a permutation/subset of the IDs we hand in + // via setEngineIds — that's what SelectionHeuristic validates. + const std::vector inputIds = {1, 2, 3, 4, 5}; + const std::vector expectedIds = {5, 4, 3, 2, 1}; + + EXPECT_CALL(*_mockPlugin, setEngineIds(mockDescriptor, ::testing::_, inputIds.size())).Times(1); + EXPECT_CALL(*_mockPlugin, getSortedEngineIds(mockDescriptor)) + .WillOnce(::testing::Return(expectedIds)); + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)).Times(1); + + SelectionHeuristic heuristic(_mockResourceManager, _policyId); + heuristic.setEngineIds(inputIds); + auto result = heuristic.getSortedEngineIds(); + + EXPECT_EQ(result, expectedIds); +} + +TEST_F(TestSelectionHeuristic, GetSortedEngineIdsRejectsFabricatedId) +{ + wireResourceManager(); + + auto mockDescriptor = reinterpret_cast(0xEEEE); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + + const std::vector inputIds = {1, 2, 3}; + // Plugin returns an ID (99) that wasn't in the candidate set. + const std::vector badIds = {2, 99, 1}; + + EXPECT_CALL(*_mockPlugin, setEngineIds(mockDescriptor, ::testing::_, inputIds.size())).Times(1); + EXPECT_CALL(*_mockPlugin, getSortedEngineIds(mockDescriptor)) + .WillOnce(::testing::Return(badIds)); + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)).Times(1); + + SelectionHeuristic heuristic(_mockResourceManager, _policyId); + heuristic.setEngineIds(inputIds); + EXPECT_THROW(heuristic.getSortedEngineIds(), hipdnn_backend::HipdnnException); +} + +TEST_F(TestSelectionHeuristic, GetSortedEngineIdsRejectsDuplicates) +{ + wireResourceManager(); + + auto mockDescriptor = reinterpret_cast(0xEEEE); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + + const std::vector inputIds = {1, 2, 3}; + const std::vector badIds = {1, 2, 2}; + + EXPECT_CALL(*_mockPlugin, setEngineIds(mockDescriptor, ::testing::_, inputIds.size())).Times(1); + EXPECT_CALL(*_mockPlugin, getSortedEngineIds(mockDescriptor)) + .WillOnce(::testing::Return(badIds)); + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)).Times(1); + + SelectionHeuristic heuristic(_mockResourceManager, _policyId); + heuristic.setEngineIds(inputIds); + EXPECT_THROW(heuristic.getSortedEngineIds(), hipdnn_backend::HipdnnException); +} + +TEST_F(TestSelectionHeuristic, GetSortedEngineIdsRejectsOversizedOutput) +{ + wireResourceManager(); + + auto mockDescriptor = reinterpret_cast(0xEEEE); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + + const std::vector inputIds = {1, 2}; + const std::vector badIds = {1, 2, 3}; + + EXPECT_CALL(*_mockPlugin, setEngineIds(mockDescriptor, ::testing::_, inputIds.size())).Times(1); + EXPECT_CALL(*_mockPlugin, getSortedEngineIds(mockDescriptor)) + .WillOnce(::testing::Return(badIds)); + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)).Times(1); + + SelectionHeuristic heuristic(_mockResourceManager, _policyId); + heuristic.setEngineIds(inputIds); + EXPECT_THROW(heuristic.getSortedEngineIds(), hipdnn_backend::HipdnnException); +} + +TEST_F(TestSelectionHeuristic, GetSortedEngineIdsAcceptsProperSubset) +{ + wireResourceManager(); + + auto mockDescriptor = reinterpret_cast(0xEEEE); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + + // Plugin may decline some candidates and return a strict subset. + const std::vector inputIds = {1, 2, 3, 4, 5}; + const std::vector expectedIds = {3, 1}; + + EXPECT_CALL(*_mockPlugin, setEngineIds(mockDescriptor, ::testing::_, inputIds.size())).Times(1); + EXPECT_CALL(*_mockPlugin, getSortedEngineIds(mockDescriptor)) + .WillOnce(::testing::Return(expectedIds)); + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)).Times(1); + + SelectionHeuristic heuristic(_mockResourceManager, _policyId); + heuristic.setEngineIds(inputIds); + EXPECT_EQ(heuristic.getSortedEngineIds(), expectedIds); +} + +// ========== Lifetime Tests ========== + +// Verifies the SelectionHeuristic keeps the resource manager alive even if the +// caller drops its own shared_ptr — this is the core memory-safety guarantee +// the shared_ptr-to-manager design provides. +TEST_F(TestSelectionHeuristic, KeepsResourceManagerAliveAcrossCallerRelease) +{ + wireResourceManager(); + + auto mockDescriptor = reinterpret_cast(0xABCD); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + EXPECT_CALL(*_mockPlugin, finalize(mockDescriptor)).WillOnce(::testing::Return(true)); + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)).Times(1); + + const std::weak_ptr> weakManager + = _mockResourceManager; + + SelectionHeuristic heuristic(_mockResourceManager, _policyId); + + // Caller releases its strong reference; the slot should still hold one. + _mockResourceManager.reset(); + EXPECT_FALSE(weakManager.expired()); + + // Operations through the slot still work because the manager is alive. + EXPECT_TRUE(heuristic.finalize()); +} + +// ========== Exception Safety Tests ========== + +TEST_F(TestSelectionHeuristic, DestructorHandlesExceptionInCleanup) +{ + wireResourceManager(); + + auto mockDescriptor = reinterpret_cast(0xFFFF); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor)); + + // Destructor should catch and suppress exceptions + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor)) + .WillOnce(::testing::Throw( + hipdnn_backend::HipdnnException(HIPDNN_STATUS_INTERNAL_ERROR, "Cleanup failed"))); + + // Should not throw from destructor + { + const SelectionHeuristic heuristic(_mockResourceManager, _policyId); + } +} + +TEST_F(TestSelectionHeuristic, MoveAssignmentHandlesExceptionInCleanup) +{ + wireResourceManager(); + + auto mockDescriptor1 = reinterpret_cast(0x1001); + auto mockDescriptor2 = reinterpret_cast(0x2002); + + EXPECT_CALL(*_mockPlugin, createPolicyDescriptor(_mockHandle, _policyId)) + .WillOnce(::testing::Return(mockDescriptor1)) + .WillOnce(::testing::Return(mockDescriptor2)); + + // First descriptor cleanup throws during move assignment + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor1)) + .WillOnce(::testing::Throw(std::runtime_error("Cleanup error"))); + + // Second descriptor cleanup succeeds + EXPECT_CALL(*_mockPlugin, destroyPolicyDescriptor(mockDescriptor2)).Times(1); + + { + SelectionHeuristic heuristic1(_mockResourceManager, _policyId); + SelectionHeuristic heuristic2(_mockResourceManager, _policyId); + + // Move assignment should not throw even though cleanup of old descriptor throws + heuristic1 = std::move(heuristic2); + } +} diff --git a/projects/hipdnn/backend/tests/descriptors/TestEngineHeuristicDescriptor.cpp b/projects/hipdnn/backend/tests/descriptors/TestEngineHeuristicDescriptor.cpp index f31fe849baa..2894166c9ba 100644 --- a/projects/hipdnn/backend/tests/descriptors/TestEngineHeuristicDescriptor.cpp +++ b/projects/hipdnn/backend/tests/descriptors/TestEngineHeuristicDescriptor.cpp @@ -9,13 +9,20 @@ #include "descriptors/EngineHeuristicDescriptor.hpp" #include "descriptors/GraphDescriptor.hpp" #include "descriptors/ScopedDescriptor.hpp" +#include "heuristics/SelectionHeuristic.hpp" #include "hipdnn_backend.h" #include "mocks/MockDescriptor.hpp" #include "mocks/MockEnginePluginResourceManager.hpp" #include "mocks/MockHandle.hpp" +#include "mocks/MockHeuristicPlugin.hpp" +#include "mocks/MockHeuristicPluginResourceManager.hpp" #include +#include +#include #include +#include +#include #include #include @@ -58,6 +65,13 @@ class TestEngineHeuristicDescriptor : public ::testing::Test EXPECT_CALL(*getMockGraph(), getHandle()).WillRepeatedly(Return(_mockHandle.get())); EXPECT_CALL(*_mockHandle, getPluginResourceManager()) .WillRepeatedly(Return(_mockEnginePluginResourceManager)); + EXPECT_CALL(*_mockHandle, getHeuristicPluginResourceManager()) + .WillRepeatedly(Return(_mockHeuristicPluginResourceManager)); + EXPECT_CALL(*_mockHandle, getStream()).WillRepeatedly(Return(_testStream)); + + // Set up mock heuristic plugin automatically when graph is set + setupMockHeuristicPlugin(); + ASSERT_NO_THROW( getEngineHeuristicDescriptor()->setAttribute(HIPDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH, HIPDNN_TYPE_BACKEND_DESCRIPTOR, @@ -74,20 +88,83 @@ class TestEngineHeuristicDescriptor : public ::testing::Test void makeEngineHeuristicFinalized() const { - setGraph(); + setGraph(); // This now automatically sets up the heuristic mock setHeuristicMode(); EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) .WillRepeatedly(Return(std::vector{0, 1, 2})); ASSERT_NO_THROW(getEngineHeuristicDescriptor()->finalize()); } + void setupMockHeuristicPlugin() const + { + const int64_t staticOrderingPolicyId + = hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"); + + // Mock handles and descriptors + auto mockHandle = reinterpret_cast(0x1234); + auto mockDescriptor = reinterpret_cast(0x5678); + + // Catch-all for unknown policy IDs (must be set first; gmock matches LIFO). + EXPECT_CALL(*_mockHeuristicPluginResourceManager, getPluginForPolicyId(_)) + .WillRepeatedly(Return(nullptr)); + EXPECT_CALL(*_mockHeuristicPluginResourceManager, getHeuristicHandleForPolicyId(_)) + .WillRepeatedly(Return(nullptr)); + + // StaticOrdering policy succeeds + EXPECT_CALL(*_mockHeuristicPluginResourceManager, + getPluginForPolicyId(staticOrderingPolicyId)) + .WillRepeatedly(Return(_mockHeuristicPlugin.get())); + EXPECT_CALL(*_mockHeuristicPluginResourceManager, + getHeuristicHandleForPolicyId(staticOrderingPolicyId)) + .WillRepeatedly(Return(mockHandle)); + EXPECT_CALL(*_mockHeuristicPluginResourceManager, setDevicePropertiesOnAllHandles(_)) + .WillRepeatedly(Return()); + + // Set up expectations for the mock plugin + EXPECT_CALL(*_mockHeuristicPlugin, setDeviceProperties(mockHandle, _)) + .WillRepeatedly(Return()); + EXPECT_CALL(*_mockHeuristicPlugin, + createPolicyDescriptor(mockHandle, staticOrderingPolicyId)) + .WillRepeatedly(Return(mockDescriptor)); + EXPECT_CALL(*_mockHeuristicPlugin, destroyPolicyDescriptor(mockDescriptor)) + .WillRepeatedly(Return()); + + // Store engine IDs when setEngineIds is called + EXPECT_CALL(*_mockHeuristicPlugin, setEngineIds(mockDescriptor, _, _)) + .WillRepeatedly([this](hipdnnHeuristicPolicyDescriptor_t, + const int64_t* engineIds, + size_t engineIdCount) { + _mockStoredEngineIds.assign(engineIds, engineIds + engineIdCount); + }); + + EXPECT_CALL(*_mockHeuristicPlugin, setSerializedGraph(mockDescriptor, _)) + .WillRepeatedly(Return()); + EXPECT_CALL(*_mockHeuristicPlugin, finalize(mockDescriptor)) + .WillRepeatedly(Return(true)); // Always succeed + + // Return the same engine IDs that were set + EXPECT_CALL(*_mockHeuristicPlugin, getSortedEngineIds(mockDescriptor)) + .WillRepeatedly([this]() { return _mockStoredEngineIds; }); + + EXPECT_CALL(*getMockGraph(), getSerializedGraph()).WillRepeatedly([]() { + static const std::vector s_dummyData = {0x01, 0x02, 0x03}; + return hipdnnPluginConstData_t{s_dummyData.data(), s_dummyData.size()}; + }); + } + protected: std::unique_ptr _engineHeuristicWrapper = nullptr; std::unique_ptr _mockGraphWrapper = nullptr; std::unique_ptr _mockGraphBadTypeWrapper = nullptr; std::unique_ptr _mockWrongTypeWrapper = nullptr; - std::unique_ptr _mockHandle = nullptr; - std::shared_ptr _mockEnginePluginResourceManager = nullptr; + std::unique_ptr> _mockHandle = nullptr; + std::shared_ptr> _mockEnginePluginResourceManager + = nullptr; + std::shared_ptr> + _mockHeuristicPluginResourceManager = nullptr; + std::shared_ptr> _mockHeuristicPlugin = nullptr; + mutable std::vector _mockStoredEngineIds; + hipStream_t _testStream = nullptr; void SetUp() override { @@ -95,12 +172,21 @@ class TestEngineHeuristicDescriptor : public ::testing::Test _mockGraphWrapper = createDescriptor(); _mockGraphBadTypeWrapper = createDescriptor(); _mockWrongTypeWrapper = createDescriptor>(); - _mockHandle = std::make_unique(); - _mockEnginePluginResourceManager = std::make_shared(); + _mockHandle = std::make_unique>(); + _mockEnginePluginResourceManager + = std::make_shared>(); + _mockHeuristicPluginResourceManager + = std::make_shared>(); + _mockHeuristicPlugin = std::make_shared>(); } void TearDown() override { + // Destroy descriptor before mocks to ensure proper cleanup order + _engineHeuristicWrapper.reset(); + _mockGraphWrapper.reset(); + _mockGraphBadTypeWrapper.reset(); + _mockWrongTypeWrapper.reset(); _engineDetailBuffers.clear(); } @@ -120,6 +206,33 @@ class TestEngineHeuristicDescriptor : public ::testing::Test std::vector _engineDetailBuffers; }; +// GPU-requiring variant. finalize() reads the device through +// hipStreamGetDevice(handle->getStream(), ...) once getApplicableEngineIds +// returns a non-empty list, so any test that finalizes with results needs a +// real stream the MockHandle can return. Tests that only exercise descriptor +// validation, attribute setters/getters, or finalize-with-empty-engines stay +// on the base fixture and continue to run on no-GPU CI runners. +class TestGpuEngineHeuristicDescriptor : public TestEngineHeuristicDescriptor +{ +protected: + void SetUp() override + { + SKIP_IF_NO_DEVICES(); + TestEngineHeuristicDescriptor::SetUp(); + ASSERT_EQ(hipStreamCreate(&_testStream), hipSuccess); + } + + void TearDown() override + { + if(_testStream != nullptr) + { + EXPECT_EQ(hipStreamDestroy(_testStream), hipSuccess); + _testStream = nullptr; + } + TestEngineHeuristicDescriptor::TearDown(); + } +}; + TEST_F(TestEngineHeuristicDescriptor, CreateEngineHeuristicDescriptor) { auto heur = getEngineHeuristicDescriptor(); @@ -219,7 +332,7 @@ TEST_F(TestEngineHeuristicDescriptor, SetEngineHeuristicDescriptorUnsupportedAtt HIPDNN_STATUS_NOT_SUPPORTED); } -TEST_F(TestEngineHeuristicDescriptor, SetAttrOnFinalizedEngineHeuristicDescriptor) +TEST_F(TestGpuEngineHeuristicDescriptor, SetAttrOnFinalizedEngineHeuristicDescriptor) { auto heur = getEngineHeuristicDescriptor(); makeEngineHeuristicFinalized(); @@ -278,7 +391,7 @@ TEST_F(TestEngineHeuristicDescriptor, GetAttrOnUnfinalizedEngineHeuristicDescrip HIPDNN_STATUS_BAD_PARAM_NOT_FINALIZED); } -TEST_F(TestEngineHeuristicDescriptor, GetEngineHeuristicDescriptorUnsupportedAttr) +TEST_F(TestGpuEngineHeuristicDescriptor, GetEngineHeuristicDescriptorUnsupportedAttr) { auto heur = getEngineHeuristicDescriptor(); hipdnnBackendHeurMode_t dummy; @@ -290,7 +403,7 @@ TEST_F(TestEngineHeuristicDescriptor, GetEngineHeuristicDescriptorUnsupportedAtt HIPDNN_STATUS_NOT_SUPPORTED); } -TEST_F(TestEngineHeuristicDescriptor, GetEngineHeuristicDescriptorGraph) +TEST_F(TestGpuEngineHeuristicDescriptor, GetEngineHeuristicDescriptorGraph) { auto heur = getEngineHeuristicDescriptor(); ScopedDescriptor graph; @@ -333,7 +446,7 @@ TEST_F(TestEngineHeuristicDescriptor, GetEngineHeuristicDescriptorGraph) ASSERT_EQ(count, 1); } -TEST_F(TestEngineHeuristicDescriptor, GetEngineHeuristicDescriptorEngineConfigs) +TEST_F(TestGpuEngineHeuristicDescriptor, GetEngineHeuristicDescriptorEngineConfigs) { auto heur = getEngineHeuristicDescriptor(); makeEngineHeuristicFinalized(); @@ -361,10 +474,16 @@ TEST_F(TestEngineHeuristicDescriptor, GetEngineHeuristicDescriptorEngineConfigs) HIPDNN_ATTR_ENGINEHEUR_RESULTS, HIPDNN_TYPE_BACKEND_DESCRIPTOR, 1, nullptr, nullptr), HIPDNN_STATUS_BAD_PARAM_NULL_POINTER); - std::vector configs(3); - for(size_t i = 0; i < 3; ++i) + std::vector ownedConfigs(3); + for(auto& owned : ownedConfigs) { - configs[i] = createDescriptorPtr(); + owned = ScopedDescriptor(createDescriptorPtr()); + } + std::vector configs; + configs.reserve(ownedConfigs.size()); + for(auto& owned : ownedConfigs) + { + configs.push_back(owned.get()); } ASSERT_THROW_HIPDNN_STATUS(heur->getAttribute(HIPDNN_ATTR_ENGINEHEUR_RESULTS, @@ -382,13 +501,6 @@ TEST_F(TestEngineHeuristicDescriptor, GetEngineHeuristicDescriptorEngineConfigs) static_cast(configs.data()))); ASSERT_EQ(count, 3); - for(auto config : configs) - { - delete config; - } - - configs.clear(); - ScopedDescriptor singleConfig(createDescriptorPtr()); count = 0; @@ -397,7 +509,7 @@ TEST_F(TestEngineHeuristicDescriptor, GetEngineHeuristicDescriptorEngineConfigs) ASSERT_EQ(count, 1); } -TEST_F(TestEngineHeuristicDescriptor, GetEngineConfigsWithNullConfig) +TEST_F(TestGpuEngineHeuristicDescriptor, GetEngineConfigsWithNullConfig) { auto heur = getEngineHeuristicDescriptor(); makeEngineHeuristicFinalized(); @@ -411,10 +523,17 @@ TEST_F(TestEngineHeuristicDescriptor, GetEngineConfigsWithNullConfig) EXPECT_CALL(*_mockEnginePluginResourceManager, destroyEngineDetails(_, _)) .WillRepeatedly(Return()); - std::vector configs(3); - configs[0] = createDescriptorPtr(); - configs[1] = nullptr; - configs[2] = createDescriptorPtr(); + std::vector ownedConfigs(3); + ownedConfigs[0] = ScopedDescriptor(createDescriptorPtr()); + // ownedConfigs[1] left as default (nullptr) to trigger the null-element path + ownedConfigs[2] = ScopedDescriptor(createDescriptorPtr()); + + std::vector configs; + configs.reserve(ownedConfigs.size()); + for(auto& owned : ownedConfigs) + { + configs.push_back(owned.get()); + } EXPECT_CALL(*getMockGraph(), isFinalized()).WillRepeatedly(Return(true)); int64_t count = 0; @@ -424,11 +543,6 @@ TEST_F(TestEngineHeuristicDescriptor, GetEngineConfigsWithNullConfig) &count, configs.data()), HIPDNN_STATUS_BAD_PARAM_NULL_POINTER); - - for(auto config : configs) - { - delete config; - } } TEST_F(TestEngineHeuristicDescriptor, GetEngineConfigsWithNoEngineIds) @@ -442,10 +556,16 @@ TEST_F(TestEngineHeuristicDescriptor, GetEngineConfigsWithNoEngineIds) ASSERT_NO_THROW(heur->finalize()); - std::vector configs(3); - for(size_t i = 0; i < 3; ++i) + std::vector ownedConfigs(3); + for(auto& owned : ownedConfigs) { - configs[i] = createDescriptorPtr(); + owned = ScopedDescriptor(createDescriptorPtr()); + } + std::vector configs; + configs.reserve(ownedConfigs.size()); + for(auto& owned : ownedConfigs) + { + configs.push_back(owned.get()); } int64_t count = 0; @@ -455,14 +575,9 @@ TEST_F(TestEngineHeuristicDescriptor, GetEngineConfigsWithNoEngineIds) &count, static_cast(configs.data()))); ASSERT_EQ(count, 0); - - for(auto config : configs) - { - delete config; - } } -TEST_F(TestEngineHeuristicDescriptor, GetEngineConfigsRequestMoreThanAvailable) +TEST_F(TestGpuEngineHeuristicDescriptor, GetEngineConfigsRequestMoreThanAvailable) { auto heur = getEngineHeuristicDescriptor(); makeEngineHeuristicFinalized(); @@ -476,10 +591,16 @@ TEST_F(TestEngineHeuristicDescriptor, GetEngineConfigsRequestMoreThanAvailable) EXPECT_CALL(*_mockEnginePluginResourceManager, destroyEngineDetails(_, _)) .WillRepeatedly(Return()); - std::vector configs(5); + std::vector ownedConfigs(5); for(size_t i = 0; i < 3; ++i) { - configs[i] = createDescriptorPtr(); + ownedConfigs[i] = ScopedDescriptor(createDescriptorPtr()); + } + std::vector configs; + configs.reserve(ownedConfigs.size()); + for(auto& owned : ownedConfigs) + { + configs.push_back(owned.get()); } int64_t count = 0; @@ -489,14 +610,9 @@ TEST_F(TestEngineHeuristicDescriptor, GetEngineConfigsRequestMoreThanAvailable) &count, static_cast(configs.data()))); ASSERT_EQ(count, 3); - - for(auto config : configs) - { - delete config; - } } -TEST_F(TestEngineHeuristicDescriptor, GetEngineConfigsCountOnly) +TEST_F(TestGpuEngineHeuristicDescriptor, GetEngineConfigsCountOnly) { auto heur = getEngineHeuristicDescriptor(); makeEngineHeuristicFinalized(); @@ -509,7 +625,7 @@ TEST_F(TestEngineHeuristicDescriptor, GetEngineConfigsCountOnly) ASSERT_EQ(count, 3); } -TEST_F(TestEngineHeuristicDescriptor, GetEngineHeuristicDescriptorHeurMode) +TEST_F(TestGpuEngineHeuristicDescriptor, GetEngineHeuristicDescriptorHeurMode) { auto heur = getEngineHeuristicDescriptor(); hipdnnBackendHeurMode_t mode = HIPDNN_HEUR_MODE_FALLBACK; @@ -546,7 +662,7 @@ TEST_F(TestEngineHeuristicDescriptor, GetGraphThrowsIfNotFinalized) ASSERT_THROW_HIPDNN_STATUS(heur->getGraph(), HIPDNN_STATUS_INTERNAL_ERROR); } -TEST_F(TestEngineHeuristicDescriptor, GetGraphReturnsPointerIfFinalized) +TEST_F(TestGpuEngineHeuristicDescriptor, GetGraphReturnsPointerIfFinalized) { auto heur = getEngineHeuristicDescriptor(); makeEngineHeuristicFinalized(); @@ -573,7 +689,7 @@ TEST_F(TestEngineHeuristicDescriptor, SetFindFirstInvalidType) HIPDNN_STATUS_BAD_PARAM); } -TEST_F(TestEngineHeuristicDescriptor, GetFindFirstAfterFinalize) +TEST_F(TestGpuEngineHeuristicDescriptor, GetFindFirstAfterFinalize) { auto heur = getEngineHeuristicDescriptor(); bool findFirst = true; @@ -594,7 +710,7 @@ TEST_F(TestEngineHeuristicDescriptor, GetFindFirstAfterFinalize) ASSERT_EQ(count, 1); } -TEST_F(TestEngineHeuristicDescriptor, FinalizeWithFindFirstPassesToPluginManager) +TEST_F(TestGpuEngineHeuristicDescriptor, FinalizeWithFindFirstPassesToPluginManager) { auto heur = getEngineHeuristicDescriptor(); bool findFirst = true; @@ -607,3 +723,567 @@ TEST_F(TestEngineHeuristicDescriptor, FinalizeWithFindFirstPassesToPluginManager .WillOnce(Return(std::vector{1})); ASSERT_NO_THROW(heur->finalize()); } + +// ========== Policy Order API Tests ========== + +TEST_F(TestEngineHeuristicDescriptor, SetPolicyOrderValid) +{ + auto heur = getEngineHeuristicDescriptor(); + + const std::vector policyIds = { + hipdnn_data_sdk::utilities::policyNameToId("Policy1"), + hipdnn_data_sdk::utilities::policyNameToId("Policy2"), + hipdnn_data_sdk::utilities::policyNameToId("Policy3"), + }; + + ASSERT_NO_THROW(heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(policyIds.size()), + policyIds.data())); +} + +TEST_F(TestEngineHeuristicDescriptor, SetPolicyOrderInvalidType) +{ + auto heur = getEngineHeuristicDescriptor(); + const char dummy = '\0'; + + ASSERT_THROW_HIPDNN_STATUS( + heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, HIPDNN_TYPE_CHAR, 1, &dummy), + HIPDNN_STATUS_BAD_PARAM); +} + +TEST_F(TestEngineHeuristicDescriptor, SetPolicyOrderNullPointer) +{ + auto heur = getEngineHeuristicDescriptor(); + + ASSERT_THROW_HIPDNN_STATUS( + heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, HIPDNN_TYPE_INT64, 1, nullptr), + HIPDNN_STATUS_BAD_PARAM_NULL_POINTER); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, GetPolicyOrderWhenNotSet) +{ + // Make sure no env-var override leaks in from the surrounding shell. + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter envGuard( + "HIPDNN_HEUR_POLICY_ORDER", ""); + + auto heur = getEngineHeuristicDescriptor(); + setGraph(); + setHeuristicMode(); + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1})); + ASSERT_NO_THROW(heur->finalize()); + + // With no descriptor-level override and no env var, resolveHeuristicPolicyOrder + // returns the built-in default: Config first, then StaticOrdering. + int64_t count = 999; + ASSERT_NO_THROW(heur->getAttribute( + HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, HIPDNN_TYPE_INT64, 0, &count, nullptr)); + ASSERT_EQ(count, 2); + + std::vector buffer(2); + ASSERT_NO_THROW(heur->getAttribute( + HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, HIPDNN_TYPE_INT64, 2, &count, buffer.data())); + ASSERT_EQ(count, 2); + EXPECT_EQ(buffer[0], hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::Config")); + EXPECT_EQ(buffer[1], + hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering")); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, GetPolicyOrderCountOnly) +{ + auto heur = getEngineHeuristicDescriptor(); + + // Caller-provided policy list is preserved verbatim; nothing is prepended. + const std::vector policyIds = { + hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"), + }; + + ASSERT_NO_THROW(heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(policyIds.size()), + policyIds.data())); + + setGraph(); + setHeuristicMode(); + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1})); + ASSERT_NO_THROW(heur->finalize()); + + int64_t count = 0; + ASSERT_NO_THROW(heur->getAttribute( + HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, HIPDNN_TYPE_INT64, 0, &count, nullptr)); + ASSERT_EQ(count, static_cast(policyIds.size())); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, GetPolicyOrderInvalidType) +{ + auto heur = getEngineHeuristicDescriptor(); + setGraph(); + setHeuristicMode(); + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1})); + ASSERT_NO_THROW(heur->finalize()); + + std::vector buffer(256); + int64_t count = 0; + ASSERT_THROW_HIPDNN_STATUS( + heur->getAttribute( + HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, HIPDNN_TYPE_CHAR, 256, &count, buffer.data()), + HIPDNN_STATUS_BAD_PARAM); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, GetPolicyOrderNullPointer) +{ + auto heur = getEngineHeuristicDescriptor(); + + const std::vector policyIds = { + hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"), + }; + + ASSERT_NO_THROW(heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(policyIds.size()), + policyIds.data())); + + setGraph(); + setHeuristicMode(); + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1})); + ASSERT_NO_THROW(heur->finalize()); + + ASSERT_THROW_HIPDNN_STATUS( + heur->getAttribute( + HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, HIPDNN_TYPE_INT64, 0, nullptr, nullptr), + HIPDNN_STATUS_BAD_PARAM_NULL_POINTER); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, GetPolicyOrderNegativeRequestedCount) +{ + auto heur = getEngineHeuristicDescriptor(); + + const std::vector policyIds = { + hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"), + }; + + ASSERT_NO_THROW(heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(policyIds.size()), + policyIds.data())); + + setGraph(); + setHeuristicMode(); + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1})); + ASSERT_NO_THROW(heur->finalize()); + + std::vector buffer(4); + int64_t count = 0; + ASSERT_THROW_HIPDNN_STATUS( + heur->getAttribute( + HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, HIPDNN_TYPE_INT64, -1, &count, buffer.data()), + HIPDNN_STATUS_BAD_PARAM); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, GetPolicyOrderBufferTooSmall) +{ + auto heur = getEngineHeuristicDescriptor(); + + // The caller-supplied list is stored verbatim; no dedup, no prepend. + const int64_t firstId + = hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"); + const int64_t secondId = hipdnn_data_sdk::utilities::policyNameToId("Vendor::Other"); + const std::vector policyIds = {firstId, secondId}; + + ASSERT_NO_THROW(heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(policyIds.size()), + policyIds.data())); + + setGraph(); + setHeuristicMode(); + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1})); + ASSERT_NO_THROW(heur->finalize()); + + // Request fewer elements than the descriptor holds; should truncate. + std::vector buffer(1); + int64_t count = 0; + ASSERT_NO_THROW(heur->getAttribute( + HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, HIPDNN_TYPE_INT64, 1, &count, buffer.data())); + ASSERT_EQ(count, 1); + ASSERT_EQ(buffer[0], firstId); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, GetPolicyOrderRoundTrip) +{ + auto heur = getEngineHeuristicDescriptor(); + + // The descriptor stores the caller-supplied list verbatim, including + // duplicates and unknown policies — nothing is prepended or dedup'd. + const int64_t otherId = hipdnn_data_sdk::utilities::policyNameToId("Vendor::Other"); + const int64_t staticOrderingId + = hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"); + const std::vector policyIds = {staticOrderingId, otherId, staticOrderingId}; + const std::vector& expected = policyIds; + + ASSERT_NO_THROW(heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(policyIds.size()), + policyIds.data())); + + setGraph(); + setHeuristicMode(); + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1})); + ASSERT_NO_THROW(heur->finalize()); + + std::vector getBuffer(expected.size()); + int64_t count = 0; + ASSERT_NO_THROW(heur->getAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(getBuffer.size()), + &count, + getBuffer.data())); + + ASSERT_EQ(count, static_cast(expected.size())); + for(size_t i = 0; i < expected.size(); ++i) + { + ASSERT_EQ(getBuffer[i], expected[i]); + } +} + +// ========== Exception Handling Tests ========== + +TEST_F(TestGpuEngineHeuristicDescriptor, FinalizeWithAllPoliciesFailing) +{ + auto heur = getEngineHeuristicDescriptor(); + setGraph(); + setHeuristicMode(); + + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1, 2})); + + // Make both policies fail + auto mockDescriptor = reinterpret_cast(0x5678); + + EXPECT_CALL(*_mockHeuristicPlugin, finalize(mockDescriptor)).WillRepeatedly(Return(false)); + + // finalize() should throw when all policies fail + ASSERT_THROW_HIPDNN_STATUS(heur->finalize(), HIPDNN_STATUS_INTERNAL_ERROR); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, FinalizeWithPolicyThrowingException) +{ + auto heur = getEngineHeuristicDescriptor(); + setGraph(); + setHeuristicMode(); + + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1})); + + auto mockDescriptor = reinterpret_cast(0x5678); + + // First call to setEngineIds throws + EXPECT_CALL(*_mockHeuristicPlugin, setEngineIds(mockDescriptor, _, _)) + .WillOnce( + Throw(HipdnnException(HIPDNN_STATUS_INTERNAL_ERROR, "Mock setEngineIds failure"))); + + // finalize() should throw when all policies fail (including exception paths) + ASSERT_THROW_HIPDNN_STATUS(heur->finalize(), HIPDNN_STATUS_INTERNAL_ERROR); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, FinalizeWithSetDevicePropertiesThrowingDisablesSlot) +{ + // setDeviceProperties failure for a plugin must disable that plugin's slots + // (mirroring the policy loop's fail-soft contract). With the only available + // policy disabled, finalize falls through to the "no policy succeeded" throw. + auto heur = getEngineHeuristicDescriptor(); + setGraph(); + setHeuristicMode(); + + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1, 2})); + + auto mockHandle = reinterpret_cast(0x1234); + EXPECT_CALL(*_mockHeuristicPlugin, setDeviceProperties(mockHandle, _)) + .WillOnce(Throw( + HipdnnException(HIPDNN_STATUS_INTERNAL_ERROR, "Mock setDeviceProperties failure"))); + + ASSERT_THROW_HIPDNN_STATUS(heur->finalize(), HIPDNN_STATUS_INTERNAL_ERROR); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, + FinalizeWithSetDevicePropertiesFailingForOnePluginContinuesWithOthers) +{ + // When one plugin's setDeviceProperties throws, only that plugin's policy + // slots are disabled. Policies backed by other plugins still get + // setDeviceProperties called and remain selectable. + const int64_t failingPolicyId = hipdnn_data_sdk::utilities::policyNameToId("Vendor::Failing"); + const int64_t staticOrderingId + = hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"); + + auto failingHandle = reinterpret_cast(0xABCD); + auto failingPlugin = std::make_shared>(); + + // Wire the failing policy to a distinct plugin/handle. Registering after + // setupMockHeuristicPlugin's catch-all (LIFO match) routes failingPolicyId + // to this plugin while staticOrderingId continues to use _mockHeuristicPlugin. + EXPECT_CALL(*_mockHeuristicPluginResourceManager, getPluginForPolicyId(failingPolicyId)) + .WillRepeatedly(Return(failingPlugin.get())); + EXPECT_CALL(*_mockHeuristicPluginResourceManager, + getHeuristicHandleForPolicyId(failingPolicyId)) + .WillRepeatedly(Return(failingHandle)); + EXPECT_CALL(*failingPlugin, setDeviceProperties(failingHandle, _)) + .WillRepeatedly(Throw( + HipdnnException(HIPDNN_STATUS_INTERNAL_ERROR, "Mock setDeviceProperties failure"))); + + auto heur = getEngineHeuristicDescriptor(); + + // Failing policy first, then StaticOrdering. The failing slot is disabled + // by setDeviceProperties throwing; StaticOrdering succeeds. + const std::vector policyIds = {failingPolicyId, staticOrderingId}; + ASSERT_NO_THROW(heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(policyIds.size()), + policyIds.data())); + + setGraph(); + setHeuristicMode(); + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1, 2})); + + ASSERT_NO_THROW(heur->finalize()); +} + +// ========== toString Tests ========== + +TEST_F(TestEngineHeuristicDescriptor, ToStringBeforeFinalize) +{ + auto heur = getEngineHeuristicDescriptor(); + const std::string str = heur->toString(); + ASSERT_NE(str.find("EngineHeuristicDescriptor"), std::string::npos); + ASSERT_NE(str.find("unset"), std::string::npos); +} + +TEST_F(TestEngineHeuristicDescriptor, ToStringAfterSetHeurMode) +{ + auto heur = getEngineHeuristicDescriptor(); + setHeuristicMode(); + const std::string str = heur->toString(); + ASSERT_NE(str.find("EngineHeuristicDescriptor"), std::string::npos); + ASSERT_NE(str.find("heuristicMode"), std::string::npos); +} + +TEST_F(TestEngineHeuristicDescriptor, ToStringAfterSetGraph) +{ + auto heur = getEngineHeuristicDescriptor(); + setGraph(); + const std::string str = heur->toString(); + ASSERT_NE(str.find("graph="), std::string::npos); + ASSERT_EQ(str.find("graph=null"), std::string::npos); // Should not be null +} + +TEST_F(TestEngineHeuristicDescriptor, ToStringWithPolicyOrder) +{ + auto heur = getEngineHeuristicDescriptor(); + + const std::vector policyIds = { + hipdnn_data_sdk::utilities::policyNameToId("Policy1"), + hipdnn_data_sdk::utilities::policyNameToId("Policy2"), + }; + + ASSERT_NO_THROW(heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(policyIds.size()), + policyIds.data())); + + const std::string str = heur->toString(); + ASSERT_NE(str.find("policyOrder"), std::string::npos); + ASSERT_NE(str.find(hipdnn_data_sdk::utilities::formatEngineIdHex(policyIds[0])), + std::string::npos); + ASSERT_NE(str.find(hipdnn_data_sdk::utilities::formatEngineIdHex(policyIds[1])), + std::string::npos); +} + +// ========== Edge Case Tests ========== + +TEST_F(TestEngineHeuristicDescriptor, SetEmptyPolicyOrder) +{ + auto heur = getEngineHeuristicDescriptor(); + + // Setting an empty policy order is allowed at the attribute level; finalize() + // would later fail because no policy can be selected, but that is exercised by + // FinalizeWithAllPoliciesFailing. Here we only verify the attribute path. + ASSERT_NO_THROW( + heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, HIPDNN_TYPE_INT64, 0, nullptr)); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, GetPolicyOrderNullElementCount) +{ + auto heur = getEngineHeuristicDescriptor(); + + const std::vector policyIds = { + hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"), + }; + + ASSERT_NO_THROW(heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(policyIds.size()), + policyIds.data())); + + setGraph(); + setHeuristicMode(); + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1})); + ASSERT_NO_THROW(heur->finalize()); + + std::vector getBuffer(16); + ASSERT_THROW_HIPDNN_STATUS(heur->getAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(getBuffer.size()), + nullptr, + getBuffer.data()), + HIPDNN_STATUS_BAD_PARAM_NULL_POINTER); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, MultipleSetPolicyOrderCalls) +{ + auto heur = getEngineHeuristicDescriptor(); + + // First set + { + const std::vector policyIds = { + hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"), + }; + ASSERT_NO_THROW(heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(policyIds.size()), + policyIds.data())); + } + + // Second set should override + const std::vector secondPolicyIds = { + hipdnn_data_sdk::utilities::policyNameToId("Vendor::Other"), + hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"), + }; + ASSERT_NO_THROW(heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(secondPolicyIds.size()), + secondPolicyIds.data())); + + setGraph(); + setHeuristicMode(); + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1})); + ASSERT_NO_THROW(heur->finalize()); + + std::vector getBuffer(secondPolicyIds.size()); + int64_t count = 0; + ASSERT_NO_THROW(heur->getAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(getBuffer.size()), + &count, + getBuffer.data())); + + ASSERT_EQ(count, static_cast(secondPolicyIds.size())); + for(size_t i = 0; i < secondPolicyIds.size(); ++i) + { + ASSERT_EQ(getBuffer[i], secondPolicyIds[i]); + } +} + +// ========== Policy Order Resolution: Environment Variable ========== + +TEST_F(TestGpuEngineHeuristicDescriptor, EnvironmentVariablePolicyOrderIsRespected) +{ + // The mock setup in setupMockHeuristicPlugin() makes the catch-all return a + // null handle for any unknown policy and StaticOrdering succeed. With no + // descriptor-level override, the default order [StaticOrdering] therefore + // succeeds. Restricting the env-var order to a policy nothing maps to should + // make finalize() throw, proving the env var supersedes the default. + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter guard( + "HIPDNN_HEUR_POLICY_ORDER", "Vendor::Unregistered"); + + auto heur = getEngineHeuristicDescriptor(); + setGraph(); + setHeuristicMode(); + + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1, 2})); + + ASSERT_THROW_HIPDNN_STATUS(heur->finalize(), HIPDNN_STATUS_INTERNAL_ERROR); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, EnvironmentPolicyOrderTakesPrecedenceOverDescriptor) +{ + // Same mock setup. The env var (highest priority) lists only + // StaticOrdering — which the mock makes succeed — while the descriptor + // attribute lists an unregistered policy that would otherwise throw. + // Env winning means finalize() succeeds. + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter guard( + "HIPDNN_HEUR_POLICY_ORDER", "SelectionHeuristic::StaticOrdering"); + + auto heur = getEngineHeuristicDescriptor(); + + const std::vector descriptorOrder = { + hipdnn_data_sdk::utilities::policyNameToId("Vendor::Unregistered"), + }; + ASSERT_NO_THROW(heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, + HIPDNN_TYPE_INT64, + static_cast(descriptorOrder.size()), + descriptorOrder.data())); + + setGraph(); + setHeuristicMode(); + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1, 2})); + + ASSERT_NO_THROW(heur->finalize()); +} + +TEST_F(TestGpuEngineHeuristicDescriptor, EnvironmentPolicyOrderAcceptsRawIds) +{ + // HIPDNN_HEUR_POLICY_ORDER tokens may be either policy names or raw int64 + // policy IDs. Mixing both forms — including a negative ID for an + // unregistered policy — must round-trip through resolution and reach the + // outer policy loop in the order written. + const int64_t staticOrderingId + = hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"); + const std::string envValue + = "-1234567890," + std::to_string(staticOrderingId) + ",SelectionHeuristic::Config"; + + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter guard( + "HIPDNN_HEUR_POLICY_ORDER", envValue); + + auto heur = getEngineHeuristicDescriptor(); + setGraph(); + setHeuristicMode(); + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1, 2})); + + // The first token is an unregistered ID (slot becomes a null placeholder), + // the StaticOrdering ID succeeds, and Config (no rules → declines) is the + // last. Finalize succeeds because StaticOrdering is reached. + ASSERT_NO_THROW(heur->finalize()); +} + +// ========== Failure Handling: Empty Policy List ========== + +TEST_F(TestGpuEngineHeuristicDescriptor, FinalizeWithEmptyPolicyListThrows) +{ + // Empty policy list reaches the "no policy succeeded" path via a different + // route from FinalizeWithAllPoliciesFailing: the outer loop never executes + // because there are no slots to try. Both paths must produce the same throw. + auto heur = getEngineHeuristicDescriptor(); + + ASSERT_NO_THROW( + heur->setAttribute(HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT, HIPDNN_TYPE_INT64, 0, nullptr)); + + setGraph(); + setHeuristicMode(); + EXPECT_CALL(*_mockEnginePluginResourceManager, getApplicableEngineIds(_, _)) + .WillRepeatedly(Return(std::vector{1, 2})); + + ASSERT_THROW_HIPDNN_STATUS(heur->finalize(), HIPDNN_STATUS_INTERNAL_ERROR); +} diff --git a/projects/hipdnn/backend/tests/descriptors/mocks/MockHandle.hpp b/projects/hipdnn/backend/tests/descriptors/mocks/MockHandle.hpp index e4c53aa9b37..fa0ccee33a4 100644 --- a/projects/hipdnn/backend/tests/descriptors/mocks/MockHandle.hpp +++ b/projects/hipdnn/backend/tests/descriptors/mocks/MockHandle.hpp @@ -14,4 +14,8 @@ struct MockHandle : hipdnnHandle getPluginResourceManager, (), (const, override)); + MOCK_METHOD(std::shared_ptr, + getHeuristicPluginResourceManager, + (), + (const, override)); }; diff --git a/projects/hipdnn/backend/tests/descriptors/mocks/MockHeuristicPlugin.hpp b/projects/hipdnn/backend/tests/descriptors/mocks/MockHeuristicPlugin.hpp index 5fa7c3d851b..a39da600808 100644 --- a/projects/hipdnn/backend/tests/descriptors/mocks/MockHeuristicPlugin.hpp +++ b/projects/hipdnn/backend/tests/descriptors/mocks/MockHeuristicPlugin.hpp @@ -23,14 +23,12 @@ class MockHeuristicPlugin : public HeuristicPlugin MockHeuristicPlugin() = default; // Module metadata - MOCK_METHOD(std::string_view, apiVersion, (), (const, override)); - MOCK_METHOD(int64_t, policyId, (), (const, override)); + MOCK_METHOD(std::vector, getAllPolicyIds, (), (const, override)); + MOCK_METHOD(std::string_view, getPolicyName, (int64_t policyId), (const, override)); MOCK_METHOD(std::string_view, name, (), (const, override)); - MOCK_METHOD(std::string_view, version, (), (const, override)); + MOCK_METHOD(hipdnnPluginType_t, type, (), (const, override)); // Handle lifecycle - MOCK_METHOD(hipdnnHeuristicHandle_t, createHandle, (), (const, override)); - MOCK_METHOD(void, destroyHandle, (hipdnnHeuristicHandle_t handle), (const, override)); MOCK_METHOD(void, setDeviceProperties, (hipdnnHeuristicHandle_t handle, @@ -40,7 +38,7 @@ class MockHeuristicPlugin : public HeuristicPlugin // Policy descriptor lifecycle MOCK_METHOD(hipdnnHeuristicPolicyDescriptor_t, createPolicyDescriptor, - (hipdnnHeuristicHandle_t pluginHandle), + (hipdnnHeuristicHandle_t pluginHandle, int64_t policyId), (const, override)); MOCK_METHOD(void, destroyPolicyDescriptor, diff --git a/projects/hipdnn/backend/tests/heuristics/TestConfigBuiltIn.cpp b/projects/hipdnn/backend/tests/heuristics/TestConfigBuiltIn.cpp new file mode 100644 index 00000000000..61bad3c00b7 --- /dev/null +++ b/projects/hipdnn/backend/tests/heuristics/TestConfigBuiltIn.cpp @@ -0,0 +1,715 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file TestConfigBuiltIn.cpp + * @brief Tests for the SelectionHeuristic::Config built-in. + * + * The built-in lives inside hipdnn_backend_private as a function-pointer table + * (ConfigBuiltIn::populateFunctionTable) wrapped by HeuristicPlugin via + * createBuiltIn. There is no .so to dlopen; the wrapper reaches the same code + * paths used in production registration through HeuristicPluginManager. + * + * Wraps the table once via HeuristicPlugin::createBuiltIn and exercises both + * the C-ABI rejection paths (null pointers, unknown policy IDs) and the + * policy's end-to-end behavior driven by HIPDNN_HEUR_CONFIG_PATH: + * matching rule reorders the candidate list, miss paths decline so the + * outer policy loop falls through. + */ + +#include "heuristics/config/ConfigBuiltIn.hpp" +#include "plugin/HeuristicPlugin.hpp" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace fb = hipdnn_flatbuffers_sdk::data_objects; +using hipdnn_backend::heuristics::config::populateFunctionTable; +using hipdnn_backend::plugin::HeuristicPlugin; +using hipdnn_backend::plugin::HeuristicPluginFunctionTable; +using hipdnn_data_sdk::utilities::engineNameToId; + +namespace +{ + +const int64_t MIOPEN_ENGINE_ID = engineNameToId("MIOPEN_ENGINE"); +const int64_t MIOPEN_DETERMINISTIC_ID = engineNameToId("MIOPEN_ENGINE_DETERMINISTIC"); +const int64_t CUSTOM_ENGINE_ID = engineNameToId("Plugin1::CustomEngine"); + +const int64_t CONFIG_POLICY_ID + = hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::Config"); + +constexpr const char* OVERRIDE_ENV = "HIPDNN_HEUR_CONFIG_PATH"; + +/// Build a minimal serialized Graph FlatBuffer with no nodes. +std::vector buildEmptyGraphBuffer() +{ + flatbuffers::FlatBufferBuilder builder; + auto graphOffset = fb::CreateGraphDirect(builder, + nullptr, + fb::DataType::UNSET, + fb::DataType::UNSET, + fb::DataType::UNSET, + nullptr, + nullptr, + ::flatbuffers::nullopt); + fb::FinishGraphBuffer(builder, graphOffset); + const auto* data = builder.GetBufferPointer(); + return {data, data + builder.GetSize()}; +} + +/// Build a serialized Graph with a single ConvolutionFwd node referencing +/// (x, w) tensors of the requested shapes. +std::vector buildConvFwdGraphBuffer(const std::vector& xDims, + const std::vector& xStrides, + const std::vector& wDims, + const std::vector& wStrides) +{ + flatbuffers::FlatBufferBuilder builder; + + constexpr int64_t X_UID = 1; + constexpr int64_t W_UID = 2; + constexpr int64_t Y_UID = 3; + + const std::vector> tensors{ + fb::CreateTensorAttributesDirect( + builder, X_UID, "x", fb::DataType::FLOAT, &xStrides, &xDims), + fb::CreateTensorAttributesDirect( + builder, W_UID, "w", fb::DataType::FLOAT, &wStrides, &wDims), + fb::CreateTensorAttributesDirect( + builder, Y_UID, "y", fb::DataType::FLOAT, nullptr, nullptr), + }; + + auto convAttrs = fb::CreateConvolutionFwdAttributesDirect(builder, X_UID, W_UID, Y_UID); + + const std::vector> nodes{ + fb::CreateNodeDirect(builder, + "conv", + fb::DataType::FLOAT, + fb::NodeAttributes::ConvolutionFwdAttributes, + convAttrs.Union())}; + + auto graphOffset = fb::CreateGraphDirect(builder, + nullptr, + fb::DataType::UNSET, + fb::DataType::UNSET, + fb::DataType::UNSET, + &tensors, + &nodes, + ::flatbuffers::nullopt); + fb::FinishGraphBuffer(builder, graphOffset); + const auto* data = builder.GetBufferPointer(); + return {data, data + builder.GetSize()}; +} + +/// Build a serialized Graph with a single ConvolutionBwd node referencing +/// (dy, w) tensors. Mirrors buildConvFwdGraphBuffer; matchOverrideConfig pulls +/// the rule's first two tensors against (dy, w) for "conv_dgrad". +std::vector buildConvBwdGraphBuffer(const std::vector& dyDims, + const std::vector& dyStrides, + const std::vector& wDims, + const std::vector& wStrides) +{ + flatbuffers::FlatBufferBuilder builder; + + constexpr int64_t DY_UID = 1; + constexpr int64_t W_UID = 2; + constexpr int64_t DX_UID = 3; + + const std::vector> tensors{ + fb::CreateTensorAttributesDirect( + builder, DY_UID, "dy", fb::DataType::FLOAT, &dyStrides, &dyDims), + fb::CreateTensorAttributesDirect( + builder, W_UID, "w", fb::DataType::FLOAT, &wStrides, &wDims), + fb::CreateTensorAttributesDirect( + builder, DX_UID, "dx", fb::DataType::FLOAT, nullptr, nullptr), + }; + + auto convAttrs = fb::CreateConvolutionBwdAttributesDirect(builder, DY_UID, W_UID, DX_UID); + + const std::vector> nodes{ + fb::CreateNodeDirect(builder, + "conv_bwd", + fb::DataType::FLOAT, + fb::NodeAttributes::ConvolutionBwdAttributes, + convAttrs.Union())}; + + auto graphOffset = fb::CreateGraphDirect(builder, + nullptr, + fb::DataType::UNSET, + fb::DataType::UNSET, + fb::DataType::UNSET, + &tensors, + &nodes, + ::flatbuffers::nullopt); + fb::FinishGraphBuffer(builder, graphOffset); + const auto* data = builder.GetBufferPointer(); + return {data, data + builder.GetSize()}; +} + +/// Build a serialized Graph with a single ConvolutionWrw node referencing +/// (x, dy) tensors. matchOverrideConfig pairs (a=x, b=dy) for "conv_wgrad". +std::vector buildConvWrwGraphBuffer(const std::vector& xDims, + const std::vector& xStrides, + const std::vector& dyDims, + const std::vector& dyStrides) +{ + flatbuffers::FlatBufferBuilder builder; + + constexpr int64_t X_UID = 1; + constexpr int64_t DY_UID = 2; + constexpr int64_t DW_UID = 3; + + const std::vector> tensors{ + fb::CreateTensorAttributesDirect( + builder, X_UID, "x", fb::DataType::FLOAT, &xStrides, &xDims), + fb::CreateTensorAttributesDirect( + builder, DY_UID, "dy", fb::DataType::FLOAT, &dyStrides, &dyDims), + fb::CreateTensorAttributesDirect( + builder, DW_UID, "dw", fb::DataType::FLOAT, nullptr, nullptr), + }; + + auto convAttrs = fb::CreateConvolutionWrwAttributesDirect(builder, X_UID, DY_UID, DW_UID); + + const std::vector> nodes{ + fb::CreateNodeDirect(builder, + "conv_wrw", + fb::DataType::FLOAT, + fb::NodeAttributes::ConvolutionWrwAttributes, + convAttrs.Union())}; + + auto graphOffset = fb::CreateGraphDirect(builder, + nullptr, + fb::DataType::UNSET, + fb::DataType::UNSET, + fb::DataType::UNSET, + &tensors, + &nodes, + ::flatbuffers::nullopt); + fb::FinishGraphBuffer(builder, graphOffset); + const auto* data = builder.GetBufferPointer(); + return {data, data + builder.GetSize()}; +} + +/// RAII temp directory + JSON file. Returns a path that can be assigned to +/// HIPDNN_HEUR_CONFIG_PATH; the directory is removed on destruction. +class TempJsonOverrideFile +{ +public: + explicit TempJsonOverrideFile(const std::string& contents) + : _dir(makeUniqueDir()) + , _path(_dir.path() / "override.json") + { + std::ofstream(_path) << contents; + } + + std::string path() const + { + return _path.string(); + } + +private: + static std::filesystem::path makeUniqueDir() + { + static std::atomic s_counter{0}; + const auto path = std::filesystem::temp_directory_path() + / ("hipdnn_test_config_" + std::to_string(s_counter.fetch_add(1))); + std::filesystem::remove_all(path); + return path; + } + + hipdnn_test_sdk::utilities::ScopedDirectory _dir; + std::filesystem::path _path; +}; + +constexpr const char* DETERMINISTIC_RULE_JSON = R"({ + "engine_overrides": [ + { + "op": "conv_fprop", + "engine_name": "MIOPEN_ENGINE_DETERMINISTIC", + "tensors": [ + { "dim": [1, 3, 4, 4] }, + { "dim": [2, 3, 1, 1] } + ] + } + ] +})"; + +const std::vector X_DIMS{1, 3, 4, 4}; +const std::vector X_STRIDES{48, 16, 4, 1}; +const std::vector W_DIMS{2, 3, 1, 1}; +const std::vector W_STRIDES{3, 1, 1, 1}; + +class TestConfigBuiltIn : public ::testing::Test +{ +protected: + void SetUp() override + { + _plugin = HeuristicPlugin::createBuiltIn(populateFunctionTable(), "built-in:Config-test"); + _handle = _plugin->createHandle(); + ASSERT_NE(_handle, nullptr); + _desc = _plugin->createPolicyDescriptor(_handle, CONFIG_POLICY_ID); + ASSERT_NE(_desc, nullptr); + } + + void TearDown() override + { + if(_desc != nullptr) + { + _plugin->destroyPolicyDescriptor(_desc); + } + if(_handle != nullptr) + { + _plugin->destroyHandle(_handle); + } + } + + void setEngineIds(const std::vector& ids) + { + _plugin->setEngineIds(_desc, ids.data(), ids.size()); + } + + void setSerializedGraph(const std::vector& buffer) + { + const hipdnnPluginConstData_t data{buffer.data(), buffer.size()}; + _plugin->setSerializedGraph(_desc, &data); + } + + std::shared_ptr _plugin; + hipdnnHeuristicHandle_t _handle = nullptr; + hipdnnHeuristicPolicyDescriptor_t _desc = nullptr; +}; + +// Convenience: grab the raw function table once for direct C-ABI rejection tests. +const HeuristicPluginFunctionTable& configAbi() +{ + static const HeuristicPluginFunctionTable s_funcs = populateFunctionTable(); + return s_funcs; +} + +} // namespace + +// ========== Built-in metadata exposed via the wrapper ========== + +TEST_F(TestConfigBuiltIn, ReportsHeuristicPluginType) +{ + EXPECT_EQ(_plugin->type(), HIPDNN_PLUGIN_TYPE_HEURISTIC); +} + +TEST_F(TestConfigBuiltIn, EnumeratesSingleConfigPolicy) +{ + const auto ids = _plugin->getAllPolicyIds(); + ASSERT_EQ(ids.size(), 1u); + EXPECT_EQ(ids[0], CONFIG_POLICY_ID); + EXPECT_EQ(_plugin->getPolicyName(CONFIG_POLICY_ID), "SelectionHeuristic::Config"); +} + +// ========== Policy Descriptor Lifecycle (BAD_PARAM via raw ABI) ========== + +TEST(TestConfigBuiltInRejection, DescriptorCreateRejectsNullHandle) +{ + hipdnnHeuristicPolicyDescriptor_t desc = nullptr; + EXPECT_EQ(configAbi().policyDescriptorCreate(nullptr, CONFIG_POLICY_ID, &desc), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); + EXPECT_EQ(desc, nullptr); +} + +TEST_F(TestConfigBuiltIn, DescriptorCreateRejectsNullOutPointer) +{ + EXPECT_EQ(configAbi().policyDescriptorCreate(_handle, CONFIG_POLICY_ID, nullptr), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST(TestConfigBuiltInRejection, DescriptorDestroyRejectsNullDescriptor) +{ + EXPECT_EQ(configAbi().policyDescriptorDestroy(nullptr), HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST_F(TestConfigBuiltIn, DescriptorCreateRejectsUnknownPolicyId) +{ + const int64_t unknownId = hipdnn_data_sdk::utilities::policyNameToId("Vendor::NotARealPolicy"); + ASSERT_NE(unknownId, CONFIG_POLICY_ID); + + hipdnnHeuristicPolicyDescriptor_t desc = nullptr; + EXPECT_EQ(configAbi().policyDescriptorCreate(_handle, unknownId, &desc), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); + EXPECT_EQ(desc, nullptr); +} + +TEST_F(TestConfigBuiltIn, GetPolicyNameRejectsUnknownPolicyId) +{ + const int64_t unknownId = hipdnn_data_sdk::utilities::policyNameToId("Vendor::NotARealPolicy"); + ASSERT_NE(unknownId, CONFIG_POLICY_ID); + + const char* name = nullptr; + EXPECT_EQ(configAbi().getPolicyName(unknownId, &name), HIPDNN_PLUGIN_STATUS_BAD_PARAM); + EXPECT_EQ(name, nullptr); +} + +// ========== SetEngineIds / SetSerializedGraph BAD_PARAM ========== + +TEST(TestConfigBuiltInRejection, SetEngineIdsRejectsNullDescriptor) +{ + const std::array ids{MIOPEN_ENGINE_ID}; + EXPECT_EQ(configAbi().policySetEngineIds(nullptr, ids.data(), ids.size()), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST_F(TestConfigBuiltIn, SetEngineIdsRejectsNullPointerWithCount) +{ + EXPECT_EQ(configAbi().policySetEngineIds(_desc, nullptr, 3), HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST(TestConfigBuiltInRejection, SetSerializedGraphRejectsNullDescriptor) +{ + const std::array buffer{0x00}; + const hipdnnPluginConstData_t data{buffer.data(), buffer.size()}; + EXPECT_EQ(configAbi().policySetSerializedGraph(nullptr, &data), HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST_F(TestConfigBuiltIn, SetSerializedGraphRejectsNullBufferStruct) +{ + EXPECT_EQ(configAbi().policySetSerializedGraph(_desc, nullptr), HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +// ========== Finalize BAD_PARAM / NOT_INITIALIZED ========== + +TEST(TestConfigBuiltInRejection, FinalizeRejectsNullDescriptor) +{ + int32_t applied = -1; // NOLINT(misc-const-correctness) + EXPECT_EQ(configAbi().policyFinalize(nullptr, &applied), HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST_F(TestConfigBuiltIn, FinalizeRejectsNullOutApplied) +{ + EXPECT_EQ(configAbi().policyFinalize(_desc, nullptr), HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST(TestConfigBuiltInRejection, GetSortedRejectsNullDescriptor) +{ + size_t count = 0; // NOLINT(misc-const-correctness) + EXPECT_EQ(configAbi().policyGetSortedEngineIds(nullptr, nullptr, &count), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST_F(TestConfigBuiltIn, GetSortedRejectsNullCountPointer) +{ + EXPECT_EQ(configAbi().policyGetSortedEngineIds(_desc, nullptr, nullptr), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST_F(TestConfigBuiltIn, GetSortedReturnsNotInitializedBeforeFinalize) +{ + size_t count = 0; // NOLINT(misc-const-correctness) + EXPECT_EQ(configAbi().policyGetSortedEngineIds(_desc, nullptr, &count), + HIPDNN_PLUGIN_STATUS_NOT_INITIALIZED); +} + +// ========== End-to-end: miss paths decline so the policy loop continues ========== + +TEST_F(TestConfigBuiltIn, FinalizeWithEmptyCandidatesDeclines) +{ + // Even with a valid env file, no candidates means the policy can't pick + // anything — decline rather than producing an empty list. + const TempJsonOverrideFile json(DETERMINISTIC_RULE_JSON); + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env(OVERRIDE_ENV, + json.path()); + + setSerializedGraph(buildConvFwdGraphBuffer(X_DIMS, X_STRIDES, W_DIMS, W_STRIDES)); + EXPECT_FALSE(_plugin->finalize(_desc)); +} + +TEST_F(TestConfigBuiltIn, FinalizeWithNoEnvDeclines) +{ + // Make sure no override file leaks in from the environment. + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter overrideEnv(OVERRIDE_ENV, ""); + + setEngineIds({MIOPEN_ENGINE_ID, CUSTOM_ENGINE_ID}); + setSerializedGraph(buildConvFwdGraphBuffer(X_DIMS, X_STRIDES, W_DIMS, W_STRIDES)); + EXPECT_FALSE(_plugin->finalize(_desc)); +} + +TEST_F(TestConfigBuiltIn, FinalizeWithMissingFileDeclines) +{ + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env( + OVERRIDE_ENV, "/nonexistent/path/hipdnn_no_such_file.json"); + + setEngineIds({MIOPEN_ENGINE_ID, MIOPEN_DETERMINISTIC_ID}); + setSerializedGraph(buildConvFwdGraphBuffer(X_DIMS, X_STRIDES, W_DIMS, W_STRIDES)); + EXPECT_FALSE(_plugin->finalize(_desc)); +} + +TEST_F(TestConfigBuiltIn, FinalizeWithInvalidGraphBufferDeclines) +{ + // Garbage bytes large enough to clear the null check but fail FlatBuffers + // verification — must be tolerated quietly so the policy loop still runs. + const TempJsonOverrideFile json(DETERMINISTIC_RULE_JSON); + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env(OVERRIDE_ENV, + json.path()); + + const std::vector garbage(64, 0xFF); + setEngineIds({MIOPEN_ENGINE_ID}); + setSerializedGraph(garbage); + EXPECT_FALSE(_plugin->finalize(_desc)); +} + +TEST_F(TestConfigBuiltIn, FinalizeWithNoMatchingRuleDeclines) +{ + // Rule targets dim [99, 99, 99, 99] — no conv in the test graph matches. + constexpr const char* JSON = R"({ + "engine_overrides": [ + { + "op": "conv_fprop", + "engine_name": "MIOPEN_ENGINE_DETERMINISTIC", + "tensors": [ + { "dim": [99, 99, 99, 99] }, + { "dim": [99, 99, 99, 99] } + ] + } + ] + })"; + const TempJsonOverrideFile json(JSON); + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env(OVERRIDE_ENV, + json.path()); + + setEngineIds({MIOPEN_ENGINE_ID, MIOPEN_DETERMINISTIC_ID}); + setSerializedGraph(buildConvFwdGraphBuffer(X_DIMS, X_STRIDES, W_DIMS, W_STRIDES)); + EXPECT_FALSE(_plugin->finalize(_desc)); +} + +TEST_F(TestConfigBuiltIn, FinalizeWithMatchedEngineNotInCandidatesDeclines) +{ + const TempJsonOverrideFile json(DETERMINISTIC_RULE_JSON); + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env(OVERRIDE_ENV, + json.path()); + + // Rule selects DETERMINISTIC; candidate list omits it. + setEngineIds({MIOPEN_ENGINE_ID, CUSTOM_ENGINE_ID}); + setSerializedGraph(buildConvFwdGraphBuffer(X_DIMS, X_STRIDES, W_DIMS, W_STRIDES)); + EXPECT_FALSE(_plugin->finalize(_desc)); +} + +TEST_F(TestConfigBuiltIn, FinalizeWithGraphMissingNodesDeclines) +{ + // Empty graph: nothing to walk; nothing matches. Decline. + const TempJsonOverrideFile json(DETERMINISTIC_RULE_JSON); + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env(OVERRIDE_ENV, + json.path()); + + setEngineIds({MIOPEN_ENGINE_ID, MIOPEN_DETERMINISTIC_ID}); + setSerializedGraph(buildEmptyGraphBuffer()); + EXPECT_FALSE(_plugin->finalize(_desc)); +} + +// ========== End-to-end: matching rule reorders candidates ========== + +TEST_F(TestConfigBuiltIn, FinalizeMatchedRuleMovesEngineToFront) +{ + const TempJsonOverrideFile json(DETERMINISTIC_RULE_JSON); + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env(OVERRIDE_ENV, + json.path()); + + setEngineIds({MIOPEN_ENGINE_ID, CUSTOM_ENGINE_ID, MIOPEN_DETERMINISTIC_ID}); + setSerializedGraph(buildConvFwdGraphBuffer(X_DIMS, X_STRIDES, W_DIMS, W_STRIDES)); + + ASSERT_TRUE(_plugin->finalize(_desc)); + const auto sorted = _plugin->getSortedEngineIds(_desc); + ASSERT_EQ(sorted.size(), 3u); + EXPECT_EQ(sorted[0], MIOPEN_DETERMINISTIC_ID); + EXPECT_EQ(sorted[1], MIOPEN_ENGINE_ID); + EXPECT_EQ(sorted[2], CUSTOM_ENGINE_ID); +} + +TEST_F(TestConfigBuiltIn, FinalizeRereadsEnvOnEachInvocation) +{ + // loadFromEnv must not be process-cached: pointing the env at a different + // file between invocations picks up the new rule. + const TempJsonOverrideFile firstFile(DETERMINISTIC_RULE_JSON); + constexpr const char* SECOND_RULE = R"({ + "engine_overrides": [ + { + "op": "conv_fprop", + "engine_name": "MIOPEN_ENGINE", + "tensors": [ + { "dim": [1, 3, 4, 4] }, + { "dim": [2, 3, 1, 1] } + ] + } + ] + })"; + const TempJsonOverrideFile secondFile(SECOND_RULE); + + hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env(OVERRIDE_ENV, firstFile.path()); + + setEngineIds({MIOPEN_ENGINE_ID, MIOPEN_DETERMINISTIC_ID}); + setSerializedGraph(buildConvFwdGraphBuffer(X_DIMS, X_STRIDES, W_DIMS, W_STRIDES)); + + ASSERT_TRUE(_plugin->finalize(_desc)); + { + const auto sorted = _plugin->getSortedEngineIds(_desc); + ASSERT_FALSE(sorted.empty()); + EXPECT_EQ(sorted.front(), MIOPEN_DETERMINISTIC_ID); + } + + env.setValue(secondFile.path()); + + // Rerun finalize — the rule from secondFile should win this time. + setEngineIds({MIOPEN_ENGINE_ID, MIOPEN_DETERMINISTIC_ID}); + setSerializedGraph(buildConvFwdGraphBuffer(X_DIMS, X_STRIDES, W_DIMS, W_STRIDES)); + ASSERT_TRUE(_plugin->finalize(_desc)); + { + const auto sorted = _plugin->getSortedEngineIds(_desc); + ASSERT_FALSE(sorted.empty()); + EXPECT_EQ(sorted.front(), MIOPEN_ENGINE_ID); + } +} + +// ========== End-to-end: ConvolutionBwd / ConvolutionWrw node parsing ========== + +TEST_F(TestConfigBuiltIn, FinalizeMatchedRuleMovesEngineToFrontBwdNode) +{ + // Drives the conv_dgrad branch in matchOverrideConfig — the rule pairs + // (dy, w), so dim entries here must match the Bwd node's tensor pair. + constexpr const char* JSON = R"({ + "engine_overrides": [ + { + "op": "conv_dgrad", + "engine_name": "MIOPEN_ENGINE_DETERMINISTIC", + "tensors": [ + { "dim": [1, 3, 4, 4] }, + { "dim": [2, 3, 1, 1] } + ] + } + ] + })"; + const TempJsonOverrideFile json(JSON); + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env(OVERRIDE_ENV, + json.path()); + + setEngineIds({MIOPEN_ENGINE_ID, CUSTOM_ENGINE_ID, MIOPEN_DETERMINISTIC_ID}); + setSerializedGraph(buildConvBwdGraphBuffer(X_DIMS, X_STRIDES, W_DIMS, W_STRIDES)); + + ASSERT_TRUE(_plugin->finalize(_desc)); + const auto sorted = _plugin->getSortedEngineIds(_desc); + ASSERT_EQ(sorted.size(), 3u); + EXPECT_EQ(sorted[0], MIOPEN_DETERMINISTIC_ID); +} + +TEST_F(TestConfigBuiltIn, FinalizeMatchedRuleMovesEngineToFrontWrwNode) +{ + // Drives the conv_wgrad branch in matchOverrideConfig — the rule pairs + // (x, dy). + constexpr const char* JSON = R"({ + "engine_overrides": [ + { + "op": "conv_wgrad", + "engine_name": "MIOPEN_ENGINE_DETERMINISTIC", + "tensors": [ + { "dim": [1, 3, 4, 4] }, + { "dim": [2, 3, 1, 1] } + ] + } + ] + })"; + const TempJsonOverrideFile json(JSON); + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env(OVERRIDE_ENV, + json.path()); + + setEngineIds({MIOPEN_ENGINE_ID, CUSTOM_ENGINE_ID, MIOPEN_DETERMINISTIC_ID}); + setSerializedGraph(buildConvWrwGraphBuffer(X_DIMS, X_STRIDES, W_DIMS, W_STRIDES)); + + ASSERT_TRUE(_plugin->finalize(_desc)); + const auto sorted = _plugin->getSortedEngineIds(_desc); + ASSERT_EQ(sorted.size(), 3u); + EXPECT_EQ(sorted[0], MIOPEN_DETERMINISTIC_ID); +} + +// ========== Logging callback / getLastErrorString ABI shape ========== + +namespace +{ +// Counter and severity capture for the logging-callback test. File-scope so a +// plain C function pointer can mutate them. +std::atomic gCallbackInvocations{0}; +std::atomic gCallbackLastSeverity{HIPDNN_SEV_INFO}; + +void testLoggingCallback(hipdnnSeverity_t severity, const char* /*message*/) +{ + gCallbackInvocations.fetch_add(1); + gCallbackLastSeverity.store(severity); +} +} // namespace + +TEST(TestConfigBuiltInLogging, LoggingCallbackReceivesErrorOnUnknownPolicyId) +{ + // Drive the STATIC_ORDERING_LOG-equivalent macro body in ConfigBuiltIn. + // getPolicyName(unknownId) logs at ERROR severity before returning + // BAD_PARAM; with a callback installed and log level SEV_ERROR we should + // observe at least one invocation tagged at HIPDNN_SEV_ERROR. + gCallbackInvocations.store(0); + gCallbackLastSeverity.store(HIPDNN_SEV_INFO); + + ASSERT_EQ(configAbi().setLoggingCallback(&testLoggingCallback), HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_EQ(configAbi().setLogLevel(HIPDNN_SEV_ERROR), HIPDNN_PLUGIN_STATUS_SUCCESS); + + const int64_t unknownId = hipdnn_data_sdk::utilities::policyNameToId("Vendor::NotARealPolicy"); + ASSERT_NE(unknownId, CONFIG_POLICY_ID); + + const char* name = nullptr; + EXPECT_EQ(configAbi().getPolicyName(unknownId, &name), HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + EXPECT_GE(gCallbackInvocations.load(), 1); + EXPECT_EQ(gCallbackLastSeverity.load(), HIPDNN_SEV_ERROR); + + // Reset globals so other tests in the binary do not see a dangling callback. + EXPECT_EQ(configAbi().setLoggingCallback(nullptr), HIPDNN_PLUGIN_STATUS_SUCCESS); + EXPECT_EQ(configAbi().setLogLevel(HIPDNN_SEV_INFO), HIPDNN_PLUGIN_STATUS_SUCCESS); +} + +TEST(TestConfigBuiltInLogging, GetLastErrorStringHandlesNullOutPointer) +{ + // Pure ABI-shape branch coverage: getLastErrorString(nullptr) must return + // (void) without dereferencing the null pointer. + EXPECT_NO_FATAL_FAILURE(configAbi().getLastErrorString(nullptr)); +} + +TEST(TestConfigBuiltInLogging, GetLastErrorStringWritesPlaceholder) +{ + const char* msg = nullptr; + configAbi().getLastErrorString(&msg); + ASSERT_NE(msg, nullptr); + EXPECT_STRNE(msg, ""); +} + +// ========== Empty serialized graph buffer ========== + +TEST_F(TestConfigBuiltIn, SetSerializedGraphAcceptsZeroSizeBuffer) +{ + // Drives the size==0 branch in policySetSerializedGraph that clears the + // descriptor's stored buffer instead of copying. The validation macro + // rejects ptr==nullptr unconditionally, so we pass a real (but unused) + // byte alongside size==0. + const std::array placeholder{0x00}; + const hipdnnPluginConstData_t data{placeholder.data(), 0}; + EXPECT_EQ(configAbi().policySetSerializedGraph(_desc, &data), HIPDNN_PLUGIN_STATUS_SUCCESS); + + // With no graph and no candidates, finalize declines (covers the empty + // buffer → parseGraphBuffer null-return path through finalize). + setEngineIds({MIOPEN_ENGINE_ID}); + EXPECT_FALSE(_plugin->finalize(_desc)); +} diff --git a/projects/hipdnn/backend/tests/heuristics/TestEngineOverrideConfig.cpp b/projects/hipdnn/backend/tests/heuristics/TestEngineOverrideConfig.cpp new file mode 100644 index 00000000000..12af8bcc049 --- /dev/null +++ b/projects/hipdnn/backend/tests/heuristics/TestEngineOverrideConfig.cpp @@ -0,0 +1,449 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file TestEngineOverrideConfig.cpp + * @brief Unit tests for the rule-matching internals of EngineOverrideConfig + * (op + dim + stride wildcards, exact/wildcard partition ordering, + * JSON parsing). The end-to-end policy behavior driven by + * HIPDNN_HEUR_CONFIG_PATH is covered in TestConfigBuiltIn.cpp. + */ + +#include "heuristics/config/EngineOverrideConfig.hpp" + +#include +#include + +#include +#include + +using namespace hipdnn_backend::heuristics::config; +using namespace hipdnn_data_sdk::utilities; + +namespace +{ + +struct TensorData +{ + std::vector dim; + std::vector stride; +}; + +TensorView viewOf(const TensorData& t) +{ + return TensorView{&t.dim, &t.stride}; +} + +std::vector viewsOf(const std::vector& ts) +{ + std::vector views; + views.reserve(ts.size()); + for(const auto& t : ts) + { + views.push_back(viewOf(t)); + } + return views; +} + +TensorPattern makePattern(std::vector dim) +{ + TensorPattern p; + p.dim = std::move(dim); + return p; +} + +TensorPattern makePatternWithStride(std::vector dim, std::vector stride) +{ + TensorPattern p; + p.dim = std::move(dim); + p.stride = std::move(stride); + return p; +} + +EngineOverrideConfig makeConfig(std::vector rules) +{ + return EngineOverrideConfig(std::move(rules)); +} + +} // namespace + +// ── Test 1: exact dim match, single rule ──────────────────────────────────── + +TEST(TestEngineOverrideConfig, ExactDimMatchSingleRule) +{ + OperationRule rule; + rule.op = "conv_fprop"; + rule.engineName = MIOPEN_ENGINE_NAME; + rule.tensors = {makePattern({1, 3, 224, 224}), makePattern({64, 3, 7, 7})}; + + const auto config = makeConfig({std::move(rule)}); + + const std::vector tensors = {{{1, 3, 224, 224}, {}}, {{64, 3, 7, 7}, {}}}; + + auto result = config.matchOperation("conv_fprop", viewsOf(tensors)); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, MIOPEN_ENGINE_ID); +} + +// ── Test 2: first matching rule wins ──────────────────────────────────────── + +TEST(TestEngineOverrideConfig, FirstMatchingRuleWins) +{ + OperationRule rule1; + rule1.op = "conv_fprop"; + rule1.engineName = MIOPEN_ENGINE_NAME; + rule1.tensors = {makePattern({1, 3, 224, 224})}; + + OperationRule rule2; + rule2.op = "conv_fprop"; + rule2.engineName = HIPBLASLT_ENGINE_NAME; + rule2.tensors = {makePattern({1, 3, 224, 224})}; + + const auto config = makeConfig({std::move(rule1), std::move(rule2)}); + + const std::vector tensors = {{{1, 3, 224, 224}, {}}}; + + auto result = config.matchOperation("conv_fprop", viewsOf(tensors)); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, MIOPEN_ENGINE_ID); +} + +// ── Test 3: no rule matches (wrong dims) ──────────────────────────────────── + +TEST(TestEngineOverrideConfig, NoRuleMatchesWrongDims) +{ + OperationRule rule; + rule.op = "conv_fprop"; + rule.engineName = MIOPEN_ENGINE_NAME; + rule.tensors = {makePattern({1, 3, 224, 224})}; + + const auto config = makeConfig({std::move(rule)}); + + const std::vector tensors = {{{1, 3, 112, 112}, {}}}; + + auto result = config.matchOperation("conv_fprop", viewsOf(tensors)); + EXPECT_FALSE(result.has_value()); +} + +// ── Test 4: wildcard (-1) in one dimension ────────────────────────────────── + +TEST(TestEngineOverrideConfig, WildcardInOneDimension) +{ + OperationRule rule; + rule.op = "conv_fprop"; + rule.engineName = HIPBLASLT_ENGINE_NAME; + rule.tensors = {makePattern({-1, 64, 56, 56})}; + + const auto config = makeConfig({std::move(rule)}); + + for(const int64_t batch : {1, 4, 8, 32}) + { + const std::vector tensors = {{{batch, 64, 56, 56}, {}}}; + auto result = config.matchOperation("conv_fprop", viewsOf(tensors)); + ASSERT_TRUE(result.has_value()) << "batch=" << batch << " should match"; + EXPECT_EQ(*result, HIPBLASLT_ENGINE_ID); + } + + const std::vector tensors = {{{4, 128, 56, 56}, {}}}; + EXPECT_FALSE(config.matchOperation("conv_fprop", viewsOf(tensors)).has_value()); +} + +// ── Test 5: all-wildcard rule matches any shape ───────────────────────────── + +TEST(TestEngineOverrideConfig, AllWildcardRuleMatchesAnyShape) +{ + OperationRule rule; + rule.op = "conv_fprop"; + rule.engineName = FUSILLI_ENGINE_NAME; + rule.tensors = {makePattern({-1, -1, -1, -1})}; + + const auto config = makeConfig({std::move(rule)}); + + for(const auto& shape : + std::vector>{{1, 3, 224, 224}, {8, 64, 56, 56}, {32, 256, 14, 14}}) + { + const std::vector tensors = {{shape, {}}}; + auto result = config.matchOperation("conv_fprop", viewsOf(tensors)); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, FUSILLI_ENGINE_ID); + } +} + +// ── Test 6: wrong op name → nullopt ───────────────────────────────────────── + +TEST(TestEngineOverrideConfig, WrongOpNameReturnsNullopt) +{ + OperationRule rule; + rule.op = "conv_fprop"; + rule.engineName = MIOPEN_ENGINE_NAME; + rule.tensors = {makePattern({1, 3, 224, 224})}; + + const auto config = makeConfig({std::move(rule)}); + + const std::vector tensors = {{{1, 3, 224, 224}, {}}}; + + EXPECT_FALSE(config.matchOperation("conv_dgrad", viewsOf(tensors)).has_value()); + EXPECT_FALSE(config.matchOperation("conv_wgrad", viewsOf(tensors)).has_value()); + EXPECT_FALSE(config.matchOperation("matmul", viewsOf(tensors)).has_value()); +} + +// ── Test 7: wrong tensor count in rule → nullopt ──────────────────────────── + +TEST(TestEngineOverrideConfig, WrongTensorCountReturnsNullopt) +{ + OperationRule rule; + rule.op = "conv_fprop"; + rule.engineName = MIOPEN_ENGINE_NAME; + rule.tensors = {makePattern({1, 3, 224, 224}), makePattern({64, 3, 7, 7})}; + + const auto config = makeConfig({std::move(rule)}); + + const std::vector tensors1 = {{{1, 3, 224, 224}, {}}}; + EXPECT_FALSE(config.matchOperation("conv_fprop", viewsOf(tensors1)).has_value()); + + const std::vector tensors3 + = {{{1, 3, 224, 224}, {}}, {{64, 3, 7, 7}, {}}, {{64, 1, 1, 1}, {}}}; + EXPECT_FALSE(config.matchOperation("conv_fprop", viewsOf(tensors3)).has_value()); +} + +// ── Tests 11–12: cross-partition ordering (exact vs wildcard) ─────────────── + +TEST(TestEngineOverrideConfig, WildcardBeforeExactBothMatch) +{ + OperationRule wildcard; + wildcard.op = "conv_fprop"; + wildcard.engineName = FUSILLI_ENGINE_NAME; + wildcard.tensors = {makePattern({-1, 3, 224, 224})}; + + OperationRule exact; + exact.op = "conv_fprop"; + exact.engineName = HIPBLASLT_ENGINE_NAME; + exact.tensors = {makePattern({1, 3, 224, 224})}; + + const auto config = makeConfig({std::move(wildcard), std::move(exact)}); + + const std::vector tensors = {{{1, 3, 224, 224}, {}}}; + auto result = config.matchOperation("conv_fprop", viewsOf(tensors)); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, FUSILLI_ENGINE_ID); +} + +TEST(TestEngineOverrideConfig, ExactBeforeWildcardBothMatch) +{ + OperationRule exact; + exact.op = "conv_fprop"; + exact.engineName = HIPBLASLT_ENGINE_NAME; + exact.tensors = {makePattern({1, 3, 224, 224})}; + + OperationRule wildcard; + wildcard.op = "conv_fprop"; + wildcard.engineName = FUSILLI_ENGINE_NAME; + wildcard.tensors = {makePattern({-1, 3, 224, 224})}; + + const auto config = makeConfig({std::move(exact), std::move(wildcard)}); + + const std::vector tensors = {{{1, 3, 224, 224}, {}}}; + auto result = config.matchOperation("conv_fprop", viewsOf(tensors)); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, HIPBLASLT_ENGINE_ID); +} + +// ── Stride matching tests ──────────────────────────────────────────────────── + +TEST(TestEngineOverrideConfig, ExactStrideMatchSelectsEngine) +{ + OperationRule rule; + rule.op = "conv_fprop"; + rule.engineName = MIOPEN_ENGINE_NAME; + rule.tensors = {makePatternWithStride({1, 3, 224, 224}, {150528, 50176, 224, 1})}; + + const auto config = makeConfig({std::move(rule)}); + + const std::vector matching = {{{1, 3, 224, 224}, {150528, 50176, 224, 1}}}; + auto result = config.matchOperation("conv_fprop", viewsOf(matching)); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, MIOPEN_ENGINE_ID); + + const std::vector wrongStride + = {{{1, 3, 224, 224}, {1, 224, int64_t{224} * 3, int64_t{224} * 3 * 224}}}; + EXPECT_FALSE(config.matchOperation("conv_fprop", viewsOf(wrongStride)).has_value()); +} + +TEST(TestEngineOverrideConfig, WildcardStrideElement) +{ + OperationRule rule; + rule.op = "conv_fprop"; + rule.engineName = HIPBLASLT_ENGINE_NAME; + rule.tensors = {makePatternWithStride({1, 3, 224, 224}, {150528, 50176, -1, -1})}; + + const auto config = makeConfig({std::move(rule)}); + + for(const int64_t s2 : {224, 112, 56}) + { + const std::vector tensors = {{{1, 3, 224, 224}, {150528, 50176, s2, 1}}}; + auto result = config.matchOperation("conv_fprop", viewsOf(tensors)); + ASSERT_TRUE(result.has_value()) << "stride[2]=" << s2; + EXPECT_EQ(*result, HIPBLASLT_ENGINE_ID); + } + + const std::vector wrongStride = {{{1, 3, 224, 224}, {999, 50176, 224, 1}}}; + EXPECT_FALSE(config.matchOperation("conv_fprop", viewsOf(wrongStride)).has_value()); +} + +TEST(TestEngineOverrideConfig, EmptyStridePatternMatchesAnyStride) +{ + OperationRule rule; + rule.op = "conv_fprop"; + rule.engineName = FUSILLI_ENGINE_NAME; + rule.tensors = {makePattern({1, 3, 224, 224})}; + + const auto config = makeConfig({std::move(rule)}); + + for(const auto& strides : std::vector>{ + {150528, 50176, 224, 1}, {1, 3, 672, 150528}, {999, 888, 777, 666}}) + { + const std::vector tensors = {{{1, 3, 224, 224}, strides}}; + auto result = config.matchOperation("conv_fprop", viewsOf(tensors)); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(*result, FUSILLI_ENGINE_ID); + } +} + +// ── JSON-dependent ────────────────────────────────────────────────────────── + +TEST(TestEngineOverrideConfig, LoadFromValidJsonFile) +{ + constexpr const char* CONTENTS = R"({ + "engine_overrides": [ + { + "comment": "test rule for ResNet first conv", + "op": "conv_fprop", + "engine_name": "MIOPEN_ENGINE", + "tensors": [ + { "dim": [1, 3, 224, 224] }, + { "dim": [64, 3, 7, 7] } + ] + }, + { + "comment": "wildcard catch-all", + "op": "conv_fprop", + "engine_name": "FUSILLI_ENGINE", + "tensors": [ + { "dim": [-1, -1, -1, -1] }, + { "dim": [-1, -1, -1, -1] } + ] + } + ] +})"; + + auto config = EngineOverrideConfig::loadFromContent(CONTENTS); + ASSERT_TRUE(config.has_value()); + + const std::vector exact = {{{1, 3, 224, 224}, {}}, {{64, 3, 7, 7}, {}}}; + auto r1 = config->matchOperation("conv_fprop", viewsOf(exact)); + ASSERT_TRUE(r1.has_value()); + EXPECT_EQ(*r1, MIOPEN_ENGINE_ID); + + const std::vector other = {{{8, 64, 56, 56}, {}}, {{64, 64, 3, 3}, {}}}; + auto r2 = config->matchOperation("conv_fprop", viewsOf(other)); + ASSERT_TRUE(r2.has_value()); + EXPECT_EQ(*r2, FUSILLI_ENGINE_ID); +} + +TEST(TestEngineOverrideConfig, LoadFromMissingFileReturnsNullopt) +{ + auto config = EngineOverrideConfig::load("/nonexistent/path/hipdnn_no_such_file.json"); + EXPECT_FALSE(config.has_value()); +} + +TEST(TestEngineOverrideConfig, EnvVarUnsetReturnsNullopt) +{ + EXPECT_FALSE(EngineOverrideConfig::loadFromEnv().has_value()); +} + +// ── Malformed JSON: parser must return nullopt, not throw ────────────────── + +TEST(TestEngineOverrideConfig, LoadFromContentRejectsInvalidJsonSyntax) +{ + // Trailing comma + unterminated array — nlohmann::json::parse_error. + constexpr const char* CONTENTS = R"({ "engine_overrides": [)"; + EXPECT_FALSE(EngineOverrideConfig::loadFromContent(CONTENTS).has_value()); +} + +TEST(TestEngineOverrideConfig, LoadFromContentRejectsMissingTopLevelKey) +{ + // Valid JSON, but no "engine_overrides" key — at() throws out_of_range. + constexpr const char* CONTENTS = R"({ "other_key": [] })"; + EXPECT_FALSE(EngineOverrideConfig::loadFromContent(CONTENTS).has_value()); +} + +TEST(TestEngineOverrideConfig, LoadFromContentRejectsMissingEntryFields) +{ + // engine_overrides entry missing "engine_name" — at() throws out_of_range. + constexpr const char* CONTENTS = R"({ + "engine_overrides": [ + { "op": "conv_fprop", "tensors": [ { "dim": [1, 3, 224, 224] } ] } + ] +})"; + EXPECT_FALSE(EngineOverrideConfig::loadFromContent(CONTENTS).has_value()); +} + +TEST(TestEngineOverrideConfig, LoadFromContentRejectsMissingTensorDim) +{ + // tensor entry missing "dim" — at() throws out_of_range. + constexpr const char* CONTENTS = R"({ + "engine_overrides": [ + { + "op": "conv_fprop", + "engine_name": "MIOPEN_ENGINE", + "tensors": [ { "stride": [1, 2, 3, 4] } ] + } + ] +})"; + EXPECT_FALSE(EngineOverrideConfig::loadFromContent(CONTENTS).has_value()); +} + +TEST(TestEngineOverrideConfig, LoadFromContentRejectsWrongFieldType) +{ + // "op" is a number, not a string — get throws type_error. + constexpr const char* CONTENTS = R"({ + "engine_overrides": [ + { + "op": 42, + "engine_name": "MIOPEN_ENGINE", + "tensors": [ { "dim": [1, 3, 224, 224] } ] + } + ] +})"; + EXPECT_FALSE(EngineOverrideConfig::loadFromContent(CONTENTS).has_value()); +} + +TEST(TestEngineOverrideConfig, JsonWithStrideConstraint) +{ + constexpr const char* CONTENTS = R"({ + "engine_overrides": [ + { + "op": "conv_fprop", + "engine_name": "MIOPEN_ENGINE", + "tensors": [ + { "dim": [1, 3, 224, 224], "stride": [150528, 50176, 224, 1] }, + { "dim": [64, 3, 7, 7] } + ] + } + ] +})"; + + auto config = EngineOverrideConfig::loadFromContent(CONTENTS); + ASSERT_TRUE(config.has_value()); + + const std::vector matching + = {{{1, 3, 224, 224}, {150528, 50176, 224, 1}}, {{64, 3, 7, 7}, {}}}; + auto r1 = config->matchOperation("conv_fprop", viewsOf(matching)); + ASSERT_TRUE(r1.has_value()); + EXPECT_EQ(*r1, MIOPEN_ENGINE_ID); + + const std::vector wrong + = {{{1, 3, 224, 224}, {1, 224, int64_t{224} * 3, int64_t{224} * 3 * 224}}, + {{64, 3, 7, 7}, {}}}; + EXPECT_FALSE(config->matchOperation("conv_fprop", viewsOf(wrong)).has_value()); +} diff --git a/projects/hipdnn/backend/tests/heuristics/TestStaticOrderingBuiltIn.cpp b/projects/hipdnn/backend/tests/heuristics/TestStaticOrderingBuiltIn.cpp new file mode 100644 index 00000000000..457de0902a8 --- /dev/null +++ b/projects/hipdnn/backend/tests/heuristics/TestStaticOrderingBuiltIn.cpp @@ -0,0 +1,441 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file TestStaticOrderingBuiltIn.cpp + * @brief Tests for the SelectionHeuristic::StaticOrdering built-in. + * + * The built-in lives inside hipdnn_backend_private as a function-pointer table + * (StaticOrderingBuiltIn::populateFunctionTable) wrapped by HeuristicPlugin via + * createBuiltIn. There is no .so to dlopen; the wrapper reaches the same code + * paths used in production registration through HeuristicPluginManager. + * + * Wraps the table once via HeuristicPlugin::createBuiltIn and exercises the + * wrapper API (createHandle, createPolicyDescriptor, setEngineIds, finalize, + * getSortedEngineIds). BAD_PARAM rejection paths in the C-ABI shape are driven + * by calling the populated function-table entries directly so test code can + * pass nullptr arguments without the wrapper translating them into exceptions. + */ + +#include "heuristics/static_ordering/StaticOrderingBuiltIn.hpp" +#include "plugin/HeuristicPlugin.hpp" + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +using hipdnn_backend::heuristics::static_ordering::populateFunctionTable; +using hipdnn_backend::plugin::HeuristicPlugin; +using hipdnn_backend::plugin::HeuristicPluginFunctionTable; + +namespace +{ + +const int64_t MIOPEN_ENGINE_ID = hipdnn_data_sdk::utilities::engineNameToId("MIOPEN_ENGINE"); +const int64_t MIOPEN_DETERMINISTIC_ID + = hipdnn_data_sdk::utilities::engineNameToId("MIOPEN_ENGINE_DETERMINISTIC"); +const int64_t CUSTOM_ENGINE_ID + = hipdnn_data_sdk::utilities::engineNameToId("Plugin1::CustomEngine"); + +const int64_t STATIC_ORDERING_POLICY_ID + = hipdnn_data_sdk::utilities::policyNameToId("SelectionHeuristic::StaticOrdering"); + +constexpr const char* FALLBACK_ORDERING_ENV = "HIPDNN_HEUR_FALLBACK_ENGINE_ORDER"; + +class TestStaticOrderingBuiltIn : public ::testing::Test +{ +protected: + void SetUp() override + { + _plugin = HeuristicPlugin::createBuiltIn(populateFunctionTable(), + "built-in:StaticOrdering-test"); + _handle = _plugin->createHandle(); + ASSERT_NE(_handle, nullptr); + _desc = _plugin->createPolicyDescriptor(_handle, STATIC_ORDERING_POLICY_ID); + ASSERT_NE(_desc, nullptr); + } + + void TearDown() override + { + if(_desc != nullptr) + { + _plugin->destroyPolicyDescriptor(_desc); + } + if(_handle != nullptr) + { + _plugin->destroyHandle(_handle); + } + } + + std::shared_ptr _plugin; + hipdnnHeuristicHandle_t _handle = nullptr; + hipdnnHeuristicPolicyDescriptor_t _desc = nullptr; +}; + +// Convenience: grab the raw function table once for direct C-ABI rejection tests. +// Named to avoid clash with the global `abi` namespace alias from . +const HeuristicPluginFunctionTable& staticOrderingAbi() +{ + static const HeuristicPluginFunctionTable s_funcs = populateFunctionTable(); + return s_funcs; +} + +} // namespace + +// ========== Built-in metadata exposed via the wrapper ========== + +TEST_F(TestStaticOrderingBuiltIn, ReportsHeuristicPluginType) +{ + EXPECT_EQ(_plugin->type(), HIPDNN_PLUGIN_TYPE_HEURISTIC); +} + +TEST_F(TestStaticOrderingBuiltIn, EnumeratesSingleStaticOrderingPolicy) +{ + const auto ids = _plugin->getAllPolicyIds(); + ASSERT_EQ(ids.size(), 1u); + EXPECT_EQ(ids[0], STATIC_ORDERING_POLICY_ID); + EXPECT_EQ(_plugin->getPolicyName(STATIC_ORDERING_POLICY_ID), + "SelectionHeuristic::StaticOrdering"); +} + +// ========== Policy Descriptor Lifecycle (BAD_PARAM via raw ABI) ========== + +TEST(TestStaticOrderingBuiltInRejection, DescriptorCreateRejectsNullHandle) +{ + hipdnnHeuristicPolicyDescriptor_t desc = nullptr; + EXPECT_EQ(staticOrderingAbi().policyDescriptorCreate(nullptr, STATIC_ORDERING_POLICY_ID, &desc), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); + EXPECT_EQ(desc, nullptr); +} + +TEST_F(TestStaticOrderingBuiltIn, DescriptorCreateRejectsNullOutPointer) +{ + EXPECT_EQ( + staticOrderingAbi().policyDescriptorCreate(_handle, STATIC_ORDERING_POLICY_ID, nullptr), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST(TestStaticOrderingBuiltInRejection, DescriptorDestroyRejectsNullDescriptor) +{ + EXPECT_EQ(staticOrderingAbi().policyDescriptorDestroy(nullptr), HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST_F(TestStaticOrderingBuiltIn, DescriptorCreateRejectsUnknownPolicyId) +{ + // The built-in only owns SelectionHeuristic::StaticOrdering; any other ID + // (here: a different policy name's hash) must be rejected with BAD_PARAM + // and must not allocate an output descriptor. + const int64_t unknownId = hipdnn_data_sdk::utilities::policyNameToId("Vendor::NotARealPolicy"); + ASSERT_NE(unknownId, STATIC_ORDERING_POLICY_ID); + + hipdnnHeuristicPolicyDescriptor_t desc = nullptr; + EXPECT_EQ(staticOrderingAbi().policyDescriptorCreate(_handle, unknownId, &desc), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); + EXPECT_EQ(desc, nullptr); +} + +TEST_F(TestStaticOrderingBuiltIn, GetPolicyNameRejectsUnknownPolicyId) +{ + const int64_t unknownId = hipdnn_data_sdk::utilities::policyNameToId("Vendor::NotARealPolicy"); + ASSERT_NE(unknownId, STATIC_ORDERING_POLICY_ID); + + const char* name = nullptr; + EXPECT_EQ(staticOrderingAbi().getPolicyName(unknownId, &name), HIPDNN_PLUGIN_STATUS_BAD_PARAM); + EXPECT_EQ(name, nullptr); +} + +// ========== SetEngineIds ========== + +TEST_F(TestStaticOrderingBuiltIn, SetEngineIdsAcceptsValidIds) +{ + const std::array ids{MIOPEN_ENGINE_ID, CUSTOM_ENGINE_ID}; + EXPECT_NO_THROW(_plugin->setEngineIds(_desc, ids.data(), ids.size())); +} + +TEST_F(TestStaticOrderingBuiltIn, SetEngineIdsAcceptsZeroCountWithNullPointer) +{ + EXPECT_NO_THROW(_plugin->setEngineIds(_desc, nullptr, 0)); +} + +TEST(TestStaticOrderingBuiltInRejection, SetEngineIdsRejectsNullDescriptor) +{ + const std::array ids{MIOPEN_ENGINE_ID}; + EXPECT_EQ(staticOrderingAbi().policySetEngineIds(nullptr, ids.data(), ids.size()), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST_F(TestStaticOrderingBuiltIn, SetEngineIdsRejectsNullPointerWithCount) +{ + EXPECT_EQ(staticOrderingAbi().policySetEngineIds(_desc, nullptr, 3), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +// ========== SetSerializedGraph ========== +// StaticOrdering ignores the graph contents but validates the parameter shape. + +TEST_F(TestStaticOrderingBuiltIn, SetSerializedGraphAcceptsAnyBuffer) +{ + const std::array buffer{0x01, 0x02, 0x03}; + const hipdnnPluginConstData_t data{buffer.data(), buffer.size()}; + EXPECT_NO_THROW(_plugin->setSerializedGraph(_desc, &data)); +} + +TEST(TestStaticOrderingBuiltInRejection, SetSerializedGraphRejectsNullDescriptor) +{ + const std::array buffer{0x00}; + const hipdnnPluginConstData_t data{buffer.data(), buffer.size()}; + EXPECT_EQ(staticOrderingAbi().policySetSerializedGraph(nullptr, &data), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST_F(TestStaticOrderingBuiltIn, SetSerializedGraphRejectsNullBufferStruct) +{ + EXPECT_EQ(staticOrderingAbi().policySetSerializedGraph(_desc, nullptr), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +// ========== Finalize ========== + +TEST(TestStaticOrderingBuiltInRejection, FinalizeRejectsNullDescriptor) +{ + // applied is non-const because policyFinalize takes int32_t* — clang-tidy + // would still flag a never-modified local, so suppress with NOLINT here. + int32_t applied = -1; // NOLINT(misc-const-correctness) + EXPECT_EQ(staticOrderingAbi().policyFinalize(nullptr, &applied), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST_F(TestStaticOrderingBuiltIn, FinalizeRejectsNullOutApplied) +{ + EXPECT_EQ(staticOrderingAbi().policyFinalize(_desc, nullptr), HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST_F(TestStaticOrderingBuiltIn, FinalizeWithNoCandidatesReportsNotApplied) +{ + EXPECT_FALSE(_plugin->finalize(_desc)); +} + +TEST_F(TestStaticOrderingBuiltIn, FinalizeSortsCandidatesAndReportsApplied) +{ + const std::array ids{MIOPEN_DETERMINISTIC_ID, CUSTOM_ENGINE_ID, MIOPEN_ENGINE_ID}; + _plugin->setEngineIds(_desc, ids.data(), ids.size()); + + EXPECT_TRUE(_plugin->finalize(_desc)); + + const auto sorted = _plugin->getSortedEngineIds(_desc); + ASSERT_EQ(sorted.size(), ids.size()); + EXPECT_EQ(sorted.front(), MIOPEN_ENGINE_ID); + EXPECT_EQ(sorted.back(), MIOPEN_DETERMINISTIC_ID); +} + +TEST_F(TestStaticOrderingBuiltIn, FinalizeResetsByLastSetEngineIdsCall) +{ + const std::array ids{MIOPEN_ENGINE_ID}; + _plugin->setEngineIds(_desc, ids.data(), ids.size()); + _plugin->setEngineIds(_desc, nullptr, 0); + + EXPECT_FALSE(_plugin->finalize(_desc)); +} + +// ========== GetSortedEngineIds ========== + +TEST(TestStaticOrderingBuiltInRejection, GetSortedRejectsNullDescriptor) +{ + size_t count = 0; // NOLINT(misc-const-correctness) — passed by-pointer, mutable required + EXPECT_EQ(staticOrderingAbi().policyGetSortedEngineIds(nullptr, nullptr, &count), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST_F(TestStaticOrderingBuiltIn, GetSortedRejectsNullCountPointer) +{ + EXPECT_EQ(staticOrderingAbi().policyGetSortedEngineIds(_desc, nullptr, nullptr), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); +} + +TEST_F(TestStaticOrderingBuiltIn, GetSortedReturnsNotInitializedBeforeFinalize) +{ + size_t count = 0; // NOLINT(misc-const-correctness) — passed by-pointer, mutable required + EXPECT_EQ(staticOrderingAbi().policyGetSortedEngineIds(_desc, nullptr, &count), + HIPDNN_PLUGIN_STATUS_NOT_INITIALIZED); +} + +TEST_F(TestStaticOrderingBuiltIn, GetSortedClipsToCallerProvidedCapacity) +{ + const std::array ids{MIOPEN_ENGINE_ID, CUSTOM_ENGINE_ID, MIOPEN_DETERMINISTIC_ID}; + _plugin->setEngineIds(_desc, ids.data(), ids.size()); + ASSERT_TRUE(_plugin->finalize(_desc)); + + std::array outBuf{0, 0}; + size_t count = outBuf.size(); + EXPECT_EQ(staticOrderingAbi().policyGetSortedEngineIds(_desc, outBuf.data(), &count), + HIPDNN_PLUGIN_STATUS_SUCCESS); + EXPECT_EQ(count, outBuf.size()); + EXPECT_EQ(outBuf[0], MIOPEN_ENGINE_ID); + EXPECT_EQ(outBuf[1], CUSTOM_ENGINE_ID); +} + +// ========== HIPDNN_HEUR_FALLBACK_ENGINE_ORDER ========== +// The env replaces sortEngineIds: only listed engines are eligible, in env order. + +TEST_F(TestStaticOrderingBuiltIn, FallbackEnvBlankFallsBackToDefaultOrdering) +{ + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env(FALLBACK_ORDERING_ENV, + " "); + + const std::array ids{MIOPEN_DETERMINISTIC_ID, CUSTOM_ENGINE_ID, MIOPEN_ENGINE_ID}; + _plugin->setEngineIds(_desc, ids.data(), ids.size()); + + ASSERT_TRUE(_plugin->finalize(_desc)); + const auto sorted = _plugin->getSortedEngineIds(_desc); + ASSERT_EQ(sorted.size(), ids.size()); + EXPECT_EQ(sorted.front(), MIOPEN_ENGINE_ID); + EXPECT_EQ(sorted.back(), MIOPEN_DETERMINISTIC_ID); +} + +TEST_F(TestStaticOrderingBuiltIn, FallbackEnvOverridesDefaultOrdering) +{ + // Env order is the inverse of the default sort: CUSTOM first, then + // DETERMINISTIC, then MIOPEN_ENGINE last. The default would put + // MIOPEN_ENGINE first and DETERMINISTIC last; if the env path were + // ignored we would see that ordering instead. + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env( + FALLBACK_ORDERING_ENV, "Plugin1::CustomEngine,MIOPEN_ENGINE_DETERMINISTIC,MIOPEN_ENGINE"); + + const std::array ids{MIOPEN_ENGINE_ID, CUSTOM_ENGINE_ID, MIOPEN_DETERMINISTIC_ID}; + _plugin->setEngineIds(_desc, ids.data(), ids.size()); + + ASSERT_TRUE(_plugin->finalize(_desc)); + const auto sorted = _plugin->getSortedEngineIds(_desc); + ASSERT_EQ(sorted.size(), 3u); + EXPECT_EQ(sorted[0], CUSTOM_ENGINE_ID); + EXPECT_EQ(sorted[1], MIOPEN_DETERMINISTIC_ID); + EXPECT_EQ(sorted[2], MIOPEN_ENGINE_ID); +} + +TEST_F(TestStaticOrderingBuiltIn, FallbackEnvSubsetDropsUnlistedCandidates) +{ + // Env names only MIOPEN_ENGINE; CUSTOM and DETERMINISTIC must be dropped + // from the result — the operator opted them out. + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env(FALLBACK_ORDERING_ENV, + "MIOPEN_ENGINE"); + + const std::array ids{MIOPEN_DETERMINISTIC_ID, CUSTOM_ENGINE_ID, MIOPEN_ENGINE_ID}; + _plugin->setEngineIds(_desc, ids.data(), ids.size()); + + ASSERT_TRUE(_plugin->finalize(_desc)); + const auto sorted = _plugin->getSortedEngineIds(_desc); + ASSERT_EQ(sorted.size(), 1u); + EXPECT_EQ(sorted[0], MIOPEN_ENGINE_ID); +} + +TEST_F(TestStaticOrderingBuiltIn, FallbackEnvTrimsTokensAndSkipsBlanks) +{ + // Whitespace around tokens, plus an empty token between the commas, should + // not affect parsing — names are trimmed and blanks are skipped. + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env( + FALLBACK_ORDERING_ENV, " MIOPEN_ENGINE_DETERMINISTIC ,, MIOPEN_ENGINE "); + + const std::array ids{MIOPEN_ENGINE_ID, CUSTOM_ENGINE_ID, MIOPEN_DETERMINISTIC_ID}; + _plugin->setEngineIds(_desc, ids.data(), ids.size()); + + ASSERT_TRUE(_plugin->finalize(_desc)); + const auto sorted = _plugin->getSortedEngineIds(_desc); + ASSERT_EQ(sorted.size(), 2u); + EXPECT_EQ(sorted[0], MIOPEN_DETERMINISTIC_ID); + EXPECT_EQ(sorted[1], MIOPEN_ENGINE_ID); +} + +TEST_F(TestStaticOrderingBuiltIn, FallbackEnvNoListedEngineInCandidatesDeclines) +{ + // Env names an engine that isn't a candidate. The policy declines so the + // outer policy loop can try the next plugin. + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env(FALLBACK_ORDERING_ENV, + "Plugin1::CustomEngine"); + + const std::array ids{MIOPEN_ENGINE_ID, MIOPEN_DETERMINISTIC_ID}; + _plugin->setEngineIds(_desc, ids.data(), ids.size()); + + EXPECT_FALSE(_plugin->finalize(_desc)); +} + +TEST_F(TestStaticOrderingBuiltIn, FallbackEnvUnknownNameSilentlySkipped) +{ + // A typo'd name hashes to its own (unmatchable) ID; it should be ignored + // and the recognized name should still apply. + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter env( + FALLBACK_ORDERING_ENV, "NotARealEngine,MIOPEN_ENGINE"); + + const std::array ids{MIOPEN_DETERMINISTIC_ID, MIOPEN_ENGINE_ID}; + _plugin->setEngineIds(_desc, ids.data(), ids.size()); + + ASSERT_TRUE(_plugin->finalize(_desc)); + const auto sorted = _plugin->getSortedEngineIds(_desc); + ASSERT_EQ(sorted.size(), 1u); + EXPECT_EQ(sorted[0], MIOPEN_ENGINE_ID); +} + +// ========== Logging callback / getLastErrorString ABI shape ========== + +namespace +{ +// Counter and severity capture for the logging-callback test. File-scope so a +// plain C function pointer can mutate them. +std::atomic gCallbackInvocations{0}; +std::atomic gCallbackLastSeverity{HIPDNN_SEV_INFO}; + +void testLoggingCallback(hipdnnSeverity_t severity, const char* /*message*/) +{ + gCallbackInvocations.fetch_add(1); + gCallbackLastSeverity.store(severity); +} +} // namespace + +TEST(TestStaticOrderingBuiltInLogging, LoggingCallbackReceivesErrorOnUnknownPolicyId) +{ + // Drive the STATIC_ORDERING_LOG macro body — getPolicyName(unknownId) logs + // at ERROR severity before returning BAD_PARAM. With a callback installed + // and log level SEV_ERROR we must observe at least one invocation tagged + // at HIPDNN_SEV_ERROR. + gCallbackInvocations.store(0); + gCallbackLastSeverity.store(HIPDNN_SEV_INFO); + + ASSERT_EQ(staticOrderingAbi().setLoggingCallback(&testLoggingCallback), + HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_EQ(staticOrderingAbi().setLogLevel(HIPDNN_SEV_ERROR), HIPDNN_PLUGIN_STATUS_SUCCESS); + + const int64_t unknownId = hipdnn_data_sdk::utilities::policyNameToId("Vendor::NotARealPolicy"); + ASSERT_NE(unknownId, STATIC_ORDERING_POLICY_ID); + + const char* name = nullptr; + EXPECT_EQ(staticOrderingAbi().getPolicyName(unknownId, &name), HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + EXPECT_GE(gCallbackInvocations.load(), 1); + EXPECT_EQ(gCallbackLastSeverity.load(), HIPDNN_SEV_ERROR); + + // Reset globals so other tests in the binary do not see a dangling callback. + EXPECT_EQ(staticOrderingAbi().setLoggingCallback(nullptr), HIPDNN_PLUGIN_STATUS_SUCCESS); + EXPECT_EQ(staticOrderingAbi().setLogLevel(HIPDNN_SEV_INFO), HIPDNN_PLUGIN_STATUS_SUCCESS); +} + +TEST(TestStaticOrderingBuiltInLogging, GetLastErrorStringHandlesNullOutPointer) +{ + // Pure ABI-shape branch coverage: getLastErrorString(nullptr) must return + // (void) without dereferencing the null pointer. + EXPECT_NO_FATAL_FAILURE(staticOrderingAbi().getLastErrorString(nullptr)); +} + +TEST(TestStaticOrderingBuiltInLogging, GetLastErrorStringWritesPlaceholder) +{ + const char* msg = nullptr; + staticOrderingAbi().getLastErrorString(&msg); + ASSERT_NE(msg, nullptr); + EXPECT_STRNE(msg, ""); +} diff --git a/projects/hipdnn/backend/tests/utilities/TestEngineOrdering.cpp b/projects/hipdnn/backend/tests/utilities/TestEngineOrdering.cpp index 935ec6984e3..b06daaab064 100644 --- a/projects/hipdnn/backend/tests/utilities/TestEngineOrdering.cpp +++ b/projects/hipdnn/backend/tests/utilities/TestEngineOrdering.cpp @@ -125,3 +125,31 @@ TEST(TestEngineOrdering, StableOrderPreservedForOthers) EXPECT_EQ(engineIds[4], other4); EXPECT_EQ(engineIds[5], MIOPEN_ENGINE_DETERMINISTIC_ID); } + +TEST(TestEngineOrdering, IsIdempotent) +{ + std::vector engineIds + = {MIOPEN_ENGINE_DETERMINISTIC_ID, HIPBLASLT_ENGINE_ID, MIOPEN_ENGINE_ID}; + sortEngineIds(engineIds); + const auto firstPass = engineIds; + + sortEngineIds(engineIds); + EXPECT_EQ(engineIds, firstPass); +} + +TEST(TestEngineOrdering, UnknownEngineIdsTreatedAsMiddlePriority) +{ + // Engine IDs that don't correspond to any known well-known name should + // sort into the middle bucket (between MIOPEN_ENGINE and + // MIOPEN_ENGINE_DETERMINISTIC) without crashing. + const auto unknown1 = static_cast(0x1234567890ABCDEF); + const auto unknown2 = static_cast(0xFEDCBA0987654321); + + std::vector engineIds = {unknown1, MIOPEN_ENGINE_ID, unknown2}; + + EXPECT_NO_THROW(sortEngineIds(engineIds)); + ASSERT_EQ(engineIds.size(), 3u); + EXPECT_EQ(engineIds[0], MIOPEN_ENGINE_ID); + EXPECT_EQ(engineIds[1], unknown1); + EXPECT_EQ(engineIds[2], unknown2); +} diff --git a/projects/hipdnn/cmake/TestPluginNames.cmake b/projects/hipdnn/cmake/TestPluginNames.cmake new file mode 100644 index 00000000000..796bb36945d --- /dev/null +++ b/projects/hipdnn/cmake/TestPluginNames.cmake @@ -0,0 +1,43 @@ +# Copyright © Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# Shared definitions for test plugin target names. +# +# Several test executables (hipdnn_backend_tests, hipdnn_public_frontend_tests, …) +# reference test plugin targets that are added by tests/test_plugins/. CMake +# processes that subdirectory after backend/ and frontend/, so any test that +# uses the names without a shared definition would either silently expand them +# to empty strings (producing -DTEST_FOO=\"\" compile defs and no-op +# add_dependencies entries) or have to duplicate the literal name and rely on +# the two copies staying in sync. +# +# Defining the names here, before any subdirectory is processed, avoids both +# pitfalls. Both tests/test_plugins/ (which creates the targets) and the test +# executables (which depend on them) read from the same source of truth. + +# Engine / generic test plugins +set(TEST_PLUGIN1_NAME "hipdnn_test_plugin1") +set(TEST_PLUGIN2_NAME "hipdnn_test_plugin2") +set(TEST_NO_API_VERSION_PLUGIN_NAME "hipdnn_test_no_api_version_plugin_name") +set(TEST_ENGINE_PLUGIN1_NAME "hipdnn_test_engine_plugin1") + +set(TEST_GOOD_PLUGIN_NAME "test_good_plugin") +set(TEST_EXECUTE_FAILS_PLUGIN_NAME "test_execute_fails_plugin") +set(TEST_NO_APPLICABLE_ENGINES_A_PLUGIN_NAME "test_no_applicable_engines_a_plugin") +set(TEST_NO_APPLICABLE_ENGINES_B_PLUGIN_NAME "test_no_applicable_engines_b_plugin") +set(TEST_DUPLICATE_ID_A_PLUGIN_NAME "test_duplicate_id_a_plugin") +set(TEST_DUPLICATE_ID_B_PLUGIN_NAME "test_duplicate_id_b_plugin") +set(TEST_INCOMPLETE_API_PLUGIN_NAME "test_incomplete_api_plugin") +set(TEST_GOOD_DEFAULT_PLUGIN_NAME "test_good_default_plugin") +set(TEST_KNOBS_PLUGIN_NAME "test_knobs_plugin") +set(TEST_KNOB_CONSTRAINT_VALIDATION_PLUGIN_NAME "test_knob_constraint_validation_plugin") +set(TEST_INCOMPATIBLE_VERSION_PLUGIN_NAME "test_incompatible_version_plugin") + +# Heuristic plugin test names +set(TEST_GOOD_HEURISTIC_PLUGIN_NAME "test_good_heuristic_plugin") +set(TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME "test_incomplete_heuristic_api_plugin") +set(TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME "test_no_optional_heuristic_plugin") +set(TEST_BAD_API_VERSION_HEURISTIC_PLUGIN_NAME "test_bad_api_version_heuristic_plugin") +set(TEST_EMPTY_NAME_HEURISTIC_PLUGIN_NAME "test_empty_name_heuristic_plugin") +set(TEST_DUPLICATE_POLICY_ID_A_PLUGIN_NAME "test_duplicate_policy_id_a_plugin") +set(TEST_DUPLICATE_POLICY_ID_B_PLUGIN_NAME "test_duplicate_policy_id_b_plugin") diff --git a/projects/hipdnn/data_sdk/CMakeLists.txt b/projects/hipdnn/data_sdk/CMakeLists.txt index 76827545fb6..af6338b666b 100644 --- a/projects/hipdnn/data_sdk/CMakeLists.txt +++ b/projects/hipdnn/data_sdk/CMakeLists.txt @@ -58,6 +58,8 @@ configure_package_config_file( INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/hipdnn_data_sdk PATH_VARS HIPDNN_PLUGIN_ENGINE_SUBDIR HIPDNN_FULL_INSTALL_PLUGIN_ENGINE_DIR HIPDNN_RELATIVE_INSTALL_PLUGIN_ENGINE_DIR + HIPDNN_PLUGIN_HEURISTIC_SUBDIR HIPDNN_FULL_INSTALL_PLUGIN_HEURISTIC_DIR + HIPDNN_RELATIVE_INSTALL_PLUGIN_HEURISTIC_DIR ) # Generate package version file for find_package version checking diff --git a/projects/hipdnn/data_sdk/cmake/hipdnn_data_sdkConfig.cmake.in b/projects/hipdnn/data_sdk/cmake/hipdnn_data_sdkConfig.cmake.in index 0af1e0aa0ec..0804edc0c4e 100644 --- a/projects/hipdnn/data_sdk/cmake/hipdnn_data_sdkConfig.cmake.in +++ b/projects/hipdnn/data_sdk/cmake/hipdnn_data_sdkConfig.cmake.in @@ -14,8 +14,18 @@ set(HIPDNN_PLUGIN_ENGINE_SUBDIR "@HIPDNN_PLUGIN_ENGINE_SUBDIR@") set(HIPDNN_FULL_INSTALL_PLUGIN_ENGINE_DIR "@PACKAGE_HIPDNN_FULL_INSTALL_PLUGIN_ENGINE_DIR@") set(HIPDNN_RELATIVE_INSTALL_PLUGIN_ENGINE_DIR "@HIPDNN_RELATIVE_INSTALL_PLUGIN_ENGINE_DIR@") +# Subdirectory path for hipDNN heuristic plugins (relative to lib) +set(HIPDNN_PLUGIN_HEURISTIC_SUBDIR "@HIPDNN_PLUGIN_HEURISTIC_SUBDIR@") + +# Full install path for heuristic plugins (relocatable using PACKAGE_ prefix from configure_package_config_file) +set(HIPDNN_FULL_INSTALL_PLUGIN_HEURISTIC_DIR "@PACKAGE_HIPDNN_FULL_INSTALL_PLUGIN_HEURISTIC_DIR@") +set(HIPDNN_RELATIVE_INSTALL_PLUGIN_HEURISTIC_DIR "@HIPDNN_RELATIVE_INSTALL_PLUGIN_HEURISTIC_DIR@") + include("${CMAKE_CURRENT_LIST_DIR}/hipdnn_data_sdkTargets.cmake") message(STATUS "hipDNN Data SDK: Engine plugin build directory is ${CMAKE_INSTALL_LIBDIR}/${HIPDNN_PLUGIN_ENGINE_SUBDIR}") message(STATUS "hipDNN Data SDK: Plugin absolute installation directory ${HIPDNN_FULL_INSTALL_PLUGIN_ENGINE_DIR}") message(STATUS "hipDNN Data SDK: Plugin relative installation directory ${HIPDNN_RELATIVE_INSTALL_PLUGIN_ENGINE_DIR}") +message(STATUS "hipDNN Data SDK: Heuristic plugin build directory is ${CMAKE_INSTALL_LIBDIR}/${HIPDNN_PLUGIN_HEURISTIC_SUBDIR}") +message(STATUS "hipDNN Data SDK: Heuristic plugin absolute installation directory ${HIPDNN_FULL_INSTALL_PLUGIN_HEURISTIC_DIR}") +message(STATUS "hipDNN Data SDK: Heuristic plugin relative installation directory ${HIPDNN_RELATIVE_INSTALL_PLUGIN_HEURISTIC_DIR}") diff --git a/projects/hipdnn/data_sdk/include/hipdnn_data_sdk/utilities/EngineOrdering.hpp b/projects/hipdnn/data_sdk/include/hipdnn_data_sdk/utilities/EngineOrdering.hpp new file mode 100644 index 00000000000..33602e5c201 --- /dev/null +++ b/projects/hipdnn/data_sdk/include/hipdnn_data_sdk/utilities/EngineOrdering.hpp @@ -0,0 +1,65 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +#include + +namespace hipdnn_data_sdk::utilities +{ + +/** + * @brief Sorts engine IDs with MIOpen-specific ordering requirements. + * + * Ordering rationale: + * - MIOPEN_ENGINE first: Default engine with full operation support + * - Other engines middle: Stable order preserved for predictability + * - MIOPEN_ENGINE_DETERMINISTIC last: Limited to conv operations only, + * deprioritized due to performance trade-offs and reduced operation support + * + * This is a header-only implementation shared between backend and heuristic plugins. + * + * @param engineIds Vector of engine IDs to sort (modified in-place) + */ +inline void sortEngineIds(std::vector& engineIds) +{ + // Sort engine IDs: MIOPEN_ENGINE first, MIOPEN_ENGINE_DETERMINISTIC last, others in middle + // Using index-based sorting with std::sort to achieve stable sort behavior + + std::vector indices(engineIds.size()); + std::iota(indices.begin(), indices.end(), 0); + + auto getPriority = [](int64_t engineId) -> int { + if(engineId == hipdnn_data_sdk::utilities::MIOPEN_ENGINE_ID) + { + return 0; + } + if(engineId == hipdnn_data_sdk::utilities::MIOPEN_ENGINE_DETERMINISTIC_ID) + { + return 2; + } + return 1; // Other engines + }; + + std::sort(indices.begin(), indices.end(), [&](size_t i, size_t j) { + const int priA = getPriority(engineIds[i]); + const int priB = getPriority(engineIds[j]); + return (priA != priB) ? (priA < priB) : (i < j); + }); + + // Reorder engineIds based on sorted indices + std::vector sorted; + sorted.reserve(engineIds.size()); + for(const size_t idx : indices) + { + sorted.push_back(engineIds[idx]); + } + engineIds = std::move(sorted); +} + +} // namespace hipdnn_data_sdk::utilities diff --git a/projects/hipdnn/data_sdk/include/hipdnn_data_sdk/utilities/PolicyNames.hpp b/projects/hipdnn/data_sdk/include/hipdnn_data_sdk/utilities/PolicyNames.hpp new file mode 100644 index 00000000000..85314c99a1e --- /dev/null +++ b/projects/hipdnn/data_sdk/include/hipdnn_data_sdk/utilities/PolicyNames.hpp @@ -0,0 +1,45 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +namespace hipdnn_data_sdk::utilities +{ + +/** + * @brief Converts a heuristic policy name string to a deterministic int64_t ID + * + * Uses the FNV-1a hash. The same input always produces the same output, so the + * hash is stable across processes and platforms. Used to identify selection + * heuristic policies (e.g., "SelectionHeuristic::StaticOrdering") on the C ABI + * without shipping null-separated string blobs. + * + * @param policyName The policy name to convert to an ID + * @return int64_t The policy ID + */ +inline int64_t policyNameToId(const char* policyName) noexcept +{ + return static_cast(fnv1aHash(policyName)); +} + +/** + * @brief Overload for std::string + */ +inline int64_t policyNameToId(const std::string& policyName) +{ + return static_cast(fnv1aHash(policyName)); +} + +/** + * @brief Overload for std::string_view + */ +inline int64_t policyNameToId(std::string_view policyName) +{ + return static_cast(fnv1aHash(policyName)); +} + +} // namespace hipdnn_data_sdk::utilities diff --git a/projects/hipdnn/data_sdk/include/hipdnn_data_sdk/utilities/StringUtil.hpp b/projects/hipdnn/data_sdk/include/hipdnn_data_sdk/utilities/StringUtil.hpp index aa4e26e4792..8fb61dd4427 100644 --- a/projects/hipdnn/data_sdk/include/hipdnn_data_sdk/utilities/StringUtil.hpp +++ b/projects/hipdnn/data_sdk/include/hipdnn_data_sdk/utilities/StringUtil.hpp @@ -93,7 +93,7 @@ inline void copyMaxSizeWithNullTerminator(char* destination, const char* source, inline std::string toLower(const std::string& str) { - std::string lowerStr = str; + std::string lowerStr = str; // NOLINT(misc-const-correctness) std::transform(lowerStr.begin(), lowerStr.end(), lowerStr.begin(), ::tolower); return lowerStr; } @@ -111,7 +111,7 @@ inline std::string trim(const std::string& str) inline std::string removeNewlines(const std::string& str) { - std::string result = str; + std::string result = str; // NOLINT(misc-const-correctness) result.erase(std::remove(result.begin(), result.end(), '\r'), result.end()); result.erase(std::remove(result.begin(), result.end(), '\n'), result.end()); return result; diff --git a/projects/hipdnn/docs/Environment.md b/projects/hipdnn/docs/Environment.md index bca7bd8c715..6b426692e27 100644 --- a/projects/hipdnn/docs/Environment.md +++ b/projects/hipdnn/docs/Environment.md @@ -6,6 +6,7 @@ This document describes the environment variables and runtime configuration opti - [Environment Variables](#environment-variables) - [Plugin Discovery](#plugin-discovery) + - [Heuristic Policy Selection](#heuristic-policy-selection) - [Logging Variables](#logging-variables) - [MIOpen Plugin Logging](#miopen-plugin-logging) - [Test Configuration](#test-configuration) @@ -67,6 +68,62 @@ export HIPDNN_HEURISTIC_PLUGIN_DIR=/opt/rocm/lib/hipdnn/plugins/heuristics - Each heuristic plugin must provide a unique policy ID and policy name - See the [Plugin Development Guide](PluginDevelopment.md) for details on creating heuristic plugins +### Heuristic Policy Selection + +hipDNN's heuristic framework selects an engine for each graph by running a configurable list of selection policies (the *outer loop*). The following variables tune that loop and the behavior of two built-in policies. + +#### HIPDNN_HEUR_POLICY_ORDER + +Overrides the heuristic policy order for the outer loop. Read by every `EngineHeuristicDescriptor::finalize()` call. + +| Value | Description | +|------------|------------------------------------------------------------| +| (unset) | Use the descriptor's `HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT` attribute if set; otherwise fall back to the built-in default `[SelectionHeuristic::Config, SelectionHeuristic::StaticOrdering]`. | +| `` | Comma-separated tokens consulted in the order written. Each token is either a policy name (hashed via `policyNameToId`) or a raw decimal int64 policy ID. Whitespace around tokens is trimmed; empty tokens are skipped. | + +This variable has the **highest priority** — it overrides both the descriptor attribute and the built-in default. + +**Example:** +```bash +# By name +export HIPDNN_HEUR_POLICY_ORDER="SelectionHeuristic::Config,SelectionHeuristic::StaticOrdering" + +# By raw ID (or mixed names + IDs) +export HIPDNN_HEUR_POLICY_ORDER="-1234567890123456789,SelectionHeuristic::StaticOrdering" +``` + +#### HIPDNN_HEUR_CONFIG_PATH + +Path to a JSON rule file consumed by the `SelectionHeuristic::Config` built-in policy. The file maps convolution op + tensor-shape patterns to a preferred engine name; the policy walks conv-like nodes in the serialized graph and, on the first matching rule, reorders the candidate engines so the chosen one runs first. Re-read on every `Finalize` invocation — there is no process-wide cache. + +| Value | Description | +|------------|------------------------------------------------------------| +| (unset) | The `SelectionHeuristic::Config` policy declines, allowing subsequent policies to run. | +| `` | Absolute or working-directory-relative path to a JSON rule file. | + +If the file is missing, unreadable, fails to parse, no rule matches, or the matched engine name is not among the current candidates, the policy declines (so the outer loop continues with the next policy). + +**Example:** +```bash +export HIPDNN_HEUR_CONFIG_PATH=/etc/hipdnn/engine_overrides.json +``` + +#### HIPDNN_HEUR_FALLBACK_ENGINE_ORDER + +Replaces the built-in ordering used by `SelectionHeuristic::StaticOrdering`. When set, **only** engines named here are eligible — anything else is dropped from the candidate list. + +| Value | Description | +|------------|------------------------------------------------------------| +| (unset) | Use the built-in static ordering (MIOpen-first, deterministic engines last). | +| `` | Comma-separated engine names, applied in the order written. Whitespace is trimmed and empty tokens are skipped. | + +Engine names that are not among the current candidates are silently skipped. If no listed engine matches any candidate, the policy declines so the outer loop can try the next plugin. + +**Example:** +```bash +export HIPDNN_HEUR_FALLBACK_ENGINE_ORDER="MIOpenConvolutionFwdEngine,HipBLASLtMatmulEngine" +``` + ### Logging Variables hipDNN provides the following environment variables to control logging behavior: diff --git a/projects/hipdnn/docs/rfcs/0007_EngineSelectionHeuristicsFramework.md b/projects/hipdnn/docs/rfcs/0007_EngineSelectionHeuristicsFramework.md index 16243e0dfdf..4862708f510 100644 --- a/projects/hipdnn/docs/rfcs/0007_EngineSelectionHeuristicsFramework.md +++ b/projects/hipdnn/docs/rfcs/0007_EngineSelectionHeuristicsFramework.md @@ -37,7 +37,7 @@ Device capabilities enter the system **only as explicit data** (`DevicePropertie Heuristic plugins implement a **two-tier** **stable C ABI** ([§8](#8-c-abi-for-heuristic-plugins)): a long-lived **plugin handle** (**`hipdnnHeuristicHandle_t`**, **createHandle** / **destroyHandle** / **setDeviceProperties**) created with the same timing and storage pattern as other hipDNN plugin handles, and a **policy descriptor** (**`hipdnnHeuristicPolicyDescriptor_t`**, **createPolicyDescriptor** / **destroyPolicyDescriptor** / **setEngineIds** / **Finalize** / **getSortedIds**) whose **lifecycle is owned by** **`EngineHeuristicDescriptor`**—one per slot in the ordered policy list, destroyed with the heuristic descriptor. That ABI is **separate from the engine plugin ABI**—a `.so` is one or the other, never both. -The **`SelectionHeuristic`** C++ type (or equivalent facade) wraps a **policy descriptor** and forwards to the C ABI; **stateful tracking** in the plugin lives behind the **plugin handle**. **HeuristicPluginManager** (and handle-scoped **HeuristicPluginResourceManager**) own discovery, loading, version validation, registration, and **plugin-handle** instances per **`hipdnnHandle`**. **`EngineHeuristicDescriptor::finalize()`** walks the **ordered policy list** using **descriptor-owned** policy objects, which is **user-configurable** with default **policy name** strings **`{ "SelectionHeuristic::Config", "SelectionHeuristic::StaticOrdering" }`** at the public/config surface; the backend resolves those names to **`int64_t` policy IDs** with the **same deterministic hash as engine IDs** ([RFC 0003](0003_EngineIdDesign.md), `hipdnn_data_sdk::utilities::engineNameToId`) before matching loaded plugins and running the outer loop ([§5.3](#53-ordered-policy-list-default-and-user-configuration)). +The **`SelectionHeuristic`** C++ type (or equivalent facade) wraps a **policy descriptor** and forwards to the C ABI; **stateful tracking** in the plugin lives behind the **plugin handle**. **HeuristicPluginManager** (and handle-scoped **HeuristicPluginResourceManager**) own discovery, loading, version validation, registration, and **plugin-handle** instances per **`hipdnnHandle`**. **`Graph.preferred_engine_id`** is honored by the **frontend** as a post-hoc reorder of the heuristic-ranked engine configs returned by **`HIPDNN_ATTR_ENGINEHEUR_RESULTS`** (see **`Graph::initializeEngineConfig`**); the backend heuristic loop is unaware of it. The previous **`HIPDNN_ENGINE_OVERRIDE_FILE`** knob has been renamed to **`HIPDNN_HEUR_CONFIG_PATH`** and is now implemented as a regular built-in policy, **`SelectionHeuristic::Config`**, that lives in the ordered policy list and **declines** when no rule matches so subsequent policies (typically **`SelectionHeuristic::StaticOrdering`**) take over ([§5.3.5](#535-the-config-built-in-policy)). The default deterministic ordering policy (**`SelectionHeuristic::StaticOrdering`**) is shipped as a **backend built-in**: it implements the same heuristic plugin C ABI shape (function-pointer table) but is **registered in-process** at **`HeuristicPluginManager`** construction time, with no separate `.so` to discover or dlopen ([§10.1](#101-heuristicpluginmanager)). **`EngineHeuristicDescriptor::finalize()`** walks the **ordered policy list** using **descriptor-owned** policy objects; the resolved order is **user-configurable** at the public/config surface ([§5.3](#53-ordered-policy-list-default-and-user-configuration)). The backend resolves user-supplied **policy name** strings to **`int64_t` policy IDs** with the **same deterministic hash as engine IDs** ([RFC 0003](0003_EngineIdDesign.md), `hipdnn_data_sdk::utilities::policyNameToId`) before matching loaded plugins (and built-ins) and running the outer loop. There is **no** separate post-loop step that applies **`utilities::sortEngineIds`** (or any other ordering) inside the backend. Legacy deterministic ordering is available **only** when a policy in **`orderedPolicyIds`** implements it—for example **`SelectionHeuristic::StaticOrdering`**. If every policy declines or the list is misconfigured so that no policy succeeds, **`finalize()`** fails using the **same error path as the rest of the hipDNN backend** (for example **`THROW_IF_FALSE`** / **`HipdnnException`** with an appropriate **`hipdnnStatus_t`** such as **`HIPDNN_STATUS_INTERNAL_ERROR`**, matching other descriptor **`finalize()`** failures when a required step does not complete—exact status TBD). @@ -60,7 +60,7 @@ There is **no** separate post-loop step that applies **`utilities::sortEngineIds 2. Sorts them with `utilities::sortEngineIds` inside `finalize()` (for example prioritizing MIOPEN engine IDs and deprioritizing deterministic variants). 3. Exposes the ordered list through `HIPDNN_ATTR_ENGINEHEUR_RESULTS` as engine configuration descriptors. -This RFC preserves **equivalent deterministic ordering** as one **policy** (the “StaticOrdering” selector), implemented as a plugin or thin wrapper, **only when that policy appears** in **`orderedPolicyIds`**. It is **not** applied as unconditional backend logic after the outer loop. +This RFC preserves **equivalent deterministic ordering** as one **policy** (the “StaticOrdering” selector), shipped as a **backend built-in** that implements the heuristic plugin C ABI shape (function-pointer table) without a separate `.so` ([§10.1](#101-heuristicpluginmanager), [§16.2](#162-loaded-heuristic-policy-enumeration)). It runs **only when that policy appears** in **`orderedPolicyIds`** and is **not** applied as unconditional backend logic after the outer loop. **`Graph.preferred_engine_id`** is now a **frontend** concern — **`Graph::initializeEngineConfig`** reorders the backend's heuristic-ranked engine configs after **`finalize()`**. The previous file-based override (formerly **`HIPDNN_ENGINE_OVERRIDE_FILE`**) is replaced by **`HIPDNN_HEUR_CONFIG_PATH`** and implemented as the built-in **`SelectionHeuristic::Config`** policy ([§5.3.5](#535-the-config-built-in-policy)) — it lives in the ordered policy list like any other selector and declines when no rule matches. --- @@ -95,9 +95,9 @@ There is **one** primary control flow: 1. Build the list of **candidate engine IDs** from existing **engine** plugins (unchanged). 2. Obtain **serialized device properties** and the **serialized operation graph** (see [§6](#6-device-properties) and [§13](#13-serialized-graph-device-properties-and-graph-level-preferences)). 3. Resolve the **ordered list of policy plugin IDs** (`int64_t`, from user **policy name** strings via the shared name→ID hash—[§5.3](#53-ordered-policy-list-default-and-user-configuration)). -4. Ensure **`EngineHeuristicDescriptor`** owns one **plugin policy descriptor** per slot (see [§5.4](#54-two-tier-plugin-objects-handle-vs-policy-descriptor)); each slot binds to the **`hipdnnHeuristicHandle_t`** for that policy’s loaded module (from **`HeuristicPluginResourceManager`**). The policy descriptor is **created with** that handle ([§8.7](#87-policy-descriptor-per-slot-graph--candidate-ids)); **Finalize** consumes device facts **from that handle’s session state**, not from separate per-call device arguments on the policy. +4. Ensure **`EngineHeuristicDescriptor`** owns one **plugin policy descriptor** per slot (see [§5.4](#54-two-tier-plugin-objects-handle-vs-policy-descriptor)); each slot binds to the **`hipdnnHeuristicHandle_t`** for that policy’s loaded module **or built-in** (from **`HeuristicPluginResourceManager`**). The policy descriptor is **created with** that handle ([§8.7](#87-policy-descriptor-per-slot-graph--candidate-ids)); **Finalize** consumes device facts **from that handle’s session state**, not from separate per-call device arguments on the policy. 5. After resolving **serialized device properties** (FlatBuffer bytes from **`DeviceProperties`**—[§6](#6-device-properties), [§13.2](#132-serialized-device-properties-flatbuffer)), call **`hipdnnHeuristicHandleSetDeviceProperties`** **once per distinct** **`hipdnnHeuristicHandle_t`** that appears among the policy slots (not once per slot iteration). That establishes the device context on the **handle**; heuristic code **queries** it as needed during **Finalize** (and any later selection work on that handle). -6. For each slot in order: **setEngineIds** and **serialized graph** on that slot’s policy descriptor, call **Finalize**; if the policy **wins**, read **getSortedIds** and **stop**; otherwise continue. +6. For each slot in order: **setEngineIds** (current candidates) and **serialized graph** on that slot’s policy descriptor, call **Finalize**; if the policy **wins**, read **getSortedIds** and **stop**; otherwise continue. The default order puts **`SelectionHeuristic::Config`** ([§5.3.5](#535-the-config-built-in-policy)) ahead of **`SelectionHeuristic::StaticOrdering`** so user-supplied **`HIPDNN_HEUR_CONFIG_PATH`** rules win when matched and the deterministic fallback runs otherwise. 7. If no policy succeeds, **`finalize()`** **fails** using the same pattern as other backend descriptor logic errors (for example **`THROW_IF_FALSE(success, HIPDNN_STATUS_INTERNAL_ERROR, …)`** throwing **`HipdnnException`**, analogous to **`GraphDescriptor::finalize()`** when verification fails—exact message and status code are implementation details). There is **no** additional built-in sort after the loop. There is **no** separate inner registry of sub-stages inside a single policy plugin in this design. If a team wants “config then static ordering” inside one deliverable, they ship **one** policy plugin that implements that sequence internally, or they register two entries in the **outer** list. @@ -111,54 +111,78 @@ There is **no** separate inner registry of sub-stages inside a single policy plu | **Rule selector** | Deterministic rules over graph features and `DeviceProperties`. | | **StaticOrdering-style selector** | Reproduce current `sortEngineIds` behavior when listed in **`orderedPolicyIds`**; users often place it **last** in the default list so behavior matches today unless they omit it. | -These are **examples**; each plugin chooses a **canonical UTF-8 policy name** (for documentation and user config). The **stable identifier** across the loader, registry, and C ABI is the **`int64_t`** produced by the **same `engineNameToId` hash** as computational engines, returned from **`hipdnnHeuristicGetPolicyId`** ([§8](#8-c-abi-for-heuristic-plugins)). +These are **examples**; each plugin chooses a **canonical UTF-8 policy name** (for documentation and user config). The **stable identifier** across the loader, registry, and C ABI is the **`int64_t`** produced by the **same `policyNameToId` hash** as computational engines, returned from **`hipdnnHeuristicGetPolicyId`** ([§8](#8-c-abi-for-heuristic-plugins)). ### 5.3 Ordered policy list: default and user configuration Policy order has **two representations**: - **User / public surface:** an ordered list of **UTF-8 policy name** strings (attributes, env vars, optional frontend helpers)—the same **human-readable names** vendors document for their plugins (for example **`SelectionHeuristic::StaticOrdering`**). -- **Internal / loader / C ABI:** each name maps to a deterministic **`int64_t`** using **`hipdnn_data_sdk::utilities::engineNameToId`** (FNV-1a, same as computational **engine** IDs—[RFC 0003](0003_EngineIdDesign.md), `EngineNames.hpp`). The outer loop operates on **`std::vector orderedPolicyIds`**; each element must match **`hipdnnHeuristicGetPolicyId`** for a loaded heuristic plugin (or a **built-in** policy registered under the same **`int64_t`** without a separate `.so`). +- **Internal / loader / C ABI:** each name maps to a deterministic **`int64_t`** using **`hipdnn_data_sdk::utilities::policyNameToId`** (`PolicyNames.hpp`). The hash is FNV-1a, the same hash family used for computational **engine** IDs ([RFC 0003](0003_EngineIdDesign.md), `EngineNames.hpp::engineNameToId`); the two functions are siblings rather than aliases — they share the algorithm but are kept as separate symbols so policy and engine name spaces stay textually distinct in call sites. The outer loop operates on **`std::vector orderedPolicyIds`**; each element must match **`hipdnnHeuristicGetPolicyId`** for a loaded heuristic plugin (or a **built-in** policy registered under the same **`int64_t`** without a separate `.so`). **Collision note:** Policy IDs and engine IDs are both `int64_t` hashes of **different string namespaces** in normal usage (policy names like `SelectionHeuristic::…` vs engine names like `MIOPEN_PLUGIN`). A numeric collision is theoretically possible; **context** (policy registry vs engine registry) keeps them separate. #### 5.3.1 Well-known policy names and IDs -This draft standardizes two logical policies using **well-known UTF-8 names** (the `SelectionHeuristic::` prefix is a **naming convention** in the string itself, not C++ language linkage or a shared type with the C++ `SelectionHeuristic` class). Their **canonical `int64_t` IDs** are **`engineNameToId(name)`** for each name below (plugins **must** return that exact value from **`hipdnnHeuristicGetPolicyId`**): +This draft standardizes one logical policy using a **well-known UTF-8 name** (the `SelectionHeuristic::` prefix is a **naming convention** in the string itself, not C++ language linkage or a shared type with the C++ `SelectionHeuristic` class). Its **canonical `int64_t` ID** is **`policyNameToId(name)`** (a plugin or built-in implementing this policy **must** return that exact value from **`hipdnnHeuristicGetPolicyId`**): | Policy name string | Role | |--------------------|------| -| **`SelectionHeuristic::Config`** | Applies user / graph configuration (for example honoring **`preferred_engine_id`**, env-based disables, future descriptor knobs). Typically runs **first** so explicit intent overrides later policies. | -| **`SelectionHeuristic::StaticOrdering`** | Deterministic ordering aligned with today’s **`utilities::sortEngineIds`** (for example MIOPEN preference, deterministic engine last). Typically runs **after** Config in the **default** list; it is **not** invoked by the backend outside the resolved policy order. | +| **`SelectionHeuristic::Config`** | Reads the JSON file pointed to by **`HIPDNN_HEUR_CONFIG_PATH`**, walks the graph’s convolution nodes against the rule set, and on a match returns the candidate list with the matched engine moved to the front. **Declines** when the env var is unset, the file is missing, the graph cannot be parsed, no rule matches, or the matched engine is not in the candidate set — so the policy loop falls through to the next slot. Shipped as a **backend built-in** ([§10.1](#101-heuristicpluginmanager), [§5.3.5](#535-the-config-built-in-policy)); no separate `.so`. | +| **`SelectionHeuristic::StaticOrdering`** | Deterministic ordering aligned with today’s **`utilities::sortEngineIds`** (for example MIOPEN preference, deterministic engine last). Shipped as a **backend built-in** ([§10.1](#101-heuristicpluginmanager)); registered through the heuristic plugin C ABI shape but with no separate `.so`. Invoked only when listed in the resolved policy order. | -**Mandatory registration.** A shipped heuristic plugin (or built-in adapter) that implements **`SelectionHeuristic::Config`** or **`SelectionHeuristic::StaticOrdering`** **must** expose **`hipdnnHeuristicGetPolicyId`** returning exactly **`engineNameToId("SelectionHeuristic::Config")`** or **`engineNameToId("SelectionHeuristic::StaticOrdering")`** respectively—there is **no** alternate ID for those behaviors. Optional **`hipdnnHeuristicGetPolicyName`** ([§8.2](#82-plugin-module-metadata)) **should** report the same canonical UTF-8 string; when present, the host **validates at load time** that **`engineNameToId(hipdnnHeuristicGetPolicyName()) == hipdnnHeuristicGetPolicyId()`** and **rejects** the module on mismatch (same idea as keeping engine IDs and registered engine names consistent). +**`Graph.preferred_engine_id`** is **not** a policy. The frontend reorders the heuristic-ranked engine configs in **`Graph::initializeEngineConfig`** after **`EngineHeuristicDescriptor::finalize()`** has returned, so it is honored independently of the policy list. -**Overlapping IDs at load.** **`HeuristicPluginManager`** **must** reject a heuristic module whose **`hipdnnHeuristicGetPolicyId`** duplicates an **`int64_t`** already claimed by another **accepted** heuristic plugin, using the **same** duplicate-ID tracking as **`EnginePluginManager::validateBeforeAdding`** does for computational engine IDs ([§10.1](#101-heuristicpluginmanager), [§11](#11-versioning-and-compatibility-checks)). +**Mandatory registration.** A shipped heuristic plugin (or built-in adapter) that implements **`SelectionHeuristic::StaticOrdering`** **must** expose **`hipdnnHeuristicGetPolicyId`** returning exactly **`policyNameToId("SelectionHeuristic::StaticOrdering")`** — there is **no** alternate ID for that behavior. Optional **`hipdnnHeuristicGetPolicyName`** ([§8.2](#82-plugin-module-metadata)) **should** report the same canonical UTF-8 string; when present, the host **validates at load time** that **`policyNameToId(hipdnnHeuristicGetPolicyName()) == hipdnnHeuristicGetPolicyId()`** and **rejects** the module on mismatch (same idea as keeping engine IDs and registered engine names consistent). + +**Overlapping IDs at load.** **`HeuristicPluginManager`** **must** reject a heuristic module whose **`hipdnnHeuristicGetPolicyId`** duplicates an **`int64_t`** already claimed by another **accepted** heuristic plugin or built-in, using the **same** duplicate-ID tracking as **`EnginePluginManager::validateBeforeAdding`** does for computational engine IDs ([§10.1](#101-heuristicpluginmanager), [§11](#11-versioning-and-compatibility-checks)). The built-in **`StaticOrdering`** is registered first, so a third-party module trying to claim the same ID will be rejected. #### 5.3.2 Default policy order (user strings) -If the user does **not** override the list, the **configured** names are: +If the user does **not** override the list, the **configured** fallback name list is: ```text { "SelectionHeuristic::Config", "SelectionHeuristic::StaticOrdering" } ``` -The backend **hashes** each entry to build the internal **`orderedPolicyIds`** used in **`finalize()`**. So **Config** is tried first; if it does not win the outer loop (not applicable or the policy **`finalize()`** does not report success), **StaticOrdering** runs. **StaticOrdering** is expected to succeed for valid candidate sets when implemented correctly; if **no** listed policy succeeds (for example the user omits **StaticOrdering** and other policies all decline), **`EngineHeuristicDescriptor::finalize()`** fails per [§5.1](#51-single-orchestration-model-outer-loop)—there is **no** backend fallback sort. +That list is also the **effective default order** seen by **`finalize()`** — the backend does **not** prepend or append anything. **`Config`** runs first so user-supplied **`HIPDNN_HEUR_CONFIG_PATH`** rules win when matched; otherwise it declines and **`StaticOrdering`** takes over. **`Graph.preferred_engine_id`** is **not** part of this list — the frontend handles it after the heuristic returns ([§5.3.5](#535-the-config-built-in-policy)). + +The backend **hashes** each entry to build the internal **`orderedPolicyIds`** used in **`finalize()`**. **StaticOrdering** is expected to succeed for valid candidate sets when implemented correctly; if **no** listed policy succeeds (for example the user supplies a policy order that omits **StaticOrdering** and every listed policy declines), **`EngineHeuristicDescriptor::finalize()`** fails per [§5.1](#51-single-orchestration-model-outer-loop) — there is **no** backend fallback sort. #### 5.3.3 How the user sets orderedPolicyIds -Users configure **policy name** strings (UTF-8); the backend hashes them to **`int64_t`** for **`orderedPolicyIds`**. **Resolution order (highest precedence first):** +The user surface is mixed: the env-var path accepts UTF-8 **policy name** strings (and raw IDs as an escape hatch); the backend attribute path takes pre-hashed **`int64_t`** IDs directly. **Resolution order (highest precedence first):** -1. **Engine-heuristic descriptor (per finalize)** — User sets an explicit ordered list via a backend attribute (proposal: **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER`** or extension equivalent: array of **UTF-8 policy name** strings). Applies only to that **`EngineHeuristicDescriptor`** instance. -2. **hipDNN handle (per handle)** — Optional API on **`HeuristicPluginResourceManager`** or a handle extension (proposal: **`hipdnnSetHeuristicPolicyOrder_ext(handle, names, count)`** or C++ **`setDefaultHeuristicPolicyOrder`** taking **strings**). Used when the descriptor has no override. -3. **Process environment (optional)** — Proposal: **`HIPDNN_HEURISTIC_POLICY_ORDER`** as a comma-separated list of **policy names**, applied when neither descriptor nor handle supplied a list. Exact syntax and escaping TBD. -4. **Built-in default** — **`{ "SelectionHeuristic::Config", "SelectionHeuristic::StaticOrdering" }`** when nothing above applies. +1. **Process environment** — **`HIPDNN_HEUR_POLICY_ORDER`** as a comma-separated list. Each token is either a policy name (UTF-8) or a signed decimal **`int64_t`** policy ID. A token is treated as an ID only when `std::strtoll` consumes the entire trimmed token; anything else (including names that happen to start with digits) is hashed through **`policyNameToId`**. Highest precedence so operators can override application-supplied lists at deploy time. +2. **Engine-heuristic descriptor (per finalize)** — User sets an explicit ordered list via the **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT`** backend attribute, typed **`HIPDNN_TYPE_INT64`** (array of pre-hashed **`int64_t`** policy IDs). The C ABI takes IDs rather than name strings because a `void* arrayOfElements` + `int64_t elementCount` surface does not cleanly carry variable-length string arrays; callers (typically the frontend) are expected to hash names through **`policyNameToId`** before calling `setAttribute`. Applies only to that **`EngineHeuristicDescriptor`** instance. +3. **Built-in fallback** — **`{ "SelectionHeuristic::Config", "SelectionHeuristic::StaticOrdering" }`** when neither env nor descriptor supplied a list. The names are hashed in the backend via **`policyNameToId`** at the call site. **`Graph.preferred_engine_id`** is honored by the frontend after the heuristic returns and is independent of this list ([§5.3.5](#535-the-config-built-in-policy)). -After merging, the implementation **hashes** each name with **`engineNameToId`** and **validates** each resulting **`int64_t`** against loaded/built-in policies (and logs or skips unknown IDs per policy TBD). +A **handle-level** default (e.g. **`hipdnnSetHeuristicPolicyOrder_ext(handle, ...)`**) is **not implemented** in the current backend. It is retained as a candidate extension in [§11](#11-options-and-tradeoffs) but is not part of the resolution order today. + +Unknown IDs that survive resolution simply produce a null `_policySlots` entry inside the per-finalize sync (the policy is treated as "not loaded") and the outer loop skips them without raising an error. #### 5.3.4 Relationship to `finalize()` -`EngineHeuristicDescriptor::finalize()` (or a helper on **`HeuristicPluginResourceManager`**) **merges** the resolution order above into an ordered list of **policy name** strings, **maps** them to **`std::vector orderedPolicyIds`**, ensures **policy descriptor** objects exist **one-to-one** with that list (recreating them if the list changed since last setup), then runs the outer loop in [§14.2](#142-pseudocode-for-finalize-first-draft). If the loop completes without a successful policy, **`finalize()`** does not mark the descriptor finalized and exits via **`HipdnnException`** like other backend failures. +`EngineHeuristicDescriptor::finalize()` resolves an ordered **`std::vector orderedPolicyIds`** from the precedence chain above (env tokens are name-or-ID hashed per token; the descriptor attribute is already a raw ID array; the default fallback is hashed from its built-in name list via **`policyNameToId`**). It then ensures **policy descriptor** objects exist **one-to-one** with that list (recreating them if the list changed since last setup) and runs the outer loop in [§14.2](#142-pseudocode-for-finalize-first-draft). If the loop completes without a successful policy, **`finalize()`** does not mark the descriptor finalized and exits via **`HipdnnException`** like other backend failures. + +#### 5.3.5 The Config built-in policy + +**`SelectionHeuristic::Config`** (in `backend/src/heuristics/config/`) is a regular built-in policy that wraps the JSON-driven engine-override behavior previously exposed through **`HIPDNN_ENGINE_OVERRIDE_FILE`**. The env var has been **renamed to `HIPDNN_HEUR_CONFIG_PATH`** and the resolver is no longer a precursor — it sits in **`orderedPolicyIds`** like any other policy and **declines** when nothing applies, letting the policy loop fall through to the next slot (typically **`SelectionHeuristic::StaticOrdering`**). + +**Behavior on each `Finalize`:** + +1. If the candidate list is empty, decline. +2. Read **`HIPDNN_HEUR_CONFIG_PATH`** from the environment on every invocation (no process cache). If the variable is unset or the file cannot be opened/parsed, decline. +3. Parse the serialized graph buffer with the data-SDK FlatBuffer verifier. If the buffer is malformed or the graph has no nodes, decline. +4. Walk the graph’s **`ConvolutionFwd`** / **`ConvolutionBwdData`** / **`ConvolutionBwdFilter`** nodes against the rule’s op + tensor-shape patterns (exact dim/stride matching with `-1` wildcards, exact-bucket-before-wildcard-bucket within declaration order). On the first match, take the matched engine ID. +5. If the matched engine ID is **not** in the candidate list, decline. +6. Otherwise return the candidates with the matched engine moved to the front (other candidates preserve their original order) and report success (`*outApplied = 1`). + +**Default placement.** The default policy list is **`{ "SelectionHeuristic::Config", "SelectionHeuristic::StaticOrdering" }`** ([§5.3.2](#532-default-policy-order-user-strings)) so user-supplied JSON rules win over the deterministic fallback. Users may freely omit, reorder, or replace either built-in. There is no implicit guarantee that **`Config`** runs — if the user supplies a custom policy order without it, the JSON file is simply not consulted for that **`finalize()`**. + +**Graph-level `preferred_engine_id` is no longer in scope of this policy.** **`Graph::set_preferred_engine_id_ext`** is honored by the **frontend** as a post-hoc reorder of the heuristic-ranked engine configs in **`Graph::initializeEngineConfig`**, after **`EngineHeuristicDescriptor::finalize()`** has returned. The backend heuristic loop is unaware of it, and the **`Config`** built-in does **not** consult it. + +**`StaticOrdering`** retains no override guarantee — it is a normal default-ordered fallback ([§5.3.2](#532-default-policy-order-user-strings)) and may be omitted or reordered by the user. Both built-ins are shipped through **`HeuristicPlugin::createBuiltIn(populateFunctionTable(), label)`** at **`HeuristicPluginManager`** construction time ([§10.1](#101-heuristicpluginmanager)); neither has a separate `.so` and the user cannot unload them. ### 5.4 Two-tier plugin objects: handle vs policy descriptor @@ -287,8 +311,8 @@ Each heuristic `.so` exports the following (names are illustrative; implementati | Function | Purpose | |----------|---------| | `hipdnnHeuristicGetApiVersion(const char** version)` | Semantic version of **this C ABI** (for example `"1.0.0"`). Host rejects load on **major** mismatch. | -| `hipdnnHeuristicGetPolicyId(int64_t* policy_id)` | Stable **`int64_t`** policy identifier: **must** equal **`engineNameToId(canonical_utf8_name)`** for the plugin’s documented policy name (same hash as computational engine IDs—[§5.3](#53-ordered-policy-list-default-and-user-configuration), [RFC 0003](0003_EngineIdDesign.md)). The host matches this value against the resolved **`orderedPolicyIds`** after hashing user-supplied name strings. | -| `hipdnnHeuristicGetPolicyName(const char** policy_name)` | Optional. **NUL-terminated UTF-8** canonical name (same string the vendor tells users to put in **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER`**). For logging and enumeration; **not** required for matching if **`GetPolicyId`** is implemented correctly. When exported, the host **validates at load** that **`engineNameToId(*policy_name) ==`** the value from **`GetPolicyId`** ([§5.3.1](#531-well-known-policy-names-and-ids), [§11](#11-versioning-and-compatibility-checks)). Omit from minimal plugins if the host derives display names from a static registry. | +| `hipdnnHeuristicGetPolicyId(int64_t* policy_id)` | Stable **`int64_t`** policy identifier: **must** equal **`policyNameToId(canonical_utf8_name)`** for the plugin’s documented policy name (same hash as computational engine IDs—[§5.3](#53-ordered-policy-list-default-and-user-configuration), [RFC 0003](0003_EngineIdDesign.md)). The host matches this value against the resolved **`orderedPolicyIds`** after hashing user-supplied name strings. | +| `hipdnnHeuristicGetPolicyName(const char** policy_name)` | Optional. **NUL-terminated UTF-8** canonical name (same string the vendor tells users to put in **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT`**). For logging and enumeration; **not** required for matching if **`GetPolicyId`** is implemented correctly. When exported, the host **validates at load** that **`policyNameToId(*policy_name) ==`** the value from **`GetPolicyId`** ([§5.3.1](#531-well-known-policy-names-and-ids), [§11](#11-versioning-and-compatibility-checks)). Omit from minimal plugins if the host derives display names from a static registry. | | `hipdnnHeuristicGetPluginVersion(const char** version)` | Plugin implementation version (informational). | | `hipdnnHeuristicSetLoggingCallback(hipdnnCallback_t cb)` | Registers the consumer logging callback; optional `hipdnnHeuristicSetLogLevel(hipdnnSeverity_t)` mirroring engine plugin behavior. | | `hipdnnHeuristicGetLastErrorString(const char** msg)` | Per-thread last error after a failed call; pointer valid only for immediate use (same contract as `hipdnnPluginGetLastErrorString`). | @@ -457,8 +481,10 @@ Analogous to **`EnginePluginManager`**: - Extends the same **plugin manager base** pattern (shared library load, symbol resolution, lifecycle). - Uses a **separate search path** from engine plugins (for example `hipdnn_plugins/heuristics/` and/or a dedicated env var such as `HIPDNN_HEURISTIC_PLUGIN_DIR`—exact names TBD). - Resolves **heuristic-only** symbols from [§8](#8-c-abi-for-heuristic-plugins); does **not** use **`EnginePlugin`** symbol tables. -- Implements **`validateBeforeAdding`**-style checks before accepting a plugin—**parallel to `EnginePluginManager`**: **API major** match (via **`hipdnnHeuristicGetApiVersion`**); **unique** **`int64_t` policy ID** (via **`hipdnnHeuristicGetPolicyId`**, tracked in a set so a second module cannot register the same ID—same pattern as duplicate engine IDs in `EnginePluginManager`); optional **ID ↔ name** consistency when **`hipdnnHeuristicGetPolicyName`** is exported; and any additional rules from [§11](#11-versioning-and-compatibility-checks). +- Implements **`validateBeforeAdding`**-style checks before accepting a plugin—**parallel to `EnginePluginManager`**: **API major** match (via **`hipdnnHeuristicGetApiVersion`**); **unique** **`int64_t` policy ID** (via **`hipdnnHeuristicGetPolicyId`**, tracked in a set so a second module cannot register the same ID—same pattern as duplicate engine IDs in `EnginePluginManager`); optional **ID ↔ name** consistency when **`hipdnnHeuristicGetPolicyName`** is exported; and any additional rules from [§11](#11-versioning-and-compatibility-checks). One plugin may export **multiple** policies; each policy ID is checked individually for global uniqueness. - Owns **`HeuristicPlugin`** wrappers that bind the C ABI (handle + policy symbols) and expose **`HandleCreate` / `HandleDestroy`** to the resource manager. +- **Registers backend built-ins at construction.** Before any external `.so` discovery, the manager calls **`registerBuiltIns()`** (in `backend/src/heuristics/BuiltInHeuristics.cpp`) to register every shipped built-in heuristic. Today there are two — **`SelectionHeuristic::Config`** (in `backend/src/heuristics/config/`, driven by **`HIPDNN_HEUR_CONFIG_PATH`** — see [§5.3.5](#535-the-config-built-in-policy)) and **`SelectionHeuristic::StaticOrdering`** (in `backend/src/heuristics/static_ordering/`). A built-in implements the **same heuristic plugin C ABI shape** (a populated **`HeuristicPluginFunctionTable`**) as a third-party `.so` and is wrapped via **`HeuristicPlugin::createBuiltIn(populateFunctionTable(), source_label)`**, so the rest of the manager (lookup by **`int64_t` policy ID**, version validation, duplicate-ID rejection) treats it identically to a loaded module. There is **no** separate **`BaseHeuristicsPlugin`** library to ship and **no** override of **`loadPlugins`** to inject default paths. Built-ins always exist regardless of the caller's **`hipdnnPluginLoadingMode_ext_t`**, including **`HIPDNN_PLUGIN_LOADING_ABSOLUTE`**. +- **`Graph.preferred_engine_id` is not a built-in concern.** It is honored by the **frontend** as a post-hoc reorder of the heuristic-ranked engine configs (see **`Graph::initializeEngineConfig`**), independent of any backend policy or built-in. ### 10.2 `HeuristicPluginResourceManager` @@ -494,7 +520,7 @@ Follow the same **spirit** as `EnginePluginManager::validateBeforeAdding` in the 1. **Heuristic C ABI major:** Parse **`hipdnnHeuristicGetApiVersion`**; **major** must match the backend’s expected heuristic API major (analogous to engine plugins comparing `plugin.apiVersion()` major to `HIPDNN_BACKEND_VERSION_MAJOR`, but using the **heuristic** version string, not the engine plugin API version). 2. **Policy ID uniqueness:** Two loaded heuristic modules **must not** return the same **`int64_t`** from **`hipdnnHeuristicGetPolicyId`**. Enforce with the **same** “insert into a set, throw on duplicate” pattern as **`EnginePluginManager::validateBeforeAdding`** / **`actionAfterAdding`** for engine IDs ([§5.3.1](#531-well-known-policy-names-and-ids)). -3. **Policy ID ↔ optional policy name:** If **`hipdnnHeuristicGetPolicyName`** is provided, **`engineNameToId`** of the returned UTF-8 string **must** equal **`hipdnnHeuristicGetPolicyId`**; otherwise the loader **rejects** the module (catches mistaken or overlapping well-known implementations early—[§5.3.1](#531-well-known-policy-names-and-ids)). +3. **Policy ID ↔ optional policy name:** If **`hipdnnHeuristicGetPolicyName`** is provided, **`policyNameToId`** of the returned UTF-8 string **must** equal **`hipdnnHeuristicGetPolicyId`**; otherwise the loader **rejects** the module (catches mistaken or overlapping well-known implementations early—[§5.3.1](#531-well-known-policy-names-and-ids)). 4. **Binary compatibility:** Document minimum backend / data-SDK versions per heuristic plugin release (align with project-wide versioning RFCs under `docs/rfcs/`), including expectations for **graph** and **device-properties** FlatBuffer schemas ([§13](#13-serialized-graph-device-properties-and-graph-level-preferences)). **On failure:** Do not register the plugin; log via the shared logging path ([§12](#12-logging)); continue loading other policies if policy loading is best-effort, or fail handle creation if strict mode is required (policy TBD). @@ -545,7 +571,9 @@ Policies that need structured access may parse the FlatBuffer using existing dat ### 13.3 Graph-level preferences (for example `preferred_engine_id`) -The graph model already carries fields such as **`preferred_engine_id`** when built from operation descriptors. **This draft assigns responsibility as follows:** interpreting graph-level preferences (honor, override, validate) is the job of a **concrete policy**—typically a **rule-based** or **config**-style selector—not a separate hard-coded pass in `EngineHeuristicDescriptor`. The outer loop may place that policy early in the ordered list so user intent affects ordering before ML or cache policies. +The graph model already carries fields such as **`preferred_engine_id`** when built from operation descriptors. **Responsibility for honoring `preferred_engine_id` belongs to the frontend**, not to any backend policy or plugin. After **`EngineHeuristicDescriptor::finalize()`** returns the heuristic-ranked engine configs, **`Graph::initializeEngineConfig`** in **hipdnn_frontend** moves the preferred config (if its global engine index matches) to the front of the candidate list. Because the reorder happens outside the backend, the user's choice of **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT`** does not affect whether **`preferred_engine_id`** is honored. + +The file-driven override formerly carried by **`HIPDNN_ENGINE_OVERRIDE_FILE`** has been **renamed to `HIPDNN_HEUR_CONFIG_PATH`** and is implemented as a regular built-in policy — **`SelectionHeuristic::Config`** ([§5.3.5](#535-the-config-built-in-policy)) — that lives in the policy list and operates on the serialized graph buffer like any other policy. --- @@ -582,15 +610,21 @@ finalize(): candidates = engineRm.getApplicableEngineIds(graph) + serializedGraph = graph.getSerializedGraph() // hipdnnPluginConstData_t; graph must be usable for heuristics + devProps = userDeviceOverride if set else queryDeviceProperties() devicePropsSerialized = serializeDevicePropertiesFlatBuffer(devProps) // §13.2; hipdnnPluginConstData_t; host-owned for this finalize() - serializedGraph = graph.getSerializedGraph() // hipdnnPluginConstData_t; graph must be usable for heuristics - - orderedPolicyIds = resolveHeuristicPolicyOrder(thisDescriptor, handle) - // §5.3: descriptor attr > handle > env > default — user surface is UTF-8 policy *names* - // default names { "SelectionHeuristic::Config", "SelectionHeuristic::StaticOrdering" } if no override - // implementation: merge names, then orderedPolicyIds[i] = engineNameToId(name[i]) + orderedPolicyIds = resolveHeuristicPolicyOrder(thisDescriptor) + // §5.3.3 precedence: HIPDNN_HEUR_POLICY_ORDER env > HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT > built-in fallback + // fallback names { "SelectionHeuristic::Config", "SelectionHeuristic::StaticOrdering" } if no override + // implementation: + // 1. pick the highest-precedence non-empty source + // 2. env: each token is name OR signed-decimal int64 ID (name → policyNameToId, ID → as-is) + // descriptor: array is already int64 IDs (caller hashed names via policyNameToId) + // fallback: hash the built-in names via policyNameToId at the call site + // No backend-injected entries — duplicates from the user-supplied list are preserved as-is. + // Handle-level default (§11) is not implemented. syncPolicySlots(thisDescriptor, orderedPolicyIds, heurRm) // Ensure one SelectionHeuristic (hipdnnHeuristicPolicyDescriptor_t) per slot, each created with @@ -630,12 +664,12 @@ finalize(): ## 15. Public API notes -- **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER` (proposal):** Ordered list of **UTF-8 policy name** strings for **`EngineHeuristicDescriptor`**, overriding handle/env defaults ([§5.3](#53-ordered-policy-list-default-and-user-configuration)). Attribute type: array of strings (or equivalent) consistent with other vector attributes in the backend. The backend **hashes** each name with **`engineNameToId`** when resolving **`finalize()`** order. When this list changes, the backend **recreates** the owned **policy descriptor** objects ([§5.4](#54-two-tier-plugin-objects-handle-vs-policy-descriptor)). -- **Handle-level override (proposal):** Extension API or **`HeuristicPluginResourceManager`** method to set the default **policy name** order for all heuristic descriptors on that handle unless the descriptor sets **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER`**. -- **`HIPDNN_HEURISTIC_POLICY_ORDER` (optional env):** Comma-separated **policy names**; lowest precedence among user overrides ([§5.3.3](#533-how-the-user-sets-orderedpolicyids)). -- **`HIPDNN_ATTR_ENGINEHEUR_MODE`:** Today the backend supports a narrow heuristic mode surface. This RFC does **not** remove the attribute; a future mapping might define default **policy order** per mode, or deprecate mode once **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER`** and handle defaults are sufficient. That decision is left open in this draft. +- **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT`:** Ordered list of pre-hashed **`int64_t`** policy IDs for **`EngineHeuristicDescriptor`**, overriding the built-in default ([§5.3](#53-ordered-policy-list-default-and-user-configuration)). Attribute type: **`HIPDNN_TYPE_INT64`** (array of IDs). The backend stores IDs as-is — callers (typically the frontend) hash policy names via **`policyNameToId`** before calling `setAttribute` because the `void*` + `int64_t elementCount` ABI does not cleanly carry variable-length string arrays. When this list changes between finalizes (a fresh descriptor with a different list), the per-finalize sync step recreates the owned **policy descriptor** objects ([§5.4](#54-two-tier-plugin-objects-handle-vs-policy-descriptor)). +- **Handle-level override (deferred):** A per-handle default (e.g. **`hipdnnSetHeuristicPolicyOrder_ext(handle, ...)`**) was sketched in earlier drafts of this RFC but is **not implemented**. Today the resolution chain is env > descriptor attribute > built-in fallback only. Retained here as a candidate extension if a per-handle default surface is needed later. +- **`HIPDNN_HEUR_POLICY_ORDER` (optional env):** Comma-separated tokens; each token is a UTF-8 policy name **or** a signed decimal **`int64_t`** policy ID (per-token disambiguation by full-string `strtoll`). **Highest** precedence among user overrides so operators can override application-supplied lists at deploy time ([§5.3.3](#533-how-the-user-sets-orderedpolicyids)). +- **`HIPDNN_ATTR_ENGINEHEUR_MODE`:** Today the backend supports a narrow heuristic mode surface. This RFC does **not** remove the attribute; a future mapping might define default **policy order** per mode, or deprecate mode once **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT`** is sufficient. That decision is left open in this draft. - **`HIPDNN_ATTR_ENGINEHEUR_DEVICEPROP`:** Proposed as the user-facing override for [§6.3](#63-proposed-override-descriptor-level-device-properties) when the descriptor type and setters are implemented. -- **No requirement for new enums** per new policy: adding a policy is **deployment + registry order** (user-facing **names** in config; **`int64_t`** IDs from **`hipdnnHeuristicGetPolicyId`** at load time), not necessarily a new public enum value. Well-known **names** **`SelectionHeuristic::Config`** and **`SelectionHeuristic::StaticOrdering`** are **strings** in attributes and docs; their **`int64_t`** values are **`engineNameToId(...)`**, not enum members. +- **No requirement for new enums** per new policy: adding a policy is **deployment + registry order** (names in env/config; **`int64_t`** IDs from **`policyNameToId`** at registration time), not necessarily a new public enum value. The well-known **names** **`SelectionHeuristic::Config`** and **`SelectionHeuristic::StaticOrdering`** are **strings** in env/config and docs; their **`int64_t`** values are **`policyNameToId(...)`**, not enum members. Both are shipped as **backend built-ins** ([§10.1](#101-heuristicpluginmanager)), so they are always available without loading any external `.so`; users may freely omit or reorder either in **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT`** or **`HIPDNN_HEUR_POLICY_ORDER`**. **`HIPDNN_HEUR_CONFIG_PATH`** is consulted by the **`Config`** built-in only when that policy is in the resolved order. **`Graph.preferred_engine_id`** is independent of this attribute — the frontend reorders engine configs after **`finalize()`** returns ([§5.3.5](#535-the-config-built-in-policy)). - **Headers:** Publish **`HeuristicsPluginApi.h`** (name TBD) in **plugin_sdk** (or a sibling package) containing the types and declarations in [§8](#8-c-abi-for-heuristic-plugins), without including **engine** plugin API headers. **Frontend:** End-to-end flow from application / **hipdnn_frontend** through heuristic **`finalize()`** and policy configuration is described in [§16](#16-frontend-api-flow-mirror-engine-selection). @@ -658,21 +692,23 @@ So **which engines appear as options** for a given graph is **dynamic**: it come ### 16.2 Loaded heuristic policy enumeration -**Policy IDs** (heuristic) and **engine IDs** are both **`int64_t`** name hashes but serve **different roles**: heuristic **`hipdnnHeuristicGetPolicyId`** identifies a **selection policy** module; engine plugins expose **engine** IDs for execution. Each loaded **heuristic** module contributes one stable **`int64_t`** from **`hipdnnHeuristicGetPolicyId`** ([§8.2](#82-plugin-module-metadata)). The **ordered list** the outer loop uses (**`orderedPolicyIds`**) is **`int64_t`**, derived from user-supplied **policy name** strings per [§5.3.3](#533-how-the-user-sets-orderedpolicyids). Applications and tools still need a **runtime view of which policies are actually loaded** for this **`hipdnnHandle`** (or process), analogous to **`getEngineInfos()`** (see [§10.2](#102-heuristicpluginresourcemanager)). +**Policy IDs** (heuristic) and **engine IDs** are both **`int64_t`** name hashes but serve **different roles**: heuristic **`hipdnnHeuristicGetPolicyId`** identifies a **selection policy** module; engine plugins expose **engine** IDs for execution. Each loaded **heuristic** module — and each **backend built-in** registered through the same C ABI shape — contributes one stable **`int64_t`** from **`hipdnnHeuristicGetPolicyId`** ([§8.2](#82-plugin-module-metadata)). The **ordered list** the outer loop uses (**`orderedPolicyIds`**) is **`int64_t`**, derived from user-supplied **policy name** strings per [§5.3.3](#533-how-the-user-sets-orderedpolicyids). Applications and tools still need a **runtime view of which policies are actually available** for this **`hipdnnHandle`** (or process) — built-ins plus successfully loaded `.so` modules — analogous to **`getEngineInfos()`** (see [§10.2](#102-heuristicpluginresourcemanager)). + +**Built-ins as the canonical adapter example.** **`SelectionHeuristic::StaticOrdering`** is the reference built-in (`backend/src/heuristics/static_ordering/`). Its source unit exports **`populateFunctionTable()`**, which fills a **`HeuristicPluginFunctionTable`** struct with function pointers matching the heuristic plugin C ABI surface from [§8](#8-c-abi-for-heuristic-plugins) (handle lifecycle, policy descriptor lifecycle, set-engine-ids, set-serialized-graph, finalize, get-sorted-ids). At **`HeuristicPluginManager`** construction time, **`registerBuiltIns()`** wraps that table via **`HeuristicPlugin::createBuiltIn(populateFunctionTable(), source_label)`** and inserts the resulting **`std::shared_ptr`** into the same registry that `dlopen`-loaded modules use. Policy-ID lookup, version validation, duplicate-ID rejection, and per-handle plugin-handle creation all flow through that registry without distinguishing built-ins from external plugins. New built-in adapters follow the same recipe — add a `populateFunctionTable()` and a line in **`registerBuiltIns()`**. **Proposal:** -- **Backend / handle scope:** **`HeuristicPluginResourceManager`** exposes a query (exact name TBD) that returns **loaded heuristic policy metadata**: at minimum the **`int64_t` policy ID** per accepted module; optionally the **canonical policy name** (from **`hipdnnHeuristicGetPolicyName`** if present, else from a host-maintained id→name map for well-known policies), **plugin implementation version**, and install path (mirroring **`getHeuristicPluginInfos()`** / **`getLoadedHeuristicPluginFiles`** in [§10.2](#102-heuristicpluginresourcemanager)). A thin **`hipdnn…_ext`** C wrapper is optional if the C API surface should stay symmetric with other handle queries. -- **Frontend:** Add a small helper (for example **`getLoadedHeuristicPolicyInfos(handle)`**) that forwards to that query, documented next to **`getEngineConfigs`** and **`Graph::initializeEngineConfig`** so policy **configuration** (**name** strings in **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER`**) can be validated or presented in UIs without hard-coding numeric IDs. +- **Backend / handle scope:** **`HeuristicPluginResourceManager`** exposes a query (exact name TBD) that returns **available heuristic policy metadata** — both built-ins and accepted external modules: at minimum the **`int64_t` policy ID**; optionally the **canonical policy name** (from **`hipdnnHeuristicGetPolicyName`** if present, else from a host-maintained id→name map for well-known policies), **plugin implementation version**, install path (for `.so`-backed entries), and a flag distinguishing **built-in** from **loaded module** (mirroring **`getHeuristicPluginInfos()`** / **`getLoadedHeuristicPluginFiles`** in [§10.2](#102-heuristicpluginresourcemanager)). A thin **`hipdnn…_ext`** C wrapper is optional if the C API surface should stay symmetric with other handle queries. +- **Frontend:** Add a small helper (for example **`getLoadedHeuristicPolicyInfos(handle)`**) that forwards to that query, documented next to **`getEngineConfigs`** and **`Graph::initializeEngineConfig`** so policy **configuration** (**name** strings in **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT`**) can be validated or presented in UIs without hard-coding numeric IDs. -**How the list is determined dynamically:** The enumeration returns **only modules that passed load-time checks** ([§11](#11-versioning-and-compatibility-checks)): heuristic search paths ([§10.1](#101-heuristicpluginmanager)), **`validateBeforeAdding`**-style validation, and **unique `int64_t` policy ID** registration. It does **not** depend on the operation graph. **Name** strings the user places in **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER`** (or other config surfaces) are **hashed** and validated against this registry when the list is resolved ([§5.3.3](#533-how-the-user-sets-orderedpolicyids)) and when **`finalize()`** skips unknown or failed modules ([§9](#9-policy-plugins-and-the-outer-loop)). +**How the list is determined dynamically:** Built-ins are present unconditionally as registered by **`registerBuiltIns()`** ([§10.1](#101-heuristicpluginmanager)). External entries are **only modules that passed load-time checks** ([§11](#11-versioning-and-compatibility-checks)): heuristic search paths ([§10.1](#101-heuristicpluginmanager)), **`validateBeforeAdding`**-style validation, and **unique `int64_t` policy ID** registration. Neither set depends on the operation graph. User-supplied entries — name tokens in **`HIPDNN_HEUR_POLICY_ORDER`** or pre-hashed IDs in **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT`** — are resolved against the combined registry when the list is resolved ([§5.3.3](#533-how-the-user-sets-orderedpolicyids)). Unknown IDs are not rejected up front; they produce a null `_policySlots` entry during the per-finalize sync and the outer loop skips them ([§9](#9-policy-plugins-and-the-outer-loop)). ### 16.3 End-to-end flow with policy order 1. **Optional:** The application calls the enumeration helper to list **available policies** (**`int64_t` IDs** and optional **names**) for loaded heuristic plugins plus any **built-in** adapters registered under the same scheme without a separate `.so`. -2. The application sets **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER`** on the **`EngineHeuristicDescriptor`** when overriding defaults, and/or relies on handle-level default or environment ([§5.3.3](#533-how-the-user-sets-orderedpolicyids)). +2. The application sets **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT`** on the **`EngineHeuristicDescriptor`** when overriding defaults, and/or relies on handle-level default or environment ([§5.3.3](#533-how-the-user-sets-orderedpolicyids)). 3. **Unchanged from today:** **`backendFinalize`** on **`EngineHeuristicDescriptor`** runs the outer policy loop ([§14.2](#142-pseudocode-for-finalize-first-draft)). **Candidate engine IDs** still come from **`EnginePluginResourceManager::getApplicableEngineIds`** ([§10.4](#104-relationship-to-enginepluginresourcemanager)). -4. **Unchanged from today:** The application reads **`HIPDNN_ATTR_ENGINEHEUR_RESULTS`** to obtain **ordered engine configuration descriptors**; **`Graph::initializeEngineConfig`** continues to pick among those configs using **preferred** or **default engine id** on the graph. Graph-level preferences (for example **`preferred_engine_id`**) are interpreted inside policies such as **`SelectionHeuristic::Config`**, not by a separate hard-coded pass ([§13.3](#133-graph-level-preferences-for-example-preferred_engine_id)). +4. **Unchanged from today:** The application reads **`HIPDNN_ATTR_ENGINEHEUR_RESULTS`** to obtain **ordered engine configuration descriptors**; **`Graph::initializeEngineConfig`** continues to pick among those configs using **preferred** or **default engine id** on the graph. **`Graph.preferred_engine_id`** is honored by the **frontend** as a post-hoc reorder of those engine configs in **`Graph::initializeEngineConfig`** — the backend heuristic loop does not see it ([§13.3](#133-graph-level-preferences-for-example-preferred_engine_id)). In short: **policies** are **configured** by an ordered **policy name** string list (public surface) plus an **optional discovery API**; the backend uses **`int64_t`** policy IDs internally for matching. **Engines** remain **graph-dependent** and appear under **`HIPDNN_ATTR_ENGINEHEUR_RESULTS`** only after heuristic **`finalize()`** succeeds. @@ -684,8 +720,12 @@ In short: **policies** are **configured** by an ordered **policy name** string l - **Regression test** asserting that when **`SelectionHeuristic::StaticOrdering`** is in effect (for example via the default **policy name** order), ordering matches current `utilities::sortEngineIds` for a fixed candidate list. - **Failure test** asserting that when **`orderedPolicyIds`** is empty, all IDs are unknown/skipped, or every policy declines, **`finalize()`** fails via **`HipdnnException`** / the same status path as other descriptor **`finalize()`** errors (no silent sort fallback). - **Integration tests** with real graphs and devices when GPU is available. -- **ABI / loader tests** that load a minimal mock heuristic `.so`, verify **`hipdnnHeuristicGetApiVersion`**, **`hipdnnHeuristicGetPolicyId`** ( **`int64_t`** matches **`engineNameToId`** for the plugin’s documented name), **`HandleCreate` / `HandleDestroy`**, **`PolicyDescriptorCreate` / `Destroy`**, **`Finalize` / `GetSortedEngineIds`**, and reject wrong major versions; **negative tests** that a second module with a **duplicate** **`hipdnnHeuristicGetPolicyId`** and a module whose optional **`hipdnnHeuristicGetPolicyName`** does not hash to its policy ID are **rejected at load** (same spirit as **`EnginePluginManager`** duplicate engine IDs—[§5.3.1](#531-well-known-policy-names-and-ids), [§11](#11-versioning-and-compatibility-checks)). -- **Policy order tests** that assert default **name** list **`{ "SelectionHeuristic::Config", "SelectionHeuristic::StaticOrdering" }`**, descriptor override wins over handle, and unknown **names** (after hashing) are handled per policy. +- **ABI / loader tests** that load a minimal mock heuristic `.so`, verify **`hipdnnHeuristicGetApiVersion`**, **`hipdnnHeuristicGetPolicyId`** ( **`int64_t`** matches **`policyNameToId`** for the plugin’s documented name), **`HandleCreate` / `HandleDestroy`**, **`PolicyDescriptorCreate` / `Destroy`**, **`Finalize` / `GetSortedEngineIds`**, and reject wrong major versions; **negative tests** that a second module with a **duplicate** **`hipdnnHeuristicGetPolicyId`** and a module whose optional **`hipdnnHeuristicGetPolicyName`** does not hash to its policy ID are **rejected at load** (same spirit as **`EnginePluginManager`** duplicate engine IDs—[§5.3.1](#531-well-known-policy-names-and-ids), [§11](#11-versioning-and-compatibility-checks)). +- **Policy order tests** that assert the effective default **name** list is **`{ "SelectionHeuristic::Config", "SelectionHeuristic::StaticOrdering" }`** (the fallback in [§5.3.2](#532-default-policy-order-user-strings) — backend does not inject any other entries), descriptor override wins over handle which wins over env, and unknown **names** (after hashing) are handled per policy. +- **Built-in registration tests** that assert **`HeuristicPluginManager`** registers both **`SelectionHeuristic::Config`** and **`SelectionHeuristic::StaticOrdering`** as built-ins at construction (no `.so` loaded), the registered IDs equal **`policyNameToId("SelectionHeuristic::Config")`** and **`policyNameToId("SelectionHeuristic::StaticOrdering")`** respectively, and both are available regardless of **`hipdnnPluginLoadingMode_ext_t`** (including **`HIPDNN_PLUGIN_LOADING_ABSOLUTE`**) — see [§10.1](#101-heuristicpluginmanager). +- **`SelectionHeuristic::Config` built-in tests** ([§5.3.5](#535-the-config-built-in-policy)) covering: empty candidates → declines; **`HIPDNN_HEUR_CONFIG_PATH`** unset / file missing / unparsable → declines; serialized graph buffer invalid or graph has no nodes → declines; no rule matches → declines; matched engine ID not in candidates → declines; matched rule moves the matched engine to the front while preserving the order of other candidates; the JSON file is re-read on every invocation (no process cache). +- **Engine override config rule-matching tests** covering exact-dim and wildcard (`-1`) matching, exact-bucket-before-wildcard-bucket cross-partition ordering, declaration-order tiebreak within a bucket, stride patterns (exact, wildcard element, empty pattern matches any), op-name and tensor-count rejection, and JSON loading (valid, missing file, env-var unset). +- **Frontend `preferred_engine_id` tests** that assert **`Graph::initializeEngineConfig`** reorders the heuristic-ranked engine configs to put the requested engine first when its global index appears among the configs returned by **`HIPDNN_ATTR_ENGINEHEUR_RESULTS`**, and falls back to the heuristic order otherwise. - **Lifetime tests** that assert destroying **`EngineHeuristicDescriptor`** invokes **`hipdnnHeuristicPolicyDescriptorDestroy`** for every owned slot (and does not leak plugin handles owned by **`hipdnnHandle`**). - **Enumeration tests** that assert the handle-scoped **loaded heuristic policy** query ([§16.2](#162-loaded-heuristic-policy-enumeration)) matches modules that passed **`validateBeforeAdding`**-style checks, returns **`int64_t`** IDs consistent with **`hipdnnHeuristicGetPolicyId`**, and that optional frontend helpers return consistent metadata. @@ -709,12 +749,15 @@ In short: **policies** are **configured** by an ordered **policy name** string l | **Engine plugin** | Shared library providing engines and execution; **distinct C ABI** from heuristic plugins. | | **Heuristic / selection policy plugin** | Shared library implementing one outer-loop selection strategy via the C ABI in [§8](#8-c-abi-for-heuristic-plugins). | | **Heuristic C ABI** | `extern "C"` symbol set: module metadata (**`int64_t` policy ID** via **`hipdnnHeuristicGetPolicyId`**), **`hipdnnHeuristicHandle_t`** and **`hipdnnHeuristicPolicyDescriptor_t`** lifecycle, selection functions ([§8](#8-c-abi-for-heuristic-plugins)). | -| **HeuristicPluginManager** | Loads and validates heuristic `.so` files; analogous to **EnginePluginManager** but **heuristic-only** symbols. | -| **HeuristicPluginResourceManager** | Handle-scoped facade for heuristic plugins; stores **`hipdnnHeuristicHandle_t`** per module, paths; analogous to **EnginePluginResourceManager**. | +| **HeuristicPluginManager** | Loads and validates heuristic `.so` files **and** registers backend built-ins (via **`registerBuiltIns()`**) at construction time; analogous to **EnginePluginManager** but **heuristic-only** symbols. | +| **HeuristicPluginResourceManager** | Handle-scoped facade for heuristic plugins (built-ins and external modules); stores **`hipdnnHeuristicHandle_t`** per module, paths; analogous to **EnginePluginResourceManager**. | | **Outer loop** | Ordered list of policies; first applicable successful policy wins. If none succeed, **`finalize()`** fails via **`HipdnnException`** (normal backend error path); no built-in sort after the loop. | -| **Policy name** | UTF-8 string in **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER`**, env, or handle defaults; hashed with **`engineNameToId`** to form the **`int64_t`** used in **`orderedPolicyIds`** and matched against **`hipdnnHeuristicGetPolicyId`**. | -| **Policy ID (heuristic)** | **`int64_t`** stable identifier for one heuristic module; **must** equal **`engineNameToId(policy_name)`** for that module’s canonical name ([§5.3](#53-ordered-policy-list-default-and-user-configuration)). | -| **orderedPolicyIds** | Resolved **`std::vector`** for one **`finalize()`**, each element the hash of a user-configured **policy name**; default **names** **`{ "SelectionHeuristic::Config", "SelectionHeuristic::StaticOrdering" }`** ([§5.3](#53-ordered-policy-list-default-and-user-configuration)). | +| **Policy name** | UTF-8 string in **`HIPDNN_ATTR_ENGINEHEUR_POLICY_ORDER_EXT`**, env, or handle defaults; hashed with **`policyNameToId`** to form the **`int64_t`** used in **`orderedPolicyIds`** and matched against **`hipdnnHeuristicGetPolicyId`**. | +| **Policy ID (heuristic)** | **`int64_t`** stable identifier for one heuristic policy (built-in or external); **must** equal **`policyNameToId(policy_name)`** for that policy’s canonical name ([§5.3](#53-ordered-policy-list-default-and-user-configuration)). | +| **orderedPolicyIds** | Resolved **`std::vector`** for one **`finalize()`**, each element the hash of a user-configured **policy name**. The backend does **not** inject any entries. Effective default when nothing is overridden: **`{ policyNameToId("SelectionHeuristic::Config"), policyNameToId("SelectionHeuristic::StaticOrdering") }`** ([§5.3](#53-ordered-policy-list-default-and-user-configuration), [§5.3.2](#532-default-policy-order-user-strings)). | +| **`SelectionHeuristic::Config`** | Backend built-in heuristic in `backend/src/heuristics/config/` that consults **`HIPDNN_HEUR_CONFIG_PATH`** for JSON rules matching the graph's convolution shapes to a preferred engine ID. On a match it returns the candidates with the matched engine moved to the front; on every miss path (env unset, file missing, no rule match, matched engine not in candidates) it **declines** so the policy loop falls through to the next slot ([§5.3.5](#535-the-config-built-in-policy)). | +| **`HIPDNN_HEUR_CONFIG_PATH`** | Environment variable consulted by **`SelectionHeuristic::Config`** (formerly **`HIPDNN_ENGINE_OVERRIDE_FILE`**) on every **`Finalize`**. Points at a JSON rule file that maps convolution op + tensor-shape patterns to a preferred engine ID. Re-read each invocation; no process cache. | +| **Built-in heuristic** | Backend-shipped policy implemented as a populated **`HeuristicPluginFunctionTable`** (the heuristic plugin C ABI shape) and wrapped via **`HeuristicPlugin::createBuiltIn(...)`** at **`HeuristicPluginManager`** construction time. No `.so` is loaded. Today **`SelectionHeuristic::Config`** and **`SelectionHeuristic::StaticOrdering`** ship this way ([§10.1](#101-heuristicpluginmanager), [§16.2](#162-loaded-heuristic-policy-enumeration)). | | **DeviceProperties** | C++ struct of device facts in the backend; serialized to FlatBuffer and passed in **`hipdnnPluginConstData_t`** for the heuristic C ABI ([§13.2](#132-serialized-device-properties-flatbuffer)). Plugins do not call HIP. | | **SelectionHeuristic** | C++ facade over **`hipdnnHeuristicPolicyDescriptor_t`** for one policy **slot** on **`EngineHeuristicDescriptor`**; session state stays on **`hipdnnHeuristicHandle_t`**. | | **Plugin heuristic handle** | **`hipdnnHeuristicHandle_t`**: session object per heuristic module per **`hipdnnHandle`**; **SetDeviceProperties** (serialized device-properties FlatBuffer in **`hipdnnPluginConstData_t`**), applied **before** **`PolicyFinalize`** on descriptors bound to that handle; heuristics **read** device facts from the handle as needed; **single-thread** use ([§8.3](#83-plugin-handle-session-object)). | diff --git a/projects/hipdnn/flatbuffers_sdk/include/hipdnn_flatbuffers_sdk/data_objects/device_properties_generated.h b/projects/hipdnn/flatbuffers_sdk/include/hipdnn_flatbuffers_sdk/data_objects/device_properties_generated.h index 61b089bd1e2..4eb41d73401 100644 --- a/projects/hipdnn/flatbuffers_sdk/include/hipdnn_flatbuffers_sdk/data_objects/device_properties_generated.h +++ b/projects/hipdnn/flatbuffers_sdk/include/hipdnn_flatbuffers_sdk/data_objects/device_properties_generated.h @@ -39,8 +39,6 @@ struct DevicePropertiesT : public ::flatbuffers::NativeTable { /// /// Evolution: New optional fields can be added without breaking compatibility. /// Plugins verify and parse this buffer using flatbuffers::Verifier. -/// -/// RFC 0007 Reference: Section 6 (Device Properties), Section 13.2 (Serialization) struct DeviceProperties FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { typedef DevicePropertiesT NativeTableType; typedef DevicePropertiesBuilder Builder; diff --git a/projects/hipdnn/flatbuffers_sdk/schemas/device_properties.fbs b/projects/hipdnn/flatbuffers_sdk/schemas/device_properties.fbs index 12104678271..6bc2f60d431 100644 --- a/projects/hipdnn/flatbuffers_sdk/schemas/device_properties.fbs +++ b/projects/hipdnn/flatbuffers_sdk/schemas/device_properties.fbs @@ -11,8 +11,6 @@ namespace hipdnn_flatbuffers_sdk.data_objects; /// /// Evolution: New optional fields can be added without breaking compatibility. /// Plugins verify and parse this buffer using flatbuffers::Verifier. -/// -/// RFC 0007 Reference: Section 6 (Device Properties), Section 13.2 (Serialization) table DeviceProperties { /// Device ID from hipGetDevice device_id: int = -1; diff --git a/projects/hipdnn/frontend/include/hipdnn_frontend/Graph.hpp b/projects/hipdnn/frontend/include/hipdnn_frontend/Graph.hpp index c024279ebb1..f88faa626ca 100644 --- a/projects/hipdnn/frontend/include/hipdnn_frontend/Graph.hpp +++ b/projects/hipdnn/frontend/include/hipdnn_frontend/Graph.hpp @@ -94,7 +94,6 @@ #include #include #include -#include #include #include #include @@ -171,23 +170,6 @@ class Graph : public INode std::optional _preferredEngineId; - static std::optional getDefaultEngineId() - { - static const std::optional s_defaultId = []() -> std::optional { - auto envStr = hipdnn_data_sdk::utilities::trim( - hipdnn_data_sdk::utilities::getEnv("HIPDNN_DEFAULT_ENGINE")); - if(envStr.empty()) - { - return std::nullopt; - } - auto engineId = hipdnn_data_sdk::utilities::engineNameToId(envStr); - HIPDNN_FE_LOG_INFO("HIPDNN_DEFAULT_ENGINE='" << envStr - << "' mapped to engine ID: " << engineId); - return engineId; - }(); - return s_defaultId; - } - /// Apply validated knob settings to the engine config descriptor via /// the descriptor-based C API path. Error applyKnobSettingsToEngineConfig(const std::vector& validatedSettings) @@ -346,12 +328,6 @@ class Graph : public INode } resetGraphDesc(); - if(!_preferredEngineId.has_value()) - { - _preferredEngineId - = hipdnn_frontend::engine_override::getPreferredIdFromOverrideConfig(*this); - } - std::unordered_map tensorDescs; std::vector operations; @@ -408,53 +384,33 @@ class Graph : public INode Error initializeEngineConfig(hipdnnBackendDescriptor_t engineHeuristicDesc) { + // The backend's SelectionHeuristic::Config built-in honors + // HIPDNN_HEUR_CONFIG_PATH inside the policy loop, so the + // heuristic-ranked list already reflects env/config-file overrides. + // The explicit Graph.preferred_engine_id setter is honored here as a + // post-hoc reorder: if the user pinned an engine and it appears in + // the ranked list, prefer it over index 0; otherwise log and fall + // back to the heuristic's choice. std::vector> engineConfigs; std::vector engineIds; - auto defaultEngineId = getDefaultEngineId(); HIPDNN_CHECK_ERROR(hipdnn_frontend::detail::getEngineConfigs( - engineConfigs, - engineIds, - engineHeuristicDesc, - _preferredEngineId.has_value() || defaultEngineId.has_value())); + engineConfigs, engineIds, engineHeuristicDesc, _preferredEngineId.has_value())); - // Select engine config based on preferred ID or use first available size_t selectedIndex = 0; - if(defaultEngineId) + if(_preferredEngineId.has_value()) { - auto defaultId = defaultEngineId.value(); - auto it = std::find(engineIds.begin(), engineIds.end(), defaultId); + const int64_t preferredId = _preferredEngineId.value(); + auto it = std::find(engineIds.begin(), engineIds.end(), preferredId); if(it != engineIds.end()) { selectedIndex = static_cast(std::distance(engineIds.begin(), it)); - HIPDNN_FE_LOG_INFO("Default engine id " << defaultId - << " found, using it for execution plan."); + HIPDNN_FE_LOG_INFO("Preferred engine id " + << preferredId << " found, using it for execution plan."); } else { - HIPDNN_FE_LOG_INFO("Default engine id " - << defaultId << " not found, using top engine config instead."); - } - } - - if(_preferredEngineId.has_value()) - { - bool found = false; - - for(size_t i = 0; i < engineIds.size(); ++i) - { - - if(engineIds[i] == _preferredEngineId.value()) - { - selectedIndex = i; - found = true; - break; - } - } - - if(!found) - { - HIPDNN_FE_LOG_WARN("Preferred engine id " - << _preferredEngineId.value() + HIPDNN_FE_LOG_INFO("Preferred engine id " + << preferredId << " not found, using top engine config instead."); } } diff --git a/projects/hipdnn/frontend/include/hipdnn_frontend/detail/EngineOverrideConfig.hpp b/projects/hipdnn/frontend/include/hipdnn_frontend/detail/EngineOverrideConfig.hpp deleted file mode 100644 index 64e20b940b4..00000000000 --- a/projects/hipdnn/frontend/include/hipdnn_frontend/detail/EngineOverrideConfig.hpp +++ /dev/null @@ -1,470 +0,0 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#pragma once - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -#ifndef HIPDNN_FRONTEND_SKIP_JSON_LIB -#include -#include -#endif // HIPDNN_FRONTEND_SKIP_JSON_LIB - -namespace hipdnn_frontend::engine_override -{ - -/// Dimension value meaning "match any value in this slot". -inline constexpr int64_t WILDCARD_DIM = -1; - -/// Pattern for a single tensor: a list of expected dimensions and optional strides, -/// with -1 as a per-slot wildcard. -/// When `stride` is empty no stride matching is performed. -struct TensorPattern -{ - std::vector dim; - std::vector stride; ///< Empty = do not match on stride. - - /// Returns true iff tensor.get_dim() matches this pattern element-by-element, - /// and (when stride is non-empty) tensor.get_stride() matches stride element-by-element. - /// Rejects immediately when rank differs; skips per-element check for WILDCARD_DIM slots. - bool matches(const graph::TensorAttributes& tensor) const - { - const auto& tdim = tensor.get_dim(); - if(dim.size() != tdim.size()) - { - return false; - } - for(size_t i = 0; i < dim.size(); ++i) - { - if(dim[i] != WILDCARD_DIM && dim[i] != tdim[i]) - { - return false; - } - } - if(!stride.empty()) - { - const auto& tstride = tensor.get_stride(); - // Nonmatching strides assume wildcard semantics. - const size_t rank = std::min(stride.size(), tstride.size()); - for(size_t i = 0; i < rank; ++i) - { - if(stride[i] != WILDCARD_DIM && stride[i] != tstride[i]) - { - return false; - } - } - } - return true; - } -}; - -/// A single engine-override rule. -struct OperationRule -{ - std::string op; ///< "conv_fprop" / "conv_dgrad" / "conv_wgrad" - std::string engineName; ///< Engine name resolved to an ID via engineNameToId() - std::vector tensors; ///< Ordered patterns for operation inputs - - /// Returns true iff every tensor in `tensors` matches the corresponding pattern. - /// Rejects immediately when the tensor count differs. - bool matches(const std::vector>& inputs) const - { - if(tensors.size() != inputs.size()) - { - return false; - } - for(size_t i = 0; i < tensors.size(); ++i) - { - if(!tensors[i].matches(*inputs[i])) - { - return false; - } - } - return true; - } -}; - -// ── Internal helpers ────────────────────────────────────────────────────────── - -/// FNV-1a hash over a flat vector key. -struct DimKeyHash -{ - size_t operator()(const std::vector& key) const noexcept - { - size_t h = 14695981039346656037ULL; // FNV-1a offset basis (64-bit) - for(int64_t v : key) - { - const auto* p = reinterpret_cast(&v); - for(size_t b = 0; b < sizeof(int64_t); ++b) - { - h ^= static_cast(p[b]); - h *= 1099511628211ULL; // FNV-1a prime - } - } - return h; - } -}; - -// ── EngineOverrideConfig ────────────────────────────────────────────────────── - -/// Loaded set of engine-override rules. -/// Rules are evaluated in declaration order; first match wins. -/// -/// Internally rules are split per op name (strategy 1) and further divided -/// into exact rules — where every dimension is concrete and no stride constraint -/// is present — stored in a hash map for O(1) lookup, and wildcard rules kept -/// in a declaration-order vector (strategy 2). The two structures are reconciled -/// via declaration index so that first-match semantics are preserved across the -/// partition. -class EngineOverrideConfig -{ -public: - /// Default-construct an empty config (no rules). - EngineOverrideConfig() = default; - - /// Construct directly from a vector of rules (useful for tests). - explicit EngineOverrideConfig(std::vector rules) - { - for(size_t i = 0; i < rules.size(); ++i) - { - indexRule(std::move(rules[i]), i); - } - } - - /// Load from an explicit file path. - /// Returns nullopt on missing file, parse error, or when JSON support is - /// compiled out (HIPDNN_FRONTEND_SKIP_JSON_LIB defined). - static std::optional load(const std::string& filepath) - { -#ifndef HIPDNN_FRONTEND_SKIP_JSON_LIB - std::ifstream file(filepath); - if(!file.is_open()) - { - HIPDNN_FE_LOG_WARN("EngineOverrideConfig: cannot open file: " << filepath); - return std::nullopt; - } - - try - { - auto config = parseJson(nlohmann::json::parse(file)); - HIPDNN_FE_LOG_INFO("EngineOverrideConfig: loaded " << config.ruleCount() - << " rule(s) from " << filepath); - return config; - } - catch(const nlohmann::json::exception& e) - { - HIPDNN_FE_LOG_WARN("EngineOverrideConfig: JSON parse error in " << filepath << ": " - << e.what()); - return std::nullopt; - } -#else - (void)filepath; - HIPDNN_FE_LOG_WARN( - "EngineOverrideConfig: JSON support not compiled in; engine override file ignored."); - return std::nullopt; -#endif // HIPDNN_FRONTEND_SKIP_JSON_LIB - } - - /// Load from a JSON string in memory. - /// Returns nullopt on parse error or when JSON support is compiled out. - static std::optional loadFromContent(const std::string& content) - { -#ifndef HIPDNN_FRONTEND_SKIP_JSON_LIB - try - { - auto config = parseJson(nlohmann::json::parse(content)); - HIPDNN_FE_LOG_INFO("EngineOverrideConfig: loaded " << config.ruleCount() - << " rule(s) from inline content"); - return config; - } - catch(const nlohmann::json::exception& e) - { - HIPDNN_FE_LOG_WARN( - "EngineOverrideConfig: JSON parse error in inline content: " << e.what()); - return std::nullopt; - } -#else - (void)content; - HIPDNN_FE_LOG_WARN( - "EngineOverrideConfig: JSON support not compiled in; engine override content ignored."); - return std::nullopt; -#endif // HIPDNN_FRONTEND_SKIP_JSON_LIB - } - - /// Return a pointer to the process-lifetime config loaded from - /// HIPDNN_ENGINE_OVERRIDE_FILE (read and cached on the first call, - /// thread-safe per C++11). Returns nullptr when the variable is unset, - /// empty, the file cannot be opened, or JSON support is compiled out. - /// Leading and trailing whitespace in the path value is ignored. - static const EngineOverrideConfig* loadFromEnv() - { - static constexpr const char* ENV_VAR = "HIPDNN_ENGINE_OVERRIDE_FILE"; - static const std::optional s_cached = []() { - std::string path = hipdnn_data_sdk::utilities::getEnv(ENV_VAR, ""); - const auto first = path.find_first_not_of(" \t\r\n"); - if(first == std::string::npos) - { - return std::optional{}; - } - path = path.substr(first, path.find_last_not_of(" \t\r\n") - first + 1); - return load(path); - }(); - return s_cached ? &*s_cached : nullptr; - } - - /// Scan rules in declaration order; return the first matching enginedId or nullopt. - /// - /// Strategy 1: only the bucket for `op` is examined (hash map lookup). - /// Strategy 2: within the bucket, exact rules are probed in O(1) via hash map; - /// wildcard rules are scanned linearly but the scan terminates as - /// soon as a lower-order exact match is known to exist. - std::optional - matchOperation(const std::string& op, - const std::vector>& tensors) const - { - // Strategy 1: find the op bucket - const auto opIt = _index.find(op); - if(opIt == _index.end()) - { - return std::nullopt; - } - const OpBucket& bucket = opIt->second; - - // Strategy 2a: O(1) probe of the exact map - std::optional exactHit; - { - const auto key = buildDimKey(tensors); - const auto eit = bucket.exact.find(key); - if(eit != bucket.exact.end()) - { - exactHit = eit->second; - } - } - - // Strategy 2b: linear scan of wildcard rules in declaration order. - // Wildcards are stored in ascending order, so once the current entry's - // order exceeds the exact hit's order, no further wildcard can win. - for(const auto& entry : bucket.wildcards) - { - if(exactHit && entry.order > exactHit->order) - { - break; // exact match has earlier declaration; no wildcard can beat it - } - if(entry.rule.matches(tensors)) - { - // This wildcard has lower or equal order to any exact hit (loop - // would have broken otherwise), so it is the first-match winner. - HIPDNN_FE_LOG_INFO("EngineOverrideConfig: matched op=" - << op << " enginedId=" << entry.enginedId << " (wildcard rule)"); - return entry.enginedId; - } - } - - if(exactHit) - { - HIPDNN_FE_LOG_INFO("EngineOverrideConfig: matched op=" - << op << " enginedId=" << exactHit->enginedId << " (exact rule)"); - return exactHit->enginedId; - } - return std::nullopt; - } - -private: - /// enginedId and declaration index for an exact-match rule. - struct ExactEntry - { - int64_t enginedId; - size_t order; ///< position in the original rule list (0 = first) - }; - - /// Wildcard rule paired with its declaration index and resolved engine ID. - struct WildcardEntry - { - OperationRule rule; - int64_t enginedId; ///< resolved from rule.engineName at index time - size_t order; - }; - - /// Per-op rule storage partitioned into exact and wildcard buckets. - struct OpBucket - { - /// Exact rules: no WILDCARD_DIM anywhere and no stride constraints. - /// Key = rank-prefixed flattened dims of all input tensors. - /// When two rules share a key, only the first (lowest order) is kept. - std::unordered_map, ExactEntry, DimKeyHash> exact; - - /// Wildcard rules in ascending declaration order. - std::vector wildcards; - }; - - std::unordered_map _index; - - // ── helpers ─────────────────────────────────────────────────────────────── - -#ifndef HIPDNN_FRONTEND_SKIP_JSON_LIB - /// Parse a nlohmann::json object into an EngineOverrideConfig. - /// Throws nlohmann::json::exception on malformed input; callers handle it. - /// - /// Expected JSON format: - /// @code{.json} - /// { - /// "engine_overrides": [ - /// { - /// "op": "conv_fprop", - /// "engine_name": "MIOPEN_ENGINE", - /// "tensors": [ - /// { "dim": [1, 3, 224, 224], "stride": [150528, 50176, 224, 1] }, - /// { "dim": [64, 3, 7, 7] } - /// ] - /// }, - /// { - /// "op": "conv_fprop", - /// "engine_name": "HIPBLASLT_ENGINE", - /// "tensors": [ - /// { "dim": [-1, -1, -1, -1] }, - /// { "dim": [-1, -1, -1, -1] } - /// ] - /// } - /// ] - /// } - /// @endcode - /// - /// Notes: - /// - `engine_name` must be a registered engine name (e.g. "MIOPEN_ENGINE"). - /// - `-1` in `dim` or `stride` is a wildcard matching any value. - /// - `stride` is optional per tensor; omitting it skips stride matching. - static EngineOverrideConfig parseJson(const nlohmann::json& j) - { - std::vector rules; - for(const auto& entry : j.at("engine_overrides")) - { - OperationRule rule; - rule.op = entry.at("op").get(); - rule.engineName = entry.at("engine_name").get(); - for(const auto& t : entry.at("tensors")) - { - TensorPattern pat; - pat.dim = t.at("dim").get>(); - if(t.contains("stride")) - { - pat.stride = t.at("stride").get>(); - } - rule.tensors.push_back(std::move(pat)); - } - rules.push_back(std::move(rule)); - } - return EngineOverrideConfig(std::move(rules)); - } -#endif // HIPDNN_FRONTEND_SKIP_JSON_LIB - - /// Returns true if any dim slot in any pattern is WILDCARD_DIM, or if any - /// pattern carries a stride constraint (stride-constrained rules use the - /// linear wildcard scan so that TensorPattern::matches() is always called). - static bool hasWildcard(const std::vector& patterns) - { - for(const auto& p : patterns) - { - for(const int64_t d : p.dim) - { - if(d == WILDCARD_DIM) - { - return true; - } - } - if(!p.stride.empty()) - { - return true; - } - } - return false; - } - - /// Build a dim key from a rule's tensor patterns. - /// Format: [rank₀, dim₀₀, dim₀₁, …, rank₁, dim₁₀, …] - static std::vector buildDimKey(const std::vector& patterns) - { - std::vector key; - for(const auto& p : patterns) - { - key.push_back(static_cast(p.dim.size())); - key.insert(key.end(), p.dim.begin(), p.dim.end()); - } - return key; - } - - /// Build a dim key from live tensor attributes. - static std::vector - buildDimKey(const std::vector>& tensors) - { - std::vector key; - for(const auto& t : tensors) - { - const auto& d = t->get_dim(); - key.push_back(static_cast(d.size())); - key.insert(key.end(), d.begin(), d.end()); - } - return key; - } - - /// Insert one rule into the appropriate bucket of _index. - /// Resolves engineName to an int64_t ID via engineNameToId(). - void indexRule(OperationRule rule, size_t order) - { - const int64_t resolvedId = hipdnn_data_sdk::utilities::engineNameToId(rule.engineName); - OpBucket& bucket = _index[rule.op]; // keyed by op (strategy 1) - if(hasWildcard(rule.tensors)) - { - bucket.wildcards.push_back(WildcardEntry{std::move(rule), resolvedId, order}); - } - else - { - const auto key = buildDimKey(rule.tensors); - // try_emplace keeps the first (lowest-order) entry for duplicate keys. - bucket.exact.try_emplace(key, ExactEntry{resolvedId, order}); - } - } - - /// Total rule count across all buckets (exact + wildcard). - size_t ruleCount() const - { - size_t n = 0; - for(const auto& [op, bucket] : _index) - { - n += bucket.exact.size() + bucket.wildcards.size(); - } - return n; - } -}; - -/// Match op/tensors against a config and return the first matching enginedId. -/// When `config` is null the process-lifetime config loaded from -/// HIPDNN_ENGINE_OVERRIDE_FILE is used (read once on first call, thread-safe -/// per C++11). Passing an explicit config bypasses the env-var lookup entirely, -/// which is useful for testing or when the caller manages the config lifetime. -/// Returns nullopt when no rule matches or JSON support is compiled out. -inline std::optional - checkEngineOverride(const std::string& op, - const std::vector>& tensors, - const EngineOverrideConfig* config = nullptr) -{ - if(config == nullptr) - { - config = EngineOverrideConfig::loadFromEnv(); - } - if(config == nullptr) - { - return std::nullopt; - } - return config->matchOperation(op, tensors); -} - -} // namespace hipdnn_frontend::engine_override diff --git a/projects/hipdnn/frontend/include/hipdnn_frontend/detail/EngineOverrideUtils.hpp b/projects/hipdnn/frontend/include/hipdnn_frontend/detail/EngineOverrideUtils.hpp deleted file mode 100644 index e347ce8b7af..00000000000 --- a/projects/hipdnn/frontend/include/hipdnn_frontend/detail/EngineOverrideUtils.hpp +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#pragma once - -#include -#include -#include -#include - -#include - -namespace hipdnn_frontend::engine_override -{ - -/// Walk the graph using the node visitor to find the first convolution -/// operation and return the preferred engine ID from the lazily-loaded -/// engine override config (pointed to by HIPDNN_ENGINE_OVERRIDE_FILE). -/// -/// Returns nullopt when: -/// - no convolution node is present in the graph, -/// - no rule in the config matches the operation's tensors, or -/// - JSON support is compiled out (HIPDNN_FRONTEND_SKIP_JSON_LIB defined). -inline std::optional getPreferredIdFromOverrideConfig(const graph::INode& root) -{ - std::optional result; - - root.visit([&result](const graph::INode& node) { - if(result.has_value()) - { - return; - } - - switch(node.getNodeType()) - { - case graph::NodeType::CONVOLUTION_FPROP: - { - const auto& conv = static_cast(node); - result = checkEngineOverride("conv_fprop", - {conv.attributes.get_x(), conv.attributes.get_w()}); - break; - } - case graph::NodeType::CONVOLUTION_DGRAD: - { - const auto& conv = static_cast(node); - result = checkEngineOverride("conv_dgrad", - {conv.attributes.get_dy(), conv.attributes.get_w()}); - break; - } - case graph::NodeType::CONVOLUTION_WGRAD: - { - const auto& conv = static_cast(node); - result = checkEngineOverride("conv_wgrad", - {conv.attributes.get_x(), conv.attributes.get_dy()}); - break; - } - default: - break; - } - }); - - return result; -} - -} // namespace hipdnn_frontend::engine_override diff --git a/projects/hipdnn/frontend/include/hipdnn_frontend/detail/GraphDetail.hpp b/projects/hipdnn/frontend/include/hipdnn_frontend/detail/GraphDetail.hpp index 214eeeb0e8e..c773237a99f 100644 --- a/projects/hipdnn/frontend/include/hipdnn_frontend/detail/GraphDetail.hpp +++ b/projects/hipdnn/frontend/include/hipdnn_frontend/detail/GraphDetail.hpp @@ -56,8 +56,12 @@ inline Error "No engine configurations available for the graph."}; } - // Get only top hit if preferred engine id isn't set. - // Otherwise get all available engine configs to search for preferred id. + // Fetch only the top engine config unless the caller needs the full ranked + // list (e.g. get_ranked_engine_ids, or the explicit Graph.preferred_engine_id + // post-hoc reorder in initializeEngineConfig). HIPDNN_HEUR_CONFIG_PATH + // reordering happens inside the SelectionHeuristic::Config built-in and is + // already reflected in the ranked list — no extra frontend search is needed + // for that knob. const int64_t requiredCount = getAll ? availableEngineCount : 1; std::vector engineConfigsShallow; for(size_t i = 0; i < static_cast(requiredCount); ++i) diff --git a/projects/hipdnn/frontend/tests/CMakeLists.txt b/projects/hipdnn/frontend/tests/CMakeLists.txt index 228cea56507..b1ee1332731 100644 --- a/projects/hipdnn/frontend/tests/CMakeLists.txt +++ b/projects/hipdnn/frontend/tests/CMakeLists.txt @@ -39,7 +39,6 @@ add_executable( TestDescriptorHelpers.cpp TestDescriptorUnpackHelpers.cpp TestConvolutionWgradNode.cpp - TestEngineOverrideConfig.cpp TestError.cpp TestFrontendLogging.cpp TestGraph.cpp diff --git a/projects/hipdnn/frontend/tests/TestEngineOverrideConfig.cpp b/projects/hipdnn/frontend/tests/TestEngineOverrideConfig.cpp deleted file mode 100644 index 7c8b9c581a6..00000000000 --- a/projects/hipdnn/frontend/tests/TestEngineOverrideConfig.cpp +++ /dev/null @@ -1,419 +0,0 @@ -// Copyright © Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include - -#ifndef HIPDNN_FRONTEND_SKIP_JSON_LIB -#include -#include -#endif - -using namespace hipdnn_frontend::engine_override; -using namespace hipdnn_frontend::graph; -using namespace hipdnn_data_sdk::utilities; - -// ── helpers ───────────────────────────────────────────────────────────────── - -static std::shared_ptr makeTensor(const std::vector& dims) -{ - auto t = std::make_shared(); - t->set_dim(dims); - return t; -} - -static std::shared_ptr makeTensorWithStride(const std::vector& dims, - const std::vector& strides) -{ - auto t = std::make_shared(); - t->set_dim(dims); - t->set_stride(strides); - return t; -} - -static TensorPattern makePattern(std::vector dims) -{ - TensorPattern p; - p.dim = std::move(dims); - return p; -} - -static TensorPattern makePatternWithStride(std::vector dims, std::vector strides) -{ - TensorPattern p; - p.dim = std::move(dims); - p.stride = std::move(strides); - return p; -} - -// Construct a single-rule config inline (no JSON required). -static EngineOverrideConfig makeConfig(std::vector rules) -{ - return EngineOverrideConfig(std::move(rules)); -} - -// ── Test 1: exact dim match, single rule ──────────────────────────────────── - -TEST(TestEngineOverrideConfig, ExactDimMatchSingleRule) -{ - OperationRule rule; - rule.op = "conv_fprop"; - rule.engineName = MIOPEN_ENGINE_NAME; - rule.tensors = {makePattern({1, 3, 224, 224}), makePattern({64, 3, 7, 7})}; - - auto config = makeConfig({std::move(rule)}); - - const std::vector> tensors - = {makeTensor({1, 3, 224, 224}), makeTensor({64, 3, 7, 7})}; - - auto result = config.matchOperation("conv_fprop", tensors); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(*result, MIOPEN_ENGINE_ID); -} - -// ── Test 2: first matching rule wins ──────────────────────────────────────── - -TEST(TestEngineOverrideConfig, FirstMatchingRuleWins) -{ - OperationRule rule1; - rule1.op = "conv_fprop"; - rule1.engineName = MIOPEN_ENGINE_NAME; - rule1.tensors = {makePattern({1, 3, 224, 224})}; - - OperationRule rule2; - rule2.op = "conv_fprop"; - rule2.engineName = HIPBLASLT_ENGINE_NAME; - rule2.tensors = {makePattern({1, 3, 224, 224})}; - - auto config = makeConfig({std::move(rule1), std::move(rule2)}); - - const std::vector> tensors = {makeTensor({1, 3, 224, 224})}; - - auto result = config.matchOperation("conv_fprop", tensors); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(*result, MIOPEN_ENGINE_ID); // first rule wins -} - -// ── Test 3: no rule matches (wrong dims) ──────────────────────────────────── - -TEST(TestEngineOverrideConfig, NoRuleMatchesWrongDims) -{ - OperationRule rule; - rule.op = "conv_fprop"; - rule.engineName = MIOPEN_ENGINE_NAME; - rule.tensors = {makePattern({1, 3, 224, 224})}; - - auto config = makeConfig({std::move(rule)}); - - const std::vector> tensors = { - makeTensor({1, 3, 112, 112}) // different spatial dims - }; - - auto result = config.matchOperation("conv_fprop", tensors); - EXPECT_FALSE(result.has_value()); -} - -// ── Test 4: wildcard (-1) in one dimension ────────────────────────────────── - -TEST(TestEngineOverrideConfig, WildcardInOneDimension) -{ - OperationRule rule; - rule.op = "conv_fprop"; - rule.engineName = HIPBLASLT_ENGINE_NAME; - rule.tensors = {makePattern({-1, 64, 56, 56})}; // batch dim is wildcard - - auto config = makeConfig({std::move(rule)}); - - for(const int64_t batch : {1, 4, 8, 32}) - { - const std::vector> tensors - = {makeTensor({batch, 64, 56, 56})}; - auto result = config.matchOperation("conv_fprop", tensors); - ASSERT_TRUE(result.has_value()) << "batch=" << batch << " should match"; - EXPECT_EQ(*result, HIPBLASLT_ENGINE_ID); - } - - // Non-matching channel dim should still fail - const std::vector> tensors = {makeTensor({4, 128, 56, 56})}; - EXPECT_FALSE(config.matchOperation("conv_fprop", tensors).has_value()); -} - -// ── Test 5: all-wildcard rule matches any shape ───────────────────────────── - -TEST(TestEngineOverrideConfig, AllWildcardRuleMatchesAnyShape) -{ - OperationRule rule; - rule.op = "conv_fprop"; - rule.engineName = FUSILLI_ENGINE_NAME; - rule.tensors = {makePattern({-1, -1, -1, -1})}; - - auto config = makeConfig({std::move(rule)}); - - for(const auto& shape : - std::vector>{{1, 3, 224, 224}, {8, 64, 56, 56}, {32, 256, 14, 14}}) - { - const std::vector> tensors = {makeTensor(shape)}; - auto result = config.matchOperation("conv_fprop", tensors); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(*result, FUSILLI_ENGINE_ID); - } -} - -// ── Test 6: wrong op name → nullopt ───────────────────────────────────────── - -TEST(TestEngineOverrideConfig, WrongOpNameReturnsNullopt) -{ - OperationRule rule; - rule.op = "conv_fprop"; - rule.engineName = MIOPEN_ENGINE_NAME; - rule.tensors = {makePattern({1, 3, 224, 224})}; - - auto config = makeConfig({std::move(rule)}); - - const std::vector> tensors = {makeTensor({1, 3, 224, 224})}; - - EXPECT_FALSE(config.matchOperation("conv_dgrad", tensors).has_value()); - EXPECT_FALSE(config.matchOperation("conv_wgrad", tensors).has_value()); - EXPECT_FALSE(config.matchOperation("matmul", tensors).has_value()); -} - -// ── Test 7: wrong tensor count in rule → nullopt ──────────────────────────── - -TEST(TestEngineOverrideConfig, WrongTensorCountReturnsNullopt) -{ - OperationRule rule; - rule.op = "conv_fprop"; - rule.engineName = MIOPEN_ENGINE_NAME; - rule.tensors = {makePattern({1, 3, 224, 224}), makePattern({64, 3, 7, 7})}; // 2 patterns - - auto config = makeConfig({std::move(rule)}); - - // Provide only 1 tensor where 2 are expected - const std::vector> tensors = {makeTensor({1, 3, 224, 224})}; - EXPECT_FALSE(config.matchOperation("conv_fprop", tensors).has_value()); - - // Provide 3 tensors where 2 are expected - const std::vector> tensors3 - = {makeTensor({1, 3, 224, 224}), makeTensor({64, 3, 7, 7}), makeTensor({64, 1, 1, 1})}; - EXPECT_FALSE(config.matchOperation("conv_fprop", tensors3).has_value()); -} - -// ── Tests 11–12: cross-partition ordering (exact vs wildcard) ─────────────── -// -// These tests verify that first-match-wins semantics are preserved when an -// exact rule and a wildcard rule sit in different partitions. - -// Test 11: wildcard declared before exact — wildcard must win -TEST(TestEngineOverrideConfig, WildcardBeforeExactBothMatch) -{ - OperationRule wildcard; - wildcard.op = "conv_fprop"; - wildcard.engineName = FUSILLI_ENGINE_NAME; - wildcard.tensors = {makePattern({-1, 3, 224, 224})}; // order 0, wildcard - - OperationRule exact; - exact.op = "conv_fprop"; - exact.engineName = HIPBLASLT_ENGINE_NAME; - exact.tensors = {makePattern({1, 3, 224, 224})}; // order 1, exact - - auto config = makeConfig({std::move(wildcard), std::move(exact)}); - - const std::vector> tensors = {makeTensor({1, 3, 224, 224})}; - auto result = config.matchOperation("conv_fprop", tensors); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(*result, FUSILLI_ENGINE_ID); // wildcard (order 0) beats exact (order 1) -} - -// Test 12: exact declared before wildcard — exact must win -TEST(TestEngineOverrideConfig, ExactBeforeWildcardBothMatch) -{ - OperationRule exact; - exact.op = "conv_fprop"; - exact.engineName = HIPBLASLT_ENGINE_NAME; - exact.tensors = {makePattern({1, 3, 224, 224})}; // order 0, exact - - OperationRule wildcard; - wildcard.op = "conv_fprop"; - wildcard.engineName = FUSILLI_ENGINE_NAME; - wildcard.tensors = {makePattern({-1, 3, 224, 224})}; // order 1, wildcard - - auto config = makeConfig({std::move(exact), std::move(wildcard)}); - - const std::vector> tensors = {makeTensor({1, 3, 224, 224})}; - auto result = config.matchOperation("conv_fprop", tensors); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(*result, HIPBLASLT_ENGINE_ID); // exact (order 0) beats wildcard (order 1) -} - -// ── Stride matching tests ──────────────────────────────────────────────────── - -// Test 13: exact stride match selects the correct engine -TEST(TestEngineOverrideConfig, ExactStrideMatchSelectsEngine) -{ - OperationRule rule; - rule.op = "conv_fprop"; - rule.engineName = MIOPEN_ENGINE_NAME; - rule.tensors = {makePatternWithStride({1, 3, 224, 224}, {150528, 50176, 224, 1})}; - - auto config = makeConfig({std::move(rule)}); - - auto matching = makeTensorWithStride({1, 3, 224, 224}, {150528, 50176, 224, 1}); - auto result = config.matchOperation("conv_fprop", {matching}); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(*result, MIOPEN_ENGINE_ID); - - // Different stride must not match - auto wrongStride = makeTensorWithStride({1, 3, 224, 224}, - {1, 224, int64_t{224} * 3, int64_t{224} * 3 * 224}); - EXPECT_FALSE(config.matchOperation("conv_fprop", {wrongStride}).has_value()); -} - -// Test 14: wildcard stride element (-1) matches any value in that slot -TEST(TestEngineOverrideConfig, WildcardStrideElement) -{ - OperationRule rule; - rule.op = "conv_fprop"; - rule.engineName = HIPBLASLT_ENGINE_NAME; - // Wildcard on last two stride slots - rule.tensors = {makePatternWithStride({1, 3, 224, 224}, {150528, 50176, -1, -1})}; - - auto config = makeConfig({std::move(rule)}); - - // Should match regardless of the last two stride values - for(const int64_t s2 : {224, 112, 56}) - { - auto t = makeTensorWithStride({1, 3, 224, 224}, {150528, 50176, s2, 1}); - auto result = config.matchOperation("conv_fprop", {t}); - ASSERT_TRUE(result.has_value()) << "stride[2]=" << s2; - EXPECT_EQ(*result, HIPBLASLT_ENGINE_ID); - } - - // First two stride slots must still match - auto wrongStride = makeTensorWithStride({1, 3, 224, 224}, {999, 50176, 224, 1}); - EXPECT_FALSE(config.matchOperation("conv_fprop", {wrongStride}).has_value()); -} - -// Test 15: empty stride in pattern matches any tensor stride (no constraint) -TEST(TestEngineOverrideConfig, EmptyStridePatternMatchesAnyStride) -{ - OperationRule rule; - rule.op = "conv_fprop"; - rule.engineName = FUSILLI_ENGINE_NAME; - rule.tensors = {makePattern({1, 3, 224, 224})}; // no stride field - - auto config = makeConfig({std::move(rule)}); - - // Should match regardless of stride - for(const auto& strides : std::vector>{ - {150528, 50176, 224, 1}, {1, 3, 672, 150528}, {999, 888, 777, 666}}) - { - auto t = makeTensorWithStride({1, 3, 224, 224}, strides); - auto result = config.matchOperation("conv_fprop", {t}); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(*result, FUSILLI_ENGINE_ID); - } -} - -// ── Tests 8–10: JSON-dependent ────────────────────────────────────────────── - -#ifndef HIPDNN_FRONTEND_SKIP_JSON_LIB - -// Test 8: load from valid JSON file → parses rules, matches correctly - -TEST(TestEngineOverrideConfig, LoadFromValidJsonFile) -{ - constexpr const char* CONTENTS = R"({ - "engine_overrides": [ - { - "comment": "test rule for ResNet first conv", - "op": "conv_fprop", - "engine_name": "MIOPEN_ENGINE", - "tensors": [ - { "dim": [1, 3, 224, 224] }, - { "dim": [64, 3, 7, 7] } - ] - }, - { - "comment": "wildcard catch-all", - "op": "conv_fprop", - "engine_name": "FUSILLI_ENGINE", - "tensors": [ - { "dim": [-1, -1, -1, -1] }, - { "dim": [-1, -1, -1, -1] } - ] - } - ] -})"; - - auto config = EngineOverrideConfig::loadFromContent(CONTENTS); - ASSERT_TRUE(config.has_value()); - - // Exact match hits the first rule - const std::vector> exact - = {makeTensor({1, 3, 224, 224}), makeTensor({64, 3, 7, 7})}; - auto r1 = config->matchOperation("conv_fprop", exact); - ASSERT_TRUE(r1.has_value()); - EXPECT_EQ(*r1, MIOPEN_ENGINE_ID); - - // Different shape falls through to the wildcard rule - const std::vector> other - = {makeTensor({8, 64, 56, 56}), makeTensor({64, 64, 3, 3})}; - auto r2 = config->matchOperation("conv_fprop", other); - ASSERT_TRUE(r2.has_value()); - EXPECT_EQ(*r2, FUSILLI_ENGINE_ID); -} - -// Test 9: load from missing file → nullopt, no crash - -TEST(TestEngineOverrideConfig, LoadFromMissingFileReturnsNullopt) -{ - auto config = EngineOverrideConfig::load("/nonexistent/path/hipdnn_no_such_file.json"); - EXPECT_FALSE(config.has_value()); -} - -// Test 10: HIPDNN_ENGINE_OVERRIDE_FILE unset → loadFromEnv() returns nullptr - -TEST(TestEngineOverrideConfig, EnvVarUnsetReturnsNullptr) -{ - // HIPDNN_ENGINE_OVERRIDE_FILE is not set in the unit-test environment. - // loadFromEnv() caches on first call, so this also verifies the pointer - // is stable across repeated calls. - const auto* config = EngineOverrideConfig::loadFromEnv(); - EXPECT_EQ(config, nullptr); - EXPECT_EQ(EngineOverrideConfig::loadFromEnv(), config); // same cached pointer -} - -// Test 16: JSON with stride constraint is parsed and matched correctly -TEST(TestEngineOverrideConfig, JsonWithStrideConstraint) -{ - constexpr const char* CONTENTS = R"({ - "engine_overrides": [ - { - "op": "conv_fprop", - "engine_name": "MIOPEN_ENGINE", - "tensors": [ - { "dim": [1, 3, 224, 224], "stride": [150528, 50176, 224, 1] }, - { "dim": [64, 3, 7, 7] } - ] - } - ] -})"; - - auto config = EngineOverrideConfig::loadFromContent(CONTENTS); - ASSERT_TRUE(config.has_value()); - - auto x = makeTensorWithStride({1, 3, 224, 224}, {150528, 50176, 224, 1}); - auto w = makeTensor({64, 3, 7, 7}); - - auto r1 = config->matchOperation("conv_fprop", {x, w}); - ASSERT_TRUE(r1.has_value()); - EXPECT_EQ(*r1, MIOPEN_ENGINE_ID); - - // Wrong stride must not match - auto xWrong = makeTensorWithStride({1, 3, 224, 224}, - {1, 224, int64_t{224} * 3, int64_t{224} * 3 * 224}); - EXPECT_FALSE(config->matchOperation("conv_fprop", {xWrong, w}).has_value()); -} - -#endif // HIPDNN_FRONTEND_SKIP_JSON_LIB diff --git a/projects/hipdnn/frontend/tests/TestGraph.cpp b/projects/hipdnn/frontend/tests/TestGraph.cpp index b88411c8f63..ff7141659a7 100644 --- a/projects/hipdnn/frontend/tests/TestGraph.cpp +++ b/projects/hipdnn/frontend/tests/TestGraph.cpp @@ -1604,378 +1604,6 @@ TEST_F(TestGraph, CanSuccessfullyCreateExecutionPlans) EXPECT_TRUE(execPlanResult.is_good()); } -TEST_F(TestGraph, PreferredEngineIdSelectsSpecificConfig) -{ - ::testing::FLAGS_gmock_verbose = "error"; - Graph graph; - const std::vector heurModes = {HeuristicMode::FALLBACK}; - std::vector backendModes; - backendModes.reserve(heurModes.size()); - for(const auto& mode : heurModes) - { - backendModes.push_back(toBackendType(mode)); - } - auto tensorAttributes = createBasicBatchnormGraph(graph); - ASSERT_TRUE(graph.validate().is_good()); - - // Set preferred engine ID - const int64_t preferredEngineId = 42; - graph.set_preferred_engine_id_ext(preferredEngineId); - - graph.build_operation_graph(_handle); - - auto heurDesc = reinterpret_cast(0x5678); - EXPECT_CALL(*_mockBackend, backendCreateDescriptor(HIPDNN_BACKEND_ENGINEHEUR_DESCRIPTOR, _)) - .WillOnce( - [&heurDesc](hipdnnBackendDescriptorType_t, hipdnnBackendDescriptor_t* descriptor) { - *descriptor = heurDesc; - return HIPDNN_STATUS_SUCCESS; - }); - - EXPECT_CALL( - *_mockBackend, - backendSetAttribute( - heurDesc, HIPDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH, HIPDNN_TYPE_BACKEND_DESCRIPTOR, 1, _)) - .WillOnce(Return(HIPDNN_STATUS_SUCCESS)); - - EXPECT_CALL( - *_mockBackend, - backendSetAttribute(heurDesc, HIPDNN_ATTR_ENGINEHEUR_MODE, HIPDNN_TYPE_HEUR_MODE, 1, _)) - .WillOnce([&backendModes](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t count, - const void* arrayOfElements) { - EXPECT_EQ(count, static_cast(backendModes.size())); - auto modesPtr = static_cast(arrayOfElements); - for(size_t i = 0; i < backendModes.size(); ++i) - { - EXPECT_EQ(modesPtr[i], backendModes[i]); - } - return HIPDNN_STATUS_SUCCESS; - }); - - EXPECT_CALL(*_mockBackend, backendFinalize(heurDesc)); - - // First call: elementCount query - return 2 configs available - EXPECT_CALL(*_mockBackend, - backendGetAttribute(heurDesc, - HIPDNN_ATTR_ENGINEHEUR_RESULTS, - HIPDNN_TYPE_BACKEND_DESCRIPTOR, - 0, - _, - nullptr)) - .WillOnce([](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t, - int64_t* elementCount, - void*) { - *elementCount = 2; - return HIPDNN_STATUS_SUCCESS; - }); - - auto engineConfigDesc1 = reinterpret_cast(0x2345); - auto engineConfigDesc2 = reinterpret_cast(0x2346); - auto engineDesc1 = reinterpret_cast(0x3345); - auto engineDesc2 = reinterpret_cast(0x3346); - - EXPECT_CALL(*_mockBackend, backendCreateDescriptor(HIPDNN_BACKEND_ENGINECFG_DESCRIPTOR, _)) - .WillOnce([&engineConfigDesc1](hipdnnBackendDescriptorType_t, - hipdnnBackendDescriptor_t* descriptor) { - *descriptor = engineConfigDesc1; - return HIPDNN_STATUS_SUCCESS; - }) - .WillOnce([&engineConfigDesc2](hipdnnBackendDescriptorType_t, - hipdnnBackendDescriptor_t* descriptor) { - *descriptor = engineConfigDesc2; - return HIPDNN_STATUS_SUCCESS; - }); - - // Second call: actual data retrieval - EXPECT_CALL(*_mockBackend, - backendGetAttribute(heurDesc, - HIPDNN_ATTR_ENGINEHEUR_RESULTS, - HIPDNN_TYPE_BACKEND_DESCRIPTOR, - 2, - _, - NotNull())) - .WillOnce([](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t, - int64_t* retrievedCount, - void*) { - *retrievedCount = 2; - return HIPDNN_STATUS_SUCCESS; - }); - - EXPECT_CALL(*_mockBackend, backendFinalize(engineConfigDesc1)); - - // Get engine from first config (ID = 10) - EXPECT_CALL(*_mockBackend, - backendGetAttribute(engineConfigDesc1, - HIPDNN_ATTR_ENGINECFG_ENGINE, - HIPDNN_TYPE_BACKEND_DESCRIPTOR, - 1, - nullptr, - _)) - .WillOnce([&engineDesc1](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t, - int64_t*, - void* arrayOfElements) { - *static_cast(arrayOfElements) = engineDesc1; - return HIPDNN_STATUS_SUCCESS; - }); - - // Get ID from first engine - EXPECT_CALL(*_mockBackend, - backendGetAttribute( - engineDesc1, HIPDNN_ATTR_ENGINE_GLOBAL_INDEX, HIPDNN_TYPE_INT64, 1, nullptr, _)) - .WillOnce([](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t, - int64_t*, - void* arrayOfElements) { - *static_cast(arrayOfElements) = 10; - return HIPDNN_STATUS_SUCCESS; - }); - - EXPECT_CALL(*_mockBackend, backendFinalize(engineConfigDesc2)); - - // Get engine from second config (ID = 42 - our preferred one) - EXPECT_CALL(*_mockBackend, - backendGetAttribute(engineConfigDesc2, - HIPDNN_ATTR_ENGINECFG_ENGINE, - HIPDNN_TYPE_BACKEND_DESCRIPTOR, - 1, - nullptr, - _)) - .WillOnce([&engineDesc2](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t, - int64_t*, - void* arrayOfElements) { - *static_cast(arrayOfElements) = engineDesc2; - return HIPDNN_STATUS_SUCCESS; - }); - - // Get ID from second engine - EXPECT_CALL(*_mockBackend, - backendGetAttribute( - engineDesc2, HIPDNN_ATTR_ENGINE_GLOBAL_INDEX, HIPDNN_TYPE_INT64, 1, nullptr, _)) - .WillOnce([](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t, - int64_t*, - void* arrayOfElements) { - *static_cast(arrayOfElements) = preferredEngineId; - return HIPDNN_STATUS_SUCCESS; - }); - - auto executionPlanDesc = reinterpret_cast(0x9876); - EXPECT_CALL(*_mockBackend, backendCreateDescriptor(HIPDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR, _)) - .WillOnce([&executionPlanDesc](hipdnnBackendDescriptorType_t, - hipdnnBackendDescriptor_t* descriptor) { - *descriptor = executionPlanDesc; - return HIPDNN_STATUS_SUCCESS; - }); - - auto execPlanResult = graph.create_execution_plans(heurModes); - EXPECT_TRUE(execPlanResult.is_good()); -} - -TEST_F(TestGraph, PreferredEngineIdFallsBackToTopConfig) -{ - ::testing::FLAGS_gmock_verbose = "error"; - Graph graph; - const std::vector heurModes = {HeuristicMode::FALLBACK}; - std::vector backendModes; - backendModes.reserve(heurModes.size()); - for(const auto& mode : heurModes) - { - backendModes.push_back(toBackendType(mode)); - } - auto tensorAttributes = createBasicBatchnormGraph(graph); - ASSERT_TRUE(graph.validate().is_good()); - - // Set preferred engine ID that doesn't exist - const int64_t preferredEngineId = 999; - graph.set_preferred_engine_id_ext(preferredEngineId); - - graph.build_operation_graph(_handle); - - auto heurDesc = reinterpret_cast(0x5678); - EXPECT_CALL(*_mockBackend, backendCreateDescriptor(HIPDNN_BACKEND_ENGINEHEUR_DESCRIPTOR, _)) - .WillOnce( - [&heurDesc](hipdnnBackendDescriptorType_t, hipdnnBackendDescriptor_t* descriptor) { - *descriptor = heurDesc; - return HIPDNN_STATUS_SUCCESS; - }); - - EXPECT_CALL( - *_mockBackend, - backendSetAttribute( - heurDesc, HIPDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH, HIPDNN_TYPE_BACKEND_DESCRIPTOR, 1, _)) - .WillOnce(Return(HIPDNN_STATUS_SUCCESS)); - - EXPECT_CALL( - *_mockBackend, - backendSetAttribute(heurDesc, HIPDNN_ATTR_ENGINEHEUR_MODE, HIPDNN_TYPE_HEUR_MODE, 1, _)) - .WillOnce([&backendModes](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t count, - const void* arrayOfElements) { - EXPECT_EQ(count, static_cast(backendModes.size())); - auto modesPtr = static_cast(arrayOfElements); - for(size_t i = 0; i < backendModes.size(); ++i) - { - EXPECT_EQ(modesPtr[i], backendModes[i]); - } - return HIPDNN_STATUS_SUCCESS; - }); - - EXPECT_CALL(*_mockBackend, backendFinalize(heurDesc)); - - // First call: elementCount query - return 2 configs available - EXPECT_CALL(*_mockBackend, - backendGetAttribute(heurDesc, - HIPDNN_ATTR_ENGINEHEUR_RESULTS, - HIPDNN_TYPE_BACKEND_DESCRIPTOR, - 0, - _, - nullptr)) - .WillOnce([](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t, - int64_t* elementCount, - void*) { - *elementCount = 2; - return HIPDNN_STATUS_SUCCESS; - }); - - auto engineConfigDesc1 = reinterpret_cast(0x2345); - auto engineConfigDesc2 = reinterpret_cast(0x2346); - auto engineDesc1 = reinterpret_cast(0x3345); - auto engineDesc2 = reinterpret_cast(0x3346); - - EXPECT_CALL(*_mockBackend, backendCreateDescriptor(HIPDNN_BACKEND_ENGINECFG_DESCRIPTOR, _)) - .WillOnce([&engineConfigDesc1](hipdnnBackendDescriptorType_t, - hipdnnBackendDescriptor_t* descriptor) { - *descriptor = engineConfigDesc1; - return HIPDNN_STATUS_SUCCESS; - }) - .WillOnce([&engineConfigDesc2](hipdnnBackendDescriptorType_t, - hipdnnBackendDescriptor_t* descriptor) { - *descriptor = engineConfigDesc2; - return HIPDNN_STATUS_SUCCESS; - }); - - // Second call: actual data retrieval - EXPECT_CALL(*_mockBackend, - backendGetAttribute(heurDesc, - HIPDNN_ATTR_ENGINEHEUR_RESULTS, - HIPDNN_TYPE_BACKEND_DESCRIPTOR, - 2, - _, - NotNull())) - .WillOnce([](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t, - int64_t* retrievedCount, - void*) { - *retrievedCount = 2; - return HIPDNN_STATUS_SUCCESS; - }); - - EXPECT_CALL(*_mockBackend, backendFinalize(engineConfigDesc1)); - - // Get engine from first config (ID = 10) - EXPECT_CALL(*_mockBackend, - backendGetAttribute(engineConfigDesc1, - HIPDNN_ATTR_ENGINECFG_ENGINE, - HIPDNN_TYPE_BACKEND_DESCRIPTOR, - 1, - nullptr, - _)) - .WillOnce([&engineDesc1](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t, - int64_t*, - void* arrayOfElements) { - *static_cast(arrayOfElements) = engineDesc1; - return HIPDNN_STATUS_SUCCESS; - }); - - // Get ID from first engine (neither will match preferred ID 999) - EXPECT_CALL(*_mockBackend, - backendGetAttribute( - engineDesc1, HIPDNN_ATTR_ENGINE_GLOBAL_INDEX, HIPDNN_TYPE_INT64, 1, nullptr, _)) - .WillOnce([](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t, - int64_t*, - void* arrayOfElements) { - *static_cast(arrayOfElements) = 10; - return HIPDNN_STATUS_SUCCESS; - }); - - EXPECT_CALL(*_mockBackend, backendFinalize(engineConfigDesc2)); - - // Get engine from second config (ID = 42) - EXPECT_CALL(*_mockBackend, - backendGetAttribute(engineConfigDesc2, - HIPDNN_ATTR_ENGINECFG_ENGINE, - HIPDNN_TYPE_BACKEND_DESCRIPTOR, - 1, - nullptr, - _)) - .WillOnce([&engineDesc2](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t, - int64_t*, - void* arrayOfElements) { - *static_cast(arrayOfElements) = engineDesc2; - return HIPDNN_STATUS_SUCCESS; - }); - - // Get ID from second engine - EXPECT_CALL(*_mockBackend, - backendGetAttribute( - engineDesc2, HIPDNN_ATTR_ENGINE_GLOBAL_INDEX, HIPDNN_TYPE_INT64, 1, nullptr, _)) - .WillOnce([](hipdnnBackendDescriptor_t, - hipdnnBackendAttributeName_t, - hipdnnBackendAttributeType_t, - int64_t, - int64_t*, - void* arrayOfElements) { - *static_cast(arrayOfElements) = 42; - return HIPDNN_STATUS_SUCCESS; - }); - - auto executionPlanDesc = reinterpret_cast(0x9876); - EXPECT_CALL(*_mockBackend, backendCreateDescriptor(HIPDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR, _)) - .WillOnce([&executionPlanDesc](hipdnnBackendDescriptorType_t, - hipdnnBackendDescriptor_t* descriptor) { - *descriptor = executionPlanDesc; - return HIPDNN_STATUS_SUCCESS; - }); - - auto execPlanResult = graph.create_execution_plans(heurModes); - EXPECT_TRUE(execPlanResult.is_good()); -} - TEST_F(TestGraph, CheckSupportFailsIfNoExecutionPlanCreated) { ::testing::FLAGS_gmock_verbose = "error"; @@ -5320,87 +4948,6 @@ TEST_F(TestGraph, MoveAssignmentToEmptyGraph) EXPECT_EQ(targetGraph.get_compute_data_type(), DataType::FLOAT); } -// ── Engine Override Config integration ─────────────────────────────────────── - -// Test: EngineOverrideConfig::matchOperation identifies conv_fprop tensors -// with the same dims that getPreferredIdFromOverrideConfig() would pass to -// checkEngineOverride() at build time. -TEST_F(TestGraph, EngineOverrideConfigMatchesConvFpropTensors) -{ - using namespace hipdnn_frontend::engine_override; - - // Tensors with conv_fprop dims (x={1,3,32,32}, w={64,3,3,3}) - auto x = std::make_shared(); - x->set_dim({1, 3, 32, 32}).set_stride({3072, 1024, 32, 1}).set_data_type(DataType::FLOAT); - - auto w = std::make_shared(); - w->set_dim({64, 3, 3, 3}).set_stride({27, 9, 3, 1}).set_data_type(DataType::FLOAT); - - // Exact rule for this shape - OperationRule exactRule; - exactRule.op = "conv_fprop"; - exactRule.engineName = hipdnn_data_sdk::utilities::HIPBLASLT_ENGINE_NAME; - exactRule.tensors = {TensorPattern{{1, 3, 32, 32}, {}}, TensorPattern{{64, 3, 3, 3}, {}}}; - - const EngineOverrideConfig config({std::move(exactRule)}); - - auto result = config.matchOperation("conv_fprop", {x, w}); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(*result, hipdnn_data_sdk::utilities::HIPBLASLT_ENGINE_ID); - - // Wrong op must not match - EXPECT_FALSE(config.matchOperation("conv_dgrad", {x, w}).has_value()); - - // Different batch size must not match (no wildcard in rule) - auto x8 = std::make_shared(); - x8->set_dim({8, 3, 32, 32}).set_data_type(DataType::FLOAT); - EXPECT_FALSE(config.matchOperation("conv_fprop", {x8, w}).has_value()); -} - -// Test 3: loading a JSON config from an in-memory string and matching against -// conv_fprop tensors. This exercises the full loadFromContent() → matchOperation() -// path with the same shapes that the graph presents during build_operation_graph(). -#ifndef HIPDNN_FRONTEND_SKIP_JSON_LIB -TEST_F(TestGraph, EngineOverrideConfigFromContentMatchesConvFpropGraph) -{ - using namespace hipdnn_frontend::engine_override; - - const int64_t kEngine = hipdnn_data_sdk::utilities::MIOPEN_ENGINE_ID; - - const std::string kJson = R"({ - "engine_overrides": [ - { - "op": "conv_fprop", - "engine_name": "MIOPEN_ENGINE", - "tensors": [ - { "dim": [1, 3, 32, 32] }, - { "dim": [64, 3, 3, 3] } - ] - } - ] -})"; - - auto config = EngineOverrideConfig::loadFromContent(kJson); - ASSERT_TRUE(config.has_value()); - - // Same conv_fprop tensor dims - auto x = std::make_shared(); - x->set_dim({1, 3, 32, 32}).set_stride({3072, 1024, 32, 1}).set_data_type(DataType::FLOAT); - - auto w = std::make_shared(); - w->set_dim({64, 3, 3, 3}).set_stride({27, 9, 3, 1}).set_data_type(DataType::FLOAT); - - auto result = config->matchOperation("conv_fprop", {x, w}); - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(*result, kEngine); - - // A different batch size must not match the exact rule - auto x8 = std::make_shared(); - x8->set_dim({8, 3, 32, 32}).set_data_type(DataType::FLOAT); - EXPECT_FALSE(config->matchOperation("conv_fprop", {x8, w}).has_value()); -} -#endif // HIPDNN_FRONTEND_SKIP_JSON_LIB - #ifdef HIPDNN_ENABLE_SDPA TEST_F(TestGraph, SdpaFwdNodeCreation) { diff --git a/projects/hipdnn/plugin_sdk/include/hipdnn_plugin_sdk/HeuristicValidation.hpp b/projects/hipdnn/plugin_sdk/include/hipdnn_plugin_sdk/HeuristicValidation.hpp new file mode 100644 index 00000000000..add93ef8ca7 --- /dev/null +++ b/projects/hipdnn/plugin_sdk/include/hipdnn_plugin_sdk/HeuristicValidation.hpp @@ -0,0 +1,73 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +/** + * @file HeuristicValidation.hpp + * @brief Argument validation macros shared by heuristic plugin C ABI entry points. + * + * Heuristic plugins repeat the same null-pointer / serialized-buffer / array + * checks at every C ABI entry point. The macros below collapse the three + * recurring patterns into one-liners that log through the plugin's existing + * logging macro and return @c HIPDNN_PLUGIN_STATUS_BAD_PARAM on failure. + * + * Each macro takes the per-plugin log macro as an argument so error messages + * keep the plugin's own prefix and routing — there is no global logging + * dependency in this header. + */ + +// NOLINTBEGIN(bugprone-macro-parentheses) + +/// Reject a null pointer at a C ABI entry point. +/// +/// @param ptr Pointer to validate. Must be a plain identifier or expression. +/// @param log Per-plugin logging macro (e.g. CONFIG_LOG). +/// @param msg String literal logged at HIPDNN_SEV_ERROR when @p ptr is null. +#define HIPDNN_PLUGIN_REQUIRE_NOT_NULL(ptr, log, msg) \ + do \ + { \ + if((ptr) == nullptr) \ + { \ + log(HIPDNN_SEV_ERROR, msg); \ + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; \ + } \ + } while(0) + +/// Reject a malformed @c hipdnnPluginConstData_t* argument. +/// +/// @param data Pointer to the const-data struct. +/// @param requireSize When true, also rejects buffers with @c size == 0. +/// @param log Per-plugin logging macro. +/// @param msg String literal logged at HIPDNN_SEV_ERROR on failure. +#define HIPDNN_PLUGIN_REQUIRE_CONST_DATA(data, requireSize, log, msg) \ + do \ + { \ + if((data) == nullptr || (data)->ptr == nullptr || ((requireSize) && (data)->size == 0)) \ + { \ + log(HIPDNN_SEV_ERROR, msg); \ + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; \ + } \ + } while(0) + +/// Reject a (pointer, count) pair that claims a non-empty array but supplies +/// a null pointer. +/// +/// @param ptr Array pointer. +/// @param count Element count claimed by the caller. +/// @param log Per-plugin logging macro. +/// @param msg String literal logged at HIPDNN_SEV_ERROR on failure. +#define HIPDNN_PLUGIN_REQUIRE_ARRAY(ptr, count, log, msg) \ + do \ + { \ + if((ptr) == nullptr && (count) > 0) \ + { \ + log(HIPDNN_SEV_ERROR, msg); \ + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; \ + } \ + } while(0) + +// NOLINTEND(bugprone-macro-parentheses) diff --git a/projects/hipdnn/plugin_sdk/include/hipdnn_plugin_sdk/HeuristicsPluginApi.h b/projects/hipdnn/plugin_sdk/include/hipdnn_plugin_sdk/HeuristicsPluginApi.h index 432d968cc3f..b489fca300a 100644 --- a/projects/hipdnn/plugin_sdk/include/hipdnn_plugin_sdk/HeuristicsPluginApi.h +++ b/projects/hipdnn/plugin_sdk/include/hipdnn_plugin_sdk/HeuristicsPluginApi.h @@ -17,7 +17,7 @@ * policy plugins that control engine ordering. * * IMPORTANT: Heuristic plugins must implement ALL base plugin functions from PluginApi.h: - * - hipdnnPluginGetName - Returns the plugin name (e.g., "StaticOrdering") + * - hipdnnPluginGetName - Returns the plugin (shared library) name, used for diagnostics * - hipdnnPluginGetVersion - Returns the plugin implementation version * - hipdnnPluginGetApiVersion - Returns the API version this plugin supports * - hipdnnPluginGetType - Returns HIPDNN_PLUGIN_TYPE_HEURISTIC @@ -27,6 +27,13 @@ * * PLUS the heuristic-specific functions defined below. * + * Multi-policy plugins: A single heuristic plugin shared library may expose one or more + * selection policies. Each policy is identified by a stable int64 policy ID (typically + * derived from a canonical policy name via hipdnn_data_sdk::utilities::policyNameToId). + * The host enumerates policies via hipdnnHeuristicPluginGetAllPolicyIds and resolves + * names via hipdnnHeuristicPluginGetPolicyName. A single plugin handle is shared across + * all policies of the same library; per-policy state lives in the policy descriptor. + * * Status codes: Use hipdnnPluginStatus_t for all return values. * Serialized data: Device properties and graphs cross the ABI as hipdnnPluginConstData_t*. */ @@ -79,14 +86,57 @@ typedef struct hipdnnHeuristicPolicyDescriptor_opaque* hipdnnHeuristicPolicyDesc /** @} */ // End of HeuristicPluginDataTypes group /** - * @defgroup HeuristicPluginExtensions Heuristic Plugin Extensions - * @brief Heuristic-specific functions beyond the base PluginApi.h. - * - * See the file comment for the complete list of required base plugin functions. + * @defgroup HeuristicPluginPolicyEnumeration Heuristic Plugin Policy Enumeration + * @brief Functions for discovering the set of policies a plugin exposes. * @{ */ -/** @} */ // End of HeuristicPluginExtensions group +/** + * @brief Retrieves the IDs of all selection policies the plugin exposes. + * + * A heuristic plugin may expose one or more policies; each is identified by a + * stable int64 policy ID (typically the FNV-1a hash of a canonical policy + * name; see hipdnn_data_sdk::utilities::policyNameToId). + * + * This function follows the same two-pass query/retrieve pattern as + * hipdnnEnginePluginGetAllEngineIds: + * 1. Pass max_policies = 0 (and policy_ids may be NULL) to discover the + * total count, written to *num_policies. + * 2. Allocate an array of that size and pass max_policies = N to fill it; + * *num_policies is set to the number of IDs actually written. + * + * @param[out] policy_ids Array to receive policy IDs, or NULL when querying count. + * @param[in] max_policies Capacity of the policy_ids array; 0 to query count only. + * @param[out] num_policies On count query: total available policies. + * On retrieve: number of IDs written. + * + * @return HIPDNN_PLUGIN_STATUS_SUCCESS on success, error code otherwise. + * + * @note Policy IDs must be unique within a plugin and stable for the lifetime + * of the loaded library. + */ +HIPDNN_PLUGIN_NODISCARD HIPDNN_HEURISTIC_PLUGIN_EXPORT hipdnnPluginStatus_t + hipdnnHeuristicPluginGetAllPolicyIds(int64_t* policy_ids, + uint32_t max_policies, + uint32_t* num_policies); + +/** + * @brief Retrieves the canonical name of a specific policy. + * + * The host validates that policyNameToId(name) == policy_id; mismatches cause + * the plugin to be rejected at load time. + * + * @param[in] policy_id The policy ID (must come from hipdnnHeuristicPluginGetAllPolicyIds). + * @param[out] name Pointer to receive a NUL-terminated string owned by the plugin. + * Must remain valid for the lifetime of the loaded library. + * + * @return HIPDNN_PLUGIN_STATUS_SUCCESS on success, + * HIPDNN_PLUGIN_STATUS_BAD_PARAM if policy_id is not exposed by this plugin. + */ +HIPDNN_PLUGIN_NODISCARD HIPDNN_HEURISTIC_PLUGIN_EXPORT hipdnnPluginStatus_t + hipdnnHeuristicPluginGetPolicyName(int64_t policy_id, const char** name); + +/** @} */ // End of HeuristicPluginPolicyEnumeration group /** * @defgroup HeuristicPluginHandleLifecycle Heuristic Plugin Handle Lifecycle @@ -157,11 +207,12 @@ HIPDNN_PLUGIN_NODISCARD HIPDNN_HEURISTIC_PLUGIN_EXPORT hipdnnPluginStatus_t */ /** - * @brief Creates a new policy descriptor. + * @brief Creates a new policy descriptor for a specific policy of the plugin. * * The host calls this once per policy slot in EngineHeuristicDescriptor, binding the - * descriptor to the given plugin handle. The descriptor stores per-slot state: - * candidate engine IDs, serialized graph, and finalize result. + * descriptor to the given plugin handle and selecting which policy from the plugin + * the descriptor will execute. The descriptor stores per-slot state: candidate + * engine IDs, serialized graph, and finalize result. * * This BINDS the policy to the handle BEFORE Finalize, so selection code can treat the * handle as the source of device-properties session state (SetDeviceProperties). @@ -170,12 +221,16 @@ HIPDNN_PLUGIN_NODISCARD HIPDNN_HEURISTIC_PLUGIN_EXPORT hipdnnPluginStatus_t * destroyed with the descriptor. * * @param[in] plugin_handle The plugin handle this descriptor is bound to. + * @param[in] policy_id The ID of the policy this descriptor will execute. Must be one + * of the IDs returned by hipdnnHeuristicPluginGetAllPolicyIds. * @param[out] out_desc Pointer to receive the created policy descriptor. * - * @return HIPDNN_PLUGIN_STATUS_SUCCESS on success, error code otherwise. + * @return HIPDNN_PLUGIN_STATUS_SUCCESS on success, + * HIPDNN_PLUGIN_STATUS_BAD_PARAM if policy_id is not exposed by this plugin. */ HIPDNN_PLUGIN_NODISCARD HIPDNN_HEURISTIC_PLUGIN_EXPORT hipdnnPluginStatus_t hipdnnHeuristicPolicyDescriptorCreate(hipdnnHeuristicHandle_t plugin_handle, + int64_t policy_id, hipdnnHeuristicPolicyDescriptor_t* out_desc); /** @@ -279,15 +334,25 @@ HIPDNN_PLUGIN_NODISCARD HIPDNN_HEURISTIC_PLUGIN_EXPORT hipdnnPluginStatus_t * The output IDs must be a permutation or subset of the input IDs from SetEngineIds. * The host validates this constraint. * - * This function supports two usage patterns: - * 1. Query count: Pass engine_ids = nullptr to get the count in num_engines - * 2. Retrieve IDs: Pass non-null engine_ids and capacity in *num_engines, - * receive actual count in *num_engines + * Callers MUST use the two-call pattern: + * 1. Query count: Pass engine_ids = NULL; on return *num_engines holds the + * total number of IDs the policy will produce. + * 2. Retrieve IDs: Allocate an array of that exact size, set *num_engines to + * that capacity, and call again with engine_ids pointing at the array. + * On return *num_engines holds the number of IDs actually written. + * + * If the caller supplies a non-NULL engine_ids with a capacity smaller than the + * policy's full result, the implementation silently truncates: it writes + * min(*num_engines, total) IDs and sets *num_engines to that truncated count. + * The return value is still HIPDNN_PLUGIN_STATUS_SUCCESS, so the caller cannot + * distinguish "buffer was exactly right" from "buffer was too small" without + * having queried the count first. Always query first. * * @param[in] desc The policy descriptor. - * @param[out] engine_ids Array to receive the sorted engine IDs, or nullptr to query count. - * @param[in,out] num_engines Input: capacity of engine_ids array (ignored if engine_ids is null). - * Output: number of IDs available/written. + * @param[out] engine_ids Array to receive the sorted engine IDs, or NULL to query count. + * @param[in,out] num_engines Input: capacity of engine_ids array (ignored if engine_ids is NULL). + * Output: number of IDs available (count query) or written + * (retrieve, possibly truncated to the input capacity). * * @return HIPDNN_PLUGIN_STATUS_SUCCESS on success, * HIPDNN_PLUGIN_STATUS_NOT_INITIALIZED if descriptor not finalized, diff --git a/projects/hipdnn/plugin_sdk/include/hipdnn_plugin_sdk/heuristic_api_version.h b/projects/hipdnn/plugin_sdk/include/hipdnn_plugin_sdk/heuristic_api_version.h index 87f804eeb6f..09bf4866e7a 100644 --- a/projects/hipdnn/plugin_sdk/include/hipdnn_plugin_sdk/heuristic_api_version.h +++ b/projects/hipdnn/plugin_sdk/include/hipdnn_plugin_sdk/heuristic_api_version.h @@ -5,7 +5,7 @@ /** * @file heuristic_api_version.h - * @brief Version constants for the heuristic plugin C ABI (RFC 0007) + * @brief Version constants for the heuristic plugin C ABI * * The heuristic plugin API has its own versioning scheme, independent of the * backend library version. This allows the heuristic plugin interface to evolve diff --git a/projects/hipdnn/tests/backend/CMakeLists.txt b/projects/hipdnn/tests/backend/CMakeLists.txt index 0ce6fad9745..0441b42d465 100644 --- a/projects/hipdnn/tests/backend/CMakeLists.txt +++ b/projects/hipdnn/tests/backend/CMakeLists.txt @@ -62,6 +62,7 @@ target_compile_definitions( TEST_KNOBS_PLUGIN_NAME="${TEST_KNOBS_PLUGIN_NAME}" TEST_KNOB_CONSTRAINT_VALIDATION_PLUGIN_NAME="${TEST_KNOB_CONSTRAINT_VALIDATION_PLUGIN_NAME}" TEST_INCOMPATIBLE_VERSION_PLUGIN_NAME="${TEST_INCOMPATIBLE_VERSION_PLUGIN_NAME}" + TEST_GOOD_HEURISTIC_PLUGIN_NAME="${TEST_GOOD_HEURISTIC_PLUGIN_NAME}" ) # Ensure test plugins are built before tests @@ -78,6 +79,7 @@ add_dependencies( test_knobs_plugin test_knob_constraint_validation_plugin test_incompatible_version_plugin + test_good_heuristic_plugin ) clang_tidy_check(hipdnn_public_backend_tests) diff --git a/projects/hipdnn/tests/backend/IntegrationEngineHeuristicApi.cpp b/projects/hipdnn/tests/backend/IntegrationEngineHeuristicApi.cpp index 22af60d18a8..a0e4ba28a5b 100644 --- a/projects/hipdnn/tests/backend/IntegrationEngineHeuristicApi.cpp +++ b/projects/hipdnn/tests/backend/IntegrationEngineHeuristicApi.cpp @@ -3,6 +3,9 @@ #include "TestUtil.hpp" #include "hipdnn_backend.h" +#include +#include +#include #include #include @@ -21,16 +24,33 @@ class IntegrationEngineHeuristicApi : public ::testing::Test hipdnnBackendDescriptor_t _engineHeuristic = nullptr; hipdnnBackendDescriptor_t _graph = nullptr; hipdnnHandle_t _handle = nullptr; + hipStream_t _stream = nullptr; + std::optional _policyOrderEnv; void SetUp() override { - const std::array paths + // finalize() resolves the device from the handle's stream, so the + // fixture binds a real stream and skips when no devices are present. + SKIP_IF_NO_DEVICES(); + + const std::array enginePaths = {hipdnn_tests::plugin_constants::testGoodPluginPath().c_str()}; ASSERT_EQ(hipdnnSetEnginePluginPaths_ext( - paths.size(), paths.data(), HIPDNN_PLUGIN_LOADING_ABSOLUTE), + enginePaths.size(), enginePaths.data(), HIPDNN_PLUGIN_LOADING_ABSOLUTE), + HIPDNN_STATUS_SUCCESS); + + const std::array heuristicPaths + = {hipdnn_tests::plugin_constants::testGoodHeuristicPluginPath().c_str()}; + ASSERT_EQ(hipdnnSetHeuristicPluginPaths_ext( + heuristicPaths.size(), heuristicPaths.data(), HIPDNN_PLUGIN_LOADING_ABSOLUTE), HIPDNN_STATUS_SUCCESS); + _policyOrderEnv.emplace("HIPDNN_HEUR_POLICY_ORDER", + hipdnn_tests::plugin_constants::testGoodHeuristicPolicyName()); + ASSERT_EQ(hipdnnCreate(&_handle), HIPDNN_STATUS_SUCCESS); + ASSERT_EQ(hipStreamCreate(&_stream), hipSuccess); + ASSERT_EQ(hipdnnSetStream(_handle, _stream), HIPDNN_STATUS_SUCCESS); EXPECT_EQ( hipdnnBackendCreateDescriptor(HIPDNN_BACKEND_ENGINEHEUR_DESCRIPTOR, &_engineHeuristic), HIPDNN_STATUS_SUCCESS); @@ -52,6 +72,11 @@ class IntegrationEngineHeuristicApi : public ::testing::Test EXPECT_EQ(hipdnnDestroy(_handle), HIPDNN_STATUS_SUCCESS); _handle = nullptr; } + if(_stream != nullptr) + { + EXPECT_EQ(hipStreamDestroy(_stream), hipSuccess); + _stream = nullptr; + } } void setHeuristicMode() diff --git a/projects/hipdnn/tests/backend/IntegrationLoggingPipeline.cpp b/projects/hipdnn/tests/backend/IntegrationLoggingPipeline.cpp index dee77584d6c..725fb670de5 100644 --- a/projects/hipdnn/tests/backend/IntegrationLoggingPipeline.cpp +++ b/projects/hipdnn/tests/backend/IntegrationLoggingPipeline.cpp @@ -8,7 +8,9 @@ #include #include #include +#include #include +#include namespace fs = std::filesystem; @@ -182,6 +184,14 @@ TEST_F(IntegrationGpuLoggingPipeline, FullWorkflowLogging) { SKIP_IF_NO_DEVICES(); + const std::array heuristicPaths + = {hipdnn_tests::plugin_constants::testGoodHeuristicPluginPath().c_str()}; + ASSERT_EQ(hipdnnSetHeuristicPluginPaths_ext( + heuristicPaths.size(), heuristicPaths.data(), HIPDNN_PLUGIN_LOADING_ABSOLUTE), + HIPDNN_STATUS_SUCCESS); + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter policyEnv( + "HIPDNN_HEUR_POLICY_ORDER", hipdnn_tests::plugin_constants::testGoodHeuristicPolicyName()); + hipdnnHandle_t handle = nullptr; ASSERT_EQ(hipdnnCreate(&handle), HIPDNN_STATUS_SUCCESS); diff --git a/projects/hipdnn/tests/backend/IntegrationPluginLoading.cpp b/projects/hipdnn/tests/backend/IntegrationPluginLoading.cpp index 8e9f03eda60..d4a8559f067 100644 --- a/projects/hipdnn/tests/backend/IntegrationPluginLoading.cpp +++ b/projects/hipdnn/tests/backend/IntegrationPluginLoading.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -32,9 +33,21 @@ class IntegrationPluginLoading : public ::testing::Test hipdnnBackendDescriptor_t _graph = nullptr; hipdnnBackendDescriptor_t _heuristicDescriptor = nullptr; hipdnnHandle_t _handle = nullptr; + hipStream_t _stream = nullptr; void SetUp() override {} + // Bind a real stream to the handle. Required for tests that finalize a + // heuristic descriptor with a non-empty applicable-engine list, since + // EngineHeuristicDescriptor::finalize() resolves the device through + // hipStreamGetDevice(handle->getStream(), ...). Caller must invoke + // SKIP_IF_NO_DEVICES() before this so the test skips on no-GPU runners. + void bindStream() + { + ASSERT_EQ(hipStreamCreate(&_stream), hipSuccess); + ASSERT_EQ(hipdnnSetStream(_handle, _stream), HIPDNN_STATUS_SUCCESS); + } + void TearDown() override { if(_engineConfig != nullptr) @@ -62,6 +75,11 @@ class IntegrationPluginLoading : public ::testing::Test EXPECT_EQ(hipdnnDestroy(_handle), HIPDNN_STATUS_SUCCESS); _handle = nullptr; } + if(_stream != nullptr) + { + EXPECT_EQ(hipStreamDestroy(_stream), hipSuccess); + _stream = nullptr; + } } }; @@ -262,6 +280,8 @@ TEST_F(IntegrationPluginLoading, MultiplePluginsNoApplicableEngines) TEST_F(IntegrationPluginLoading, MultiplePluginsOneApplicableEngine) { + SKIP_IF_NO_DEVICES(); + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter envSetter( "HIPDNN_PLUGIN_DIR", getTestPluginDefaultDir()); @@ -271,7 +291,16 @@ TEST_F(IntegrationPluginLoading, MultiplePluginsOneApplicableEngine) hipdnnSetEnginePluginPaths_ext(paths.size(), paths.data(), HIPDNN_PLUGIN_LOADING_ADDITIVE), HIPDNN_STATUS_SUCCESS); + const std::array heuristicPaths + = {hipdnn_tests::plugin_constants::testGoodHeuristicPluginPath().c_str()}; + ASSERT_EQ(hipdnnSetHeuristicPluginPaths_ext( + heuristicPaths.size(), heuristicPaths.data(), HIPDNN_PLUGIN_LOADING_ABSOLUTE), + HIPDNN_STATUS_SUCCESS); + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter policyEnv( + "HIPDNN_HEUR_POLICY_ORDER", hipdnn_tests::plugin_constants::testGoodHeuristicPolicyName()); + ASSERT_EQ(hipdnnCreate(&_handle), HIPDNN_STATUS_SUCCESS); + bindStream(); EXPECT_EQ(hipdnnBackendCreateDescriptor(HIPDNN_BACKEND_ENGINECFG_DESCRIPTOR, &_engineConfig), HIPDNN_STATUS_SUCCESS); ASSERT_NE(_engineConfig, nullptr); @@ -295,6 +324,7 @@ TEST_F(IntegrationPluginLoading, MultiplePluginsOneApplicableEngine) TEST_F(IntegrationPluginLoading, MultiplePluginsMultipleApplicableEngines) { + SKIP_IF_NO_DEVICES(); const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter envSetter( "HIPDNN_PLUGIN_DIR", getTestPluginDefaultDir()); @@ -305,7 +335,16 @@ TEST_F(IntegrationPluginLoading, MultiplePluginsMultipleApplicableEngines) hipdnnSetEnginePluginPaths_ext(paths.size(), paths.data(), HIPDNN_PLUGIN_LOADING_ADDITIVE), HIPDNN_STATUS_SUCCESS); + const std::array heuristicPaths + = {hipdnn_tests::plugin_constants::testGoodHeuristicPluginPath().c_str()}; + ASSERT_EQ(hipdnnSetHeuristicPluginPaths_ext( + heuristicPaths.size(), heuristicPaths.data(), HIPDNN_PLUGIN_LOADING_ABSOLUTE), + HIPDNN_STATUS_SUCCESS); + const hipdnn_test_sdk::utilities::ScopedEnvironmentVariableSetter policyEnv( + "HIPDNN_HEUR_POLICY_ORDER", hipdnn_tests::plugin_constants::testGoodHeuristicPolicyName()); + ASSERT_EQ(hipdnnCreate(&_handle), HIPDNN_STATUS_SUCCESS); + bindStream(); EXPECT_EQ(hipdnnBackendCreateDescriptor(HIPDNN_BACKEND_ENGINECFG_DESCRIPTOR, &_engineConfig), HIPDNN_STATUS_SUCCESS); ASSERT_NE(_engineConfig, nullptr); diff --git a/projects/hipdnn/tests/frontend/CMakeLists.txt b/projects/hipdnn/tests/frontend/CMakeLists.txt index 28101a4d7c2..be9a968a3f8 100644 --- a/projects/hipdnn/tests/frontend/CMakeLists.txt +++ b/projects/hipdnn/tests/frontend/CMakeLists.txt @@ -93,6 +93,7 @@ target_compile_definitions( TEST_KNOBS_PLUGIN_NAME="${TEST_KNOBS_PLUGIN_NAME}" TEST_KNOB_CONSTRAINT_VALIDATION_PLUGIN_NAME="${TEST_KNOB_CONSTRAINT_VALIDATION_PLUGIN_NAME}" TEST_INCOMPATIBLE_VERSION_PLUGIN_NAME="${TEST_INCOMPATIBLE_VERSION_PLUGIN_NAME}" + TEST_GOOD_HEURISTIC_PLUGIN_NAME="${TEST_GOOD_HEURISTIC_PLUGIN_NAME}" ) # Ensure test plugins are built before tests @@ -108,6 +109,7 @@ add_dependencies( test_incomplete_api_plugin test_knobs_plugin test_knob_constraint_validation_plugin + test_good_heuristic_plugin ) clang_tidy_check(hipdnn_public_frontend_tests) diff --git a/projects/hipdnn/tests/frontend/IntegrationGraphEngineFiltering.cpp b/projects/hipdnn/tests/frontend/IntegrationGraphEngineFiltering.cpp index cdebadc4a47..3ac1dbfaf0c 100644 --- a/projects/hipdnn/tests/frontend/IntegrationGraphEngineFiltering.cpp +++ b/projects/hipdnn/tests/frontend/IntegrationGraphEngineFiltering.cpp @@ -1,11 +1,15 @@ // Copyright © Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +#include +#include #include #include +#include #include #include +#include #include #include @@ -66,6 +70,33 @@ class IntegrationGraphEngineFiltering : public ::testing::TestWithParam yTensor; }; + // This suite verifies preferred_engine_id behavior, which the frontend + // resolves as a post-hoc reorder of the heuristic-ranked engine configs + // (see Graph::initializeEngineConfig). The HIPDNN_HEUR_CONFIG_PATH + // env knob lives in the SelectionHeuristic::Config built-in instead. We + // only need to chain test_good_heuristic_plugin so the heuristic loop has + // a ranked list to reorder against. + static void SetUpTestSuite() + { + const std::array heuristicPaths + = {hipdnn_tests::plugin_constants::testGoodHeuristicPluginPath().c_str()}; + ASSERT_EQ(hipdnnSetHeuristicPluginPaths_ext( + heuristicPaths.size(), heuristicPaths.data(), HIPDNN_PLUGIN_LOADING_ABSOLUTE), + HIPDNN_STATUS_SUCCESS); + sPolicyOrderEnv.emplace("HIPDNN_HEUR_POLICY_ORDER", + hipdnn_tests::plugin_constants::testGoodHeuristicPolicyName()); + } + + static void TearDownTestSuite() + { + sPolicyOrderEnv.reset(); + const std::array heuristicPaths + = {hipdnn_tests::plugin_constants::testGoodHeuristicPluginPath().c_str()}; + ASSERT_EQ(hipdnnSetHeuristicPluginPaths_ext( + heuristicPaths.size(), heuristicPaths.data(), HIPDNN_PLUGIN_LOADING_ABSOLUTE), + HIPDNN_STATUS_SUCCESS); + } + void SetUp() override { SKIP_IF_NO_DEVICES(); @@ -100,6 +131,9 @@ class IntegrationGraphEngineFiltering : public ::testing::TestWithParam + sPolicyOrderEnv; + static std::shared_ptr createSimplePointwiseGraph(const std::string& graphName, const std::vector& dims) { @@ -149,6 +183,14 @@ class IntegrationGraphEngineFiltering : public ::testing::TestWithParambuild_operation_graph(_handle); ASSERT_EQ(result.code, ErrorCode::OK) << result.err_msg; + // Capture the heuristic-ranked engine list before plan creation. The + // preferred-engine setter is a post-hoc reorder over this list, so when + // the preferred ID isn't present, execute() must follow rankedEngineIds[0]. + std::vector rankedEngineIds; + result = graph->get_ranked_engine_ids(rankedEngineIds); + ASSERT_EQ(result.code, ErrorCode::OK) << result.err_msg; + ASSERT_FALSE(rankedEngineIds.empty()); + result = graph->create_execution_plans(); ASSERT_EQ(result.code, ErrorCode::OK) << result.err_msg; @@ -164,7 +206,6 @@ class IntegrationGraphEngineFiltering : public ::testing::TestWithParamexecute(_handle, variantPack, nullptr); - // For non-deterministic engine selection we don't check if it's successful. if(testCase.shouldSucceed.has_value() && testCase.shouldSucceed.value()) { ASSERT_EQ(result.code, ErrorCode::OK) << result.err_msg; @@ -173,12 +214,40 @@ class IntegrationGraphEngineFiltering : public ::testing::TestWithParam(); + if(rankedEngineIds.front() == failingEngineId) + { + ASSERT_NE(result.code, ErrorCode::OK) + << "Top-ranked engine is the failing plugin; execute should have failed"; + } + else + { + ASSERT_EQ(result.code, ErrorCode::OK) << result.err_msg; + } + } } private: hipdnnHandle_t _handle = nullptr; }; +std::optional + IntegrationGraphEngineFiltering::sPolicyOrderEnv; + } // namespace INSTANTIATE_TEST_SUITE_P( diff --git a/projects/hipdnn/tests/frontend/IntegrationGraphSupportCheck.cpp b/projects/hipdnn/tests/frontend/IntegrationGraphSupportCheck.cpp index 8e8e69cbde4..c2129423677 100644 --- a/projects/hipdnn/tests/frontend/IntegrationGraphSupportCheck.cpp +++ b/projects/hipdnn/tests/frontend/IntegrationGraphSupportCheck.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include using namespace hipdnn_frontend; @@ -19,8 +20,13 @@ class IntegrationGraphSupportCheck : public ::testing::Test protected: void SetUp() override { + // is_supported_ext drives finalize(), which reads the device from the + // handle's stream — bind a real one and skip on no-GPU runners. + SKIP_IF_NO_DEVICES(); loadPlugins({hipdnn_tests::plugin_constants::testGoodPluginPath().c_str()}); ASSERT_EQ(hipdnnCreate(&_handle), HIPDNN_STATUS_SUCCESS); + ASSERT_EQ(hipStreamCreate(&_stream), hipSuccess); + ASSERT_EQ(hipdnnSetStream(_handle, _stream), HIPDNN_STATUS_SUCCESS); } void TearDown() override @@ -29,6 +35,11 @@ class IntegrationGraphSupportCheck : public ::testing::Test { ASSERT_EQ(hipdnnDestroy(_handle), HIPDNN_STATUS_SUCCESS); } + if(_stream != nullptr) + { + ASSERT_EQ(hipStreamDestroy(_stream), hipSuccess); + _stream = nullptr; + } } static void loadPlugins(std::initializer_list pluginPaths) @@ -46,9 +57,11 @@ class IntegrationGraphSupportCheck : public ::testing::Test loadPlugins(pluginPaths); ASSERT_EQ(hipdnnCreate(&_handle), HIPDNN_STATUS_SUCCESS); + ASSERT_EQ(hipdnnSetStream(_handle, _stream), HIPDNN_STATUS_SUCCESS); } hipdnnHandle_t _handle = nullptr; + hipStream_t _stream = nullptr; }; TEST_F(IntegrationGraphSupportCheck, SupportedWithGoodPlugin) diff --git a/projects/hipdnn/tests/frontend/main.cpp b/projects/hipdnn/tests/frontend/main.cpp index 059e18cc98d..66c883e9beb 100644 --- a/projects/hipdnn/tests/frontend/main.cpp +++ b/projects/hipdnn/tests/frontend/main.cpp @@ -3,11 +3,15 @@ Copyright © Advanced Micro Devices, Inc., or its affiliates. SPDX-License-Identifier: MIT */ +#include #include +#include +#include #include #include #include +#include int main(int argc, char** argv) { @@ -28,6 +32,21 @@ int main(int argc, char** argv) testing::TestEventListeners& listeners = testing::UnitTest::GetInstance()->listeners(); listeners.Append(new hipdnn_test_sdk::utilities::HipErrorHandler); + // Frontend integration tests don't exercise heuristic-plugin behavior; they just + // need a generic working heuristic so engine selection succeeds. Wire the + // engine-agnostic test_good_heuristic_plugin in once before any test runs (no + // active handles allowed when changing heuristic plugin paths). + const std::array heuristicPaths + = {hipdnn_tests::plugin_constants::testGoodHeuristicPluginPath().c_str()}; + if(hipdnnSetHeuristicPluginPaths_ext( + heuristicPaths.size(), heuristicPaths.data(), HIPDNN_PLUGIN_LOADING_ABSOLUTE) + != HIPDNN_STATUS_SUCCESS) + { + return 1; + } + hipdnn_data_sdk::utilities::setEnv( + "HIPDNN_HEUR_POLICY_ORDER", hipdnn_tests::plugin_constants::testGoodHeuristicPolicyName()); + auto result = RUN_ALL_TESTS(); return result; } diff --git a/projects/hipdnn/tests/test_plugins/CMakeLists.txt b/projects/hipdnn/tests/test_plugins/CMakeLists.txt index b26a479197c..5311d6e9872 100644 --- a/projects/hipdnn/tests/test_plugins/CMakeLists.txt +++ b/projects/hipdnn/tests/test_plugins/CMakeLists.txt @@ -5,23 +5,10 @@ cmake_minimum_required(VERSION 3.25.2) find_package(hip REQUIRED) -# Define plugin names as variables -set(TEST_GOOD_PLUGIN_NAME "test_good_plugin") -set(TEST_EXECUTE_FAILS_PLUGIN_NAME "test_execute_fails_plugin") -set(TEST_NO_APPLICABLE_ENGINES_A_PLUGIN_NAME "test_no_applicable_engines_a_plugin") -set(TEST_NO_APPLICABLE_ENGINES_B_PLUGIN_NAME "test_no_applicable_engines_b_plugin") -set(TEST_DUPLICATE_ID_A_PLUGIN_NAME "test_duplicate_id_a_plugin") -set(TEST_DUPLICATE_ID_B_PLUGIN_NAME "test_duplicate_id_b_plugin") -set(TEST_INCOMPLETE_API_PLUGIN_NAME "test_incomplete_api_plugin") -set(TEST_GOOD_DEFAULT_PLUGIN_NAME "test_good_default_plugin") -set(TEST_KNOBS_PLUGIN_NAME "test_knobs_plugin") -set(TEST_KNOB_CONSTRAINT_VALIDATION_PLUGIN_NAME "test_knob_constraint_validation_plugin") -set(TEST_INCOMPATIBLE_VERSION_PLUGIN_NAME "test_incompatible_version_plugin") - -# Heuristic plugin test names (RFC 0007) -set(TEST_GOOD_HEURISTIC_PLUGIN_NAME "test_good_heuristic_plugin") -set(TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME "test_incomplete_heuristic_api_plugin") -set(TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME "test_no_optional_heuristic_plugin") +# Plugin target names are defined in cmake/TestPluginNames.cmake (included by +# the root CMakeLists.txt) so they are available to both this directory and +# the test executables that depend on these targets, regardless of the order +# subdirectories are added. # Function to create a test plugin with common configuration function(add_test_plugin target_name) @@ -96,7 +83,7 @@ set(test_good_default_plugin_OUTPUT_DIR "${HIPDNN_TEST_PLUGIN_DIR}/default") set(test_good_default_plugin_INSTALL_DIR "${HIPDNN_TEST_PLUGIN_INSTALL_DIR}/default") add_test_plugin(${TEST_GOOD_DEFAULT_PLUGIN_NAME} TestGoodDefaultPlugin.cpp) -# Heuristic test plugins (RFC 0007 Part 1) +# Heuristic test plugins # test_good_heuristic_plugin implements the full heuristic plugin API add_test_plugin(${TEST_GOOD_HEURISTIC_PLUGIN_NAME} TestGoodHeuristicPlugin.cpp) @@ -106,22 +93,15 @@ add_test_plugin(${TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME} TestIncompleteHeuri # test_no_optional_heuristic_plugin omits optional symbols (GetPolicyName, SetLogLevel) add_test_plugin(${TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME} TestNoOptionalHeuristicPlugin.cpp) -# Export plugin names to parent scope for use in test compilation -set(TEST_GOOD_PLUGIN_NAME "${TEST_GOOD_PLUGIN_NAME}" PARENT_SCOPE) -set(TEST_EXECUTE_FAILS_PLUGIN_NAME "${TEST_EXECUTE_FAILS_PLUGIN_NAME}" PARENT_SCOPE) -set(TEST_NO_APPLICABLE_ENGINES_A_PLUGIN_NAME "${TEST_NO_APPLICABLE_ENGINES_A_PLUGIN_NAME}" - PARENT_SCOPE -) -set(TEST_NO_APPLICABLE_ENGINES_B_PLUGIN_NAME "${TEST_NO_APPLICABLE_ENGINES_B_PLUGIN_NAME}" - PARENT_SCOPE -) -set(TEST_DUPLICATE_ID_A_PLUGIN_NAME "${TEST_DUPLICATE_ID_A_PLUGIN_NAME}" PARENT_SCOPE) -set(TEST_DUPLICATE_ID_B_PLUGIN_NAME "${TEST_DUPLICATE_ID_B_PLUGIN_NAME}" PARENT_SCOPE) -set(TEST_INCOMPLETE_API_PLUGIN_NAME "${TEST_INCOMPLETE_API_PLUGIN_NAME}" PARENT_SCOPE) -set(TEST_GOOD_DEFAULT_PLUGIN_NAME "${TEST_GOOD_DEFAULT_PLUGIN_NAME}" PARENT_SCOPE) -set(TEST_KNOBS_PLUGIN_NAME "${TEST_KNOBS_PLUGIN_NAME}" PARENT_SCOPE) -set(TEST_KNOB_CONSTRAINT_VALIDATION_PLUGIN_NAME "${TEST_KNOB_CONSTRAINT_VALIDATION_PLUGIN_NAME}" PARENT_SCOPE) -set(TEST_INCOMPATIBLE_VERSION_PLUGIN_NAME "${TEST_INCOMPATIBLE_VERSION_PLUGIN_NAME}" PARENT_SCOPE) -set(TEST_GOOD_HEURISTIC_PLUGIN_NAME "${TEST_GOOD_HEURISTIC_PLUGIN_NAME}" PARENT_SCOPE) -set(TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME "${TEST_INCOMPLETE_HEURISTIC_API_PLUGIN_NAME}" PARENT_SCOPE) -set(TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME "${TEST_NO_OPTIONAL_HEURISTIC_PLUGIN_NAME}" PARENT_SCOPE) +# test_bad_api_version_heuristic_plugin returns incompatible API version +add_test_plugin(${TEST_BAD_API_VERSION_HEURISTIC_PLUGIN_NAME} TestBadApiVersionHeuristicPlugin.cpp) + +# test_empty_name_heuristic_plugin returns empty policy name +add_test_plugin(${TEST_EMPTY_NAME_HEURISTIC_PLUGIN_NAME} TestEmptyNameHeuristicPlugin.cpp) + +# test_duplicate_policy_id_a/b_plugin return same policy name (same ID) +add_test_plugin(${TEST_DUPLICATE_POLICY_ID_A_PLUGIN_NAME} TestDuplicatePolicyIdAPlugin.cpp) +add_test_plugin(${TEST_DUPLICATE_POLICY_ID_B_PLUGIN_NAME} TestDuplicatePolicyIdBPlugin.cpp) + +# Plugin names are exported via cmake/TestPluginNames.cmake; no PARENT_SCOPE +# re-export needed. diff --git a/projects/hipdnn/tests/test_plugins/TestBadApiVersionHeuristicPlugin.cpp b/projects/hipdnn/tests/test_plugins/TestBadApiVersionHeuristicPlugin.cpp new file mode 100644 index 00000000000..df9daa8721d --- /dev/null +++ b/projects/hipdnn/tests/test_plugins/TestBadApiVersionHeuristicPlugin.cpp @@ -0,0 +1,49 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file TestBadApiVersionHeuristicPlugin.cpp + * @brief Test plugin that returns incompatible API version + * + * This plugin intentionally returns an API version with wrong major version + * to trigger HeuristicPluginManager::validateBeforeAdding() API version check + * (lines 55-66 in HeuristicPluginManager.hpp). + */ + +#include "TestHeuristicPluginBase.hpp" + +#include + +// NOLINTNEXTLINE +thread_local char + hipdnn_plugin_sdk::PluginLastErrorManager::s_lastError[HIPDNN_PLUGIN_ERROR_STRING_MAX_LENGTH] + = ""; + +class BadApiVersionHeuristicPlugin : public TestHeuristicPluginBase +{ +public: + const char* getPolicyName() const override + { + return "TestBadApiVersionPolicy"; + } + + const char* getPluginVersion() const override + { + return "1.0.0"; + } + + const char* getApiVersion() const override + { + // Return incompatible API version (wrong major version) + return "99.0.0"; + } +}; + +// Initialize plugin instance on load +__attribute__((constructor)) static void initializePlugin() +{ + TestHeuristicPluginBase::setInstance(std::make_unique()); +} + +// Register all API functions using the macro +REGISTER_HEURISTIC_PLUGIN_API() diff --git a/projects/hipdnn/tests/test_plugins/TestDuplicatePolicyIdAPlugin.cpp b/projects/hipdnn/tests/test_plugins/TestDuplicatePolicyIdAPlugin.cpp new file mode 100644 index 00000000000..a3940d8bcaf --- /dev/null +++ b/projects/hipdnn/tests/test_plugins/TestDuplicatePolicyIdAPlugin.cpp @@ -0,0 +1,45 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file TestDuplicatePolicyIdAPlugin.cpp + * @brief Test plugin with duplicate policy ID (plugin A) + * + * This plugin returns the same policy name as TestDuplicatePolicyIdBPlugin, + * which generates the same policy ID (via FNV-1a hash of the name). + * This triggers HeuristicPluginManager::validateBeforeAdding() duplicate ID check + * (lines 70-78 in HeuristicPluginManager.hpp). + */ + +#include "TestHeuristicPluginBase.hpp" + +#include + +// NOLINTNEXTLINE +thread_local char + hipdnn_plugin_sdk::PluginLastErrorManager::s_lastError[HIPDNN_PLUGIN_ERROR_STRING_MAX_LENGTH] + = ""; + +class DuplicatePolicyIdAPlugin : public TestHeuristicPluginBase +{ +public: + const char* getPolicyName() const override + { + // Same name as plugin B -> same policy ID + return "TestDuplicatePolicyName"; + } + + const char* getPluginVersion() const override + { + return "1.0.0-duplicate-a"; + } +}; + +// Initialize plugin instance on load +__attribute__((constructor)) static void initializePlugin() +{ + TestHeuristicPluginBase::setInstance(std::make_unique()); +} + +// Register all API functions using the macro +REGISTER_HEURISTIC_PLUGIN_API() diff --git a/projects/hipdnn/tests/test_plugins/TestDuplicatePolicyIdBPlugin.cpp b/projects/hipdnn/tests/test_plugins/TestDuplicatePolicyIdBPlugin.cpp new file mode 100644 index 00000000000..4eb9237432d --- /dev/null +++ b/projects/hipdnn/tests/test_plugins/TestDuplicatePolicyIdBPlugin.cpp @@ -0,0 +1,45 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file TestDuplicatePolicyIdBPlugin.cpp + * @brief Test plugin with duplicate policy ID (plugin B) + * + * This plugin returns the same policy name as TestDuplicatePolicyIdAPlugin, + * which generates the same policy ID (via FNV-1a hash of the name). + * This triggers HeuristicPluginManager::validateBeforeAdding() duplicate ID check + * (lines 70-78 in HeuristicPluginManager.hpp). + */ + +#include "TestHeuristicPluginBase.hpp" + +#include + +// NOLINTNEXTLINE +thread_local char + hipdnn_plugin_sdk::PluginLastErrorManager::s_lastError[HIPDNN_PLUGIN_ERROR_STRING_MAX_LENGTH] + = ""; + +class DuplicatePolicyIdBPlugin : public TestHeuristicPluginBase +{ +public: + const char* getPolicyName() const override + { + // Same name as plugin A -> same policy ID + return "TestDuplicatePolicyName"; + } + + const char* getPluginVersion() const override + { + return "1.0.0-duplicate-b"; + } +}; + +// Initialize plugin instance on load +__attribute__((constructor)) static void initializePlugin() +{ + TestHeuristicPluginBase::setInstance(std::make_unique()); +} + +// Register all API functions using the macro +REGISTER_HEURISTIC_PLUGIN_API() diff --git a/projects/hipdnn/tests/test_plugins/TestEmptyNameHeuristicPlugin.cpp b/projects/hipdnn/tests/test_plugins/TestEmptyNameHeuristicPlugin.cpp new file mode 100644 index 00000000000..ae950cf1481 --- /dev/null +++ b/projects/hipdnn/tests/test_plugins/TestEmptyNameHeuristicPlugin.cpp @@ -0,0 +1,44 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file TestEmptyNameHeuristicPlugin.cpp + * @brief Test plugin that returns empty policy name + * + * This plugin intentionally returns an empty string for policy name + * to trigger HeuristicPluginManager::validateBeforeAdding() policy name check + * (lines 82-89 in HeuristicPluginManager.hpp). + */ + +#include "TestHeuristicPluginBase.hpp" + +#include + +// NOLINTNEXTLINE +thread_local char + hipdnn_plugin_sdk::PluginLastErrorManager::s_lastError[HIPDNN_PLUGIN_ERROR_STRING_MAX_LENGTH] + = ""; + +class EmptyNameHeuristicPlugin : public TestHeuristicPluginBase +{ +public: + const char* getPolicyName() const override + { + // Return empty policy name to trigger validation error + return ""; + } + + const char* getPluginVersion() const override + { + return "1.0.0"; + } +}; + +// Initialize plugin instance on load +__attribute__((constructor)) static void initializePlugin() +{ + TestHeuristicPluginBase::setInstance(std::make_unique()); +} + +// Register all API functions using the macro +REGISTER_HEURISTIC_PLUGIN_API() diff --git a/projects/hipdnn/tests/test_plugins/TestGoodHeuristicPlugin.cpp b/projects/hipdnn/tests/test_plugins/TestGoodHeuristicPlugin.cpp index cf01872ae1f..b5167a63b49 100644 --- a/projects/hipdnn/tests/test_plugins/TestGoodHeuristicPlugin.cpp +++ b/projects/hipdnn/tests/test_plugins/TestGoodHeuristicPlugin.cpp @@ -1,10 +1,12 @@ // Copyright © Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +#include #include #include #include +#include #include #include #include @@ -18,6 +20,7 @@ thread_local char namespace { // NOLINTBEGIN(readability-identifier-naming) +const char* PLUGIN_NAME = "TestGoodHeuristicPlugin"; const char* POLICY_NAME = "TestGoodHeuristicPolicy"; const char* PLUGIN_VERSION = "1.0.0"; @@ -56,7 +59,7 @@ hipdnnPluginStatus_t hipdnnPluginGetName(const char** name) "name pointer is null"); return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; } - *name = POLICY_NAME; + *name = PLUGIN_NAME; return HIPDNN_PLUGIN_STATUS_SUCCESS; } @@ -96,6 +99,51 @@ hipdnnPluginStatus_t hipdnnPluginGetType(hipdnnPluginType_t* type) return HIPDNN_PLUGIN_STATUS_SUCCESS; } +// ========== Policy Enumeration ========== + +hipdnnPluginStatus_t hipdnnHeuristicPluginGetAllPolicyIds(int64_t* policy_ids, + uint32_t max_policies, + uint32_t* num_policies) +{ + if(num_policies == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError(HIPDNN_PLUGIN_STATUS_INVALID_VALUE, + "num_policies pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + *num_policies = 1; + if(policy_ids == nullptr || max_policies == 0) + { + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + if(max_policies < 1) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "max_policies smaller than available count"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + policy_ids[0] = hipdnn_data_sdk::utilities::policyNameToId(POLICY_NAME); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnHeuristicPluginGetPolicyName(int64_t policy_id, const char** name) +{ + if(name == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError(HIPDNN_PLUGIN_STATUS_INVALID_VALUE, + "name pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + if(policy_id != hipdnn_data_sdk::utilities::policyNameToId(POLICY_NAME)) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError(HIPDNN_PLUGIN_STATUS_BAD_PARAM, + "unknown policy id"); + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; + } + *name = POLICY_NAME; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + hipdnnPluginStatus_t hipdnnPluginSetLoggingCallback(hipdnnCallback_t callback) { g_loggingCallback = callback; @@ -172,6 +220,7 @@ hipdnnPluginStatus_t hipdnnPluginStatus_t hipdnnHeuristicPolicyDescriptorCreate(hipdnnHeuristicHandle_t handle, + int64_t policy_id, hipdnnHeuristicPolicyDescriptor_t* out_descriptor) { if(handle == nullptr) @@ -186,6 +235,12 @@ hipdnnPluginStatus_t "out_descriptor pointer is null"); return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; } + if(policy_id != hipdnn_data_sdk::utilities::policyNameToId(POLICY_NAME)) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError(HIPDNN_PLUGIN_STATUS_BAD_PARAM, + "unknown policy id"); + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; + } auto* desc = new PolicyDescriptorImpl{}; *out_descriptor = reinterpret_cast(desc); diff --git a/projects/hipdnn/tests/test_plugins/TestHeuristicPluginBase.hpp b/projects/hipdnn/tests/test_plugins/TestHeuristicPluginBase.hpp new file mode 100644 index 00000000000..7dfa47576aa --- /dev/null +++ b/projects/hipdnn/tests/test_plugins/TestHeuristicPluginBase.hpp @@ -0,0 +1,594 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace +{ +// Simple handle implementation for test plugins +struct HeuristicHandleImpl +{ + int handleId; + bool devicePropsSet{false}; +}; + +// Simple policy descriptor implementation for test plugins +struct PolicyDescriptorImpl +{ + std::vector inputEngineIds; + std::vector serializedGraph; + std::vector sortedEngineIds; + bool finalized{false}; +}; + +// Callback state +// NOLINTBEGIN(readability-identifier-naming) +hipdnnCallback_t g_loggingCallback = nullptr; +hipdnnSeverity_t g_logLevel = HIPDNN_SEV_INFO; +// NOLINTEND(readability-identifier-naming) + +} // anonymous namespace + +/** + * @brief Base class for test heuristic plugins with common implementations. + * + * Single-policy convenience: derived classes only need to override getPolicyName() + * and a single-policy plugin will be exposed. Multi-policy plugins should override + * getAllPolicyNames() instead. + * + * Optional overrides: + * - getPluginName() - the plugin (library) name returned by hipdnnPluginGetName. + * Defaults to "TestHeuristicPlugin"; override to test invalid plugin names. + * - getApiVersion() - return a wrong API version to test ABI rejection. + * - getPluginVersion() - customize plugin implementation version. + */ +class TestHeuristicPluginBase +{ +public: + virtual ~TestHeuristicPluginBase() = default; + + // Single-policy convenience override (default implementation throws if not overridden + // and getAllPolicyNames is also unset). + virtual const char* getPolicyName() const + { + return ""; + } + + // Override for multi-policy plugins. Default implementation returns a single entry + // taken from getPolicyName(), so single-policy plugins keep their existing behavior. + virtual std::vector getAllPolicyNames() const + { + return {std::string(getPolicyName())}; + } + + // Plugin (library) name returned by hipdnnPluginGetName. Distinct from policy names. + virtual const char* getPluginName() const + { + return "TestHeuristicPlugin"; + } + + virtual const char* getPluginVersion() const + { + return "1.0.0"; + } + + virtual const char* getApiVersion() const + { + return HIPDNN_HEURISTIC_API_VERSION; + } + + // Static instance management using Meyer's singleton pattern to avoid ODR issues + static std::unique_ptr& getInstanceStorage() + { + // NOLINTNEXTLINE(readability-identifier-naming) + static std::unique_ptr sInstance = nullptr; + return sInstance; + } + + static void setInstance(std::unique_ptr instance) + { + getInstanceStorage() = std::move(instance); + } + + static TestHeuristicPluginBase* getInstance() + { + return getInstanceStorage().get(); + } + + // ========== Common API Implementations ========== + + static hipdnnPluginStatus_t getName(const char** name) + { + if(name == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "name pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + if(getInstance() == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR, "plugin instance is null"); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } + *name = getInstance()->getPluginName(); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + // ========== Policy Enumeration ========== + + // Cached policy names so the const char* returned to the C ABI remains valid for the + // lifetime of the loaded library. + static std::vector& getCachedPolicyNames() + { + // NOLINTNEXTLINE(readability-identifier-naming) + static std::vector sNames; + return sNames; + } + + static const std::vector& policyNames() + { + auto& cached = getCachedPolicyNames(); + if(cached.empty() && getInstance() != nullptr) + { + cached = getInstance()->getAllPolicyNames(); + } + return cached; + } + + static hipdnnPluginStatus_t + getAllPolicyIds(int64_t* policyIds, uint32_t maxPolicies, uint32_t* numPolicies) + { + if(numPolicies == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "num_policies pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + if(getInstance() == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR, "plugin instance is null"); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } + + const auto& names = policyNames(); + const auto total = static_cast(names.size()); + *numPolicies = total; + if(policyIds == nullptr || maxPolicies == 0) + { + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + if(maxPolicies < total) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "max_policies smaller than available count"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + for(uint32_t i = 0; i < total; ++i) + { + policyIds[i] = hipdnn_data_sdk::utilities::policyNameToId(names[i]); + } + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static hipdnnPluginStatus_t getPolicyName(int64_t policyId, const char** name) + { + if(name == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "name pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + for(const auto& candidate : policyNames()) + { + if(hipdnn_data_sdk::utilities::policyNameToId(candidate) == policyId) + { + *name = candidate.c_str(); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + } + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError(HIPDNN_PLUGIN_STATUS_BAD_PARAM, + "unknown policy id"); + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; + } + + static hipdnnPluginStatus_t getVersion(const char** version) + { + if(version == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "version pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + if(getInstance() == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR, "plugin instance is null"); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } + *version = getInstance()->getPluginVersion(); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static hipdnnPluginStatus_t getApiVersion(const char** version) + { + if(version == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "version pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + if(getInstance() == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR, "plugin instance is null"); + return HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR; + } + *version = getInstance()->getApiVersion(); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static hipdnnPluginStatus_t getType(hipdnnPluginType_t* type) + { + if(type == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "type pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + *type = HIPDNN_PLUGIN_TYPE_HEURISTIC; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static hipdnnPluginStatus_t setLoggingCallback(hipdnnCallback_t callback) + { + g_loggingCallback = callback; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static hipdnnPluginStatus_t setLogLevel(hipdnnSeverity_t level) + { + g_logLevel = level; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static void getLastErrorString(const char** errorStr) + { + if(errorStr != nullptr) + { + *errorStr = hipdnn_plugin_sdk::PluginLastErrorManager::getLastError(); + } + } + + static hipdnnPluginStatus_t handleCreate(hipdnnHeuristicHandle_t* outHandle) + { + if(outHandle == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "out_handle pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + + auto* handle = new HeuristicHandleImpl{42, false}; + *outHandle = reinterpret_cast(handle); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static hipdnnPluginStatus_t handleDestroy(hipdnnHeuristicHandle_t handle) + { + if(handle == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "handle is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + + auto* impl = reinterpret_cast(handle); + delete impl; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static hipdnnPluginStatus_t + handleSetDeviceProperties(hipdnnHeuristicHandle_t handle, + const hipdnnPluginConstData_t* devicePropsSerialized) + { + if(handle == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "handle is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + if(devicePropsSerialized == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "devicePropsSerialized pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + + auto* impl = reinterpret_cast(handle); + impl->devicePropsSet = true; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static hipdnnPluginStatus_t + policyDescriptorCreate(hipdnnHeuristicHandle_t handle, + int64_t policyId, + hipdnnHeuristicPolicyDescriptor_t* outDescriptor) + { + if(handle == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "handle is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + if(outDescriptor == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "out_descriptor pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + + bool found = false; + for(const auto& candidate : policyNames()) + { + if(hipdnn_data_sdk::utilities::policyNameToId(candidate) == policyId) + { + found = true; + break; + } + } + if(!found) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError(HIPDNN_PLUGIN_STATUS_BAD_PARAM, + "unknown policy id"); + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; + } + + auto* desc = new PolicyDescriptorImpl{}; + *outDescriptor = reinterpret_cast(desc); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static hipdnnPluginStatus_t + policyDescriptorDestroy(hipdnnHeuristicPolicyDescriptor_t descriptor) + { + if(descriptor == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "descriptor is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + + auto* impl = reinterpret_cast(descriptor); + delete impl; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static hipdnnPluginStatus_t policySetEngineIds(hipdnnHeuristicPolicyDescriptor_t descriptor, + const int64_t* engineIds, + size_t engineIdCount) + { + if(descriptor == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "descriptor is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + if(engineIds == nullptr && engineIdCount > 0) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "engine_ids is null but count > 0"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + + auto* impl = reinterpret_cast(descriptor); + impl->inputEngineIds.assign(engineIds, engineIds + engineIdCount); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static hipdnnPluginStatus_t + policySetSerializedGraph(hipdnnHeuristicPolicyDescriptor_t descriptor, + const hipdnnPluginConstData_t* serializedGraph) + { + if(descriptor == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "descriptor is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + if(serializedGraph == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "serialized_graph pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + + auto* impl = reinterpret_cast(descriptor); + const auto* bytes = static_cast(serializedGraph->ptr); + impl->serializedGraph.assign(bytes, bytes + serializedGraph->size); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static hipdnnPluginStatus_t policyFinalize(hipdnnHeuristicPolicyDescriptor_t descriptor, + int32_t* applied) + { + if(descriptor == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "descriptor is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + if(applied == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "applied pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + + auto* impl = reinterpret_cast(descriptor); + + // Simple policy: reverse the input order + impl->sortedEngineIds = impl->inputEngineIds; + std::reverse(impl->sortedEngineIds.begin(), impl->sortedEngineIds.end()); + + impl->finalized = true; + *applied = 1; // Policy applied + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + static hipdnnPluginStatus_t policyGetSortedEngineIds( + hipdnnHeuristicPolicyDescriptor_t descriptor, int64_t* engineIds, size_t* count) + { + if(descriptor == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "descriptor is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + if(count == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "count pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + + auto* impl = reinterpret_cast(descriptor); + + if(engineIds == nullptr) + { + // Query mode: return count only + *count = impl->sortedEngineIds.size(); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + // Retrieve mode: copy IDs + const size_t numToCopy = std::min(*count, impl->sortedEngineIds.size()); + std::memcpy(engineIds, impl->sortedEngineIds.data(), numToCopy * sizeof(int64_t)); + *count = numToCopy; + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } +}; + +// Macro to register all heuristic plugin API functions +// NOLINTBEGIN(readability-identifier-naming,cppcoreguidelines-macro-usage) +#define REGISTER_HEURISTIC_PLUGIN_API() \ + extern "C" { \ + hipdnnPluginStatus_t hipdnnPluginGetName(const char** name) \ + { \ + return TestHeuristicPluginBase::getName(name); \ + } \ + \ + hipdnnPluginStatus_t hipdnnPluginGetVersion(const char** version) \ + { \ + return TestHeuristicPluginBase::getVersion(version); \ + } \ + \ + hipdnnPluginStatus_t hipdnnPluginGetApiVersion(const char** version) \ + { \ + return TestHeuristicPluginBase::getApiVersion(version); \ + } \ + \ + hipdnnPluginStatus_t hipdnnPluginGetType(hipdnnPluginType_t* type) \ + { \ + return TestHeuristicPluginBase::getType(type); \ + } \ + \ + hipdnnPluginStatus_t hipdnnHeuristicPluginGetAllPolicyIds(int64_t* policyIds, \ + uint32_t maxPolicies, \ + uint32_t* numPolicies) \ + { \ + return TestHeuristicPluginBase::getAllPolicyIds(policyIds, maxPolicies, numPolicies); \ + } \ + \ + hipdnnPluginStatus_t hipdnnHeuristicPluginGetPolicyName(int64_t policyId, const char** name) \ + { \ + return TestHeuristicPluginBase::getPolicyName(policyId, name); \ + } \ + \ + hipdnnPluginStatus_t hipdnnPluginSetLoggingCallback(hipdnnCallback_t callback) \ + { \ + return TestHeuristicPluginBase::setLoggingCallback(callback); \ + } \ + \ + hipdnnPluginStatus_t hipdnnPluginSetLogLevel(hipdnnSeverity_t level) \ + { \ + return TestHeuristicPluginBase::setLogLevel(level); \ + } \ + \ + void hipdnnPluginGetLastErrorString(const char** errorStr) \ + { \ + TestHeuristicPluginBase::getLastErrorString(errorStr); \ + } \ + \ + hipdnnPluginStatus_t hipdnnHeuristicHandleCreate(hipdnnHeuristicHandle_t* outHandle) \ + { \ + return TestHeuristicPluginBase::handleCreate(outHandle); \ + } \ + \ + hipdnnPluginStatus_t hipdnnHeuristicHandleDestroy(hipdnnHeuristicHandle_t handle) \ + { \ + return TestHeuristicPluginBase::handleDestroy(handle); \ + } \ + \ + hipdnnPluginStatus_t \ + hipdnnHeuristicHandleSetDeviceProperties(hipdnnHeuristicHandle_t handle, \ + const hipdnnPluginConstData_t* deviceProps) \ + { \ + return TestHeuristicPluginBase::handleSetDeviceProperties(handle, deviceProps); \ + } \ + \ + hipdnnPluginStatus_t \ + hipdnnHeuristicPolicyDescriptorCreate(hipdnnHeuristicHandle_t handle, \ + int64_t policyId, \ + hipdnnHeuristicPolicyDescriptor_t* outDescriptor) \ + { \ + return TestHeuristicPluginBase::policyDescriptorCreate(handle, policyId, outDescriptor); \ + } \ + \ + hipdnnPluginStatus_t \ + hipdnnHeuristicPolicyDescriptorDestroy(hipdnnHeuristicPolicyDescriptor_t descriptor) \ + { \ + return TestHeuristicPluginBase::policyDescriptorDestroy(descriptor); \ + } \ + \ + hipdnnPluginStatus_t \ + hipdnnHeuristicPolicySetEngineIds(hipdnnHeuristicPolicyDescriptor_t descriptor, \ + const int64_t* engineIds, \ + size_t engineIdCount) \ + { \ + return TestHeuristicPluginBase::policySetEngineIds(descriptor, engineIds, engineIdCount); \ + } \ + \ + hipdnnPluginStatus_t \ + hipdnnHeuristicPolicySetSerializedGraph(hipdnnHeuristicPolicyDescriptor_t descriptor, \ + const hipdnnPluginConstData_t* serializedGraph) \ + { \ + return TestHeuristicPluginBase::policySetSerializedGraph(descriptor, serializedGraph); \ + } \ + \ + hipdnnPluginStatus_t hipdnnHeuristicPolicyFinalize(hipdnnHeuristicPolicyDescriptor_t desc, \ + int32_t* applied) \ + { \ + return TestHeuristicPluginBase::policyFinalize(desc, applied); \ + } \ + \ + hipdnnPluginStatus_t hipdnnHeuristicPolicyGetSortedEngineIds( \ + hipdnnHeuristicPolicyDescriptor_t descriptor, int64_t* engineIds, size_t* count) \ + { \ + return TestHeuristicPluginBase::policyGetSortedEngineIds(descriptor, engineIds, count); \ + } \ + } +// NOLINTEND(readability-identifier-naming,cppcoreguidelines-macro-usage) diff --git a/projects/hipdnn/tests/test_plugins/TestIncompleteHeuristicApiPlugin.cpp b/projects/hipdnn/tests/test_plugins/TestIncompleteHeuristicApiPlugin.cpp index 7a9f41e3a0f..787e60f634c 100644 --- a/projects/hipdnn/tests/test_plugins/TestIncompleteHeuristicApiPlugin.cpp +++ b/projects/hipdnn/tests/test_plugins/TestIncompleteHeuristicApiPlugin.cpp @@ -98,9 +98,11 @@ hipdnnPluginStatus_t hipdnnPluginStatus_t hipdnnHeuristicPolicyDescriptorCreate(hipdnnHeuristicHandle_t handle, + int64_t policy_id, hipdnnHeuristicPolicyDescriptor_t* out_descriptor) { (void)handle; + (void)policy_id; if(out_descriptor == nullptr) { return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; diff --git a/projects/hipdnn/tests/test_plugins/TestNoOptionalHeuristicPlugin.cpp b/projects/hipdnn/tests/test_plugins/TestNoOptionalHeuristicPlugin.cpp index 80082c55180..25dde198cef 100644 --- a/projects/hipdnn/tests/test_plugins/TestNoOptionalHeuristicPlugin.cpp +++ b/projects/hipdnn/tests/test_plugins/TestNoOptionalHeuristicPlugin.cpp @@ -1,6 +1,7 @@ // Copyright © Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +#include #include #include #include @@ -25,6 +26,7 @@ thread_local char namespace { // NOLINTBEGIN(readability-identifier-naming) +const char* PLUGIN_NAME = "TestNoOptionalHeuristicPlugin"; const char* POLICY_NAME = "TestNoOptionalHeuristicPolicy"; // NOLINTEND(readability-identifier-naming) } // anonymous namespace @@ -46,15 +48,58 @@ hipdnnPluginStatus_t hipdnnPluginGetApiVersion(const char** version) return HIPDNN_PLUGIN_STATUS_SUCCESS; } -hipdnnPluginStatus_t hipdnnPluginGetName(const char** policy_name) +hipdnnPluginStatus_t hipdnnPluginGetName(const char** plugin_name) { - if(policy_name == nullptr) + if(plugin_name == nullptr) { hipdnn_plugin_sdk::PluginLastErrorManager::setLastError(HIPDNN_PLUGIN_STATUS_INVALID_VALUE, - "policy_name pointer is null"); + "plugin_name pointer is null"); return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; } - *policy_name = POLICY_NAME; + *plugin_name = PLUGIN_NAME; + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnHeuristicPluginGetAllPolicyIds(int64_t* policy_ids, + uint32_t max_policies, + uint32_t* num_policies) +{ + if(num_policies == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError(HIPDNN_PLUGIN_STATUS_INVALID_VALUE, + "num_policies pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + *num_policies = 1; + if(policy_ids == nullptr || max_policies == 0) + { + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + if(max_policies < 1) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_INVALID_VALUE, "max_policies smaller than available count"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + policy_ids[0] = hipdnn_data_sdk::utilities::policyNameToId(POLICY_NAME); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnHeuristicPluginGetPolicyName(int64_t policy_id, const char** name) +{ + if(name == nullptr) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError(HIPDNN_PLUGIN_STATUS_INVALID_VALUE, + "name pointer is null"); + return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; + } + if(policy_id != hipdnn_data_sdk::utilities::policyNameToId(POLICY_NAME)) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError(HIPDNN_PLUGIN_STATUS_BAD_PARAM, + "unknown policy id"); + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; + } + *name = POLICY_NAME; return HIPDNN_PLUGIN_STATUS_SUCCESS; } @@ -147,6 +192,7 @@ hipdnnPluginStatus_t hipdnnPluginStatus_t hipdnnHeuristicPolicyDescriptorCreate(hipdnnHeuristicHandle_t handle, + int64_t policy_id, hipdnnHeuristicPolicyDescriptor_t* out_descriptor) { if(handle == nullptr) @@ -161,6 +207,12 @@ hipdnnPluginStatus_t "out_descriptor pointer is null"); return HIPDNN_PLUGIN_STATUS_INVALID_VALUE; } + if(policy_id != hipdnn_data_sdk::utilities::policyNameToId(POLICY_NAME)) + { + hipdnn_plugin_sdk::PluginLastErrorManager::setLastError(HIPDNN_PLUGIN_STATUS_BAD_PARAM, + "unknown policy id"); + return HIPDNN_PLUGIN_STATUS_BAD_PARAM; + } // NOLINTNEXTLINE(performance-no-int-to-ptr) *out_descriptor = reinterpret_cast(0xDCBA); return HIPDNN_PLUGIN_STATUS_SUCCESS; diff --git a/projects/hipdnn/tests/test_plugins/TestPluginConstants.hpp b/projects/hipdnn/tests/test_plugins/TestPluginConstants.hpp index 800a14946e3..f014e2b9a2c 100644 --- a/projects/hipdnn/tests/test_plugins/TestPluginConstants.hpp +++ b/projects/hipdnn/tests/test_plugins/TestPluginConstants.hpp @@ -114,4 +114,19 @@ inline const std::string& testIncompatibleVersionPluginPath() = getTestCustomFilepathForPlugin(TEST_INCOMPATIBLE_VERSION_PLUGIN_NAME); return s_testIncompatibleVersionPluginPath; } + +// Heuristic test plugins. Policy name registered by test_good_heuristic_plugin -- +// callers that need a specific policy should set HIPDNN_HEUR_POLICY_ORDER to +// this value via a scoped env guard. +inline const char* testGoodHeuristicPolicyName() +{ + return "TestGoodHeuristicPolicy"; +} + +inline const std::string& testGoodHeuristicPluginPath() +{ + static const std::string s_testGoodHeuristicPluginPath + = getTestCustomFilepathForPlugin(TEST_GOOD_HEURISTIC_PLUGIN_NAME); + return s_testGoodHeuristicPluginPath; +} } // namespace hipdnn_tests::plugin_constants