Skip to content

Commit b9fa576

Browse files
authored
[BugFix] Use shape dtype on ArgReduce to determine return type (#12083)
Fix ArgReduce automatic return type inference by forcing it to use the datatype of the shape of the Tensor instead of the fixed Int32. Including additional tests.
1 parent de3c0f4 commit b9fa576

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

src/relay/op/tensor/reduce.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ bool GenericReduceRel(const Array<Type>& types, int num_inputs, const Attrs& att
338338

339339
// assign output type and shape
340340
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
341-
reporter->Assign(types[1], TensorType(oshape, DataType::Int(32)));
341+
reporter->Assign(types[1], TensorType(oshape, data->shape[0].dtype()));
342342
return true;
343343
}
344344
/*!

tests/python/relay/test_type_infer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from tvm.relay import analysis, op, transform
2424
from tvm.relay.op import op as _op
2525

26+
import numpy as np
27+
2628

2729
def infer_mod(mod, annotate_spans=True):
2830
if annotate_spans:
@@ -544,6 +546,42 @@ def test_repeat_register():
544546
assert "Operator custom_log3 is registered before" in str(cm.execption)
545547

546548

549+
def test_argreduce_infer_return_type():
550+
x_shape = (1, 1)
551+
broadcast_shape = [1, 1]
552+
shape_dtypes = [("int32", lambda x: np.int32(x)), ("int64", lambda x: np.int64(x))]
553+
554+
# Testing with argmax
555+
for (sdtype, conv) in shape_dtypes:
556+
x = relay.var("data", relay.TensorType(x_shape, "float32"))
557+
broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype))
558+
argmax = relay.op.argmax(broadcast_to, axis=[1])
559+
560+
f = relay.Function([x], argmax)
561+
assert_has_type(
562+
f,
563+
relay.FuncType(
564+
[relay.TensorType(broadcast_shape, "float32")],
565+
relay.TensorType([conv(1)], dtype=sdtype),
566+
),
567+
)
568+
569+
# Testing with argmin
570+
for (sdtype, conv) in shape_dtypes:
571+
x = relay.var("data", relay.TensorType(x_shape, "float32"))
572+
broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype))
573+
argmin = relay.op.argmin(broadcast_to, axis=[1])
574+
575+
f = relay.Function([x], argmin)
576+
assert_has_type(
577+
f,
578+
relay.FuncType(
579+
[relay.TensorType(broadcast_shape, "float32")],
580+
relay.TensorType([conv(1)], dtype=sdtype),
581+
),
582+
)
583+
584+
547585
if __name__ == "__main__":
548586
import sys
549587

0 commit comments

Comments
 (0)