From 6d727b6d2c5fb9315bb4f2668017592211c6efa3 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Sun, 17 Nov 2019 08:12:36 +0000 Subject: [PATCH] [Relay tests] AlterOpLayout - Temporary attr update --- include/tvm/relay/op.h | 6 ++ python/tvm/relay/op/op.py | 16 ++++ src/relay/ir/op.cc | 20 ++++ .../python/relay/test_pass_alter_op_layout.py | 92 ++++++++++++------- 4 files changed, 102 insertions(+), 32 deletions(-) diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index 7d2a1f653a932..7f1ef456b59b0 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -258,6 +258,12 @@ class OpRegistry { inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*) const ValueType& value, int plevel = 10); + /*! + * \brief Resets an attr of the registry. + * \param attr_name The name of the attribute. + */ + inline void reset_attr(const std::string& attr_name); + // set the name of the op to be the same as registry inline OpRegistry& set_name() { // NOLINT(*) if (get()->name.length() == 0) { diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index fcbc3fd544479..00a2aa0d8bc13 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -64,6 +64,22 @@ def set_attr(self, attr_name, value, plevel=10): """ _OpSetAttr(self, attr_name, value, plevel) + def reset_attr(self, attr_name): + """Reset attribute about the operator. + + Parameters + ---------- + attr_name : str + The attribute name + + value : object + The attribute value + + plevel : int + The priority level + """ + _OpResetAttr(self, attr_name) + def get(op_name): """Get the Op for a given name diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index c4557ac16ad5d..53902a28206a4 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -95,6 +95,17 @@ const bool Op::HasGenericAttr(const std::string& key) { return true; } +// Resets attr of the OpMap. +void OpRegistry::reset_attr(const std::string& key) { + OpManager* mgr = OpManager::Global(); + std::lock_guard lock(mgr->mutex); + std::unique_ptr& op_map = mgr->attr[key]; + if (op_map == nullptr) { + return; + } + op_map->data_.clear(); +} + void OpRegistry::UpdateAttr(const std::string& key, TVMRetValue value, int plevel) { @@ -152,6 +163,15 @@ TVM_REGISTER_API("relay.op._OpSetAttr") reg.set_attr(attr_name, value, plevel); }); +TVM_REGISTER_API("relay.op._OpResetAttr") +.set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + auto& reg = + OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name(); + reg.reset_attr(attr_name); + }); + TVM_REGISTER_API("relay.op._Register") .set_body([](TVMArgs args, TVMRetValue* rv) { std::string op_name = args[0]; diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 2738690025df6..3fe001ef8f27b 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -31,6 +31,17 @@ def run_opt_pass(expr, passes): entry = mod["main"] return entry if isinstance(expr, relay.Function) else entry.body +def reset_retrieve_alter_op_layout(op_name, alter_layout): + op = relay.op.get(op_name) + older_attr = op.get_attr("FTVMAlterOpLayout") + op.reset_attr("FTVMAlterOpLayout") + register_alter_op_layout(op_name, alter_layout) + return older_attr + +def recover_attr(op_name, older_attr): + op = relay.op.get(op_name) + op.reset_attr("FTVMAlterOpLayout") + op.set_attr("FTVMAlterOpLayout", older_attr) def test_alter_op(): """Test directly replacing an operator with a new one""" @@ -45,13 +56,13 @@ def before(): y = relay.Function([x, weight], y) return y - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=100) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs weight = relay.multiply(weight, relay.const(2.0, "float32")) return relay.nn.conv2d(data, weight, **attrs) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + def expected(): x = relay.var("x", shape=(1, 64, 56, 56)) weight = relay.var('weight', shape=(64, 64, 3, 3)) @@ -69,6 +80,8 @@ def expected(): assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + recover_attr("nn.conv2d", older_attr) + def test_alter_return_none(): """Test doing nothing by returning 'None' """ @@ -80,12 +93,12 @@ def before(): called = [False] - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.global_max_pool2d", level=101) def alter_conv2d(attrs, inputs, tinfos): called[0] = True return None + older_attr = reset_retrieve_alter_op_layout("nn.global_max_pool2d", alter_conv2d) + a = before() a = run_opt_pass(a, transform.AlterOpLayout()) @@ -94,6 +107,8 @@ def alter_conv2d(attrs, inputs, tinfos): assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) assert(called[0]) + recover_attr("nn.global_max_pool2d", older_attr) + def test_alter_layout(): """Test alternating the layout of a conv2d. @@ -114,8 +129,6 @@ def before(): y = relay.Function(analysis.free_vars(y), y) return y - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=102) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) @@ -123,6 +136,8 @@ def alter_conv2d(attrs, inputs, tinfos): new_attrs['kernel_layout'] = 'OIHW16i' return relay.nn.conv2d(data, weight, **new_attrs) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + def expected(): x = relay.var("x", shape=(1, 64, 56, 56)) bias = relay.var("bias", shape=(64,)) @@ -157,6 +172,7 @@ def expected(): b = run_opt_pass(b, transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + recover_attr("nn.conv2d", older_attr) def test_alter_layout_dual_path(): @@ -183,14 +199,14 @@ def before(): y = relay.Function(analysis.free_vars(ret), ret) return y - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=103) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) new_attrs['data_layout'] = 'NCHW16c' return relay.nn.conv2d(data, weight, **new_attrs) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + def expected(): x = relay.var("x", shape=(1, 64, 56, 56)) weight1 = relay.var('weight1') @@ -222,6 +238,7 @@ def expected(): b = run_opt_pass(b, transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + recover_attr("nn.conv2d", older_attr) def test_alter_layout_resnet(): """Test alternating the layout of a residual block @@ -245,14 +262,14 @@ def before(): y = relay.nn.global_max_pool2d(y) return relay.Function(analysis.free_vars(y), y) - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=104) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) new_attrs['data_layout'] = 'NCHW16c' return relay.nn.conv2d(data, weight, **new_attrs) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + def expected(): x = relay.var("x", shape=(1, 64, 56, 56)) weight1 = relay.var('weight1') @@ -281,6 +298,7 @@ def expected(): b = run_opt_pass(b, transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + recover_attr("nn.conv2d", older_attr) def test_alter_layout_broadcast_op(): @@ -296,14 +314,14 @@ def before(): y = relay.Function(analysis.free_vars(y), y) return y - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=105) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) new_attrs['data_layout'] = 'NCHW16c' return relay.nn.conv2d(data, weight, **new_attrs) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + def expected(): x = relay.var("x", shape=(1, 64, 56, 56)) bias = relay.var("bias", shape=(64,)) @@ -331,6 +349,7 @@ def expected(): b = run_opt_pass(b, transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + recover_attr("nn.conv2d", older_attr) def test_alter_layout_scalar(): """Test alternating the layout of a conv2d. @@ -344,14 +363,14 @@ def before(): y = relay.Function(analysis.free_vars(y), y) return y - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=106) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) new_attrs['data_layout'] = 'NCHW16c' return relay.nn.conv2d(data, weight, **new_attrs) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + def expected(): x = relay.var("x", shape=(1, 64, 56, 56)) w = relay.var("weight") @@ -376,18 +395,19 @@ def expected(): b = run_opt_pass(b, transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + recover_attr("nn.conv2d", older_attr) def test_alter_layout_concatenate(): """ NCHW, NHWC and corner case concatenate layout transform.""" - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=107) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) new_attrs['data_layout'] = 'NCHW16c' return relay.nn.conv2d(data, weight, **new_attrs) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + # NCHW layout transformation. def before_nchw(): x = relay.var("x", shape=(1, 64, 56, 56)) @@ -479,6 +499,7 @@ def expected_nhwc(): b = run_opt_pass(b, transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + recover_attr("nn.conv2d", older_attr) def test_alter_layout_nchw_upsamping_op(): @@ -492,14 +513,14 @@ def before(): y = relay.Function(analysis.free_vars(y), y) return y - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=108) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) new_attrs['data_layout'] = 'NCHW16c' return relay.nn.conv2d(data, weight, **new_attrs) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + def expected(): x = relay.var("x", shape=(1, 32, 28, 28)) weight = relay.var("weight") @@ -520,6 +541,7 @@ def expected(): b = run_opt_pass(b, transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + recover_attr("nn.conv2d", older_attr) def test_alter_layout_strided_slice(): @@ -532,14 +554,14 @@ def before(): y = relay.Function(analysis.free_vars(y), y) return y - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=109) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) new_attrs['data_layout'] = 'NCHW4c' return relay.nn.conv2d(data, weight, **new_attrs) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + def expected(): x = relay.var("x", shape=(1, 32, 28, 28)) weight = relay.var("weight") @@ -559,6 +581,7 @@ def expected(): b = run_opt_pass(b, transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + recover_attr("nn.conv2d", older_attr) def test_alter_layout_depthwise_conv2d(): """Test depthwise_conv2d operator""" @@ -570,12 +593,12 @@ def before(): return y import topi - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=110) def alter_conv2d(attrs, inputs, tinfos): with tvm.target.create("llvm"): return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, relay) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + def expected(): x = relay.var("x", shape=(1, 32, 56, 56)) w = relay.var("w", shape=(32, 1, 3, 3)) @@ -596,6 +619,7 @@ def expected(): b = run_opt_pass(b, transform.InferType()) assert(analysis.alpha_equal(a, b)) + recover_attr("nn.conv2d", older_attr) def test_alter_layout_prelu(): """Test PRelu operator""" @@ -608,14 +632,14 @@ def before(): y = relay.Function(analysis.free_vars(y), y) return y - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=111) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) new_attrs['data_layout'] = 'NCHW16c' return relay.nn.conv2d(data, weight, **new_attrs) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + def expected(): x = relay.var("x", shape=(1, 64, 56, 56)) w = relay.var("weight") @@ -639,18 +663,19 @@ def expected(): b = run_opt_pass(b, transform.InferType()) assert(analysis.alpha_equal(a, b)) + recover_attr("nn.conv2d", older_attr) def test_alter_layout_pad(): """ Check NCHW, NHWC and corner case for pad layout conversion""" - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=112) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) new_attrs['data_layout'] = 'NCHW16c' return relay.nn.conv2d(data, weight, **new_attrs) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + # Check NCHW conversion. def before_nchw(): x = relay.var("x", shape=(1, 64, 56, 56)) @@ -753,18 +778,19 @@ def expected(): b = run_opt_pass(b, transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + recover_attr("nn.conv2d", older_attr) def test_alter_layout_pool(): """ Check NCHW, NHWC pool layout conversion""" - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=113) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) new_attrs['data_layout'] = 'NCHW16c' return relay.nn.conv2d(data, weight, **new_attrs) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + # Check NCHW conversion. def before_nchw(): x = relay.var("x", shape=(1, 64, 56, 56)) @@ -833,18 +859,19 @@ def expected_nhwc(): b = run_opt_pass(b, transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + recover_attr("nn.conv2d", older_attr) def test_alter_layout_sum(): """ Check NCHW, NHWC sum layout conversion""" - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=114) def alter_conv2d(attrs, inputs, tinfos): data, weight = inputs new_attrs = dict(attrs) new_attrs['data_layout'] = 'NCHW16c' return relay.nn.conv2d(data, weight, **new_attrs) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + # Check NCHW conversion. def before_nchw(): x = relay.var("x", shape=(1, 64, 56, 56)) @@ -914,16 +941,17 @@ def expected_nhwc(): b = run_opt_pass(b, transform.InferType()) assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + recover_attr("nn.conv2d", older_attr) def test_alter_layout_nhwc_nchw_arm(): """ Check NHWC to NHCW conversion for a small sequence of ops.""" - # Register alter op layout. "level" is used to override the previously registered functions. - @register_alter_op_layout("nn.conv2d", level=115) def alter_conv2d(attrs, inputs, tinfos): from topi.arm_cpu.conv2d import _alter_conv2d_layout_arm return _alter_conv2d_layout_arm(attrs, inputs, tinfos, tvm.relay) + older_attr = reset_retrieve_alter_op_layout("nn.conv2d", alter_conv2d) + # Check NHWC conversion. def before_nhwc(): x = relay.var("x", shape=(1, 56, 56, 64))