Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4f3d004
[EP ABI] Add weight pre-packing support to kernel-based plugin EPs
adrianlizarraga Dec 8, 2025
e954021
Add comment about sharing of prepacked weights (cpu ep only)
adrianlizarraga Dec 8, 2025
fb2998b
Update Mul kernel to pre-pack input b
adrianlizarraga Dec 9, 2025
5e64f79
Apply suggestions from code review
adrianlizarraga Dec 9, 2025
9b1c6a2
Add comments regarding prepack allocator lifetime
adrianlizarraga Dec 9, 2025
c638a1a
Merge branch 'adrianl/plugin-ep-kernel-prepack' of github.com:microso…
adrianlizarraga Dec 9, 2025
717ed4a
Added support for sharing pre-packed weights for cpu-accessible alloc…
adrianlizarraga Dec 11, 2025
bd8f6f0
Define what should happen if OrtKernelImpl::SetSharedPrePackedWeight(…
adrianlizarraga Dec 12, 2025
fc1fd16
Merge branch 'main' into adrianl/plugin-ep-kernel-prepack
adrianlizarraga Dec 16, 2025
8b3f56c
Clean up some exception handling
adrianlizarraga Dec 16, 2025
23503a1
Refactor example kernel classes (no inheritance)
adrianlizarraga Dec 17, 2025
7f37ffb
Merge branch 'main' into adrianl/plugin-ep-kernel-prepack
adrianlizarraga Dec 17, 2025
26eca56
Correct use of output param
adrianlizarraga Dec 17, 2025
7af257b
Add more edge-case handling for PrePack() call
adrianlizarraga Dec 17, 2025
515062e
API version checks
adrianlizarraga Dec 17, 2025
347ce4f
Use correct SAL annotation for array parameters
adrianlizarraga Dec 18, 2025
906187d
Clean up some includes
adrianlizarraga Dec 18, 2025
1611fc3
Update onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc
adrianlizarraga Dec 19, 2025
30ca590
Remove OrtAllocator parameter from SharedPrePackedWeightCache_StoreWe…
adrianlizarraga Dec 19, 2025
a5342b9
Clarify what happens when SharedPrePackedWeightCache_StoreWeightData …
adrianlizarraga Dec 20, 2025
51bc731
Merge branch 'main' into adrianl/plugin-ep-kernel-prepack
adrianlizarraga Dec 22, 2025
edf3f2c
Review comments
adrianlizarraga Dec 23, 2025
e94c0aa
C++ API
adrianlizarraga Dec 23, 2025
c8eb3c9
Improve doc for c++ api convenience class
adrianlizarraga Dec 23, 2025
98e3d13
Add buffer_sizes as a parameter to OrtKernelImpl::SetSharedWeightData
adrianlizarraga Dec 24, 2025
c61ae41
Add comment to implementation of OrtKernelImpl::SetSharedPrePackedWeight
adrianlizarraga Dec 24, 2025
02d75d2
Do not prescribe what the kernel impl should return for a situation t…
adrianlizarraga Dec 24, 2025
0a84eda
Update include/onnxruntime/core/session/onnxruntime_ep_c_api.h
adrianlizarraga Dec 24, 2025
5f80f9d
Adjust comments
adrianlizarraga Dec 24, 2025
c60472d
Tweak comment again
adrianlizarraga Dec 24, 2025
441c9e2
Add comments to clarify ownership scenarios
adrianlizarraga Dec 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2104,11 +2104,12 @@ if (onnxruntime_BUILD_SHARED_LIB AND
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_allocator.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_data_transfer.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/base.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h"
Expand Down
24 changes: 23 additions & 1 deletion include/onnxruntime/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include "boost/mp11.hpp"
#include <gsl/gsl>

Check warning on line 7 in include/onnxruntime/core/framework/op_kernel.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after other header. Should be: op_kernel.h, c system, c++ system, other. [build/include_order] [4] Raw Output: include/onnxruntime/core/framework/op_kernel.h:7: Found C system header after other header. Should be: op_kernel.h, c system, c++ system, other. [build/include_order] [4]

// It is safe to include the below header even if SHARED_PROVIDER macro is enabled
// as it doesn't include any pb headers.
Expand All @@ -26,7 +27,6 @@
#include "core/graph/constants.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/onnx_protobuf.h"
#include <gsl/gsl>
namespace onnxruntime {
class OpKernelContext;
}
Expand Down Expand Up @@ -105,6 +105,7 @@
return Status::OK();
}

// Note: New implementations should override OpKernel::UseSharedPrePackedBuffers_V2 instead.
// Override this function to use provided pre-packed weight.
// Status UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers,
// int input_idx,
Expand All @@ -130,6 +131,27 @@
return Status::OK();
}

