Skip to content

Commit 1985c01

Browse files
authored
[Relay][Layout] Add FInferCorrectLayout for L2 norm layout transform. (#12497)
* [Relay][Layout] FInferCorrectLayout for L2 norm layout change. * [Relay][Layout] Test for L2 norm layout transform. * [Relay][Layout] Re-edit test to add multi-dimensional axis list. * Fix cpplint errors * Use clang-format-10 rules. * replace uint with size_t.
1 parent c83ee08 commit 1985c01

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

src/relay/op/nn/nn.cc

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,42 @@ Expr MakeL2Normalize(Expr data, double eps, Array<Integer> axis) {
653653
return Call(op, {data}, Attrs(attrs), {});
654654
}
655655

656+
InferCorrectLayoutOutput L2NormalizeInferCorrectLayout(
657+
const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
658+
const Array<tvm::relay::Type>& old_in_types) {
659+
const auto* attrs_ptr = attrs.as<L2NormalizeAttrs>();
660+
ICHECK(attrs_ptr);
661+
ObjectPtr<L2NormalizeAttrs> param = make_object<L2NormalizeAttrs>(*attrs_ptr);
662+
663+
Array<Array<IndexExpr>> old_in_shapes;
664+
for (auto old_in_t : old_in_types) {
665+
ICHECK(old_in_t.as<TensorTypeNode>());
666+
old_in_shapes.push_back(old_in_t.as<TensorTypeNode>()->shape);
667+
}
668+
std::vector<size_t> axis_list;
669+
for (auto i : param->axis) {
670+
int64_t axis = i->value;
671+
if (axis < 0) {
672+
axis = axis + static_cast<size_t>(old_in_shapes[0].size());
673+
}
674+
axis_list.emplace_back(axis);
675+
}
676+
677+
Layout ret = Layout::Undef();
678+
if (new_in_layouts.defined() && old_in_layouts.defined()) {
679+
for (size_t i = 0; i < axis_list.size(); ++i) {
680+
const auto& axis_dim = old_in_layouts[0][axis_list[i]];
681+
auto axis_index = new_in_layouts[0].IndexOf(axis_dim);
682+
param->axis.Set(i, axis_index);
683+
}
684+
ret = new_in_layouts[0];
685+
} else if (old_in_layouts.defined()) {
686+
ret = old_in_layouts[0];
687+
}
688+
689+
return InferCorrectLayoutOutput({ret}, {ret}, Attrs(param));
690+
}
691+
656692
TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize").set_body_typed(MakeL2Normalize);
657693

658694
RELAY_REGISTER_OP("nn.l2_normalize")
@@ -669,7 +705,7 @@ Normalizes along dimension axis using an L2 norm
669705
.set_num_inputs(1)
670706
.add_argument("data", "Tensor", "The input tensor.")
671707
.set_support_level(2)
672-
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
708+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", L2NormalizeInferCorrectLayout)
673709
.add_type_rel("Identity", IdentityRel);
674710

675711
// Dropout

tests/python/relay/test_pass_convert_op_layout.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2697,5 +2697,44 @@ def expected():
26972697
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n Expect = \n" + str(b)
26982698

26992699

2700+
def test_conv_l2n_convert_layout():
2701+
"""Check that layout transforms are propagated through bn."""
2702+
axis_list = ([3], [-1], [2, 3])
2703+
expected_axis = ([1], [1], [3, 1])
2704+
for i, axis in enumerate(axis_list):
2705+
2706+
def before():
2707+
x = relay.var("x", shape=(1, 56, 56, 64))
2708+
weight = relay.var("weight", shape=(3, 3, 64, 64))
2709+
y = relay.nn.conv2d(
2710+
x,
2711+
weight,
2712+
channels=64,
2713+
kernel_size=(3, 3),
2714+
padding=(1, 1),
2715+
data_layout="NHWC",
2716+
kernel_layout="HWIO",
2717+
)
2718+
z = relay.nn.l2_normalize(y, eps=0.001, axis=axis)
2719+
z = relay.Function(analysis.free_vars(z), z)
2720+
return z
2721+
2722+
def expected():
2723+
x = relay.var("x", shape=(1, 56, 56, 64))
2724+
w = relay.var("weight", shape=(3, 3, 64, 64))
2725+
x = relay.layout_transform(x, "NHWC", "NCHW")
2726+
w = relay.layout_transform(w, "HWIO", "OIHW")
2727+
y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1))
2728+
z = relay.nn.l2_normalize(y, eps=0.001, axis=expected_axis[i])
2729+
z = relay.layout_transform(z, "NCHW", "NHWC")
2730+
z = relay.Function(analysis.free_vars(z), z)
2731+
return z
2732+
2733+
a = before()
2734+
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
2735+
b = run_opt_pass(expected(), transform.InferType())
2736+
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n Expected = \n" + str(b)
2737+
2738+
27002739
if __name__ == "__main__":
27012740
pytest.main([__file__])

0 commit comments

Comments
 (0)