diff --git a/onnxruntime/core/providers/webnn/builders/impl/mha_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/mha_op_builder.cc index d435750221e70..5cadeb632c9c1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/mha_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/mha_op_builder.cc @@ -33,7 +33,6 @@ class MultiHeadAttentionOpBuilder : public BaseOpBuilder { /** MultiHeadAttention SubGraph. Abbreviations: B is batch_size, S is sequence_length, W is hidden_size N is number of attention heads, H is head size - Notes: If the datatype of the inputs (qkv and past kv) is float16, we cast them to float32 to ensure data precision. query key value | | | @@ -74,14 +73,6 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu int32_t input_query_type = 0; ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_query_type, logger), "Could not get input data type."); - int32_t output_type = 0; - ORT_RETURN_IF_NOT(GetType(*node.OutputDefs()[0], output_type, logger), "Could not get input data type."); - - if (input_query_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - common_options.set("label", node.Name() + "_/MHA/preprocess/cast/query_input"); - query_input = model_builder.GetBuilder().call("cast", query_input, emscripten::val("float32"), - common_options); - } std::vector input_q_shape, input_k_shape, input_v_shape; uint32_t batch_size, sequence_length, kv_sequence_length, hidden_size, head_size; @@ -91,12 +82,6 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu hidden_size = SafeInt(input_q_shape[2]); head_size = hidden_size / num_heads; key_input = model_builder.GetOperand(input_defs[1]->Name()); - if (input_query_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - common_options.set("label", node.Name() + "_/MHA/preprocess/cast/key_input"); - key_input = model_builder.GetBuilder().call("cast", key_input, emscripten::val("float32"), - common_options); - } - ORT_RETURN_IF_NOT(GetShape(*input_defs[1], input_k_shape, logger), "Cannot get key shape"); const auto k_rank = input_k_shape.size(); @@ -119,12 +104,6 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu k_reshape_skip = true; } value_input = model_builder.GetOperand(input_defs[2]->Name()); - if (input_query_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - common_options.set("label", node.Name() + "_/MHA/preprocess/cast/value_input"); - value_input = model_builder.GetBuilder().call("cast", value_input, emscripten::val("float32"), - common_options); - } - ORT_RETURN_IF_NOT(GetShape(*input_defs[2], input_v_shape, logger), "Cannot get value shape"); const auto v_rank = input_v_shape.size(); if (v_rank == 3) { // Value with shape (batch_size, kv_sequence_length, v_hidden_size) @@ -149,13 +128,8 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu } emscripten::val attention_bias = emscripten::val::undefined(); - if (!TensorExists(input_defs, 5)) { + if (TensorExists(input_defs, 5)) { attention_bias = model_builder.GetOperand(input_defs[5]->Name()); - if (input_query_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - common_options.set("label", node.Name() + "_/MHA/preprocess/cast/attention_bias"); - attention_bias = model_builder.GetBuilder().call("cast", attention_bias, - emscripten::val("float32"), common_options); - } } batch_size = SafeInt(input_q_shape[0]); @@ -188,12 +162,6 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu if (TensorExists(input_defs, 6)) { emscripten::val past_key_input = model_builder.GetOperand(input_defs[6]->Name()); - if (input_query_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - common_options.set("label", node.Name() + "_/MHA/preprocess/cast/past_key_input"); - past_key_input = model_builder.GetBuilder().call("cast", past_key_input, - emscripten::val("float32"), common_options); - } - common_options.set("label", node.Name() + "_/MHA/key/concat"); std::vector inputs({past_key_input, present_key}); uint32_t axis = 2; @@ -220,11 +188,6 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu if (TensorExists(input_defs, 7)) { emscripten::val past_value_input = model_builder.GetOperand(input_defs[7]->Name()); - if (input_query_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - common_options.set("label", node.Name() + "_/MHA/preprocess/cast/past_value_input"); - past_value_input = model_builder.GetBuilder().call("cast", past_value_input, - emscripten::val("float32"), common_options); - } common_options.set("label", node.Name() + "_/MHA/value/concat"); std::vector inputs({past_value_input, present_value}); @@ -237,33 +200,17 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu } emscripten::val scale_constant = - model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, scale_value, {1}); + model_builder.CreateOrGetConstant(input_query_type, scale_value, {1}); emscripten::val output = ScaledDotProductAttention(model_builder, node, logger, new_query, new_key, present_value, scale_constant, attention_bias, reshape_output_shape); - - if (output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - common_options.set("label", node.Name() + "_/MHA/postprocess/cast/output"); - output = - model_builder.GetBuilder().call("cast", output, emscripten::val("float16"), common_options); - } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); if (TensorExists(node.OutputDefs(), 1)) { - if (output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - common_options.set("label", node.Name() + "_/MHA/postprocess/cast/present_key"); - present_key = model_builder.GetBuilder().call("cast", present_key, emscripten::val("float16"), - common_options); - } model_builder.AddOperand(node.OutputDefs()[1]->Name(), std::move(present_key)); } if (TensorExists(node.OutputDefs(), 2)) { - if (output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { - common_options.set("label", node.Name() + "_/MHA/postprocess/cast/present_value"); - present_value = model_builder.GetBuilder().call("cast", present_value, - emscripten::val("float16"), common_options); - } model_builder.AddOperand(node.OutputDefs()[2]->Name(), std::move(present_value)); } @@ -280,6 +227,7 @@ bool MultiHeadAttentionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_vie const uint32_t num_heads = helper.Get("num_heads", 0); if (num_heads == 0) { LOGS(logger, VERBOSE) << "Attributes num_heads is required."; + return false; } std::vector input_shape; @@ -310,6 +258,7 @@ bool MultiHeadAttentionOpBuilder::HasSupportedInputsImpl(const GraphViewer&, con if (TensorExists(input_defs, i)) { int32_t input_type = 0; if (!GetType(*input_defs[i], input_type, logger)) { + LOGS(logger, VERBOSE) << "Could not get input " << i << " data type."; return false; } input_types.push_back(input_type); @@ -335,26 +284,34 @@ bool MultiHeadAttentionOpBuilder::HasSupportedOutputsImpl(const Node& node, cons bool has_present_k = TensorExists(output_defs, 1); bool has_present_v = TensorExists(output_defs, 2); if (has_present_k != has_present_v) { // present_k and present_v must appear together. + LOGS(logger, VERBOSE) << op_type << " requires both present_k and present_v outputs."; return false; } int32_t output_type = 0; - if (has_present_k) { + if (!GetType(*output_defs[0], output_type, logger)) { + LOGS(logger, VERBOSE) << "Could not get output 0's data type."; + return false; + } + if (has_present_k && has_present_v) { int32_t present_k_type = 0; int32_t present_v_type = 0; - if (!GetType(*output_defs[0], output_type, logger) || !GetType(*output_defs[1], present_k_type, logger) || + if (!GetType(*output_defs[1], present_k_type, logger) || !GetType(*output_defs[2], present_v_type, logger)) { + LOGS(logger, VERBOSE) << "Could not get output 1 or 2's data type."; return false; } std::array output_types{output_type, present_k_type, present_v_type}; if (!AreDataTypesSame(op_type, output_types, logger)) { + LOGS(logger, VERBOSE) << op_type << " requires the data types of output 0, 1 and 2 to be the same."; return false; } } if (output_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT && output_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + LOGS(logger, VERBOSE) << op_type << " only supports float16 and float32 output types, but got " << output_type << "."; return false; } return true;