Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 14 additions & 57 deletions onnxruntime/core/providers/webnn/builders/impl/mha_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ class MultiHeadAttentionOpBuilder : public BaseOpBuilder {
/** MultiHeadAttention SubGraph.
Abbreviations: B is batch_size, S is sequence_length, W is hidden_size
N is number of attention heads, H is head size
Notes: If the datatype of the inputs (qkv and past kv) is float16, we cast them to float32 to ensure data precision.

query key value
| | |
Expand Down Expand Up @@ -74,14 +73,6 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu

int32_t input_query_type = 0;
ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_query_type, logger), "Could not get input data type.");
int32_t output_type = 0;
ORT_RETURN_IF_NOT(GetType(*node.OutputDefs()[0], output_type, logger), "Could not get input data type.");

if (input_query_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
common_options.set("label", node.Name() + "_/MHA/preprocess/cast/query_input");
query_input = model_builder.GetBuilder().call<emscripten::val>("cast", query_input, emscripten::val("float32"),
common_options);
}

std::vector<int64_t> input_q_shape, input_k_shape, input_v_shape;
uint32_t batch_size, sequence_length, kv_sequence_length, hidden_size, head_size;
Expand All @@ -91,12 +82,6 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
hidden_size = SafeInt<uint32_t>(input_q_shape[2]);
head_size = hidden_size / num_heads;
key_input = model_builder.GetOperand(input_defs[1]->Name());
if (input_query_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
common_options.set("label", node.Name() + "_/MHA/preprocess/cast/key_input");
key_input = model_builder.GetBuilder().call<emscripten::val>("cast", key_input, emscripten::val("float32"),
common_options);
}

ORT_RETURN_IF_NOT(GetShape(*input_defs[1], input_k_shape, logger), "Cannot get key shape");
const auto k_rank = input_k_shape.size();

Expand All @@ -119,12 +104,6 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
k_reshape_skip = true;
}
value_input = model_builder.GetOperand(input_defs[2]->Name());
if (input_query_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
common_options.set("label", node.Name() + "_/MHA/preprocess/cast/value_input");
value_input = model_builder.GetBuilder().call<emscripten::val>("cast", value_input, emscripten::val("float32"),
common_options);
}

ORT_RETURN_IF_NOT(GetShape(*input_defs[2], input_v_shape, logger), "Cannot get value shape");
const auto v_rank = input_v_shape.size();
if (v_rank == 3) { // Value with shape (batch_size, kv_sequence_length, v_hidden_size)
Expand All @@ -149,13 +128,8 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
}

emscripten::val attention_bias = emscripten::val::undefined();
if (!TensorExists(input_defs, 5)) {
if (TensorExists(input_defs, 5)) {
attention_bias = model_builder.GetOperand(input_defs[5]->Name());
if (input_query_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
common_options.set("label", node.Name() + "_/MHA/preprocess/cast/attention_bias");
attention_bias = model_builder.GetBuilder().call<emscripten::val>("cast", attention_bias,
emscripten::val("float32"), common_options);
}
}

batch_size = SafeInt<uint32_t>(input_q_shape[0]);
Expand Down Expand Up @@ -188,12 +162,6 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu

if (TensorExists(input_defs, 6)) {
emscripten::val past_key_input = model_builder.GetOperand(input_defs[6]->Name());
if (input_query_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
common_options.set("label", node.Name() + "_/MHA/preprocess/cast/past_key_input");
past_key_input = model_builder.GetBuilder().call<emscripten::val>("cast", past_key_input,
emscripten::val("float32"), common_options);
}

common_options.set("label", node.Name() + "_/MHA/key/concat");
std::vector<emscripten::val> inputs({past_key_input, present_key});
uint32_t axis = 2;
Expand All @@ -220,11 +188,6 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu

if (TensorExists(input_defs, 7)) {
emscripten::val past_value_input = model_builder.GetOperand(input_defs[7]->Name());
if (input_query_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
common_options.set("label", node.Name() + "_/MHA/preprocess/cast/past_value_input");
past_value_input = model_builder.GetBuilder().call<emscripten::val>("cast", past_value_input,
emscripten::val("float32"), common_options);
}

common_options.set("label", node.Name() + "_/MHA/value/concat");
std::vector<emscripten::val> inputs({past_value_input, present_value});
Expand All @@ -237,33 +200,17 @@ Status MultiHeadAttentionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu
}

emscripten::val scale_constant =
model_builder.CreateOrGetConstant<float>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, scale_value, {1});
model_builder.CreateOrGetConstant<float>(input_query_type, scale_value, {1});

emscripten::val output = ScaledDotProductAttention(model_builder, node, logger, new_query, new_key, present_value,
scale_constant, attention_bias, reshape_output_shape);

if (output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
common_options.set("label", node.Name() + "_/MHA/postprocess/cast/output");
output =
model_builder.GetBuilder().call<emscripten::val>("cast", output, emscripten::val("float16"), common_options);
}
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));

if (TensorExists(node.OutputDefs(), 1)) {
if (output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
common_options.set("label", node.Name() + "_/MHA/postprocess/cast/present_key");
present_key = model_builder.GetBuilder().call<emscripten::val>("cast", present_key, emscripten::val("float16"),
common_options);
}
model_builder.AddOperand(node.OutputDefs()[1]->Name(), std::move(present_key));
}

if (TensorExists(node.OutputDefs(), 2)) {
if (output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
common_options.set("label", node.Name() + "_/MHA/postprocess/cast/present_value");
present_value = model_builder.GetBuilder().call<emscripten::val>("cast", present_value,
emscripten::val("float16"), common_options);
}
model_builder.AddOperand(node.OutputDefs()[2]->Name(), std::move(present_value));
}

Expand All @@ -280,6 +227,7 @@ bool MultiHeadAttentionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_vie
const uint32_t num_heads = helper.Get("num_heads", 0);
if (num_heads == 0) {
LOGS(logger, VERBOSE) << "Attributes num_heads is required.";
return false;
}

std::vector<int64_t> input_shape;
Expand Down Expand Up @@ -310,6 +258,7 @@ bool MultiHeadAttentionOpBuilder::HasSupportedInputsImpl(const GraphViewer&, con
if (TensorExists(input_defs, i)) {
int32_t input_type = 0;
if (!GetType(*input_defs[i], input_type, logger)) {
LOGS(logger, VERBOSE) << "Could not get input " << i << " data type.";
return false;
}
input_types.push_back(input_type);
Expand All @@ -335,26 +284,34 @@ bool MultiHeadAttentionOpBuilder::HasSupportedOutputsImpl(const Node& node, cons
bool has_present_k = TensorExists(output_defs, 1);
bool has_present_v = TensorExists(output_defs, 2);
if (has_present_k != has_present_v) { // present_k and present_v must appear together.
LOGS(logger, VERBOSE) << op_type << " requires both present_k and present_v outputs.";
return false;
}

int32_t output_type = 0;
if (has_present_k) {
if (!GetType(*output_defs[0], output_type, logger)) {
LOGS(logger, VERBOSE) << "Could not get output 0's data type.";
return false;
}
if (has_present_k && has_present_v) {
int32_t present_k_type = 0;
int32_t present_v_type = 0;
if (!GetType(*output_defs[0], output_type, logger) || !GetType(*output_defs[1], present_k_type, logger) ||
if (!GetType(*output_defs[1], present_k_type, logger) ||
!GetType(*output_defs[2], present_v_type, logger)) {
LOGS(logger, VERBOSE) << "Could not get output 1 or 2's data type.";
return false;
}

std::array<int32_t, 3> output_types{output_type, present_k_type, present_v_type};
if (!AreDataTypesSame(op_type, output_types, logger)) {
LOGS(logger, VERBOSE) << op_type << " requires the data types of output 0, 1 and 2 to be the same.";
return false;
}
}

if (output_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
output_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
LOGS(logger, VERBOSE) << op_type << " only supports float16 and float32 output types, but got " << output_type << ".";
return false;
}
return true;
Expand Down
Loading