|
23 | 23 | from tvm.relay import analysis, op, transform |
24 | 24 | from tvm.relay.op import op as _op |
25 | 25 |
|
| 26 | +import numpy as np |
| 27 | + |
26 | 28 |
|
27 | 29 | def infer_mod(mod, annotate_spans=True): |
28 | 30 | if annotate_spans: |
@@ -544,6 +546,42 @@ def test_repeat_register(): |
544 | 546 | assert "Operator custom_log3 is registered before" in str(cm.execption) |
545 | 547 |
|
546 | 548 |
|
| 549 | +def test_argreduce_infer_return_type(): |
| 550 | + x_shape = (1, 1) |
| 551 | + broadcast_shape = [1, 1] |
| 552 | + shape_dtypes = [("int32", lambda x: np.int32(x)), ("int64", lambda x: np.int64(x))] |
| 553 | + |
| 554 | + # Testing with argmax |
| 555 | + for (sdtype, conv) in shape_dtypes: |
| 556 | + x = relay.var("data", relay.TensorType(x_shape, "float32")) |
| 557 | + broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) |
| 558 | + argmax = relay.op.argmax(broadcast_to, axis=[1]) |
| 559 | + |
| 560 | + f = relay.Function([x], argmax) |
| 561 | + assert_has_type( |
| 562 | + f, |
| 563 | + relay.FuncType( |
| 564 | + [relay.TensorType(broadcast_shape, "float32")], |
| 565 | + relay.TensorType([conv(1)], dtype=sdtype), |
| 566 | + ), |
| 567 | + ) |
| 568 | + |
| 569 | + # Testing with argmin |
| 570 | + for (sdtype, conv) in shape_dtypes: |
| 571 | + x = relay.var("data", relay.TensorType(x_shape, "float32")) |
| 572 | + broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) |
| 573 | + argmin = relay.op.argmin(broadcast_to, axis=[1]) |
| 574 | + |
| 575 | + f = relay.Function([x], argmin) |
| 576 | + assert_has_type( |
| 577 | + f, |
| 578 | + relay.FuncType( |
| 579 | + [relay.TensorType(broadcast_shape, "float32")], |
| 580 | + relay.TensorType([conv(1)], dtype=sdtype), |
| 581 | + ), |
| 582 | + ) |
| 583 | + |
| 584 | + |
547 | 585 | if __name__ == "__main__": |
548 | 586 | import sys |
549 | 587 |
|
|
0 commit comments