From 4ff5df6eeda223d7b32c2ddcc11ed7a9b670917b Mon Sep 17 00:00:00 2001 From: Aaron St George Date: Thu, 18 Sep 2025 13:37:13 -0600 Subject: [PATCH 01/14] [FusilliPlugin] Basic CI, Build, and README (`fusilli-plugin` 1 of N) (#2275) Adds basic CI, CMake build scrips, and Readme. --- projects/fusilli-plugin/.clang-format | 1 + projects/fusilli-plugin/.gitignore | 8 + projects/fusilli-plugin/CMakeLists.txt | 35 ++++ projects/fusilli-plugin/README.md | 26 +++ .../cmake/FusilliPluginDependencyUtils.cmake | 194 ++++++++++++++++++ 5 files changed, 264 insertions(+) create mode 100644 projects/fusilli-plugin/.clang-format create mode 100644 projects/fusilli-plugin/.gitignore create mode 100644 projects/fusilli-plugin/CMakeLists.txt create mode 100644 projects/fusilli-plugin/README.md create mode 100644 projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake diff --git a/projects/fusilli-plugin/.clang-format b/projects/fusilli-plugin/.clang-format new file mode 100644 index 00000000000..9b3aa8b7213 --- /dev/null +++ b/projects/fusilli-plugin/.clang-format @@ -0,0 +1 @@ +BasedOnStyle: LLVM diff --git a/projects/fusilli-plugin/.gitignore b/projects/fusilli-plugin/.gitignore new file mode 100644 index 00000000000..425ada20f17 --- /dev/null +++ b/projects/fusilli-plugin/.gitignore @@ -0,0 +1,8 @@ +# CMake build +build/ + +# clangd intellisense cache +.cache/ + +# CMake presets +CMakePresets.json diff --git a/projects/fusilli-plugin/CMakeLists.txt b/projects/fusilli-plugin/CMakeLists.txt new file mode 100644 index 00000000000..d56c0552477 --- /dev/null +++ b/projects/fusilli-plugin/CMakeLists.txt @@ -0,0 +1,35 @@ +# Copyright 2025 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +cmake_minimum_required(VERSION 3.28) + +project(fusilli-plugin + VERSION 0.1.0 + DESCRIPTION "Fusilli-Plugin: A Fusilli/IREE powered hipDNN plugin for graph JIT compilation." + LANGUAGES C CXX) + +# Set C++ standard +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_SCAN_FOR_MODULES OFF) + +# Local Includes +list(APPEND CMAKE_MODULE_PATH + ${CMAKE_CURRENT_LIST_DIR}/build_tools/cmake/ +) +include(FusilliPluginDependencyUtils) + +# Options +option(FUSILLI_PLUGIN_USE_LOCAL_FUSILLI "Use local Fusilli build from ../sharkfuser" ON) + +# Dependencies +set(HIP_PLATFORM "amd") +find_package(hip REQUIRED) +fusilli_plugin_dependency(GTest GTEST_VERSION 1.16.0) +fusilli_plugin_dependency(IREERuntime) +fusilli_plugin_dependency(hipdnn_frontend HIP_DNN_HASH 8a49655a7cc73f1c766256a7ead3a987e830efe9) +fusilli_plugin_dependency(Fusilli USE_LOCAL ${FUSILLI_PLUGIN_USE_LOCAL_FUSILLI}) diff --git a/projects/fusilli-plugin/README.md b/projects/fusilli-plugin/README.md new file mode 100644 index 00000000000..7c211381635 --- /dev/null +++ b/projects/fusilli-plugin/README.md @@ -0,0 +1,26 @@ +# Fusilli Plugin + +Fusilli-Plugin: A Fusilli/IREE powered hipDNN plugin for graph JIT compilation. + +:construction: **This project is under active development, many things don't work yet** :construction: + +The plugin builds as a shared library (`fusilli_plugin.so`) providing a `hipDNN` [kernel engine plugin](https://github.com/ROCm/hipDNN/blob/develop/docs/PluginDevelopment.md#creating-a-kernel-engine-plugin) [API](https://github.com/ROCm/hipDNN/blob/839cf6c4bc6fe403d0ef72cb5d7df004e2004743/sdk/include/hipdnn_sdk/plugin/EnginePluginApi.h). + +## Developer Guide + +### Setup + +For the time being, `fusilli-plugin` setup relies on / builds on [Fusilli setup](../sharkfuser/README.md#setup). +Keeping the projects in sync prevents "works on my machine" style bugs. +Requirements that are unique to `fusilli-plugin`, `hipDNN` and `googletest` for +example, are fetched configured and built as part of `fusilli-plugin` build. + +After following steps in [Fusilli Setup](../sharkfuser/README.md#setup), build and test +`fusilli-plugin` as follows: +```shell +$ cmake -GNinja -S. -Bbuild \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ +$ cmake --build build --target all +$ ctest --test-dir build +``` diff --git a/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake new file mode 100644 index 00000000000..b887d4b7419 --- /dev/null +++ b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake @@ -0,0 +1,194 @@ +# Copyright 2025 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#===----------------------------------------------------------------------===# +# +# Provides correctly configured dependencies for fusilli-plugin build. +# +# Main entry point: +# fusilli_plugin_dependency(DEP_NAME [args...]) +# +# `fusilli_plugin_dependency` routes to lower level `_fetch_X` macros to +# actually fetch dependency `X`. Each `_fetch_X` macro preferentially +# `find_package`s installed/system versions of packages and falls back to +# vendoring dependencies in the build tree with `FetchContent`. +# +# Supported dependencies: GTest, hipdnn_frontend, Fusilli, IREERuntime +# +#===----------------------------------------------------------------------===# + +cmake_minimum_required(VERSION 3.25.2) + +include(FetchContent) + +# Provide a fusilli plugin dependency. `fusilli_plugin_dependency` will +# preferentially use system version (available through `find_package`) of a +# dependency, and fall back to building local copy with `FetchContent` + +# configuration. +# +# fusilli_plugin_dependency( +# DEP_NAME +# [...] +# ) +# +# DEP_NAME +# Supported dependencies: +# GTest +# hipdnn_frontend +# Fusilli +# IREERuntime +# +# +# The `_fetch_X` macro for dependency X defines the available options. +# Examples: GTEST_VERSION for GTest, HIP_DNN_HASH for hipdnn_frontend +# +function(fusilli_plugin_dependency DEP_NAME) + # Set indent for logging, any logs from dep "X" will be prefixed with [X]. + set(CMAKE_MESSAGE_INDENT "[${DEP_NAME}] ") + + # Route to appropriate _fetch_X macro. CMake macros aren't textual + # expansions like C preprocessor macros, so a dynamic call (like below) to a + # macro isn't a problem. + # Macro vs function: + # - macros execute in caller's scope and arguments are textually substituted + # - functions create a new scope and arguments are real variables + # - both functions and macros are executed at runtime + # + # WARNING: Logging below checks variables it expects a _fetch_X macro to set + # in this scope, requiring that _fetch_X is a macro and not a + # function. + if(COMMAND _fetch_${DEP_NAME}) + cmake_language(CALL _fetch_${DEP_NAME} ${ARGN}) + else() + set(CMAKE_MESSAGE_INDENT "") + message(FATAL_ERROR "Unknown dependency: ${DEP_NAME}") + endif() + + # reset indent. + set(CMAKE_MESSAGE_INDENT "") + + # FetchContent_MakeAvailable(DEP) creates a _POPULATED variable + # indicating the dependency was fetched rather than found on system. + # + # WARNING: FetchContent_Declare()/FetchContent_MakeAvailable() + # can use anything for the name argument, if the _fetch_X macro + # doesn't use ${DEP_NAME} the _POPULATED we're checking for + # here won't exist and the log may be misleading. + string(TOLOWER ${DEP_NAME} DEP_NAME_LOWER) + if (${DEP_NAME_LOWER}_POPULATED) + message(STATUS "${DEP_NAME} dependency populated via FetchContent") + message(STATUS " Source: ${${DEP_NAME_LOWER}_SOURCE_DIR}") + message(STATUS " Build: ${${DEP_NAME_LOWER}_BINARY_DIR}") + else() + message(STATUS "${DEP_NAME} dependency found on system via find_package") + message(STATUS " Config: ${${DEP_NAME}_DIR}") + endif() +endfunction() + +# GTest +# +# GTEST_VERSION +# Version tag of GTest +macro(_fetch_GTest) + cmake_parse_arguments( + ARG # prefix for parsed variables + "" # options (flags) + "GTEST_VERSION" # single-value arguments + "" # multi-value arguments + ${ARGN} + ) + if(NOT DEFINED ARG_GTEST_VERSION) + message(FATAL_ERROR "GTEST_VERSION is required") + endif() + + FetchContent_Declare( + GTest + URL https://github.com/google/googletest/archive/refs/tags/v${ARG_GTEST_VERSION}.zip + ) + set(INSTALL_GTEST OFF) + set(BUILD_GMOCK OFF) + FetchContent_MakeAvailable(GTest) +endmacro() + +# hipdnn_frontend +# +# HIP_DNN_HASH +# Git commit hash or tag to fetch +macro(_fetch_hipdnn_frontend) + cmake_parse_arguments( + ARG # prefix for parsed variables + "" # options (flags) + "HIP_DNN_HASH" # single-value arguments + "" # multi-value arguments + ${ARGN} + ) + if(NOT DEFINED ARG_HIP_DNN_HASH) + message(FATAL_ERROR "HIP_DNN_HASH is required") + endif() + + FetchContent_Declare( + hipdnn_frontend + GIT_REPOSITORY https://github.com/ROCm/hipDNN.git + GIT_TAG ${ARG_HIP_DNN_HASH} + # When FIND_PACKAGE_ARGS is passed, FetchContent_Declare tries to + # find_package an installed version before downloading. + FIND_PACKAGE_ARGS CONFIG + ) + + set(HIP_DNN_BUILD_BACKEND ON) + set(HIP_DNN_BUILD_FRONTEND ON) + set(HIP_DNN_SKIP_TESTS ON) + set(HIP_DNN_BUILD_PLUGINS OFF) + set(ENABLE_CLANG_TIDY OFF) + # PIC required to link static library into shared object. + set(CMAKE_POSITION_INDEPENDENT_CODE ON) + FetchContent_MakeAvailable(hipdnn_frontend) +endmacro() + +# IREERuntime +# +# NOTE: For now, we're not providing a FetchContent fallback for IREERuntime. +# Fusilli expects that the system provides this dependency, and we're +# keeping the projects in sync as much as possible for now. If you're +# running in the fusilli docker container (described in sharkfuser README) +# passing -DIREERuntime_DIR=/workspace/.cache/docker/iree/build/lib/cmake/IREE +# should be enough. +macro(_fetch_IREERuntime) + find_package(IREERuntime CONFIG REQUIRED) +endmacro() + +# Fusilli +# +# USE_LOCAL +# If set, uses local source from ../sharkfuser directory. Without USE_LOCAL, +# requires system installation via find_package. +macro(_fetch_Fusilli) + cmake_parse_arguments( + ARG # prefix for parsed variables + "" # options (flags) + "USE_LOCAL" # single-value arguments + "" # multi-value arguments + ${ARGN} + ) + + if(NOT DEFINED ARG_USE_LOCAL) + message(FATAL_ERROR "USE_LOCAL argument is required") + endif() + + if(NOT ARG_USE_LOCAL) + # For the time being we're keeping fusilli-plugin setup as in sync as + # possible with fusilli. + message(FATAL_ERROR "Only LOCAL builds supported currently") + endif() + + message(STATUS "Using local Fusilli build from ../sharkfuser") + FetchContent_Declare( + Fusilli + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../sharkfuser + ) + set(FUSILLI_BUILD_TESTS OFF) + FetchContent_MakeAvailable(Fusilli) +endmacro() From ce616f045c80df0875896a514326ebfdc6f1c84b Mon Sep 17 00:00:00 2001 From: Aaron St George Date: Mon, 22 Sep 2025 12:50:26 -0600 Subject: [PATCH 02/14] [FusilliPlugin] Add local build option for `hipDNN` (#2295) This PR adds a convenience feature for development. When switching between `hipDNN` code and `fusilli-plugin` code it's nice to have one location for sources. --- .../cmake/FusilliPluginDependencyUtils.cmake | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake index b887d4b7419..668918a1d20 100644 --- a/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake +++ b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake @@ -104,13 +104,13 @@ macro(_fetch_GTest) message(FATAL_ERROR "GTEST_VERSION is required") endif() - FetchContent_Declare( - GTest - URL https://github.com/google/googletest/archive/refs/tags/v${ARG_GTEST_VERSION}.zip + FetchContent_Declare( + GTest + URL https://github.com/google/googletest/archive/refs/tags/v${ARG_GTEST_VERSION}.zip ) - set(INSTALL_GTEST OFF) - set(BUILD_GMOCK OFF) - FetchContent_MakeAvailable(GTest) + set(INSTALL_GTEST OFF) + set(BUILD_GMOCK OFF) + FetchContent_MakeAvailable(GTest) endmacro() # hipdnn_frontend @@ -119,24 +119,35 @@ endmacro() # Git commit hash or tag to fetch macro(_fetch_hipdnn_frontend) cmake_parse_arguments( - ARG # prefix for parsed variables - "" # options (flags) - "HIP_DNN_HASH" # single-value arguments - "" # multi-value arguments + ARG # prefix for parsed variables + "" # options (flags) + "HIP_DNN_HASH;LOCAL_PATH" # single-value arguments + "" # multi-value arguments ${ARGN} ) - if(NOT DEFINED ARG_HIP_DNN_HASH) - message(FATAL_ERROR "HIP_DNN_HASH is required") + if(NOT DEFINED ARG_LOCAL_PATH AND NOT DEFINED ARG_HIP_DNN_HASH) + message(FATAL_ERROR "Required argument: one of LOCAL_PATH or HIP_DNN_HASH") endif() - FetchContent_Declare( - hipdnn_frontend - GIT_REPOSITORY https://github.com/ROCm/hipDNN.git - GIT_TAG ${ARG_HIP_DNN_HASH} - # When FIND_PACKAGE_ARGS is passed, FetchContent_Declare tries to - # find_package an installed version before downloading. - FIND_PACKAGE_ARGS CONFIG - ) + if(DEFINED ARG_LOCAL_PATH AND DEFINED ARG_HIP_DNN_HASH) + message(FATAL_ERROR "Argument error: passing both LOCAL_PATH and HIP_DNN_HASH is ambiguous.") + endif() + + if (DEFINED ARG_HIP_DNN_HASH) + FetchContent_Declare( + hipdnn_frontend + GIT_REPOSITORY https://github.com/ROCm/hipDNN.git + GIT_TAG ${ARG_HIP_DNN_HASH} + # When FIND_PACKAGE_ARGS is passed, FetchContent_Declare tries to + # find_package an installed version before downloading. + FIND_PACKAGE_ARGS CONFIG + ) + else() + FetchContent_Declare( + hipdnn_frontend + SOURCE_DIR ${ARG_LOCAL_PATH} + ) + endif() set(HIP_DNN_BUILD_BACKEND ON) set(HIP_DNN_BUILD_FRONTEND ON) @@ -181,7 +192,7 @@ macro(_fetch_Fusilli) if(NOT ARG_USE_LOCAL) # For the time being we're keeping fusilli-plugin setup as in sync as # possible with fusilli. - message(FATAL_ERROR "Only LOCAL builds supported currently") + message(FATAL_ERROR "Only LOCAL builds are supported currently") endif() message(STATUS "Using local Fusilli build from ../sharkfuser") From b5f8e4e0cfcbfbeeb3cd1a927cc019193a0b153a Mon Sep 17 00:00:00 2001 From: Aaron St George Date: Thu, 25 Sep 2025 14:44:28 -0600 Subject: [PATCH 03/14] [FusilliPlugin] Stub plugin (`fusilli-plugin` 2 of N) (#2333) This PR adds "all the boring bits" of the fusilli plugin. It adds a basic engine plugin that hipDNN can load, which ensures that all of the required API definitions are present, and those APIs called as part of load function. The PR also tests for plugin loading, logging, and basic functionality. --- projects/fusilli-plugin/CMakeLists.txt | 26 ++ .../cmake/FusilliPluginDependencyUtils.cmake | 4 +- .../fusilli-plugin/src/fusilli_plugin.cpp | 409 ++++++++++++++++++ .../hipdnn_engine_plugin_execution_context.h | 29 ++ .../src/hipdnn_engine_plugin_handle.h | 54 +++ projects/fusilli-plugin/src/utils.h | 99 +++++ projects/fusilli-plugin/test/CMakeLists.txt | 30 ++ .../test/integration/CMakeLists.txt | 29 ++ .../integration/test_basic_integration.cpp | 78 ++++ .../test/test_fusilli_plugin_api.cpp | 160 +++++++ 10 files changed, 916 insertions(+), 2 deletions(-) create mode 100644 projects/fusilli-plugin/src/fusilli_plugin.cpp create mode 100644 projects/fusilli-plugin/src/hipdnn_engine_plugin_execution_context.h create mode 100644 projects/fusilli-plugin/src/hipdnn_engine_plugin_handle.h create mode 100644 projects/fusilli-plugin/src/utils.h create mode 100644 projects/fusilli-plugin/test/CMakeLists.txt create mode 100644 projects/fusilli-plugin/test/integration/CMakeLists.txt create mode 100644 projects/fusilli-plugin/test/integration/test_basic_integration.cpp create mode 100644 projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp diff --git a/projects/fusilli-plugin/CMakeLists.txt b/projects/fusilli-plugin/CMakeLists.txt index d56c0552477..785d5074b78 100644 --- a/projects/fusilli-plugin/CMakeLists.txt +++ b/projects/fusilli-plugin/CMakeLists.txt @@ -23,6 +23,10 @@ list(APPEND CMAKE_MODULE_PATH ) include(FusilliPluginDependencyUtils) +# Global constants +set(FUSILLI_PLUGIN_NAME fusilli_plugin) +set(FUSILLI_PLUGIN_ENGINE_ID 1001) + # Options option(FUSILLI_PLUGIN_USE_LOCAL_FUSILLI "Use local Fusilli build from ../sharkfuser" ON) @@ -33,3 +37,25 @@ fusilli_plugin_dependency(GTest GTEST_VERSION 1.16.0) fusilli_plugin_dependency(IREERuntime) fusilli_plugin_dependency(hipdnn_frontend HIP_DNN_HASH 8a49655a7cc73f1c766256a7ead3a987e830efe9) fusilli_plugin_dependency(Fusilli USE_LOCAL ${FUSILLI_PLUGIN_USE_LOCAL_FUSILLI}) + +# Plugin definition +add_library(${FUSILLI_PLUGIN_NAME} SHARED + src/fusilli_plugin.cpp + src/utils.h + src/hipdnn_engine_plugin_handle.h + src/hipdnn_engine_plugin_execution_context.h +) +target_compile_options(${FUSILLI_PLUGIN_NAME} PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS}) +target_link_libraries(${FUSILLI_PLUGIN_NAME} PRIVATE hipdnn_sdk hip::host fusilli::fusilli) +target_compile_definitions(${FUSILLI_PLUGIN_NAME} PRIVATE + FUSILLI_PLUGIN_NAME="${FUSILLI_PLUGIN_NAME}" + FUSILLI_PLUGIN_ENGINE_ID=${FUSILLI_PLUGIN_ENGINE_ID} +) +set_target_properties(${FUSILLI_PLUGIN_NAME} PROPERTIES + CXX_VISIBILITY_PRESET hidden + LIBRARY_OUTPUT_DIRECTORY "${HIPDNN_BUILD_PLUGIN_ENGINE_DIR}" +) + +# Tests +enable_testing() +add_subdirectory(test) diff --git a/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake index 668918a1d20..d6ec2eaca6e 100644 --- a/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake +++ b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake @@ -4,7 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#===----------------------------------------------------------------------===# +#===------------------------------------------------------------------------===# # # Provides correctly configured dependencies for fusilli-plugin build. # @@ -18,7 +18,7 @@ # # Supported dependencies: GTest, hipdnn_frontend, Fusilli, IREERuntime # -#===----------------------------------------------------------------------===# +#===------------------------------------------------------------------------===# cmake_minimum_required(VERSION 3.25.2) diff --git a/projects/fusilli-plugin/src/fusilli_plugin.cpp b/projects/fusilli-plugin/src/fusilli_plugin.cpp new file mode 100644 index 00000000000..c37f2164986 --- /dev/null +++ b/projects/fusilli-plugin/src/fusilli_plugin.cpp @@ -0,0 +1,409 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file is the main entry point for fusilli-plugin, implementations for all +// required hipDNN engine plugin API functions live here. +// +//===----------------------------------------------------------------------===// + +// hipDNN logging expects COMPONENT_NAME to be defined +#define COMPONENT_NAME FUSILLI_PLUGIN_NAME + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "hipdnn_engine_plugin_execution_context.h" +#include "hipdnn_engine_plugin_handle.h" +#include "utils.h" + +using namespace hipdnn_plugin; + +// TODO(#2317): ensure single source of truth for plugin version +static const char *fusilliPluginVersion = "0.0.1"; + +// s_lastError is thread_local static so can't be initialized in the header file +// as the header file is included in many context. Clear the string here. +thread_local char + PluginLastErrorManager::s_lastError[HIPDNN_PLUGIN_ERROR_STRING_MAX_LENGTH] = + ""; + +extern "C" { + +// ---------------------------------------------------------------------- +// Implementations for the basic plugin API defined in +// hipDNN/sdk/include/hipdnn_sdk/plugin/PluginApi.h +// ---------------------------------------------------------------------- + +hipdnnPluginStatus_t hipdnnPluginGetName(const char **name) { + LOG_API_ENTRY("name_ptr={:p}", static_cast(name)); + FUSILLI_PLUGIN_CHECK_NULL(name); + + *name = FUSILLI_PLUGIN_NAME; + + LOG_API_SUCCESS_AUTO("pluginName={}", *name); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnPluginGetVersion(const char **version) { + LOG_API_ENTRY("version_ptr={:p}", static_cast(version)); + FUSILLI_PLUGIN_CHECK_NULL(version); + + *version = fusilliPluginVersion; + + LOG_API_SUCCESS_AUTO("version={}", *version); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnPluginGetType(hipdnnPluginType_t *type) { + LOG_API_ENTRY("type_ptr={:p}", static_cast(type)); + FUSILLI_PLUGIN_CHECK_NULL(type); + + *type = HIPDNN_PLUGIN_TYPE_ENGINE; + + LOG_API_SUCCESS_AUTO("type={}", *type); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +void hipdnnPluginGetLastErrorString(const char **error_str) { + if (error_str) { + *error_str = hipdnn_plugin::PluginLastErrorManager::getLastError(); + } +} + +// Once plugins are loaded via plugin manager then logging will work for them +hipdnnPluginStatus_t hipdnnPluginSetLoggingCallback(hipdnnCallback_t callback) { + // No LOG_API_ENTRY as logging won't be wired up yet. + FUSILLI_PLUGIN_CHECK_NULL(callback); + + hipdnn::logging::initializeCallbackLogging(FUSILLI_PLUGIN_NAME, callback); + + LOG_API_SUCCESS_AUTO("logging callback initialized"); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +// ---------------------------------------------------------------------- +// Implementations for engine plugin API defined in +// hipDNN/sdk/include/hipdnn_sdk/plugin/EnginePluginApi.h +// ---------------------------------------------------------------------- + +hipdnnPluginStatus_t hipdnnEnginePluginGetAllEngineIds(int64_t *engineIds, + uint32_t maxEngines, + uint32_t *numEngines) { + LOG_API_ENTRY("engineIds={:p}, maxEngines={}, numEngines={:p}", + static_cast(engineIds), maxEngines, + static_cast(numEngines)); + FUSILLI_PLUGIN_CHECK_NULL(numEngines); + if (maxEngines != 0) { + FUSILLI_PLUGIN_CHECK_NULL(engineIds); + } + + // Set `numEngines` regardless of how many engines are actually returned. + // The backend queries this function twice: + // - First call: engineIds=NULL, maxEngines=0 to get the count + // - Second call: engineIds allocated based on numEngines from first pass + *numEngines = 1; + + if (maxEngines >= 1) { + engineIds[0] = FUSILLI_PLUGIN_ENGINE_ID; + } + + LOG_API_SUCCESS_AUTO("numEngines={}", *numEngines); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t +hipdnnEnginePluginCreate(hipdnnEnginePluginHandle_t *handle) { + LOG_API_ENTRY("handle_ptr={:p}", static_cast(handle)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + + // According to runtime/src/iree/hal/driver_registry.h the underlying device + // creation methods should be thread safe, fusilli::Handle ensures that + // instance creation is thread safe, so this should be thread safe. + // TODO(#2335): handle multiple architectures + auto fusilliHandle = + FUSILLI_PLUGIN_TRY(fusilli::Handle::create(fusilli::Backend::GFX942)); + *handle = new HipdnnEnginePluginHandle(std::move(fusilliHandle)); + + LOG_API_SUCCESS_AUTO("createdHandle={:p}", static_cast(*handle)); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t +hipdnnEnginePluginDestroy(hipdnnEnginePluginHandle_t handle) { + LOG_API_ENTRY("handle={:p}", static_cast(handle)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + + delete handle; + + LOG_API_SUCCESS_AUTO(""); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t +hipdnnEnginePluginSetStream(hipdnnEnginePluginHandle_t handle, + hipStream_t stream) { + LOG_API_ENTRY("handle={:p}, stream_id={:p}", static_cast(handle), + static_cast(stream)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + + // TODO(#2151): Set stream on fusilli handle, or defer creation until stream + // is available and create handle around stream. Today fusilli handle creates + // a default IREE runtime device and execute programs on a stream associated + // with that device. The passed in stream is ignored. + + LOG_API_SUCCESS_AUTO(""); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnEnginePluginGetApplicableEngineIds( + hipdnnEnginePluginHandle_t handle, const hipdnnPluginConstData_t *opGraph, + int64_t *engineIds, uint32_t maxEngines, uint32_t *numEngines) { + LOG_API_ENTRY("handle={:p}, opGraph={:p}, engineIds={:p}, maxEngines={}, " + "numEngines={:p}", + static_cast(handle), static_cast(opGraph), + static_cast(engineIds), maxEngines, + static_cast(numEngines)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(opGraph); + if (maxEngines != 0) { + FUSILLI_PLUGIN_CHECK_NULL(engineIds); + } + FUSILLI_PLUGIN_CHECK_NULL(numEngines); + + *numEngines = 0; + if (maxEngines < 1) { + HIPDNN_LOG_INFO( + "Maximum number of engines reached ({}), ignoring additional " + "engines, numEngines count: {}", + maxEngines, *numEngines); + LOG_API_SUCCESS_AUTO("numEngines={}", *numEngines); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + // TODO: check graph for supported fusilli operations, return + // FUSILLI_PLUGIN_ENGINE_ID if graph can be supported. + + LOG_API_SUCCESS_AUTO("numEngines={}", *numEngines); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t +hipdnnEnginePluginGetEngineDetails(hipdnnEnginePluginHandle_t handle, + int64_t engineId, + const hipdnnPluginConstData_t *opGraph, + hipdnnPluginConstData_t *engineDetails) { + // ---------------------------------------------------------------------- + // Plugin API call flow for engine configuration and execution. + // + // hipDNN Plugin + // ====================================================================== + // hipdnnEnginePluginGetEngineDetails -> populates engineDetails + // (flatbuffer object) with + // behavioral notes + knob + // definitions that are available + // to the higher level API. + // Return populated engineDetails + // <- (hipdnnPluginConstData_t). + // + // Decides final configuration, populating ~~ + // engineConfig flatbuffer + // (hipdnnPluginConstData_t) based on info + // provided in engineDetails. + // + // hipdnnEnginePluginCreateExecutionContext -> Creates execution context + // (hipdnnEnginePluginExecutionContext_t) + // <- based on engineConfig. + // + // Uses returned execution context to ~~ + // invoke kernels + // + // hipdnnEnginePluginDestroyEngineDetails -> cleans up engine details. + // + // hipdnnEnginePluginDestroyExecutionContext -> cleans up execution context. + // ---------------------------------------------------------------------- + + LOG_API_ENTRY("handle={:p}, engineId={}, opGraph={:p}, engineDetails={:p}", + static_cast(handle), engineId, + static_cast(opGraph), + static_cast(engineDetails)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(opGraph); + FUSILLI_PLUGIN_CHECK_NULL(engineDetails); + + if (engineId != FUSILLI_PLUGIN_ENGINE_ID) { + return hipdnn_plugin::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_BAD_PARAM, "unexpected engine id"); + } + + // Build engine details object, we're only storing the engine id for the time + // being. + flatbuffers::FlatBufferBuilder builder; + auto engineDetailsObj = + hipdnn_sdk::data_objects::CreateEngineDetails(builder, engineId); + builder.Finish(engineDetailsObj); + + // Populate out parameter. + auto detachedBuffer = + std::make_unique(builder.Release()); + engineDetails->ptr = detachedBuffer->data(); + engineDetails->size = detachedBuffer->size(); + + // Store owning pointer in handle, hipdnnEnginePluginDestroyEngineDetails will + // inform us when it's safe to clean this up. + handle->storeEngineDetailsBuffer(engineDetails->ptr, + std::move(detachedBuffer)); + + LOG_API_SUCCESS_AUTO("engineDetails->ptr={:p}", engineDetails->ptr); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t +hipdnnEnginePluginDestroyEngineDetails(hipdnnEnginePluginHandle_t handle, + hipdnnPluginConstData_t *engineDetails) { + // See comment in hipdnnEnginePluginGetEngineDetails for more about how this + // function fits into the flow. + + LOG_API_ENTRY("handle={:p}, engineDetails={:p}", static_cast(handle), + static_cast(engineDetails)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(engineDetails); + FUSILLI_PLUGIN_CHECK_NULL(engineDetails->ptr); + + // Deallocate engine details. + handle->eraseEngineDetailsBuffer(engineDetails->ptr); + engineDetails->ptr = nullptr; + engineDetails->size = 0; + + LOG_API_SUCCESS_AUTO("engineDetails->ptr={:p}", engineDetails->ptr); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t +hipdnnEnginePluginGetWorkspaceSize(hipdnnEnginePluginHandle_t handle, + const hipdnnPluginConstData_t *engineConfig, + const hipdnnPluginConstData_t *opGraph, + size_t *workspaceSize) { + LOG_API_ENTRY( + "handle={:p}, engineConfig={:p}, opGraph={:p}, workspaceSize={:p}", + static_cast(handle), static_cast(engineConfig), + static_cast(opGraph), static_cast(workspaceSize)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(engineConfig); + FUSILLI_PLUGIN_CHECK_NULL(opGraph); + FUSILLI_PLUGIN_CHECK_NULL(workspaceSize); + + // TODO(#2309): for now we're focusing on kernels that don't require scratch + // buffer space. Eventually we will need to teach IREE to report what scratch + // buffer space required, and how to use a passed in pre-allocated scratch + // space rather than a runtime allocated scratch space. + *workspaceSize = 0; + + LOG_API_SUCCESS_AUTO("workspaceSize={}", *workspaceSize); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnEnginePluginCreateExecutionContext( + hipdnnEnginePluginHandle_t handle, + const hipdnnPluginConstData_t *engineConfig, + const hipdnnPluginConstData_t *opGraph, + hipdnnEnginePluginExecutionContext_t *executionContext) { + // See comment in hipdnnEnginePluginGetEngineDetails for more about how this + // function fits into the flow. + + LOG_API_ENTRY( + "handle={:p}, engineConfig={:p}, opGraph={:p}, executionContext={:p}", + static_cast(handle), static_cast(engineConfig), + static_cast(opGraph), + static_cast(executionContext)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(engineConfig); + FUSILLI_PLUGIN_CHECK_NULL(opGraph); + FUSILLI_PLUGIN_CHECK_NULL(executionContext); + + // Ensure that config contains expected engine id. + hipdnn_plugin::EngineConfigWrapper engineConfigWrapper(engineConfig->ptr, + engineConfig->size); + if (engineConfigWrapper.engineId() != FUSILLI_PLUGIN_ENGINE_ID) { + return hipdnn_plugin::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_BAD_PARAM, "unexpected engine id"); + } + + // TODO: Implement graph compilation + // This is a stub plugin, the full implementation would: + // 1. Create and compile a fusilli graph from the opGraph + // 2. Store tensor mappings (uid to fusilli tensor attributes) + // 3. Store the compiled graph in the execution context + *executionContext = new HipdnnEnginePluginExecutionContext{}; + + LOG_API_SUCCESS_AUTO("created_execution_context={:p}", + static_cast(*executionContext)); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnEnginePluginDestroyExecutionContext( + hipdnnEnginePluginHandle_t handle, + hipdnnEnginePluginExecutionContext_t executionContext) { + LOG_API_ENTRY("handle={:p}, executionContext={:p}", + static_cast(handle), + static_cast(executionContext)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(executionContext); + + delete executionContext; + + LOG_API_SUCCESS_AUTO("destroyed executionContext"); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnEnginePluginExecuteOpGraph( + hipdnnEnginePluginHandle_t handle, + hipdnnEnginePluginExecutionContext_t executionContext, void *workspace, + const hipdnnPluginDeviceBuffer_t *deviceBuffers, + uint32_t numDeviceBuffers) { + // See comment in hipdnnEnginePluginGetEngineDetails for more about how this + // function fits into the flow. + + LOG_API_ENTRY( + "handle={:p}, executionContext={:p}, workspace={:p}, deviceBuffers={:p}, " + "numDeviceBuffers={}", + static_cast(handle), static_cast(executionContext), + workspace, static_cast(deviceBuffers), numDeviceBuffers); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(executionContext); + FUSILLI_PLUGIN_CHECK_NULL(deviceBuffers); + + // TODO: Implement graph execution. + // This is a stub plugin, the full implementation would: + // 1. Map device buffers to fusilli tensor attributes based on uid mapping + // stored on executionContext. + // 2. Create IREE buffer views from HIP device pointers. + // 3. Execute the compiled graph. + + LOG_API_SUCCESS_AUTO("executed graph"); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +} // extern "C" diff --git a/projects/fusilli-plugin/src/hipdnn_engine_plugin_execution_context.h b/projects/fusilli-plugin/src/hipdnn_engine_plugin_execution_context.h new file mode 100644 index 00000000000..004a7b52093 --- /dev/null +++ b/projects/fusilli-plugin/src/hipdnn_engine_plugin_execution_context.h @@ -0,0 +1,29 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file contains the fusilli plugin's definition of +// HipdnnEnginePluginExecutionContext. To hipDNN this type is opaque, it deals +// in hipdnnEnginePluginExecutionContext_t which is a pointer to the undefined +// HipdnnEnginePluginExecutionContext. Each plugin must define +// HipdnnEnginePluginExecutionContext in order to create something when hipDNN +// asks for an execution context. +// +// The execution context should store what's needed to execute a given kernel +// (plan in hipDNN parlance) in a hot loop without any overhead. For fusilli +// plugin, that maps to constructing and storing a fusilli::Graph based on +// hipDNN graph. When an execution is requested, it should be a simple lookup +// for uid -> tensor attribute, then a graph execution. +// +//===----------------------------------------------------------------------===// + +#ifndef FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_EXECUTION_CONTEXT_H +#define FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_EXECUTION_CONTEXT_H + +struct HipdnnEnginePluginExecutionContext {}; + +#endif // FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_EXECUTION_CONTEXT_H diff --git a/projects/fusilli-plugin/src/hipdnn_engine_plugin_handle.h b/projects/fusilli-plugin/src/hipdnn_engine_plugin_handle.h new file mode 100644 index 00000000000..80903b19b4d --- /dev/null +++ b/projects/fusilli-plugin/src/hipdnn_engine_plugin_handle.h @@ -0,0 +1,54 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file contains the fusilli plugin's definition of +// HipdnnEnginePluginHandle. To hipDNN this type is opaque, it deals in +// hipdnnEnginePluginHandle_t which is a pointer to the undefined +// HipdnnEnginePluginHandle. Each plugin must define HipdnnEnginePluginHandle in +// order to create something when hipDNN asks for an plugin handle. +// +// HipdnnEnginePluginHandle stores any persistent data associated with a +// particular engine plugin. In fusilli plugin that's the fusilli::Handle, and +// some temporary buffers that higher level APIs create and destroy at different +// times. +// +//===----------------------------------------------------------------------===// + +#ifndef FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_HANDLE_H +#define FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_HANDLE_H + +#include "fusilli/backend/handle.h" +#include +#include +#include + +struct HipdnnEnginePluginHandle { +public: + fusilli::Handle fusilliHandle; + + HipdnnEnginePluginHandle(fusilli::Handle &&handle) + : fusilliHandle(std::move(handle)) {} + + // Take ownership of a flatbuffers::DetachedBuffer and store it associated + // with its memory address. + void storeEngineDetailsBuffer( + const void *ptr, std::unique_ptr &&buffer) { + _engineDetailsBuffers[ptr] = std::move(buffer); + } + + // Destroy the flatbuffers::DetachedBuffer associated with ptr. + void eraseEngineDetailsBuffer(const void *ptr) { + _engineDetailsBuffers.erase(ptr); + } + +private: + std::unordered_map> + _engineDetailsBuffers; +}; + +#endif // FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_HANDLE_H diff --git a/projects/fusilli-plugin/src/utils.h b/projects/fusilli-plugin/src/utils.h new file mode 100644 index 00000000000..7b56ad305c3 --- /dev/null +++ b/projects/fusilli-plugin/src/utils.h @@ -0,0 +1,99 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file contains the fusilli plugin utils and macros. +// +//===----------------------------------------------------------------------===// + +#ifndef FUSILLI_PLUGIN_SRC_UTILS_H +#define FUSILLI_PLUGIN_SRC_UTILS_H + +#include +#include +#include + +// Checks for null, sets the plugin last error manager and returns error if +// null. +// +// SIDE EFFECT: any util function returning an `hipdnnPluginStatus_t` is +// intended for error checking and reporting, and therefore sets +// PluginLastErrorManager::setLastError to an appropriate error on the unhappy +// path. +template hipdnnPluginStatus_t isNull(T *value) { + if (value == nullptr) { + return hipdnn_plugin::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_BAD_PARAM, + std::string(typeid(T).name()) + " is nullptr"); + } + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +// If null, set plugin error manager last error to +// HIPDNN_PLUGIN_STATUS_BAD_PARAM and return said error from the enclosing +// scope. +#define FUSILLI_PLUGIN_CHECK_NULL(X) \ + do { \ + if (hipdnnPluginStatus_t status = isNull(X); \ + status != HIPDNN_PLUGIN_STATUS_SUCCESS) { \ + return status; \ + } \ + } while (false) + +// LOG_API_SUCCESS from hipDNN, but deducing the enclosing function rather than +// passing the function name. +#define LOG_API_SUCCESS_AUTO(format, ...) \ + LOG_API_SUCCESS(__func__, format, __VA_ARGS__) + +// Unwrap the value returned from an expression that evaluates to a +// fusilli::ErrorOr. In the unhappy path set plugin error manager last error to +// HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR and return said error from the enclosing +// scope. +// +// Usage: +// fusilli::ErrorOr getString(); +// +// hipdnnPluginStatus_t processString() { +// // Either gets the string or returns error. +// std::string str = FUSILLI_PLUGIN_TRY(getString()); +// doSomethingImportant(str); +// return HIPDNN_PLUGIN_STATUS_SUCCESS; +// } +#define FUSILLI_PLUGIN_TRY(expr) \ + ({ \ + auto errorOr = (expr); \ + if (fusilli::isError(errorOr)) { \ + return hipdnn_plugin::PluginLastErrorManager::setLastError( \ + HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR, \ + fusilli::ErrorObject(errorOr).getMessage()); \ + } \ + std ::move(*errorOr); \ + }) + +// Set plugin error manager last error and return failed status from enclosing +// scope if expression evaluates to a fusilli::ErrorObject in an error state; or +// in the case of fusilli::ErrorOr is convertible to an fusilli::ErrorObject +// in an error state. +// +// Usage: +// fusilli::ErrorObject doBar(); +// +// hipdnnPluginStatus_t doFoo() { +// // Returns error if doBar() fails +// FUSILLI_PLUGIN_CHECK_ERROR(doBar()); +// return HIPDNN_PLUGIN_STATUS_SUCCESS; +// } +#define FUSILLI_PLUGIN_CHECK_ERROR(expr) \ + do { \ + fusilli::ErrorObject err = (expr); \ + if (isError(err)) { \ + return hipdnn_plugin::PluginLastErrorManager::setLastError( \ + HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR, err.getMessage()); \ + } \ + } while (false) + +#endif // FUSILLI_PLUGIN_SRC_UTILS_H diff --git a/projects/fusilli-plugin/test/CMakeLists.txt b/projects/fusilli-plugin/test/CMakeLists.txt new file mode 100644 index 00000000000..32a67c35c8a --- /dev/null +++ b/projects/fusilli-plugin/test/CMakeLists.txt @@ -0,0 +1,30 @@ +# Copyright 2025 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +include(GoogleTest) +find_package(Threads REQUIRED) + +# Plugin API tests +add_executable(test_fusilli_plugin_api + test_fusilli_plugin_api.cpp +) +target_compile_options(test_fusilli_plugin_api PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS}) +target_link_libraries(test_fusilli_plugin_api PRIVATE + fusilli_plugin # Link plugin directly + GTest::gtest_main + hip::host + hipdnn_sdk + Threads::Threads +) +target_compile_definitions(test_fusilli_plugin_api PRIVATE + FUSILLI_PLUGIN_NAME="${FUSILLI_PLUGIN_NAME}" + FUSILLI_PLUGIN_ENGINE_ID=${FUSILLI_PLUGIN_ENGINE_ID} +) +# Register with CTest +gtest_discover_tests(test_fusilli_plugin_api) + +# Integration tests +add_subdirectory(integration) diff --git a/projects/fusilli-plugin/test/integration/CMakeLists.txt b/projects/fusilli-plugin/test/integration/CMakeLists.txt new file mode 100644 index 00000000000..3e20a22fab9 --- /dev/null +++ b/projects/fusilli-plugin/test/integration/CMakeLists.txt @@ -0,0 +1,29 @@ +# Copyright 2025 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +include(GoogleTest) +find_package(Threads REQUIRED) + +# Integration tests +add_executable(fusilli_plugin_integration_tests + # main.cpp + test_basic_integration.cpp +) +target_compile_options(fusilli_plugin_integration_tests PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS}) +target_link_libraries(fusilli_plugin_integration_tests PRIVATE + GTest::gtest_main + hip::host + hipdnn_frontend + hipdnn_sdk + Threads::Threads +) +target_compile_definitions(fusilli_plugin_integration_tests PRIVATE + FUSILLI_PLUGIN_DIR="${HIPDNN_BUILD_PLUGIN_ENGINE_DIR}" + FUSILLI_PLUGIN_NAME="${FUSILLI_PLUGIN_NAME}" + FUSILLI_PLUGIN_ENGINE_ID="${FUSILLI_PLUGIN_ENGINE_ID}" +) +# Register with CTest +gtest_discover_tests(fusilli_plugin_integration_tests) diff --git a/projects/fusilli-plugin/test/integration/test_basic_integration.cpp b/projects/fusilli-plugin/test/integration/test_basic_integration.cpp new file mode 100644 index 00000000000..391cc80e35b --- /dev/null +++ b/projects/fusilli-plugin/test/integration/test_basic_integration.cpp @@ -0,0 +1,78 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include +#include +#include + +std::vector getLoadedPlugins(hipdnnHandle_t handle) { + size_t numPlugins = 0; + size_t maxPathLength = 0; + auto status = hipdnnGetLoadedEnginePluginPaths_ext(handle, &numPlugins, + nullptr, &maxPathLength); + + if (status != HIPDNN_STATUS_SUCCESS) { + throw std::runtime_error("Failed to get loaded plugin paths"); + } + + if (numPlugins == 0) { + return {}; + } + + std::vector> pathBuffers(numPlugins, + std::vector(maxPathLength)); + std::vector pluginPathsC(numPlugins); + for (size_t i = 0; i < numPlugins; ++i) { + pluginPathsC[i] = pathBuffers[i].data(); + } + + status = hipdnnGetLoadedEnginePluginPaths_ext( + handle, &numPlugins, pluginPathsC.data(), &maxPathLength); + if (status != HIPDNN_STATUS_SUCCESS) { + throw std::runtime_error("Failed to get loaded plugin paths"); + } + + std::vector pluginPaths; + pluginPaths.reserve(numPlugins); + for (size_t i = 0; i < numPlugins; ++i) { + pluginPaths.emplace_back(pluginPathsC[i]); + } + return pluginPaths; +} + +TEST(IntegrationTests, PluginLoad) { + // Uncomment if you want debug logging info. + // setenv("HIPDNN_LOG_LEVEL", "info", 1); + + // Ensure hipDNN will load fusilli plugin. + const std::array paths = {FUSILLI_PLUGIN_DIR}; + hipdnnStatus_t status = hipdnnSetEnginePluginPaths_ext( + paths.size(), paths.data(), HIPDNN_PLUGIN_LOADING_ABSOLUTE); + EXPECT_EQ(status, HIPDNN_STATUS_SUCCESS); + + // Stand up enough of hipDNN to load plugins. + hipdnnHandle_t handle = nullptr; + status = hipdnnCreate(&handle); + ASSERT_EQ(status, HIPDNN_STATUS_SUCCESS); + ASSERT_NE(handle, nullptr); + + // If fusilli plugin fails to define a required method it will fail to load. + auto loadedPlugins = getLoadedPlugins(handle); + EXPECT_EQ(loadedPlugins.size(), 1); + + // Check that fusilli plugin did load. + auto expectedPath = std::filesystem::path(FUSILLI_PLUGIN_DIR) / + std::format("lib{}.so", FUSILLI_PLUGIN_NAME); + EXPECT_TRUE(std::ranges::any_of( + loadedPlugins, [&expectedPath](const std::string &loadedPluginPath) { + return std::filesystem::canonical(loadedPluginPath) == expectedPath; + })); + + EXPECT_EQ(hipdnnDestroy(handle), HIPDNN_STATUS_SUCCESS); +} diff --git a/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp b/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp new file mode 100644 index 00000000000..6eb088d956d --- /dev/null +++ b/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp @@ -0,0 +1,160 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +bool loggingCallbackCalled = false; +std::vector capturedLogMessages; +std::vector capturedLogSeverities; +std::mutex logMutex; +std::condition_variable logConditionVariable; + +void testLoggingCallback(hipdnnSeverity_t severity, const char *msg) { + // hipDNN sets spdlog up to log in a separate thread, so we need to put our + // mutual exclusion gloves on before touching any variables the main thread + // does. + std::scoped_lock lock(logMutex); + + loggingCallbackCalled = true; + if (msg) { + capturedLogMessages.push_back(std::string(msg)); + capturedLogSeverities.push_back(severity); + } + logConditionVariable.notify_one(); +} + +TEST(TestFusilliPluginApi, Logging) { + // Set tracking variables + { + std::scoped_lock lock(logMutex); + loggingCallbackCalled = false; + capturedLogMessages.clear(); + capturedLogSeverities.clear(); + } + + // Set up logging callback + ASSERT_EQ(hipdnnPluginSetLoggingCallback(testLoggingCallback), + HIPDNN_PLUGIN_STATUS_SUCCESS); + + std::unique_lock lock(logMutex); + + // Wait for the logging callback to signal that it has been called. + auto timeout = std::chrono::steady_clock::now() + std::chrono::seconds(5); + EXPECT_TRUE(logConditionVariable.wait_until( + lock, timeout, [&]() { return loggingCallbackCalled; })); + + EXPECT_TRUE(loggingCallbackCalled); + EXPECT_FALSE(capturedLogMessages.empty()); + EXPECT_TRUE(capturedLogMessages.front().find( + "logging callback initialized") != std::string::npos); +}; + +TEST(TestFusilliPluginApi, GetNameSuccess) { + const char *name = nullptr; + EXPECT_EQ(hipdnnPluginGetName(&name), HIPDNN_PLUGIN_STATUS_SUCCESS); + EXPECT_STREQ(name, FUSILLI_PLUGIN_NAME); +} + +TEST(TestFusilliPluginApi, GetNameNullptr) { + EXPECT_EQ(hipdnnPluginGetName(nullptr), HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + // Verify error was set + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); +} + +TEST(TestFusilliPluginApi, GetVersionSuccess) { + const char *version = nullptr; + EXPECT_EQ(hipdnnPluginGetVersion(&version), HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_NE(version, nullptr); + // TODO(#2317): check returned version against single source of truth. +} + +TEST(TestFusilliPluginApi, GetVersionNullptr) { + EXPECT_EQ(hipdnnPluginGetVersion(nullptr), HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + // Verify error was set + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); +} + +TEST(TestFusilliPluginApi, GetTypeSuccess) { + hipdnnPluginType_t type; + EXPECT_EQ(hipdnnPluginGetType(&type), HIPDNN_PLUGIN_STATUS_SUCCESS); + EXPECT_EQ(type, HIPDNN_PLUGIN_TYPE_ENGINE); +} + +TEST(TestFusilliPluginApi, GetTypeNullptr) { + EXPECT_EQ(hipdnnPluginGetType(nullptr), HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + // Verify error was set + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); +} + +TEST(TestFusilliPluginApi, GetLastErrorStringSuccess) { + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); + // Initially should be empty or contain a previous error + EXPECT_GE(strlen(errorStr), 0); +} + +TEST(TestFusilliPluginApi, GetLastErrorStringNullptr) { + // This should not crash even with nullptr + EXPECT_NO_THROW(hipdnnPluginGetLastErrorString(nullptr)); +} + +TEST(TestFusilliPluginApi, SetLoggingCallbackNullptr) { + // Setting nullptr should return BAD_PARAM + EXPECT_EQ(hipdnnPluginSetLoggingCallback(nullptr), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + // Verify error was set + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); +} + +TEST(TestFusilliPluginApi, GetAllEngineIds) { + // First call with null buffer to get count + uint32_t numEngines = 0; + EXPECT_EQ(hipdnnEnginePluginGetAllEngineIds(nullptr, 0, &numEngines), + HIPDNN_PLUGIN_STATUS_SUCCESS); + EXPECT_EQ(numEngines, 1); + + // Second call to get actual engine IDs + std::vector engineIds(numEngines); + EXPECT_EQ(hipdnnEnginePluginGetAllEngineIds(engineIds.data(), numEngines, + &numEngines), + HIPDNN_PLUGIN_STATUS_SUCCESS); + EXPECT_EQ(numEngines, 1); + EXPECT_EQ(engineIds[0], FUSILLI_PLUGIN_ENGINE_ID); +} + +TEST(TestFusilliPluginApi, GetAllEngineIdsNullNumEngines) { + EXPECT_EQ(hipdnnEnginePluginGetAllEngineIds(nullptr, 0, nullptr), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + // Verify error was set + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); + EXPECT_GT(strlen(errorStr), 0u); +} From a0a320ba86e3767724e7b6d53f90f279da1d8430 Mon Sep 17 00:00:00 2001 From: Aaron St George Date: Mon, 29 Sep 2025 17:23:39 -0600 Subject: [PATCH 04/14] [FusilliPlugin] Graph conversion (`fusilli-plugin` 3 of N) (#2365) This PR implements [`hipdnnEnginePluginGetApplicableEngineIds`](https://github.com/ROCm/rocm-libraries/blob/7e29d65a8db488a5e038c3616398144d7c2289a5/projects/hipdnn/sdk/include/hipdnn_sdk/plugin/EnginePluginApi.h#L90-L115) and [`hipdnnEnginePluginCreateExecutionContext`](https://github.com/ROCm/rocm-libraries/blob/7e29d65a8db488a5e038c3616398144d7c2289a5/projects/hipdnn/sdk/include/hipdnn_sdk/plugin/EnginePluginApi.h#L170-L191), which together completes `hipDNN` -> `fusilli` graph translation for the initial plugin goal of e2e run of single node ConvFprop graph. --- projects/fusilli-plugin/CMakeLists.txt | 6 +- .../fusilli-plugin/include/graph_import.h | 254 ++++++++++++++++++ .../hipdnn_engine_plugin_execution_context.h | 12 +- .../hipdnn_engine_plugin_handle.h | 0 .../fusilli-plugin/{src => include}/utils.h | 0 .../fusilli-plugin/src/fusilli_plugin.cpp | 64 ++++- projects/fusilli-plugin/test/CMakeLists.txt | 16 +- .../test/integration/CMakeLists.txt | 1 - .../test/test_fusilli_plugin_api.cpp | 210 +++++++++++++++ .../fusilli-plugin/test/test_graph_import.cpp | 40 +++ projects/fusilli-plugin/test/utils.h | 29 ++ 11 files changed, 618 insertions(+), 14 deletions(-) create mode 100644 projects/fusilli-plugin/include/graph_import.h rename projects/fusilli-plugin/{src => include}/hipdnn_engine_plugin_execution_context.h (80%) rename projects/fusilli-plugin/{src => include}/hipdnn_engine_plugin_handle.h (100%) rename projects/fusilli-plugin/{src => include}/utils.h (100%) create mode 100644 projects/fusilli-plugin/test/test_graph_import.cpp create mode 100644 projects/fusilli-plugin/test/utils.h diff --git a/projects/fusilli-plugin/CMakeLists.txt b/projects/fusilli-plugin/CMakeLists.txt index 785d5074b78..2f8a4b9678f 100644 --- a/projects/fusilli-plugin/CMakeLists.txt +++ b/projects/fusilli-plugin/CMakeLists.txt @@ -38,12 +38,12 @@ fusilli_plugin_dependency(IREERuntime) fusilli_plugin_dependency(hipdnn_frontend HIP_DNN_HASH 8a49655a7cc73f1c766256a7ead3a987e830efe9) fusilli_plugin_dependency(Fusilli USE_LOCAL ${FUSILLI_PLUGIN_USE_LOCAL_FUSILLI}) +# Includes +include_directories(include) + # Plugin definition add_library(${FUSILLI_PLUGIN_NAME} SHARED src/fusilli_plugin.cpp - src/utils.h - src/hipdnn_engine_plugin_handle.h - src/hipdnn_engine_plugin_execution_context.h ) target_compile_options(${FUSILLI_PLUGIN_NAME} PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS}) target_link_libraries(${FUSILLI_PLUGIN_NAME} PRIVATE hipdnn_sdk hip::host fusilli::fusilli) diff --git a/projects/fusilli-plugin/include/graph_import.h b/projects/fusilli-plugin/include/graph_import.h new file mode 100644 index 00000000000..3435dbf8d12 --- /dev/null +++ b/projects/fusilli-plugin/include/graph_import.h @@ -0,0 +1,254 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file contains facilities for converting hipDNN serialized graphs to +// fusilli graphs. +// +//===----------------------------------------------------------------------===// + +#ifndef FUSILLI_PLUGIN_SRC_GRAPH_IMPORT_H +#define FUSILLI_PLUGIN_SRC_GRAPH_IMPORT_H + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "fusilli/support/logging.h" +#include "hipdnn_engine_plugin_execution_context.h" + +// Convert from hipDNN DataType to fusilli DataType +inline fusilli::ErrorOr +hipDnnDataTypeToFusilliDataType(hipdnn_sdk::data_objects::DataType hipdnnType) { + switch (hipdnnType) { + case hipdnn_sdk::data_objects::DataType::HALF: + return ok(fusilli::DataType::Half); + case hipdnn_sdk::data_objects::DataType::BFLOAT16: + return ok(fusilli::DataType::BFloat16); + case hipdnn_sdk::data_objects::DataType::FLOAT: + return ok(fusilli::DataType::Float); + case hipdnn_sdk::data_objects::DataType::DOUBLE: + return ok(fusilli::DataType::Double); + case hipdnn_sdk::data_objects::DataType::UINT8: + return ok(fusilli::DataType::Uint8); + case hipdnn_sdk::data_objects::DataType::INT32: + return ok(fusilli::DataType::Int32); + case hipdnn_sdk::data_objects::DataType::UNSET: + return ok(fusilli::DataType::NotSet); + default: + return error(fusilli::ErrorCode::RuntimeFailure, + "Unknown type in hipdnn -> fusilli graph translation."); + } +} + +// Graph import is done through importGraph function, this class exists for +// organization and is used by importGraph. +// +// Graph import is designed around individual Node import functions (such as +// importConvFPropAttr) which convert a given node type, and track input and +// output tensors in shared state (via importNodeInput and importNodeOutput +// functions). Graph nodes are processed in topological order to ensure that +// outputs of producer nodes are tracked and available for consuming nodes. +// +// NOTE: input hipDNN graph .node()s may not be in topological order. There's +// plans for a topological sort method but that's not available yet. As we're +// only handling single-node graphs currently, that's not a problem. +class GraphImport { +private: + friend fusilli::ErrorOr + importGraph(const hipdnnPluginConstData_t *opGraph); + + // The imported graph. + fusilli::Graph fusilliGraph; + + // Maps hipDNN tensor UIDs to fusilli::TensorAttrs for graph boundary tensors + // (inputs and outputs). Used by hipdnnEnginePluginExecuteOpGraph to match + // incoming device buffers (identified by UID) to their corresponding + // fusilli::TensorAttr. + // + // All tensors in this map should be non-virtual boundary tensors---a virtual + // tensor is an internal intermediate---internal tensors will be tracked + // elsewhere. Currently we only support single node graphs, so there are no + // virtual tensors to track. + std::unordered_map> + uidToIOTensor; + + // Helper class for reading from flatbuffer. + hipdnn_plugin::GraphWrapper opGraphWrapper; + + GraphImport(const hipdnnPluginConstData_t *opGraph) + : opGraphWrapper(opGraph->ptr, opGraph->size) {} + + fusilli::ErrorObject importGraph() { + const hipdnn_sdk::data_objects::Graph &hipDnnGraph = + opGraphWrapper.getGraph(); + + // Import graph level properties. + fusilliGraph.setName(hipDnnGraph.name()->str()) + .setIODataType( + FUSILLI_TRY(hipDnnDataTypeToFusilliDataType(hipDnnGraph.io_type()))) + .setIntermediateDataType(FUSILLI_TRY( + hipDnnDataTypeToFusilliDataType(hipDnnGraph.intermediate_type()))) + .setComputeDataType(FUSILLI_TRY( + hipDnnDataTypeToFusilliDataType(hipDnnGraph.compute_type()))); + + return importNodes(); + } + + // Import all graph nodes. + fusilli::ErrorObject importNodes() { + if (opGraphWrapper.nodeCount() > 1) + return fusilli::error(fusilli::ErrorCode::NotImplemented, + "Multi-node graphs not supported currently."); + for (size_t i = 0; i < opGraphWrapper.nodeCount(); ++i) { + const hipdnn_sdk::data_objects::Node &node = opGraphWrapper.getNode(i); + FUSILLI_CHECK_ERROR(importNode(node)); + } + + return fusilli::ok(); + } + + // Import single graph node. + fusilli::ErrorObject importNode(const hipdnn_sdk::data_objects::Node &node) { + switch (node.attributes_type()) { + case hipdnn_sdk::data_objects::NodeAttributes::ConvolutionFwdAttributes: + FUSILLI_CHECK_ERROR( + importConvFPropAttr(node.attributes_as_ConvolutionFwdAttributes())); + break; + default: + return fusilli::error(fusilli::ErrorCode::NotImplemented, + "Unsupported node type."); + } + return fusilli::ok(); + } + + fusilli::ErrorObject + importConvFPropAttr(const hipdnn_sdk::data_objects::ConvolutionFwdAttributes + *hipDnnConvFwdAttr) { + // Import node inputs. + std::shared_ptr x = + FUSILLI_TRY(importNodeInput(hipDnnConvFwdAttr->x_tensor_uid(), "x")); + std::shared_ptr w = + FUSILLI_TRY(importNodeInput(hipDnnConvFwdAttr->w_tensor_uid(), "w")); + + // hipdnnEnginePluginGetApplicableEngineIds should have already eliminated + // any nodes with asymmetric padding, this is just a double check. + if (!std::ranges::equal(*hipDnnConvFwdAttr->pre_padding(), + *hipDnnConvFwdAttr->post_padding())) // C++ 20 + return fusilli::error(fusilli::ErrorCode::AttributeNotSet, + "Conv node with asymmetric padding found."); + // Import node. + auto fusilliConvFwdAttr = + fusilli::ConvFPropAttr() + .setPadding(*hipDnnConvFwdAttr->post_padding()) + .setStride(*hipDnnConvFwdAttr->stride()) + .setDilation(*hipDnnConvFwdAttr->dilation()) + .setName("conv_fprop"); + std::shared_ptr y = + fusilliGraph.convFProp(x, w, fusilliConvFwdAttr); + + // Import node output. + FUSILLI_CHECK_ERROR( + importNodeOutput(hipDnnConvFwdAttr->y_tensor_uid(), "y", y)); + + return fusilli::ok(); + } + + // Import, and track, node input tensor. Node input tensor is created in the + // case of a boundary tensor, and read from shared state otherwise. + fusilli::ErrorOr> + importNodeInput(int64_t uid, const char *name) { + // Get hipDNN tensor. TensorMap is created from the graph that uid variable + // is read from, so .at() call should be safe. + const hipdnn_sdk::data_objects::TensorAttributes *hipDnnTensorAttr = + opGraphWrapper.getTensorMap().at(uid); + + // A virtual node indicates a non-boundary node. + if (hipDnnTensorAttr->virtual_()) { + // When multi-op graphs are supported, we would look up the output of a + // previously imported node and return it here. + return fusilli::error(fusilli::ErrorCode::NotImplemented, + "Virtual inputs currently unsupported."); + } + + // Import new tensor. + auto fusilliTensorAttr = fusilli::TensorAttr().setName( + std::format("{}_{}", name, uid)); // C++ 20 + FUSILLI_CHECK_ERROR(importAttrs(fusilliTensorAttr, hipDnnTensorAttr)); + std::shared_ptr graphInput = + fusilliGraph.tensor(fusilliTensorAttr); + + // Track boundary tensor. + uidToIOTensor[uid] = graphInput; + + return ok(graphInput); + }; + + // Import and track node output tensor. + fusilli::ErrorObject + importNodeOutput(int64_t uid, const char *name, + const std::shared_ptr &nodeOutput) { + // Get hipDNN tensor. TensorMap is created from the graph that uid variable + // is read from, so .at() call should be safe. + const hipdnn_sdk::data_objects::TensorAttributes *hipDnnTensorAttr = + opGraphWrapper.getTensorMap().at(uid); + + // Import attrs. + nodeOutput->setName(std::format("{}_{}", name, uid)); // C++ 20 + FUSILLI_CHECK_ERROR(importAttrs(*nodeOutput, hipDnnTensorAttr)); + + // A virtual node indicates a non-boundary node. + if (hipDnnTensorAttr->virtual_()) { + // This tensor is an input to nodes farther down the topological sort. + // When multi-op graphs are supported, we would track UID -> node output + // tensor in a non-IO tensor map. + return fusilli::error(fusilli::ErrorCode::NotImplemented, + "Virtual outputs currently unsupported. An output " + "tensor may have not been marked as an output."); + } + + // Track boundary node. + uidToIOTensor[uid] = nodeOutput; + + return fusilli::ok(); + }; + + // Import all tensor attrs src -> dest. + fusilli::ErrorObject + importAttrs(fusilli::TensorAttr &dest, + const hipdnn_sdk::data_objects::TensorAttributes *src) { + dest.setIsVirtual(src->virtual_()) + .setDim(*src->dims()) + .setStride(*src->strides()) + .setDataType( + FUSILLI_TRY(hipDnnDataTypeToFusilliDataType(src->data_type()))); + return fusilli::ok(); + } +}; + +// Given a hipDNN serialized graph, return imported fusilli::Graph and UID -> +// fusilli::TensorAttr map for IO tensors. +// +// NOTE: HipdnnEnginePluginExecutionContext used as return type because it +// contains (only) the exact required fields. If it requires more members in +// the future it's probably worth creating a new data transmission type. +inline fusilli::ErrorOr +importGraph(const hipdnnPluginConstData_t *opGraph) { + auto gc = GraphImport(opGraph); + FUSILLI_CHECK_ERROR(gc.importGraph()); + return HipdnnEnginePluginExecutionContext{.graph = std::move(gc.fusilliGraph), + .uidToFusilliTensorAttr = + std::move(gc.uidToIOTensor)}; +} + +#endif // FUSILLI_PLUGIN_SRC_GRAPH_IMPORT_H diff --git a/projects/fusilli-plugin/src/hipdnn_engine_plugin_execution_context.h b/projects/fusilli-plugin/include/hipdnn_engine_plugin_execution_context.h similarity index 80% rename from projects/fusilli-plugin/src/hipdnn_engine_plugin_execution_context.h rename to projects/fusilli-plugin/include/hipdnn_engine_plugin_execution_context.h index 004a7b52093..9c7c8fde39b 100644 --- a/projects/fusilli-plugin/src/hipdnn_engine_plugin_execution_context.h +++ b/projects/fusilli-plugin/include/hipdnn_engine_plugin_execution_context.h @@ -24,6 +24,16 @@ #ifndef FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_EXECUTION_CONTEXT_H #define FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_EXECUTION_CONTEXT_H -struct HipdnnEnginePluginExecutionContext {}; +#include + +struct HipdnnEnginePluginExecutionContext { + // Fusilli graph. + fusilli::Graph graph; + + // Map from hipDNN tensor UID to fusilli::TensorAttrs for graph boundary + // tensors (inputs and outputs). + std::unordered_map> + uidToFusilliTensorAttr; +}; #endif // FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_EXECUTION_CONTEXT_H diff --git a/projects/fusilli-plugin/src/hipdnn_engine_plugin_handle.h b/projects/fusilli-plugin/include/hipdnn_engine_plugin_handle.h similarity index 100% rename from projects/fusilli-plugin/src/hipdnn_engine_plugin_handle.h rename to projects/fusilli-plugin/include/hipdnn_engine_plugin_handle.h diff --git a/projects/fusilli-plugin/src/utils.h b/projects/fusilli-plugin/include/utils.h similarity index 100% rename from projects/fusilli-plugin/src/utils.h rename to projects/fusilli-plugin/include/utils.h diff --git a/projects/fusilli-plugin/src/fusilli_plugin.cpp b/projects/fusilli-plugin/src/fusilli_plugin.cpp index c37f2164986..74057a04287 100644 --- a/projects/fusilli-plugin/src/fusilli_plugin.cpp +++ b/projects/fusilli-plugin/src/fusilli_plugin.cpp @@ -15,8 +15,10 @@ #define COMPONENT_NAME FUSILLI_PLUGIN_NAME #include +#include #include #include +#include #include #include #include @@ -34,7 +36,9 @@ #include #include #include +#include +#include "graph_import.h" #include "hipdnn_engine_plugin_execution_context.h" #include "hipdnn_engine_plugin_handle.h" #include "utils.h" @@ -203,8 +207,42 @@ hipdnnPluginStatus_t hipdnnEnginePluginGetApplicableEngineIds( return HIPDNN_PLUGIN_STATUS_SUCCESS; } - // TODO: check graph for supported fusilli operations, return - // FUSILLI_PLUGIN_ENGINE_ID if graph can be supported. + // Check for single conv_fprop node graph + GraphWrapper opGraphWrapper(opGraph->ptr, opGraph->size); + if (opGraphWrapper.nodeCount() != 1) { + HIPDNN_LOG_INFO("Fusilli plan builder is (currently) only applicable only " + "for single node conv_fprop graphs.", + opGraphWrapper.nodeCount()); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + if (!opGraphWrapper.hasOnlySupportedAttributes( + std::set{ + hipdnn_sdk::data_objects::NodeAttributes:: + ConvolutionFwdAttributes})) { + HIPDNN_LOG_INFO("Fusilli plan builder is (currently) only applicable only " + "for single node conv_fprop graphs.", + opGraphWrapper.nodeCount()); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + // Check single conv_fprop node for symmetric padding + const hipdnn_sdk::data_objects::ConvolutionFwdAttributes *convFwdAttrs = + opGraphWrapper.getNode(0).attributes_as_ConvolutionFwdAttributes(); + // pre/post_padding are flatbuffer::vectors (not std::vectors) and don't + // override ==, so we use std::ranges::equal for structural vs referential + // equality. + if (!std::ranges::equal(*convFwdAttrs->pre_padding(), + *convFwdAttrs->post_padding())) { // C++ 20 + HIPDNN_LOG_INFO("Fusilli plan builder is (currently) requires symmetric " + "padding for conv_fprop nodes.", + opGraphWrapper.nodeCount()); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + // We have a single conv_fprop node with symmetric padding, the fusilli engine + // is applicable. + engineIds[0] = FUSILLI_PLUGIN_ENGINE_ID; + *numEngines = 1; LOG_API_SUCCESS_AUTO("numEngines={}", *numEngines); return HIPDNN_PLUGIN_STATUS_SUCCESS; @@ -351,12 +389,22 @@ hipdnnPluginStatus_t hipdnnEnginePluginCreateExecutionContext( HIPDNN_PLUGIN_STATUS_BAD_PARAM, "unexpected engine id"); } - // TODO: Implement graph compilation - // This is a stub plugin, the full implementation would: - // 1. Create and compile a fusilli graph from the opGraph - // 2. Store tensor mappings (uid to fusilli tensor attributes) - // 3. Store the compiled graph in the execution context - *executionContext = new HipdnnEnginePluginExecutionContext{}; + auto importAndCompile = [&handle](const hipdnnPluginConstData_t *opGraph) + -> fusilli::ErrorOr { + // Import fusilli::Graph and compute UID -> fusilli::TensorAttr map for + // graph boundary tensors. + HipdnnEnginePluginExecutionContext graphImport = + FUSILLI_TRY(importGraph(opGraph)); + + // Compile graph + FUSILLI_CHECK_ERROR(graphImport.graph.validate()); + FUSILLI_CHECK_ERROR(graphImport.graph.compile(handle->fusilliHandle)); + + return fusilli::ok(std::move(graphImport)); + }; + + *executionContext = new HipdnnEnginePluginExecutionContext( + FUSILLI_PLUGIN_TRY(importAndCompile(opGraph))); LOG_API_SUCCESS_AUTO("created_execution_context={:p}", static_cast(*executionContext)); diff --git a/projects/fusilli-plugin/test/CMakeLists.txt b/projects/fusilli-plugin/test/CMakeLists.txt index 32a67c35c8a..cfb816a2037 100644 --- a/projects/fusilli-plugin/test/CMakeLists.txt +++ b/projects/fusilli-plugin/test/CMakeLists.txt @@ -10,10 +10,12 @@ find_package(Threads REQUIRED) # Plugin API tests add_executable(test_fusilli_plugin_api test_fusilli_plugin_api.cpp + utils.h ) target_compile_options(test_fusilli_plugin_api PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS}) target_link_libraries(test_fusilli_plugin_api PRIVATE fusilli_plugin # Link plugin directly + fusilli::fusilli GTest::gtest_main hip::host hipdnn_sdk @@ -23,8 +25,20 @@ target_compile_definitions(test_fusilli_plugin_api PRIVATE FUSILLI_PLUGIN_NAME="${FUSILLI_PLUGIN_NAME}" FUSILLI_PLUGIN_ENGINE_ID=${FUSILLI_PLUGIN_ENGINE_ID} ) -# Register with CTest gtest_discover_tests(test_fusilli_plugin_api) +# Graph import tests +add_executable(test_graph_import + test_graph_import.cpp + utils.h +) +target_compile_options(test_graph_import PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS}) +target_link_libraries(test_graph_import PRIVATE + fusilli::fusilli + GTest::gtest_main + hipdnn_sdk +) +gtest_discover_tests(test_graph_import) + # Integration tests add_subdirectory(integration) diff --git a/projects/fusilli-plugin/test/integration/CMakeLists.txt b/projects/fusilli-plugin/test/integration/CMakeLists.txt index 3e20a22fab9..7b11220a761 100644 --- a/projects/fusilli-plugin/test/integration/CMakeLists.txt +++ b/projects/fusilli-plugin/test/integration/CMakeLists.txt @@ -9,7 +9,6 @@ find_package(Threads REQUIRED) # Integration tests add_executable(fusilli_plugin_integration_tests - # main.cpp test_basic_integration.cpp ) target_compile_options(fusilli_plugin_integration_tests PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS}) diff --git a/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp b/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp index 6eb088d956d..e3ad40d90b0 100644 --- a/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp +++ b/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp @@ -4,18 +4,33 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include +#include #include +#include +#include +#include +#include +#include #include #include +#include #include #include +#include +#include #include #include #include #include #include +#include "fusilli/attributes/tensor_attributes.h" +#include "graph_import.h" +#include "hipdnn_engine_plugin_execution_context.h" +#include "utils.h" + bool loggingCallbackCalled = false; std::vector capturedLogMessages; std::vector capturedLogSeverities; @@ -158,3 +173,198 @@ TEST(TestFusilliPluginApi, GetAllEngineIdsNullNumEngines) { ASSERT_NE(errorStr, nullptr); EXPECT_GT(strlen(errorStr), 0u); } + +// TODO(#2363): investigate using createValidConvFwdGraph from upstream hipDNN +flatbuffers::FlatBufferBuilder +createValidConvFwdGraph(int64_t xUID = 0, int64_t wUID = 1, int64_t yUID = 2, + hipdnn_sdk::data_objects::DataType dataType = + hipdnn_sdk::data_objects::DataType::FLOAT, + const std::vector &xDims = {4, 4, 4, 4}, + const std::vector &xStrides = {64, 16, 4, 1}, + const std::vector &wDims = {4, 4, 1, 1}, + const std::vector &wStrides = {4, 1, 1, 1}, + const std::vector &yDims = {4, 4, 4, 4}, + const std::vector &yStrides = {64, 16, 4, 1}, + const std::vector &convPrePadding = {0, 0}, + const std::vector &convPostPadding = {0, 0}, + const std::vector &convStrides = {1, 1}, + const std::vector &convDilation = {1, 1}) { + flatbuffers::FlatBufferBuilder builder; + std::vector<::flatbuffers::Offset> + tensorAttributes; + + tensorAttributes.push_back(CreateTensorAttributesDirect( + builder, xUID, "x", dataType, &xStrides, &xDims)); + + tensorAttributes.push_back(CreateTensorAttributesDirect( + builder, wUID, "w", dataType, &wStrides, &wDims)); + + tensorAttributes.push_back(CreateTensorAttributesDirect( + builder, yUID, "y", dataType, &yStrides, &yDims)); + + auto convAttributes = CreateConvolutionFwdAttributesDirect( + builder, + /*x_tensor_uid*/ xUID, + /*w_tensor_uid*/ wUID, + /*y_tensor_uid*/ yUID, &convPrePadding, &convPostPadding, &convStrides, + &convDilation, hipdnn_sdk::data_objects::ConvMode::CROSS_CORRELATION); + + std::vector<::flatbuffers::Offset> nodes; + auto node = CreateNodeDirect( + builder, "conv_fwd", + hipdnn_sdk::data_objects::NodeAttributes::ConvolutionFwdAttributes, + convAttributes.Union()); + nodes.push_back(node); + + auto graphOffset = + CreateGraphDirect(builder, "test", + /*compute_type*/ dataType, + /*intermediate_type*/ dataType, + /*io_type=*/dataType, &tensorAttributes, &nodes); + builder.Finish(graphOffset); + return builder; +} + +TEST(TestFusilliPluginApi, GetApplicableEngineIds) { + // Create plugin handle. + hipdnnEnginePluginHandle_t handle = nullptr; + ASSERT_EQ(hipdnnEnginePluginCreate(&handle), HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_NE(handle, nullptr); + + // Create a serialized hipDNN bach norm graph. + auto builder = hipdnn_backend::test_utilities::createValidBatchnormGraph(); + hipdnnPluginConstData_t opGraph; + opGraph.ptr = builder.GetBufferPointer(); + opGraph.size = builder.GetSize(); + + // Fusilli plugin should not offer to compile and execute bach norm (yet). + std::array engineIDs; + uint32_t numEngines = -1; + ASSERT_EQ(hipdnnEnginePluginGetApplicableEngineIds( + handle, &opGraph, engineIDs.data(), 5, &numEngines), + HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_EQ(numEngines, 0); + + // Create a serialized hipDNN conv_fprop graph with symmetric padding. + builder = createValidConvFwdGraph(); + opGraph.ptr = builder.GetBufferPointer(); + opGraph.size = builder.GetSize(); + + // Fusilli plugin should offer to compile and execute single node conv_fprop. + ASSERT_EQ(hipdnnEnginePluginGetApplicableEngineIds( + handle, &opGraph, engineIDs.data(), 5, &numEngines), + HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_EQ(numEngines, 1); + ASSERT_EQ(engineIDs[0], FUSILLI_PLUGIN_ENGINE_ID); + + // Create a serialized hipDNN conv_fprop graph with asymmetric padding. + builder = createValidConvFwdGraph( + /*xUID=*/0, /*wUID=*/1, /*yUID=*/2, + /*dataType=*/hipdnn_sdk::data_objects::DataType::FLOAT, + /*xDims=*/{4, 4, 4, 4}, /*xStrides=*/{64, 16, 4, 1}, + /*wDims=*/{4, 4, 1, 1}, /*wStrides=*/{4, 1, 1, 1}, + /*yDims=*/{4, 4, 4, 4}, /*yStrides=*/{64, 16, 4, 1}, + /*convPrePadding=*/{1, 0}, // asymmetric: pre doesn't match post + /*convPostPadding=*/{2, 1}, // asymmetric: pre doesn't match post + /*convStrides=*/{1, 1}, /*convDilation=*/{1, 1}); + opGraph.ptr = builder.GetBufferPointer(); + opGraph.size = builder.GetSize(); + + // Fusilli plugin should not offer to compile and execute single node + // conv_fprop with asymmetric padding. + ASSERT_EQ(hipdnnEnginePluginGetApplicableEngineIds( + handle, &opGraph, engineIDs.data(), 5, &numEngines), + HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_EQ(numEngines, 0); +} + +TEST(TestFusilliPluginApi, CreateExecutionContext) { + // Create plugin handle. + hipdnnEnginePluginHandle_t handle = nullptr; + ASSERT_EQ(hipdnnEnginePluginCreate(&handle), HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_NE(handle, nullptr); + + // UIDs. + int64_t xUID = 1; + int64_t wUID = 2; + int64_t yUID = 3; + + // Dims and strides. + const std::vector expectedXDims = {4, 4, 4, 4}; + const std::vector expectedXStrides = {64, 16, 4, 1}; + const std::vector expectedWDims = {4, 4, 1, 1}; + const std::vector expectedWStrides = {4, 1, 1, 1}; + const std::vector expectedYDims = {4, 4, 4, 4}; + const std::vector expectedYStrides = {64, 16, 4, 1}; + const hipdnn_sdk::data_objects::DataType dataType = + hipdnn_sdk::data_objects::DataType::FLOAT; + fusilli::DataType expectedDataType = + FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType(dataType)); + + // Create a serialized hipDNN conv_fprop. + auto builder = createValidConvFwdGraph( + xUID, wUID, yUID, dataType, expectedXDims, expectedXStrides, + expectedWDims, expectedWStrides, expectedYDims, expectedYStrides); + hipdnnPluginConstData_t opGraph; + opGraph.ptr = builder.GetBufferPointer(); + opGraph.size = builder.GetSize(); + + // Create engine config. + flatbuffers::FlatBufferBuilder configBuilder; + auto engineConfig = hipdnn_sdk::data_objects::CreateEngineConfig( + configBuilder, FUSILLI_PLUGIN_ENGINE_ID); + configBuilder.Finish(engineConfig); + hipdnnPluginConstData_t engineConfigData; + engineConfigData.ptr = configBuilder.GetBufferPointer(); + engineConfigData.size = configBuilder.GetSize(); + + // The function we're actually testing. + hipdnnEnginePluginExecutionContext_t executionContext = nullptr; + ASSERT_EQ(hipdnnEnginePluginCreateExecutionContext( + handle, &engineConfigData, &opGraph, &executionContext), + HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_NE(executionContext, nullptr); + + auto *ctx = + static_cast(executionContext); + + // Check that we have 3 tensors tracked (x, w, y). + EXPECT_EQ(ctx->uidToFusilliTensorAttr.size(), 3); + + // Check x tensor properties. + ASSERT_TRUE(ctx->uidToFusilliTensorAttr.contains(xUID)); // C++ 20 + std::shared_ptr xTensor = + ctx->uidToFusilliTensorAttr[xUID]; + EXPECT_EQ(xTensor->getDim(), expectedXDims); + EXPECT_EQ(xTensor->getStride(), expectedXStrides); + EXPECT_EQ(xTensor->getDataType(), expectedDataType); + EXPECT_FALSE(xTensor->isVirtual()); + + // Check w tensor properties. + ASSERT_TRUE(ctx->uidToFusilliTensorAttr.contains(wUID)); // C++ 20 + std::shared_ptr wTensor = + ctx->uidToFusilliTensorAttr[wUID]; + EXPECT_EQ(wTensor->getDim(), expectedWDims); + EXPECT_EQ(wTensor->getStride(), expectedWStrides); + EXPECT_EQ(wTensor->getDataType(), expectedDataType); + EXPECT_FALSE(wTensor->isVirtual()); + + // Check y tensor properties. + ASSERT_TRUE(ctx->uidToFusilliTensorAttr.contains(wUID)); // C++ 20 + std::shared_ptr yTensor = + ctx->uidToFusilliTensorAttr[yUID]; + EXPECT_EQ(yTensor->getDim(), expectedYDims); + EXPECT_EQ(yTensor->getStride(), expectedYStrides); + EXPECT_EQ(yTensor->getDataType(), expectedDataType); + EXPECT_FALSE(yTensor->isVirtual()); + + // Verify graph properties. + EXPECT_EQ(ctx->graph.context.getIODataType(), expectedDataType); + EXPECT_EQ(ctx->graph.context.getIntermediateDataType(), expectedDataType); + EXPECT_EQ(ctx->graph.context.getComputeDataType(), expectedDataType); + + // Clean up. + EXPECT_EQ(hipdnnEnginePluginDestroyExecutionContext(handle, executionContext), + HIPDNN_PLUGIN_STATUS_SUCCESS); + EXPECT_EQ(hipdnnEnginePluginDestroy(handle), HIPDNN_PLUGIN_STATUS_SUCCESS); +} diff --git a/projects/fusilli-plugin/test/test_graph_import.cpp b/projects/fusilli-plugin/test/test_graph_import.cpp new file mode 100644 index 00000000000..fa36ecfa5ab --- /dev/null +++ b/projects/fusilli-plugin/test/test_graph_import.cpp @@ -0,0 +1,40 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "graph_import.h" +#include "utils.h" + +#include +#include +#include + +TEST(TestGraphImport, ConvertHipDnnToFusilli) { + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::HALF)), + fusilli::DataType::Half); + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::BFLOAT16)), + fusilli::DataType::BFloat16); + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::FLOAT)), + fusilli::DataType::Float); + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::DOUBLE)), + fusilli::DataType::Double); + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::UINT8)), + fusilli::DataType::Uint8); + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::INT32)), + fusilli::DataType::Int32); + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::UNSET)), + fusilli::DataType::NotSet); + + auto invalidResult = hipDnnDataTypeToFusilliDataType( + static_cast(42)); + EXPECT_TRUE(isError(invalidResult)); +} diff --git a/projects/fusilli-plugin/test/utils.h b/projects/fusilli-plugin/test/utils.h new file mode 100644 index 00000000000..de16707fc8e --- /dev/null +++ b/projects/fusilli-plugin/test/utils.h @@ -0,0 +1,29 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file contains utilities for fusilli plugin tests. +// +//===----------------------------------------------------------------------===// + +#ifndef FUSILLI_PLUGIN_TESTS_UTILS_H +#define FUSILLI_PLUGIN_TESTS_UTILS_H + +// Unwrap the type returned from an expression that evaluates to an ErrorOr, +// fail the test using GTest's EXPECT_TRUE if the result is an ErrorObject. +// +// This is very similar to FUSILLI_TRY, but FUSILLI_TRY propagates an error to +// callers on the error path, this fails the test on the error path. The two +// macros are analogous to rust's `?` (try) operator and `.unwrap()` call. +#define FUSILLI_PLUGIN_EXPECT_UNWRAP(expr) \ + ({ \ + auto _errorOr = (expr); \ + EXPECT_TRUE(isOk(_errorOr)); \ + std::move(*_errorOr); \ + }) + +#endif // FUSILLI_PLUGIN_TESTS_UTILS_H From 0ee34edff66c9af4a0b2ca7b57e3018ef120a421 Mon Sep 17 00:00:00 2001 From: Aaron St George Date: Tue, 30 Sep 2025 16:16:09 -0600 Subject: [PATCH 05/14] [FusilliPlugin] Graph execution (`fusilli-plugin` 4 of 4) (#2384) This PR implements `hipdnnEnginePluginExecuteOpGraph` enabling graph/plan execution. We can now complete the initial plugin goal of e2e run of single node ConvFprop graph. Note: a simple e2e integration test is added. This should be expanded to a parameterized matrix of test, once we're on the latest hipDNN [example](https://github.com/ROCm/rocm-libraries/blob/develop/projects/hipdnn/plugins/miopen_legacy_plugin/integration_tests/IntegrationGpuConvForward.cpp). Tickets: [#2363](https://github.com/nod-ai/shark-ai/issues/2363), [#2375](https://github.com/nod-ai/shark-ai/issues/2375). --- .../fusilli-plugin/include/graph_import.h | 2 +- .../hipdnn_engine_plugin_execution_context.h | 2 +- projects/fusilli-plugin/include/utils.h | 52 +++++++ .../fusilli-plugin/src/fusilli_plugin.cpp | 90 +++++++++++- .../test/integration/CMakeLists.txt | 3 +- ...t_basic_integration.cpp => test_basic.cpp} | 0 .../test/integration/test_convfprop.cpp | 128 ++++++++++++++++++ 7 files changed, 268 insertions(+), 9 deletions(-) rename projects/fusilli-plugin/test/integration/{test_basic_integration.cpp => test_basic.cpp} (100%) create mode 100644 projects/fusilli-plugin/test/integration/test_convfprop.cpp diff --git a/projects/fusilli-plugin/include/graph_import.h b/projects/fusilli-plugin/include/graph_import.h index 3435dbf8d12..4046cb78f96 100644 --- a/projects/fusilli-plugin/include/graph_import.h +++ b/projects/fusilli-plugin/include/graph_import.h @@ -27,7 +27,7 @@ #include "fusilli/support/logging.h" #include "hipdnn_engine_plugin_execution_context.h" -// Convert from hipDNN DataType to fusilli DataType +// Convert from hipDNN DataType to fusilli DataType. inline fusilli::ErrorOr hipDnnDataTypeToFusilliDataType(hipdnn_sdk::data_objects::DataType hipdnnType) { switch (hipdnnType) { diff --git a/projects/fusilli-plugin/include/hipdnn_engine_plugin_execution_context.h b/projects/fusilli-plugin/include/hipdnn_engine_plugin_execution_context.h index 9c7c8fde39b..d5a38da68b0 100644 --- a/projects/fusilli-plugin/include/hipdnn_engine_plugin_execution_context.h +++ b/projects/fusilli-plugin/include/hipdnn_engine_plugin_execution_context.h @@ -17,7 +17,7 @@ // (plan in hipDNN parlance) in a hot loop without any overhead. For fusilli // plugin, that maps to constructing and storing a fusilli::Graph based on // hipDNN graph. When an execution is requested, it should be a simple lookup -// for uid -> tensor attribute, then a graph execution. +// for UID -> tensor attribute, then a graph execution. // //===----------------------------------------------------------------------===// diff --git a/projects/fusilli-plugin/include/utils.h b/projects/fusilli-plugin/include/utils.h index 7b56ad305c3..d92ef85b93f 100644 --- a/projects/fusilli-plugin/include/utils.h +++ b/projects/fusilli-plugin/include/utils.h @@ -13,9 +13,12 @@ #ifndef FUSILLI_PLUGIN_SRC_UTILS_H #define FUSILLI_PLUGIN_SRC_UTILS_H +#include "fusilli/attributes/types.h" +#include "fusilli/support/logging.h" #include #include #include +#include // Checks for null, sets the plugin last error manager and returns error if // null. @@ -33,6 +36,21 @@ template hipdnnPluginStatus_t isNull(T *value) { return HIPDNN_PLUGIN_STATUS_SUCCESS; } +// Find deviceBuffer with UID. +inline fusilli::ErrorOr +findDeviceBuffer(int64_t uid, const hipdnnPluginDeviceBuffer_t *deviceBuffers, + uint32_t numDeviceBuffers) { + for (uint32_t i = 0; i < numDeviceBuffers; i++) { + if (uid == deviceBuffers[i].uid) { + return fusilli::ok(deviceBuffers[i]); + } + } + + return fusilli::error(fusilli::ErrorCode::AttributeNotSet, + "Device buffer with the uid: " + std::to_string(uid) + + " not found in the provided device buffers."); +} + // If null, set plugin error manager last error to // HIPDNN_PLUGIN_STATUS_BAD_PARAM and return said error from the enclosing // scope. @@ -96,4 +114,38 @@ template hipdnnPluginStatus_t isNull(T *value) { } \ } while (false) +// Convert from fusilli DataType to iree hal data type. +inline fusilli::ErrorOr +fusilliDataTypeToIreeHalDataType(fusilli::DataType fusilliDataType) { + switch (fusilliDataType) { + case fusilli::DataType::Half: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_FLOAT_16); + case fusilli::DataType::BFloat16: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_BFLOAT_16); + case fusilli::DataType::Float: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_FLOAT_32); + case fusilli::DataType::Double: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_FLOAT_64); + case fusilli::DataType::Uint8: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_UINT_8); + case fusilli::DataType::Int8: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_INT_8); + case fusilli::DataType::Int16: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_INT_16); + case fusilli::DataType::Int32: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_INT_32); + case fusilli::DataType::Int64: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_INT_64); + case fusilli::DataType::Boolean: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_BOOL_8); + case fusilli::DataType::FP8E5M2: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2); + case fusilli::DataType::NotSet: + default: + return fusilli::error( + fusilli::ErrorCode::InvalidAttribute, + "unknown data type in fusilli -> iree runtime data type conversion"); + } +} + #endif // FUSILLI_PLUGIN_SRC_UTILS_H diff --git a/projects/fusilli-plugin/src/fusilli_plugin.cpp b/projects/fusilli-plugin/src/fusilli_plugin.cpp index 74057a04287..417ce9f1bc1 100644 --- a/projects/fusilli-plugin/src/fusilli_plugin.cpp +++ b/projects/fusilli-plugin/src/fusilli_plugin.cpp @@ -12,6 +12,8 @@ //===----------------------------------------------------------------------===// // hipDNN logging expects COMPONENT_NAME to be defined +#include +#include #define COMPONENT_NAME FUSILLI_PLUGIN_NAME #include @@ -443,12 +445,88 @@ hipdnnPluginStatus_t hipdnnEnginePluginExecuteOpGraph( FUSILLI_PLUGIN_CHECK_NULL(executionContext); FUSILLI_PLUGIN_CHECK_NULL(deviceBuffers); - // TODO: Implement graph execution. - // This is a stub plugin, the full implementation would: - // 1. Map device buffers to fusilli tensor attributes based on uid mapping - // stored on executionContext. - // 2. Create IREE buffer views from HIP device pointers. - // 3. Execute the compiled graph. + // Params and allocators hoisted out of loop below. + iree_hal_allocator_t *deviceAllocator = + iree_hal_device_allocator(handle->fusilliHandle); + iree_allocator_t ireeHostAllocator = iree_allocator_system(); + iree_hal_buffer_params_t bufferParams = { + .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, + .access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE, + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + }; + + // Fill variant pack for graph execution. Fusilli expects a variant pack to + // map from fusilli::TensorAttr -> fusilli::Buffer for all boundary tensors. + // + // The execution context (created by hipdnnEnginePluginCreateExecutionContext) + // holds a UID -> fusilli::TensorAttr mapping for all boundary tensors + // already. To build the mapping we need to: + // 1. Find the external HIP-allocated device buffer in `deviceBuffers` + // associated with UID. + // 2. Import buffer from 1) into IREE runtime and create fusilli::Buffer. + // + // We may want to cache all of this in the future. As long as the device + // pointers + UIDs haven't changed it should be possible to re-use an already + // imported buffer + buffer view + the call that fusilli::Graph::execute + // builds internally. + std::unordered_map, + std::shared_ptr> + variantPack; + for (auto &[uid, tensorAttr] : executionContext->uidToFusilliTensorAttr) { + // 1. Find associated buffer. + hipdnnPluginDeviceBuffer_t hipMallocedBuffer = FUSILLI_PLUGIN_TRY( + findDeviceBuffer(uid, deviceBuffers, numDeviceBuffers)); + + // 2.1. Import external buffer into IREE runtime. This isn't allocating a + // buffer, it's making an existing allocation available to the IREE runtime. + iree_hal_external_buffer_t externalBuffer = { + .type = IREE_HAL_EXTERNAL_BUFFER_TYPE_DEVICE_ALLOCATION, + .flags = 0, + .size = static_cast(sizeof(float) * + tensorAttr->getVolume()), + .handle = + { + .device_allocation = + { + .ptr = (uint64_t)hipMallocedBuffer.ptr, + }, + }, + }; + iree_hal_buffer_t *importedBuffer = nullptr; + FUSILLI_PLUGIN_CHECK_ERROR(iree_hal_allocator_import_buffer( + deviceAllocator, bufferParams, &externalBuffer, + iree_hal_buffer_release_callback_null(), &importedBuffer)); + + // 2.2. Create a buffer view for external buffer. + iree_hal_buffer_view_t *outBufferView = nullptr; + FUSILLI_PLUGIN_CHECK_ERROR(iree_hal_buffer_view_create( + /*buffer=*/importedBuffer, /*shape_rank=*/tensorAttr->getDim().size(), + /*shape=*/(const iree_hal_dim_t *)tensorAttr->getDim().data(), + /*element_type=*/ + FUSILLI_PLUGIN_TRY( + fusilliDataTypeToIreeHalDataType(tensorAttr->getDataType())), + /*encoding_type=*/IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + /*host_allocator=*/ireeHostAllocator, + /*out_buffer_view=*/&outBufferView)); + + // Release our reference to buffer. The buffer view holds a reference to + // buffer and will handle release + possible destruction when it's + // destroyed. + iree_hal_buffer_release(importedBuffer); + + // 2.3. Create fusilli::Buffer from buffer view. Buffer::import is a RAII + // type that retains the buffer view, incrementing its reference count, on + // construction and releases the buffer view on destruction. + variantPack[tensorAttr] = std::make_shared( + FUSILLI_PLUGIN_TRY(fusilli::Buffer::import(outBufferView))); + + // Release our reference to buffer view. The buffer view and buffer will + // (now) be tied to fusilli::Buffer's lifetime as it holds the only + // reference to the buffer view. + iree_hal_buffer_view_release(outBufferView); + } + + FUSILLI_PLUGIN_CHECK_ERROR(executionContext->graph.execute(variantPack)); LOG_API_SUCCESS_AUTO("executed graph"); return HIPDNN_PLUGIN_STATUS_SUCCESS; diff --git a/projects/fusilli-plugin/test/integration/CMakeLists.txt b/projects/fusilli-plugin/test/integration/CMakeLists.txt index 7b11220a761..0c06e8fe719 100644 --- a/projects/fusilli-plugin/test/integration/CMakeLists.txt +++ b/projects/fusilli-plugin/test/integration/CMakeLists.txt @@ -9,7 +9,8 @@ find_package(Threads REQUIRED) # Integration tests add_executable(fusilli_plugin_integration_tests - test_basic_integration.cpp + test_convfprop.cpp + test_basic.cpp ) target_compile_options(fusilli_plugin_integration_tests PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS}) target_link_libraries(fusilli_plugin_integration_tests PRIVATE diff --git a/projects/fusilli-plugin/test/integration/test_basic_integration.cpp b/projects/fusilli-plugin/test/integration/test_basic.cpp similarity index 100% rename from projects/fusilli-plugin/test/integration/test_basic_integration.cpp rename to projects/fusilli-plugin/test/integration/test_basic.cpp diff --git a/projects/fusilli-plugin/test/integration/test_convfprop.cpp b/projects/fusilli-plugin/test/integration/test_convfprop.cpp new file mode 100644 index 00000000000..3d959cd9676 --- /dev/null +++ b/projects/fusilli-plugin/test/integration/test_convfprop.cpp @@ -0,0 +1,128 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +using namespace hipdnn_frontend; +using namespace hipdnn_sdk::utilities; +using namespace hipdnn_sdk::test_utilities; + +TEST(ConvFpropIntegrationTest, Basic1x1Convolution) { + // Uncomment to enable debug logging + // setenv("HIPDNN_LOG_LEVEL", "info", 1); + + // Initialize HIP + ASSERT_EQ(hipInit(0), hipSuccess); + int deviceId; + ASSERT_EQ(hipGetDevice(&deviceId), hipSuccess); + + // Set plugin path + const std::array paths = {FUSILLI_PLUGIN_DIR}; + ASSERT_EQ(hipdnnSetEnginePluginPaths_ext(paths.size(), paths.data(), + HIPDNN_PLUGIN_LOADING_ABSOLUTE), + HIPDNN_STATUS_SUCCESS); + + // Create handle + hipdnnHandle_t handle; + ASSERT_EQ(hipdnnCreate(&handle), HIPDNN_STATUS_SUCCESS); + + // Dimensions + const int64_t n = 16; // batch + const int64_t c = 128; // in channels + const int64_t h = 64; // image height + const int64_t w = 64; // image width + const int64_t k = 256; // out channels + const int64_t r = 1; // filter height + const int64_t s = 1; // filter width + + // UIDs + const int64_t xUID = 0; + const int64_t wUID = 1; + const int64_t yUID = 2; + + // Initialize tensors + PinnedTensor xTensor({n, c, h, w}); + PinnedTensor wTensor({k, c, r, s}); + PinnedTensor yTensor({n, k, h, w}); + xTensor.fillWithValue(1.0f); + wTensor.fillWithValue(1.0f); + yTensor.fillWithValue(-100.0f); + + // Expected output + PinnedTensor expectedOutput({n, k, h, w}); + expectedOutput.fillWithValue(128.0f); + + // Create graph + auto graph = std::make_shared(); + graph->set_name("conv_1x1_test"); + graph->set_io_data_type(DataType_t::FLOAT) + .set_compute_data_type(DataType_t::FLOAT); + + // Create tensor attributes + auto xAttr = std::make_shared( + graph::makeTensorAttributes("input", DataType_t::FLOAT, xTensor)); + xAttr->set_uid(xUID); + auto wAttr = std::make_shared( + graph::makeTensorAttributes("filter", DataType_t::FLOAT, wTensor)); + wAttr->set_uid(wUID); + + // Create convolution attributes + graph::ConvFpropAttributes convAttr; + convAttr.set_name("conv_fprop") + .set_padding({0, 0}) + .set_stride({1, 1}) + .set_dilation({1, 1}); + + // Create graph + auto yAttr = graph->conv_fprop(xAttr, wAttr, convAttr); + yAttr->set_uid(yUID); + yAttr->set_dim(yTensor.dims()).set_stride(yTensor.strides()).set_output(true); + + // Build + validate + build plans for graph + auto result = graph->validate(); + ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; + + result = graph->build_operation_graph(handle); + ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; + + result = graph->create_execution_plans(); + ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; + + result = graph->check_support(); + ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; + + result = graph->build_plans(); + ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; + + // Create variant pack + std::unordered_map variantPack; + variantPack[xUID] = xTensor.memory().deviceData(); + variantPack[wUID] = wTensor.memory().deviceData(); + variantPack[yUID] = yTensor.memory().deviceData(); + + // Execute graph + result = graph->execute(handle, variantPack, nullptr); + ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; + // Mark hipDNN tensor CPU cache ask stale, data must be read from device. + yTensor.memory().markDeviceModified(); + + // Check results + CpuFpReferenceValidation validator(1e-6f, 1e-6f); + EXPECT_TRUE(validator.allClose(expectedOutput.memory(), yTensor.memory())); + + // Cleanup + ASSERT_EQ(hipdnnDestroy(handle), HIPDNN_STATUS_SUCCESS); +} From 8b485a49c0848901a42b89037a30efbf16ca668b Mon Sep 17 00:00:00 2001 From: Aaron St George Date: Thu, 2 Oct 2025 09:52:01 -0600 Subject: [PATCH 06/14] [Fusilli,FusilliPlugin] Update hipdnn and clang version (#2405) Update build to pull hipdnn from `rocm-libraries` and update clang to version 20 from TheRock. Co-authored-by: Sambhav Jain --- projects/fusilli-plugin/CMakeLists.txt | 2 +- .../cmake/FusilliPluginDependencyUtils.cmake | 62 ++++++++++++------- .../fusilli-plugin/src/fusilli_plugin.cpp | 10 +-- .../test/test_fusilli_plugin_api.cpp | 2 +- 4 files changed, 46 insertions(+), 30 deletions(-) diff --git a/projects/fusilli-plugin/CMakeLists.txt b/projects/fusilli-plugin/CMakeLists.txt index 2f8a4b9678f..a9cee7b9b45 100644 --- a/projects/fusilli-plugin/CMakeLists.txt +++ b/projects/fusilli-plugin/CMakeLists.txt @@ -35,7 +35,7 @@ set(HIP_PLATFORM "amd") find_package(hip REQUIRED) fusilli_plugin_dependency(GTest GTEST_VERSION 1.16.0) fusilli_plugin_dependency(IREERuntime) -fusilli_plugin_dependency(hipdnn_frontend HIP_DNN_HASH 8a49655a7cc73f1c766256a7ead3a987e830efe9) +fusilli_plugin_dependency(hipdnn_frontend HIP_DNN_HASH 4e0a0452cfcb8fdb86e9c40a6e43debab4d4ecbc) fusilli_plugin_dependency(Fusilli USE_LOCAL ${FUSILLI_PLUGIN_USE_LOCAL_FUSILLI}) # Includes diff --git a/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake index d6ec2eaca6e..969ce5644cc 100644 --- a/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake +++ b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake @@ -133,30 +133,46 @@ macro(_fetch_hipdnn_frontend) message(FATAL_ERROR "Argument error: passing both LOCAL_PATH and HIP_DNN_HASH is ambiguous.") endif() - if (DEFINED ARG_HIP_DNN_HASH) - FetchContent_Declare( - hipdnn_frontend - GIT_REPOSITORY https://github.com/ROCm/hipDNN.git - GIT_TAG ${ARG_HIP_DNN_HASH} - # When FIND_PACKAGE_ARGS is passed, FetchContent_Declare tries to - # find_package an installed version before downloading. - FIND_PACKAGE_ARGS CONFIG - ) - else() - FetchContent_Declare( - hipdnn_frontend - SOURCE_DIR ${ARG_LOCAL_PATH} - ) + # We would normally check for, and preferentially use, an installed config + # package using FIND_PACKAGE_ARGS CONFIG on FetchContent_Declare. But, CMake + # throws an error if both FIND_PACKAGE_ARGS and DOWNLOAD_COMMAND arguments + # are passed, it's one or the other. + find_package( + hipdnn_frontend QUIET ${PARSE_FIND_PACKAGE_ARGS} + ) + if(NOT hipdnn_frontend_FOUND) # we can't early return in a macro + if (DEFINED ARG_HIP_DNN_HASH) + FetchContent_Declare( + hipdnn_frontend + # location of hipdnn CMakeLists.txt in rocm-libraries + SOURCE_SUBDIR projects/hipdnn + # rocm-libraries takes 10+ min to fetch without sparse checkout + # (even with a shallow clone). We provide a custom + # DOWNLOAD_COMMAND until such time as CMAKE natively supports + # sparse checkouts. + DOWNLOAD_COMMAND + git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && + cd && + git sparse-checkout init --cone && + git sparse-checkout set projects/hipdnn && + git checkout ${ARG_HIP_DNN_HASH} + ) + else() + FetchContent_Declare( + hipdnn_frontend + SOURCE_DIR ${ARG_LOCAL_PATH} + ) + endif() + + set(HIP_DNN_BUILD_BACKEND ON) + set(HIP_DNN_BUILD_FRONTEND ON) + set(HIP_DNN_SKIP_TESTS ON) + set(HIP_DNN_BUILD_PLUGINS OFF) + set(ENABLE_CLANG_TIDY OFF) + # PIC required to link static library into shared object. + set(CMAKE_POSITION_INDEPENDENT_CODE ON) + FetchContent_MakeAvailable(hipdnn_frontend) endif() - - set(HIP_DNN_BUILD_BACKEND ON) - set(HIP_DNN_BUILD_FRONTEND ON) - set(HIP_DNN_SKIP_TESTS ON) - set(HIP_DNN_BUILD_PLUGINS OFF) - set(ENABLE_CLANG_TIDY OFF) - # PIC required to link static library into shared object. - set(CMAKE_POSITION_INDEPENDENT_CODE ON) - FetchContent_MakeAvailable(hipdnn_frontend) endmacro() # IREERuntime diff --git a/projects/fusilli-plugin/src/fusilli_plugin.cpp b/projects/fusilli-plugin/src/fusilli_plugin.cpp index 417ce9f1bc1..0aa097a333f 100644 --- a/projects/fusilli-plugin/src/fusilli_plugin.cpp +++ b/projects/fusilli-plugin/src/fusilli_plugin.cpp @@ -106,7 +106,7 @@ hipdnnPluginStatus_t hipdnnPluginSetLoggingCallback(hipdnnCallback_t callback) { hipdnn::logging::initializeCallbackLogging(FUSILLI_PLUGIN_NAME, callback); - LOG_API_SUCCESS_AUTO("logging callback initialized"); + LOG_API_SUCCESS_AUTO("{}", "logging callback initialized"); return HIPDNN_PLUGIN_STATUS_SUCCESS; } @@ -164,7 +164,7 @@ hipdnnEnginePluginDestroy(hipdnnEnginePluginHandle_t handle) { delete handle; - LOG_API_SUCCESS_AUTO(""); + LOG_API_SUCCESS_AUTO("", ""); return HIPDNN_PLUGIN_STATUS_SUCCESS; } @@ -180,7 +180,7 @@ hipdnnEnginePluginSetStream(hipdnnEnginePluginHandle_t handle, // a default IREE runtime device and execute programs on a stream associated // with that device. The passed in stream is ignored. - LOG_API_SUCCESS_AUTO(""); + LOG_API_SUCCESS_AUTO("", ""); return HIPDNN_PLUGIN_STATUS_SUCCESS; } @@ -424,7 +424,7 @@ hipdnnPluginStatus_t hipdnnEnginePluginDestroyExecutionContext( delete executionContext; - LOG_API_SUCCESS_AUTO("destroyed executionContext"); + LOG_API_SUCCESS_AUTO("", "destroyed executionContext"); return HIPDNN_PLUGIN_STATUS_SUCCESS; } @@ -528,7 +528,7 @@ hipdnnPluginStatus_t hipdnnEnginePluginExecuteOpGraph( FUSILLI_PLUGIN_CHECK_ERROR(executionContext->graph.execute(variantPack)); - LOG_API_SUCCESS_AUTO("executed graph"); + LOG_API_SUCCESS_AUTO("{}", "executed graph"); return HIPDNN_PLUGIN_STATUS_SUCCESS; } diff --git a/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp b/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp index e3ad40d90b0..880b9069a4d 100644 --- a/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp +++ b/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp @@ -232,7 +232,7 @@ TEST(TestFusilliPluginApi, GetApplicableEngineIds) { ASSERT_NE(handle, nullptr); // Create a serialized hipDNN bach norm graph. - auto builder = hipdnn_backend::test_utilities::createValidBatchnormGraph(); + auto builder = hipdnn_sdk::test_utilities::createValidBatchnormBwdGraph(); hipdnnPluginConstData_t opGraph; opGraph.ptr = builder.GetBufferPointer(); opGraph.size = builder.GetSize(); From 5a228946134f59770f1507d8cde17dc1cca0b9f4 Mon Sep 17 00:00:00 2001 From: Aaron St George Date: Wed, 8 Oct 2025 10:56:38 -0600 Subject: [PATCH 07/14] [Fusilli,FusilliPlugin] Fix missing `utils.h` in `fusilli-plugin` build. (#2449) This PR fixes `fusilli-plugin` build failure when fusilli benchmarking (`FUSILLI_BUILD_BENCHMARKS`) was enabled, but fusilli testing (`FUSILLI_BUILD_TESTS`) was disabled: [link to CI failure](https://github.com/nod-ai/shark-ai/actions/runs/18206945208/job/51839515287?pr=2410). Both _should_ be disabled, but ideally the build wouldn't break if one wasn't. The issue was: both fusilli tests and fusilli benchmarks depend on `libutils`, but the target was only defined if tests were enabled. --- .../build_tools/cmake/FusilliPluginDependencyUtils.cmake | 2 ++ 1 file changed, 2 insertions(+) diff --git a/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake index 969ce5644cc..d3740a510d4 100644 --- a/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake +++ b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake @@ -217,5 +217,7 @@ macro(_fetch_Fusilli) SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../sharkfuser ) set(FUSILLI_BUILD_TESTS OFF) + set(FUSILLI_BUILD_BENCHMARKS OFF) + set(FUSILLI_SYSTEMS_AMDGPU ON) FetchContent_MakeAvailable(Fusilli) endmacro() From 2f85cf8c5212acf3ab8f065e8f6c75e85ece6875 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Tue, 14 Oct 2025 10:21:54 -0700 Subject: [PATCH 08/14] [Fusilli] Generalize to AMDGPU backend and extract HIP arch from rocm_agent_enumerator (#2488) Got this building on my SharkWorkstation with an RX 9070XT (RDNA4) targeting the `gfx1201` arch with minimal changes. The `Backend` enum doesn't need to be specific to the arch, and the iree-compile flag `--iree-hip-target` can just be derived using `rocm_agent_enumerator` (which lists the gfx arch from the GPUs visible) based on https://iree.dev/guides/deployment-configurations/gpu-rocm/#choosing-hip-targets. --- projects/fusilli-plugin/src/fusilli_plugin.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/fusilli-plugin/src/fusilli_plugin.cpp b/projects/fusilli-plugin/src/fusilli_plugin.cpp index 0aa097a333f..5adef515c2e 100644 --- a/projects/fusilli-plugin/src/fusilli_plugin.cpp +++ b/projects/fusilli-plugin/src/fusilli_plugin.cpp @@ -150,7 +150,7 @@ hipdnnEnginePluginCreate(hipdnnEnginePluginHandle_t *handle) { // instance creation is thread safe, so this should be thread safe. // TODO(#2335): handle multiple architectures auto fusilliHandle = - FUSILLI_PLUGIN_TRY(fusilli::Handle::create(fusilli::Backend::GFX942)); + FUSILLI_PLUGIN_TRY(fusilli::Handle::create(fusilli::Backend::AMDGPU)); *handle = new HipdnnEnginePluginHandle(std::move(fusilliHandle)); LOG_API_SUCCESS_AUTO("createdHandle={:p}", static_cast(*handle)); From a830fbfb5a2ff11cc4d0a60b776af31d1900bc13 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Thu, 16 Oct 2025 16:48:41 -0700 Subject: [PATCH 09/14] Initial draft/placeholder for RFC to move Fusilli-plugin into TheRock monorepo. (#2521) Signed-off-by: MaheshRavishankar --- projects/fusilli-plugin/docs/TheRockRFC.md | 83 ++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 projects/fusilli-plugin/docs/TheRockRFC.md diff --git a/projects/fusilli-plugin/docs/TheRockRFC.md b/projects/fusilli-plugin/docs/TheRockRFC.md new file mode 100644 index 00000000000..154fbd4db25 --- /dev/null +++ b/projects/fusilli-plugin/docs/TheRockRFC.md @@ -0,0 +1,83 @@ +| author | created | modified | status | +|--------|---------|----------|--------| +| Sambhav Jain, Aaron St. George and Mahesh Ravishankar | | | draft | + + +# Add IREE as a kernel provider for hipDNN + +## Overview + +[IREE](https://github.com/iree-org/iree/) is an open source ML +compiler stack built using MLIR that is intended to support the +compilation and execution of ML models. While IREE is setup to be +multi-targeting, over the past couple of years a lot of effort has +gone into improving the codegeneration for AMDGPUs, specifically +Instinct class GPUs. While a lot of the IREE compiler stack is meant +to optimize execution of full-scale ML models, one key component of +the work is to have efficient kernel code generation for MI300+ +cards. [Fusilli](https://github.com/nod-ai/shark-ai/tree/main/sharkfuser) +is a C++ graph API and JIT fronend that leverages the kernel +codegeneration capabilities of IREE and packages it to be useable as a +kernel provider within hipDNN. This allows use of IREE for targeted +portions of the program, even for training use cases. The advantages +of using IREE this way are + +1) IREE has been built from the ground-up as a fusion compiler. The + kinds of fusions that libraries like hipDNN are expected to provide + are supported out-of-the box in IREE. + +2) Using a compiler as a kernel provider through Fusilli's JIT + interface helps pick IREE generated kernels without having to ship + pre-built kernels with hipDNN - saving both build time and space. + +This RFC is to propose adding a path to using IREE as a kernel +provider to hipDNN within TheRock. There are three components to +reason about + +1. The hipDNN backend plugin to Fusilli. Currently it lives + [here](https://github.com/nod-ai/shark-ai/tree/main/fusilli-plugin) + +2. Fusilli. Currently lives + [here](https://github.com/nod-ai/shark-ai/tree/main/sharkfuser) + +3. IREE, which lives in its own Github Org and is a Linux Foundation + Project. It currently lives [here](https://github.com/iree-org/iree) + +## Workplan + +### Immediate next steps + +The immediate workplan is to move just the hipDNN backend plugin to +Fusilli (so component 1 above) into TheRock. The plugin will be built +conditionally (not on by default) and will pull in Fusilli and IREE as +external dependencies. Nested dependencies of Fusilli are (TODO: +Adjust/clarify the following). + +1. IREE runtime sources are built into Fusilli + +2. IREE compile using the command-line interface, i.e. the + `iree-compile` binary needs to be available (typically through + pip-install of the IREE compiler package) + +### Medium term/Long term requirements. + +While the initial integration will just focus on pulling in the hipDNN +plugin to IREE into monorepo, long term the expectation is that +Fusilli and IREE are brought in through official release mechanisms +that allow TheRock to seamlessly pull them in (through the usual +versioning mechanisms). Some question that need to be answered for +those are + +1. Where does Fusilli live? Fusilli is a C++ API around IREE and such + is tightlt coupled with IREE. A natural home for Fusilli is within + the same repo/github organization as IREE itself. + +2. The expectation is that Fusilli will stop using `iree-compile` as a + binary, but rather use the C-API of the IREE compiler to JIT + compile the (fused) kernel computation. This would require + significant changes to current IREE workflow. Apart from resolving + where the IREE project lives, i.e. if it should move into the + monorepo as well, another challenge to solve there is which LLVM + version should IREE use. IREE currently tracks top-of-main of LLVM + pretty closely. This would need to change to use either the LLVM + version within monorepo or a release version of LLVM/MLIR. From d7cb049cbfbedf334e2d993c8b1129f71b6f684b Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Fri, 17 Oct 2025 11:30:02 -0700 Subject: [PATCH 10/14] [FusilliPlugin] TheRock RFC revisions (#2546) Add Fusilli+IREE as a kernel provider and JIT engine for hipDNN --- projects/fusilli-plugin/docs/TheRockRFC.md | 168 ++++++++++++--------- 1 file changed, 93 insertions(+), 75 deletions(-) diff --git a/projects/fusilli-plugin/docs/TheRockRFC.md b/projects/fusilli-plugin/docs/TheRockRFC.md index 154fbd4db25..db971a1247c 100644 --- a/projects/fusilli-plugin/docs/TheRockRFC.md +++ b/projects/fusilli-plugin/docs/TheRockRFC.md @@ -1,83 +1,101 @@ -| author | created | modified | status | -|--------|---------|----------|--------| -| Sambhav Jain, Aaron St. George and Mahesh Ravishankar | | | draft | - - -# Add IREE as a kernel provider for hipDNN +--- +author(s): Sambhav Jain, Aaron St. George, and Mahesh Ravishankar +created: 2025-10-16 +modified: 2025-10-16 +status: draft +discussion: https://github.com/ROCm/TheRock/discussions/1817 +--- + +# Fusilli+IREE as a kernel provider and JIT engine for hipDNN + +This RFC proposes adding IREE as a kernel provider to hipDNN to leverage JIT +compiled and codegenerated kernels in ML training and inference solutions. +This is made possible with the development of Fusilli - a C++ graph API and +JIT engine for IREE. We believe hand-authored kernel libraries are great for +highly tuned performance but they are difficult to 1) scale to newer models +or target architectures and 2) package and release effectively. This RFC is +founded on the overarching goal to complement our software stack with JIT +solutions while being competitive to hand-authored kernel libraries. Apart +from the usual benefits of having a compiler-backed JIT engine that gets +progressively better, a systemic benefit of this is it helps reduce build +times and binary sizes, making it easier to ship software effectively. ## Overview -[IREE](https://github.com/iree-org/iree/) is an open source ML -compiler stack built using MLIR that is intended to support the -compilation and execution of ML models. While IREE is setup to be -multi-targeting, over the past couple of years a lot of effort has -gone into improving the codegeneration for AMDGPUs, specifically -Instinct class GPUs. While a lot of the IREE compiler stack is meant -to optimize execution of full-scale ML models, one key component of -the work is to have efficient kernel code generation for MI300+ -cards. [Fusilli](https://github.com/nod-ai/shark-ai/tree/main/sharkfuser) -is a C++ graph API and JIT fronend that leverages the kernel -codegeneration capabilities of IREE and packages it to be useable as a -kernel provider within hipDNN. This allows use of IREE for targeted -portions of the program, even for training use cases. The advantages -of using IREE this way are - -1) IREE has been built from the ground-up as a fusion compiler. The +[IREE](https://github.com/iree-org/iree/) is an open source ML compiler stack +built using MLIR that is intended to support the compilation and execution of +ML models. While IREE supports multiple target backends, over the past couple +of years a lot of effort has gone into improving the codegeneration for AMD +GPUs, specifically Instinct (MI-series) GPUs. Much of the IREE compiler stack +is geared towards optimizing execution of full-scale ML models. However, a key +objective of this work is to have efficient kernel code generation for MI300+ +GPUs. + +[Fusilli](https://github.com/nod-ai/shark-ai/tree/main/sharkfuser) is a C++ +graph API that leverages the kernel codegeneration capabilities of IREE and +packages it to be useable as a JIT engine for hipDNN. This allows use of IREE +for specific portions of the program, even for training use cases. The +advantages of this approach are: + +1. IREE has been built from the ground-up as a fusion compiler. The kinds of fusions that libraries like hipDNN are expected to provide are supported out-of-the box in IREE. - -2) Using a compiler as a kernel provider through Fusilli's JIT - interface helps pick IREE generated kernels without having to ship - pre-built kernels with hipDNN - saving both build time and space. - -This RFC is to propose adding a path to using IREE as a kernel -provider to hipDNN within TheRock. There are three components to -reason about - -1. The hipDNN backend plugin to Fusilli. Currently it lives - [here](https://github.com/nod-ai/shark-ai/tree/main/fusilli-plugin) - -2. Fusilli. Currently lives - [here](https://github.com/nod-ai/shark-ai/tree/main/sharkfuser) - -3. IREE, which lives in its own Github Org and is a Linux Foundation - Project. It currently lives [here](https://github.com/iree-org/iree) +1. Fusilli allows compiling codegenerated kernels just-in-time (on-demand) + without having to ship pre-built kernels with hipDNN - saving both build + times and binary sizes. ## Workplan -### Immediate next steps - -The immediate workplan is to move just the hipDNN backend plugin to -Fusilli (so component 1 above) into TheRock. The plugin will be built -conditionally (not on by default) and will pull in Fusilli and IREE as -external dependencies. Nested dependencies of Fusilli are (TODO: -Adjust/clarify the following). - -1. IREE runtime sources are built into Fusilli - -2. IREE compile using the command-line interface, i.e. the - `iree-compile` binary needs to be available (typically through - pip-install of the IREE compiler package) - -### Medium term/Long term requirements. - -While the initial integration will just focus on pulling in the hipDNN -plugin to IREE into monorepo, long term the expectation is that -Fusilli and IREE are brought in through official release mechanisms -that allow TheRock to seamlessly pull them in (through the usual -versioning mechanisms). Some question that need to be answered for -those are - -1. Where does Fusilli live? Fusilli is a C++ API around IREE and such - is tightlt coupled with IREE. A natural home for Fusilli is within - the same repo/github organization as IREE itself. - -2. The expectation is that Fusilli will stop using `iree-compile` as a - binary, but rather use the C-API of the IREE compiler to JIT - compile the (fused) kernel computation. This would require - significant changes to current IREE workflow. Apart from resolving - where the IREE project lives, i.e. if it should move into the - monorepo as well, another challenge to solve there is which LLVM - version should IREE use. IREE currently tracks top-of-main of LLVM - pretty closely. This would need to change to use either the LLVM - version within monorepo or a release version of LLVM/MLIR. +From a code organization standpoint, there are three components to reason about: + +1. IREE. This includes the compiler and runtime stack. It is a Linux Foundation + project and lives [here](https://github.com/iree-org/iree). +1. Fusilli. This is a general purpose API and backend-neutral JIT engine for + IREE that currently lives [here](https://github.com/nod-ai/shark-ai/tree/main/sharkfuser). + It depends minimally on IREE compiler (CLI) and IREE runtime (C-API), and + does NOT require a direct HIP dependency (abstracted by IREE's HAL design). +1. The hipDNN engine plugin for Fusilli. This specializes Fusilli for use within + hipDNN specifically for AMD GPUs. Currently it is being developed + [here](https://github.com/nod-ai/shark-ai/tree/main/fusilli-plugin). + In addition to Fusilli's dependencies, the plugin also depends on HIP, hipDNN + frontend/SDK and hipDNN's dependencies transitively. + +### Short term plan + +The immediate workplan is to move the hipDNN engine plugin (i.e., component 3 +above) into `rocm-libraries` (under `dnn-providers` once build tree is normalized +per RFC0003) following guidelines from the MIOpen plugin restructuring effort. +This will be built conditionally (NOT on by default) and will pull in Fusilli +and IREE as external dependencies. + +The expected build artifact from the plugin integration is a self-contained +`libfusilliplugin.so` that is linked against Fusilli headers and IREE runtime +sources built and statically linked. The dependency on the IREE compiler is +through the `iree-compile` binary (made available typically through a pip-install), +as Fusilli currently invokes the compiler through its command-line-interface. + +A small note on C++ standards: Fusilli and the hipDNN engine plugin for Fusilli +are built on the C++20 standard. We believe this should not pose any issues from an +integration standpoint but happy to revisit this further if the need arises. + +### Long term requirements + +While the initial integration will just focus on pulling in the hipDNN IREE +plugin into the monorepo, long term the expectation is that Fusilli and IREE +are sourced through official release mechanisms that allow TheRock to +seamlessly pull them in (through lockstep versioning). Some questions that need +to be answered for those are: + +1. Where should Fusilli live? Fusilli is a general purpose C++ Graph API around + IREE and as such is tightly coupled with IREE. A natural home for Fusilli is + within the same GitHub organization as IREE itself. This will allow Fusilli + to not only address a gap in the IREE ecosystem for JIT/training use-cases, + but also participate in the release processes in place for IREE already. +1. The expectation is that Fusilli will start using the C-API for the IREE compiler + (through `libIREECompiler.so`) and reserve the use of `iree-compile` binary + only for debugging and sharing reproducers. This would require significant + changes to current IREE workflow. Apart from resolving where the IREE project + lives, i.e. if it should move into the monorepo as well (unlikely), another + challenge to solve there is which LLVM version should IREE use. IREE currently + tracks top-of-main of LLVM pretty closely. This would need to change to use + either the LLVM version within monorepo or a release version of LLVM/MLIR. From 6b6e69f63c862da2018f1ebc5a8a7ac88c36117b Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Mon, 20 Oct 2025 13:36:15 -0700 Subject: [PATCH 11/14] [FusilliPlugin] Cleanup draft RFC (#2558) Official RFC is here: https://github.com/ROCm/TheRock/blob/main/docs/rfcs/RFC0004-Fusilli-IREE-Kernel-Provider-hipDNN.md --- projects/fusilli-plugin/docs/TheRockRFC.md | 101 --------------------- 1 file changed, 101 deletions(-) delete mode 100644 projects/fusilli-plugin/docs/TheRockRFC.md diff --git a/projects/fusilli-plugin/docs/TheRockRFC.md b/projects/fusilli-plugin/docs/TheRockRFC.md deleted file mode 100644 index db971a1247c..00000000000 --- a/projects/fusilli-plugin/docs/TheRockRFC.md +++ /dev/null @@ -1,101 +0,0 @@ ---- -author(s): Sambhav Jain, Aaron St. George, and Mahesh Ravishankar -created: 2025-10-16 -modified: 2025-10-16 -status: draft -discussion: https://github.com/ROCm/TheRock/discussions/1817 ---- - -# Fusilli+IREE as a kernel provider and JIT engine for hipDNN - -This RFC proposes adding IREE as a kernel provider to hipDNN to leverage JIT -compiled and codegenerated kernels in ML training and inference solutions. -This is made possible with the development of Fusilli - a C++ graph API and -JIT engine for IREE. We believe hand-authored kernel libraries are great for -highly tuned performance but they are difficult to 1) scale to newer models -or target architectures and 2) package and release effectively. This RFC is -founded on the overarching goal to complement our software stack with JIT -solutions while being competitive to hand-authored kernel libraries. Apart -from the usual benefits of having a compiler-backed JIT engine that gets -progressively better, a systemic benefit of this is it helps reduce build -times and binary sizes, making it easier to ship software effectively. - -## Overview - -[IREE](https://github.com/iree-org/iree/) is an open source ML compiler stack -built using MLIR that is intended to support the compilation and execution of -ML models. While IREE supports multiple target backends, over the past couple -of years a lot of effort has gone into improving the codegeneration for AMD -GPUs, specifically Instinct (MI-series) GPUs. Much of the IREE compiler stack -is geared towards optimizing execution of full-scale ML models. However, a key -objective of this work is to have efficient kernel code generation for MI300+ -GPUs. - -[Fusilli](https://github.com/nod-ai/shark-ai/tree/main/sharkfuser) is a C++ -graph API that leverages the kernel codegeneration capabilities of IREE and -packages it to be useable as a JIT engine for hipDNN. This allows use of IREE -for specific portions of the program, even for training use cases. The -advantages of this approach are: - -1. IREE has been built from the ground-up as a fusion compiler. The - kinds of fusions that libraries like hipDNN are expected to provide - are supported out-of-the box in IREE. -1. Fusilli allows compiling codegenerated kernels just-in-time (on-demand) - without having to ship pre-built kernels with hipDNN - saving both build - times and binary sizes. - -## Workplan - -From a code organization standpoint, there are three components to reason about: - -1. IREE. This includes the compiler and runtime stack. It is a Linux Foundation - project and lives [here](https://github.com/iree-org/iree). -1. Fusilli. This is a general purpose API and backend-neutral JIT engine for - IREE that currently lives [here](https://github.com/nod-ai/shark-ai/tree/main/sharkfuser). - It depends minimally on IREE compiler (CLI) and IREE runtime (C-API), and - does NOT require a direct HIP dependency (abstracted by IREE's HAL design). -1. The hipDNN engine plugin for Fusilli. This specializes Fusilli for use within - hipDNN specifically for AMD GPUs. Currently it is being developed - [here](https://github.com/nod-ai/shark-ai/tree/main/fusilli-plugin). - In addition to Fusilli's dependencies, the plugin also depends on HIP, hipDNN - frontend/SDK and hipDNN's dependencies transitively. - -### Short term plan - -The immediate workplan is to move the hipDNN engine plugin (i.e., component 3 -above) into `rocm-libraries` (under `dnn-providers` once build tree is normalized -per RFC0003) following guidelines from the MIOpen plugin restructuring effort. -This will be built conditionally (NOT on by default) and will pull in Fusilli -and IREE as external dependencies. - -The expected build artifact from the plugin integration is a self-contained -`libfusilliplugin.so` that is linked against Fusilli headers and IREE runtime -sources built and statically linked. The dependency on the IREE compiler is -through the `iree-compile` binary (made available typically through a pip-install), -as Fusilli currently invokes the compiler through its command-line-interface. - -A small note on C++ standards: Fusilli and the hipDNN engine plugin for Fusilli -are built on the C++20 standard. We believe this should not pose any issues from an -integration standpoint but happy to revisit this further if the need arises. - -### Long term requirements - -While the initial integration will just focus on pulling in the hipDNN IREE -plugin into the monorepo, long term the expectation is that Fusilli and IREE -are sourced through official release mechanisms that allow TheRock to -seamlessly pull them in (through lockstep versioning). Some questions that need -to be answered for those are: - -1. Where should Fusilli live? Fusilli is a general purpose C++ Graph API around - IREE and as such is tightly coupled with IREE. A natural home for Fusilli is - within the same GitHub organization as IREE itself. This will allow Fusilli - to not only address a gap in the IREE ecosystem for JIT/training use-cases, - but also participate in the release processes in place for IREE already. -1. The expectation is that Fusilli will start using the C-API for the IREE compiler - (through `libIREECompiler.so`) and reserve the use of `iree-compile` binary - only for debugging and sharing reproducers. This would require significant - changes to current IREE workflow. Apart from resolving where the IREE project - lives, i.e. if it should move into the monorepo as well (unlikely), another - challenge to solve there is which LLVM version should IREE use. IREE currently - tracks top-of-main of LLVM pretty closely. This would need to change to use - either the LLVM version within monorepo or a release version of LLVM/MLIR. From 8f431723b4056c1894068b57f408cf01313f5ec5 Mon Sep 17 00:00:00 2001 From: Aaron St George Date: Mon, 20 Oct 2025 16:24:04 -0600 Subject: [PATCH 12/14] [FusilliPlugin] `FetchContent` `hipDNN` dependency (#2563) When building `hipDNN` from source it handles downloading and configuring transitive dependencies, while an installed version requires the consuming package to provide needed dependencies. In the plugin's current form, the former is more convenient. --- .../cmake/FusilliPluginDependencyUtils.cmake | 74 +++++++++---------- 1 file changed, 35 insertions(+), 39 deletions(-) diff --git a/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake index d3740a510d4..9983941fc66 100644 --- a/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake +++ b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake @@ -115,6 +115,11 @@ endmacro() # hipdnn_frontend # +# NOTE: we currently build hipDNN as a CMake source dependency (via +# FetchContent) rather than using find_package() to locate an installed version. +# The hipDNN build automatically handles transitive dependencies, which is +# quite convenient. +# # HIP_DNN_HASH # Git commit hash or tag to fetch macro(_fetch_hipdnn_frontend) @@ -133,46 +138,37 @@ macro(_fetch_hipdnn_frontend) message(FATAL_ERROR "Argument error: passing both LOCAL_PATH and HIP_DNN_HASH is ambiguous.") endif() - # We would normally check for, and preferentially use, an installed config - # package using FIND_PACKAGE_ARGS CONFIG on FetchContent_Declare. But, CMake - # throws an error if both FIND_PACKAGE_ARGS and DOWNLOAD_COMMAND arguments - # are passed, it's one or the other. - find_package( - hipdnn_frontend QUIET ${PARSE_FIND_PACKAGE_ARGS} - ) - if(NOT hipdnn_frontend_FOUND) # we can't early return in a macro - if (DEFINED ARG_HIP_DNN_HASH) - FetchContent_Declare( - hipdnn_frontend - # location of hipdnn CMakeLists.txt in rocm-libraries - SOURCE_SUBDIR projects/hipdnn - # rocm-libraries takes 10+ min to fetch without sparse checkout - # (even with a shallow clone). We provide a custom - # DOWNLOAD_COMMAND until such time as CMAKE natively supports - # sparse checkouts. - DOWNLOAD_COMMAND - git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && - cd && - git sparse-checkout init --cone && - git sparse-checkout set projects/hipdnn && - git checkout ${ARG_HIP_DNN_HASH} - ) - else() - FetchContent_Declare( - hipdnn_frontend - SOURCE_DIR ${ARG_LOCAL_PATH} - ) - endif() - - set(HIP_DNN_BUILD_BACKEND ON) - set(HIP_DNN_BUILD_FRONTEND ON) - set(HIP_DNN_SKIP_TESTS ON) - set(HIP_DNN_BUILD_PLUGINS OFF) - set(ENABLE_CLANG_TIDY OFF) - # PIC required to link static library into shared object. - set(CMAKE_POSITION_INDEPENDENT_CODE ON) - FetchContent_MakeAvailable(hipdnn_frontend) + if (DEFINED ARG_HIP_DNN_HASH) + FetchContent_Declare( + hipdnn_frontend + # location of hipdnn CMakeLists.txt in rocm-libraries + SOURCE_SUBDIR projects/hipdnn + # rocm-libraries takes 10+ min to fetch without sparse checkout + # (even with a shallow clone). We provide a custom + # DOWNLOAD_COMMAND until such time as CMAKE natively supports + # sparse checkouts. + DOWNLOAD_COMMAND + git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && + cd && + git sparse-checkout init --cone && + git sparse-checkout set projects/hipdnn && + git checkout ${ARG_HIP_DNN_HASH} + ) + else() + FetchContent_Declare( + hipdnn_frontend + SOURCE_DIR ${ARG_LOCAL_PATH} + ) endif() + + set(HIP_DNN_BUILD_BACKEND ON) + set(HIP_DNN_BUILD_FRONTEND ON) + set(HIP_DNN_SKIP_TESTS ON) + set(HIP_DNN_BUILD_PLUGINS OFF) + set(ENABLE_CLANG_TIDY OFF) + # PIC required to link static library into shared object. + set(CMAKE_POSITION_INDEPENDENT_CODE ON) + FetchContent_MakeAvailable(hipdnn_frontend) endmacro() # IREERuntime From 55b156da0f073071a64dee3e1c987b403a8b9415 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Wed, 22 Oct 2025 18:49:49 -0700 Subject: [PATCH 13/14] [Fusilli] Pick compiler fixes and change dispatch count checks for conv3d (#2560) Compiler fix: https://github.com/iree-org/iree/pull/22320 Issue: https://github.com/iree-org/iree/issues/22312 --- projects/fusilli-plugin/test/integration/CMakeLists.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/projects/fusilli-plugin/test/integration/CMakeLists.txt b/projects/fusilli-plugin/test/integration/CMakeLists.txt index 0c06e8fe719..f2ba4dce4fd 100644 --- a/projects/fusilli-plugin/test/integration/CMakeLists.txt +++ b/projects/fusilli-plugin/test/integration/CMakeLists.txt @@ -26,4 +26,8 @@ target_compile_definitions(fusilli_plugin_integration_tests PRIVATE FUSILLI_PLUGIN_ENGINE_ID="${FUSILLI_PLUGIN_ENGINE_ID}" ) # Register with CTest -gtest_discover_tests(fusilli_plugin_integration_tests) +gtest_discover_tests(fusilli_plugin_integration_tests + # Ensure that tests pick up libhipdnn_backend.so in build directory, not + # from the global TheRock install. + PROPERTIES ENVIRONMENT "LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/lib" +) From 05f8da959b49431b9c42413270aeb4bc49dfa72f Mon Sep 17 00:00:00 2001 From: Aaron St George Date: Thu, 23 Oct 2025 14:05:40 -0600 Subject: [PATCH 14/14] [Fusilli,FusilliPlugin] Async execution & device selection (#2487) Adds support for async execution and explicit device/stream selection and updates fusilli plugin to use new API: - Handle creation supports device id and external HIP stream parameters. - Graph execution is async on AMDGPU backend. - Fusilli plugin now uses stream if set by user. --- .../fusilli-plugin/include/graph_import.h | 1 - .../include/hipdnn_engine_plugin_handle.h | 33 +++++- projects/fusilli-plugin/include/utils.h | 34 +++++- .../fusilli-plugin/src/fusilli_plugin.cpp | 48 +++++--- .../test/integration/test_basic.cpp | 2 +- .../test/integration/test_convfprop.cpp | 111 ++++++++++++++---- .../test/test_fusilli_plugin_api.cpp | 39 +++++- 7 files changed, 220 insertions(+), 48 deletions(-) diff --git a/projects/fusilli-plugin/include/graph_import.h b/projects/fusilli-plugin/include/graph_import.h index 4046cb78f96..8fdffece063 100644 --- a/projects/fusilli-plugin/include/graph_import.h +++ b/projects/fusilli-plugin/include/graph_import.h @@ -24,7 +24,6 @@ #include #include -#include "fusilli/support/logging.h" #include "hipdnn_engine_plugin_execution_context.h" // Convert from hipDNN DataType to fusilli DataType. diff --git a/projects/fusilli-plugin/include/hipdnn_engine_plugin_handle.h b/projects/fusilli-plugin/include/hipdnn_engine_plugin_handle.h index 80903b19b4d..7dd79a24795 100644 --- a/projects/fusilli-plugin/include/hipdnn_engine_plugin_handle.h +++ b/projects/fusilli-plugin/include/hipdnn_engine_plugin_handle.h @@ -22,17 +22,20 @@ #ifndef FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_HANDLE_H #define FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_HANDLE_H -#include "fusilli/backend/handle.h" #include +#include +#include + +#include +#include #include #include struct HipdnnEnginePluginHandle { public: - fusilli::Handle fusilliHandle; + const int deviceId; - HipdnnEnginePluginHandle(fusilli::Handle &&handle) - : fusilliHandle(std::move(handle)) {} + HipdnnEnginePluginHandle(int deviceId) : deviceId(deviceId) {} // Take ownership of a flatbuffers::DetachedBuffer and store it associated // with its memory address. @@ -46,7 +49,29 @@ struct HipdnnEnginePluginHandle { _engineDetailsBuffers.erase(ptr); } + // Get or create fusilli::Handle just in time. As the engine API may set the + // stream (through `hipdnnEnginePluginSetStream`) after initial handle + // creation (in `hipdnnEnginePluginCreate`) we defer the fusilli::Handle + // creation until we know if a stream has been set. + fusilli::ErrorOr> getFusilliHandle() { + if (!_fusilliHandle.has_value()) + _fusilliHandle = FUSILLI_TRY( + fusilli::Handle::create(fusilli::Backend::AMDGPU, deviceId, + reinterpret_cast(_stream))); + return fusilli::ok( + std::reference_wrapper(*_fusilliHandle)); + } + + void setStream(hipStream_t stream) { _stream = stream; } + private: + // Default to creating a handle on the null (default) stream. + hipStream_t _stream = 0; + + // Fusilli handle, will be created on the first call to `getFusilliHandle`. + std::optional _fusilliHandle; + + // Storage for engine details. std::unordered_map> _engineDetailsBuffers; }; diff --git a/projects/fusilli-plugin/include/utils.h b/projects/fusilli-plugin/include/utils.h index d92ef85b93f..273b96d9770 100644 --- a/projects/fusilli-plugin/include/utils.h +++ b/projects/fusilli-plugin/include/utils.h @@ -13,13 +13,17 @@ #ifndef FUSILLI_PLUGIN_SRC_UTILS_H #define FUSILLI_PLUGIN_SRC_UTILS_H -#include "fusilli/attributes/types.h" -#include "fusilli/support/logging.h" +#include +#include #include #include #include #include +#include + +namespace fusilli_plugin { + // Checks for null, sets the plugin last error manager and returns error if // null. // @@ -92,6 +96,28 @@ findDeviceBuffer(int64_t uid, const hipdnnPluginDeviceBuffer_t *deviceBuffers, std ::move(*errorOr); \ }) +template fusilli::ErrorObject convertToErrorObject(T &&error) { + using DecayT = std::decay_t; + if constexpr (std::is_convertible_v) { + return std::forward(error); + } else if constexpr (std::is_same_v) { + // Convert HIP error to fusilli ErrorObject + if (error != hipSuccess) { + return fusilli::error(fusilli::ErrorCode::InternalError, + hipGetErrorString(error)); + } + return fusilli::ok(); + } else { + static_assert( + std::is_convertible_v || + std::is_same_v, + "convertToErrorObject requires fusilli::ErrorObject or hipError_t"); + // Unreachable + return fusilli::error(fusilli::ErrorCode::InternalError, + "Unknown error type"); + } +} + // Set plugin error manager last error and return failed status from enclosing // scope if expression evaluates to a fusilli::ErrorObject in an error state; or // in the case of fusilli::ErrorOr is convertible to an fusilli::ErrorObject @@ -107,7 +133,7 @@ findDeviceBuffer(int64_t uid, const hipdnnPluginDeviceBuffer_t *deviceBuffers, // } #define FUSILLI_PLUGIN_CHECK_ERROR(expr) \ do { \ - fusilli::ErrorObject err = (expr); \ + fusilli::ErrorObject err = convertToErrorObject(expr); \ if (isError(err)) { \ return hipdnn_plugin::PluginLastErrorManager::setLastError( \ HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR, err.getMessage()); \ @@ -148,4 +174,6 @@ fusilliDataTypeToIreeHalDataType(fusilli::DataType fusilliDataType) { } } +} // namespace fusilli_plugin + #endif // FUSILLI_PLUGIN_SRC_UTILS_H diff --git a/projects/fusilli-plugin/src/fusilli_plugin.cpp b/projects/fusilli-plugin/src/fusilli_plugin.cpp index 5adef515c2e..da8d251007d 100644 --- a/projects/fusilli-plugin/src/fusilli_plugin.cpp +++ b/projects/fusilli-plugin/src/fusilli_plugin.cpp @@ -12,8 +12,6 @@ //===----------------------------------------------------------------------===// // hipDNN logging expects COMPONENT_NAME to be defined -#include -#include #define COMPONENT_NAME FUSILLI_PLUGIN_NAME #include @@ -32,6 +30,8 @@ #include #include #include +#include +#include #include #include @@ -46,6 +46,7 @@ #include "utils.h" using namespace hipdnn_plugin; +using namespace fusilli_plugin; // TODO(#2317): ensure single source of truth for plugin version static const char *fusilliPluginVersion = "0.0.1"; @@ -145,13 +146,12 @@ hipdnnEnginePluginCreate(hipdnnEnginePluginHandle_t *handle) { LOG_API_ENTRY("handle_ptr={:p}", static_cast(handle)); FUSILLI_PLUGIN_CHECK_NULL(handle); - // According to runtime/src/iree/hal/driver_registry.h the underlying device - // creation methods should be thread safe, fusilli::Handle ensures that - // instance creation is thread safe, so this should be thread safe. - // TODO(#2335): handle multiple architectures - auto fusilliHandle = - FUSILLI_PLUGIN_TRY(fusilli::Handle::create(fusilli::Backend::AMDGPU)); - *handle = new HipdnnEnginePluginHandle(std::move(fusilliHandle)); + // Get device id. + int deviceId; + FUSILLI_PLUGIN_CHECK_ERROR(hipGetDevice(&deviceId)); + + // Create handle. + *handle = new HipdnnEnginePluginHandle(deviceId); LOG_API_SUCCESS_AUTO("createdHandle={:p}", static_cast(*handle)); return HIPDNN_PLUGIN_STATUS_SUCCESS; @@ -175,10 +175,22 @@ hipdnnEnginePluginSetStream(hipdnnEnginePluginHandle_t handle, static_cast(stream)); FUSILLI_PLUGIN_CHECK_NULL(handle); - // TODO(#2151): Set stream on fusilli handle, or defer creation until stream - // is available and create handle around stream. Today fusilli handle creates - // a default IREE runtime device and execute programs on a stream associated - // with that device. The passed in stream is ignored. + // Get device associated with stream. + hipDevice_t deviceId; + FUSILLI_PLUGIN_CHECK_ERROR(hipStreamGetDevice(stream, &deviceId)); + + // This should never happen, check so that when it does we get a nice error + // message. + if (deviceId != handle->deviceId) { + return hipdnn_plugin::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_BAD_PARAM, + "Stream is associated with different device. Device reported " + "through `hipStreamGetDevice` does not match active " + "device reported through `hipGetDevice`."); + } + + // Set stream, it will be used to create fusilli::Handle later. + handle->setStream(stream); LOG_API_SUCCESS_AUTO("", ""); return HIPDNN_PLUGIN_STATUS_SUCCESS; @@ -400,7 +412,8 @@ hipdnnPluginStatus_t hipdnnEnginePluginCreateExecutionContext( // Compile graph FUSILLI_CHECK_ERROR(graphImport.graph.validate()); - FUSILLI_CHECK_ERROR(graphImport.graph.compile(handle->fusilliHandle)); + FUSILLI_CHECK_ERROR( + graphImport.graph.compile(FUSILLI_TRY(handle->getFusilliHandle()))); return fusilli::ok(std::move(graphImport)); }; @@ -446,8 +459,8 @@ hipdnnPluginStatus_t hipdnnEnginePluginExecuteOpGraph( FUSILLI_PLUGIN_CHECK_NULL(deviceBuffers); // Params and allocators hoisted out of loop below. - iree_hal_allocator_t *deviceAllocator = - iree_hal_device_allocator(handle->fusilliHandle); + iree_hal_allocator_t *deviceAllocator = iree_hal_device_allocator( + FUSILLI_PLUGIN_TRY(handle->getFusilliHandle()).get()); iree_allocator_t ireeHostAllocator = iree_allocator_system(); iree_hal_buffer_params_t bufferParams = { .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, @@ -526,7 +539,8 @@ hipdnnPluginStatus_t hipdnnEnginePluginExecuteOpGraph( iree_hal_buffer_view_release(outBufferView); } - FUSILLI_PLUGIN_CHECK_ERROR(executionContext->graph.execute(variantPack)); + FUSILLI_PLUGIN_CHECK_ERROR(executionContext->graph.execute( + FUSILLI_PLUGIN_TRY(handle->getFusilliHandle()), variantPack)); LOG_API_SUCCESS_AUTO("{}", "executed graph"); return HIPDNN_PLUGIN_STATUS_SUCCESS; diff --git a/projects/fusilli-plugin/test/integration/test_basic.cpp b/projects/fusilli-plugin/test/integration/test_basic.cpp index 391cc80e35b..b1561fcc8c2 100644 --- a/projects/fusilli-plugin/test/integration/test_basic.cpp +++ b/projects/fusilli-plugin/test/integration/test_basic.cpp @@ -11,7 +11,7 @@ #include #include -std::vector getLoadedPlugins(hipdnnHandle_t handle) { +static std::vector getLoadedPlugins(hipdnnHandle_t handle) { size_t numPlugins = 0; size_t maxPathLength = 0; auto status = hipdnnGetLoadedEnginePluginPaths_ext(handle, &numPlugins, diff --git a/projects/fusilli-plugin/test/integration/test_convfprop.cpp b/projects/fusilli-plugin/test/integration/test_convfprop.cpp index 3d959cd9676..1812f5a753b 100644 --- a/projects/fusilli-plugin/test/integration/test_convfprop.cpp +++ b/projects/fusilli-plugin/test/integration/test_convfprop.cpp @@ -4,42 +4,75 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include #include #include - +#include #include #include #include #include #include #include + +#include #include +#include using namespace hipdnn_frontend; using namespace hipdnn_sdk::utilities; using namespace hipdnn_sdk::test_utilities; -TEST(ConvFpropIntegrationTest, Basic1x1Convolution) { +struct TestParams { + bool shouldSetStream; + hipDevice_t deviceId; + + // GTest uses the << operator to update the parameterized test name. + friend std::ostream &operator<<(std::ostream &ss, const TestParams &p) { + ss << "Device" << p.deviceId << "_"; + ss << (p.shouldSetStream ? "WithStream" : "WithoutStream"); + return ss; + } +}; + +class ConvFpropIntegrationTest : public ::testing::TestWithParam {}; + +TEST_P(ConvFpropIntegrationTest, Basic1x1Convolution) { + TestParams params = GetParam(); // Uncomment to enable debug logging // setenv("HIPDNN_LOG_LEVEL", "info", 1); - // Initialize HIP + // Initialize HIP. ASSERT_EQ(hipInit(0), hipSuccess); - int deviceId; - ASSERT_EQ(hipGetDevice(&deviceId), hipSuccess); - // Set plugin path + // Set device. + ASSERT_EQ(hipSetDevice(params.deviceId), hipSuccess); + + // Create stream. + hipStream_t stream = nullptr; + if (params.shouldSetStream) { + ASSERT_EQ(hipStreamCreate(&stream), hipSuccess); + } + + // Set plugin path. const std::array paths = {FUSILLI_PLUGIN_DIR}; ASSERT_EQ(hipdnnSetEnginePluginPaths_ext(paths.size(), paths.data(), HIPDNN_PLUGIN_LOADING_ABSOLUTE), HIPDNN_STATUS_SUCCESS); - // Create handle + // Create handle. hipdnnHandle_t handle; ASSERT_EQ(hipdnnCreate(&handle), HIPDNN_STATUS_SUCCESS); - // Dimensions + // Check that loading the plugin didn't change the active device. + hipDevice_t deviceId = -1; + ASSERT_EQ(hipGetDevice(&deviceId), hipSuccess); + ASSERT_EQ(deviceId, params.deviceId); + + if (params.shouldSetStream) { + ASSERT_EQ(hipdnnSetStream(handle, stream), HIPDNN_STATUS_SUCCESS); + } + + // Dimensions. const int64_t n = 16; // batch const int64_t c = 128; // in channels const int64_t h = 64; // image height @@ -48,12 +81,12 @@ TEST(ConvFpropIntegrationTest, Basic1x1Convolution) { const int64_t r = 1; // filter height const int64_t s = 1; // filter width - // UIDs + // UIDs. const int64_t xUID = 0; const int64_t wUID = 1; const int64_t yUID = 2; - // Initialize tensors + // Initialize tensors. PinnedTensor xTensor({n, c, h, w}); PinnedTensor wTensor({k, c, r, s}); PinnedTensor yTensor({n, k, h, w}); @@ -61,17 +94,17 @@ TEST(ConvFpropIntegrationTest, Basic1x1Convolution) { wTensor.fillWithValue(1.0f); yTensor.fillWithValue(-100.0f); - // Expected output + // Expected output. PinnedTensor expectedOutput({n, k, h, w}); expectedOutput.fillWithValue(128.0f); - // Create graph + // Create graph. auto graph = std::make_shared(); graph->set_name("conv_1x1_test"); graph->set_io_data_type(DataType_t::FLOAT) .set_compute_data_type(DataType_t::FLOAT); - // Create tensor attributes + // Create tensor attributes. auto xAttr = std::make_shared( graph::makeTensorAttributes("input", DataType_t::FLOAT, xTensor)); xAttr->set_uid(xUID); @@ -79,19 +112,19 @@ TEST(ConvFpropIntegrationTest, Basic1x1Convolution) { graph::makeTensorAttributes("filter", DataType_t::FLOAT, wTensor)); wAttr->set_uid(wUID); - // Create convolution attributes + // Create convolution attributes. graph::ConvFpropAttributes convAttr; convAttr.set_name("conv_fprop") .set_padding({0, 0}) .set_stride({1, 1}) .set_dilation({1, 1}); - // Create graph + // Create graph. auto yAttr = graph->conv_fprop(xAttr, wAttr, convAttr); yAttr->set_uid(yUID); yAttr->set_dim(yTensor.dims()).set_stride(yTensor.strides()).set_output(true); - // Build + validate + build plans for graph + // Build + validate + build plans for graph. auto result = graph->validate(); ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; @@ -107,22 +140,58 @@ TEST(ConvFpropIntegrationTest, Basic1x1Convolution) { result = graph->build_plans(); ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; - // Create variant pack + // Create variant pack. std::unordered_map variantPack; variantPack[xUID] = xTensor.memory().deviceData(); variantPack[wUID] = wTensor.memory().deviceData(); variantPack[yUID] = yTensor.memory().deviceData(); - // Execute graph + // Execute graph. result = graph->execute(handle, variantPack, nullptr); ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; // Mark hipDNN tensor CPU cache ask stale, data must be read from device. yTensor.memory().markDeviceModified(); - // Check results + // Check results. CpuFpReferenceValidation validator(1e-6f, 1e-6f); EXPECT_TRUE(validator.allClose(expectedOutput.memory(), yTensor.memory())); - // Cleanup + // Clean up. + if (params.shouldSetStream) { + ASSERT_EQ(hipStreamDestroy(stream), HIPDNN_STATUS_SUCCESS); + } ASSERT_EQ(hipdnnDestroy(handle), HIPDNN_STATUS_SUCCESS); } + +static std::vector generateTestParams() { + std::vector params; + int deviceCount; + assert(hipGetDeviceCount(&deviceCount) == hipSuccess); + + // Always test with device 0. + params.push_back({ + .shouldSetStream = false, + .deviceId = 0, + }); + params.push_back({ + .shouldSetStream = true, + .deviceId = 0, + }); + + // Test on last device if multiple devices are available. + if (deviceCount > 1) { + params.push_back({ + .shouldSetStream = false, + .deviceId = deviceCount - 1, + }); + params.push_back({ + .shouldSetStream = true, + .deviceId = deviceCount - 1, + }); + } + + return params; +} + +INSTANTIATE_TEST_SUITE_P(, ConvFpropIntegrationTest, + ::testing::ValuesIn(generateTestParams())); diff --git a/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp b/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp index 880b9069a4d..00df0ddcab2 100644 --- a/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp +++ b/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp @@ -15,13 +15,13 @@ #include #include #include +#include #include #include #include #include #include -#include #include #include #include @@ -368,3 +368,40 @@ TEST(TestFusilliPluginApi, CreateExecutionContext) { HIPDNN_PLUGIN_STATUS_SUCCESS); EXPECT_EQ(hipdnnEnginePluginDestroy(handle), HIPDNN_PLUGIN_STATUS_SUCCESS); } + +TEST(TestFusilliPluginApi, SetStreamSuccess) { + // Create plugin handle. + hipdnnEnginePluginHandle_t handle = nullptr; + ASSERT_EQ(hipdnnEnginePluginCreate(&handle), HIPDNN_PLUGIN_STATUS_SUCCESS); + + // Create a HIP stream. + hipStream_t stream; + ASSERT_EQ(hipStreamCreate(&stream), hipSuccess); + + // Set the stream on the handle. + EXPECT_EQ(hipdnnEnginePluginSetStream(handle, stream), + HIPDNN_PLUGIN_STATUS_SUCCESS); + + // Clean up. + EXPECT_EQ(hipStreamDestroy(stream), hipSuccess); + EXPECT_EQ(hipdnnEnginePluginDestroy(handle), HIPDNN_PLUGIN_STATUS_SUCCESS); +} + +TEST(TestFusilliPluginApi, SetStreamNullHandle) { + // Create a HIP stream. + hipStream_t stream; + ASSERT_EQ(hipStreamCreate(&stream), hipSuccess); + + // Attempt to set stream with null handle should fail. + EXPECT_EQ(hipdnnEnginePluginSetStream(nullptr, stream), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + // Verify error was set. + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); + EXPECT_GT(strlen(errorStr), 0u); + + // Clean up. + EXPECT_EQ(hipStreamDestroy(stream), hipSuccess); +}