Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 71 additions & 44 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,14 @@ def __init__(
hidden_size: int,
num_heads: int,
attention_mask: AttentionMask,
use_multi_head_attention: bool = False,
):
super().__init__(model, "Attention", ["SkipLayerNormalization", "LayerNormalization"])
attention_op_name = "MultiHeadAttention" if use_multi_head_attention else "Attention"
super().__init__(model, attention_op_name, ["SkipLayerNormalization", "LayerNormalization"])
self.hidden_size = hidden_size
self.num_heads = num_heads
self.attention_mask = attention_mask
self.use_multi_head_attention = use_multi_head_attention

# Flags to show warning only once
self.num_heads_warning = True
Expand All @@ -108,18 +111,18 @@ def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[
Detect num_heads and hidden_size from Concat node in the following subgraph:

SkipLayerNormalization or EmbedLayerNormalization
/ \
MatMul Shape
| |
Add Gather(indices=0)
\ |
\ Unsqueeze
\ |
\ Concat (*, -1, 12, 64)
\ /
Reshape
|
Transpose
/ |
MatMul Shape
| |
Add Gather(indices=0)
| |
| Unsqueeze
| |
| Concat (*, -1, 12, 64)
| /
Reshape
|
Transpose
"""
if len(concat.input) == 4:
num_heads = self.model.get_constant_value(concat.input[2])
Expand Down Expand Up @@ -307,17 +310,18 @@ def create_attention_node(

attention_node_name = self.model.create_node_name("Attention")

weight = helper.make_tensor(
name=attention_node_name + "_qkv_weight",
data_type=TensorProto.FLOAT,
dims=[qw_in_size, qkv_weight_dim],
vals=qkv_weight.flatten().tolist(),
)
if not self.use_multi_head_attention:
weight = helper.make_tensor(
name=attention_node_name + "_qkv_weight",
data_type=TensorProto.FLOAT,
dims=[qw_in_size, qkv_weight_dim],
vals=qkv_weight.flatten().tolist(),
)

# Sometimes weights and bias are stored in fp16
if q_weight.data_type == 10:
weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
self.model.add_initializer(weight, self.this_graph_name)
# Sometimes weights and bias are stored in fp16
if q_weight.data_type == 10:
weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name))
self.model.add_initializer(weight, self.this_graph_name)

bias = helper.make_tensor(
name=attention_node_name + "_qkv_bias",
Expand All @@ -329,26 +333,48 @@ def create_attention_node(
bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name))
self.model.add_initializer(bias, self.this_graph_name)

attention_inputs = [
input,
attention_node_name + "_qkv_weight",
attention_node_name + "_qkv_bias",
]
if mask_index is not None:
attention_inputs.append(mask_index)
# For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights.
if self.use_multi_head_attention:
if add_qk_str is not None:
logger.debug("MultiHeadAttention does not support extra_add_qk: cannot fuse the attention.")
return None

attention_inputs = [
q_matmul.output[0],
k_matmul.output[0],
v_matmul.output[0],
attention_node_name + "_qkv_bias",
]
if mask_index is not None:
attention_inputs.append(mask_index)

attention_node = helper.make_node(
"MultiHeadAttention",
inputs=attention_inputs,
outputs=[output],
name=attention_node_name,
)
else:
attention_inputs.append("")
attention_inputs = [
input,
attention_node_name + "_qkv_weight",
attention_node_name + "_qkv_bias",
]
if mask_index is not None:
attention_inputs.append(mask_index)
else:
attention_inputs.append("")

if add_qk_str is not None:
attention_inputs.append("")
attention_inputs.append(add_qk_str)
if add_qk_str is not None:
attention_inputs.append("") # no past
attention_inputs.append(add_qk_str)

attention_node = helper.make_node(
"Attention",
inputs=attention_inputs,
outputs=[output],
name=attention_node_name,
)
attention_node = helper.make_node(
"Attention",
inputs=attention_inputs,
outputs=[output],
name=attention_node_name,
)
attention_node.domain = "com.microsoft"
attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])

Expand Down Expand Up @@ -595,10 +621,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):

self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
self.nodes_to_remove.extend(qk_nodes)
self.nodes_to_remove.extend(q_nodes)
self.nodes_to_remove.extend(k_nodes)
self.nodes_to_remove.extend(v_nodes)

# For MultiHeadAttention operator, MatMul nodes for Q/K/V projection shall not be fused.
self.nodes_to_remove.extend(q_nodes if not self.use_multi_head_attention else q_nodes[:-1])
self.nodes_to_remove.extend(k_nodes if not self.use_multi_head_attention else k_nodes[:-1])
self.nodes_to_remove.extend(v_nodes if not self.use_multi_head_attention else v_nodes[:-1])

# Use prune graph to remove mask nodes since they are shared by all attention nodes.
# self.nodes_to_remove.extend(mask_nodes)
self.prune_graph = True
116 changes: 70 additions & 46 deletions onnxruntime/python/tools/transformers/fusion_embedlayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,48 +63,64 @@ def check_attention_subgraph(
self.attention = self.model.find_first_child_by_type(
layernorm, "Attention", input_name_to_nodes, recursive=False
)
if self.attention is None:
# In case user disables attention fusion, check whether subgraph looks like Attention.
if layernorm.output[0] not in input_name_to_nodes:

if self.attention is not None:
return True

if layernorm.output[0] not in input_name_to_nodes:
return False
children = input_name_to_nodes[layernorm.output[0]]
children_types = sorted([child.op_type for child in children])

# Try find MultiHeadAttention
if children_types == ["MatMul", "MatMul", "MatMul", "SkipLayerNormalization"]:
for node in children:
if node.op_type == "SkipLayerNormalization":
path1 = self.model.match_parent_path(
node,
["Add", "MatMul", "MultiHeadAttention", "MatMul"],
[None, None, 0, 0],
)
if path1 is not None and path1[-1].input[0] == layernorm.output[0]:
self.cross_attention = path1[2]
return True

# In case user disables attention fusion, check whether subgraph looks like Attention.
# For Albert, there is MatMul+Add after embedding layer before attention.
if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
grandchildren = input_name_to_nodes[children[0].output[0]]
if (
len(grandchildren) == 1
and grandchildren[0].op_type == "Add"
and grandchildren[0].output[0] in input_name_to_nodes
):
nodes = input_name_to_nodes[grandchildren[0].output[0]]
for node in nodes:
if node.op_type == "Attention":
self.attention = node
return True
children_types = sorted([child.op_type for child in nodes])

# Two Shape nodes might be merged by ORT
if is_distil_bert:
# SkipLayerNormailization might exist when model has been optimized by ORT first.
if (
children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
):
logger.debug("No Attention like subgraph in children of LayerNormalization")
return False
children = input_name_to_nodes[layernorm.output[0]]

# For Albert, there is MatMul+Add after embedding layer before attention.
if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
grandchildren = input_name_to_nodes[children[0].output[0]]
if (
len(grandchildren) == 1
and grandchildren[0].op_type == "Add"
and grandchildren[0].output[0] in input_name_to_nodes
):
nodes = input_name_to_nodes[grandchildren[0].output[0]]
for node in nodes:
if node.op_type == "Attention":
self.attention = node
return True
children_types = sorted([child.op_type for child in nodes])
else:
children_types = sorted([child.op_type for child in children])

# Two Shape nodes might be merged by ORT
if is_distil_bert:
# SkipLayerNormailization might exist when model has been optimized by ORT first.
if (
children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
):
logger.debug("No Attention like subgraph in children of LayerNormalization")
return False
else:
if children_types != ["Add", "MatMul", "MatMul", "MatMul",] and children_types != [
"MatMul",
"MatMul",
"MatMul",
"SkipLayerNormalization",
]:
logger.debug("No Attention like subgraph in children of LayerNormalization")
return False
else:
if children_types != ["Add", "MatMul", "MatMul", "MatMul",] and children_types != [
"MatMul",
"MatMul",
"MatMul",
"SkipLayerNormalization",
]:
logger.debug("No Attention like subgraph in children of LayerNormalization")
return False

return True

def match_position_embedding_distilbert(self, position_embedding_gather, input_ids, output_name_to_node):
Expand Down Expand Up @@ -713,11 +729,15 @@ def replace_mask(self, mask_int32, attention_nodes):

for attention_node in attention_nodes:
logger.debug("update mask_index in %s", attention_node.name)
attention_node.input[3] = embed_node.output[1]
if attention_node.op_type == "Attention":
attention_node.input[3] = embed_node.output[1]
elif attention_node.op_type == "MultiHeadAttention":
attention_node.input[4] = embed_node.output[1]

def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Reset attention and embed_node so that we know fusion is successful when they are not None.
self.attention = None
self.cross_attention = None
self.embed_node = None
super().fuse(node, input_name_to_nodes, output_name_to_node)

Expand All @@ -729,15 +749,19 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
self.increase_counter("EmbedLayerNormalization(no mask)")
return

if self.attention is None:
if self.attention is None and self.cross_attention is None:
logger.debug("EmbedLayerNormalization will not have mask since attention node is not found")
self.increase_counter("EmbedLayerNormalization(no mask)")
return

mask_int32 = self.attention.input[3]
if self.attention:
mask_int32 = self.attention.input[3]
else:
mask_int32 = self.cross_attention.input[4]

children_nodes = input_name_to_nodes[mask_int32]
if self.model.find_graph_input(mask_int32):
attention_nodes = [node for node in children_nodes if node.op_type == "Attention"]
attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
self.replace_mask(mask_int32, attention_nodes)
self.increase_counter("EmbedLayerNormalization(with mask)")
return
Expand All @@ -749,7 +773,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):

node = output_name_to_node[mask_int32]
if node.op_type in ["ReduceSum", "Cast"]:
attention_nodes = [node for node in children_nodes if node.op_type == "Attention"]
attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
if node.op_type == "ReduceSum":
mask_int32 = node.input[0]
if len(children_nodes) == len(attention_nodes):
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def __init__(self, model_type):
self.enable_gelu = True
self.enable_layer_norm = True
self.enable_attention = True

# Use MultiHeadAttention instead of Attention operator. The difference:
# (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is
# merged into one.
# (2) Attention could only handle self attention; MultiHeadAttention could handle both self and cross attention.
# (3) MultiHeadAttention has only cuda implementation right now.
self.use_multi_head_attention = False

self.enable_skip_layer_norm = True
self.enable_embed_layer_norm = True
self.enable_bias_skip_layer_norm = True
Expand Down Expand Up @@ -48,6 +56,8 @@ def parse(args):
options.enable_layer_norm = False
if args.disable_attention:
options.enable_attention = False
if args.use_multi_head_attention:
options.use_multi_head_attention = True
if args.disable_skip_layer_norm:
options.enable_skip_layer_norm = False
if args.disable_embed_layer_norm:
Expand Down Expand Up @@ -165,3 +175,13 @@ def add_arguments(parser: ArgumentParser):
help="no attention mask. Only works for model_type=bert",
)
parser.set_defaults(no_attention_mask=False)

parser.add_argument(
"--use_multi_head_attention",
required=False,
action="store_true",
help="Use MultiHeadAttention instead of Attention operator for testing purpose. "
"Note that MultiHeadAttention might be slower than Attention since MatMul of input projection is excluded. "
"MultiHeadAttention has only CUDA implementation so the model can only run with cuda execution provider.",
)
parser.set_defaults(use_multi_head_attention=False)
12 changes: 9 additions & 3 deletions onnxruntime/python/tools/transformers/onnx_model_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,14 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo
if (options is None) or options.enable_skip_layer_norm:
self.fuse_skip_layer_norm()

if options is not None:
self.attention_mask.set_mask_format(options.attention_mask_format)
if options.use_multi_head_attention:
self.attention_fusion = FusionAttention(
self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention
)

if (options is None) or options.enable_attention:
if options is not None:
self.attention_mask.set_mask_format(options.attention_mask_format)
self.fuse_attention()

# Perform the MatMul fusion after the Attention fusion as we do not
Expand Down Expand Up @@ -438,6 +443,7 @@ def get_fused_operator_statistics(self):
ops = [
"EmbedLayerNormalization",
"Attention",
"MultiHeadAttention",
"Gelu",
"FastGelu",
"BiasGelu",
Expand All @@ -459,7 +465,7 @@ def is_fully_optimized(self):
"""
op_count = self.get_fused_operator_statistics()
embed = op_count["EmbedLayerNormalization"]
attention = op_count["Attention"] + op_count["QOrderedAttention"]
attention = op_count["Attention"] + op_count["MultiHeadAttention"] + op_count["QOrderedAttention"]
gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"]
layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"]
is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention)
Expand Down
Loading