diff --git a/onnxruntime/core/optimizer/attention_fusion.cc b/onnxruntime/core/optimizer/attention_fusion.cc index 9fd71b3b00cd0..7fe7c914fa796 100644 --- a/onnxruntime/core/optimizer/attention_fusion.cc +++ b/onnxruntime/core/optimizer/attention_fusion.cc @@ -2,15 +2,673 @@ // Licensed under the MIT License. #include "core/graph/graph_utils.h" +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/attention_fusion.h" #include "core/optimizer/utils.h" #include "core/optimizer/attention_fusion_helper.h" -#include "core/graph/graph_utils.h" #include +#include namespace onnxruntime { +static bool ValidateMatMulInitializer(const Graph& graph, const Node& matmul, int64_t hidden_size); + +namespace { + +static bool ValidateAddBiasInitializerEitherInput(const Graph& graph, const Node& add, int64_t hidden_size) { + if (add.InputDefs().size() < 2) { + return false; + } + + const NodeArg& input_0 = *(add.InputDefs()[0]); + const NodeArg& input_1 = *(add.InputDefs()[1]); + const bool input_0_is_bias = graph_utils::IsInitializer(graph, input_0.Name(), true) && + optimizer_utils::ValidateShape(input_0, {hidden_size}); + const bool input_1_is_bias = graph_utils::IsInitializer(graph, input_1.Name(), true) && + optimizer_utils::ValidateShape(input_1, {hidden_size}); + return input_0_is_bias || input_1_is_bias; +} + +static bool ValidateProjectionGemmInitializer(const Graph& graph, const Node& gemm, int64_t hidden_size) { + if (gemm.InputDefs().size() < 3) { + return false; + } + + if (const auto* alpha_attr = graph_utils::GetNodeAttribute(gemm, "alpha"); + alpha_attr && std::abs(alpha_attr->f() - 1.0f) > 1e-6f) { + return false; + } + + if (const auto* beta_attr = graph_utils::GetNodeAttribute(gemm, "beta"); + beta_attr && std::abs(beta_attr->f() - 1.0f) > 1e-6f) { + return false; + } + + if (const auto* trans_a_attr = graph_utils::GetNodeAttribute(gemm, "transA"); + trans_a_attr && trans_a_attr->i() != 0) { + return false; + } + + if (const auto* trans_b_attr = graph_utils::GetNodeAttribute(gemm, "transB"); + trans_b_attr && trans_b_attr->i() != 0) { + return false; + } + + const NodeArg& input_b = *(gemm.InputDefs()[1]); + const NodeArg& input_c = *(gemm.InputDefs()[2]); + if (!graph_utils::IsInitializer(graph, input_b.Name(), true) || + !graph_utils::IsInitializer(graph, input_c.Name(), true)) { + return false; + } + + return optimizer_utils::ValidateShape(input_b, {hidden_size, hidden_size}) && + optimizer_utils::ValidateShape(input_c, {hidden_size}); +} + +// Most attention fusions require all matched nodes to already be assigned to an execution provider +// that supports the fused op. MobileClipMHA is also matched before partitioning in graph-transform +// tests, so nodes may still be unassigned here. Accept nodes that are either unassigned or already +// assigned to a compatible provider, and preserve the original provider string on the fused nodes +// once the pattern is rewritten. +static bool IsSupportedOrUnassignedNode(const Node& node, + const InlinedHashSet& compatible_execution_providers) { + return node.GetExecutionProviderType().empty() || + graph_utils::IsSupportedProvider(node, compatible_execution_providers); +} + +static bool IsSupportedOrUnassignedNode(const Node& node, + std::string_view required_execution_provider) { + const auto& execution_provider = node.GetExecutionProviderType(); + return execution_provider.empty() || + execution_provider == required_execution_provider; +} + +static bool AreSupportedOrUnassignedNodes( + const Node& anchor_node, + const std::initializer_list& nodes, + const InlinedHashSet& compatible_execution_providers) { + if (!IsSupportedOrUnassignedNode(anchor_node, compatible_execution_providers)) { + return false; + } + + const auto& required_execution_provider = anchor_node.GetExecutionProviderType(); + for (const Node* node : nodes) { + if (node == nullptr) { + continue; + } + + if (!IsSupportedOrUnassignedNode(*node, required_execution_provider)) { + return false; + } + } + + return true; +} + +static bool HasExpectedPerm(const Node& node, const std::initializer_list& expected_perm) { + return optimizer_utils::IsAttributeWithExpectedValues(node, "perm", std::vector(expected_perm)); +} + +static bool HasExpectedAxesInput(const Graph& graph, const Node& node, const std::initializer_list& expected_axes) { + if (node.InputDefs().size() < 2) { + return false; + } + + InlinedVector axes; + if (!optimizer_utils::AppendTensorFromInitializer(graph, *node.InputDefs()[1], axes, true)) { + return false; + } + + return axes == InlinedVector(expected_axes.begin(), expected_axes.end()); +} + +static bool TryGetMobileClipQkvReshapeInfo(const Graph& graph, const Node& qkv_reshape, + int64_t& num_heads, int64_t& head_size, int64_t& hidden_size) { + if (qkv_reshape.InputDefs().size() < 2) { + return false; + } + + InlinedVector reshape_dims; + if (!optimizer_utils::AppendTensorFromInitializer(graph, *qkv_reshape.InputDefs()[1], reshape_dims, true)) { + return false; + } + + if (reshape_dims.size() != 5 || reshape_dims[2] != 3 || reshape_dims[3] <= 0 || reshape_dims[4] <= 0) { + return false; + } + + num_heads = reshape_dims[3]; + head_size = reshape_dims[4]; + + try { + hidden_size = SafeInt(num_heads) * head_size; + } catch (const OnnxRuntimeException&) { + return false; + } + + return hidden_size > 0; +} + +static std::optional TryCreateMobileClipMhaOutputType(const NodeArg& qkv_output, + int64_t hidden_size) { + const auto* qkv_output_type = qkv_output.TypeAsProto(); + if (qkv_output_type == nullptr || !qkv_output_type->has_tensor_type()) { + return std::nullopt; + } + + ONNX_NAMESPACE::TypeProto mha_output_type(*qkv_output_type); + auto* shape = mha_output_type.mutable_tensor_type()->mutable_shape(); + if (shape->dim_size() > 0) { + auto* last_dim = shape->mutable_dim(shape->dim_size() - 1); + last_dim->clear_dim_param(); + last_dim->set_dim_value(hidden_size); + } + + return mha_output_type; +} + +static Node* GetOnlyChildByOutputIndex(Graph& graph, const Node& node, size_t output_index, const char* child_op_type) { + const auto output_edges = graph_utils::GraphEdge::GetNodeOutputEdges(node, output_index); + if (output_edges.size() != 1) { + return nullptr; + } + + Node* child = graph.GetNode(output_edges[0].dst_node); + if (child == nullptr || child->OpType() != child_op_type) { + return nullptr; + } + + return child; +} + +static bool TryCreateNormalizedProjectionGemm(Graph& graph, + NodeArg& projection_input, + const NodeArg& original_projection_input, + const NodeArg& proj_weight, + const NodeArg& proj_bias, + NodeArg& projection_output, + const std::string& base_name, + const std::string& provider_type) { + const auto* proj_input_shape = original_projection_input.Shape(); + const auto* proj_weight_shape = proj_weight.Shape(); + if (proj_input_shape == nullptr || proj_weight_shape == nullptr || proj_weight_shape->dim_size() != 2) { + return false; + } + + auto input_shape = utils::GetTensorShapeFromTensorShapeProto(*proj_input_shape); + if (input_shape.Size() == -1 || input_shape.NumDimensions() < 2) { + return false; + } + + const auto& dim_k = proj_weight_shape->dim(0); + const auto& dim_n = proj_weight_shape->dim(1); + if (!utils::HasDimValue(dim_k) || !utils::HasDimValue(dim_n)) { + return false; + } + + const int64_t m = input_shape.SizeToDimension(input_shape.NumDimensions() - 1); + if (m <= 0) { + return false; + } + + const int64_t k = dim_k.dim_value(); + const int64_t n = dim_n.dim_value(); + if (input_shape[input_shape.NumDimensions() - 1] != k) { + return false; + } + + const auto* bias_shape = proj_bias.Shape(); + if (bias_shape == nullptr || bias_shape->dim_size() != 1 || !utils::HasDimValue(bias_shape->dim(0)) || + bias_shape->dim(0).dim_value() != n) { + return false; + } + + const auto* input_type = original_projection_input.TypeAsProto(); + if (input_type == nullptr || !input_type->has_tensor_type()) { + return false; + } + + const auto element_type = static_cast(input_type->tensor_type().elem_type()); + + auto add_shape_initializer = [&](const std::string& name, const InlinedVector& shape) -> NodeArg& { + ONNX_NAMESPACE::TensorProto shape_initializer_proto; + shape_initializer_proto.set_name(graph.GenerateNodeArgName(name)); + shape_initializer_proto.add_dims(static_cast(shape.size())); + shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + const size_t shape_bytes = SafeInt(shape.size()) * sizeof(int64_t); + utils::SetRawDataInTensorProto(shape_initializer_proto, shape.data(), shape_bytes); + return graph_utils::AddInitializerWithOrtValue(graph, shape_initializer_proto); + }; + + auto make_tensor_arg = [&](const std::string& name, const InlinedVector& shape) -> NodeArg* { + ONNX_NAMESPACE::TypeProto type_proto; + type_proto.mutable_tensor_type()->set_elem_type(element_type); + for (int64_t dim_value : shape) { + type_proto.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim_value); + } + + return &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(name), &type_proto); + }; + + InlinedVector gemm_input_shape{m, k}; + InlinedVector gemm_output_shape{m, n}; + InlinedVector output_shape_values = input_shape.AsShapeVector(); + output_shape_values.back() = n; + + NodeArg* gemm_input_arg = make_tensor_arg("mobileclip_proj_gemm_input", gemm_input_shape); + NodeArg* gemm_output_arg = make_tensor_arg("mobileclip_proj_gemm_output", gemm_output_shape); + NodeArg& gemm_input_shape_arg = add_shape_initializer("mobileclip_proj_gemm_input_shape", gemm_input_shape); + NodeArg& gemm_output_shape_arg = add_shape_initializer("mobileclip_proj_gemm_output_shape", output_shape_values); + + Node& input_reshape = graph.AddNode( + graph.GenerateNodeName("MobileClipProjGemmInputReshape"), + "Reshape", + "Reshape MobileCLIP projection input for Gemm", + {&projection_input, &gemm_input_shape_arg}, + {gemm_input_arg}); + input_reshape.SetExecutionProviderType(provider_type); + + Node& gemm_node = graph.AddNode( + graph.GenerateNodeName(base_name + "/MobileClipProjectionGemm"), + "Gemm", + "Normalized MobileCLIP projection Gemm", + {gemm_input_arg, const_cast(&proj_weight), const_cast(&proj_bias)}, + {gemm_output_arg}); + gemm_node.SetExecutionProviderType(provider_type); + + Node& output_reshape = graph.AddNode( + graph.GenerateNodeName("MobileClipProjGemmOutputReshape"), + "Reshape", + "Restore MobileCLIP projection output shape after Gemm", + {gemm_output_arg, &gemm_output_shape_arg}, + {&projection_output}); + output_reshape.SetExecutionProviderType(provider_type); + + return true; +} + +static bool TryRewriteProjectionMatMulAddToGemm(Graph& graph, + NodeArg& projection_input, + Node& proj_matmul, + Node& proj_add) { + if (proj_matmul.InputDefs().size() < 2 || proj_add.InputDefs().size() < 2) { + return false; + } + + const int bias_idx = proj_matmul.OutputDefs()[0]->Name() == proj_add.InputDefs()[0]->Name() ? 1 : 0; + return TryCreateNormalizedProjectionGemm(graph, + projection_input, + *proj_matmul.InputDefs()[0], + *proj_matmul.InputDefs()[1], + *proj_add.InputDefs()[bias_idx], + *proj_add.MutableOutputDefs()[0], + proj_matmul.Name(), + proj_matmul.GetExecutionProviderType()); +} + +static bool TryRewriteProjectionGemm(Graph& graph, + NodeArg& projection_input, + Node& proj_gemm) { + if (proj_gemm.InputDefs().size() < 3 || proj_gemm.OutputDefs().empty()) { + return false; + } + + return TryCreateNormalizedProjectionGemm(graph, + projection_input, + *proj_gemm.InputDefs()[0], + *proj_gemm.InputDefs()[1], + *proj_gemm.InputDefs()[2], + *proj_gemm.MutableOutputDefs()[0], + proj_gemm.Name(), + proj_gemm.GetExecutionProviderType()); +} + +static bool TryFuseMobileClipMHA(Node& qkv_matmul, + Graph& graph, + const InlinedHashSet& compatible_execution_providers, + const logging::Logger& logger) { + const auto fail = [&](const char* message) { + LOGS(logger, VERBOSE) << "MobileClipMHA[" << qkv_matmul.Name() << "]: fusion skipped: " << message; + return false; + }; + + if (!graph_utils::IsSupportedOptypeVersionAndDomain(qkv_matmul, "MatMul", {1, 9, 13}, kOnnxDomain)) { + return false; + } + + if (!IsSupportedOrUnassignedNode(qkv_matmul, compatible_execution_providers)) { + return false; + } + + if (!optimizer_utils::CheckOutputEdges(graph, qkv_matmul, 1) || qkv_matmul.InputDefs().size() < 2 || + !graph_utils::IsInitializer(graph, qkv_matmul.InputDefs()[1]->Name(), true)) { + return fail("qkv MatMul output count or weight initializer check failed"); + } + + const Node* sequence_transpose = graph_utils::GetInputNode(qkv_matmul, 0); + if (sequence_transpose == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*sequence_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !HasExpectedPerm(*sequence_transpose, {0, 2, 1}) || + !optimizer_utils::CheckOutputEdges(graph, *sequence_transpose, 1)) { + return false; + } + + const Node* input_reshape = graph_utils::GetInputNode(*sequence_transpose, 0); + if (input_reshape == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*input_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *input_reshape, 1)) { + return fail("missing input Reshape before sequence transpose"); + } + + Node* qkv_reshape = GetOnlyChildByOutputIndex(graph, qkv_matmul, 0, "Reshape"); + if (qkv_reshape == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*qkv_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *qkv_reshape, 1)) { + return fail("qkv Reshape after MatMul not matched"); + } + + Node* split = GetOnlyChildByOutputIndex(graph, *qkv_reshape, 0, "Split"); + if (split == nullptr || !graph_utils::IsSupportedOptypeVersionAndDomain(*split, "Split", {13, 18}, kOnnxDomain) || + split->OutputDefs().size() != 3 || !optimizer_utils::IsAttributeWithExpectedValue(*split, "axis", static_cast(2))) { + return fail("qkv Split(axis=2, outputs=3) not matched"); + } + + Node* q_transpose = GetOnlyChildByOutputIndex(graph, *split, 0, "Transpose"); + Node* k_squeeze = GetOnlyChildByOutputIndex(graph, *split, 1, "Squeeze"); + Node* v_transpose = GetOnlyChildByOutputIndex(graph, *split, 2, "Transpose"); + if (q_transpose == nullptr || k_squeeze == nullptr || v_transpose == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*q_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*k_squeeze, "Squeeze", {13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*v_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !HasExpectedPerm(*q_transpose, {2, 0, 3, 1, 4}) || + !HasExpectedPerm(*v_transpose, {2, 0, 3, 1, 4}) || + !HasExpectedAxesInput(graph, *k_squeeze, {2})) { + return fail("q/k/v branch entry pattern after Split not matched"); + } + + Node* q_squeeze = GetOnlyChildByOutputIndex(graph, *q_transpose, 0, "Squeeze"); + Node* v_squeeze = GetOnlyChildByOutputIndex(graph, *v_transpose, 0, "Squeeze"); + if (q_squeeze == nullptr || v_squeeze == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*q_squeeze, "Squeeze", {13}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*v_squeeze, "Squeeze", {13}, kOnnxDomain) || + !HasExpectedAxesInput(graph, *q_squeeze, {0}) || + !HasExpectedAxesInput(graph, *v_squeeze, {0})) { + return fail("q/v squeeze pattern not matched"); + } + + Node* q_scale_mul = GetOnlyChildByOutputIndex(graph, *q_squeeze, 0, "Mul"); + Node* k_transpose = GetOnlyChildByOutputIndex(graph, *k_squeeze, 0, "Transpose"); + if (q_scale_mul == nullptr || k_transpose == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*q_scale_mul, "Mul", {7, 13, 14}, kOnnxDomain) || + !graph_utils::IsSupportedOptypeVersionAndDomain(*k_transpose, "Transpose", {1, 13}, kOnnxDomain) || + !HasExpectedPerm(*k_transpose, {0, 2, 3, 1})) { + return fail("q scale Mul or k Transpose(0,2,3,1) not matched"); + } + + float scale = 0.0f; + if (q_scale_mul->InputDefs().size() < 2) { + return fail("q scale constant not found"); + } + + const NodeArg* q_squeeze_output = q_squeeze->OutputDefs()[0]; + const NodeArg* mul_input_0 = q_scale_mul->InputDefs()[0]; + const NodeArg* mul_input_1 = q_scale_mul->InputDefs()[1]; + const bool input_0_is_q_squeeze = mul_input_0 != nullptr && q_squeeze_output != nullptr && + mul_input_0->Name() == q_squeeze_output->Name(); + const bool input_1_is_q_squeeze = mul_input_1 != nullptr && q_squeeze_output != nullptr && + mul_input_1->Name() == q_squeeze_output->Name(); + + const NodeArg* scale_input = nullptr; + if (input_0_is_q_squeeze && !input_1_is_q_squeeze) { + scale_input = mul_input_1; + } else if (input_1_is_q_squeeze && !input_0_is_q_squeeze) { + scale_input = mul_input_0; + } + + if (scale_input == nullptr || + !optimizer_utils::GetScalarInitializerValue(graph, *scale_input, scale, true)) { + return fail("q scale constant not found"); + } + + Node* qk_matmul = GetOnlyChildByOutputIndex(graph, *q_scale_mul, 0, "MatMul"); + if (qk_matmul == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*qk_matmul, "MatMul", {1, 9, 13}, kOnnxDomain) || + graph_utils::GetInputNode(*qk_matmul, 1) == nullptr || + graph_utils::GetInputNode(*qk_matmul, 1)->Index() != k_transpose->Index() || + !optimizer_utils::CheckOutputEdges(graph, *qk_matmul, 1)) { + return fail("qk MatMul not matched"); + } + + Node* softmax = GetOnlyChildByOutputIndex(graph, *qk_matmul, 0, "Softmax"); + if (softmax == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*softmax, "Softmax", {1, 11, 13}, kOnnxDomain) || + !optimizer_utils::IsAttributeWithExpectedValue(*softmax, "axis", static_cast(-1)) || + !optimizer_utils::CheckOutputEdges(graph, *softmax, 1)) { + return fail("Softmax(axis=-1) not matched"); + } + + Node* qkv_matmul_1 = GetOnlyChildByOutputIndex(graph, *softmax, 0, "MatMul"); + if (qkv_matmul_1 == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*qkv_matmul_1, "MatMul", {1, 9, 13}, kOnnxDomain) || + graph_utils::GetInputNode(*qkv_matmul_1, 1) == nullptr || + graph_utils::GetInputNode(*qkv_matmul_1, 1)->Index() != v_squeeze->Index() || + !optimizer_utils::CheckOutputEdges(graph, *qkv_matmul_1, 1)) { + return fail("attention-value MatMul not matched"); + } + + Node* transpose_3 = GetOnlyChildByOutputIndex(graph, *qkv_matmul_1, 0, "Transpose"); + if (transpose_3 == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*transpose_3, "Transpose", {1, 13}, kOnnxDomain) || + !HasExpectedPerm(*transpose_3, {0, 2, 1, 3}) || + !optimizer_utils::CheckOutputEdges(graph, *transpose_3, 1)) { + return fail("output Transpose(0,2,1,3) not matched"); + } + + Node* reshape_2 = GetOnlyChildByOutputIndex(graph, *transpose_3, 0, "Reshape"); + if (reshape_2 == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_2, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *reshape_2, 1)) { + return fail("output Reshape not matched"); + } + + Node* proj_matmul = GetOnlyChildByOutputIndex(graph, *reshape_2, 0, "MatMul"); + Node* proj_gemm = proj_matmul == nullptr ? GetOnlyChildByOutputIndex(graph, *reshape_2, 0, "Gemm") : nullptr; + Node* proj_gemm_input_reshape = nullptr; + Node* proj_gemm_output_reshape = nullptr; + Node* proj_add = nullptr; + + if (proj_matmul != nullptr) { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(*proj_matmul, "MatMul", {1, 9, 13}, kOnnxDomain) || + proj_matmul->InputDefs().size() < 2 || + !graph_utils::IsInitializer(graph, proj_matmul->InputDefs()[1]->Name(), true) || + !optimizer_utils::CheckOutputEdges(graph, *proj_matmul, 1)) { + return fail("projection MatMul not matched"); + } + + proj_add = GetOnlyChildByOutputIndex(graph, *proj_matmul, 0, "Add"); + if (proj_add == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_add, "Add", {7, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_add, 1)) { + return fail("projection Add not matched"); + } + } else { + if (proj_gemm == nullptr) { + proj_gemm_input_reshape = GetOnlyChildByOutputIndex(graph, *reshape_2, 0, "Reshape"); + if (proj_gemm_input_reshape == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_input_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_gemm_input_reshape, 1)) { + return fail("projection MatMul/Gemm not matched"); + } + + proj_gemm = GetOnlyChildByOutputIndex(graph, *proj_gemm_input_reshape, 0, "Gemm"); + if (proj_gemm == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm, "Gemm", {7, 9, 11, 13}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_gemm, 1)) { + return fail("projection MatMul/Gemm not matched"); + } + + proj_gemm_output_reshape = GetOnlyChildByOutputIndex(graph, *proj_gemm, 0, "Reshape"); + if (proj_gemm_output_reshape == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_output_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_gemm_output_reshape, 1)) { + return fail("normalized projection Gemm output Reshape not matched"); + } + } else if (!graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm, "Gemm", {7, 9, 11, 13}, kOnnxDomain) || + !optimizer_utils::CheckOutputEdges(graph, *proj_gemm, 1)) { + return fail("projection MatMul/Gemm not matched"); + } + } + + int64_t num_heads = 0; + int64_t head_size = 0; + int64_t hidden_size = 0; + if (!TryGetMobileClipQkvReshapeInfo(graph, *qkv_reshape, num_heads, head_size, hidden_size)) { + return fail("unable to derive num_heads/head_size from qkv reshape initializer"); + } + + if (proj_matmul != nullptr) { + if (!ValidateMatMulInitializer(graph, *proj_matmul, hidden_size) || + !ValidateAddBiasInitializerEitherInput(graph, *proj_add, hidden_size)) { + return fail("projection weight/bias shape validation failed"); + } + } else { + if (!ValidateProjectionGemmInitializer(graph, *proj_gemm, hidden_size)) { + return fail("projection Gemm weight/bias shape validation failed"); + } + } + + const NodeArg& qkv_weight = *qkv_matmul.InputDefs()[1]; + if (!optimizer_utils::ValidateShape(qkv_weight, {hidden_size, 3 * hidden_size})) { + return fail("qkv weight shape is not [hidden, 3*hidden]"); + } + + if (!AreSupportedOrUnassignedNodes( + qkv_matmul, + {sequence_transpose, + input_reshape, + qkv_reshape, + split, + q_transpose, + k_squeeze, + v_transpose, + q_squeeze, + v_squeeze, + q_scale_mul, + k_transpose, + qk_matmul, + softmax, + qkv_matmul_1, + transpose_3, + reshape_2, + proj_matmul, + proj_add, + proj_gemm_input_reshape, + proj_gemm, + proj_gemm_output_reshape}, + compatible_execution_providers)) { + return fail("matched nodes are assigned to incompatible execution providers"); + } + + auto mha_output_type = TryCreateMobileClipMhaOutputType(*qkv_matmul.OutputDefs()[0], hidden_size); + auto* mha_output = &graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("mobileclip_mha_output"), + mha_output_type ? &*mha_output_type : nullptr); + + if (proj_matmul != nullptr) { + if (!TryRewriteProjectionMatMulAddToGemm(graph, *mha_output, *proj_matmul, *proj_add)) { + return fail("projection MatMul/Add could not be rewritten to Gemm"); + } + } else if (proj_gemm_input_reshape == nullptr) { + if (!TryRewriteProjectionGemm(graph, *mha_output, *proj_gemm)) { + return fail("projection Gemm could not be normalized"); + } + } + + ONNX_NAMESPACE::TensorProto split_sizes_tensor; + split_sizes_tensor.set_name(graph.GenerateNodeArgName("mobileclip_mha_split_sizes")); + split_sizes_tensor.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + split_sizes_tensor.add_dims(3); + const std::array split_sizes{hidden_size, hidden_size, hidden_size}; + utils::SetRawDataInTensorProto(split_sizes_tensor, split_sizes.data(), split_sizes.size() * sizeof(int64_t)); + NodeArg& split_sizes_arg = graph_utils::AddInitializerWithOrtValue(graph, split_sizes_tensor); + + auto* mha_q = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("mobileclip_mha_q"), nullptr); + auto* mha_k = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("mobileclip_mha_k"), nullptr); + auto* mha_v = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("mobileclip_mha_v"), nullptr); + + Node& split_for_mha = graph.AddNode( + graph.GenerateNodeName("MobileClipSplitForMHA"), + "Split", + "Split packed MobileCLIP QKV for MultiHeadAttention", + {qkv_matmul.MutableOutputDefs()[0], &split_sizes_arg}, + {mha_q, mha_k, mha_v}, + nullptr, + kOnnxDomain); + split_for_mha.AddAttribute("axis", static_cast(2)); + + Node& mha_node = graph.AddNode( + graph.GenerateNodeName("MobileClipMultiHeadAttention"), + "MultiHeadAttention", + "Fused MobileCLIP attention subgraph", + {mha_q, mha_k, mha_v}, + {mha_output}, + nullptr, + kMSDomain); + mha_node.AddAttribute("num_heads", num_heads); + mha_node.AddAttribute("scale", scale); + + const auto& provider = qkv_matmul.GetExecutionProviderType(); + split_for_mha.SetExecutionProviderType(provider); + mha_node.SetExecutionProviderType(provider); + + if (proj_gemm_input_reshape != nullptr) { + graph_utils::ReplaceDownstreamNodeInput(graph, *reshape_2, 0, mha_node, 0); + } + + std::vector nodes_to_remove{ + qkv_reshape->Index(), + split->Index(), + q_transpose->Index(), + q_squeeze->Index(), + q_scale_mul->Index(), + k_squeeze->Index(), + k_transpose->Index(), + qk_matmul->Index(), + softmax->Index(), + v_transpose->Index(), + v_squeeze->Index(), + qkv_matmul_1->Index(), + transpose_3->Index(), + reshape_2->Index(), + }; + + if (proj_matmul != nullptr) { + nodes_to_remove.push_back(proj_matmul->Index()); + nodes_to_remove.push_back(proj_add->Index()); + } else if (proj_gemm_input_reshape == nullptr) { + nodes_to_remove.push_back(proj_gemm->Index()); + } + + for (const auto& node_index : nodes_to_remove) { + Node* node = graph.GetNode(node_index); + if (node == nullptr) { + continue; + } + + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(node_index); + } + + LOGS(logger, VERBOSE) << "MobileClipMHA[" << qkv_matmul.Name() + << "]: fused MobileCLIP attention subgraph to MultiHeadAttention"; + + return true; +} + +} // namespace + static bool ValidateMatMulInitializer(const Graph& graph, const Node& matmul, int64_t hidden_size) { const NodeArg& input_b = *(matmul.InputDefs()[1]); if (!graph_utils::IsInitializer(graph, input_b.Name(), true)) { @@ -179,6 +837,12 @@ Status AttentionFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, Node& node = *p_node; ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + if (TryFuseMobileClipMHA(node, graph, GetCompatibleExecutionProviders(), logger)) { + fused_count++; + modified = true; + continue; + } + // Add node.GetOutputEdgesCount() == 5/6 for distilbert if ((node.GetOutputEdgesCount() >= 2 && node.GetOutputEdgesCount() <= 6) && graph_utils::IsSupportedOptypeVersionAndDomain(node, "LayerNormalization", {1, 17}, kOnnxDomain) && diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 18933e45b8922..75ba3b802f9ae 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -5826,6 +5826,363 @@ TEST_F(GraphTransformationTests, AttentionFusionDistilBertTest) { EXPECT_EQ(op_to_count["Shape"], 0); } +enum class MobileClipProjectionType { + MatMulAdd, + GemmWithReshapes, +}; + +struct MobileClipAttentionShapeConfig { + int64_t input_channels = 512; + int64_t hidden_size = 512; + int64_t num_heads = 16; + int64_t head_size = 32; + int64_t qkv_weight_input_dim = 512; +}; + +static void BuildMobileClipAttentionTestCase(ModelTestBuilder& builder, + MobileClipProjectionType projection_type, + const MobileClipAttentionShapeConfig& shape_config = {}, + bool use_non_default_projection_gemm_attributes = false, + bool use_runtime_projection_shape_input = false) { + const int64_t input_channels = shape_config.input_channels; + const int64_t hidden_size = shape_config.hidden_size; + const int64_t num_heads = shape_config.num_heads; + const int64_t head_size = shape_config.head_size; + const int64_t qkv_weight_input_dim = shape_config.qkv_weight_input_dim; + const int64_t qkv_hidden_size = num_heads * head_size; + const int64_t qkv_output_size = 3 * qkv_hidden_size; + + auto* input_x = builder.MakeInput({1, input_channels, 8, 8}, -1.0f, 1.0f); + auto* input_skip = builder.MakeInput({1, hidden_size, 8, 8}, -1.0f, 1.0f); + + auto* reshape0_shape = builder.Make1DInitializer({1, input_channels, 64}); + auto* qkv_weight = builder.MakeInitializer({qkv_weight_input_dim, qkv_output_size}, -0.05f, 0.05f); + auto* qkv_reshape_shape = builder.Make1DInitializer({1, 64, 3, num_heads, head_size}); + auto* split_sizes = builder.Make1DInitializer({1, 1, 1}); + auto* squeeze_axis_0 = builder.Make1DInitializer({0}); + auto* squeeze_axis_2 = builder.Make1DInitializer({2}); + auto* scale = builder.MakeScalarInitializer(1.0f / std::sqrt(static_cast(head_size))); + auto* reshape2_shape = use_runtime_projection_shape_input + ? builder.MakeInput({3}, {1, 64, hidden_size}) + : builder.Make1DInitializer({1, 64, hidden_size}); + auto* proj_gemm_input_shape = builder.Make1DInitializer({64, hidden_size}); + auto* proj_weight = builder.MakeInitializer({hidden_size, hidden_size}, -0.05f, 0.05f); + auto* proj_bias = builder.MakeInitializer({hidden_size}, -0.02f, 0.02f); + auto* proj_gemm_output_shape = builder.Make1DInitializer({1, 64, hidden_size}); + auto* reshape3_shape = builder.Make1DInitializer({1, hidden_size, 8, 8}); + auto* layer_scale = builder.MakeInitializer({hidden_size, 1, 1}, 0.9f, 1.1f); + + auto* reshape0_out = builder.MakeIntermediate(std::vector{1, input_channels, 64}); + auto* transpose0_out = builder.MakeIntermediate(std::vector{1, 64, input_channels}); + auto* qkv_out = builder.MakeIntermediate(std::vector{1, 64, qkv_output_size}); + auto* qkv_reshape_out = builder.MakeIntermediate(std::vector{1, 64, 3, num_heads, head_size}); + auto* split_q = builder.MakeIntermediate(std::vector{1, 64, 1, num_heads, head_size}); + auto* split_k = builder.MakeIntermediate(std::vector{1, 64, 1, num_heads, head_size}); + auto* split_v = builder.MakeIntermediate(std::vector{1, 64, 1, num_heads, head_size}); + auto* q_transpose_out = builder.MakeIntermediate(std::vector{1, 1, num_heads, 64, head_size}); + auto* q_squeeze_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, head_size}); + auto* k_squeeze_out = builder.MakeIntermediate(std::vector{1, 64, num_heads, head_size}); + auto* k_transpose_out = builder.MakeIntermediate(std::vector{1, num_heads, head_size, 64}); + auto* q_scaled_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, head_size}); + auto* qk_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, 64}); + auto* softmax_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, 64}); + auto* v_transpose_out = builder.MakeIntermediate(std::vector{1, 1, num_heads, 64, head_size}); + auto* v_squeeze_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, head_size}); + auto* attn_out = builder.MakeIntermediate(std::vector{1, num_heads, 64, head_size}); + auto* transpose3_out = builder.MakeIntermediate(std::vector{1, 64, num_heads, head_size}); + auto* reshape2_out = use_runtime_projection_shape_input + ? builder.MakeIntermediate(std::nullopt) + : builder.MakeIntermediate(std::vector{1, 64, hidden_size}); + auto* proj_gemm_input_out = builder.MakeIntermediate(std::vector{64, hidden_size}); + auto* proj_gemm_out = builder.MakeIntermediate(std::vector{64, hidden_size}); + auto* proj_linear_out = builder.MakeIntermediate(std::vector{1, 64, hidden_size}); + auto* transpose4_out = builder.MakeIntermediate(std::vector{1, hidden_size, 64}); + auto* reshape3_out = builder.MakeIntermediate(std::vector{1, hidden_size, 8, 8}); + auto* layer_scale_out = builder.MakeIntermediate(std::vector{1, hidden_size, 8, 8}); + auto* output = builder.MakeOutput(std::vector{1, hidden_size, 8, 8}); + + auto& reshape0 = builder.AddNode("Reshape", std::vector{input_x, reshape0_shape}, std::vector{reshape0_out}); + reshape0.AddAttribute("allowzero", static_cast(0)); + + auto& transpose0 = builder.AddNode("Transpose", std::vector{reshape0_out}, std::vector{transpose0_out}); + transpose0.AddAttribute("perm", std::vector{0, 2, 1}); + + builder.AddNode("MatMul", std::vector{transpose0_out, qkv_weight}, std::vector{qkv_out}); + + auto& qkv_reshape = builder.AddNode("Reshape", std::vector{qkv_out, qkv_reshape_shape}, std::vector{qkv_reshape_out}); + qkv_reshape.AddAttribute("allowzero", static_cast(0)); + + auto& split = builder.AddNode("Split", std::vector{qkv_reshape_out, split_sizes}, std::vector{split_q, split_k, split_v}); + split.AddAttribute("axis", static_cast(2)); + + auto& q_transpose = builder.AddNode("Transpose", std::vector{split_q}, std::vector{q_transpose_out}); + q_transpose.AddAttribute("perm", std::vector{2, 0, 3, 1, 4}); + + builder.AddNode("Squeeze", std::vector{q_transpose_out, squeeze_axis_0}, std::vector{q_squeeze_out}); + builder.AddNode("Squeeze", std::vector{split_k, squeeze_axis_2}, std::vector{k_squeeze_out}); + + auto& k_transpose = builder.AddNode("Transpose", std::vector{k_squeeze_out}, std::vector{k_transpose_out}); + k_transpose.AddAttribute("perm", std::vector{0, 2, 3, 1}); + + builder.AddNode("Mul", std::vector{q_squeeze_out, scale}, std::vector{q_scaled_out}); + builder.AddNode("MatMul", std::vector{q_scaled_out, k_transpose_out}, std::vector{qk_out}); + + auto& softmax = builder.AddNode("Softmax", std::vector{qk_out}, std::vector{softmax_out}); + softmax.AddAttribute("axis", static_cast(-1)); + + auto& v_transpose = builder.AddNode("Transpose", std::vector{split_v}, std::vector{v_transpose_out}); + v_transpose.AddAttribute("perm", std::vector{2, 0, 3, 1, 4}); + + builder.AddNode("Squeeze", std::vector{v_transpose_out, squeeze_axis_0}, std::vector{v_squeeze_out}); + builder.AddNode("MatMul", std::vector{softmax_out, v_squeeze_out}, std::vector{attn_out}); + + auto& transpose3 = builder.AddNode("Transpose", std::vector{attn_out}, std::vector{transpose3_out}); + transpose3.AddAttribute("perm", std::vector{0, 2, 1, 3}); + + auto& reshape2 = builder.AddNode("Reshape", std::vector{transpose3_out, reshape2_shape}, std::vector{reshape2_out}); + reshape2.AddAttribute("allowzero", static_cast(0)); + + if (projection_type == MobileClipProjectionType::GemmWithReshapes) { + auto& proj_gemm_input = builder.AddNode("Reshape", std::vector{reshape2_out, proj_gemm_input_shape}, + std::vector{proj_gemm_input_out}); + proj_gemm_input.AddAttribute("allowzero", static_cast(0)); + + auto& proj_gemm = builder.AddNode("Gemm", std::vector{proj_gemm_input_out, proj_weight, proj_bias}, + std::vector{proj_gemm_out}); + if (use_non_default_projection_gemm_attributes) { + proj_gemm.AddAttribute("transB", static_cast(1)); + } + + auto& proj_gemm_output = builder.AddNode("Reshape", std::vector{proj_gemm_out, proj_gemm_output_shape}, + std::vector{proj_linear_out}); + proj_gemm_output.AddAttribute("allowzero", static_cast(0)); + } else { + auto* proj_matmul_out = builder.MakeIntermediate(std::vector{1, 64, hidden_size}); + builder.AddNode("MatMul", std::vector{reshape2_out, proj_weight}, std::vector{proj_matmul_out}); + builder.AddNode("Add", std::vector{proj_bias, proj_matmul_out}, std::vector{proj_linear_out}); + } + + auto& transpose4 = builder.AddNode("Transpose", std::vector{proj_linear_out}, std::vector{transpose4_out}); + transpose4.AddAttribute("perm", std::vector{0, 2, 1}); + + auto& reshape3 = builder.AddNode("Reshape", std::vector{transpose4_out, reshape3_shape}, std::vector{reshape3_out}); + reshape3.AddAttribute("allowzero", static_cast(0)); + + builder.AddNode("Mul", std::vector{layer_scale, reshape3_out}, std::vector{layer_scale_out}); + builder.AddNode("Add", std::vector{input_skip, layer_scale_out}, std::vector{output}); +} + +static Status CheckMobileClipAttentionFusedGraph(Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Softmax"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Squeeze"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Split"] == 1); + TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 4); + TEST_RETURN_IF_NOT(op_to_count["Mul"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 1); + + int mha_nodes = 0; + int gemm_nodes = 0; + int split_nodes = 0; + for (Node& node : graph.Nodes()) { + if (node.OpType() == "MultiHeadAttention" && node.Domain() == kMSDomain) { + ++mha_nodes; + TEST_RETURN_IF_NOT(node.GetAttributes().at("num_heads").i() == 16); + TEST_RETURN_IF_NOT(std::abs(node.GetAttributes().at("scale").f() - (1.0f / std::sqrt(32.0f))) < 1e-6f); + TEST_RETURN_IF_NOT(node.OutputDefs().size() == 1); + TEST_RETURN_IF_NOT(node.OutputDefs()[0]->Shape() != nullptr); + TEST_RETURN_IF_NOT(node.OutputDefs()[0]->Shape()->dim_size() == 3); + } else if (node.OpType() == "Split") { + ++split_nodes; + } else if (node.OpType() == "Gemm") { + ++gemm_nodes; + TEST_RETURN_IF_NOT(node.InputDefs().size() == 3); + TEST_RETURN_IF_NOT(node.OutputDefs().size() == 1); + TEST_RETURN_IF_NOT(node.InputDefs()[0]->Shape() != nullptr); + TEST_RETURN_IF_NOT(node.InputDefs()[0]->Shape()->dim_size() == 2); + TEST_RETURN_IF_NOT(node.OutputDefs()[0]->Shape() != nullptr); + TEST_RETURN_IF_NOT(node.OutputDefs()[0]->Shape()->dim_size() == 2); + + const Node* gemm_input_node = graph_utils::GetInputNode(node, 0); + TEST_RETURN_IF_NOT(gemm_input_node != nullptr); + TEST_RETURN_IF_NOT(gemm_input_node->OpType() == "Reshape"); + + bool has_output_reshape = false; + for (const Node& consumer : graph.Nodes()) { + for (const NodeArg* input_def : consumer.InputDefs()) { + if (input_def != nullptr && input_def->Name() == node.OutputDefs()[0]->Name()) { + has_output_reshape = consumer.OpType() == "Reshape"; + break; + } + } + + if (has_output_reshape) { + break; + } + } + + TEST_RETURN_IF_NOT(has_output_reshape); + } + } + + TEST_RETURN_IF_NOT(mha_nodes == 1); + TEST_RETURN_IF_NOT(gemm_nodes == 1); + TEST_RETURN_IF_NOT(split_nodes == 1); + return Status::OK(); +} + +static Status CheckMobileClipAttentionFusedGraphOnProvider(Graph& graph, const char* provider) { + ORT_RETURN_IF_ERROR(CheckMobileClipAttentionFusedGraph(graph)); + + for (Node& node : graph.Nodes()) { + TEST_RETURN_IF_NOT(node.GetExecutionProviderType() == provider); + } + + return Status::OK(); +} + +static Status CheckMobileClipAttentionUnfusedProjectionGemmGraph(Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Squeeze"] == 3); + TEST_RETURN_IF_NOT(op_to_count["Split"] == 1); + TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 3); + TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 6); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 6); + TEST_RETURN_IF_NOT(op_to_count["Mul"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 1); + + int gemm_nodes = 0; + for (Node& node : graph.Nodes()) { + if (node.OpType() != "Gemm") { + continue; + } + + ++gemm_nodes; + const auto& attrs = node.GetAttributes(); + auto trans_b_attr = attrs.find("transB"); + TEST_RETURN_IF_NOT(trans_b_attr != attrs.end()); + TEST_RETURN_IF_NOT(trans_b_attr->second.i() == 1); + } + + TEST_RETURN_IF_NOT(gemm_nodes == 1); + return Status::OK(); +} + +static Status CheckMobileClipAttentionUnfusedMatMulGraph(Graph& graph) { + auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count["com.microsoft.MultiHeadAttention"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Gemm"] == 0); + TEST_RETURN_IF_NOT(op_to_count["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_to_count["Squeeze"] == 3); + TEST_RETURN_IF_NOT(op_to_count["Split"] == 1); + TEST_RETURN_IF_NOT(op_to_count["MatMul"] == 4); + TEST_RETURN_IF_NOT(op_to_count["Transpose"] == 6); + TEST_RETURN_IF_NOT(op_to_count["Reshape"] == 4); + TEST_RETURN_IF_NOT(op_to_count["Mul"] == 2); + TEST_RETURN_IF_NOT(op_to_count["Add"] == 2); + return Status::OK(); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionFusedGraph)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionFusedGraph)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaCudaEpTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd); + }; + + auto pre_graph_checker = [](Graph& graph) { + for (Node& node : graph.Nodes()) { + node.SetExecutionProviderType(kCudaExecutionProvider); + } + + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + return CheckMobileClipAttentionFusedGraphOnProvider(graph, kCudaExecutionProvider); + }; + + ASSERT_STATUS_OK(TestGraphTransformer( + build_test_case, 14, *logger_, std::make_unique(InlinedHashSet{kCudaExecutionProvider}), + TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmCudaEpTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes); + }; + + auto pre_graph_checker = [](Graph& graph) { + for (Node& node : graph.Nodes()) { + node.SetExecutionProviderType(kCudaExecutionProvider); + } + + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + return CheckMobileClipAttentionFusedGraphOnProvider(graph, kCudaExecutionProvider); + }; + + ASSERT_STATUS_OK(TestGraphTransformer( + build_test_case, 14, *logger_, std::make_unique(InlinedHashSet{kCudaExecutionProvider}), + TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaInvalidQkvWeightShapeTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, + MobileClipProjectionType::MatMulAdd, + MobileClipAttentionShapeConfig{512, 510, 15, 34, 512}); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, CheckMobileClipAttentionUnfusedMatMulGraph)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionGemmNonDefaultAttributesTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::GemmWithReshapes, {}, true); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, + CheckMobileClipAttentionUnfusedProjectionGemmGraph)); +} + +TEST_F(GraphTransformationTests, AttentionFusionMobileClipMhaProjectionRewriteFailureLeavesGraphUnfusedTest) { + auto build_test_case = [](ModelTestBuilder& builder) { + BuildMobileClipAttentionTestCase(builder, MobileClipProjectionType::MatMulAdd, {}, false, true); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, nullptr, + CheckMobileClipAttentionUnfusedMatMulGraph)); +} + TEST_F(GraphTransformationTests, GeluFusionTest) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/gelu.onnx"; std::shared_ptr p_model;