Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions onnxruntime/core/providers/shared/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,6 @@ bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node, float& min
node, min, max, logger);
}

// deprecated version that is not able to check if the initializer is constant
bool GetClipMinMax(const InitializedTensorSet& initializers, const Node& node, float& min, float& max,
const logging::Logger& logger) {
return GetClipMinMaxImpl(
[&initializers](const std::string& name) -> const ONNX_NAMESPACE::TensorProto* {
auto entry = initializers.find(name);
return entry == initializers.end() ? nullptr : entry->second;
},
node, min, max, logger);
}

NodeAttrHelper::NodeAttrHelper(const onnxruntime::Node& node)
: node_attributes_(node.GetAttributes()) {}

Expand Down
6 changes: 0 additions & 6 deletions onnxruntime/core/providers/shared/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@ class NodeUnit;
bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node,
float& min, float& max, const logging::Logger& logger);

/// <deprecated>GraphViewer GetConstantInitializer/IsConstantInitializer should be used to ensure the initializer is
/// constant. Low risk for Clip min/max but in general the infrastructure to check if an operator is supported needs
/// to be updated to not use InitializedTensorSet which may contain non-constant initializers.</deprecated>
bool GetClipMinMax(const InitializedTensorSet& initializers, const Node& node,
float& min, float& max, const logging::Logger& logger);

// Get the type of the given NodeArg
// Will return false if the given NodeArg has no type
bool GetType(const NodeArg& node_arg, int32_t& type, const logging::Logger& logger);
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ bool GetShape(const NodeArg& node_arg, std::vector<int64_t>& shape, const loggin
return true;
}

bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const WebnnDeviceType device_type,
bool IsNodeSupported(const GraphViewer& graph_viewer, const Node& node, const WebnnDeviceType device_type,
const emscripten::val& wnn_limits, const logging::Logger& logger) {
const auto& op_builders = GetOpBuilders();
if (Contains(op_builders, node.OpType())) {
const auto* op_builder = op_builders.at(node.OpType());
return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, wnn_limits, logger);
return op_builder->IsOpSupported(graph_viewer, node, device_type, wnn_limits, logger);
} else {
return false;
}
Expand Down Expand Up @@ -107,7 +107,7 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
std::unordered_set<const Node*> supported_nodes;

for (const auto& node : graph_viewer.Nodes()) {
const bool supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger);
const bool supported = IsNodeSupported(graph_viewer, node, device_type, wnn_limits, logger);
LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType()
<< "] index: [" << node.Index()
<< "] name: [" << node.Name()
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,13 @@ inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::va
return true;
}

inline bool IsEmptyTensor(const InitializedTensorSet& initializers, const std::string& name) {
if (name.empty() || !Contains(initializers, name)) {
inline bool IsEmptyTensor(const GraphViewer& graph_viewer, const std::string& name) {
const auto* tensor_init = graph_viewer.GetConstantInitializer(name);
if (name.empty() || !tensor_init) {
return true;
}

const auto& tensor = *initializers.at(name);
const auto& tensor = *tensor_init;
const auto dims = tensor.dims();
// An empty tensor contains a 0 in the dimensions list.
return std::any_of(dims.begin(), dims.end(), [](auto d) { return d == 0; });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
bool IsOpSupportedImpl(const GraphViewer&, const Node& node,
WebnnDeviceType device_type, const logging::Logger& logger) const override;
};

Expand Down Expand Up @@ -66,7 +66,7 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
}

// Operator support related.
bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const GraphViewer& /* initializers */,
const Node& node,
WebnnDeviceType device_type,
const logging::Logger& logger) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace webnn {
Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const {
ORT_RETURN_IF_NOT(
IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(),
IsOpSupported(model_builder.GetGraphViewer(), node, model_builder.GetWebnnDeviceType(),
model_builder.GetOpSupportLimits(), logger),
"Unsupported operator ", node.OpType());
ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger));
Expand All @@ -26,10 +26,10 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node&

// Operator support related.

bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
bool BaseOpBuilder::IsOpSupported(const GraphViewer& graph_viewer, const Node& node,
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
if (!HasSupportedInputs(initializers, node, wnn_limits, logger))
if (!HasSupportedInputs(graph_viewer, node, wnn_limits, logger))
return false;

if (!HasSupportedOutputs(node, wnn_limits, logger))
Expand All @@ -38,22 +38,22 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
if (!HasSupportedOpSet(node, logger))
return false;

return IsOpSupportedImpl(initializers, node, device_type, logger);
return IsOpSupportedImpl(graph_viewer, node, device_type, logger);
}

bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool BaseOpBuilder::HasSupportedInputs(const GraphViewer& graph_viewer, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
if (!IsTensorShapeSupported(*input, node_name, logger, allow_empty_tensor_as_input_)) {
return false;
}
}

return HasSupportedInputsImpl(initializers, node, wnn_limits, logger);
return HasSupportedInputsImpl(graph_viewer, node, wnn_limits, logger);
}

bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node,
bool BaseOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
// We only check the type of input 0 by default, specific op builder can override this.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ class BaseOpBuilder : public IOpBuilder {

// Operator support related.
public:
bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
bool IsOpSupported(const GraphViewer& graph_viewer, const Node& node,
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;

protected:
virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */,
virtual bool IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const Node& /* node */,
const WebnnDeviceType /* device_type */, const logging::Logger& /* logger */) const {
return true;
}

virtual bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits,
virtual bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;
virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;
Expand All @@ -56,7 +56,7 @@ class BaseOpBuilder : public IOpBuilder {

private:
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedInputs(const GraphViewer&, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;

const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BinaryOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
bool HasSupportedInputsImpl(const GraphViewer&, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

Expand Down Expand Up @@ -57,7 +57,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
return Status::OK();
}

bool BinaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
bool BinaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const std::string_view op_type = node.OpType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ClipOpBuilder : public BaseOpBuilder {

// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
};

Expand Down Expand Up @@ -61,15 +61,12 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

// Operator support related.

bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
bool ClipOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer,
const Node& node,
const WebnnDeviceType device_type,
const logging::Logger& logger) const {
// TODO: Update IsOpSupportedImpl to pass GraphViewer instead of InitializedTensorSet so the implementations
// can ensure initializers are constant. See #19401 for details of how this update was made to the NNAPI EP.
// GetClipMinMax(graph_viewer, node, minValue, maxValue, logger)
float min, max;
return GetClipMinMax(initializers, node, min, max, logger);
return GetClipMinMax(graph_viewer, node, min, max, logger);
}

void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class ConcatOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
bool HasSupportedInputsImpl(const GraphViewer&, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

Expand Down Expand Up @@ -54,7 +54,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return Status::OK();
}

bool ConcatOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
bool ConcatOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const std::string_view op_type = node.OpType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class ConvOpBuilder : public BaseOpBuilder {

// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
bool IsOpSupportedImpl(const GraphViewer&, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
bool HasSupportedInputsImpl(const GraphViewer&, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

Expand Down Expand Up @@ -344,7 +344,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N

// Operator support related.

bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
bool ConvOpBuilder::IsOpSupportedImpl(const GraphViewer&,
const Node& node,
const WebnnDeviceType device_type,
const logging::Logger& logger) const {
Expand Down Expand Up @@ -381,7 +381,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}

bool ConvOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
bool ConvOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const std::string_view op_type = node.OpType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class CumSumOpBuilder : public BaseOpBuilder {

// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
};

Expand Down Expand Up @@ -70,7 +70,7 @@ Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
}

// Operator support related.
bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
bool CumSumOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer,
const Node& node,
WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
Expand All @@ -82,7 +82,8 @@ bool CumSumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers

const std::string axis_name = GetTensorName(input_defs, 1);
// Inputs contain optional 'axis' input.
if (!Contains(initializers, axis_name)) {
const auto* init = graph_viewer.GetConstantInitializer(axis_name);
if (init == nullptr) {
LOGS(logger, VERBOSE) << "The axis must be a constant initializer.";
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DropoutOpBuilder : public BaseOpBuilder {

// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
bool IsOpSupportedImpl(const GraphViewer&, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
};

Expand Down Expand Up @@ -73,7 +73,7 @@ Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
}

// Operator support related.
bool DropoutOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
bool DropoutOpBuilder::IsOpSupportedImpl(const GraphViewer&,
const Node& node,
const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class EinsumOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
bool IsOpSupportedImpl(const GraphViewer&, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
bool HasSupportedInputsImpl(const GraphViewer&, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

Expand Down Expand Up @@ -694,7 +694,7 @@ Status EinsumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

// Operator support related.

bool EinsumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */,
bool EinsumOpBuilder::IsOpSupportedImpl(const GraphViewer&,
const Node& node,
const WebnnDeviceType device_type,
const logging::Logger& logger) const {
Expand Down Expand Up @@ -734,7 +734,7 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
return true;
}

bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
bool EinsumOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();

Expand Down
Loading
Loading