Skip to content

Commit b20678b

Browse files
ZihengJiangtqchen
authored andcommitted
[TOPI] Fix declaration for different dtypes (#546)
1 parent b384cd4 commit b20678b

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

python/tvm/expr.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import absolute_import as _abs
1919
from ._ffi.node import NodeBase, register_node
2020
from . import make as _make
21+
from . import _api_internal
2122

2223
class ExprOp(object):
2324
def __add__(self, other):
@@ -60,7 +61,8 @@ def __mod__(self, other):
6061
return _make.Mod(self, other)
6162

6263
def __neg__(self):
63-
return self.__mul__(-1)
64+
neg_one = _api_internal._const(-1, self.dtype)
65+
return self.__mul__(neg_one)
6466

6567
def __lshift__(self, other):
6668
return _make.Call(self.dtype, "shift_left", [self, other], Call.PureIntrinsic, None, 0)

topi/python/topi/nn/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def relu(x):
1717
y : tvm.Tensor
1818
The result.
1919
"""
20-
return tvm.compute(x.shape, lambda *i: tvm.max(x(*i), 0))
20+
return tvm.compute(x.shape, lambda *i: tvm.max(x(*i), tvm.const(0, x.dtype)))
2121

2222

2323
@tvm.tag_scope(tag=tag.ELEMWISE)

topi/python/topi/nn/pooling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def global_pool(data, pool_type):
3838
tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]), \
3939
tag="global_pool_sum")
4040
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
41-
tsum[n, c, h, w] / (height*width), \
41+
tsum[n, c, h, w] / (height*width).astype(tsum.dtype), \
4242
tag=tag.ELEMWISE)
4343
else:
4444
raise ValueError("Pool type should be 'avg' or 'max'.")

0 commit comments

Comments
 (0)