Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/cuda_plugin_ep/cuda_plugin_ep_design.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ The primary approach moves pure-computation helpers from CPU `.cc` files to head
- `roialign.h` — `CheckROIAlignValidInput`, `RoiAlignBase` constructor (templatized on info type)
- `upsamplebase.h` — `UpsampleBase::AdjustOutputSizeAsPolicy`
- `crop.h` — `CropBase` constructor (templatized on info type)
- `space_depth_ops.h` — `SpaceDepthBase` constructor (templatized on info type)
- `space_depth_ops.h` — `SpaceDepthBase` constructor plus shared `ReadBlocksize`, `ReadIsDCR`, and dimension-validation helpers (templatized on info/context type where needed)
- `clip.h` — Clip min/max attribute handling (removed `Clip_6Base` CPU dependency)
- `cuda_common_type_helpers.h` — CUDA type conversion and handle error string helpers (moved from `cuda_common.cc`)

Expand Down Expand Up @@ -249,7 +249,8 @@ This allows the base class constructor to work with both the framework `OpKernel
Some CPU base classes have heavy dependencies (protobuf, `UnpackTensor`) that make inlining impractical:

- **`ConstantOfShapeBase`** — depends on `TensorProto` and `UnpackTensor`. The plugin path in `constant_of_shape.h` stays self-contained: it reuses `ConstantOfShapeCore` but fetches the `value` attribute through the ORT C++ API instead of depending on the full CPU base implementation.
- **`UpsampleBase`** — partially addressed: `AdjustOutputSizeAsPolicy` moved to header (#27628). Still depends on `InputDefs()` and `OpKernelInfo::GetAllocator()` which are not in the adapter.

`UpsampleBase` no longer belongs in this category: the adapter now exposes `OpKernelInfo::GetAllocator(OrtMemType)`, and the remaining shape-rank query already has an adapter-safe fallback when `Node::InputDefs()` is unavailable. That lets the CUDA `Upsample` antialias path reuse the same persistent device lookup-table initialization in both bundled and plugin builds instead of keeping a plugin-only scratch-buffer fallback.

---

Expand Down Expand Up @@ -603,7 +604,7 @@ The branch still contains a small set of plugin guards in both infrastructure an
- `generator/constant_of_shape.h` still needs a plugin-specific path because `ConstantOfShapeBase` depends on framework-only tensor-attribute helpers.
- Tunable kernels such as `math/matmul.cc` still gate framework-only registration paths.
- `tensor/identity_op.h` guards the `TensorSeq` code path and `context->InputType()` call with `#ifndef BUILD_CUDA_EP_AS_PLUGIN` — the plugin build handles only the `Tensor` path. `identity_op.cc` uses conditional macros (`IDENTITY_V_TYPES` / `IDENTITY_V_TYPES_IRv9`) so opset 14+ registrations use `AllFixedSizeTensorTypes()` in the plugin build. Additionally, old Dropout opset 7–9 and 10–11 kernel registrations were moved from `identity_op.cc` to `nn/dropout.cc` so that each op's registrations live in that op's own source file.
- A few tensor kernels (`pad.cc`, `tile.cc`, `unsqueeze.cc`, `upsample.*`, `space_depth_ops.h`, `scatter_nd.*`) still contain localized plugin guards where adapter and framework paths have not fully converged.
- A few tensor kernels (`pad.cc`, `tile.cc`, `unsqueeze.cc`) still contain localized plugin guards where adapter and framework paths have not fully converged. Recent cleanup removed the plugin-only branches from `upsample.*`, `space_depth_ops.h`, and `scatter_nd.*` by moving reusable logic into shared adapter-safe helpers and by adding allocator access to `ep::adapter::OpKernelInfo`.

The broad trend remains positive: most operator-level plugin conditionals were removed by moving reusable CPU/helper logic into shared headers and by centralizing stream bridging in `CudaKernel` helpers.

Expand Down
21 changes: 21 additions & 0 deletions include/onnxruntime/ep/adapter/op_kernel_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,27 @@ struct OpKernelInfo {
const DataTransferManager& GetDataTransferManager() const noexcept {
return (static_cast<const Ep*>(cache_->ort_ep_))->GetDataTransferManager();
}

AllocatorPtr GetAllocator(OrtMemType mem_type) const {
const auto* ort_ep = cache_->ort_ep_;
ORT_ENFORCE(ort_ep != nullptr, "Kernel execution provider is not associated with an OrtEp instance.");

AllocatorPtr allocator;
const auto* ep = static_cast<const Ep*>(ort_ep);

if (mem_type == OrtMemTypeDefault) {
ORT_THROW_IF_ERROR(ep->GetTempSpaceAllocator(&allocator));
return allocator;
}

if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput || mem_type == OrtMemTypeCPU) {
ORT_THROW_IF_ERROR(ep->GetTempSpaceCPUAllocator(&allocator));
return allocator;
}

ORT_THROW("Unsupported OrtMemType in adapter::OpKernelInfo::GetAllocator: ", static_cast<int>(mem_type));
Comment thread
tianleiwu marked this conversation as resolved.
Outdated
}

