Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
root_input = layernorm_node.output[0]
else:
return
elif mul_children is not None and len(mul_children) == 5:
root_input = mul_before_layernorm.output[0]
else:
return

Expand Down
74 changes: 49 additions & 25 deletions onnxruntime/python/tools/transformers/fusion_embedlayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,32 @@ class FusionEmbedLayerNoMask(Fusion):
SkipLayerNormalization
"""
def __init__(self, model: OnnxModel, description='no mask'):
super().__init__(model, "EmbedLayerNormalization", "SkipLayerNormalization", description)
super().__init__(model, "EmbedLayerNormalization", ["SkipLayerNormalization", "LayerNormalization"], description)
self.utils = FusionUtils(model)
self.attention = None

def match_segment_path(self, normalize_node, input_name_to_nodes, output_name_to_node, input_ids_cast_node):
segment_ids = None
segment_embedding_gather = None

segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1])
if normalize_node.op_type == "SkipLayerNormalization":
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [1])

if segment_embedding_path is None:
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1])
if segment_embedding_path is None:
logger.info("Segment embedding is not found. Embed layer cannot be fused.")
return
_, segment_embedding_gather = segment_embedding_path
else:
segment_embedding_gather = segment_embedding_path[0]
elif normalize_node.op_type == "LayerNormalization":
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Add', 'Gather'], [0, 0, 1])

if segment_embedding_path is None:
segment_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather'], [0, 1])
if segment_embedding_path is None:
logger.info("Segment embedding is not found. Embed layer cannot be fused.")
return
_, segment_embedding_gather = segment_embedding_path
else:
segment_embedding_gather = segment_embedding_path[0]
_, _, segment_embedding_gather = segment_embedding_path

segment_ids = segment_embedding_gather.input[1]

Expand Down Expand Up @@ -92,7 +100,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
logger.debug(
"Failed to match path SkipLayerNormalization[0] <-- Add <-- Gather or SkipLayerNormalization[0] <-- Gather"
)
return
if node.op_type != "LayerNormalization" or self.model.match_parent_path(node, ['Add', 'Gather'], [0, 1]) is None:
return

self.attention = self.model.find_first_child_by_type(node, 'Attention', input_name_to_nodes, recursive=False)
if self.attention is None:
Expand All @@ -114,19 +123,23 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
if word_embedding_path is not None:
add_node, word_embedding_gather = word_embedding_path
else:
word_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [0])
word_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Add', 'Gather'], [0, 0, 0])
if word_embedding_path is not None:
word_embedding_gather = word_embedding_path[0]
is_distill = True
from packaging.version import Version
import onnxruntime
if Version(onnxruntime.__version__) <= Version("1.4.0"):
logger.warning(
'Please install onnxruntime with version > 1.4.0 for embedlayer fusion support for distilbert')
return
_, add_node, word_embedding_gather = word_embedding_path
else:
logger.info("Word embedding path is not found. Embed layer cannot be fused.")
return
word_embedding_path = self.model.match_parent_path(normalize_node, ['Gather'], [0])
if word_embedding_path is not None:
word_embedding_gather = word_embedding_path[0]
is_distill = True
from packaging.version import Version
import onnxruntime
if Version(onnxruntime.__version__) <= Version("1.4.0"):
logger.warning(
'Please install onnxruntime with version > 1.4.0 for embedlayer fusion support for distilbert')
return
else:
logger.info("Word embedding path is not found. Embed layer cannot be fused.")
return

input_ids = word_embedding_gather.input[1]

Expand Down Expand Up @@ -162,8 +175,12 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
if position_embedding_path is not None:
position_embedding_weight_node, position_embedding_node_before_gather = position_embedding_path
else:
logger.info("Position embedding path is not found. Embed layer cannot be fused.")
return
position_embedding_path = self.model.match_parent_path(normalize_node, ['Add', 'Gather', 'Slice'], [0, 1, 1])
if position_embedding_path is not None:
_, position_embedding_weight_node, position_embedding_node_before_gather = position_embedding_path
else:
logger.info("Position embedding path is not found. Embed layer cannot be fused.")
return

if position_embedding_shape is not None and position_embedding_shape.input[0] != input_ids:
logger.info("position and word embedding is expected to be applied on same input")
Expand Down Expand Up @@ -191,6 +208,13 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
node_name = self.model.create_node_name('EmbedLayerNormalization')
output_name = node_name + "_output"

if normalize_node.op_type == "LayerNormalization":
gamma = normalize_node.input[1]
beta = normalize_node.input[2]
elif normalize_node.op_type == "SkipLayerNormalization":
gamma = normalize_node.input[2]
beta = normalize_node.input[3]

embed_node_inputs = None
if is_distill == False:
segment_path = self.match_segment_path(normalize_node, input_name_to_nodes, output_name_to_node,
Expand All @@ -206,8 +230,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
word_embedding_gather.input[0],
position_embedding_weight_node.input[0],
segment_embedding_gather.input[0],
normalize_node.input[2],
normalize_node.input[3] # gamma and beta
gamma,
beta # gamma and beta
]
else:
embed_node_inputs = [
Expand All @@ -216,8 +240,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
word_embedding_gather.input[0],
position_embedding_weight_node.input[0],
'',
normalize_node.input[2],
normalize_node.input[3] # gamma and beta
gamma,
beta # gamma and beta
]

embed_node = helper.make_node('EmbedLayerNormalization',
Expand Down