Skip to content

Commit 196b694

Browse files
authored
[FRONTEND][ONNX] Make bias input optional in LayerNormalization (#17980)
This change updates the LayerNormalization converter to support ONNX models where the optional bias input is not provided. When bias is missing, an empty bias tensor is generated. This behavior aligns with the ONNX’s spec for LayerNormalization (opset 17+) where the bias input is officially optional.
1 parent 68dd534 commit 196b694

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2489,6 +2489,10 @@ def _impl_v17(cls, bb, inputs, attr, params):
24892489
axis = attr.get("axis", -1)
24902490
epsilon = attr.get("epsilon", 1e-05)
24912491

2492+
if bias is None:
2493+
seq_len = data.struct_info.shape[1].value
2494+
bias = relax.const([0.0] * seq_len, dtype="float32")
2495+
24922496
output = relax.op.nn.layer_norm(data, scale, bias, axis, epsilon)
24932497
# Onnx layernorm has 3 outputs but only the first is used.
24942498
# We construct two empty constants for this.

tests/python/relax/test_frontend_onnx.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,6 +1303,24 @@ def test_layer_norm():
13031303
model = helper.make_model(graph, producer_name="layer_norm_test")
13041304
check_correctness(model)
13051305

1306+
# Test case with no bias that is an optional input
1307+
layer_norm_node = helper.make_node("LayerNormalization", ["a", "b"], ["d"], epsilon=1e-12)
1308+
1309+
graph = helper.make_graph(
1310+
[layer_norm_node],
1311+
"layer_norm_test",
1312+
inputs=[
1313+
helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
1314+
helper.make_tensor_value_info("b", TensorProto.FLOAT, [32]),
1315+
],
1316+
outputs=[
1317+
helper.make_tensor_value_info("d", TensorProto.FLOAT, [32, 32]),
1318+
],
1319+
)
1320+
1321+
model = helper.make_model(graph, producer_name="layer_norm_test")
1322+
check_correctness(model)
1323+
13061324

13071325
# TODO Enable dynamism
13081326
@pytest.mark.parametrize("dynamic", [False])

0 commit comments

Comments
 (0)