Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/aoti/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
find_package_torch()

# Common AOTI functionality - combines all AOTI common components
set(_aoti_common_sources aoti_model_container.cpp common_shims.cpp)
set(_aoti_common_sources common_shims.cpp)
add_library(aoti_common STATIC ${_aoti_common_sources})
target_include_directories(
aoti_common
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,36 +60,17 @@ using AOTInductorModelContainerRunFunc = AOTIRuntimeError (*)(
AOTInductorStreamHandle stream_handle,
AOTIProxyExecutorHandle proxy_executor_handle);

// Global function pointers (will be loaded dynamically)
extern AOTInductorModelContainerCreateWithDeviceFunc
AOTInductorModelContainerCreateWithDevice;
extern AOTInductorModelContainerDeleteFunc AOTInductorModelContainerDelete;
extern AOTInductorModelContainerGetNumInputsFunc
AOTInductorModelContainerGetNumInputs;
extern AOTInductorModelContainerGetNumOutputsFunc
AOTInductorModelContainerGetNumOutputs;
extern AOTInductorModelContainerRunFunc AOTInductorModelContainerRun;

// Retrieves the name of an input tensor by index from the AOTI model container.
// Needed by Metal backend
using AOTInductorModelContainerGetInputNameFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
size_t input_idx,
const char** input_name);

// Retrieves the number of constants from the AOTI model container.
// Needed by Metal backend
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@manuelcandales see if Metal should also save this inside the handle

using AOTInductorModelContainerGetNumConstantsFunc = AOTIRuntimeError (*)(
AOTInductorModelContainerHandle container_handle,
size_t* num_constants);

// Global function pointers (will be loaded dynamically).
// Needed by Metal backend
extern AOTInductorModelContainerGetInputNameFunc
AOTInductorModelContainerGetInputName;
extern AOTInductorModelContainerGetNumConstantsFunc
AOTInductorModelContainerGetNumConstants;

} // extern "C"

// AOTI Delegate Handle structure
Expand All @@ -99,6 +80,13 @@ struct AOTIDelegateHandle {
AOTInductorModelContainerHandle container_handle;
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
// dependency

// Function pointers specific to this handle's shared library
AOTInductorModelContainerCreateWithDeviceFunc create_with_device;
AOTInductorModelContainerDeleteFunc delete_container;
AOTInductorModelContainerGetNumInputsFunc get_num_inputs;
AOTInductorModelContainerGetNumOutputsFunc get_num_outputs;
AOTInductorModelContainerRunFunc run;
};

} // namespace aoti
Expand Down
39 changes: 0 additions & 39 deletions backends/aoti/aoti_model_container.cpp

This file was deleted.

11 changes: 4 additions & 7 deletions backends/aoti/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,9 @@ def define_common_targets():

# AOTI model container functionality
runtime.cxx_library(
name = "model_container",
srcs = [
"aoti_model_container.cpp",
],
name = "delegate_handle",
headers = [
"aoti_model_container.h",
"aoti_delegate_handle.h",
],
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
link_whole = True,
Expand All @@ -44,7 +41,7 @@ def define_common_targets():
],
)

# Common AOTI functionality (combining both common_shims and model_container)
# Common AOTI functionality (combining both common_shims and delegate_handle)
runtime.cxx_library(
name = "aoti_common",
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
Expand All @@ -53,6 +50,6 @@ def define_common_targets():
visibility = ["@EXECUTORCH_CLIENTS"],
exported_deps = [
":common_shims",
":model_container",
":delegate_handle",
],
)
75 changes: 44 additions & 31 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@
#include <vector>

// Include our shim layer headers
#include <executorch/backends/aoti/aoti_model_container.h>
#include <executorch/backends/aoti/aoti_delegate_handle.h>
#include <executorch/backends/aoti/common_shims.h>
#include <executorch/backends/cuda/runtime/shims/memory.h>
#include <executorch/backends/cuda/runtime/utils.h>

