diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index de7f0a044c118..039c1dab16f3c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -892,6 +892,13 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): add_before_layernorm = self.model.match_parent(normalize_node, "Add", 0) if add_before_layernorm is not None: start_node = add_before_layernorm + elif self.model.find_graph_input(normalize_node.input[0]) is not None: + # Pre-LN first block: LN fed directly by graph input. QKV matching will + # still fail from this (first) LN anchor because its inputs are weights, not + # the QKV projection path. The real fusion happens when fuse() is called + # again from the second LN/SkipLN anchor after the residual Add, where the + # other_inputs and root_input changes (#2-#4) take effect. + start_node = normalize_node else: return @@ -917,7 +924,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): other_inputs = [] for _i, node_input in enumerate(start_node.input): if node_input not in output_name_to_node: - continue + if self.model.find_graph_input(node_input) is None: + continue if node_input == qkv_nodes[0].output[0]: continue @@ -946,7 +954,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): root_input = mul_before_layernorm.output[0] else: return - elif normalize_node.op_type == "LayerNormalization": + elif normalize_node.op_type in ("LayerNormalization", "SkipLayerNormalization"): children = input_name_to_nodes[root_input] for child in children: if child.op_type == "LayerNormalization": @@ -961,9 +969,10 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): # | | # | | # +---------------------------------------------------------------------+ - parent_node = output_name_to_node[root_input] - if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4: - root_input = parent_node.output[0] + if root_input in output_name_to_node: + parent_node = output_name_to_node[root_input] + if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4: + root_input = parent_node.output[0] children = input_name_to_nodes[root_input] children_types = [child.op_type for child in children] diff --git a/onnxruntime/test/python/transformers/bert_model_generator.py b/onnxruntime/test/python/transformers/bert_model_generator.py index 0bb71bd8736d4..8cb6a29fe7ead 100644 --- a/onnxruntime/test/python/transformers/bert_model_generator.py +++ b/onnxruntime/test/python/transformers/bert_model_generator.py @@ -259,6 +259,203 @@ def create_bert_attention( return helper.make_model(graph, opset_imports=(opsetid,)) +def create_bert_attention_pre_ln( + input_hidden_size=16, + num_heads=2, + pruned_qk_hidden_size=16, + pruned_v_hidden_size=16, + switch_add_inputs=False, +): + """Create a pre-layer-norm first block attention graph (no mask). + + Unlike post-LN, the first block of a pre-LN model has no Add before the + first LayerNormalization — the graph input feeds LN directly. The residual + skip connection adds the graph input (not the LN output) to the attention + output. No attention mask is included so the graph exercises the + ``is_no_mask_attention`` code path (Softmax -> Div -> MatMul). + + Graph structure:: + + input_1 -> LN -> MatMul Q/K/V -> ... -> Add(attn_out, input_1) -> LN -> output + """ + nodes = [ + # First LayerNormalization takes graph input directly (no preceding Add) + helper.make_node( + "LayerNormalization", + ["input_1", "layer_norm_weight", "layer_norm_bias"], + ["layernorm_out"], + "layernorm", + axis=-1, + epsion=0.000009999999747378752, + ), + # q nodes + helper.make_node("MatMul", ["layernorm_out", "matmul_q_weight"], ["matmul_q_out"], "matmul_q"), + helper.make_node( + "Add", + reverse_if(["matmul_q_out", "add_q_weight"], switch_add_inputs), + ["add_q_out"], + "add_q", + ), + helper.make_node( + "Reshape", + ["add_q_out", "reshape_weight_qk"], + ["reshape_q_out"], + "reshape_q", + ), + helper.make_node( + "Transpose", + ["reshape_q_out"], + ["transpose_q_out"], + "transpose_q", + perm=[0, 2, 1, 3], + ), + # k nodes + helper.make_node("MatMul", ["layernorm_out", "matmul_k_weight"], ["matmul_k_out"], "matmul_k"), + helper.make_node( + "Add", + reverse_if(["matmul_k_out", "add_k_weight"], switch_add_inputs), + ["add_k_out"], + "add_k", + ), + helper.make_node( + "Reshape", + ["add_k_out", "reshape_weight_qk"], + ["reshape_k_out"], + "reshape_k", + ), + helper.make_node( + "Transpose", + ["reshape_k_out"], + ["transpose_k_out"], + "transpose_k", + perm=[0, 2, 3, 1], + ), + # qk nodes (no mask — uses the is_no_mask_attention path: Softmax -> Div -> MatMul) + helper.make_node( + "MatMul", + ["transpose_q_out", "transpose_k_out"], + ["matmul_qk_out"], + "matmul_qk", + ), + helper.make_node("Div", ["matmul_qk_out", "div_weight"], ["div_qk_out"], "div_qk"), + helper.make_node("Softmax", ["div_qk_out"], ["softmax_qk_out"], "softmax_qk", axis=3), + # v nodes + helper.make_node("MatMul", ["layernorm_out", "matmul_v_weight"], ["matmul_v_out"], "matmul_v"), + helper.make_node("Add", ["matmul_v_out", "add_v_weight"], ["add_v_out"], "add_v"), + helper.make_node("Reshape", ["add_v_out", "reshape_weight_v"], ["reshape_v_out"], "reshape_v"), + helper.make_node( + "Transpose", + ["reshape_v_out"], + ["transpose_v_out"], + "transpose_v", + perm=[0, 2, 1, 3], + ), + # qkv nodes + helper.make_node( + "MatMul", + ["softmax_qk_out", "transpose_v_out"], + ["matmul_qkv_1_out"], + "matmul_qkv_1", + ), + helper.make_node( + "Transpose", + ["matmul_qkv_1_out"], + ["transpose_qkv_out"], + "transpose_qkv", + perm=[0, 2, 1, 3], + ), + helper.make_node( + "Reshape", + ["transpose_qkv_out", "reshape_weight_qkv"], + ["reshape_qkv_out"], + "reshape_qkv", + ), + helper.make_node( + "MatMul", + ["reshape_qkv_out", "matmul_qkv_weight"], + ["matmul_qkv_2_out"], + "matmul_qkv_2", + ), + helper.make_node( + "Add", + reverse_if(["matmul_qkv_2_out", "add_qkv_weight"], switch_add_inputs), + ["add_qkv_out"], + "add_qkv", + ), + # Residual skip: adds attention output with original graph input (not LN output) + helper.make_node( + "Add", + reverse_if(["add_qkv_out", "input_1"], switch_add_inputs), + ["skip_output"], + "add_skip", + ), + helper.make_node( + "LayerNormalization", + ["skip_output", "layer_norm_weight_2", "layer_norm_bias_2"], + ["output"], + "layernorm2", + axis=-1, + epsion=0.000009999999747378752, + ), + ] + + pruned_qk_head_size = int(pruned_qk_hidden_size / num_heads) + pruned_v_head_size = int(pruned_v_hidden_size / num_heads) + initializers = [ + float_tensor("layer_norm_weight", [input_hidden_size]), + float_tensor("layer_norm_bias", [input_hidden_size]), + float_tensor("layer_norm_weight_2", [input_hidden_size]), + float_tensor("layer_norm_bias_2", [input_hidden_size]), + float_tensor("matmul_q_weight", [input_hidden_size, pruned_qk_hidden_size]), + float_tensor("matmul_k_weight", [input_hidden_size, pruned_qk_hidden_size]), + float_tensor("matmul_v_weight", [input_hidden_size, pruned_v_hidden_size]), + float_tensor("matmul_qkv_weight", [pruned_v_hidden_size, input_hidden_size]), + float_tensor("add_q_weight", [pruned_qk_hidden_size]), + float_tensor("add_k_weight", [pruned_qk_hidden_size]), + float_tensor("add_v_weight", [pruned_v_hidden_size]), + float_tensor("add_qkv_weight", [input_hidden_size]), + helper.make_tensor("div_weight", TensorProto.FLOAT, [1], [math.sqrt(pruned_qk_head_size)]), + helper.make_tensor( + "reshape_weight_qk", + TensorProto.INT64, + [4], + [0, 0, num_heads, pruned_qk_head_size], + ), + helper.make_tensor( + "reshape_weight_v", + TensorProto.INT64, + [4], + [0, 0, num_heads, pruned_v_head_size], + ), + helper.make_tensor("reshape_weight_qkv", TensorProto.INT64, [3], [0, 0, pruned_v_hidden_size]), + ] + + batch_size = 1 + sequence_length = 3 + graph = helper.make_graph( + [node for node in nodes if node], + "PreLNAttentionFusion", + [ # inputs: only one embedding input (no preceding Add) + helper.make_tensor_value_info( + "input_1", + TensorProto.FLOAT, + [batch_size, sequence_length, input_hidden_size], + ), + ], + [ # outputs + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + [batch_size, sequence_length, input_hidden_size], + ), + ], + initializers, + ) + + opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) + return helper.make_model(graph, opset_imports=(opsetid,)) + + def create_tf2onnx_attention_3d(input_hidden_size=16, num_heads=4, head_size=4, use_float_mask=False): # unsqueeze in opset version 13 has two inputs (axis is moved from attribute to input). has_unsqueeze_two_inputs = version.parse(onnx.__version__) >= version.parse("1.8.0") diff --git a/onnxruntime/test/python/transformers/test_attention_fusion.py b/onnxruntime/test/python/transformers/test_attention_fusion.py index a74a872c5734b..02cffd5e3a3aa 100644 --- a/onnxruntime/test/python/transformers/test_attention_fusion.py +++ b/onnxruntime/test/python/transformers/test_attention_fusion.py @@ -5,11 +5,12 @@ # -------------------------------------------------------------------------- import os +import tempfile import unittest import onnx from bart_model_generator import create_bart_attention_sdpa -from bert_model_generator import create_bert_attention, create_tf2onnx_attention_3d +from bert_model_generator import create_bert_attention, create_bert_attention_pre_ln, create_tf2onnx_attention_3d from gpt2_model_generator import create_gpt2_attention from model_loader import get_test_data_path from parity_utilities import find_transformers_source @@ -152,6 +153,67 @@ def test_3d_attention_fusion_tf2onnx_model(self): self.verify_fusion(optimized_model, "bert_3d_attention_opt.onnx") + def test_attention_fusion_pre_ln(self): + """Test attention fusion for pre-layer-norm first block. + + In a pre-LN model the first block has no Add before the first + LayerNormalization — the graph input feeds LN directly. + """ + model = create_bert_attention_pre_ln() + dir = tempfile.mkdtemp() + model_path = os.path.join(dir, "pre_ln_attention.onnx") + onnx.save(model, model_path) + options = FusionOptions("bert") + options.use_raw_attention_mask(True) + optimized_model = optimize_model(model_path, opt_level=0, optimization_options=options) + os.remove(model_path) + + attention_nodes = [n for n in optimized_model.model.graph.node if n.op_type == "Attention"] + self.assertEqual(len(attention_nodes), 1, "Expected exactly 1 fused Attention node") + num_heads_attr = next((a for a in attention_nodes[0].attribute if a.name == "num_heads"), None) + self.assertIsNotNone(num_heads_attr) + self.assertEqual(num_heads_attr.i, 2) + + def test_attention_fusion_pre_ln_reverse_add_order(self): + """Pre-LN fusion with reversed Add input ordering.""" + model = create_bert_attention_pre_ln(switch_add_inputs=True) + dir = tempfile.mkdtemp() + model_path = os.path.join(dir, "pre_ln_attention_reverse.onnx") + onnx.save(model, model_path) + options = FusionOptions("bert") + options.use_raw_attention_mask(True) + optimized_model = optimize_model(model_path, opt_level=0, optimization_options=options) + os.remove(model_path) + + attention_nodes = [n for n in optimized_model.model.graph.node if n.op_type == "Attention"] + self.assertEqual(len(attention_nodes), 1, "Expected exactly 1 fused Attention node") + num_heads_attr = next((a for a in attention_nodes[0].attribute if a.name == "num_heads"), None) + self.assertIsNotNone(num_heads_attr) + self.assertEqual(num_heads_attr.i, 2) + + def test_attention_fusion_pre_ln_with_skiplayernorm(self): + """Pre-LN fusion when SkipLayerNorm fusion runs first (exercises Change 3). + + The optimizer runs fuse_skip_layer_norm before fuse_attention. When enabled, + the Add + LayerNorm after the residual becomes a SkipLayerNormalization node, + and attention fusion must handle that anchor type. + """ + model = create_bert_attention_pre_ln() + dir = tempfile.mkdtemp() + model_path = os.path.join(dir, "pre_ln_attention_skiplayernorm.onnx") + onnx.save(model, model_path) + options = FusionOptions("bert") + options.use_raw_attention_mask(True) + options.enable_skip_layer_norm = True + optimized_model = optimize_model(model_path, opt_level=0, optimization_options=options) + os.remove(model_path) + + attention_nodes = [n for n in optimized_model.model.graph.node if n.op_type == "Attention"] + self.assertEqual(len(attention_nodes), 1, "Expected exactly 1 fused Attention node with SkipLN anchor") + num_heads_attr = next((a for a in attention_nodes[0].attribute if a.name == "num_heads"), None) + self.assertIsNotNone(num_heads_attr) + self.assertEqual(num_heads_attr.i, 2) + def test_gpt2_attention_fusion(self): hidden_size = 64 num_heads = 4