Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,13 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
add_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)
if add_before_layernorm is not None:
start_node = add_before_layernorm
elif self.model.find_graph_input(normalize_node.input[0]) is not None:
# Pre-LN first block: LN fed directly by graph input. QKV matching will
# still fail from this (first) LN anchor because its inputs are weights, not
# the QKV projection path. The real fusion happens when fuse() is called
# again from the second LN/SkipLN anchor after the residual Add, where the
# other_inputs and root_input changes (#2-#4) take effect.
start_node = normalize_node
else:
return

Expand All @@ -917,7 +924,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
other_inputs = []
for _i, node_input in enumerate(start_node.input):
if node_input not in output_name_to_node:
continue
if self.model.find_graph_input(node_input) is None:
continue

if node_input == qkv_nodes[0].output[0]:
continue
Expand Down Expand Up @@ -946,7 +954,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
root_input = mul_before_layernorm.output[0]
else:
return
elif normalize_node.op_type == "LayerNormalization":
elif normalize_node.op_type in ("LayerNormalization", "SkipLayerNormalization"):
children = input_name_to_nodes[root_input]
for child in children:
if child.op_type == "LayerNormalization":
Expand All @@ -961,9 +969,10 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
# | |
# | |
# +---------------------------------------------------------------------+
parent_node = output_name_to_node[root_input]
if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4:
root_input = parent_node.output[0]
if root_input in output_name_to_node:
parent_node = output_name_to_node[root_input]
if parent_node.op_type == "SkipLayerNormalization" and len(parent_node.output) == 4:
root_input = parent_node.output[0]