/// <summary>
/// Version 2 of OpKernel::UseSharedPrePackedBuffers() that additionally accepts the buffer sizes as a parameter.
/// The default implementation of this function just calls directly to OpKernel::UseSharedPrePackedBuffers()
/// to avoid the need to update all existing kernel-based provider-bridge EPs.
///
/// TODO: Consolidate UseSharedPrePackedBuffers and UseSharedPrePackedBuffers_V2 into a single function,
/// which will require updating kernel-based provider-bridge EPs (cpu, cuda, webgpu).
///
/// </summary>
/// <param name="prepacked_buffers"></param>
/// <param name="prepacked_buffer_sizes"></param>
/// <param name="input_idx"></param>
/// <param name="used_shared_buffers"></param>
/// <returns></returns>
virtual Status UseSharedPrePackedBuffers_V2(std::vector<BufferUniquePtr>& prepacked_buffers,
gsl::span<const size_t> /*prepacked_buffer_sizes*/,
int input_idx,
/*out*/ bool& used_shared_buffers) {
return UseSharedPrePackedBuffers(prepacked_buffers, input_idx, used_shared_buffers);
}

const OrtDevice GetDevice(OrtMemType mem_type) const;
const OpKernelInfo& Info() const {
return *op_kernel_info_;
Expand Down
29 changes: 29 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3379,5 +3379,34 @@ struct KernelRegistry : detail::Base<OrtKernelRegistry> {
Status AddKernel(const OrtKernelDef* kernel_def, OrtKernelCreateFunc kernel_create_func,
void* kernel_create_func_state);
};

namespace detail {
template <typename T>
struct SharedPrePackedWeightCacheImpl : Ort::detail::Base<T> {
using B = Ort::detail::Base<T>;
using B::B;

//< Wraps SharedPrePackedWeightCache_StoreWeightData
Status StoreWeightData(void** buffer_data_ptrs, size_t* buffer_sizes, size_t num_buffers);
};
} // namespace detail

/** \brief Convenience C++ wrapper class around a ::OrtSharedPrePackedWeightCache instance owned by ORT.
*
* An `OrtSharedPrePackedWeightCache*` instance is passed as an argument to OrtKernelImpl::PrePackWeight.
* Example use:
* OrtStatus* MyKernel::PrePackWeightImpl(OrtKernelImpl*, ..., OrtSharedPrePackedWeightCache* c_cache, ...) {
* ...
* if (c_cache != nullptr) {
* Ort::UnownedSharedPrePackedWeightCache cpp_cache(c_cache);
* Ort::Status status = cpp_cache.StoreWeightData(...);
* }
* ...
* }
*
* \remarks OrtSharedPrePackedWeightCache is always unowned, but mutable, for EpApi users.
*/
using UnownedSharedPrePackedWeightCache =
detail::SharedPrePackedWeightCacheImpl<Ort::detail::Unowned<OrtSharedPrePackedWeightCache>>;
} // namespace Ort
#include "onnxruntime_cxx_inline.h"
9 changes: 9 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -3713,4 +3713,13 @@ inline Status KernelRegistry::AddKernel(const OrtKernelDef* kernel_def, OrtKerne
void* kernel_create_func_state) {
return Status{GetEpApi().KernelRegistry_AddKernel(p_, kernel_def, kernel_create_func, kernel_create_func_state)};
}

