Skip to content

Commit

Permalink
Fix fusion for two LayerNorm sharing same input but with different we…
Browse files Browse the repository at this point in the history
…ights (#15919)

in gpt_j_residual(https://arxiv.org/pdf/2204.06745.pdf), there are 2 LN
nodes will share one same input, and ORT does CSE graph optimization
before LN fusion, which will modify the LN graph pattern and thus make
LN fusion failure.


![image](https://github.com/microsoft/onnxruntime/assets/10530022/40990fd6-796f-4edf-be0b-3203e8503678)
  • Loading branch information
zhijxu-MS authored May 22, 2023
1 parent 5607a71 commit 4dc4470
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 1 deletion.
15 changes: 15 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test_layernorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ TEST_F(GraphTransformationTests, LayerNormFusionTest) {
}
}

TEST_F(GraphTransformationTests, TwoLayerNormShareSameInput) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_shared_input.onnx";
std::shared_ptr<Model> p_model;
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
Graph& graph = p_model->MainGraph();

onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<LayerNormFusion>(), TransformerLevel::Level2));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));

std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count.size() == 1);
ASSERT_TRUE(op_to_count["LayerNormalization"] == 2);
}

TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_with_cast.onnx";
std::shared_ptr<Model> p_model;
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import onnx
from onnx import OperatorSetIdProto, TensorProto, helper


# in gpt_j_residual, there will be 2 LN share the same input
def GenerateModel(model_name): # noqa: N802
nodes = [
# LN1 subgraph
helper.make_node("ReduceMean", ["A"], ["LN1/rd1_out"], "LN1/reduce", axes=[-1]),
helper.make_node("Sub", ["A", "LN1/rd1_out"], ["LN1/sub1_out"], "LN1/sub"),
helper.make_node("Pow", ["LN1/sub1_out", "LN1/pow_in_2"], ["LN1/pow_out"], "LN1/pow"),
helper.make_node("ReduceMean", ["LN1/pow_out"], ["LN1/rd2_out"], "LN1/reduce2", axes=[-1]),
helper.make_node("Add", ["LN1/rd2_out", "LN1/const_0"], ["LN1/add1_out"], "LN1/add"),
helper.make_node("Sqrt", ["LN1/add1_out"], ["LN1/sqrt_out"], "LN1/sqrt"),
helper.make_node("Div", ["LN1/sub1_out", "LN1/sqrt_out"], ["LN1/div_out"], "LN1/div"),
helper.make_node("Mul", ["LN1/gamma", "LN1/div_out"], ["LN1/mul_out"], "LN1/mul"),
helper.make_node("Add", ["LN1/beta", "LN1/mul_out"], ["LN1/C"], "LN1/add2"),
# LN2 subgraph
helper.make_node("ReduceMean", ["A"], ["LN2/rd1_out"], "LN2/reduce", axes=[-1]),
helper.make_node("Sub", ["A", "LN2/rd1_out"], ["LN2/sub1_out"], "LN2/sub"),
helper.make_node("Pow", ["LN2/sub1_out", "LN2/pow_in_2"], ["LN2/pow_out"], "LN2/pow"),
helper.make_node("ReduceMean", ["LN2/pow_out"], ["LN2/rd2_out"], "LN2/reduce2", axes=[-1]),
helper.make_node("Add", ["LN2/rd2_out", "LN2/const_0"], ["LN2/add1_out"], "LN2/add"),
helper.make_node("Sqrt", ["LN2/add1_out"], ["LN2/sqrt_out"], "LN2/sqrt"),
helper.make_node("Div", ["LN2/sub1_out", "LN2/sqrt_out"], ["LN2/div_out"], "LN2/div"),
helper.make_node("Mul", ["LN2/gamma", "LN2/div_out"], ["LN2/mul_out"], "LN2/mul"),
helper.make_node("Add", ["LN2/beta", "LN2/mul_out"], ["LN2/C"], "LN2/add2"),
]

initializers = [
# LN1 initializers
helper.make_tensor("LN1/pow_in_2", TensorProto.FLOAT, [], [2]),
helper.make_tensor("LN1/const_0", TensorProto.FLOAT, [], [0]),
helper.make_tensor("LN1/gamma", TensorProto.FLOAT, [4], [1, 2, 3, 4]),
helper.make_tensor("LN1/beta", TensorProto.FLOAT, [4], [1, 2, 3, 4]),
# LN2 initializers
helper.make_tensor("LN2/pow_in_2", TensorProto.FLOAT, [], [2]),
helper.make_tensor("LN2/const_0", TensorProto.FLOAT, [], [0]),
helper.make_tensor("LN2/gamma", TensorProto.FLOAT, [4], [1, 2, 3, 4]),
helper.make_tensor("LN2/beta", TensorProto.FLOAT, [4], [1, 2, 3, 4]),
]

graph = helper.make_graph(
nodes,
"2LayerNormShareSameInput", # name
[ # inputs
helper.make_tensor_value_info("A", TensorProto.FLOAT, [16, 32, 4]),
],
[ # outputs
helper.make_tensor_value_info("LN1/C", TensorProto.FLOAT, [16, 32, 4]),
helper.make_tensor_value_info("LN2/C", TensorProto.FLOAT, [16, 32, 4]),
],
initializers,
)

onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
# The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification.
onnxdomain.domain = ""
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = "com.microsoft"
opsets = [onnxdomain, msdomain]

model = helper.make_model(graph, opset_imports=opsets)
onnx.save(model, model_name)


GenerateModel("layer_norm_shared_input.onnx")
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
// CSE. For example, if A and B nodes both do Add operation with a same value but different initializers, by
// default, CSE will not merge them, because the different initializers are represented by different NodeArg.
transformers.emplace_back(std::make_unique<ConstantSharing>(compatible_eps));
// LayerNormFusion must be applied before CommonSubexpressionElimination as the latter will break the pattern when 2 LayerNormFusion share the same input.
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
// Remove duplicate nodes. Must be applied before any recompute transformations.
if (config.gelu_recompute || config.attn_dropout_recompute || config.transformer_layer_recompute) {
transformers.emplace_back(std::make_unique<CommonSubexpressionEliminationApplyOnce>(compatible_eps));
Expand All @@ -117,7 +119,6 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
}

transformers.emplace_back(std::make_unique<GeluFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
#if defined(USE_CUDA) || defined(USE_ROCM)
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps,
true /* skip_device_check*/));
Expand Down

0 comments on commit 4dc4470

Please sign in to comment.