-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ARITH] Simplify casts of constants 0 and 1 #3758
Conversation
include/tvm/expr_operator.h
Outdated
return false; | ||
} | ||
} | ||
|
||
inline bool is_const_int(const Expr& x, int64_t value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep this impl, as the is_const_value involves recursions and could be more complicated than this one.
@@ -1757,6 +1757,20 @@ Mutate_(const Variable* op, const Expr& self) { | |||
return self; | |||
} | |||
|
|||
Expr RewriteSimplifier::Impl:: | |||
Mutate_(const Cast* op, const Expr& self) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us just add cast between (uint/float/int here), ignore vector types for now as they are not as common.
include/tvm/expr_operator.h
Outdated
@@ -551,18 +561,31 @@ inline bool is_negative_const(const Expr& a) { | |||
} | |||
} | |||
|
|||
template <typename ValueType> | |||
inline bool is_const_value(const Expr& e, ValueType value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us not introduce this for now(see comments in the Mutator).
src/arithmetic/rewrite_simplify.cc
Outdated
} | ||
if (is_const_value(op->value, 1)) { | ||
return make_const(op->type, 1); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you support float(16) --> 16.000, int32(16.000) --> 16 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the implementation via tvm::cast
cases like these should work now.
This reverts commit 7e1b346.
Turned out there is the |
* [ARITH] Simplify casts of constants 0 and 1 * [EXPR] is_const_value to check whether non-ints are consts * Revert "[EXPR] is_const_value to check whether non-ints are consts" This reverts commit 7e1b346. * Use tvm::cast
* [ARITH] Simplify casts of constants 0 and 1 * [EXPR] is_const_value to check whether non-ints are consts * Revert "[EXPR] is_const_value to check whether non-ints are consts" This reverts commit 7e1b346. * Use tvm::cast
* [ARITH] Simplify casts of constants 0 and 1 * [EXPR] is_const_value to check whether non-ints are consts * Revert "[EXPR] is_const_value to check whether non-ints are consts" This reverts commit 7e1b346. * Use tvm::cast
This PR implements simplification of some casts of constants in the rewrite simplifier. The use case I have in mind is conversion from boolean expressions to floats (
float32(i < j)
instead ofselect(i < k, 1.0, 0.0)
). I restricted consts to be just 0 and 1 because I'm not sure about usefulness and safety of other values (0 and 1 should be safe for all types, except maybe for some custom types).This PR also introduces the
is_const_value
functions which is likeis_const_int
, but works for float expressions too. Besides this PR, it is also used in the tensor expression autodiff implementation.