Skip to content

Commit

Permalink
Make rewrite_once default to false
Browse files Browse the repository at this point in the history
Change-Id: Idf6f01f254c403158883681e75c2a5978efbd2d0
  • Loading branch information
ekalda committed Aug 26, 2021
1 parent b10297a commit 5ad8fe3
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 71 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class DFPatternCallbackNode : public Object {
class DFPatternCallback : public ObjectRef {
public:
TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type,
bool rewrite_once);
bool rewrite_once = false);
TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode);
};

Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/fold_explicit_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class SimplifyExplicitPadding {
Map<DFPattern, Array<Expr>> node_map = args[2];
*rv = pattern.callback(pre, post, node_map);
};
callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true, false));
callbacks_.push_back(DFPatternCallback(pattern.pattern(), PackedFunc(func), true));
}

Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); }
Expand Down
4 changes: 1 addition & 3 deletions src/relay/transforms/simplify_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,14 @@ class DFPatternRewrite {
Map<DFPattern, Array<Expr>> node_map = args[2];
*rv = this->Callback(pre, post, node_map);
};
return DFPatternCallback(pattern_, PackedFunc(func), require_type_, rewrite_once_);
return DFPatternCallback(pattern_, PackedFunc(func), require_type_);
}

protected:
/*! \brief The pattern for matching and rewriting. */
DFPattern pattern_;
/*! \brief Whether or not the rewrite requires types to be inferred. */
bool require_type_ = true;
/*! \brief If True, rewrite only once. */
bool rewrite_once_ = false;
};

/*! \brief Helper class for composing rewrites and getting callbacks. */
Expand Down
67 changes: 1 addition & 66 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,69 +1760,4 @@ def callback(self, pre, post, node_map):


if __name__ == "__main__":
test_expr_pattern()
test_var_pattern()
test_constant_pattern()
test_wildcard_pattern()
test_CallPattern()
test_TuplePattern()
test_TupleGetItemPattern()
test_AltPattern()
test_TypePattern()
test_DataTypePattern()
test_ShapePattern()
test_AttrPattern()
test_match_op()
test_no_match_op()
test_match_op_or()
test_match_call_commutive()
test_no_match_call_commutive()
test_match_call()
test_no_match_call()
test_match_option()
test_no_match_option()
test_match_const()
test_match_tuple()
test_no_match_tuple()
test_match_type()
test_no_match_type()
test_match_dtype()
test_no_match_dtype()
test_match_shape()
test_no_match_shape()
test_match_op_attr()
test_no_match_op_attr()
test_match_func_attr()
test_no_match_func_attr()
test_match_call_attr()
test_no_match_call_attr()
test_match_diamond()
test_no_match_diamond()
test_match_fake_diamond()
test_match_dominator()
test_not_match_dominator()
test_rewrite()
test_rewrite_func()
test_nested_rewrite()
test_not_fuse_multi_diamond()
test_fuse_batchnorm()
test_no_fuse_batchnorm()
test_fuse_double_batchnorm()
test_partial_fuse_double_batchnorm()
test_fuse_batchnorm_commutation()
test_quadruple_rewrite_dominator()
test_algebraic_simplify()
test_double_partition()
test_partition_dominator()
test_quadruple_partition_dominator()
test_partition_batchnorm()
test_partition_double_batchnorm()
test_partition_check()
test_partition_check_types()
test_partition_option()
test_match_match()
test_partition_constant_embedding()
test_IfPattern()
test_match_if()
test_no_match_if()
test_rewrite_once()
pytest.main([__file__])

0 comments on commit 5ad8fe3

Please sign in to comment.