Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 64 additions & 15 deletions onnxruntime/python/tools/transformers/fusion_skiplayernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,25 @@
logger = getLogger(__name__)


def _is_broadcast_skip(input_shape, skip_shape):
"""Check if skip_shape can broadcast to input_shape for SkipLayerNormalization.

The kernel supports: input 3D (B,S,H) with skip 3D (1,S,H) or skip 2D (S,H).
"""
if len(input_shape) != 3:
return False
if len(skip_shape) == 3:
return skip_shape[0] == 1 and skip_shape[1] == input_shape[1] and skip_shape[2] == input_shape[2]
if len(skip_shape) == 2:
return skip_shape[0] == input_shape[1] and skip_shape[1] == input_shape[2]
return False


class FusionSkipLayerNormalization(Fusion):
"""
Fuse Add + LayerNormalization into one node: SkipLayerNormalization
Note: This fusion does not check the input shape of Add and LayerNormalization.
Fuse Add + LayerNormalization into one node: SkipLayerNormalization.
Supports broadcasting of the skip input: (1, sequence_length, hidden_size)
or (sequence_length, hidden_size) will be broadcast to match the input shape.
"""

def __init__(
Expand All @@ -31,9 +46,33 @@ def __init__(
# Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True)
if self.shape_infer_helper is None:
# TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op.
# TODO(tianleiwu): support subgraph in shape inference.
logger.warning("symbolic shape inference disabled or failed.")

def get_skip_index(self, add):
"""Identify which Add input is the skip tensor (the one that may broadcast).

Returns (skip_index, broadcast):
skip_index: 0 or 1 (which Add input is skip), -1 if incompatible
broadcast: True if broadcasting is needed
"""
shape_a = self.shape_infer_helper.get_edge_shape(add.input[0])
shape_b = self.shape_infer_helper.get_edge_shape(add.input[1])
if shape_a is None or shape_b is None:
return -1, False

if shape_a == shape_b:
return (1, False) if len(shape_a) == 3 else (-1, False)

# Check if b is a broadcastable skip for a
if _is_broadcast_skip(shape_a, shape_b):
return 1, True
# Check if a is a broadcastable skip for b
if _is_broadcast_skip(shape_b, shape_a):
return 0, True

return -1, False

def fuse(self, node, input_name_to_nodes, output_name_to_node):
add = self.model.get_parent(node, 0, output_name_to_node)

Expand All @@ -57,19 +96,15 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
# Root Mean Square Layer Normalization
simplified = node.op_type == "SimplifiedLayerNormalization"

skip_index = 1 # default: add.input[1] is the skip
_broadcast = False

if hasattr(self, "shape_infer_helper"):
if self.shape_infer_helper is not None:
if (
self.shape_infer_helper.get_edge_shape(add.input[0])
and len(self.shape_infer_helper.get_edge_shape(add.input[0])) != 3
):
logger.debug("skip SkipLayerNormalization fusion since shape of input %s is not 3D", add.input[0])
return

# TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size)
if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
skip_index, _broadcast = self.get_skip_index(add)
if skip_index < 0:
logger.debug(
"skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same",
"skip SkipLayerNormalization fusion since shapes of inputs (%s, %s) are not compatible",
add.input[0],
add.input[1],
)
Expand All @@ -83,6 +118,19 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
if self.model.match_parent_path(gather_path[0], ["ConstantOfShape"], [1]) is None:
return

# When broadcasting is needed, check that neither Add input comes from a Gather
# (embedding lookup). Embedding Add+LayerNorm should be fused by EmbedLayerNormalization
# later in the pipeline, not as SkipLayerNormalization.
if _broadcast:
for i in range(2):
parent = self.model.get_parent(add, i, output_name_to_node)
if parent is not None and parent.op_type == "Gather":
logger.debug(
"skip SkipLayerNormalization broadcast fusion since Add input %d comes from Gather (embedding)",
i,
)
return

# This means that the residual Add before the LayerNormalization produces an output
# that is consumed by some other nodes or graph output other than the LayerNormalization itself
# We can still go ahead with the SkipLayerNormalization fusion but we need to
Expand All @@ -106,10 +154,11 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
if self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node):
self.nodes_to_remove.extend([add, node])

input_index = 1 - skip_index
inputs = (
[add.input[0], add.input[1], node.input[1], node.input[2]]
[add.input[input_index], add.input[skip_index], node.input[1], node.input[2]]
if not simplified
else [add.input[0], add.input[1], node.input[1]]
else [add.input[input_index], add.input[skip_index], node.input[1]]
)
normalize_node = helper.make_node(
self.fused_op_type,
Expand Down
191 changes: 190 additions & 1 deletion onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def create_test_model(
["output"],
"layernorm",
axis=-1,
epsion=0.000009999999747378752,
epsilon=0.000009999999747378752,
)

initializers = [ # initializers
Expand Down Expand Up @@ -270,6 +270,195 @@ def test_skip_layer_norm_graph_output_cast_bias2(self):
)
os.remove(model_name)

def create_broadcast_test_model(
self,
batch_size: int = 2,
sequence_length: int = 3,
hidden_size: int = 4,
skip_shape: str = "2d", # "2d" for (seq, hidden), "3d_batch1" for (1, seq, hidden)
skip_on_input: int = 1, # Which Add input index gets the skip (smaller) shape
add_graph_output: bool = False,
simplified: bool = False, # Use SimplifiedLayerNormalization (RMS LayerNorm) instead
):
"""Create a test model where one Add input has a broadcast-compatible shape."""
if skip_shape == "2d":
skip_dims = [sequence_length, hidden_size]
elif skip_shape == "3d_batch1":
skip_dims = [1, sequence_length, hidden_size]
else:
raise ValueError(f"Unknown skip_shape: {skip_shape}")

full_dims = [batch_size, sequence_length, hidden_size]

add_before_layer_norm = helper.make_node("Add", ["input_1", "input_2"], ["layernorm_input"], "add_layernorm")

if simplified:
layer_norm = helper.make_node(
"SimplifiedLayerNormalization",
["layernorm_input", "layer_norm_weight"],
["output"],
"layernorm",
axis=-1,
epsilon=0.000009999999747378752,
)
initializers = [float_tensor("layer_norm_weight", [hidden_size])]
else:
layer_norm = helper.make_node(
"LayerNormalization",
["layernorm_input", "layer_norm_weight", "layer_norm_bias"],
["output"],
"layernorm",
axis=-1,
epsilon=0.000009999999747378752,
)
initializers = [
float_tensor("layer_norm_weight", [hidden_size]),
float_tensor("layer_norm_bias", [hidden_size]),
]

input_1_shape = full_dims if skip_on_input != 0 else skip_dims
input_2_shape = skip_dims if skip_on_input != 0 else full_dims

outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, full_dims)]
if add_graph_output:
outputs.append(helper.make_tensor_value_info("layernorm_input", TensorProto.FLOAT, full_dims))