namespace detail {
template <typename T>
inline Status SharedPrePackedWeightCacheImpl<T>::StoreWeightData(void** buffer_data_ptrs, size_t* buffer_sizes,
size_t num_buffers) {
return Status{GetEpApi().SharedPrePackedWeightCache_StoreWeightData(this->p_, buffer_data_ptrs, buffer_sizes,
num_buffers)};
}
} // namespace detail
} // namespace Ort
125 changes: 125 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ ORT_RUNTIME_CLASS(KernelRegistry);
ORT_RUNTIME_CLASS(KernelDefBuilder);
ORT_RUNTIME_CLASS(KernelDef);
ORT_RUNTIME_CLASS(DataType); // combination of ONNXType (e.g., Tensor, Map, Sequence) and ONNXTensorElementDataType
ORT_RUNTIME_CLASS(SharedPrePackedWeightCache);

/** \brief Struct that an EP implements for IDataTransfer to copy between devices it uses and CPU.
*
Expand Down Expand Up @@ -308,6 +309,101 @@ struct OrtKernelImpl {
* \since Version 1.24.
*/
ORT_API_T(void, Release, _In_ OrtKernelImpl* this_ptr);

/** \brief Optional function to pre-pack a constant tensor (i.e., a weight) to the kernel's preferred data layout.
*
* For example, a Conv kernel can define this function to pack input W to the channel-last data layout
* before inference.
*
* Pre-packing can operate in three different modes: no pre-packing mode, sharing mode, and non-sharing mode.
* 1) No pre-packing mode: The kernel can forgo any weight pre-packing for the given `input_index` by setting
* `is_packed` to false and returning a successful OrtStatus. In this mode, the kernel's
* OrtKernelImpl::SetSharedPrePackedWeight() function is not called for that specific
* `input_index`.
* 2) Sharing mode: Sharing is allowed if the `prepacked_weight_cache` argument is not NULL and the EP stores
* weight data in CPU-accessible memory. In this case, the kernel can optionally choose
* to share the packed weight with other kernels that use the same weight
* (compared by content hash). To do so, the kernel must allocate the packed weight with the
* provided `allocator`, then it stores the packed weight data into `prepacked_weight_cache`
* via SharedPrePackedWeightCache_StoreWeightData(), sets `is_packed` to true, and returns a
* successful OrtStatus. ORT will subsequently call OrtKernelImpl::SetSharedPrePackedWeight()
* to provide this kernel with the actual shared weight data, whose memory location could
* differ (i.e., if shared data was allocated by a previously processed kernel).
* 3) Non-sharing mode: In non-sharing mode, the `prepacked_weight_cache` argument is ignored. In this mode,
* the implementation allocates the packed data with the provided `allocator`, sets
* `is_packed` to true, and returns a successful OrtStatus. The kernel is ultimately
* responsible for releasing the packed data for the weight with `allocator`.
* ORT may release the original (unpacked) weight, which must not be accessed in
* OrtKernelImpl::Compute(). Note that in this mode, the kernel's
* OrtKernelImpl::SetSharedPrePackedWeight() function is not called by ORT for that specific
* `input_index`.
*
* \note This function is based on the internal OpKernel::PrePack() virtual function used within ORT.
*
* \param[in] this_ptr The OrtKernelImpl instance.
* \param[in] tensor The OrtValue instance representing the constant tensor (weight). Do not cache in the kernel.
* \param[in] input_index The input index of the tensor in this kernel.
* \param[in] allocator Allocator for allocating the pre-packed data. Its use is required in sharing mode and
* recommended, but not required, in the non-sharing mode. This will be an allocator set by
* the application for the session/environment (e.g., via CreateAndRegisterAllocator[V2]
* or RegisterAllocator), or an allocator on the OrtEpDevice (read-only or default) otherwise.
* The allocator remains valid throughout the lifetime of the OrtKernelImpl instance.
* \param[in] prepacked_weights_cache May be NULL. If not NULL, the kernel may choose to share a packed weight by
* first storing it in the OrtSharedPrePackedWeightCache instance and then
* receiving the actual shared weight data in the call to
* OrtKernelImpl::SetSharedPrePackedWeight(). See the above description for
* "sharing mode".
* \param[out] is_packed Output parameter that the implementation sets to true if the kernel packed the tensor data.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \note Implementation of this function is optional. If not implemented (set to NULL), ORT assumes the kernel
* does not pre-pack weight data (i.e., `is_packed` defaults to false).
*
* \since Version 1.24.
*/
ORT_API2_STATUS(PrePackWeight, _In_ OrtKernelImpl* this_ptr, _In_ const OrtValue* tensor,
_In_ int input_index, _Inout_ OrtAllocator* allocator,
_In_opt_ OrtSharedPrePackedWeightCache* prepacked_weight_cache, _Out_ bool* is_packed);

