Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 91 additions & 20 deletions onnxruntime/python/tools/transformers/fusion_bart_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
[1, 1, 0, 0, 0],
)

# For LayerNormalization (when SkipLayerNorm fusion doesn't run, e.g. SDPA models where
# symbolic shape inference fails), there's an extra Add node for the residual connection
# between the LayerNorm and the attention output path.
add_before_layernorm = None
if qkv_nodes is None:
qkv_nodes_with_residual = self.model.match_parent_path(
normalize_node,
["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
[0, None, 0, 0, 0, 0],
)
if qkv_nodes_with_residual is not None:
add_before_layernorm = qkv_nodes_with_residual[0]
qkv_nodes = qkv_nodes_with_residual[1:]

if qkv_nodes is not None:
(
add_out,
Expand All @@ -45,16 +60,23 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
logger.debug("fuse_attention: failed to match qkv path")
return

other_inputs = []
for input_ in normalize_node.input:
if input_ not in output_name_to_node:
continue
if input_ == qkv_nodes[0].output[0]:
continue
other_inputs.append(input_)
if len(other_inputs) != 1:
return
root_input = other_inputs[0]
if add_before_layernorm is not None:
# LayerNorm case: root_input is the non-attention input of the residual Add
if add_before_layernorm.input[0] == add_out.output[0]:
root_input = add_before_layernorm.input[1]
else:
root_input = add_before_layernorm.input[0]
else:
other_inputs = []
for input_ in normalize_node.input:
if input_ not in output_name_to_node:
continue
if input_ == qkv_nodes[0].output[0]:
continue
other_inputs.append(input_)
if len(other_inputs) != 1:
return
root_input = other_inputs[0]

# Sometimes the input name to the attention MatMul nodes does not match the input name to the end
# SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul
Expand Down Expand Up @@ -148,13 +170,25 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):

qk_nodes_no_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
qk_nodes_with_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
# SDPA: NaN guard (Where(IsNaN, 0, softmax)) wraps the Softmax output.
# Where input[2] is the Softmax output (value when condition is False).
qk_nodes_sdpa_no_mask = self.model.match_parent_path(matmul_qkv, ["Where", "Softmax", "MatMul"], [0, 2, 0])
qk_nodes_sdpa_with_mask = self.model.match_parent_path(
matmul_qkv, ["Where", "Softmax", "Add", "MatMul"], [0, 2, 0, 0]
)
qk_nodes, add_qk = [], None
if qk_nodes_no_mask is not None:
_, matmul_qk = qk_nodes_no_mask
qk_nodes = qk_nodes_no_mask
elif qk_nodes_with_mask is not None:
_, add_qk, matmul_qk = qk_nodes_with_mask
qk_nodes = qk_nodes_with_mask
elif qk_nodes_sdpa_no_mask is not None:
_, _, matmul_qk = qk_nodes_sdpa_no_mask
qk_nodes = qk_nodes_sdpa_no_mask
elif qk_nodes_sdpa_with_mask is not None:
_, _, add_qk, matmul_qk = qk_nodes_sdpa_with_mask
qk_nodes = qk_nodes_sdpa_with_mask
else:
logger.debug("fuse_attention: failed to match qk path")
return
Expand All @@ -169,13 +203,22 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
["Mul", "Transpose", "Reshape", "Add", "MatMul"],
[0, 0, 0, 0, 1],
)
# SDPA: Mul(scale) applied before Transpose, MatMul may be at any Add input.
q_nodes_sdpa = self.model.match_parent_path(
matmul_qk,
["Mul", "Transpose", "Reshape", "Add", "MatMul"],
[0, 0, 0, 0, None],
)
q_nodes = []
if q_nodes_hf is not None:
q_nodes = q_nodes_hf
(transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes
elif q_nodes_oai is not None:
q_nodes = q_nodes_oai
(mul_q, transpose_q, reshape_q, add_q, matmul_q) = q_nodes
elif q_nodes_sdpa is not None:
q_nodes = q_nodes_sdpa
(mul_q, transpose_q, reshape_q, add_q, matmul_q) = q_nodes
else:
logger.debug("fuse_attention: failed to match q path")
return
Expand All @@ -200,6 +243,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
["Mul", "Transpose", "Reshape", "Reshape", "Transpose"],
[1, 0, 0, 0, 0],
)
# SDPA: K is scaled (Mul) and transposed via Reshape->Transpose(0,2,1)->Reshape chain.
k_nodes_sdpa = self.model.match_parent_path(
matmul_qk,
["Mul", "Reshape", "Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"],
[1, 0, 0, 0, 0, 0, 0, None],
)
past_k, present_k = "", ""
k_nodes, add_k, matmul_k = [], None, None
if k_nodes_no_past_hf is not None:
Expand All @@ -221,6 +270,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
# Hugging Face's cross-attention where past_k is used directly as key
k_nodes = [output_name_to_node[matmul_qk.input[1]]]
past_k = k_nodes[0].input[0]
elif k_nodes_sdpa is not None:
k_nodes = k_nodes_sdpa
(_, _, _, _, transpose_k, reshape_k, add_k, matmul_k) = k_nodes
elif k_nodes_past_or_present_oai is not None:
k_nodes = k_nodes_past_or_present_oai
(_, transpose_k, reshape_k, matmul_k) = k_nodes
Expand Down Expand Up @@ -291,19 +343,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
)

