diff --git a/src/op/reduce.cc b/src/op/reduce.cc index b56bd36b3..f12d1c9f5 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -44,15 +44,32 @@ ReduceOp::ReduceOp(Array args, BufferMap vmap) { } PrimExpr ReduceOp::MakeInitValue() const { + auto dst_dtype = dst->dtype; + auto is_int = dst_dtype.is_int(); + bool is_uint = dst_dtype.is_uint(); + auto bits = dst_dtype.bits(); + switch (type) { case ReduceType::kSum: return make_zero(dst->dtype); case ReduceType::kAbsSum: return make_zero(dst->dtype); case ReduceType::kMax: - return make_const(dst->dtype, -INFINITY); + if (is_int) { + return make_const(dst->dtype, -(1 << (bits - 1))); + } else if (is_uint) { + return make_const(dst->dtype, 0); + } else { + return make_const(dst->dtype, -INFINITY); + } case ReduceType::kMin: - return make_const(dst->dtype, INFINITY); + if (is_int) { + return make_const(dst->dtype, (1 << (bits - 1)) - 1); + } else if (is_uint) { + return make_const(dst->dtype, (1 << bits) - 1); + } else { + return make_const(dst->dtype, INFINITY); + } case ReduceType::kAbsMax: return make_const(dst->dtype, 0); default: