diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 142d64caa64aa..9056ea6d86f65 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -226,13 +226,27 @@ bool AreDataTypesSame(const std::string_view op_type, return true; } -bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) { +bool IsSupportedDataType(const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view webnn_input_output_name) { auto it = onnx_to_webnn_data_type_map.find(static_cast(onnx_data_type)); if (it == onnx_to_webnn_data_type_map.end()) return false; const std::string_view webnn_data_type = it->second; + // MLOpSupportLimits has different structure. Certain WebNN ops have input and output name, + // special cases like 'constant', 'input' and 'output' have no input or output name. + emscripten::val webnn_supported_data_types = + webnn_input_output_name.empty() + ? wnn_limits[std::string(webnn_op_type)]["dataTypes"] + : wnn_limits[std::string(webnn_op_type)][std::string(webnn_input_output_name)]["dataTypes"]; + + if (webnn_supported_data_types.isUndefined()) { + return false; + } + // Check if WebNN supports the data type. bool is_supported = webnn_supported_data_types.call("includes", emscripten::val(std::string(webnn_data_type))) @@ -240,7 +254,8 @@ bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& we if (webnn_data_type == "int64" && !is_supported && - webnn_supported_data_types.call("includes", emscripten::val("int32")).as()) { + webnn_supported_data_types.call("includes", emscripten::val("int32")).as() && + !wnn_limits["constant"]["dataTypes"].call("includes", emscripten::val("int64")).as()) { // Current context doesn't support int64, but int32 is supported. // We can use int32 as a workaround. is_supported = true; @@ -280,8 +295,7 @@ bool IsDataTypeSupportedByWebNNOp(const std::string_view onnx_op_type, << webnn_input_output_name << "]"; return false; } - if (!IsSupportedDataType( - onnx_data_type, wnn_limits[std::string(webnn_op_type)][std::string(webnn_input_output_name)]["dataTypes"])) { + if (!IsSupportedDataType(onnx_data_type, wnn_limits, webnn_op_type, webnn_input_output_name)) { LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] " << onnx_input_output_name << "'s data type: [" << onnx_data_type << "] is not supported by WebNN op [" << webnn_op_type << "] for now"; return false; diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index da1fac6d1ad05..baedb98a34c28 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -268,7 +268,10 @@ inline bool GetWebNNOpInputs(const std::string_view onnx_op_type, bool AreDataTypesSame(const std::string_view op_type, gsl::span input_types, const logging::Logger& logger); -bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types); +bool IsSupportedDataType(const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string_view webnn_op_type, + const std::string_view webnn_input_output_name); bool IsDataTypeSupportedByOp(const std::string_view onnx_op_type, const int32_t onnx_data_type, const emscripten::val& wnn_limits, diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 5ee4a9daa1407..d12806cbcfbb1 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -229,7 +229,7 @@ Status ModelBuilder::RegisterInitializers() { desc.set("shape", emscripten::val::array(dims)); const auto data_type = tensor.data_type(); emscripten::val operand = emscripten::val::object(); - if (IsSupportedDataType(data_type, wnn_limits_["constant"]["dataTypes"])) { + if (IsSupportedDataType(data_type, wnn_limits_, "constant", "")) { ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "WebNN backend does not support data type: ", data_type); ORT_RETURN_IF_ERROR(RegisterConstant(tensor, operand, desc, logger_)); } else {