graph = helper.make_graph(
[add_before_layer_norm, layer_norm],
"SkipLayerNormBroadcastModel",
[
helper.make_tensor_value_info("input_1", TensorProto.FLOAT, input_1_shape),
helper.make_tensor_value_info("input_2", TensorProto.FLOAT, input_2_shape),
],
outputs,
initializers,
)

onnx_opset = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16))
return helper.make_model(graph, opset_imports=(onnx_opset,))

def test_skip_layer_norm_broadcast_2d_skip(self):
"""2D skip (seq, hidden) on input[1] should fuse with input order preserved."""
model = self.create_broadcast_test_model(skip_shape="2d", skip_on_input=1)
model_name = "skip_layer_norm_broadcast_2d.onnx"
onnx.save(model, model_name)
self.verify_skip_layer_norm_fusion(
model_name,
{"Add": 0, "LayerNormalization": 0, "SkipLayerNormalization": 1, "Cast": 0},
["input_1", "input_2", "layer_norm_weight", "layer_norm_bias"],
["output"],
)
os.remove(model_name)

def test_skip_layer_norm_broadcast_2d_skip_swapped(self):
"""2D skip (seq, hidden) on input[0] should fuse with inputs swapped."""
model = self.create_broadcast_test_model(skip_shape="2d", skip_on_input=0)
model_name = "skip_layer_norm_broadcast_2d_swapped.onnx"
onnx.save(model, model_name)
self.verify_skip_layer_norm_fusion(
model_name,
{"Add": 0, "LayerNormalization": 0, "SkipLayerNormalization": 1, "Cast": 0},
["input_2", "input_1", "layer_norm_weight", "layer_norm_bias"],
["output"],
)
os.remove(model_name)

