Skip to content

Commit a3a903b

Browse files
committed
[BugFix] Use shape dtype on ArgReduce to determine return type
Fix ArgReduce automatic return type inference by forcing it to use the datatype of the shape of the Tensor instead of the fixed Int32. Including additional tests.
1 parent 9ee25eb commit a3a903b

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

src/relay/op/tensor/reduce.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ bool GenericReduceRel(const Array<Type>& types, int num_inputs, const Attrs& att
338338

339339
// assign output type and shape
340340
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
341-
reporter->Assign(types[1], TensorType(oshape, DataType::Int(32)));
341+
reporter->Assign(types[1], TensorType(oshape, data->shape[0].dtype()));
342342
return true;
343343
}
344344
/*!

tests/python/relay/test_type_infer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from tvm.relay import analysis, op, transform
2424
from tvm.relay.op import op as _op
2525

26+
import numpy as np
27+
2628

2729
def infer_mod(mod, annotate_spans=True):
2830
if annotate_spans:
@@ -544,6 +546,42 @@ def test_repeat_register():
544546
assert "Operator custom_log3 is registered before" in str(cm.execption)
545547

546548

549+
def test_argreduce_infer_return_type():
550+
x_shape = (1, 1)
551+
broadcast_shape = [1, 1]
552+
shape_dtypes = [("int32", lambda x: np.int32(x)), ("int64", lambda x: np.int64(x))]
553+
554+
# Testing with argmax
555+
for (sdtype, conv) in shape_dtypes:
556+
x = relay.var("data", relay.TensorType(x_shape, "float32"))
557+
broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype))
558+
argmax = relay.op.argmax(broadcast_to, axis=[1])
559+
560+
f = relay.Function([x], argmax)
561+
assert_has_type(
562+
f,
563+
relay.FuncType(
564+
[relay.TensorType(broadcast_shape, "float32")],
565+
relay.TensorType([conv(1)], dtype=sdtype),
566+
),
567+
)
568+
569+
# Testing with argmin
570+
for (sdtype, conv) in shape_dtypes:
571+
x = relay.var("data", relay.TensorType(x_shape, "float32"))
572+
broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype))
573+
argmin = relay.op.argmin(broadcast_to, axis=[1])
574+
575+
f = relay.Function([x], argmin)
576+
assert_has_type(
577+
f,
578+
relay.FuncType(
579+
[relay.TensorType(broadcast_shape, "float32")],
580+
relay.TensorType([conv(1)], dtype=sdtype),
581+
),
582+
)
583+
584+
547585
if __name__ == "__main__":
548586
import sys
549587

0 commit comments

Comments
 (0)