# There are 5 types of attention:
# 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_no_mask
# 2) Decoder self attention with one_root_input=True and qk_nodes=qk_nodes_with_mask
# 3) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_no_mask
# 4) Decoder self attention with past with one_root_input=True and qk_nodes=qk_nodes_with_mask and past_k=past_decoder_key and past_v=past_decoder_value
# 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_no_mask
encoder_attention = one_root_input and qk_nodes == qk_nodes_no_mask
decoder_self_attention = one_root_input and qk_nodes == qk_nodes_with_mask
decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_no_mask
# 1) Encoder attention with one_root_input=True and no mask
# 2) Decoder self attention with one_root_input=True and has mask
# 3) Decoder cross attention with two_root_inputs=True and no mask
# 4) Decoder self attention with past with one_root_input=True and has mask and past_k and past_v
# 5) Decoder cross attention with past with three_root_inputs=True and no mask
# Derive mask presence from which QK pattern matched rather than re-walking the graph.
# This reuses the result of match_parent_paths above, which already tried both masked and
# unmasked variants and returned the first successful match.
has_mask = qk_nodes in (qk_nodes_with_mask, qk_nodes_sdpa_with_mask)
no_mask = not has_mask
Comment thread
tianleiwu marked this conversation as resolved.
encoder_attention = one_root_input and no_mask
decoder_self_attention = one_root_input and has_mask
decoder_cross_attention = two_root_inputs and no_mask
decoder_self_attention_with_past = decoder_self_attention and bool(past_k) and bool(past_v)
decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_no_mask
decoder_cross_attention_with_past = three_root_inputs and no_mask

# For decoder self-attentions, the attention mask needs to be included in the attention node
causal_mask = qk_nodes == qk_nodes_with_mask
causal_mask = has_mask
mask_nodes = []
if causal_mask:
mask_nodes_bart = self.model.match_parent_path(
Expand Down Expand Up @@ -349,6 +406,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
attention_last_node = reshape_qkv
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)

# Fall back to user-specified values when detected values are invalid
# (e.g., SDPA models use -1 in reshape shapes for dynamic dimensions).
if (num_heads <= 0 or hidden_size <= 0) and self.num_heads > 0 and self.hidden_size > 0:
logger.debug(
"fuse_attention: reshape dims invalid (num_heads=%d, hidden_size=%d), "
"falling back to user-specified num_heads=%d, hidden_size=%d",
num_heads,
hidden_size,
self.num_heads,
self.hidden_size,
)
num_heads = self.num_heads
hidden_size = self.hidden_size
Comment thread
tianleiwu marked this conversation as resolved.

if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
return
Expand Down
Loading
Loading