children = input_name_to_nodes[root_input]
children_types = [child.op_type for child in children]
Expand Down
197 changes: 197 additions & 0 deletions onnxruntime/test/python/transformers/bert_model_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,203 @@ def create_bert_attention(
return helper.make_model(graph, opset_imports=(opsetid,))


def create_bert_attention_pre_ln(
input_hidden_size=16,
num_heads=2,
pruned_qk_hidden_size=16,
pruned_v_hidden_size=16,
switch_add_inputs=False,
):
"""Create a pre-layer-norm first block attention graph (no mask).

Unlike post-LN, the first block of a pre-LN model has no Add before the
first LayerNormalization — the graph input feeds LN directly. The residual
skip connection adds the graph input (not the LN output) to the attention
output. No attention mask is included so the graph exercises the
``is_no_mask_attention`` code path (Softmax -> Div -> MatMul).

Graph structure::

input_1 -> LN -> MatMul Q/K/V -> ... -> Add(attn_out, input_1) -> LN -> output
"""
nodes = [
# First LayerNormalization takes graph input directly (no preceding Add)
helper.make_node(
"LayerNormalization",
["input_1", "layer_norm_weight", "layer_norm_bias"],
["layernorm_out"],
"layernorm",
axis=-1,
epsion=0.000009999999747378752,
),
# q nodes
helper.make_node("MatMul", ["layernorm_out", "matmul_q_weight"], ["matmul_q_out"], "matmul_q"),
helper.make_node(
"Add",
reverse_if(["matmul_q_out", "add_q_weight"], switch_add_inputs),
["add_q_out"],
"add_q",
),
helper.make_node(
"Reshape",
["add_q_out", "reshape_weight_qk"],
["reshape_q_out"],
"reshape_q",
),
helper.make_node(
"Transpose",
["reshape_q_out"],
["transpose_q_out"],
"transpose_q",
perm=[0, 2, 1, 3],
),
# k nodes
helper.make_node("MatMul", ["layernorm_out", "matmul_k_weight"], ["matmul_k_out"], "matmul_k"),
helper.make_node(
"Add",
reverse_if(["matmul_k_out", "add_k_weight"], switch_add_inputs),
["add_k_out"],
"add_k",
),
helper.make_node(
"Reshape",
["add_k_out", "reshape_weight_qk"],
["reshape_k_out"],
"reshape_k",
),
helper.make_node(
"Transpose",
["reshape_k_out"],
["transpose_k_out"],
"transpose_k",
perm=[0, 2, 3, 1],
),
# qk nodes (no mask — uses the is_no_mask_attention path: Softmax -> Div -> MatMul)
helper.make_node(
"MatMul",
["transpose_q_out", "transpose_k_out"],
["matmul_qk_out"],
"matmul_qk",
),
helper.make_node("Div", ["matmul_qk_out", "div_weight"], ["div_qk_out"], "div_qk"),
helper.make_node("Softmax", ["div_qk_out"], ["softmax_qk_out"], "softmax_qk", axis=3),
# v nodes
helper.make_node("MatMul", ["layernorm_out", "matmul_v_weight"], ["matmul_v_out"], "matmul_v"),
helper.make_node("Add", ["matmul_v_out", "add_v_weight"], ["add_v_out"], "add_v"),
helper.make_node("Reshape", ["add_v_out", "reshape_weight_v"], ["reshape_v_out"], "reshape_v"),
helper.make_node(
"Transpose",
["reshape_v_out"],
["transpose_v_out"],
"transpose_v",
perm=[0, 2, 1, 3],
),
# qkv nodes
helper.make_node(
"MatMul",
["softmax_qk_out", "transpose_v_out"],
["matmul_qkv_1_out"],
"matmul_qkv_1",
),
helper.make_node(
"Transpose",
["matmul_qkv_1_out"],
["transpose_qkv_out"],
"transpose_qkv",
perm=[0, 2, 1, 3],
),
helper.make_node(
"Reshape",
["transpose_qkv_out", "reshape_weight_qkv"],
["reshape_qkv_out"],
"reshape_qkv",
),
helper.make_node(
"MatMul",
["reshape_qkv_out", "matmul_qkv_weight"],
["matmul_qkv_2_out"],
"matmul_qkv_2",
),
helper.make_node(
"Add",
reverse_if(["matmul_qkv_2_out", "add_qkv_weight"], switch_add_inputs),
["add_qkv_out"],
"add_qkv",
),
# Residual skip: adds attention output with original graph input (not LN output)
helper.make_node(
"Add",
reverse_if(["add_qkv_out", "input_1"], switch_add_inputs),
["skip_output"],
"add_skip",
),
helper.make_node(
"LayerNormalization",
["skip_output", "layer_norm_weight_2", "layer_norm_bias_2"],
["output"],
"layernorm2",
axis=-1,
epsion=0.000009999999747378752,
),
]

pruned_qk_head_size = int(pruned_qk_hidden_size / num_heads)
pruned_v_head_size = int(pruned_v_hidden_size / num_heads)
initializers = [
float_tensor("layer_norm_weight", [input_hidden_size]),
float_tensor("layer_norm_bias", [input_hidden_size]),
float_tensor("layer_norm_weight_2", [input_hidden_size]),
float_tensor("layer_norm_bias_2", [input_hidden_size]),
float_tensor("matmul_q_weight", [input_hidden_size, pruned_qk_hidden_size]),
float_tensor("matmul_k_weight", [input_hidden_size, pruned_qk_hidden_size]),
float_tensor("matmul_v_weight", [input_hidden_size, pruned_v_hidden_size]),
float_tensor("matmul_qkv_weight", [pruned_v_hidden_size, input_hidden_size]),
float_tensor("add_q_weight", [pruned_qk_hidden_size]),
float_tensor("add_k_weight", [pruned_qk_hidden_size]),
float_tensor("add_v_weight", [pruned_v_hidden_size]),
float_tensor("add_qkv_weight", [input_hidden_size]),
helper.make_tensor("div_weight", TensorProto.FLOAT, [1], [math.sqrt(pruned_qk_head_size)]),
helper.make_tensor(
"reshape_weight_qk",
TensorProto.INT64,
[4],
[0, 0, num_heads, pruned_qk_head_size],
),
helper.make_tensor(
"reshape_weight_v",
TensorProto.INT64,
[4],
[0, 0, num_heads, pruned_v_head_size],
),
helper.make_tensor("reshape_weight_qkv", TensorProto.INT64, [3], [0, 0, pruned_v_hidden_size]),
]

batch_size = 1
sequence_length = 3
graph = helper.make_graph(
[node for node in nodes if node],
"PreLNAttentionFusion",
[ # inputs: only one embedding input (no preceding Add)
helper.make_tensor_value_info(
"input_1",
TensorProto.FLOAT,
[batch_size, sequence_length, input_hidden_size],
),
],
[ # outputs
helper.make_tensor_value_info(
"output",
TensorProto.FLOAT,
[batch_size, sequence_length, input_hidden_size],
),
],
initializers,
)

opsetid = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16))
return helper.make_model(graph, opset_imports=(opsetid,))