/** \brief Optional function that receives data for a shared pre-packed weight from ORT.
*
* ORT calls this function after calling OrtKernelImpl::PrePackWeight for a specific `input_index` if:
* - OrtKernelImpl::PrePackWeight set the output parameter `is_packed` to true.
* - OrtKernelImpl::PrePackWeight stored weight data to share into the provided OrtSharedPrePackedWeightCache
* parameter (`prepacked_weight_cache`) via the API SharedPrePackedWeightCache_StoreWeightData.
*
* Refer to the description of the "sharing-mode" in the documentation for OrtKernelImpl::PrePackWeight().
*
* \note ORT will not call this function for an `input_index` that a previous call to
* OrtKernelImpl::PrePackWeight() did not elect to pre-pack and share.
*
* \note This function is based on the internal OpKernel::UseSharedPrePackedBuffers() virtual function used
* within ORT.
*
* \param[in] this_ptr The OrtKernelImpl instance.
* \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a
* single shared weight. The buffers are provided in the same order and with the same
* contents (in a potentially different memory location) as the buffers
* passed into SharedPrePackedWeightCache_StoreWeightData() within the
* OrtKernelImpl::PrePackWeight() call for the same `input_index`.
* \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`.
* \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight.
* Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays.
* \param[in] input_index The input index of the tensor in this kernel. This index identifies the identity of
* the weight.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \note Implementation of this function is generally optional. It is only required if OrtKernelImpl::PrePack()
* elects to share pre-packed weights.
*
* \since Version 1.24.
*/
ORT_API2_STATUS(SetSharedPrePackedWeight, _In_ OrtKernelImpl* this_ptr,
_In_reads_(num_buffers) const void* const* buffer_data_ptrs,
_In_reads_(num_buffers) const size_t* buffer_data_sizes,
_In_ size_t num_buffers, _In_ int input_index);
};

/** \brief Type definition for a function that creates an OrtKernelImpl instance for an operator kernel.
Expand Down Expand Up @@ -846,6 +942,35 @@ struct OrtEpApi {
*/
ORT_API2_STATUS(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_ const OrtNode* node, _Outptr_result_maybenull_ const OrtKernelDef** out_kernel_def);

