-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ARITH] Improve div/mod in rewrite simplifier #3149
Conversation
Thanks for the PR. One thing I want to note is that we want to focus on the good use-cases in index simplification.
For example, it is a bit unclear whether taking the gcd out from To elaborate about the second principle, given that canonical simplifier now inheritates rewrite simplifier, we might be able to use canonical for many cases that rewrite simplifier is not good at. It would be great if we can provide some justifications for the set of changes. |
@tqchen I'm currently trying to move my implementation of tensor expression autodiff to the new simplifiers instead of the obsolete Halide simplifier. Most of these changes are required for simplifying complex automatically generated expressions involving divisions. Most of the new rules are just slightly modified copies of the existing rules, just to cover more cases.
It is needed to simplify expressions like
I can remove the corresponding test cases and the rules specifically designed to cover the negative divisor case, since they are not really required by the autodiff. However I would keep them just in case.
The rewrite simplifier has a clear design, so I would prefer fixing the rewrite simplifier rather than the canonical simplifier. Currently there are CI errors which seem to be connected to the fact that the old Halide notion of division is still in use in some parts of TVM. I didn't expect it and have to think what to do with them. |
I see, thanks for sharing your thoughts! I agree with most of your comment. I do hope that we can move forward to have a combined use case of canonical/rewrite simplifications. Given that canonical can be much more global and will be our default simplifier in the future. Specifically, the canonical simplifier can aggressively simplify complicated patterns of div/mod which are otherwise hard for rewrite simplifier |
@sgrechanik-h after some thoughts i think we can bring these changes in as long as they are properly reviewed. Can you tag more reviewers and address the CI issue? Thanks! |
@tqchen Sure. I want to run some additional tests though, so I will do it around monday. |
ea6ba99
to
52cfa09
Compare
The tests were failing because I enabled const folding when dividing negative numbers. Some parts of tvm (like the bound deducer) still assume that the division is euclidean, and since they also use const folding, enabling it for negative numbers (using truncated division) breaks them. Such parts should be rewritten or removed in the future, because currently they may lead to incorrect code. For now I just added additional const folding rule in the rewrite simplifier (which is not used in legacy code, hopefully). @derisavi @merrymercy @kazum @wweic Please review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm.
CanProveGreaterEqual(x.Eval(), 0) && c1.Eval()->value > 0); | ||
c1.Eval()->value != 0); | ||
TVM_TRY_REWRITE_IF(x - ((x + c2) / c1) * c1, (x + c2) % c1 - c2, | ||
c1.Eval()->value != 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we replace c2
by y
(a non-const), would the rule still be correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be. I'll fix it and add more tests.
src/arithmetic/rewrite_simplify.cc
Outdated
auto pgcd = PConstWithTypeLike<PVar<Expr>>(x, gcd); | ||
return ((x * b1 + y * b2) * pgcd).Eval(); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have the same question as Tianqi. I read your answer as follows:
It is needed to simplify expressions like 2 * x + 4 * y <= 0. I think I can move it to the LT case.
But still didn't quite understand what you mean. Do you mean that from 2x + 4y <= 0, we can derive x + 2*y <= 0? Can you elaborate please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean that from 2x + 4y <= 0, we can derive x + 2*y <= 0?
Yes, exactly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I don't understand is how dealing with x+2y <= 0 is any better/easier/simpler than 2x+4y <= 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I decided to investigate, and turned out I was wrong, and the rule is beneficial for a different reason, namely it transforms x*3 - (x/2)*6
into (x - (x/2)*2)*3
which is then transformed into (x % 2) * 3
by another rule (the resulting expression is easier for bound analysis).
I think, It is possible to replace this rule with a bunch of more specialized rules, however I would prefer a more general rule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense to do the simplification if c1 (resp. c2) is divisible by c2 (resp. c1)?
Even if it does, I still can't convince myself that xc1 + yc2 -> (x+yc3)c1 (where c3 = c2/c1) is always a beneficial transformation. For example, consider (6x + 3y) - (6y + 2z). If we simplify each parenthesized expression first and get 3*(2x+y) - 2(3*y+z), we may not be able to simplify further. I just see the problem but not a solution in my suggestion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(6x + 3y) - (6y + 2z)
cannot be simplified by the rewrite simplifier anyway. And the canonical simplifier is able to simplify both it and 3*(2x+y) - 2(3*y+z)
. However, this is indeed controversial, so I'll disable the transformation, and may be add several special-case rules instead.
I replaced the GCD rule with several more specialized rules. Please rereview. |
@@ -99,6 +93,12 @@ TryCompare(const Expr& x, int64_t val) { | |||
if (dbound->max_value <= val) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for enhancement , do we need to change dbound->max_value <= val into dbound->max_value == val? similar existing issue in line 90.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is really an enhancement, current formulation seems more straightforward, e.g. if max_value were an Expr, we could write can_prove(max_value <= val)
whereas can_prove(max_value == val)
would be "less complete".
TVM_TRY_REWRITE_IF(((x - y) / c1) * c3 - x * c2, ((y - x) % c1 - y) * c2, | ||
c1.Eval()->value != 0 && | ||
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rules are all correct to me. Thanks.
What happens if we have something like c2 * x - c3 * (x / c1)
( I used commutativity) ? Does the simplifier automatically use TVM_TRY_RECURSIVE_REWRITE(c1 * x, x * c1);
in its attempt to simplify it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this particular example works thanks to this canonicalization rule. I've added a couple of tests to be sure.
int64_t c2val = c2.Eval()->value; | ||
return make_const(op->type, c1val / c2val); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2, | ||
c1.Eval()->value > 0 && | ||
TVM_TRY_REWRITE_IF(x * c2 < c1, x < c1 / c2, | ||
c1.Eval()->value <= 0 && | ||
c2.Eval()->value > 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is true if x
is an integer. Can we assume that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can, since these rules are guarded by the condition IsIndexType(op->a.type())
, and also all constants are declared as integers (PVar<Integer>
).
Other than the comments above, it LGTM. |
c1.Eval()->value != 0 && | ||
c3.Eval()->value == c1.Eval()->value * c2.Eval()->value); | ||
TVM_TRY_REWRITE_IF(((x - y) / c1) * c3 - x * c2, ((y - x) % c1 - y) * c2, | ||
c1.Eval()->value != 0 && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sgrechanik-h please update the comment to mark all the cases that require special truc div rule. So we can be sure about this later
// NOTE: trunc div required
98d0663
to
f14aa9a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @sgrechanik-h.
Thanks, @sgrechanik-h @wweic @derisavi @huajsj . This PR is now merged |
* [ARITH] Improve div/mod in rewrite simplifier * Fix lint error * Fuller file name in src/arithmetic/modular_set.h Co-Authored-By: Wei Chen <[email protected]> * Generalize some rules * Replace gcd factoring with specialized rules * Mark rules that don't work for non-truncated division * More tests
* [ARITH] Improve div/mod in rewrite simplifier * Fix lint error * Fuller file name in src/arithmetic/modular_set.h Co-Authored-By: Wei Chen <[email protected]> * Generalize some rules * Replace gcd factoring with specialized rules * Mark rules that don't work for non-truncated division * More tests
This PR improves division-related rewriting rules.