From e1cb085a5c9bc21fa1df4a44ae0d3f67aaee136e Mon Sep 17 00:00:00 2001 From: ganler Date: Tue, 12 Apr 2022 15:38:02 -0500 Subject: [PATCH 1/5] resolve int64/32 for AttrStmtNode --- src/tir/transforms/narrow_datatype.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index c2bf27393173..8504f407bec2 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -29,6 +29,7 @@ #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h" +#include "tvm/runtime/logging.h" namespace tvm { namespace tir { @@ -276,7 +277,13 @@ class DataTypeRewriter : public StmtExprMutator { PrimExpr e = VisitExpr(iv->var); Var var = Downcast(e); if (ivmap_.find(iv) == ivmap_.end()) { - ivmap_[iv] = IterVar(iv->dom, var, iv->iter_type, iv->thread_tag); + Range dom = iv->dom; + if (dom.defined()) { + DataType vi_dtype = dom->extent.dtype(); + if (vi_dtype.is_int() && vi_dtype.bits() < var.dtype().bits()) + dom = Range(cast(var.dtype(), dom->min), cast(var.dtype(), dom->extent), dom->span); + } + ivmap_[iv] = IterVar(dom, var, iv->iter_type, iv->thread_tag); } return AttrStmt(ivmap_[iv], op->attr_key, cast(var.dtype(), op->value), op->body); } From 17da5db153525a30b842b5fc93fa454d0400d77c Mon Sep 17 00:00:00 2001 From: ganler Date: Tue, 12 Apr 2022 15:39:28 -0500 Subject: [PATCH 2/5] rm debug header --- src/tir/transforms/narrow_datatype.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 8504f407bec2..625a7c5d5c5b 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -29,7 +29,6 @@ #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h" -#include "tvm/runtime/logging.h" namespace tvm { namespace tir { From eeaf5ce96a71f7f3427beaa393ab4d66f66f4ffa Mon Sep 17 00:00:00 2001 From: ganler Date: Tue, 12 Apr 2022 17:55:27 -0500 Subject: [PATCH 3/5] refine --- src/tir/transforms/narrow_datatype.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 625a7c5d5c5b..8df7b57eafde 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -278,9 +278,13 @@ class DataTypeRewriter : public StmtExprMutator { if (ivmap_.find(iv) == ivmap_.end()) { Range dom = iv->dom; if (dom.defined()) { - DataType vi_dtype = dom->extent.dtype(); - if (vi_dtype.is_int() && vi_dtype.bits() < var.dtype().bits()) - dom = Range(cast(var.dtype(), dom->min), cast(var.dtype(), dom->extent), dom->span); + PrimExpr extend = dom->extent; + if (extend.dtype().is_int() && var.dtype().is_int() && + var.dtype().bits() != extend.dtype().bits()) { + int bits = std::max(extend.dtype().bits(), var.dtype().bits()); + DataType dtype = var.dtype().with_bits(bits); + dom = Range(cast(dtype, dom->min), cast(dtype, extend), dom->span); + } } ivmap_[iv] = IterVar(dom, var, iv->iter_type, iv->thread_tag); } From fd40e92912b7c88ebba8f60c96af31f13c07aefd Mon Sep 17 00:00:00 2001 From: ganler Date: Tue, 12 Apr 2022 23:25:47 -0500 Subject: [PATCH 4/5] add test case --- tests/python/relay/test_op_level10.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 0486ef40017b..8fbd8434d0e5 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -228,6 +228,21 @@ def test_broadcast_to(): op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu +def test_broadcast_to_const_shape_int64(): + shape_like = relay.const(np.array([1, 5]), dtype="int64") + x = relay.var("x", shape=(1,), dtype="int64") + z = relay.broadcast_to(x, shape=shape_like) + z = relay.sum(z, axis=0) + + f = relay.Function([x], z) + + x = np.random.randint(10, size=(1,), dtype="int64") + ref_res = np.broadcast_to(x, (5,)) + for target, dev in tvm.testing.enabled_targets(): + for kind in ["graph", "debug"]: + op_res = relay.create_executor(kind, device=dev, target=target).evaluate(f)(x) + tvm.testing.assert_allclose(op_res.numpy(), ref_res) @tvm.testing.uses_gpu def test_broadcast_to_like(): From c8ae6a42db895bcf4680a5c4f05851ddcb899e61 Mon Sep 17 00:00:00 2001 From: ganler Date: Wed, 13 Apr 2022 01:33:05 -0500 Subject: [PATCH 5/5] lint --- tests/python/relay/test_op_level10.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 8fbd8434d0e5..85a3dd5636f1 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -228,6 +228,7 @@ def test_broadcast_to(): op_res = relay.create_executor(kind, device=dev, target=target).evaluate(func)(x) tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) + @tvm.testing.uses_gpu def test_broadcast_to_const_shape_int64(): shape_like = relay.const(np.array([1, 5]), dtype="int64") @@ -244,6 +245,7 @@ def test_broadcast_to_const_shape_int64(): op_res = relay.create_executor(kind, device=dev, target=target).evaluate(f)(x) tvm.testing.assert_allclose(op_res.numpy(), ref_res) + @tvm.testing.uses_gpu def test_broadcast_to_like(): shape = (4, 1, 6)