diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index fba2a60cecb2..2b1afc6e55f2 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -338,7 +338,7 @@ bool GenericReduceRel(const Array& types, int num_inputs, const Attrs& att // assign output type and shape auto oshape = ReduceShapeImpl(in_shape, param, reporter); - reporter->Assign(types[1], TensorType(oshape, DataType::Int(32))); + reporter->Assign(types[1], TensorType(oshape, data->shape[0].dtype())); return true; } /*! diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index af64ce714df8..b0b7ef048192 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -23,6 +23,8 @@ from tvm.relay import analysis, op, transform from tvm.relay.op import op as _op +import numpy as np + def infer_mod(mod, annotate_spans=True): if annotate_spans: @@ -544,6 +546,42 @@ def test_repeat_register(): assert "Operator custom_log3 is registered before" in str(cm.execption) +def test_argreduce_infer_return_type(): + x_shape = (1, 1) + broadcast_shape = [1, 1] + shape_dtypes = [("int32", lambda x: np.int32(x)), ("int64", lambda x: np.int64(x))] + + # Testing with argmax + for (sdtype, conv) in shape_dtypes: + x = relay.var("data", relay.TensorType(x_shape, "float32")) + broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) + argmax = relay.op.argmax(broadcast_to, axis=[1]) + + f = relay.Function([x], argmax) + assert_has_type( + f, + relay.FuncType( + [relay.TensorType(broadcast_shape, "float32")], + relay.TensorType([conv(1)], dtype=sdtype), + ), + ) + + # Testing with argmin + for (sdtype, conv) in shape_dtypes: + x = relay.var("data", relay.TensorType(x_shape, "float32")) + broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) + argmin = relay.op.argmin(broadcast_to, axis=[1]) + + f = relay.Function([x], argmin) + assert_has_type( + f, + relay.FuncType( + [relay.TensorType(broadcast_shape, "float32")], + relay.TensorType([conv(1)], dtype=sdtype), + ), + ) + + if __name__ == "__main__": import sys