def create_tf2onnx_attention_3d(input_hidden_size=16, num_heads=4, head_size=4, use_float_mask=False):
# unsqueeze in opset version 13 has two inputs (axis is moved from attribute to input).
has_unsqueeze_two_inputs = version.parse(onnx.__version__) >= version.parse("1.8.0")
Expand Down
64 changes: 63 additions & 1 deletion onnxruntime/test/python/transformers/test_attention_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
# --------------------------------------------------------------------------

import os
import tempfile
import unittest

import onnx
from bart_model_generator import create_bart_attention_sdpa
from bert_model_generator import create_bert_attention, create_tf2onnx_attention_3d
from bert_model_generator import create_bert_attention, create_bert_attention_pre_ln, create_tf2onnx_attention_3d
from gpt2_model_generator import create_gpt2_attention
from model_loader import get_test_data_path
from parity_utilities import find_transformers_source
Expand Down Expand Up @@ -152,6 +153,67 @@ def test_3d_attention_fusion_tf2onnx_model(self):

self.verify_fusion(optimized_model, "bert_3d_attention_opt.onnx")

def test_attention_fusion_pre_ln(self):
"""Test attention fusion for pre-layer-norm first block.

In a pre-LN model the first block has no Add before the first
LayerNormalization — the graph input feeds LN directly.
"""
model = create_bert_attention_pre_ln()
dir = tempfile.mkdtemp()
model_path = os.path.join(dir, "pre_ln_attention.onnx")
onnx.save(model, model_path)
options = FusionOptions("bert")
options.use_raw_attention_mask(True)
optimized_model = optimize_model(model_path, opt_level=0, optimization_options=options)
os.remove(model_path)

attention_nodes = [n for n in optimized_model.model.graph.node if n.op_type == "Attention"]
self.assertEqual(len(attention_nodes), 1, "Expected exactly 1 fused Attention node")
num_heads_attr = next((a for a in attention_nodes[0].attribute if a.name == "num_heads"), None)
self.assertIsNotNone(num_heads_attr)
self.assertEqual(num_heads_attr.i, 2)

def test_attention_fusion_pre_ln_reverse_add_order(self):
"""Pre-LN fusion with reversed Add input ordering."""
model = create_bert_attention_pre_ln(switch_add_inputs=True)
dir = tempfile.mkdtemp()
model_path = os.path.join(dir, "pre_ln_attention_reverse.onnx")
onnx.save(model, model_path)
options = FusionOptions("bert")
options.use_raw_attention_mask(True)
optimized_model = optimize_model(model_path, opt_level=0, optimization_options=options)
os.remove(model_path)

attention_nodes = [n for n in optimized_model.model.graph.node if n.op_type == "Attention"]
self.assertEqual(len(attention_nodes), 1, "Expected exactly 1 fused Attention node")
num_heads_attr = next((a for a in attention_nodes[0].attribute if a.name == "num_heads"), None)
self.assertIsNotNone(num_heads_attr)
self.assertEqual(num_heads_attr.i, 2)

def test_attention_fusion_pre_ln_with_skiplayernorm(self):
"""Pre-LN fusion when SkipLayerNorm fusion runs first (exercises Change 3).

The optimizer runs fuse_skip_layer_norm before fuse_attention. When enabled,
the Add + LayerNorm after the residual becomes a SkipLayerNormalization node,
and attention fusion must handle that anchor type.
"""
model = create_bert_attention_pre_ln()
dir = tempfile.mkdtemp()
model_path = os.path.join(dir, "pre_ln_attention_skiplayernorm.onnx")
onnx.save(model, model_path)
options = FusionOptions("bert")
options.use_raw_attention_mask(True)
options.enable_skip_layer_norm = True
optimized_model = optimize_model(model_path, opt_level=0, optimization_options=options)
os.remove(model_path)

attention_nodes = [n for n in optimized_model.model.graph.node if n.op_type == "Attention"]
self.assertEqual(len(attention_nodes), 1, "Expected exactly 1 fused Attention node with SkipLN anchor")
num_heads_attr = next((a for a in attention_nodes[0].attribute if a.name == "num_heads"), None)
self.assertIsNotNone(num_heads_attr)
self.assertEqual(num_heads_attr.i, 2)

def test_gpt2_attention_fusion(self):
hidden_size = 64
num_heads = 4
Expand Down
Loading