Skip to content
70 changes: 38 additions & 32 deletions onnxruntime/python/tools/transformers/fusion_attention_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,42 +269,48 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
attention_last_node = reshape_qkv

add_qk = ""
causal_mask_nodes_1 = None
causal_mask_nodes_2 = None
if add_mask is not None:
# 4D Add after Q x K'
add_qk_nodes = self.model.match_parent_path(
add_mask,
[
"Where",
"Sub",
"Cast",
"Expand",
"Unsqueeze",
"Unsqueeze",
"Reshape",
"Reshape",
"Cast",
],
[1, 2, 1, 0, 0, 0, 0, 0, 0],
)
if add_qk_nodes is not None:
if add_mask.input[1] == "attention_mask":
add_qk = add_mask.input[1]
else:
# Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
# of computing causal mask.
causal_mask_nodes_1 = self.model.match_parent_path(
add_mask,
["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0, 0],
)
# If the model is exported with batch_size == 1, there is no Concat node
causal_mask_nodes_2 = self.model.match_parent_path(
# 4D Add after Q x K'
add_qk_nodes = self.model.match_parent_path(
add_mask,
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0],
[
"Where",
"Sub",
"Cast",
"Expand",
"Unsqueeze",
"Unsqueeze",
"Reshape",
"Reshape",
"Cast",
],
[1, 2, 1, 0, 0, 0, 0, 0, 0],
)
if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None:
logger.debug("fuse_attention: failed to match causal mask subgraph")
return
if add_qk_nodes is not None:
add_qk = add_mask.input[1]
else:
# Here we do not match the whole subgraph since it is very complex. Instead, we just check whether a key path
# of computing causal mask.
causal_mask_nodes_1 = self.model.match_parent_path(
add_mask,
["Concat", "Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0, 0],
)
# If the model is exported with batch_size == 1, there is no Concat node
causal_mask_nodes_2 = self.model.match_parent_path(
add_mask,
["Expand", "Unsqueeze", "Unsqueeze", "Where", "Less"],
[causal_mask_input_index, 0, 0, 0, 0],
)

if causal_mask_nodes_1 is None and causal_mask_nodes_2 is None:
logger.debug("fuse_attention: failed to match causal mask subgraph")
return

new_node = self.create_attention_node(
mask_index=None,
Expand All @@ -320,7 +326,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
output=attention_last_node.output[0],
add_qk_str=add_qk,
scale=None,
causal=(add_mask is not None),
causal=(causal_mask_nodes_1 is not None) or (causal_mask_nodes_2 is not None),
)
if new_node is None:
logger.debug("fuse_attention: failed to create fused node")
Expand Down
Binary file not shown.
70 changes: 59 additions & 11 deletions onnxruntime/test/python/transformers/test_phi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(self):
self.attn = PhiVCLIPAttention()
self.ln = torch.nn.LayerNorm(20, eps=1e-05)

def forward(self, x):
def forward(self, x, attention_mask=None):
# SkipLayerNorm ------+
# | |
# Attention |
Expand All @@ -163,8 +163,7 @@ def forward(self, x):
x = self.ln(x)
residual = x

# Attention + MatMul
x = self.attn(x)
x = self.attn(x, attention_mask=attention_mask)

# SkipLayerNorm
x = residual + x
Expand Down Expand Up @@ -194,14 +193,31 @@ def verify_fusion(self, optimized_model, expected_model_filename):
)

def export(self, model, inputs):
torch.onnx.export(
model,
args=inputs,
f=os.path.join(os.path.dirname(__file__), "export.onnx"),
export_params=True,
opset_version=14,
do_constant_folding=True,
)
path = os.path.join(os.path.dirname(__file__), "export.onnx")

if len(inputs) == 2:
torch.onnx.export(
model,
args=inputs,
f=path,
export_params=True,
opset_version=14,
do_constant_folding=True,
input_names=["input", "attention_mask"],
dynamic_axes={
"input": {0: "batch", 1: "seq"},
"attention_mask": {0: "batch", 2: "seq", 3: "seq"},
},
)
else:
torch.onnx.export(
model,
args=inputs,
f=path,
export_params=True,
opset_version=14,
do_constant_folding=True,
)

def tearDown(self):
path = os.path.join(os.path.dirname(__file__), "export.onnx")
Expand Down Expand Up @@ -249,6 +265,38 @@ def test_phi_vision_attention(self):
)
self.verify_fusion(optimized_model, "phi-3.5-v-instruct-vision-attention.onnx")

def test_phi_vision_attention_with_mask(self):
model = PhiVCLIPAttentionAndLayerNorm()

batch, seq_len, dim = 1, 2, 20
mask = torch.zeros(batch, 1, seq_len, seq_len)
mask[:, 1:] = float("-inf")

inputs = (torch.randn(batch, seq_len, dim), mask)
self.export(model, inputs)
original_model = onnx.load(os.path.join(os.path.dirname(__file__), "export.onnx"))
options = FusionOptions("clip")
optimized_model = optimize_model(
original_model,
model_type="clip",
num_heads=2,
hidden_size=20,
optimization_options=options,
opt_level=0,
use_gpu=True,
)
self.verify_fusion(optimized_model, "phi-4-v-instruct-vision-attention.onnx")

graph = optimized_model.model.graph
attention_node = next((n for n in graph.node if n.name == "Attention_0"), None)
self.assertIsNotNone(attention_node, "Could not find the Attention fused node")
attr_names = [attr.name for attr in attention_node.attribute]
self.assertNotIn(
"unidirectional",
attr_names,
f"The attention node should not have a 'unidirectional' attribute: {attr_names}",
)


if __name__ == "__main__":
unittest.main()
Loading