diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index abe26b6c6727..8e13ae49afdf 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -383,7 +383,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const Array& output_ PrimExpr divide_factor = tvm::cast(x->dtype, 1); for (size_t i = 0; i < n_dim; ++i) { - divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent); + divide_factor *= tvm::cast(DataType::Int(32), reduce_axes[i]->dom->extent); } return div(pool_sum(indices), divide_factor); diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index 9925f54be4db..b0850a89b5c5 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -882,6 +882,14 @@ def te_workload(): _check_workload(te_workload, tir_workload) +def test_global_pool(): + # fix the issue-17938 + data = te.placeholder((1, 1, 32, 32), dtype="int8", name="data") + op_output = topi.nn.global_pool(data=data, pool_type="avg", layout="NCHW") + f = te.create_prim_func([data, op_output]) + assert f + + def test_nested_reduce_domain_dependency(): @T.prim_func def tir_workload(