diff --git a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py index 4728caaaf3289..743bf50f6c608 100644 --- a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py @@ -13,10 +13,25 @@ logger = getLogger(__name__) +def _is_broadcast_skip(input_shape, skip_shape): + """Check if skip_shape can broadcast to input_shape for SkipLayerNormalization. + + The kernel supports: input 3D (B,S,H) with skip 3D (1,S,H) or skip 2D (S,H). + """ + if len(input_shape) != 3: + return False + if len(skip_shape) == 3: + return skip_shape[0] == 1 and skip_shape[1] == input_shape[1] and skip_shape[2] == input_shape[2] + if len(skip_shape) == 2: + return skip_shape[0] == input_shape[1] and skip_shape[1] == input_shape[2] + return False + + class FusionSkipLayerNormalization(Fusion): """ - Fuse Add + LayerNormalization into one node: SkipLayerNormalization - Note: This fusion does not check the input shape of Add and LayerNormalization. + Fuse Add + LayerNormalization into one node: SkipLayerNormalization. + Supports broadcasting of the skip input: (1, sequence_length, hidden_size) + or (sequence_length, hidden_size) will be broadcast to match the input shape. """ def __init__( @@ -31,9 +46,33 @@ def __init__( # Update shape inference is needed since other fusions might add new edge which does not have shape info yet. self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True) if self.shape_infer_helper is None: - # TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op. + # TODO(tianleiwu): support subgraph in shape inference. logger.warning("symbolic shape inference disabled or failed.") + def get_skip_index(self, add): + """Identify which Add input is the skip tensor (the one that may broadcast). + + Returns (skip_index, broadcast): + skip_index: 0 or 1 (which Add input is skip), -1 if incompatible + broadcast: True if broadcasting is needed + """ + shape_a = self.shape_infer_helper.get_edge_shape(add.input[0]) + shape_b = self.shape_infer_helper.get_edge_shape(add.input[1]) + if shape_a is None or shape_b is None: + return -1, False + + if shape_a == shape_b: + return (1, False) if len(shape_a) == 3 else (-1, False) + + # Check if b is a broadcastable skip for a + if _is_broadcast_skip(shape_a, shape_b): + return 1, True + # Check if a is a broadcastable skip for b + if _is_broadcast_skip(shape_b, shape_a): + return 0, True + + return -1, False + def fuse(self, node, input_name_to_nodes, output_name_to_node): add = self.model.get_parent(node, 0, output_name_to_node) @@ -57,19 +96,15 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): # Root Mean Square Layer Normalization simplified = node.op_type == "SimplifiedLayerNormalization" + skip_index = 1 # default: add.input[1] is the skip + _broadcast = False + if hasattr(self, "shape_infer_helper"): if self.shape_infer_helper is not None: - if ( - self.shape_infer_helper.get_edge_shape(add.input[0]) - and len(self.shape_infer_helper.get_edge_shape(add.input[0])) != 3 - ): - logger.debug("skip SkipLayerNormalization fusion since shape of input %s is not 3D", add.input[0]) - return - - # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size) - if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): + skip_index, _broadcast = self.get_skip_index(add) + if skip_index < 0: logger.debug( - "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same", + "skip SkipLayerNormalization fusion since shapes of inputs (%s, %s) are not compatible", add.input[0], add.input[1], ) @@ -83,6 +118,19 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): if self.model.match_parent_path(gather_path[0], ["ConstantOfShape"], [1]) is None: return + # When broadcasting is needed, check that neither Add input comes from a Gather + # (embedding lookup). Embedding Add+LayerNorm should be fused by EmbedLayerNormalization + # later in the pipeline, not as SkipLayerNormalization. + if _broadcast: + for i in range(2): + parent = self.model.get_parent(add, i, output_name_to_node) + if parent is not None and parent.op_type == "Gather": + logger.debug( + "skip SkipLayerNormalization broadcast fusion since Add input %d comes from Gather (embedding)", + i, + ) + return + # This means that the residual Add before the LayerNormalization produces an output # that is consumed by some other nodes or graph output other than the LayerNormalization itself # We can still go ahead with the SkipLayerNormalization fusion but we need to @@ -106,10 +154,11 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): if self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.extend([add, node]) + input_index = 1 - skip_index inputs = ( - [add.input[0], add.input[1], node.input[1], node.input[2]] + [add.input[input_index], add.input[skip_index], node.input[1], node.input[2]] if not simplified - else [add.input[0], add.input[1], node.input[1]] + else [add.input[input_index], add.input[skip_index], node.input[1]] ) normalize_node = helper.make_node( self.fused_op_type, diff --git a/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py b/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py index a55ff5aa91519..454880f5b33b2 100644 --- a/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py +++ b/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py @@ -78,7 +78,7 @@ def create_test_model( ["output"], "layernorm", axis=-1, - epsion=0.000009999999747378752, + epsilon=0.000009999999747378752, ) initializers = [ # initializers @@ -270,6 +270,195 @@ def test_skip_layer_norm_graph_output_cast_bias2(self): ) os.remove(model_name) + def create_broadcast_test_model( + self, + batch_size: int = 2, + sequence_length: int = 3, + hidden_size: int = 4, + skip_shape: str = "2d", # "2d" for (seq, hidden), "3d_batch1" for (1, seq, hidden) + skip_on_input: int = 1, # Which Add input index gets the skip (smaller) shape + add_graph_output: bool = False, + simplified: bool = False, # Use SimplifiedLayerNormalization (RMS LayerNorm) instead + ): + """Create a test model where one Add input has a broadcast-compatible shape.""" + if skip_shape == "2d": + skip_dims = [sequence_length, hidden_size] + elif skip_shape == "3d_batch1": + skip_dims = [1, sequence_length, hidden_size] + else: + raise ValueError(f"Unknown skip_shape: {skip_shape}") + + full_dims = [batch_size, sequence_length, hidden_size] + + add_before_layer_norm = helper.make_node("Add", ["input_1", "input_2"], ["layernorm_input"], "add_layernorm") + + if simplified: + layer_norm = helper.make_node( + "SimplifiedLayerNormalization", + ["layernorm_input", "layer_norm_weight"], + ["output"], + "layernorm", + axis=-1, + epsilon=0.000009999999747378752, + ) + initializers = [float_tensor("layer_norm_weight", [hidden_size])] + else: + layer_norm = helper.make_node( + "LayerNormalization", + ["layernorm_input", "layer_norm_weight", "layer_norm_bias"], + ["output"], + "layernorm", + axis=-1, + epsilon=0.000009999999747378752, + ) + initializers = [ + float_tensor("layer_norm_weight", [hidden_size]), + float_tensor("layer_norm_bias", [hidden_size]), + ] + + input_1_shape = full_dims if skip_on_input != 0 else skip_dims + input_2_shape = skip_dims if skip_on_input != 0 else full_dims + + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, full_dims)] + if add_graph_output: + outputs.append(helper.make_tensor_value_info("layernorm_input", TensorProto.FLOAT, full_dims)) + + graph = helper.make_graph( + [add_before_layer_norm, layer_norm], + "SkipLayerNormBroadcastModel", + [ + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, input_1_shape), + helper.make_tensor_value_info("input_2", TensorProto.FLOAT, input_2_shape), + ], + outputs, + initializers, + ) + + onnx_opset = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) + return helper.make_model(graph, opset_imports=(onnx_opset,)) + + def test_skip_layer_norm_broadcast_2d_skip(self): + """2D skip (seq, hidden) on input[1] should fuse with input order preserved.""" + model = self.create_broadcast_test_model(skip_shape="2d", skip_on_input=1) + model_name = "skip_layer_norm_broadcast_2d.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + {"Add": 0, "LayerNormalization": 0, "SkipLayerNormalization": 1, "Cast": 0}, + ["input_1", "input_2", "layer_norm_weight", "layer_norm_bias"], + ["output"], + ) + os.remove(model_name) + + def test_skip_layer_norm_broadcast_2d_skip_swapped(self): + """2D skip (seq, hidden) on input[0] should fuse with inputs swapped.""" + model = self.create_broadcast_test_model(skip_shape="2d", skip_on_input=0) + model_name = "skip_layer_norm_broadcast_2d_swapped.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + {"Add": 0, "LayerNormalization": 0, "SkipLayerNormalization": 1, "Cast": 0}, + ["input_2", "input_1", "layer_norm_weight", "layer_norm_bias"], + ["output"], + ) + os.remove(model_name) + + def test_skip_layer_norm_broadcast_3d_batch1(self): + """3D skip (1, seq, hidden) on input[1] should fuse with input order preserved.""" + model = self.create_broadcast_test_model(skip_shape="3d_batch1", skip_on_input=1) + model_name = "skip_layer_norm_broadcast_3d_batch1.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + {"Add": 0, "LayerNormalization": 0, "SkipLayerNormalization": 1, "Cast": 0}, + ["input_1", "input_2", "layer_norm_weight", "layer_norm_bias"], + ["output"], + ) + os.remove(model_name) + + def test_skip_layer_norm_broadcast_3d_batch1_swapped(self): + """3D skip (1, seq, hidden) on input[0] should fuse with inputs swapped.""" + model = self.create_broadcast_test_model(skip_shape="3d_batch1", skip_on_input=0) + model_name = "skip_layer_norm_broadcast_3d_batch1_swapped.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + {"Add": 0, "LayerNormalization": 0, "SkipLayerNormalization": 1, "Cast": 0}, + ["input_2", "input_1", "layer_norm_weight", "layer_norm_bias"], + ["output"], + ) + os.remove(model_name) + + def test_skip_layer_norm_broadcast_graph_output(self): + """Broadcast fusion should preserve Add output when it is a graph output.""" + model = self.create_broadcast_test_model(skip_shape="2d", skip_on_input=1, add_graph_output=True) + model_name = "skip_layer_norm_broadcast_graph_output.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + {"Add": 0, "LayerNormalization": 0, "SkipLayerNormalization": 1, "Cast": 0}, + ["input_1", "input_2", "layer_norm_weight", "layer_norm_bias"], + ["output", "", "", "layernorm_input"], + ) + os.remove(model_name) + + def test_skip_layer_norm_broadcast_incompatible_shapes(self): + """Incompatible broadcast shapes should not fuse. + + Uses (2,3,4) + (1,1,4): broadcastable for Add but not supported by SkipLayerNorm + kernel (which requires skip seq_len == input seq_len). + """ + add_node = helper.make_node("Add", ["input_1", "input_2"], ["layernorm_input"], "add_layernorm") + layer_norm = helper.make_node( + "LayerNormalization", + ["layernorm_input", "layer_norm_weight", "layer_norm_bias"], + ["output"], + "layernorm", + axis=-1, + epsilon=0.000009999999747378752, + ) + initializers = [ + float_tensor("layer_norm_weight", [4]), + float_tensor("layer_norm_bias", [4]), + ] + graph = helper.make_graph( + [add_node, layer_norm], + "IncompatibleShapesModel", + [ + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [2, 3, 4]), + helper.make_tensor_value_info("input_2", TensorProto.FLOAT, [1, 1, 4]), # seq_len mismatch + ], + [helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 3, 4])], + initializers, + ) + onnx_opset = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) + model = helper.make_model(graph, opset_imports=(onnx_opset,)) + model_name = "skip_layer_norm_incompatible.onnx" + onnx.save(model, model_name) + + options = FusionOptions("bert") + optimized_model = optimize_model(model_name, optimization_options=options, opt_level=0) + self.assertEqual(len(optimized_model.get_nodes_by_op_type("SkipLayerNormalization")), 0) + self.assertEqual(len(optimized_model.get_nodes_by_op_type("Add")), 1) + os.remove(model_name) + + def test_skip_simplified_layer_norm_broadcast(self): + """SimplifiedLayerNormalization (RMS LayerNorm) with broadcast skip should fuse.""" + model = self.create_broadcast_test_model(skip_shape="2d", skip_on_input=1, simplified=True) + model_name = "skip_simplified_layer_norm_broadcast.onnx" + onnx.save(model, model_name) + + options = FusionOptions("bert") + optimized_model = optimize_model(model_name, optimization_options=options, opt_level=0) + + sln_nodes = optimized_model.get_nodes_by_op_type("SkipSimplifiedLayerNormalization") + self.assertEqual(len(sln_nodes), 1) + self.assertEqual(len(optimized_model.get_nodes_by_op_type("Add")), 0) + self.assertEqual(len(optimized_model.get_nodes_by_op_type("SimplifiedLayerNormalization")), 0) + # SimplifiedLayerNorm has no bias, so only 3 inputs: input, skip, weight + self.assertEqual(list(sln_nodes[0].input), ["input_1", "input_2", "layer_norm_weight"]) + os.remove(model_name) + if __name__ == "__main__": unittest.main()