diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 0487f7761352f..8a403e3318f04 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -93,11 +93,14 @@ def __init__( hidden_size: int, num_heads: int, attention_mask: AttentionMask, + use_multi_head_attention: bool = False, ): - super().__init__(model, "Attention", ["SkipLayerNormalization", "LayerNormalization"]) + attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention" + super().__init__(model, attention_op_name, ["SkipLayerNormalization", "LayerNormalization"]) self.hidden_size = hidden_size self.num_heads = num_heads self.attention_mask = attention_mask + self.use_multi_head_attention = use_multi_head_attention # Flags to show warning only once self.num_heads_warning = True @@ -108,18 +111,18 @@ def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[ Detect num_heads and hidden_size from Concat node in the following subgraph: SkipLayerNormalization or EmbedLayerNormalization - / \ - MatMul Shape - | | - Add Gather(indices=0) - \ | - \ Unsqueeze - \ | - \ Concat (*, -1, 12, 64) - \ / - Reshape - | - Transpose + / | + MatMul Shape + | | + Add Gather(indices=0) + | | + | Unsqueeze + | | + | Concat (*, -1, 12, 64) + | / + Reshape + | + Transpose """ if len(concat.input) == 4: num_heads = self.model.get_constant_value(concat.input[2]) @@ -307,17 +310,18 @@ def create_attention_node( attention_node_name = self.model.create_node_name("Attention") - weight = helper.make_tensor( - name=attention_node_name + "_qkv_weight", - data_type=TensorProto.FLOAT, - dims=[qw_in_size, qkv_weight_dim], - vals=qkv_weight.flatten().tolist(), - ) + if not self.use_multi_head_attention: + weight = helper.make_tensor( + name=attention_node_name + "_qkv_weight", + data_type=TensorProto.FLOAT, + dims=[qw_in_size, qkv_weight_dim], + vals=qkv_weight.flatten().tolist(), + ) - # Sometimes weights and bias are stored in fp16 - if q_weight.data_type == 10: - weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name)) - self.model.add_initializer(weight, self.this_graph_name) + # Sometimes weights and bias are stored in fp16 + if q_weight.data_type == 10: + weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name)) + self.model.add_initializer(weight, self.this_graph_name) bias = helper.make_tensor( name=attention_node_name + "_qkv_bias", @@ -329,26 +333,48 @@ def create_attention_node( bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) self.model.add_initializer(bias, self.this_graph_name) - attention_inputs = [ - input, - attention_node_name + "_qkv_weight", - attention_node_name + "_qkv_bias", - ] - if mask_index is not None: - attention_inputs.append(mask_index) + # For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights. + if self.use_multi_head_attention: + if add_qk_str is not None: + logger.debug("MultiHeadAttention does not support extra_add_qk: cannot fuse the attention.") + return None + + attention_inputs = [ + q_matmul.output[0], + k_matmul.output[0], + v_matmul.output[0], + attention_node_name + "_qkv_bias", + ] + if mask_index is not None: + attention_inputs.append(mask_index) + + attention_node = helper.make_node( + "MultiHeadAttention", + inputs=attention_inputs, + outputs=[output], + name=attention_node_name, + ) else: - attention_inputs.append("") + attention_inputs = [ + input, + attention_node_name + "_qkv_weight", + attention_node_name + "_qkv_bias", + ] + if mask_index is not None: + attention_inputs.append(mask_index) + else: + attention_inputs.append("") - if add_qk_str is not None: - attention_inputs.append("") - attention_inputs.append(add_qk_str) + if add_qk_str is not None: + attention_inputs.append("") # no past + attention_inputs.append(add_qk_str) - attention_node = helper.make_node( - "Attention", - inputs=attention_inputs, - outputs=[output], - name=attention_node_name, - ) + attention_node = helper.make_node( + "Attention", + inputs=attention_inputs, + outputs=[output], + name=attention_node_name, + ) attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) @@ -595,10 +621,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) self.nodes_to_remove.extend(qk_nodes) - self.nodes_to_remove.extend(q_nodes) - self.nodes_to_remove.extend(k_nodes) - self.nodes_to_remove.extend(v_nodes) + + # For MultiHeadAttention operator, MatMul nodes for Q/K/V projection shall not be fused. + self.nodes_to_remove.extend(q_nodes if not self.use_multi_head_attention else q_nodes[:-1]) + self.nodes_to_remove.extend(k_nodes if not self.use_multi_head_attention else k_nodes[:-1]) + self.nodes_to_remove.extend(v_nodes if not self.use_multi_head_attention else v_nodes[:-1]) # Use prune graph to remove mask nodes since they are shared by all attention nodes. - # self.nodes_to_remove.extend(mask_nodes) self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/fusion_embedlayer.py b/onnxruntime/python/tools/transformers/fusion_embedlayer.py index e0cfc15a33d4f..f4ae184bdf825 100644 --- a/onnxruntime/python/tools/transformers/fusion_embedlayer.py +++ b/onnxruntime/python/tools/transformers/fusion_embedlayer.py @@ -63,48 +63,64 @@ def check_attention_subgraph( self.attention = self.model.find_first_child_by_type( layernorm, "Attention", input_name_to_nodes, recursive=False ) - if self.attention is None: - # In case user disables attention fusion, check whether subgraph looks like Attention. - if layernorm.output[0] not in input_name_to_nodes: + + if self.attention is not None: + return True + + if layernorm.output[0] not in input_name_to_nodes: + return False + children = input_name_to_nodes[layernorm.output[0]] + children_types = sorted([child.op_type for child in children]) + + # Try find MultiHeadAttention + if children_types == ["MatMul", "MatMul", "MatMul", "SkipLayerNormalization"]: + for node in children: + if node.op_type == "SkipLayerNormalization": + path1 = self.model.match_parent_path( + node, + ["Add", "MatMul", "MultiHeadAttention", "MatMul"], + [None, None, 0, 0], + ) + if path1 is not None and path1[-1].input[0] == layernorm.output[0]: + self.cross_attention = path1[2] + return True + + # In case user disables attention fusion, check whether subgraph looks like Attention. + # For Albert, there is MatMul+Add after embedding layer before attention. + if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes: + grandchildren = input_name_to_nodes[children[0].output[0]] + if ( + len(grandchildren) == 1 + and grandchildren[0].op_type == "Add" + and grandchildren[0].output[0] in input_name_to_nodes + ): + nodes = input_name_to_nodes[grandchildren[0].output[0]] + for node in nodes: + if node.op_type == "Attention": + self.attention = node + return True + children_types = sorted([child.op_type for child in nodes]) + + # Two Shape nodes might be merged by ORT + if is_distil_bert: + # SkipLayerNormailization might exist when model has been optimized by ORT first. + if ( + children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"] + and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"] + and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"] + ): + logger.debug("No Attention like subgraph in children of LayerNormalization") return False - children = input_name_to_nodes[layernorm.output[0]] - - # For Albert, there is MatMul+Add after embedding layer before attention. - if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes: - grandchildren = input_name_to_nodes[children[0].output[0]] - if ( - len(grandchildren) == 1 - and grandchildren[0].op_type == "Add" - and grandchildren[0].output[0] in input_name_to_nodes - ): - nodes = input_name_to_nodes[grandchildren[0].output[0]] - for node in nodes: - if node.op_type == "Attention": - self.attention = node - return True - children_types = sorted([child.op_type for child in nodes]) - else: - children_types = sorted([child.op_type for child in children]) - - # Two Shape nodes might be merged by ORT - if is_distil_bert: - # SkipLayerNormailization might exist when model has been optimized by ORT first. - if ( - children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"] - and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"] - and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"] - ): - logger.debug("No Attention like subgraph in children of LayerNormalization") - return False - else: - if children_types != ["Add", "MatMul", "MatMul", "MatMul",] and children_types != [ - "MatMul", - "MatMul", - "MatMul", - "SkipLayerNormalization", - ]: - logger.debug("No Attention like subgraph in children of LayerNormalization") - return False + else: + if children_types != ["Add", "MatMul", "MatMul", "MatMul",] and children_types != [ + "MatMul", + "MatMul", + "MatMul", + "SkipLayerNormalization", + ]: + logger.debug("No Attention like subgraph in children of LayerNormalization") + return False + return True def match_position_embedding_distilbert(self, position_embedding_gather, input_ids, output_name_to_node): @@ -713,11 +729,15 @@ def replace_mask(self, mask_int32, attention_nodes): for attention_node in attention_nodes: logger.debug("update mask_index in %s", attention_node.name) - attention_node.input[3] = embed_node.output[1] + if attention_node.op_type == "Attention": + attention_node.input[3] = embed_node.output[1] + elif attention_node.op_type == "MultiHeadAttention": + attention_node.input[4] = embed_node.output[1] def fuse(self, node, input_name_to_nodes, output_name_to_node): # Reset attention and embed_node so that we know fusion is successful when they are not None. self.attention = None + self.cross_attention = None self.embed_node = None super().fuse(node, input_name_to_nodes, output_name_to_node) @@ -729,15 +749,19 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): self.increase_counter("EmbedLayerNormalization(no mask)") return - if self.attention is None: + if self.attention is None and self.cross_attention is None: logger.debug("EmbedLayerNormalization will not have mask since attention node is not found") self.increase_counter("EmbedLayerNormalization(no mask)") return - mask_int32 = self.attention.input[3] + if self.attention: + mask_int32 = self.attention.input[3] + else: + mask_int32 = self.cross_attention.input[4] + children_nodes = input_name_to_nodes[mask_int32] if self.model.find_graph_input(mask_int32): - attention_nodes = [node for node in children_nodes if node.op_type == "Attention"] + attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]] self.replace_mask(mask_int32, attention_nodes) self.increase_counter("EmbedLayerNormalization(with mask)") return @@ -749,7 +773,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): node = output_name_to_node[mask_int32] if node.op_type in ["ReduceSum", "Cast"]: - attention_nodes = [node for node in children_nodes if node.op_type == "Attention"] + attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]] if node.op_type == "ReduceSum": mask_int32 = node.input[0] if len(children_nodes) == len(attention_nodes): diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 2473da8004ff2..9a5359b58caa6 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -19,6 +19,14 @@ def __init__(self, model_type): self.enable_gelu = True self.enable_layer_norm = True self.enable_attention = True + + # Use MultiHeadAttention instead of Attention operator. The difference: + # (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is + # merged into one. + # (2) Attention could only handle self attention; MultiHeadAttention could handle both self and cross attention. + # (3) MultiHeadAttention has only cuda implementation right now. + self.use_multi_head_attention = False + self.enable_skip_layer_norm = True self.enable_embed_layer_norm = True self.enable_bias_skip_layer_norm = True @@ -48,6 +56,8 @@ def parse(args): options.enable_layer_norm = False if args.disable_attention: options.enable_attention = False + if args.use_multi_head_attention: + options.use_multi_head_attention = True if args.disable_skip_layer_norm: options.enable_skip_layer_norm = False if args.disable_embed_layer_norm: @@ -165,3 +175,13 @@ def add_arguments(parser: ArgumentParser): help="no attention mask. Only works for model_type=bert", ) parser.set_defaults(no_attention_mask=False) + + parser.add_argument( + "--use_multi_head_attention", + required=False, + action="store_true", + help="Use MultiHeadAttention instead of Attention operator for testing purpose. " + "Note that MultiHeadAttention might be slower than Attention since MatMul of input projection is excluded. " + "MultiHeadAttention has only CUDA implementation so the model can only run with cuda execution provider.", + ) + parser.set_defaults(use_multi_head_attention=False) diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 60bd2d7b8e2dc..81c83d222529f 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -385,9 +385,14 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() + if options is not None: + self.attention_mask.set_mask_format(options.attention_mask_format) + if options.use_multi_head_attention: + self.attention_fusion = FusionAttention( + self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention + ) + if (options is None) or options.enable_attention: - if options is not None: - self.attention_mask.set_mask_format(options.attention_mask_format) self.fuse_attention() # Perform the MatMul fusion after the Attention fusion as we do not @@ -438,6 +443,7 @@ def get_fused_operator_statistics(self): ops = [ "EmbedLayerNormalization", "Attention", + "MultiHeadAttention", "Gelu", "FastGelu", "BiasGelu", @@ -459,7 +465,7 @@ def is_fully_optimized(self): """ op_count = self.get_fused_operator_statistics() embed = op_count["EmbedLayerNormalization"] - attention = op_count["Attention"] + op_count["QOrderedAttention"] + attention = op_count["Attention"] + op_count["MultiHeadAttention"] + op_count["QOrderedAttention"] gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"] layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"] is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention) diff --git a/onnxruntime/test/python/transformers/test_attention_fusion.py b/onnxruntime/test/python/transformers/test_attention_fusion.py index 33656abaeee34..74d20295a0a63 100644 --- a/onnxruntime/test/python/transformers/test_attention_fusion.py +++ b/onnxruntime/test/python/transformers/test_attention_fusion.py @@ -33,6 +33,17 @@ def verify_fusion(self, optimized_model, expected_model_filename): self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph)) + def test_multi_head_attention_fusion(self): + model = create_bert_attention() + dir = "." + model_path = os.path.join(dir, "attention.onnx") + onnx.save(model, model_path) + options = FusionOptions("bert") + options.use_multi_head_attention = True + optimized_model = optimize_model(model_path, optimization_options=options) + os.remove(model_path) + self.verify_fusion(optimized_model, "attention_mha.onnx") + def test_attention_fusion(self): model = create_bert_attention() dir = "." diff --git a/onnxruntime/test/python/transformers/test_data/models/attention_mha.onnx b/onnxruntime/test/python/transformers/test_data/models/attention_mha.onnx new file mode 100644 index 0000000000000..76d808538e0e4 Binary files /dev/null and b/onnxruntime/test/python/transformers/test_data/models/attention_mha.onnx differ