/** \brief Sets one or more data buffers that collectively hold the pre-packed data for a single shared weight.
*
* \note Used within the implementation of OrtKernelImpl::PrePackWeight() when the kernel wants to share pre-packed
* weight data with other kernels. The buffer data MUST be allocated with the OrtAllocator provided to
* OrtKernelImpl::PrePack.
*
* \note Ownership of weight data transfers to the OrtSharedPrePackedWeightCache instance on success.
* If this function returns an error status, the caller retains ownership of the weight data.
*
* \note Subsequent calls with the same OrtSharedPrePackedWeightCache instance release and replace the old data.
*
* \param[in] this_ptr The OrtKernelImpl instance.
* \param[in] buffer_data_ptrs An array of buffer data pointers that collectively hold the pre-packed data for a
* single shared weight. Note that sometimes a single weight may have multiple pre-packed
* buffers and it is up to the kernel implementation to determine how to split the data
* into multiple buffers (if desired).
* \param[in] buffer_data_sizes An array of buffer byte sizes, one per element in `buffer_data_ptrs`.
* \param[in] num_buffers The number of buffers used to store the data for the shared pre-packed weight.
* Specifies the number of elements in the `buffer_data_ptrs` and `buffer_data_sizes` arrays.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(SharedPrePackedWeightCache_StoreWeightData,
_In_ OrtSharedPrePackedWeightCache* prepacked_weight_cache,
_In_reads_(num_buffers) void** buffer_data_ptrs, _In_reads_(num_buffers) size_t* buffer_data_sizes,
_In_ size_t num_buffers);
};

/**
Expand Down
15 changes: 11 additions & 4 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,16 +421,23 @@ void SessionState::CleanInitializedTensorsFromGraph() {
static Status KernelUseSharedPrePackedBuffers(OpKernel& kernel, int input_idx,
const PrePackedWeights& prepacked_weights,
const std::string& node_name) {
const size_t num_buffers = prepacked_weights.buffers_.size();
assert(prepacked_weights.buffer_sizes_.size() == num_buffers);

std::vector<BufferUniquePtr> shared_prepacked_buffers;
shared_prepacked_buffers.reserve(4); // Unlikely to see more than 4 prepacked buffers per initializer
std::vector<size_t> shared_prepacked_buffer_sizes;
shared_prepacked_buffers.reserve(num_buffers);
shared_prepacked_buffer_sizes.reserve(num_buffers);

for (const auto& prepacked_buffer : prepacked_weights.buffers_) {
for (size_t i = 0; i < num_buffers; i++) {
// BufferDeleter is nullptr because the kernel should not delete the shared buffer - it can only use it
shared_prepacked_buffers.emplace_back(prepacked_buffer.get(), BufferDeleter(nullptr));
shared_prepacked_buffers.emplace_back(prepacked_weights.buffers_[i].get(), BufferDeleter(nullptr));
shared_prepacked_buffer_sizes.push_back(prepacked_weights.buffer_sizes_[i]);
}

bool used_shared_buffers = false;
ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers(shared_prepacked_buffers, input_idx, used_shared_buffers));
ORT_RETURN_IF_ERROR(kernel.UseSharedPrePackedBuffers_V2(shared_prepacked_buffers, shared_prepacked_buffer_sizes,
input_idx, used_shared_buffers));

// BUG CHECK: Ensure that the kernel used the provided shared buffers
// Mostly a debug check to ensure that the kernel has an overridden implementation of the
Expand Down
42 changes: 42 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,47 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_LookUpKernel, _In_ OrtEpGraphSupportInfo*
API_IMPL_END
}

ORT_API_STATUS_IMPL(SharedPrePackedWeightCache_StoreWeightData,
_In_ OrtSharedPrePackedWeightCache* prepacked_weight_cache,
_In_reads_(num_buffers) void** buffer_data_ptrs, _In_reads_(num_buffers) size_t* buffer_data_sizes,
_In_ size_t num_buffers) {
API_IMPL_BEGIN
if (prepacked_weight_cache == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a valid OrtPrePackedWeightsCache instance");
}

if (buffer_data_ptrs == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of buffer data pointers");
}

if (buffer_data_sizes == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify a valid array of buffer data sizes");
}

if (num_buffers == 0) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Must specify at least one weight data buffer");
}

OrtStatus* status = nullptr;

ORT_TRY {
prepacked_weight_cache->SetBuffers(buffer_data_ptrs, buffer_data_sizes, num_buffers);
}
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
// This API function promises that ORT will take ownership of the data only if it returns successfully.
// If any exception occurred while filling out `prepacked_weight_cache`, we try to release ownership so that
// the caller retains ownership of all of the original data and can delete it.
prepacked_weight_cache->ReleaseAllData();
status = OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what());
});
}

return status;
API_IMPL_END
}

static constexpr OrtEpApi ort_ep_api = {
// NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end,
// and no functions can be removed (the implementation needs to change to return an error).
Expand Down Expand Up @@ -636,6 +677,7 @@ static constexpr OrtEpApi ort_ep_api = {
&OrtExecutionProviderApi::KernelDef_GetOutputMemType,
&OrtExecutionProviderApi::GetTensorDataType,
&OrtExecutionProviderApi::EpGraphSupportInfo_LookUpKernel,
&OrtExecutionProviderApi::SharedPrePackedWeightCache_StoreWeightData,
};

// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned
Expand Down
Loading
Loading