Node node() const noexcept {
return Node{cache_->kernel_info_};
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/cpu_provider_shared.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
// From cpu/tensor/scatter_nd.h (direct)
Status ScatterNDBase__ValidateShapes(const TensorShape& input_shape,
const TensorShape& indice_shape,
const TensorShape& update_shape) override { return ScatterND::ValidateShapes(input_shape, indice_shape, update_shape); }
const TensorShape& update_shape) override { return scatter_nd_internal::ValidateShapes(input_shape, indice_shape, update_shape); }
// From cpu/tensor/padbase.h (direct)
Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); }

Expand Down
82 changes: 46 additions & 36 deletions onnxruntime/core/providers/cpu/tensor/scatter_nd.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,51 @@ namespace concurrency {
class ThreadPool;
}

namespace scatter_nd_internal {

inline Status ValidateShapes(const TensorShape& input_shape,
const TensorShape& indice_shape,
const TensorShape& update_shape) {
auto input_rank = input_shape.NumDimensions();
auto indice_rank = indice_shape.NumDimensions();
auto update_rank = update_shape.NumDimensions();

if (input_rank == 0 || indice_rank == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input tensor and indices tensor must have rank larger than 0. ",
"input shape: ", input_shape, ", indices shape: ", indice_shape);
}

auto last_indice_dimension = indice_shape[indice_rank - 1];
if (last_indice_dimension > static_cast<int64_t>(input_rank)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"last dimension of indices must not be larger than rank of input tensor");
}

bool is_update_shape_invalid = [&]() {
if (update_rank != (input_rank + indice_rank - 1 - static_cast<ptrdiff_t>(last_indice_dimension))) {
return true;
}
if (indice_shape.Slice(0, indice_rank - 1) != update_shape.Slice(0, indice_rank - 1)) {
return true;
}
if (input_shape.Slice(onnxruntime::narrow<size_t>(last_indice_dimension)) != update_shape.Slice(indice_rank - 1)) {
return true;
}
return false;
}();

if (is_update_shape_invalid) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"updates tensor should have shape equal to indices.shape[:-1] + data.shape[indices.shape[-1]:]. ",
"updates shape: ", update_shape, ", indices shape: ", indice_shape, ", data shape: ", input_shape);
}

return Status::OK();
}

} // namespace scatter_nd_internal

class ScatterND final : public OpKernel {
public:
enum class Reduction : int {
Expand Down Expand Up @@ -51,42 +96,7 @@ class ScatterND final : public OpKernel {
static inline Status ValidateShapes(const TensorShape& input_shape,
const TensorShape& indice_shape,
const TensorShape& update_shape) {
auto input_rank = input_shape.NumDimensions();
auto indice_rank = indice_shape.NumDimensions();
auto update_rank = update_shape.NumDimensions();

if (input_rank == 0 || indice_rank == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input tensor and indices tensor must has rank larger than 0. ",
"input shape: ", input_shape, ", indices shape: ", indice_shape);
}

auto last_indice_dimension = indice_shape[indice_rank - 1];
if (last_indice_dimension > static_cast<int64_t>(input_rank)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"last dimension of indices must not be larger than rank of input tensor");
}

bool is_update_shape_invalid = [&]() {
if (update_rank != (input_rank + indice_rank - 1 - static_cast<ptrdiff_t>(last_indice_dimension))) {
return true;
}
if (indice_shape.Slice(0, indice_rank - 1) != update_shape.Slice(0, indice_rank - 1)) {
return true;
}
if (input_shape.Slice(onnxruntime::narrow<size_t>(last_indice_dimension)) != update_shape.Slice(indice_rank - 1)) {
return true;
}
return false;
}();

if (is_update_shape_invalid) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"updates tensor should have shape equal to indices.shape[:-1] + data.shape[indices.shape[-1]:]. ",
"updates shape: ", update_shape, ", indices shape: ", indice_shape, ", data shape: ", input_shape);
}

return Status::OK();
return scatter_nd_internal::ValidateShapes(input_shape, indice_shape, update_shape);
}
#endif // SHARED_PROVIDER