namespace executorch::backends::cuda {

#define LOAD_SYMBOL(name, handle) \
do { \
name = reinterpret_cast<name##Func>(dlsym(handle, #name)); \
ET_CHECK_OR_RETURN_ERROR( \
name != nullptr, AccessFailed, "Failed to load " #name); \
#define LOAD_SYMBOL(handle, member, name, so_handle) \
do { \
handle->member = reinterpret_cast<name##Func>(dlsym(so_handle, #name)); \
ET_CHECK_OR_RETURN_ERROR( \
handle->member != nullptr, AccessFailed, "Failed to load " #name); \
} while (0)

using namespace std;
Expand All @@ -57,12 +57,31 @@ using executorch::runtime::etensor::Tensor;
class ET_EXPERIMENTAL CudaBackend final
: public ::executorch::runtime::BackendInterface {
private:
Error register_shared_library_functions(void* so_handle) const {
LOAD_SYMBOL(AOTInductorModelContainerCreateWithDevice, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerDelete, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerGetNumInputs, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerGetNumOutputs, so_handle);
LOAD_SYMBOL(AOTInductorModelContainerRun, so_handle);
Error load_function_pointers_into_handle(
void* so_handle,
AOTIDelegateHandle* handle) const {
LOAD_SYMBOL(
handle,
create_with_device,
AOTInductorModelContainerCreateWithDevice,
so_handle);

LOAD_SYMBOL(
handle, delete_container, AOTInductorModelContainerDelete, so_handle);

LOAD_SYMBOL(
handle,
get_num_inputs,
AOTInductorModelContainerGetNumInputs,
so_handle);

LOAD_SYMBOL(
handle,
get_num_outputs,
AOTInductorModelContainerGetNumOutputs,
so_handle);

LOAD_SYMBOL(handle, run, AOTInductorModelContainerRun, so_handle);

return Error::Ok;
}
Expand Down Expand Up @@ -135,19 +154,22 @@ class ET_EXPERIMENTAL CudaBackend final

processed->Free();

// Register all shared library functions
ET_CHECK_OK_OR_RETURN_ERROR(register_shared_library_functions(so_handle));
// Create handle and load function pointers into it
AOTIDelegateHandle* handle = new AOTIDelegateHandle();
handle->so_handle = so_handle;
handle->so_path = so_path.string();

// Load function pointers specific to this handle's shared library
ET_CHECK_OK_OR_RETURN_ERROR(
load_function_pointers_into_handle(so_handle, handle));

AOTInductorModelContainerHandle container_handle = nullptr;

ET_CHECK_OK_OR_RETURN_ERROR(AOTInductorModelContainerCreateWithDevice(
&container_handle, 1, "cuda", nullptr));
ET_CHECK_OK_OR_RETURN_ERROR(
handle->create_with_device(&container_handle, 1, "cuda", nullptr));

ET_LOG(Info, "container_handle = %p", container_handle);

AOTIDelegateHandle* handle = new AOTIDelegateHandle();
handle->so_handle = so_handle;
handle->so_path = so_path.string();
handle->container_handle = container_handle;

// Create a CUDA stream for asynchronous execution
Expand All @@ -165,20 +187,11 @@ class ET_EXPERIMENTAL CudaBackend final
Span<EValue*> args) const override {
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;

// Need to re-register all the symbols from the so_handle hosted by this
// CudaBackend instance. The reason is that these symbols are
// static/singleton across the whole process. When we share multiple methods
// (meaning multiple so_handle) in the same process, we need to re-register
// the symbols from the so_handle that is being used in this execution.
ET_CHECK_OK_OR_RETURN_ERROR(
register_shared_library_functions(handle->so_handle));

size_t n_inputs;
AOTInductorModelContainerGetNumInputs(handle->container_handle, &n_inputs);
handle->get_num_inputs(handle->container_handle, &n_inputs);

size_t n_outputs;
AOTInductorModelContainerGetNumOutputs(
handle->container_handle, &n_outputs);
handle->get_num_outputs(handle->container_handle, &n_outputs);

ET_CHECK_OR_RETURN_ERROR(
n_inputs + n_outputs == args.size(),
Expand Down Expand Up @@ -261,7 +274,7 @@ class ET_EXPERIMENTAL CudaBackend final
gpu_outputs[i] = gpu_output_handle;
}
// Run AOTI container with GPU tensors
AOTIRuntimeError error = AOTInductorModelContainerRun(
AOTIRuntimeError error = handle->run(
handle->container_handle,
gpu_inputs.data(), // Use GPU input tensors
n_inputs,
Expand Down
Loading