-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix inconsistent python/cpp API behavior for if_then_else, power #3829
Fix inconsistent python/cpp API behavior for if_then_else, power #3829
Conversation
@kevinthesun @icemelon9 |
@@ -419,7 +419,7 @@ def power(x, y): | |||
z : Expr | |||
The result. | |||
""" | |||
return call_pure_intrin(x.dtype, "pow", x, y) | |||
return _make._OpPow(convert(x), convert(y)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a test case to check the type conversion.
if cond.dtype != "bool": | ||
raise TypeError("The condition's data type has to be bool") | ||
return call_pure_intrin(t.dtype, "tvm_if_then_else", cond, t, f) | ||
return _make._OpIfThenElse(convert(cond), convert(t), convert(f)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
@kevinthesun I've added the checks. |
Also, I find that the behavior of For example: import numpy as np
import tvm
print((tvm.var(dtype='int32') + tvm.var(dtype='float16')).dtype)
# result: float16
print((tvm.var(dtype='int32') + tvm.var(dtype='float32')).dtype)
# result: float32
print((np.array(1).astype('int32') + np.array(1).astype('float32')).dtype)
# result: float64
print((np.array(1).astype('int32') + np.array(1).astype('float16')).dtype)
# result: float64 Seems that numpy promotes the float to the highest precision whenever it meets the integer + float. |
Thanks! |
…che#3829) * fix inconsistent python/cpp APIs for if_then_else * fix error message * fix power consistency * fix * fix bug * add test
…che#3829) * fix inconsistent python/cpp APIs for if_then_else * fix error message * fix power consistency * fix * fix bug * add test
…che#3829) * fix inconsistent python/cpp APIs for if_then_else * fix error message * fix power consistency * fix * fix bug * add test
In the C++ side, the if_then_else will call BinaryOpMatchTypes, which will try to convert the dtypes of the inner expressions to be the same:
https://github.com/dmlc/tvm/blob/d05893b3aaddb098780a2fbaf2d2af81bb50d393/src/lang/expr_operator.cc#L242
However, in the python side, there is no such conversion, which makes the behavior inconsistent.
The same issue occurs for power. Thus, fix by directly registering the APIs inside _make.
To give an example of the inconsistent behavior, consider the following example:
In the original version, the result will be
If we use the C++ API, it will be the following, in which b will be automatically casted to float32.