Expand Down
148 changes: 91 additions & 57 deletions onnxruntime/core/providers/cpu/tensor/space_depth_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,72 +3,114 @@

#pragma once

Comment thread
tianleiwu marked this conversation as resolved.
#if !defined(SHARED_PROVIDER) && !defined(BUILD_CUDA_EP_AS_PLUGIN)
#include "core/framework/op_kernel.h"
#endif

namespace onnxruntime {

class SpaceDepthBase {
protected:
template <typename KernelInfoType>
explicit SpaceDepthBase(const KernelInfoType& info) {
ORT_ENFORCE(info.template GetAttr<int64_t>("blocksize", &blocksize_).IsOK(),
"Attribute blocksize is not set.");
namespace space_depth_internal {

template <typename KernelInfoType>
inline int64_t ReadBlocksize(const KernelInfoType& info) {
int64_t blocksize = 0;
ORT_ENFORCE(info.template GetAttr<int64_t>("blocksize", &blocksize).IsOK(),
"Attribute blocksize is not set.");
return blocksize;
}

template <typename KernelInfoType>
inline bool ReadIsDCR(const KernelInfoType& info) {
bool is_dcr = true;
std::string mode;

Check warning on line 25 in onnxruntime/core/providers/cpu/tensor/space_depth_ops.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cpu/tensor/space_depth_ops.h:25: Add #include <string> for string [build/include_what_you_use] [4]
// If mode doesn't exist, then it is the default "DCR" mode
// (or) it is an opset < 11 model for which the only mode is "DCR" mode.
if (info.GetAttr("mode", &mode).IsOK()) {
if (mode == "CRD") {
is_dcr = false;
} else if (mode != "DCR") {
ORT_THROW("DepthToSpace op: only 'DCR' and 'CRD' modes are supported");
}
}

template <bool IsNHWC = false>
Status InputValidationsAndOutputDimsCalc(const Tensor& input,
int64_t& batch,
int64_t& input_depth, int64_t& input_height, int64_t& input_width,
int64_t& output_depth, int64_t& output_height, int64_t& output_width,
bool is_space_to_depth) const {
const TensorShape& input_shape = input.Shape();
return is_dcr;
}

template <bool IsNHWC = false>
inline Status InputValidationsAndOutputDimsCalc(int64_t blocksize,
const Tensor& input,
int64_t& batch,
int64_t& input_depth, int64_t& input_height, int64_t& input_width,
int64_t& output_depth, int64_t& output_height, int64_t& output_width,
bool is_space_to_depth) {
const TensorShape& input_shape = input.Shape();

if (input_shape.NumDimensions() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "SpaceDepth ops require a 4-D input. Provided rank: ",
input_shape.NumDimensions());
}

if (input_shape.NumDimensions() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "SpaceDepth ops require a 4-D input. Provided rank: ",
input_shape.NumDimensions());
batch = input_shape[0];
if constexpr (IsNHWC) {
input_depth = input_shape[3];
input_height = input_shape[1];
input_width = input_shape[2];
} else {
input_depth = input_shape[1];
input_height = input_shape[2];
input_width = input_shape[3];
}

if (is_space_to_depth) { // SpaceToDepth op
if ((input_height % blocksize) != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "SpaceToDepth requires input height to be a multiple of block_size");
}

batch = input_shape[0];
if constexpr (IsNHWC) {
input_depth = input_shape[3];
input_height = input_shape[1];
input_width = input_shape[2];
} else {
input_depth = input_shape[1];
input_height = input_shape[2];
input_width = input_shape[3];
if ((input_width % blocksize) != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "SpaceToDepth requires input width to be a multiple of block_size");
}

if (is_space_to_depth) { // SpaceToDepth op
if ((input_height % this->blocksize_) != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "SpaceToDepth requires input height to be a multiple of block_size");
}
output_depth = input_depth * blocksize * blocksize;
output_height = input_height / blocksize;
output_width = input_width / blocksize;

} else { // DepthToSpace op
if ((input_depth % (blocksize * blocksize) != 0)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"DepthToSpace requires input depth to be a multiple of (block_size * block_size)");
}

if ((input_width % this->blocksize_) != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "SpaceToDepth requires input width to be a multiple of block_size");
}
output_depth = input_depth / blocksize / blocksize;
output_height = input_height * blocksize;
output_width = input_width * blocksize;
}

output_depth = input_depth * blocksize_ * blocksize_;
output_height = input_height / blocksize_;
output_width = input_width / blocksize_;
return Status::OK();
}

} else { // DepthToSpace op
if ((input_depth % (blocksize_ * blocksize_) != 0)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"DepthToSpace requires input depth to be a multiple of (block_size * block_size)");
}
} // namespace space_depth_internal

