Skip to content

Commit fbd9fe4

Browse files
authored
Fix division truncation in window size calculation for small dtypes in average_pool (#18014)
* Update pooling.h * Update test_te_create_primfunc.py * fix lint error
1 parent 2d964b4 commit fbd9fe4

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

include/tvm/topi/nn/pooling.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_
383383

384384
PrimExpr divide_factor = tvm::cast(x->dtype, 1);
385385
for (size_t i = 0; i < n_dim; ++i) {
386-
divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
386+
divide_factor *= tvm::cast(DataType::Int(32), reduce_axes[i]->dom->extent);
387387
}
388388

389389
return div(pool_sum(indices), divide_factor);

tests/python/te/test_te_create_primfunc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,14 @@ def te_workload():
882882
_check_workload(te_workload, tir_workload)
883883

884884

885+
def test_global_pool():
886+
# fix the issue-17938
887+
data = te.placeholder((1, 1, 32, 32), dtype="int8", name="data")
888+
op_output = topi.nn.global_pool(data=data, pool_type="avg", layout="NCHW")
889+
f = te.create_prim_func([data, op_output])
890+
assert f
891+
892+
885893
def test_nested_reduce_domain_dependency():
886894
@T.prim_func
887895
def tir_workload(

0 commit comments

Comments
 (0)