diff --git a/projects/fusilli-plugin/.clang-format b/projects/fusilli-plugin/.clang-format new file mode 100644 index 00000000000..9b3aa8b7213 --- /dev/null +++ b/projects/fusilli-plugin/.clang-format @@ -0,0 +1 @@ +BasedOnStyle: LLVM diff --git a/projects/fusilli-plugin/.gitignore b/projects/fusilli-plugin/.gitignore new file mode 100644 index 00000000000..425ada20f17 --- /dev/null +++ b/projects/fusilli-plugin/.gitignore @@ -0,0 +1,8 @@ +# CMake build +build/ + +# clangd intellisense cache +.cache/ + +# CMake presets +CMakePresets.json diff --git a/projects/fusilli-plugin/CMakeLists.txt b/projects/fusilli-plugin/CMakeLists.txt new file mode 100644 index 00000000000..a9cee7b9b45 --- /dev/null +++ b/projects/fusilli-plugin/CMakeLists.txt @@ -0,0 +1,61 @@ +# Copyright 2025 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +cmake_minimum_required(VERSION 3.28) + +project(fusilli-plugin + VERSION 0.1.0 + DESCRIPTION "Fusilli-Plugin: A Fusilli/IREE powered hipDNN plugin for graph JIT compilation." + LANGUAGES C CXX) + +# Set C++ standard +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_SCAN_FOR_MODULES OFF) + +# Local Includes +list(APPEND CMAKE_MODULE_PATH + ${CMAKE_CURRENT_LIST_DIR}/build_tools/cmake/ +) +include(FusilliPluginDependencyUtils) + +# Global constants +set(FUSILLI_PLUGIN_NAME fusilli_plugin) +set(FUSILLI_PLUGIN_ENGINE_ID 1001) + +# Options +option(FUSILLI_PLUGIN_USE_LOCAL_FUSILLI "Use local Fusilli build from ../sharkfuser" ON) + +# Dependencies +set(HIP_PLATFORM "amd") +find_package(hip REQUIRED) +fusilli_plugin_dependency(GTest GTEST_VERSION 1.16.0) +fusilli_plugin_dependency(IREERuntime) +fusilli_plugin_dependency(hipdnn_frontend HIP_DNN_HASH 4e0a0452cfcb8fdb86e9c40a6e43debab4d4ecbc) +fusilli_plugin_dependency(Fusilli USE_LOCAL ${FUSILLI_PLUGIN_USE_LOCAL_FUSILLI}) + +# Includes +include_directories(include) + +# Plugin definition +add_library(${FUSILLI_PLUGIN_NAME} SHARED + src/fusilli_plugin.cpp +) +target_compile_options(${FUSILLI_PLUGIN_NAME} PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS}) +target_link_libraries(${FUSILLI_PLUGIN_NAME} PRIVATE hipdnn_sdk hip::host fusilli::fusilli) +target_compile_definitions(${FUSILLI_PLUGIN_NAME} PRIVATE + FUSILLI_PLUGIN_NAME="${FUSILLI_PLUGIN_NAME}" + FUSILLI_PLUGIN_ENGINE_ID=${FUSILLI_PLUGIN_ENGINE_ID} +) +set_target_properties(${FUSILLI_PLUGIN_NAME} PROPERTIES + CXX_VISIBILITY_PRESET hidden + LIBRARY_OUTPUT_DIRECTORY "${HIPDNN_BUILD_PLUGIN_ENGINE_DIR}" +) + +# Tests +enable_testing() +add_subdirectory(test) diff --git a/projects/fusilli-plugin/README.md b/projects/fusilli-plugin/README.md new file mode 100644 index 00000000000..7c211381635 --- /dev/null +++ b/projects/fusilli-plugin/README.md @@ -0,0 +1,26 @@ +# Fusilli Plugin + +Fusilli-Plugin: A Fusilli/IREE powered hipDNN plugin for graph JIT compilation. + +:construction: **This project is under active development, many things don't work yet** :construction: + +The plugin builds as a shared library (`fusilli_plugin.so`) providing a `hipDNN` [kernel engine plugin](https://github.com/ROCm/hipDNN/blob/develop/docs/PluginDevelopment.md#creating-a-kernel-engine-plugin) [API](https://github.com/ROCm/hipDNN/blob/839cf6c4bc6fe403d0ef72cb5d7df004e2004743/sdk/include/hipdnn_sdk/plugin/EnginePluginApi.h). + +## Developer Guide + +### Setup + +For the time being, `fusilli-plugin` setup relies on / builds on [Fusilli setup](../sharkfuser/README.md#setup). +Keeping the projects in sync prevents "works on my machine" style bugs. +Requirements that are unique to `fusilli-plugin`, `hipDNN` and `googletest` for +example, are fetched configured and built as part of `fusilli-plugin` build. + +After following steps in [Fusilli Setup](../sharkfuser/README.md#setup), build and test +`fusilli-plugin` as follows: +```shell +$ cmake -GNinja -S. -Bbuild \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ +$ cmake --build build --target all +$ ctest --test-dir build +``` diff --git a/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake new file mode 100644 index 00000000000..9983941fc66 --- /dev/null +++ b/projects/fusilli-plugin/build_tools/cmake/FusilliPluginDependencyUtils.cmake @@ -0,0 +1,219 @@ +# Copyright 2025 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#===------------------------------------------------------------------------===# +# +# Provides correctly configured dependencies for fusilli-plugin build. +# +# Main entry point: +# fusilli_plugin_dependency(DEP_NAME [args...]) +# +# `fusilli_plugin_dependency` routes to lower level `_fetch_X` macros to +# actually fetch dependency `X`. Each `_fetch_X` macro preferentially +# `find_package`s installed/system versions of packages and falls back to +# vendoring dependencies in the build tree with `FetchContent`. +# +# Supported dependencies: GTest, hipdnn_frontend, Fusilli, IREERuntime +# +#===------------------------------------------------------------------------===# + +cmake_minimum_required(VERSION 3.25.2) + +include(FetchContent) + +# Provide a fusilli plugin dependency. `fusilli_plugin_dependency` will +# preferentially use system version (available through `find_package`) of a +# dependency, and fall back to building local copy with `FetchContent` + +# configuration. +# +# fusilli_plugin_dependency( +# DEP_NAME +# [...] +# ) +# +# DEP_NAME +# Supported dependencies: +# GTest +# hipdnn_frontend +# Fusilli +# IREERuntime +# +# +# The `_fetch_X` macro for dependency X defines the available options. +# Examples: GTEST_VERSION for GTest, HIP_DNN_HASH for hipdnn_frontend +# +function(fusilli_plugin_dependency DEP_NAME) + # Set indent for logging, any logs from dep "X" will be prefixed with [X]. + set(CMAKE_MESSAGE_INDENT "[${DEP_NAME}] ") + + # Route to appropriate _fetch_X macro. CMake macros aren't textual + # expansions like C preprocessor macros, so a dynamic call (like below) to a + # macro isn't a problem. + # Macro vs function: + # - macros execute in caller's scope and arguments are textually substituted + # - functions create a new scope and arguments are real variables + # - both functions and macros are executed at runtime + # + # WARNING: Logging below checks variables it expects a _fetch_X macro to set + # in this scope, requiring that _fetch_X is a macro and not a + # function. + if(COMMAND _fetch_${DEP_NAME}) + cmake_language(CALL _fetch_${DEP_NAME} ${ARGN}) + else() + set(CMAKE_MESSAGE_INDENT "") + message(FATAL_ERROR "Unknown dependency: ${DEP_NAME}") + endif() + + # reset indent. + set(CMAKE_MESSAGE_INDENT "") + + # FetchContent_MakeAvailable(DEP) creates a _POPULATED variable + # indicating the dependency was fetched rather than found on system. + # + # WARNING: FetchContent_Declare()/FetchContent_MakeAvailable() + # can use anything for the name argument, if the _fetch_X macro + # doesn't use ${DEP_NAME} the _POPULATED we're checking for + # here won't exist and the log may be misleading. + string(TOLOWER ${DEP_NAME} DEP_NAME_LOWER) + if (${DEP_NAME_LOWER}_POPULATED) + message(STATUS "${DEP_NAME} dependency populated via FetchContent") + message(STATUS " Source: ${${DEP_NAME_LOWER}_SOURCE_DIR}") + message(STATUS " Build: ${${DEP_NAME_LOWER}_BINARY_DIR}") + else() + message(STATUS "${DEP_NAME} dependency found on system via find_package") + message(STATUS " Config: ${${DEP_NAME}_DIR}") + endif() +endfunction() + +# GTest +# +# GTEST_VERSION +# Version tag of GTest +macro(_fetch_GTest) + cmake_parse_arguments( + ARG # prefix for parsed variables + "" # options (flags) + "GTEST_VERSION" # single-value arguments + "" # multi-value arguments + ${ARGN} + ) + if(NOT DEFINED ARG_GTEST_VERSION) + message(FATAL_ERROR "GTEST_VERSION is required") + endif() + + FetchContent_Declare( + GTest + URL https://github.com/google/googletest/archive/refs/tags/v${ARG_GTEST_VERSION}.zip + ) + set(INSTALL_GTEST OFF) + set(BUILD_GMOCK OFF) + FetchContent_MakeAvailable(GTest) +endmacro() + +# hipdnn_frontend +# +# NOTE: we currently build hipDNN as a CMake source dependency (via +# FetchContent) rather than using find_package() to locate an installed version. +# The hipDNN build automatically handles transitive dependencies, which is +# quite convenient. +# +# HIP_DNN_HASH +# Git commit hash or tag to fetch +macro(_fetch_hipdnn_frontend) + cmake_parse_arguments( + ARG # prefix for parsed variables + "" # options (flags) + "HIP_DNN_HASH;LOCAL_PATH" # single-value arguments + "" # multi-value arguments + ${ARGN} + ) + if(NOT DEFINED ARG_LOCAL_PATH AND NOT DEFINED ARG_HIP_DNN_HASH) + message(FATAL_ERROR "Required argument: one of LOCAL_PATH or HIP_DNN_HASH") + endif() + + if(DEFINED ARG_LOCAL_PATH AND DEFINED ARG_HIP_DNN_HASH) + message(FATAL_ERROR "Argument error: passing both LOCAL_PATH and HIP_DNN_HASH is ambiguous.") + endif() + + if (DEFINED ARG_HIP_DNN_HASH) + FetchContent_Declare( + hipdnn_frontend + # location of hipdnn CMakeLists.txt in rocm-libraries + SOURCE_SUBDIR projects/hipdnn + # rocm-libraries takes 10+ min to fetch without sparse checkout + # (even with a shallow clone). We provide a custom + # DOWNLOAD_COMMAND until such time as CMAKE natively supports + # sparse checkouts. + DOWNLOAD_COMMAND + git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && + cd && + git sparse-checkout init --cone && + git sparse-checkout set projects/hipdnn && + git checkout ${ARG_HIP_DNN_HASH} + ) + else() + FetchContent_Declare( + hipdnn_frontend + SOURCE_DIR ${ARG_LOCAL_PATH} + ) + endif() + + set(HIP_DNN_BUILD_BACKEND ON) + set(HIP_DNN_BUILD_FRONTEND ON) + set(HIP_DNN_SKIP_TESTS ON) + set(HIP_DNN_BUILD_PLUGINS OFF) + set(ENABLE_CLANG_TIDY OFF) + # PIC required to link static library into shared object. + set(CMAKE_POSITION_INDEPENDENT_CODE ON) + FetchContent_MakeAvailable(hipdnn_frontend) +endmacro() + +# IREERuntime +# +# NOTE: For now, we're not providing a FetchContent fallback for IREERuntime. +# Fusilli expects that the system provides this dependency, and we're +# keeping the projects in sync as much as possible for now. If you're +# running in the fusilli docker container (described in sharkfuser README) +# passing -DIREERuntime_DIR=/workspace/.cache/docker/iree/build/lib/cmake/IREE +# should be enough. +macro(_fetch_IREERuntime) + find_package(IREERuntime CONFIG REQUIRED) +endmacro() + +# Fusilli +# +# USE_LOCAL +# If set, uses local source from ../sharkfuser directory. Without USE_LOCAL, +# requires system installation via find_package. +macro(_fetch_Fusilli) + cmake_parse_arguments( + ARG # prefix for parsed variables + "" # options (flags) + "USE_LOCAL" # single-value arguments + "" # multi-value arguments + ${ARGN} + ) + + if(NOT DEFINED ARG_USE_LOCAL) + message(FATAL_ERROR "USE_LOCAL argument is required") + endif() + + if(NOT ARG_USE_LOCAL) + # For the time being we're keeping fusilli-plugin setup as in sync as + # possible with fusilli. + message(FATAL_ERROR "Only LOCAL builds are supported currently") + endif() + + message(STATUS "Using local Fusilli build from ../sharkfuser") + FetchContent_Declare( + Fusilli + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../sharkfuser + ) + set(FUSILLI_BUILD_TESTS OFF) + set(FUSILLI_BUILD_BENCHMARKS OFF) + set(FUSILLI_SYSTEMS_AMDGPU ON) + FetchContent_MakeAvailable(Fusilli) +endmacro() diff --git a/projects/fusilli-plugin/include/graph_import.h b/projects/fusilli-plugin/include/graph_import.h new file mode 100644 index 00000000000..8fdffece063 --- /dev/null +++ b/projects/fusilli-plugin/include/graph_import.h @@ -0,0 +1,253 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file contains facilities for converting hipDNN serialized graphs to +// fusilli graphs. +// +//===----------------------------------------------------------------------===// + +#ifndef FUSILLI_PLUGIN_SRC_GRAPH_IMPORT_H +#define FUSILLI_PLUGIN_SRC_GRAPH_IMPORT_H + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "hipdnn_engine_plugin_execution_context.h" + +// Convert from hipDNN DataType to fusilli DataType. +inline fusilli::ErrorOr +hipDnnDataTypeToFusilliDataType(hipdnn_sdk::data_objects::DataType hipdnnType) { + switch (hipdnnType) { + case hipdnn_sdk::data_objects::DataType::HALF: + return ok(fusilli::DataType::Half); + case hipdnn_sdk::data_objects::DataType::BFLOAT16: + return ok(fusilli::DataType::BFloat16); + case hipdnn_sdk::data_objects::DataType::FLOAT: + return ok(fusilli::DataType::Float); + case hipdnn_sdk::data_objects::DataType::DOUBLE: + return ok(fusilli::DataType::Double); + case hipdnn_sdk::data_objects::DataType::UINT8: + return ok(fusilli::DataType::Uint8); + case hipdnn_sdk::data_objects::DataType::INT32: + return ok(fusilli::DataType::Int32); + case hipdnn_sdk::data_objects::DataType::UNSET: + return ok(fusilli::DataType::NotSet); + default: + return error(fusilli::ErrorCode::RuntimeFailure, + "Unknown type in hipdnn -> fusilli graph translation."); + } +} + +// Graph import is done through importGraph function, this class exists for +// organization and is used by importGraph. +// +// Graph import is designed around individual Node import functions (such as +// importConvFPropAttr) which convert a given node type, and track input and +// output tensors in shared state (via importNodeInput and importNodeOutput +// functions). Graph nodes are processed in topological order to ensure that +// outputs of producer nodes are tracked and available for consuming nodes. +// +// NOTE: input hipDNN graph .node()s may not be in topological order. There's +// plans for a topological sort method but that's not available yet. As we're +// only handling single-node graphs currently, that's not a problem. +class GraphImport { +private: + friend fusilli::ErrorOr + importGraph(const hipdnnPluginConstData_t *opGraph); + + // The imported graph. + fusilli::Graph fusilliGraph; + + // Maps hipDNN tensor UIDs to fusilli::TensorAttrs for graph boundary tensors + // (inputs and outputs). Used by hipdnnEnginePluginExecuteOpGraph to match + // incoming device buffers (identified by UID) to their corresponding + // fusilli::TensorAttr. + // + // All tensors in this map should be non-virtual boundary tensors---a virtual + // tensor is an internal intermediate---internal tensors will be tracked + // elsewhere. Currently we only support single node graphs, so there are no + // virtual tensors to track. + std::unordered_map> + uidToIOTensor; + + // Helper class for reading from flatbuffer. + hipdnn_plugin::GraphWrapper opGraphWrapper; + + GraphImport(const hipdnnPluginConstData_t *opGraph) + : opGraphWrapper(opGraph->ptr, opGraph->size) {} + + fusilli::ErrorObject importGraph() { + const hipdnn_sdk::data_objects::Graph &hipDnnGraph = + opGraphWrapper.getGraph(); + + // Import graph level properties. + fusilliGraph.setName(hipDnnGraph.name()->str()) + .setIODataType( + FUSILLI_TRY(hipDnnDataTypeToFusilliDataType(hipDnnGraph.io_type()))) + .setIntermediateDataType(FUSILLI_TRY( + hipDnnDataTypeToFusilliDataType(hipDnnGraph.intermediate_type()))) + .setComputeDataType(FUSILLI_TRY( + hipDnnDataTypeToFusilliDataType(hipDnnGraph.compute_type()))); + + return importNodes(); + } + + // Import all graph nodes. + fusilli::ErrorObject importNodes() { + if (opGraphWrapper.nodeCount() > 1) + return fusilli::error(fusilli::ErrorCode::NotImplemented, + "Multi-node graphs not supported currently."); + for (size_t i = 0; i < opGraphWrapper.nodeCount(); ++i) { + const hipdnn_sdk::data_objects::Node &node = opGraphWrapper.getNode(i); + FUSILLI_CHECK_ERROR(importNode(node)); + } + + return fusilli::ok(); + } + + // Import single graph node. + fusilli::ErrorObject importNode(const hipdnn_sdk::data_objects::Node &node) { + switch (node.attributes_type()) { + case hipdnn_sdk::data_objects::NodeAttributes::ConvolutionFwdAttributes: + FUSILLI_CHECK_ERROR( + importConvFPropAttr(node.attributes_as_ConvolutionFwdAttributes())); + break; + default: + return fusilli::error(fusilli::ErrorCode::NotImplemented, + "Unsupported node type."); + } + return fusilli::ok(); + } + + fusilli::ErrorObject + importConvFPropAttr(const hipdnn_sdk::data_objects::ConvolutionFwdAttributes + *hipDnnConvFwdAttr) { + // Import node inputs. + std::shared_ptr x = + FUSILLI_TRY(importNodeInput(hipDnnConvFwdAttr->x_tensor_uid(), "x")); + std::shared_ptr w = + FUSILLI_TRY(importNodeInput(hipDnnConvFwdAttr->w_tensor_uid(), "w")); + + // hipdnnEnginePluginGetApplicableEngineIds should have already eliminated + // any nodes with asymmetric padding, this is just a double check. + if (!std::ranges::equal(*hipDnnConvFwdAttr->pre_padding(), + *hipDnnConvFwdAttr->post_padding())) // C++ 20 + return fusilli::error(fusilli::ErrorCode::AttributeNotSet, + "Conv node with asymmetric padding found."); + // Import node. + auto fusilliConvFwdAttr = + fusilli::ConvFPropAttr() + .setPadding(*hipDnnConvFwdAttr->post_padding()) + .setStride(*hipDnnConvFwdAttr->stride()) + .setDilation(*hipDnnConvFwdAttr->dilation()) + .setName("conv_fprop"); + std::shared_ptr y = + fusilliGraph.convFProp(x, w, fusilliConvFwdAttr); + + // Import node output. + FUSILLI_CHECK_ERROR( + importNodeOutput(hipDnnConvFwdAttr->y_tensor_uid(), "y", y)); + + return fusilli::ok(); + } + + // Import, and track, node input tensor. Node input tensor is created in the + // case of a boundary tensor, and read from shared state otherwise. + fusilli::ErrorOr> + importNodeInput(int64_t uid, const char *name) { + // Get hipDNN tensor. TensorMap is created from the graph that uid variable + // is read from, so .at() call should be safe. + const hipdnn_sdk::data_objects::TensorAttributes *hipDnnTensorAttr = + opGraphWrapper.getTensorMap().at(uid); + + // A virtual node indicates a non-boundary node. + if (hipDnnTensorAttr->virtual_()) { + // When multi-op graphs are supported, we would look up the output of a + // previously imported node and return it here. + return fusilli::error(fusilli::ErrorCode::NotImplemented, + "Virtual inputs currently unsupported."); + } + + // Import new tensor. + auto fusilliTensorAttr = fusilli::TensorAttr().setName( + std::format("{}_{}", name, uid)); // C++ 20 + FUSILLI_CHECK_ERROR(importAttrs(fusilliTensorAttr, hipDnnTensorAttr)); + std::shared_ptr graphInput = + fusilliGraph.tensor(fusilliTensorAttr); + + // Track boundary tensor. + uidToIOTensor[uid] = graphInput; + + return ok(graphInput); + }; + + // Import and track node output tensor. + fusilli::ErrorObject + importNodeOutput(int64_t uid, const char *name, + const std::shared_ptr &nodeOutput) { + // Get hipDNN tensor. TensorMap is created from the graph that uid variable + // is read from, so .at() call should be safe. + const hipdnn_sdk::data_objects::TensorAttributes *hipDnnTensorAttr = + opGraphWrapper.getTensorMap().at(uid); + + // Import attrs. + nodeOutput->setName(std::format("{}_{}", name, uid)); // C++ 20 + FUSILLI_CHECK_ERROR(importAttrs(*nodeOutput, hipDnnTensorAttr)); + + // A virtual node indicates a non-boundary node. + if (hipDnnTensorAttr->virtual_()) { + // This tensor is an input to nodes farther down the topological sort. + // When multi-op graphs are supported, we would track UID -> node output + // tensor in a non-IO tensor map. + return fusilli::error(fusilli::ErrorCode::NotImplemented, + "Virtual outputs currently unsupported. An output " + "tensor may have not been marked as an output."); + } + + // Track boundary node. + uidToIOTensor[uid] = nodeOutput; + + return fusilli::ok(); + }; + + // Import all tensor attrs src -> dest. + fusilli::ErrorObject + importAttrs(fusilli::TensorAttr &dest, + const hipdnn_sdk::data_objects::TensorAttributes *src) { + dest.setIsVirtual(src->virtual_()) + .setDim(*src->dims()) + .setStride(*src->strides()) + .setDataType( + FUSILLI_TRY(hipDnnDataTypeToFusilliDataType(src->data_type()))); + return fusilli::ok(); + } +}; + +// Given a hipDNN serialized graph, return imported fusilli::Graph and UID -> +// fusilli::TensorAttr map for IO tensors. +// +// NOTE: HipdnnEnginePluginExecutionContext used as return type because it +// contains (only) the exact required fields. If it requires more members in +// the future it's probably worth creating a new data transmission type. +inline fusilli::ErrorOr +importGraph(const hipdnnPluginConstData_t *opGraph) { + auto gc = GraphImport(opGraph); + FUSILLI_CHECK_ERROR(gc.importGraph()); + return HipdnnEnginePluginExecutionContext{.graph = std::move(gc.fusilliGraph), + .uidToFusilliTensorAttr = + std::move(gc.uidToIOTensor)}; +} + +#endif // FUSILLI_PLUGIN_SRC_GRAPH_IMPORT_H diff --git a/projects/fusilli-plugin/include/hipdnn_engine_plugin_execution_context.h b/projects/fusilli-plugin/include/hipdnn_engine_plugin_execution_context.h new file mode 100644 index 00000000000..d5a38da68b0 --- /dev/null +++ b/projects/fusilli-plugin/include/hipdnn_engine_plugin_execution_context.h @@ -0,0 +1,39 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file contains the fusilli plugin's definition of +// HipdnnEnginePluginExecutionContext. To hipDNN this type is opaque, it deals +// in hipdnnEnginePluginExecutionContext_t which is a pointer to the undefined +// HipdnnEnginePluginExecutionContext. Each plugin must define +// HipdnnEnginePluginExecutionContext in order to create something when hipDNN +// asks for an execution context. +// +// The execution context should store what's needed to execute a given kernel +// (plan in hipDNN parlance) in a hot loop without any overhead. For fusilli +// plugin, that maps to constructing and storing a fusilli::Graph based on +// hipDNN graph. When an execution is requested, it should be a simple lookup +// for UID -> tensor attribute, then a graph execution. +// +//===----------------------------------------------------------------------===// + +#ifndef FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_EXECUTION_CONTEXT_H +#define FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_EXECUTION_CONTEXT_H + +#include + +struct HipdnnEnginePluginExecutionContext { + // Fusilli graph. + fusilli::Graph graph; + + // Map from hipDNN tensor UID to fusilli::TensorAttrs for graph boundary + // tensors (inputs and outputs). + std::unordered_map> + uidToFusilliTensorAttr; +}; + +#endif // FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_EXECUTION_CONTEXT_H diff --git a/projects/fusilli-plugin/include/hipdnn_engine_plugin_handle.h b/projects/fusilli-plugin/include/hipdnn_engine_plugin_handle.h new file mode 100644 index 00000000000..7dd79a24795 --- /dev/null +++ b/projects/fusilli-plugin/include/hipdnn_engine_plugin_handle.h @@ -0,0 +1,79 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file contains the fusilli plugin's definition of +// HipdnnEnginePluginHandle. To hipDNN this type is opaque, it deals in +// hipdnnEnginePluginHandle_t which is a pointer to the undefined +// HipdnnEnginePluginHandle. Each plugin must define HipdnnEnginePluginHandle in +// order to create something when hipDNN asks for an plugin handle. +// +// HipdnnEnginePluginHandle stores any persistent data associated with a +// particular engine plugin. In fusilli plugin that's the fusilli::Handle, and +// some temporary buffers that higher level APIs create and destroy at different +// times. +// +//===----------------------------------------------------------------------===// + +#ifndef FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_HANDLE_H +#define FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_HANDLE_H + +#include +#include +#include + +#include +#include +#include +#include + +struct HipdnnEnginePluginHandle { +public: + const int deviceId; + + HipdnnEnginePluginHandle(int deviceId) : deviceId(deviceId) {} + + // Take ownership of a flatbuffers::DetachedBuffer and store it associated + // with its memory address. + void storeEngineDetailsBuffer( + const void *ptr, std::unique_ptr &&buffer) { + _engineDetailsBuffers[ptr] = std::move(buffer); + } + + // Destroy the flatbuffers::DetachedBuffer associated with ptr. + void eraseEngineDetailsBuffer(const void *ptr) { + _engineDetailsBuffers.erase(ptr); + } + + // Get or create fusilli::Handle just in time. As the engine API may set the + // stream (through `hipdnnEnginePluginSetStream`) after initial handle + // creation (in `hipdnnEnginePluginCreate`) we defer the fusilli::Handle + // creation until we know if a stream has been set. + fusilli::ErrorOr> getFusilliHandle() { + if (!_fusilliHandle.has_value()) + _fusilliHandle = FUSILLI_TRY( + fusilli::Handle::create(fusilli::Backend::AMDGPU, deviceId, + reinterpret_cast(_stream))); + return fusilli::ok( + std::reference_wrapper(*_fusilliHandle)); + } + + void setStream(hipStream_t stream) { _stream = stream; } + +private: + // Default to creating a handle on the null (default) stream. + hipStream_t _stream = 0; + + // Fusilli handle, will be created on the first call to `getFusilliHandle`. + std::optional _fusilliHandle; + + // Storage for engine details. + std::unordered_map> + _engineDetailsBuffers; +}; + +#endif // FUSILLI_PLUGIN_SRC_HIPDNN_ENGINE_PLUGIN_HANDLE_H diff --git a/projects/fusilli-plugin/include/utils.h b/projects/fusilli-plugin/include/utils.h new file mode 100644 index 00000000000..273b96d9770 --- /dev/null +++ b/projects/fusilli-plugin/include/utils.h @@ -0,0 +1,179 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file contains the fusilli plugin utils and macros. +// +//===----------------------------------------------------------------------===// + +#ifndef FUSILLI_PLUGIN_SRC_UTILS_H +#define FUSILLI_PLUGIN_SRC_UTILS_H + +#include +#include +#include +#include +#include +#include + +#include + +namespace fusilli_plugin { + +// Checks for null, sets the plugin last error manager and returns error if +// null. +// +// SIDE EFFECT: any util function returning an `hipdnnPluginStatus_t` is +// intended for error checking and reporting, and therefore sets +// PluginLastErrorManager::setLastError to an appropriate error on the unhappy +// path. +template hipdnnPluginStatus_t isNull(T *value) { + if (value == nullptr) { + return hipdnn_plugin::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_BAD_PARAM, + std::string(typeid(T).name()) + " is nullptr"); + } + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +// Find deviceBuffer with UID. +inline fusilli::ErrorOr +findDeviceBuffer(int64_t uid, const hipdnnPluginDeviceBuffer_t *deviceBuffers, + uint32_t numDeviceBuffers) { + for (uint32_t i = 0; i < numDeviceBuffers; i++) { + if (uid == deviceBuffers[i].uid) { + return fusilli::ok(deviceBuffers[i]); + } + } + + return fusilli::error(fusilli::ErrorCode::AttributeNotSet, + "Device buffer with the uid: " + std::to_string(uid) + + " not found in the provided device buffers."); +} + +// If null, set plugin error manager last error to +// HIPDNN_PLUGIN_STATUS_BAD_PARAM and return said error from the enclosing +// scope. +#define FUSILLI_PLUGIN_CHECK_NULL(X) \ + do { \ + if (hipdnnPluginStatus_t status = isNull(X); \ + status != HIPDNN_PLUGIN_STATUS_SUCCESS) { \ + return status; \ + } \ + } while (false) + +// LOG_API_SUCCESS from hipDNN, but deducing the enclosing function rather than +// passing the function name. +#define LOG_API_SUCCESS_AUTO(format, ...) \ + LOG_API_SUCCESS(__func__, format, __VA_ARGS__) + +// Unwrap the value returned from an expression that evaluates to a +// fusilli::ErrorOr. In the unhappy path set plugin error manager last error to +// HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR and return said error from the enclosing +// scope. +// +// Usage: +// fusilli::ErrorOr getString(); +// +// hipdnnPluginStatus_t processString() { +// // Either gets the string or returns error. +// std::string str = FUSILLI_PLUGIN_TRY(getString()); +// doSomethingImportant(str); +// return HIPDNN_PLUGIN_STATUS_SUCCESS; +// } +#define FUSILLI_PLUGIN_TRY(expr) \ + ({ \ + auto errorOr = (expr); \ + if (fusilli::isError(errorOr)) { \ + return hipdnn_plugin::PluginLastErrorManager::setLastError( \ + HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR, \ + fusilli::ErrorObject(errorOr).getMessage()); \ + } \ + std ::move(*errorOr); \ + }) + +template fusilli::ErrorObject convertToErrorObject(T &&error) { + using DecayT = std::decay_t; + if constexpr (std::is_convertible_v) { + return std::forward(error); + } else if constexpr (std::is_same_v) { + // Convert HIP error to fusilli ErrorObject + if (error != hipSuccess) { + return fusilli::error(fusilli::ErrorCode::InternalError, + hipGetErrorString(error)); + } + return fusilli::ok(); + } else { + static_assert( + std::is_convertible_v || + std::is_same_v, + "convertToErrorObject requires fusilli::ErrorObject or hipError_t"); + // Unreachable + return fusilli::error(fusilli::ErrorCode::InternalError, + "Unknown error type"); + } +} + +// Set plugin error manager last error and return failed status from enclosing +// scope if expression evaluates to a fusilli::ErrorObject in an error state; or +// in the case of fusilli::ErrorOr is convertible to an fusilli::ErrorObject +// in an error state. +// +// Usage: +// fusilli::ErrorObject doBar(); +// +// hipdnnPluginStatus_t doFoo() { +// // Returns error if doBar() fails +// FUSILLI_PLUGIN_CHECK_ERROR(doBar()); +// return HIPDNN_PLUGIN_STATUS_SUCCESS; +// } +#define FUSILLI_PLUGIN_CHECK_ERROR(expr) \ + do { \ + fusilli::ErrorObject err = convertToErrorObject(expr); \ + if (isError(err)) { \ + return hipdnn_plugin::PluginLastErrorManager::setLastError( \ + HIPDNN_PLUGIN_STATUS_INTERNAL_ERROR, err.getMessage()); \ + } \ + } while (false) + +// Convert from fusilli DataType to iree hal data type. +inline fusilli::ErrorOr +fusilliDataTypeToIreeHalDataType(fusilli::DataType fusilliDataType) { + switch (fusilliDataType) { + case fusilli::DataType::Half: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_FLOAT_16); + case fusilli::DataType::BFloat16: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_BFLOAT_16); + case fusilli::DataType::Float: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_FLOAT_32); + case fusilli::DataType::Double: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_FLOAT_64); + case fusilli::DataType::Uint8: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_UINT_8); + case fusilli::DataType::Int8: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_INT_8); + case fusilli::DataType::Int16: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_INT_16); + case fusilli::DataType::Int32: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_INT_32); + case fusilli::DataType::Int64: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_INT_64); + case fusilli::DataType::Boolean: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_BOOL_8); + case fusilli::DataType::FP8E5M2: + return fusilli::ok(IREE_HAL_ELEMENT_TYPE_FLOAT_8_E5M2); + case fusilli::DataType::NotSet: + default: + return fusilli::error( + fusilli::ErrorCode::InvalidAttribute, + "unknown data type in fusilli -> iree runtime data type conversion"); + } +} + +} // namespace fusilli_plugin + +#endif // FUSILLI_PLUGIN_SRC_UTILS_H diff --git a/projects/fusilli-plugin/src/fusilli_plugin.cpp b/projects/fusilli-plugin/src/fusilli_plugin.cpp new file mode 100644 index 00000000000..da8d251007d --- /dev/null +++ b/projects/fusilli-plugin/src/fusilli_plugin.cpp @@ -0,0 +1,549 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file is the main entry point for fusilli-plugin, implementations for all +// required hipDNN engine plugin API functions live here. +// +//===----------------------------------------------------------------------===// + +// hipDNN logging expects COMPONENT_NAME to be defined +#define COMPONENT_NAME FUSILLI_PLUGIN_NAME + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "graph_import.h" +#include "hipdnn_engine_plugin_execution_context.h" +#include "hipdnn_engine_plugin_handle.h" +#include "utils.h" + +using namespace hipdnn_plugin; +using namespace fusilli_plugin; + +// TODO(#2317): ensure single source of truth for plugin version +static const char *fusilliPluginVersion = "0.0.1"; + +// s_lastError is thread_local static so can't be initialized in the header file +// as the header file is included in many context. Clear the string here. +thread_local char + PluginLastErrorManager::s_lastError[HIPDNN_PLUGIN_ERROR_STRING_MAX_LENGTH] = + ""; + +extern "C" { + +// ---------------------------------------------------------------------- +// Implementations for the basic plugin API defined in +// hipDNN/sdk/include/hipdnn_sdk/plugin/PluginApi.h +// ---------------------------------------------------------------------- + +hipdnnPluginStatus_t hipdnnPluginGetName(const char **name) { + LOG_API_ENTRY("name_ptr={:p}", static_cast(name)); + FUSILLI_PLUGIN_CHECK_NULL(name); + + *name = FUSILLI_PLUGIN_NAME; + + LOG_API_SUCCESS_AUTO("pluginName={}", *name); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnPluginGetVersion(const char **version) { + LOG_API_ENTRY("version_ptr={:p}", static_cast(version)); + FUSILLI_PLUGIN_CHECK_NULL(version); + + *version = fusilliPluginVersion; + + LOG_API_SUCCESS_AUTO("version={}", *version); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnPluginGetType(hipdnnPluginType_t *type) { + LOG_API_ENTRY("type_ptr={:p}", static_cast(type)); + FUSILLI_PLUGIN_CHECK_NULL(type); + + *type = HIPDNN_PLUGIN_TYPE_ENGINE; + + LOG_API_SUCCESS_AUTO("type={}", *type); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +void hipdnnPluginGetLastErrorString(const char **error_str) { + if (error_str) { + *error_str = hipdnn_plugin::PluginLastErrorManager::getLastError(); + } +} + +// Once plugins are loaded via plugin manager then logging will work for them +hipdnnPluginStatus_t hipdnnPluginSetLoggingCallback(hipdnnCallback_t callback) { + // No LOG_API_ENTRY as logging won't be wired up yet. + FUSILLI_PLUGIN_CHECK_NULL(callback); + + hipdnn::logging::initializeCallbackLogging(FUSILLI_PLUGIN_NAME, callback); + + LOG_API_SUCCESS_AUTO("{}", "logging callback initialized"); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +// ---------------------------------------------------------------------- +// Implementations for engine plugin API defined in +// hipDNN/sdk/include/hipdnn_sdk/plugin/EnginePluginApi.h +// ---------------------------------------------------------------------- + +hipdnnPluginStatus_t hipdnnEnginePluginGetAllEngineIds(int64_t *engineIds, + uint32_t maxEngines, + uint32_t *numEngines) { + LOG_API_ENTRY("engineIds={:p}, maxEngines={}, numEngines={:p}", + static_cast(engineIds), maxEngines, + static_cast(numEngines)); + FUSILLI_PLUGIN_CHECK_NULL(numEngines); + if (maxEngines != 0) { + FUSILLI_PLUGIN_CHECK_NULL(engineIds); + } + + // Set `numEngines` regardless of how many engines are actually returned. + // The backend queries this function twice: + // - First call: engineIds=NULL, maxEngines=0 to get the count + // - Second call: engineIds allocated based on numEngines from first pass + *numEngines = 1; + + if (maxEngines >= 1) { + engineIds[0] = FUSILLI_PLUGIN_ENGINE_ID; + } + + LOG_API_SUCCESS_AUTO("numEngines={}", *numEngines); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t +hipdnnEnginePluginCreate(hipdnnEnginePluginHandle_t *handle) { + LOG_API_ENTRY("handle_ptr={:p}", static_cast(handle)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + + // Get device id. + int deviceId; + FUSILLI_PLUGIN_CHECK_ERROR(hipGetDevice(&deviceId)); + + // Create handle. + *handle = new HipdnnEnginePluginHandle(deviceId); + + LOG_API_SUCCESS_AUTO("createdHandle={:p}", static_cast(*handle)); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t +hipdnnEnginePluginDestroy(hipdnnEnginePluginHandle_t handle) { + LOG_API_ENTRY("handle={:p}", static_cast(handle)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + + delete handle; + + LOG_API_SUCCESS_AUTO("", ""); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t +hipdnnEnginePluginSetStream(hipdnnEnginePluginHandle_t handle, + hipStream_t stream) { + LOG_API_ENTRY("handle={:p}, stream_id={:p}", static_cast(handle), + static_cast(stream)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + + // Get device associated with stream. + hipDevice_t deviceId; + FUSILLI_PLUGIN_CHECK_ERROR(hipStreamGetDevice(stream, &deviceId)); + + // This should never happen, check so that when it does we get a nice error + // message. + if (deviceId != handle->deviceId) { + return hipdnn_plugin::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_BAD_PARAM, + "Stream is associated with different device. Device reported " + "through `hipStreamGetDevice` does not match active " + "device reported through `hipGetDevice`."); + } + + // Set stream, it will be used to create fusilli::Handle later. + handle->setStream(stream); + + LOG_API_SUCCESS_AUTO("", ""); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnEnginePluginGetApplicableEngineIds( + hipdnnEnginePluginHandle_t handle, const hipdnnPluginConstData_t *opGraph, + int64_t *engineIds, uint32_t maxEngines, uint32_t *numEngines) { + LOG_API_ENTRY("handle={:p}, opGraph={:p}, engineIds={:p}, maxEngines={}, " + "numEngines={:p}", + static_cast(handle), static_cast(opGraph), + static_cast(engineIds), maxEngines, + static_cast(numEngines)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(opGraph); + if (maxEngines != 0) { + FUSILLI_PLUGIN_CHECK_NULL(engineIds); + } + FUSILLI_PLUGIN_CHECK_NULL(numEngines); + + *numEngines = 0; + if (maxEngines < 1) { + HIPDNN_LOG_INFO( + "Maximum number of engines reached ({}), ignoring additional " + "engines, numEngines count: {}", + maxEngines, *numEngines); + LOG_API_SUCCESS_AUTO("numEngines={}", *numEngines); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + // Check for single conv_fprop node graph + GraphWrapper opGraphWrapper(opGraph->ptr, opGraph->size); + if (opGraphWrapper.nodeCount() != 1) { + HIPDNN_LOG_INFO("Fusilli plan builder is (currently) only applicable only " + "for single node conv_fprop graphs.", + opGraphWrapper.nodeCount()); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + if (!opGraphWrapper.hasOnlySupportedAttributes( + std::set{ + hipdnn_sdk::data_objects::NodeAttributes:: + ConvolutionFwdAttributes})) { + HIPDNN_LOG_INFO("Fusilli plan builder is (currently) only applicable only " + "for single node conv_fprop graphs.", + opGraphWrapper.nodeCount()); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + // Check single conv_fprop node for symmetric padding + const hipdnn_sdk::data_objects::ConvolutionFwdAttributes *convFwdAttrs = + opGraphWrapper.getNode(0).attributes_as_ConvolutionFwdAttributes(); + // pre/post_padding are flatbuffer::vectors (not std::vectors) and don't + // override ==, so we use std::ranges::equal for structural vs referential + // equality. + if (!std::ranges::equal(*convFwdAttrs->pre_padding(), + *convFwdAttrs->post_padding())) { // C++ 20 + HIPDNN_LOG_INFO("Fusilli plan builder is (currently) requires symmetric " + "padding for conv_fprop nodes.", + opGraphWrapper.nodeCount()); + return HIPDNN_PLUGIN_STATUS_SUCCESS; + } + + // We have a single conv_fprop node with symmetric padding, the fusilli engine + // is applicable. + engineIds[0] = FUSILLI_PLUGIN_ENGINE_ID; + *numEngines = 1; + + LOG_API_SUCCESS_AUTO("numEngines={}", *numEngines); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t +hipdnnEnginePluginGetEngineDetails(hipdnnEnginePluginHandle_t handle, + int64_t engineId, + const hipdnnPluginConstData_t *opGraph, + hipdnnPluginConstData_t *engineDetails) { + // ---------------------------------------------------------------------- + // Plugin API call flow for engine configuration and execution. + // + // hipDNN Plugin + // ====================================================================== + // hipdnnEnginePluginGetEngineDetails -> populates engineDetails + // (flatbuffer object) with + // behavioral notes + knob + // definitions that are available + // to the higher level API. + // Return populated engineDetails + // <- (hipdnnPluginConstData_t). + // + // Decides final configuration, populating ~~ + // engineConfig flatbuffer + // (hipdnnPluginConstData_t) based on info + // provided in engineDetails. + // + // hipdnnEnginePluginCreateExecutionContext -> Creates execution context + // (hipdnnEnginePluginExecutionContext_t) + // <- based on engineConfig. + // + // Uses returned execution context to ~~ + // invoke kernels + // + // hipdnnEnginePluginDestroyEngineDetails -> cleans up engine details. + // + // hipdnnEnginePluginDestroyExecutionContext -> cleans up execution context. + // ---------------------------------------------------------------------- + + LOG_API_ENTRY("handle={:p}, engineId={}, opGraph={:p}, engineDetails={:p}", + static_cast(handle), engineId, + static_cast(opGraph), + static_cast(engineDetails)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(opGraph); + FUSILLI_PLUGIN_CHECK_NULL(engineDetails); + + if (engineId != FUSILLI_PLUGIN_ENGINE_ID) { + return hipdnn_plugin::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_BAD_PARAM, "unexpected engine id"); + } + + // Build engine details object, we're only storing the engine id for the time + // being. + flatbuffers::FlatBufferBuilder builder; + auto engineDetailsObj = + hipdnn_sdk::data_objects::CreateEngineDetails(builder, engineId); + builder.Finish(engineDetailsObj); + + // Populate out parameter. + auto detachedBuffer = + std::make_unique(builder.Release()); + engineDetails->ptr = detachedBuffer->data(); + engineDetails->size = detachedBuffer->size(); + + // Store owning pointer in handle, hipdnnEnginePluginDestroyEngineDetails will + // inform us when it's safe to clean this up. + handle->storeEngineDetailsBuffer(engineDetails->ptr, + std::move(detachedBuffer)); + + LOG_API_SUCCESS_AUTO("engineDetails->ptr={:p}", engineDetails->ptr); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t +hipdnnEnginePluginDestroyEngineDetails(hipdnnEnginePluginHandle_t handle, + hipdnnPluginConstData_t *engineDetails) { + // See comment in hipdnnEnginePluginGetEngineDetails for more about how this + // function fits into the flow. + + LOG_API_ENTRY("handle={:p}, engineDetails={:p}", static_cast(handle), + static_cast(engineDetails)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(engineDetails); + FUSILLI_PLUGIN_CHECK_NULL(engineDetails->ptr); + + // Deallocate engine details. + handle->eraseEngineDetailsBuffer(engineDetails->ptr); + engineDetails->ptr = nullptr; + engineDetails->size = 0; + + LOG_API_SUCCESS_AUTO("engineDetails->ptr={:p}", engineDetails->ptr); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t +hipdnnEnginePluginGetWorkspaceSize(hipdnnEnginePluginHandle_t handle, + const hipdnnPluginConstData_t *engineConfig, + const hipdnnPluginConstData_t *opGraph, + size_t *workspaceSize) { + LOG_API_ENTRY( + "handle={:p}, engineConfig={:p}, opGraph={:p}, workspaceSize={:p}", + static_cast(handle), static_cast(engineConfig), + static_cast(opGraph), static_cast(workspaceSize)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(engineConfig); + FUSILLI_PLUGIN_CHECK_NULL(opGraph); + FUSILLI_PLUGIN_CHECK_NULL(workspaceSize); + + // TODO(#2309): for now we're focusing on kernels that don't require scratch + // buffer space. Eventually we will need to teach IREE to report what scratch + // buffer space required, and how to use a passed in pre-allocated scratch + // space rather than a runtime allocated scratch space. + *workspaceSize = 0; + + LOG_API_SUCCESS_AUTO("workspaceSize={}", *workspaceSize); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnEnginePluginCreateExecutionContext( + hipdnnEnginePluginHandle_t handle, + const hipdnnPluginConstData_t *engineConfig, + const hipdnnPluginConstData_t *opGraph, + hipdnnEnginePluginExecutionContext_t *executionContext) { + // See comment in hipdnnEnginePluginGetEngineDetails for more about how this + // function fits into the flow. + + LOG_API_ENTRY( + "handle={:p}, engineConfig={:p}, opGraph={:p}, executionContext={:p}", + static_cast(handle), static_cast(engineConfig), + static_cast(opGraph), + static_cast(executionContext)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(engineConfig); + FUSILLI_PLUGIN_CHECK_NULL(opGraph); + FUSILLI_PLUGIN_CHECK_NULL(executionContext); + + // Ensure that config contains expected engine id. + hipdnn_plugin::EngineConfigWrapper engineConfigWrapper(engineConfig->ptr, + engineConfig->size); + if (engineConfigWrapper.engineId() != FUSILLI_PLUGIN_ENGINE_ID) { + return hipdnn_plugin::PluginLastErrorManager::setLastError( + HIPDNN_PLUGIN_STATUS_BAD_PARAM, "unexpected engine id"); + } + + auto importAndCompile = [&handle](const hipdnnPluginConstData_t *opGraph) + -> fusilli::ErrorOr { + // Import fusilli::Graph and compute UID -> fusilli::TensorAttr map for + // graph boundary tensors. + HipdnnEnginePluginExecutionContext graphImport = + FUSILLI_TRY(importGraph(opGraph)); + + // Compile graph + FUSILLI_CHECK_ERROR(graphImport.graph.validate()); + FUSILLI_CHECK_ERROR( + graphImport.graph.compile(FUSILLI_TRY(handle->getFusilliHandle()))); + + return fusilli::ok(std::move(graphImport)); + }; + + *executionContext = new HipdnnEnginePluginExecutionContext( + FUSILLI_PLUGIN_TRY(importAndCompile(opGraph))); + + LOG_API_SUCCESS_AUTO("created_execution_context={:p}", + static_cast(*executionContext)); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnEnginePluginDestroyExecutionContext( + hipdnnEnginePluginHandle_t handle, + hipdnnEnginePluginExecutionContext_t executionContext) { + LOG_API_ENTRY("handle={:p}, executionContext={:p}", + static_cast(handle), + static_cast(executionContext)); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(executionContext); + + delete executionContext; + + LOG_API_SUCCESS_AUTO("", "destroyed executionContext"); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +hipdnnPluginStatus_t hipdnnEnginePluginExecuteOpGraph( + hipdnnEnginePluginHandle_t handle, + hipdnnEnginePluginExecutionContext_t executionContext, void *workspace, + const hipdnnPluginDeviceBuffer_t *deviceBuffers, + uint32_t numDeviceBuffers) { + // See comment in hipdnnEnginePluginGetEngineDetails for more about how this + // function fits into the flow. + + LOG_API_ENTRY( + "handle={:p}, executionContext={:p}, workspace={:p}, deviceBuffers={:p}, " + "numDeviceBuffers={}", + static_cast(handle), static_cast(executionContext), + workspace, static_cast(deviceBuffers), numDeviceBuffers); + FUSILLI_PLUGIN_CHECK_NULL(handle); + FUSILLI_PLUGIN_CHECK_NULL(executionContext); + FUSILLI_PLUGIN_CHECK_NULL(deviceBuffers); + + // Params and allocators hoisted out of loop below. + iree_hal_allocator_t *deviceAllocator = iree_hal_device_allocator( + FUSILLI_PLUGIN_TRY(handle->getFusilliHandle()).get()); + iree_allocator_t ireeHostAllocator = iree_allocator_system(); + iree_hal_buffer_params_t bufferParams = { + .usage = IREE_HAL_BUFFER_USAGE_DEFAULT, + .access = IREE_HAL_MEMORY_ACCESS_READ | IREE_HAL_MEMORY_ACCESS_WRITE, + .type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL, + }; + + // Fill variant pack for graph execution. Fusilli expects a variant pack to + // map from fusilli::TensorAttr -> fusilli::Buffer for all boundary tensors. + // + // The execution context (created by hipdnnEnginePluginCreateExecutionContext) + // holds a UID -> fusilli::TensorAttr mapping for all boundary tensors + // already. To build the mapping we need to: + // 1. Find the external HIP-allocated device buffer in `deviceBuffers` + // associated with UID. + // 2. Import buffer from 1) into IREE runtime and create fusilli::Buffer. + // + // We may want to cache all of this in the future. As long as the device + // pointers + UIDs haven't changed it should be possible to re-use an already + // imported buffer + buffer view + the call that fusilli::Graph::execute + // builds internally. + std::unordered_map, + std::shared_ptr> + variantPack; + for (auto &[uid, tensorAttr] : executionContext->uidToFusilliTensorAttr) { + // 1. Find associated buffer. + hipdnnPluginDeviceBuffer_t hipMallocedBuffer = FUSILLI_PLUGIN_TRY( + findDeviceBuffer(uid, deviceBuffers, numDeviceBuffers)); + + // 2.1. Import external buffer into IREE runtime. This isn't allocating a + // buffer, it's making an existing allocation available to the IREE runtime. + iree_hal_external_buffer_t externalBuffer = { + .type = IREE_HAL_EXTERNAL_BUFFER_TYPE_DEVICE_ALLOCATION, + .flags = 0, + .size = static_cast(sizeof(float) * + tensorAttr->getVolume()), + .handle = + { + .device_allocation = + { + .ptr = (uint64_t)hipMallocedBuffer.ptr, + }, + }, + }; + iree_hal_buffer_t *importedBuffer = nullptr; + FUSILLI_PLUGIN_CHECK_ERROR(iree_hal_allocator_import_buffer( + deviceAllocator, bufferParams, &externalBuffer, + iree_hal_buffer_release_callback_null(), &importedBuffer)); + + // 2.2. Create a buffer view for external buffer. + iree_hal_buffer_view_t *outBufferView = nullptr; + FUSILLI_PLUGIN_CHECK_ERROR(iree_hal_buffer_view_create( + /*buffer=*/importedBuffer, /*shape_rank=*/tensorAttr->getDim().size(), + /*shape=*/(const iree_hal_dim_t *)tensorAttr->getDim().data(), + /*element_type=*/ + FUSILLI_PLUGIN_TRY( + fusilliDataTypeToIreeHalDataType(tensorAttr->getDataType())), + /*encoding_type=*/IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, + /*host_allocator=*/ireeHostAllocator, + /*out_buffer_view=*/&outBufferView)); + + // Release our reference to buffer. The buffer view holds a reference to + // buffer and will handle release + possible destruction when it's + // destroyed. + iree_hal_buffer_release(importedBuffer); + + // 2.3. Create fusilli::Buffer from buffer view. Buffer::import is a RAII + // type that retains the buffer view, incrementing its reference count, on + // construction and releases the buffer view on destruction. + variantPack[tensorAttr] = std::make_shared( + FUSILLI_PLUGIN_TRY(fusilli::Buffer::import(outBufferView))); + + // Release our reference to buffer view. The buffer view and buffer will + // (now) be tied to fusilli::Buffer's lifetime as it holds the only + // reference to the buffer view. + iree_hal_buffer_view_release(outBufferView); + } + + FUSILLI_PLUGIN_CHECK_ERROR(executionContext->graph.execute( + FUSILLI_PLUGIN_TRY(handle->getFusilliHandle()), variantPack)); + + LOG_API_SUCCESS_AUTO("{}", "executed graph"); + return HIPDNN_PLUGIN_STATUS_SUCCESS; +} + +} // extern "C" diff --git a/projects/fusilli-plugin/test/CMakeLists.txt b/projects/fusilli-plugin/test/CMakeLists.txt new file mode 100644 index 00000000000..cfb816a2037 --- /dev/null +++ b/projects/fusilli-plugin/test/CMakeLists.txt @@ -0,0 +1,44 @@ +# Copyright 2025 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +include(GoogleTest) +find_package(Threads REQUIRED) + +# Plugin API tests +add_executable(test_fusilli_plugin_api + test_fusilli_plugin_api.cpp + utils.h +) +target_compile_options(test_fusilli_plugin_api PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS}) +target_link_libraries(test_fusilli_plugin_api PRIVATE + fusilli_plugin # Link plugin directly + fusilli::fusilli + GTest::gtest_main + hip::host + hipdnn_sdk + Threads::Threads +) +target_compile_definitions(test_fusilli_plugin_api PRIVATE + FUSILLI_PLUGIN_NAME="${FUSILLI_PLUGIN_NAME}" + FUSILLI_PLUGIN_ENGINE_ID=${FUSILLI_PLUGIN_ENGINE_ID} +) +gtest_discover_tests(test_fusilli_plugin_api) + +# Graph import tests +add_executable(test_graph_import + test_graph_import.cpp + utils.h +) +target_compile_options(test_graph_import PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS}) +target_link_libraries(test_graph_import PRIVATE + fusilli::fusilli + GTest::gtest_main + hipdnn_sdk +) +gtest_discover_tests(test_graph_import) + +# Integration tests +add_subdirectory(integration) diff --git a/projects/fusilli-plugin/test/integration/CMakeLists.txt b/projects/fusilli-plugin/test/integration/CMakeLists.txt new file mode 100644 index 00000000000..f2ba4dce4fd --- /dev/null +++ b/projects/fusilli-plugin/test/integration/CMakeLists.txt @@ -0,0 +1,33 @@ +# Copyright 2025 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +include(GoogleTest) +find_package(Threads REQUIRED) + +# Integration tests +add_executable(fusilli_plugin_integration_tests + test_convfprop.cpp + test_basic.cpp +) +target_compile_options(fusilli_plugin_integration_tests PRIVATE ${HIPDNN_WARNING_COMPILE_OPTIONS}) +target_link_libraries(fusilli_plugin_integration_tests PRIVATE + GTest::gtest_main + hip::host + hipdnn_frontend + hipdnn_sdk + Threads::Threads +) +target_compile_definitions(fusilli_plugin_integration_tests PRIVATE + FUSILLI_PLUGIN_DIR="${HIPDNN_BUILD_PLUGIN_ENGINE_DIR}" + FUSILLI_PLUGIN_NAME="${FUSILLI_PLUGIN_NAME}" + FUSILLI_PLUGIN_ENGINE_ID="${FUSILLI_PLUGIN_ENGINE_ID}" +) +# Register with CTest +gtest_discover_tests(fusilli_plugin_integration_tests + # Ensure that tests pick up libhipdnn_backend.so in build directory, not + # from the global TheRock install. + PROPERTIES ENVIRONMENT "LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/lib" +) diff --git a/projects/fusilli-plugin/test/integration/test_basic.cpp b/projects/fusilli-plugin/test/integration/test_basic.cpp new file mode 100644 index 00000000000..b1561fcc8c2 --- /dev/null +++ b/projects/fusilli-plugin/test/integration/test_basic.cpp @@ -0,0 +1,78 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include +#include +#include + +static std::vector getLoadedPlugins(hipdnnHandle_t handle) { + size_t numPlugins = 0; + size_t maxPathLength = 0; + auto status = hipdnnGetLoadedEnginePluginPaths_ext(handle, &numPlugins, + nullptr, &maxPathLength); + + if (status != HIPDNN_STATUS_SUCCESS) { + throw std::runtime_error("Failed to get loaded plugin paths"); + } + + if (numPlugins == 0) { + return {}; + } + + std::vector> pathBuffers(numPlugins, + std::vector(maxPathLength)); + std::vector pluginPathsC(numPlugins); + for (size_t i = 0; i < numPlugins; ++i) { + pluginPathsC[i] = pathBuffers[i].data(); + } + + status = hipdnnGetLoadedEnginePluginPaths_ext( + handle, &numPlugins, pluginPathsC.data(), &maxPathLength); + if (status != HIPDNN_STATUS_SUCCESS) { + throw std::runtime_error("Failed to get loaded plugin paths"); + } + + std::vector pluginPaths; + pluginPaths.reserve(numPlugins); + for (size_t i = 0; i < numPlugins; ++i) { + pluginPaths.emplace_back(pluginPathsC[i]); + } + return pluginPaths; +} + +TEST(IntegrationTests, PluginLoad) { + // Uncomment if you want debug logging info. + // setenv("HIPDNN_LOG_LEVEL", "info", 1); + + // Ensure hipDNN will load fusilli plugin. + const std::array paths = {FUSILLI_PLUGIN_DIR}; + hipdnnStatus_t status = hipdnnSetEnginePluginPaths_ext( + paths.size(), paths.data(), HIPDNN_PLUGIN_LOADING_ABSOLUTE); + EXPECT_EQ(status, HIPDNN_STATUS_SUCCESS); + + // Stand up enough of hipDNN to load plugins. + hipdnnHandle_t handle = nullptr; + status = hipdnnCreate(&handle); + ASSERT_EQ(status, HIPDNN_STATUS_SUCCESS); + ASSERT_NE(handle, nullptr); + + // If fusilli plugin fails to define a required method it will fail to load. + auto loadedPlugins = getLoadedPlugins(handle); + EXPECT_EQ(loadedPlugins.size(), 1); + + // Check that fusilli plugin did load. + auto expectedPath = std::filesystem::path(FUSILLI_PLUGIN_DIR) / + std::format("lib{}.so", FUSILLI_PLUGIN_NAME); + EXPECT_TRUE(std::ranges::any_of( + loadedPlugins, [&expectedPath](const std::string &loadedPluginPath) { + return std::filesystem::canonical(loadedPluginPath) == expectedPath; + })); + + EXPECT_EQ(hipdnnDestroy(handle), HIPDNN_STATUS_SUCCESS); +} diff --git a/projects/fusilli-plugin/test/integration/test_convfprop.cpp b/projects/fusilli-plugin/test/integration/test_convfprop.cpp new file mode 100644 index 00000000000..1812f5a753b --- /dev/null +++ b/projects/fusilli-plugin/test/integration/test_convfprop.cpp @@ -0,0 +1,197 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +using namespace hipdnn_frontend; +using namespace hipdnn_sdk::utilities; +using namespace hipdnn_sdk::test_utilities; + +struct TestParams { + bool shouldSetStream; + hipDevice_t deviceId; + + // GTest uses the << operator to update the parameterized test name. + friend std::ostream &operator<<(std::ostream &ss, const TestParams &p) { + ss << "Device" << p.deviceId << "_"; + ss << (p.shouldSetStream ? "WithStream" : "WithoutStream"); + return ss; + } +}; + +class ConvFpropIntegrationTest : public ::testing::TestWithParam {}; + +TEST_P(ConvFpropIntegrationTest, Basic1x1Convolution) { + TestParams params = GetParam(); + // Uncomment to enable debug logging + // setenv("HIPDNN_LOG_LEVEL", "info", 1); + + // Initialize HIP. + ASSERT_EQ(hipInit(0), hipSuccess); + + // Set device. + ASSERT_EQ(hipSetDevice(params.deviceId), hipSuccess); + + // Create stream. + hipStream_t stream = nullptr; + if (params.shouldSetStream) { + ASSERT_EQ(hipStreamCreate(&stream), hipSuccess); + } + + // Set plugin path. + const std::array paths = {FUSILLI_PLUGIN_DIR}; + ASSERT_EQ(hipdnnSetEnginePluginPaths_ext(paths.size(), paths.data(), + HIPDNN_PLUGIN_LOADING_ABSOLUTE), + HIPDNN_STATUS_SUCCESS); + + // Create handle. + hipdnnHandle_t handle; + ASSERT_EQ(hipdnnCreate(&handle), HIPDNN_STATUS_SUCCESS); + + // Check that loading the plugin didn't change the active device. + hipDevice_t deviceId = -1; + ASSERT_EQ(hipGetDevice(&deviceId), hipSuccess); + ASSERT_EQ(deviceId, params.deviceId); + + if (params.shouldSetStream) { + ASSERT_EQ(hipdnnSetStream(handle, stream), HIPDNN_STATUS_SUCCESS); + } + + // Dimensions. + const int64_t n = 16; // batch + const int64_t c = 128; // in channels + const int64_t h = 64; // image height + const int64_t w = 64; // image width + const int64_t k = 256; // out channels + const int64_t r = 1; // filter height + const int64_t s = 1; // filter width + + // UIDs. + const int64_t xUID = 0; + const int64_t wUID = 1; + const int64_t yUID = 2; + + // Initialize tensors. + PinnedTensor xTensor({n, c, h, w}); + PinnedTensor wTensor({k, c, r, s}); + PinnedTensor yTensor({n, k, h, w}); + xTensor.fillWithValue(1.0f); + wTensor.fillWithValue(1.0f); + yTensor.fillWithValue(-100.0f); + + // Expected output. + PinnedTensor expectedOutput({n, k, h, w}); + expectedOutput.fillWithValue(128.0f); + + // Create graph. + auto graph = std::make_shared(); + graph->set_name("conv_1x1_test"); + graph->set_io_data_type(DataType_t::FLOAT) + .set_compute_data_type(DataType_t::FLOAT); + + // Create tensor attributes. + auto xAttr = std::make_shared( + graph::makeTensorAttributes("input", DataType_t::FLOAT, xTensor)); + xAttr->set_uid(xUID); + auto wAttr = std::make_shared( + graph::makeTensorAttributes("filter", DataType_t::FLOAT, wTensor)); + wAttr->set_uid(wUID); + + // Create convolution attributes. + graph::ConvFpropAttributes convAttr; + convAttr.set_name("conv_fprop") + .set_padding({0, 0}) + .set_stride({1, 1}) + .set_dilation({1, 1}); + + // Create graph. + auto yAttr = graph->conv_fprop(xAttr, wAttr, convAttr); + yAttr->set_uid(yUID); + yAttr->set_dim(yTensor.dims()).set_stride(yTensor.strides()).set_output(true); + + // Build + validate + build plans for graph. + auto result = graph->validate(); + ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; + + result = graph->build_operation_graph(handle); + ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; + + result = graph->create_execution_plans(); + ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; + + result = graph->check_support(); + ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; + + result = graph->build_plans(); + ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; + + // Create variant pack. + std::unordered_map variantPack; + variantPack[xUID] = xTensor.memory().deviceData(); + variantPack[wUID] = wTensor.memory().deviceData(); + variantPack[yUID] = yTensor.memory().deviceData(); + + // Execute graph. + result = graph->execute(handle, variantPack, nullptr); + ASSERT_EQ(result.code, error_code_t::OK) << result.err_msg; + // Mark hipDNN tensor CPU cache ask stale, data must be read from device. + yTensor.memory().markDeviceModified(); + + // Check results. + CpuFpReferenceValidation validator(1e-6f, 1e-6f); + EXPECT_TRUE(validator.allClose(expectedOutput.memory(), yTensor.memory())); + + // Clean up. + if (params.shouldSetStream) { + ASSERT_EQ(hipStreamDestroy(stream), HIPDNN_STATUS_SUCCESS); + } + ASSERT_EQ(hipdnnDestroy(handle), HIPDNN_STATUS_SUCCESS); +} + +static std::vector generateTestParams() { + std::vector params; + int deviceCount; + assert(hipGetDeviceCount(&deviceCount) == hipSuccess); + + // Always test with device 0. + params.push_back({ + .shouldSetStream = false, + .deviceId = 0, + }); + params.push_back({ + .shouldSetStream = true, + .deviceId = 0, + }); + + // Test on last device if multiple devices are available. + if (deviceCount > 1) { + params.push_back({ + .shouldSetStream = false, + .deviceId = deviceCount - 1, + }); + params.push_back({ + .shouldSetStream = true, + .deviceId = deviceCount - 1, + }); + } + + return params; +} + +INSTANTIATE_TEST_SUITE_P(, ConvFpropIntegrationTest, + ::testing::ValuesIn(generateTestParams())); diff --git a/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp b/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp new file mode 100644 index 00000000000..00df0ddcab2 --- /dev/null +++ b/projects/fusilli-plugin/test/test_fusilli_plugin_api.cpp @@ -0,0 +1,407 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fusilli/attributes/tensor_attributes.h" +#include "graph_import.h" +#include "hipdnn_engine_plugin_execution_context.h" +#include "utils.h" + +bool loggingCallbackCalled = false; +std::vector capturedLogMessages; +std::vector capturedLogSeverities; +std::mutex logMutex; +std::condition_variable logConditionVariable; + +void testLoggingCallback(hipdnnSeverity_t severity, const char *msg) { + // hipDNN sets spdlog up to log in a separate thread, so we need to put our + // mutual exclusion gloves on before touching any variables the main thread + // does. + std::scoped_lock lock(logMutex); + + loggingCallbackCalled = true; + if (msg) { + capturedLogMessages.push_back(std::string(msg)); + capturedLogSeverities.push_back(severity); + } + logConditionVariable.notify_one(); +} + +TEST(TestFusilliPluginApi, Logging) { + // Set tracking variables + { + std::scoped_lock lock(logMutex); + loggingCallbackCalled = false; + capturedLogMessages.clear(); + capturedLogSeverities.clear(); + } + + // Set up logging callback + ASSERT_EQ(hipdnnPluginSetLoggingCallback(testLoggingCallback), + HIPDNN_PLUGIN_STATUS_SUCCESS); + + std::unique_lock lock(logMutex); + + // Wait for the logging callback to signal that it has been called. + auto timeout = std::chrono::steady_clock::now() + std::chrono::seconds(5); + EXPECT_TRUE(logConditionVariable.wait_until( + lock, timeout, [&]() { return loggingCallbackCalled; })); + + EXPECT_TRUE(loggingCallbackCalled); + EXPECT_FALSE(capturedLogMessages.empty()); + EXPECT_TRUE(capturedLogMessages.front().find( + "logging callback initialized") != std::string::npos); +}; + +TEST(TestFusilliPluginApi, GetNameSuccess) { + const char *name = nullptr; + EXPECT_EQ(hipdnnPluginGetName(&name), HIPDNN_PLUGIN_STATUS_SUCCESS); + EXPECT_STREQ(name, FUSILLI_PLUGIN_NAME); +} + +TEST(TestFusilliPluginApi, GetNameNullptr) { + EXPECT_EQ(hipdnnPluginGetName(nullptr), HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + // Verify error was set + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); +} + +TEST(TestFusilliPluginApi, GetVersionSuccess) { + const char *version = nullptr; + EXPECT_EQ(hipdnnPluginGetVersion(&version), HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_NE(version, nullptr); + // TODO(#2317): check returned version against single source of truth. +} + +TEST(TestFusilliPluginApi, GetVersionNullptr) { + EXPECT_EQ(hipdnnPluginGetVersion(nullptr), HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + // Verify error was set + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); +} + +TEST(TestFusilliPluginApi, GetTypeSuccess) { + hipdnnPluginType_t type; + EXPECT_EQ(hipdnnPluginGetType(&type), HIPDNN_PLUGIN_STATUS_SUCCESS); + EXPECT_EQ(type, HIPDNN_PLUGIN_TYPE_ENGINE); +} + +TEST(TestFusilliPluginApi, GetTypeNullptr) { + EXPECT_EQ(hipdnnPluginGetType(nullptr), HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + // Verify error was set + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); +} + +TEST(TestFusilliPluginApi, GetLastErrorStringSuccess) { + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); + // Initially should be empty or contain a previous error + EXPECT_GE(strlen(errorStr), 0); +} + +TEST(TestFusilliPluginApi, GetLastErrorStringNullptr) { + // This should not crash even with nullptr + EXPECT_NO_THROW(hipdnnPluginGetLastErrorString(nullptr)); +} + +TEST(TestFusilliPluginApi, SetLoggingCallbackNullptr) { + // Setting nullptr should return BAD_PARAM + EXPECT_EQ(hipdnnPluginSetLoggingCallback(nullptr), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + // Verify error was set + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); +} + +TEST(TestFusilliPluginApi, GetAllEngineIds) { + // First call with null buffer to get count + uint32_t numEngines = 0; + EXPECT_EQ(hipdnnEnginePluginGetAllEngineIds(nullptr, 0, &numEngines), + HIPDNN_PLUGIN_STATUS_SUCCESS); + EXPECT_EQ(numEngines, 1); + + // Second call to get actual engine IDs + std::vector engineIds(numEngines); + EXPECT_EQ(hipdnnEnginePluginGetAllEngineIds(engineIds.data(), numEngines, + &numEngines), + HIPDNN_PLUGIN_STATUS_SUCCESS); + EXPECT_EQ(numEngines, 1); + EXPECT_EQ(engineIds[0], FUSILLI_PLUGIN_ENGINE_ID); +} + +TEST(TestFusilliPluginApi, GetAllEngineIdsNullNumEngines) { + EXPECT_EQ(hipdnnEnginePluginGetAllEngineIds(nullptr, 0, nullptr), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + // Verify error was set + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); + EXPECT_GT(strlen(errorStr), 0u); +} + +// TODO(#2363): investigate using createValidConvFwdGraph from upstream hipDNN +flatbuffers::FlatBufferBuilder +createValidConvFwdGraph(int64_t xUID = 0, int64_t wUID = 1, int64_t yUID = 2, + hipdnn_sdk::data_objects::DataType dataType = + hipdnn_sdk::data_objects::DataType::FLOAT, + const std::vector &xDims = {4, 4, 4, 4}, + const std::vector &xStrides = {64, 16, 4, 1}, + const std::vector &wDims = {4, 4, 1, 1}, + const std::vector &wStrides = {4, 1, 1, 1}, + const std::vector &yDims = {4, 4, 4, 4}, + const std::vector &yStrides = {64, 16, 4, 1}, + const std::vector &convPrePadding = {0, 0}, + const std::vector &convPostPadding = {0, 0}, + const std::vector &convStrides = {1, 1}, + const std::vector &convDilation = {1, 1}) { + flatbuffers::FlatBufferBuilder builder; + std::vector<::flatbuffers::Offset> + tensorAttributes; + + tensorAttributes.push_back(CreateTensorAttributesDirect( + builder, xUID, "x", dataType, &xStrides, &xDims)); + + tensorAttributes.push_back(CreateTensorAttributesDirect( + builder, wUID, "w", dataType, &wStrides, &wDims)); + + tensorAttributes.push_back(CreateTensorAttributesDirect( + builder, yUID, "y", dataType, &yStrides, &yDims)); + + auto convAttributes = CreateConvolutionFwdAttributesDirect( + builder, + /*x_tensor_uid*/ xUID, + /*w_tensor_uid*/ wUID, + /*y_tensor_uid*/ yUID, &convPrePadding, &convPostPadding, &convStrides, + &convDilation, hipdnn_sdk::data_objects::ConvMode::CROSS_CORRELATION); + + std::vector<::flatbuffers::Offset> nodes; + auto node = CreateNodeDirect( + builder, "conv_fwd", + hipdnn_sdk::data_objects::NodeAttributes::ConvolutionFwdAttributes, + convAttributes.Union()); + nodes.push_back(node); + + auto graphOffset = + CreateGraphDirect(builder, "test", + /*compute_type*/ dataType, + /*intermediate_type*/ dataType, + /*io_type=*/dataType, &tensorAttributes, &nodes); + builder.Finish(graphOffset); + return builder; +} + +TEST(TestFusilliPluginApi, GetApplicableEngineIds) { + // Create plugin handle. + hipdnnEnginePluginHandle_t handle = nullptr; + ASSERT_EQ(hipdnnEnginePluginCreate(&handle), HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_NE(handle, nullptr); + + // Create a serialized hipDNN bach norm graph. + auto builder = hipdnn_sdk::test_utilities::createValidBatchnormBwdGraph(); + hipdnnPluginConstData_t opGraph; + opGraph.ptr = builder.GetBufferPointer(); + opGraph.size = builder.GetSize(); + + // Fusilli plugin should not offer to compile and execute bach norm (yet). + std::array engineIDs; + uint32_t numEngines = -1; + ASSERT_EQ(hipdnnEnginePluginGetApplicableEngineIds( + handle, &opGraph, engineIDs.data(), 5, &numEngines), + HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_EQ(numEngines, 0); + + // Create a serialized hipDNN conv_fprop graph with symmetric padding. + builder = createValidConvFwdGraph(); + opGraph.ptr = builder.GetBufferPointer(); + opGraph.size = builder.GetSize(); + + // Fusilli plugin should offer to compile and execute single node conv_fprop. + ASSERT_EQ(hipdnnEnginePluginGetApplicableEngineIds( + handle, &opGraph, engineIDs.data(), 5, &numEngines), + HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_EQ(numEngines, 1); + ASSERT_EQ(engineIDs[0], FUSILLI_PLUGIN_ENGINE_ID); + + // Create a serialized hipDNN conv_fprop graph with asymmetric padding. + builder = createValidConvFwdGraph( + /*xUID=*/0, /*wUID=*/1, /*yUID=*/2, + /*dataType=*/hipdnn_sdk::data_objects::DataType::FLOAT, + /*xDims=*/{4, 4, 4, 4}, /*xStrides=*/{64, 16, 4, 1}, + /*wDims=*/{4, 4, 1, 1}, /*wStrides=*/{4, 1, 1, 1}, + /*yDims=*/{4, 4, 4, 4}, /*yStrides=*/{64, 16, 4, 1}, + /*convPrePadding=*/{1, 0}, // asymmetric: pre doesn't match post + /*convPostPadding=*/{2, 1}, // asymmetric: pre doesn't match post + /*convStrides=*/{1, 1}, /*convDilation=*/{1, 1}); + opGraph.ptr = builder.GetBufferPointer(); + opGraph.size = builder.GetSize(); + + // Fusilli plugin should not offer to compile and execute single node + // conv_fprop with asymmetric padding. + ASSERT_EQ(hipdnnEnginePluginGetApplicableEngineIds( + handle, &opGraph, engineIDs.data(), 5, &numEngines), + HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_EQ(numEngines, 0); +} + +TEST(TestFusilliPluginApi, CreateExecutionContext) { + // Create plugin handle. + hipdnnEnginePluginHandle_t handle = nullptr; + ASSERT_EQ(hipdnnEnginePluginCreate(&handle), HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_NE(handle, nullptr); + + // UIDs. + int64_t xUID = 1; + int64_t wUID = 2; + int64_t yUID = 3; + + // Dims and strides. + const std::vector expectedXDims = {4, 4, 4, 4}; + const std::vector expectedXStrides = {64, 16, 4, 1}; + const std::vector expectedWDims = {4, 4, 1, 1}; + const std::vector expectedWStrides = {4, 1, 1, 1}; + const std::vector expectedYDims = {4, 4, 4, 4}; + const std::vector expectedYStrides = {64, 16, 4, 1}; + const hipdnn_sdk::data_objects::DataType dataType = + hipdnn_sdk::data_objects::DataType::FLOAT; + fusilli::DataType expectedDataType = + FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType(dataType)); + + // Create a serialized hipDNN conv_fprop. + auto builder = createValidConvFwdGraph( + xUID, wUID, yUID, dataType, expectedXDims, expectedXStrides, + expectedWDims, expectedWStrides, expectedYDims, expectedYStrides); + hipdnnPluginConstData_t opGraph; + opGraph.ptr = builder.GetBufferPointer(); + opGraph.size = builder.GetSize(); + + // Create engine config. + flatbuffers::FlatBufferBuilder configBuilder; + auto engineConfig = hipdnn_sdk::data_objects::CreateEngineConfig( + configBuilder, FUSILLI_PLUGIN_ENGINE_ID); + configBuilder.Finish(engineConfig); + hipdnnPluginConstData_t engineConfigData; + engineConfigData.ptr = configBuilder.GetBufferPointer(); + engineConfigData.size = configBuilder.GetSize(); + + // The function we're actually testing. + hipdnnEnginePluginExecutionContext_t executionContext = nullptr; + ASSERT_EQ(hipdnnEnginePluginCreateExecutionContext( + handle, &engineConfigData, &opGraph, &executionContext), + HIPDNN_PLUGIN_STATUS_SUCCESS); + ASSERT_NE(executionContext, nullptr); + + auto *ctx = + static_cast(executionContext); + + // Check that we have 3 tensors tracked (x, w, y). + EXPECT_EQ(ctx->uidToFusilliTensorAttr.size(), 3); + + // Check x tensor properties. + ASSERT_TRUE(ctx->uidToFusilliTensorAttr.contains(xUID)); // C++ 20 + std::shared_ptr xTensor = + ctx->uidToFusilliTensorAttr[xUID]; + EXPECT_EQ(xTensor->getDim(), expectedXDims); + EXPECT_EQ(xTensor->getStride(), expectedXStrides); + EXPECT_EQ(xTensor->getDataType(), expectedDataType); + EXPECT_FALSE(xTensor->isVirtual()); + + // Check w tensor properties. + ASSERT_TRUE(ctx->uidToFusilliTensorAttr.contains(wUID)); // C++ 20 + std::shared_ptr wTensor = + ctx->uidToFusilliTensorAttr[wUID]; + EXPECT_EQ(wTensor->getDim(), expectedWDims); + EXPECT_EQ(wTensor->getStride(), expectedWStrides); + EXPECT_EQ(wTensor->getDataType(), expectedDataType); + EXPECT_FALSE(wTensor->isVirtual()); + + // Check y tensor properties. + ASSERT_TRUE(ctx->uidToFusilliTensorAttr.contains(wUID)); // C++ 20 + std::shared_ptr yTensor = + ctx->uidToFusilliTensorAttr[yUID]; + EXPECT_EQ(yTensor->getDim(), expectedYDims); + EXPECT_EQ(yTensor->getStride(), expectedYStrides); + EXPECT_EQ(yTensor->getDataType(), expectedDataType); + EXPECT_FALSE(yTensor->isVirtual()); + + // Verify graph properties. + EXPECT_EQ(ctx->graph.context.getIODataType(), expectedDataType); + EXPECT_EQ(ctx->graph.context.getIntermediateDataType(), expectedDataType); + EXPECT_EQ(ctx->graph.context.getComputeDataType(), expectedDataType); + + // Clean up. + EXPECT_EQ(hipdnnEnginePluginDestroyExecutionContext(handle, executionContext), + HIPDNN_PLUGIN_STATUS_SUCCESS); + EXPECT_EQ(hipdnnEnginePluginDestroy(handle), HIPDNN_PLUGIN_STATUS_SUCCESS); +} + +TEST(TestFusilliPluginApi, SetStreamSuccess) { + // Create plugin handle. + hipdnnEnginePluginHandle_t handle = nullptr; + ASSERT_EQ(hipdnnEnginePluginCreate(&handle), HIPDNN_PLUGIN_STATUS_SUCCESS); + + // Create a HIP stream. + hipStream_t stream; + ASSERT_EQ(hipStreamCreate(&stream), hipSuccess); + + // Set the stream on the handle. + EXPECT_EQ(hipdnnEnginePluginSetStream(handle, stream), + HIPDNN_PLUGIN_STATUS_SUCCESS); + + // Clean up. + EXPECT_EQ(hipStreamDestroy(stream), hipSuccess); + EXPECT_EQ(hipdnnEnginePluginDestroy(handle), HIPDNN_PLUGIN_STATUS_SUCCESS); +} + +TEST(TestFusilliPluginApi, SetStreamNullHandle) { + // Create a HIP stream. + hipStream_t stream; + ASSERT_EQ(hipStreamCreate(&stream), hipSuccess); + + // Attempt to set stream with null handle should fail. + EXPECT_EQ(hipdnnEnginePluginSetStream(nullptr, stream), + HIPDNN_PLUGIN_STATUS_BAD_PARAM); + + // Verify error was set. + const char *errorStr = nullptr; + hipdnnPluginGetLastErrorString(&errorStr); + ASSERT_NE(errorStr, nullptr); + EXPECT_GT(strlen(errorStr), 0u); + + // Clean up. + EXPECT_EQ(hipStreamDestroy(stream), hipSuccess); +} diff --git a/projects/fusilli-plugin/test/test_graph_import.cpp b/projects/fusilli-plugin/test/test_graph_import.cpp new file mode 100644 index 00000000000..fa36ecfa5ab --- /dev/null +++ b/projects/fusilli-plugin/test/test_graph_import.cpp @@ -0,0 +1,40 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "graph_import.h" +#include "utils.h" + +#include +#include +#include + +TEST(TestGraphImport, ConvertHipDnnToFusilli) { + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::HALF)), + fusilli::DataType::Half); + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::BFLOAT16)), + fusilli::DataType::BFloat16); + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::FLOAT)), + fusilli::DataType::Float); + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::DOUBLE)), + fusilli::DataType::Double); + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::UINT8)), + fusilli::DataType::Uint8); + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::INT32)), + fusilli::DataType::Int32); + EXPECT_EQ(FUSILLI_PLUGIN_EXPECT_UNWRAP(hipDnnDataTypeToFusilliDataType( + hipdnn_sdk::data_objects::DataType::UNSET)), + fusilli::DataType::NotSet); + + auto invalidResult = hipDnnDataTypeToFusilliDataType( + static_cast(42)); + EXPECT_TRUE(isError(invalidResult)); +} diff --git a/projects/fusilli-plugin/test/utils.h b/projects/fusilli-plugin/test/utils.h new file mode 100644 index 00000000000..de16707fc8e --- /dev/null +++ b/projects/fusilli-plugin/test/utils.h @@ -0,0 +1,29 @@ +// Copyright 2025 Advanced Micro Devices, Inc. +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===----------------------------------------------------------------------===// +// +// This file contains utilities for fusilli plugin tests. +// +//===----------------------------------------------------------------------===// + +#ifndef FUSILLI_PLUGIN_TESTS_UTILS_H +#define FUSILLI_PLUGIN_TESTS_UTILS_H + +// Unwrap the type returned from an expression that evaluates to an ErrorOr, +// fail the test using GTest's EXPECT_TRUE if the result is an ErrorObject. +// +// This is very similar to FUSILLI_TRY, but FUSILLI_TRY propagates an error to +// callers on the error path, this fails the test on the error path. The two +// macros are analogous to rust's `?` (try) operator and `.unwrap()` call. +#define FUSILLI_PLUGIN_EXPECT_UNWRAP(expr) \ + ({ \ + auto _errorOr = (expr); \ + EXPECT_TRUE(isOk(_errorOr)); \ + std::move(*_errorOr); \ + }) + +#endif // FUSILLI_PLUGIN_TESTS_UTILS_H