output_depth = input_depth / blocksize_ / blocksize_;
output_height = input_height * blocksize_;
output_width = input_width * blocksize_;
}
class SpaceDepthBase {
protected:
template <typename KernelInfoType>
explicit SpaceDepthBase(const KernelInfoType& info) : blocksize_(space_depth_internal::ReadBlocksize(info)) {}

return Status::OK();
template <bool IsNHWC = false>
Status InputValidationsAndOutputDimsCalc(const Tensor& input,
int64_t& batch,
int64_t& input_depth, int64_t& input_height, int64_t& input_width,
int64_t& output_depth, int64_t& output_height, int64_t& output_width,
bool is_space_to_depth) const {
return space_depth_internal::InputValidationsAndOutputDimsCalc<IsNHWC>(
blocksize_, input, batch, input_depth, input_height, input_width,
output_depth, output_height, output_width, is_space_to_depth);
}

int64_t blocksize_;
};

#if !defined(SHARED_PROVIDER) && !defined(BUILD_CUDA_EP_AS_PLUGIN)

class SpaceToDepth final : public OpKernel, SpaceDepthBase {
public:
explicit SpaceToDepth(const OpKernelInfo& info) : OpKernel(info), SpaceDepthBase(info) {
Expand All @@ -79,23 +121,15 @@

class DepthToSpace final : public OpKernel, SpaceDepthBase {
public:
explicit DepthToSpace(const OpKernelInfo& info) : OpKernel(info), SpaceDepthBase(info) {
std::string mode;
// if mode doesn't exist, then it is the default "DCR" mode
// (or) it is an opset < 11 model for which the only mode is "DCR" mode
if (info.GetAttr("mode", &mode).IsOK()) {
if (mode == "CRD")
is_dcr_ = false;

else if (mode != "DCR")
ORT_THROW("DepthToSpace op: only 'DCR' and 'CRD' modes are supported");
}
}
explicit DepthToSpace(const OpKernelInfo& info)
: OpKernel(info), SpaceDepthBase(info), is_dcr_(space_depth_internal::ReadIsDCR(info)) {}

Status Compute(OpKernelContext* context) const override;

private:
bool is_dcr_ = true;
};

#endif // !defined(SHARED_PROVIDER) && !defined(BUILD_CUDA_EP_AS_PLUGIN)

} // namespace onnxruntime
5 changes: 2 additions & 3 deletions onnxruntime/core/providers/cuda/cudnn_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,10 @@ struct Consts<BFloat16> {

inline double ClampCudnnBatchNormEpsilon(double epsilon) {
if (epsilon < CUDNN_BN_MIN_EPSILON) {
#ifndef BUILD_CUDA_EP_AS_PLUGIN
if (CUDNN_BN_MIN_EPSILON - epsilon > FLT_EPSILON)
if (CUDNN_BN_MIN_EPSILON - epsilon > FLT_EPSILON) {
LOGS_DEFAULT(WARNING) << "Provided epsilon is smaller than CUDNN_BN_MIN_EPSILON. "
<< "Setting it to CUDNN_BN_MIN_EPSILON";
#endif
}
return CUDNN_BN_MIN_EPSILON;
}
return epsilon;
Expand Down
8 changes: 0 additions & 8 deletions onnxruntime/core/providers/cuda/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,8 @@ Status Conv<T, Layout>::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShap
CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_operation_graph(handle));
CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->create_execution_plans({heur_mode}));
} catch (const std::exception& ex) {
#ifndef BUILD_CUDA_EP_AS_PLUGIN
std::string message = MakeString("Failed to initialize CUDNN Frontend: ", ex.what(),
" with the cudnn frontend json:\n", s_.cudnn_fe_graph->print());
#else
std::string message = MakeString("Failed to initialize CUDNN Frontend: ", ex.what());
#endif
return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message);
}

Expand All @@ -253,12 +249,8 @@ Status Conv<T, Layout>::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShap
CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_plans(handle));
} catch (const std::exception& ex) {
if (!fuse_bias && !fuse_act && use_tf32) {
#ifndef BUILD_CUDA_EP_AS_PLUGIN
std::string message = MakeString("OP not supported by CUDNN Frontend: ", ex.what(),
" with the cudnn frontend json:\n", s_.cudnn_fe_graph->print());
#else
std::string message = MakeString("OP not supported by CUDNN Frontend: ", ex.what());
#endif
return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message);
}

Expand Down
Loading
Loading