def test_skip_layer_norm_broadcast_3d_batch1(self):
"""3D skip (1, seq, hidden) on input[1] should fuse with input order preserved."""
model = self.create_broadcast_test_model(skip_shape="3d_batch1", skip_on_input=1)
model_name = "skip_layer_norm_broadcast_3d_batch1.onnx"
onnx.save(model, model_name)
self.verify_skip_layer_norm_fusion(
model_name,
{"Add": 0, "LayerNormalization": 0, "SkipLayerNormalization": 1, "Cast": 0},
["input_1", "input_2", "layer_norm_weight", "layer_norm_bias"],
["output"],
)
os.remove(model_name)

def test_skip_layer_norm_broadcast_3d_batch1_swapped(self):
"""3D skip (1, seq, hidden) on input[0] should fuse with inputs swapped."""
model = self.create_broadcast_test_model(skip_shape="3d_batch1", skip_on_input=0)
model_name = "skip_layer_norm_broadcast_3d_batch1_swapped.onnx"
onnx.save(model, model_name)
self.verify_skip_layer_norm_fusion(
model_name,
{"Add": 0, "LayerNormalization": 0, "SkipLayerNormalization": 1, "Cast": 0},
["input_2", "input_1", "layer_norm_weight", "layer_norm_bias"],
["output"],
)
os.remove(model_name)

def test_skip_layer_norm_broadcast_graph_output(self):
"""Broadcast fusion should preserve Add output when it is a graph output."""
model = self.create_broadcast_test_model(skip_shape="2d", skip_on_input=1, add_graph_output=True)
model_name = "skip_layer_norm_broadcast_graph_output.onnx"
onnx.save(model, model_name)
self.verify_skip_layer_norm_fusion(
model_name,
{"Add": 0, "LayerNormalization": 0, "SkipLayerNormalization": 1, "Cast": 0},
["input_1", "input_2", "layer_norm_weight", "layer_norm_bias"],
["output", "", "", "layernorm_input"],
)
os.remove(model_name)

def test_skip_layer_norm_broadcast_incompatible_shapes(self):
"""Incompatible broadcast shapes should not fuse.

Uses (2,3,4) + (1,1,4): broadcastable for Add but not supported by SkipLayerNorm
kernel (which requires skip seq_len == input seq_len).
"""
add_node = helper.make_node("Add", ["input_1", "input_2"], ["layernorm_input"], "add_layernorm")
layer_norm = helper.make_node(
"LayerNormalization",
["layernorm_input", "layer_norm_weight", "layer_norm_bias"],
["output"],
"layernorm",
axis=-1,
epsilon=0.000009999999747378752,
)
initializers = [
float_tensor("layer_norm_weight", [4]),
float_tensor("layer_norm_bias", [4]),
]
graph = helper.make_graph(
[add_node, layer_norm],
"IncompatibleShapesModel",
[
helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [2, 3, 4]),
helper.make_tensor_value_info("input_2", TensorProto.FLOAT, [1, 1, 4]), # seq_len mismatch
],
[helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 3, 4])],
initializers,
)
onnx_opset = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16))
model = helper.make_model(graph, opset_imports=(onnx_opset,))
model_name = "skip_layer_norm_incompatible.onnx"
onnx.save(model, model_name)

options = FusionOptions("bert")
optimized_model = optimize_model(model_name, optimization_options=options, opt_level=0)
self.assertEqual(len(optimized_model.get_nodes_by_op_type("SkipLayerNormalization")), 0)
self.assertEqual(len(optimized_model.get_nodes_by_op_type("Add")), 1)
os.remove(model_name)

def test_skip_simplified_layer_norm_broadcast(self):
"""SimplifiedLayerNormalization (RMS LayerNorm) with broadcast skip should fuse."""
model = self.create_broadcast_test_model(skip_shape="2d", skip_on_input=1, simplified=True)
model_name = "skip_simplified_layer_norm_broadcast.onnx"
onnx.save(model, model_name)

options = FusionOptions("bert")
optimized_model = optimize_model(model_name, optimization_options=options, opt_level=0)

sln_nodes = optimized_model.get_nodes_by_op_type("SkipSimplifiedLayerNormalization")
self.assertEqual(len(sln_nodes), 1)
self.assertEqual(len(optimized_model.get_nodes_by_op_type("Add")), 0)
self.assertEqual(len(optimized_model.get_nodes_by_op_type("SimplifiedLayerNormalization")), 0)
# SimplifiedLayerNorm has no bias, so only 3 inputs: input, skip, weight
self.assertEqual(list(sln_nodes[0].input), ["input_1", "input_2", "layer_norm_weight"])
os.remove(model_name)


if __name__ == "__main__":
unittest.main()
Loading