Skip to content

Commit 28d32b5

Browse files
author
Siyuan Feng
authored
[TIR] Support narrow dtype for let binding (#16947)
The current pass `ForceNarrowIndexToI32` fails to narrow dtype for let binding. This PR fixes the issue. BTW, this PR addresses the comments in #16934
1 parent 876f528 commit 28d32b5

File tree

5 files changed

+60
-13
lines changed

5 files changed

+60
-13
lines changed

include/tvm/tir/data_type_rewriter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
110110
Stmt VisitStmt_(const IfThenElseNode* op) override;
111111
Stmt VisitStmt_(const DeclBufferNode* op) override;
112112
Stmt VisitStmt_(const AllocateNode* op) override;
113+
Stmt VisitStmt_(const LetStmtNode* op) override;
113114
PrimExpr VisitExpr_(const EQNode* op) override;
114115
PrimExpr VisitExpr_(const NENode* op) override;
115116
PrimExpr VisitExpr_(const LTNode* op) override;

python/tvm/relax/backend/dispatch_sort_scan.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,13 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
155155
tgt = self._get_target(call.struct_info)
156156
axis = int(call.attrs.axis) if call.attrs.axis is not None else call.attrs.axis
157157
shape = call.struct_info.shape
158+
# TODO(tvm-team): Support fully dynamic case with `shape=None`
159+
if shape is None:
160+
raise ValueError("non-symbolic shape is not supported for now")
158161
kwargs = {}
159162
if (
160-
(axis == -1 or axis == len(shape) - 1)
163+
shape is not None
164+
and (axis == -1 or axis == len(shape) - 1)
161165
and is_gpu_target(tgt)
162166
and not can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan")
163167
and call.op.name == "relax.cumsum"

src/tir/ir/data_type_rewriter.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
#include <tvm/tir/op.h>
2828

2929
#include "./functor_common.h"
30+
#include "tvm/ir/expr.h"
31+
#include "tvm/tir/expr.h"
32+
#include "tvm/tir/stmt.h"
33+
#include "tvm/tir/var.h"
3034

3135
namespace tvm {
3236
namespace tir {
@@ -558,6 +562,21 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) {
558562
}
559563
}
560564

565+
Stmt IndexDataTypeRewriter::VisitStmt_(const LetStmtNode* op) {
566+
LetStmt let_stmt = Downcast<LetStmt>(DataTypeLegalizer::VisitStmt_(op));
567+
if (var_remap_.find(let_stmt->var.get()) == var_remap_.end()) {
568+
return let_stmt;
569+
}
570+
bool is_enabled = is_enabled_;
571+
is_enabled_ = true;
572+
PrimExpr value = VisitExpr(op->value);
573+
Var var = var_remap_[let_stmt->var.get()];
574+
is_enabled_ = is_enabled;
575+
ICHECK(value.dtype() == var.dtype());
576+
// No need to re-visit body
577+
return LetStmt(var, value, let_stmt->body, let_stmt->span);
578+
}
579+
561580
#define TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \
562581
PrimExpr IndexDataTypeRewriter::VisitExpr_(const OP* op) { \
563582
bool is_enabled = is_enabled_; \

tests/python/relax/test_backend_dispatch_sort_scan.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def foo2(y: R.Tensor((2, 3), "float32")):
273273
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
274274
workspace = bb.emit(
275275
relax.op.builtin.alloc_tensor(
276-
R.shape([4194568]), R.dtype("uint8"), R.prim_value(0), R.str("global")
276+
R.shape([8388872]), R.dtype("uint8"), R.prim_value(0), R.str("global")
277277
)
278278
)
279279
out = bb.emit_te(
@@ -400,8 +400,8 @@ def foo(x: R.Tensor((2, 3), "float32", "vulkan")):
400400
assert_structural_equal(mod, expected_mod)
401401

402402

403-
@tvm.testing.requires_cuda
404-
def test_dispatch_cumsum_gpu():
403+
@tvm.testing.parametrize_targets("cuda", "vulkan -supports_int64=1")
404+
def test_dispatch_cumsum_gpu(target, dev):
405405
"""Test cumsum kernel dispatch and numerical correctness"""
406406

407407
@I.ir_module
@@ -416,15 +416,13 @@ def main(x: R.Tensor(("m", "n"), "int32")):
416416
size = (8, 2000)
417417
np_data = np.random.randint(0, 10, size).astype("int32")
418418
np_cumsum = np.cumsum(np_data, axis=-1)
419-
for target in ["cuda", "vulkan -supports_int64=1"]:
420-
with tvm.target.Target(target):
421-
mod = DispatchSortScan()(Module)
422-
ex = tvm.relax.build(mod, target)
423-
device = tvm.device(target, 0)
424-
vm = tvm.relax.VirtualMachine(ex, device)
425-
tvm_data = tvm.nd.array(np_data, device)
426-
cumsum = vm["main"](tvm_data)
427-
tvm.testing.assert_allclose(cumsum.numpy(), np_cumsum)
419+
with tvm.target.Target(target):
420+
mod = DispatchSortScan()(Module)
421+
ex = tvm.relax.build(mod, target)
422+
vm = tvm.relax.VirtualMachine(ex, dev)
423+
tvm_data = tvm.nd.array(np_data, dev)
424+
cumsum = vm["main"](tvm_data)
425+
tvm.testing.assert_allclose(cumsum.numpy(), np_cumsum)
428426

429427

430428
if __name__ == "__main__":

tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,5 +278,30 @@ def main(B: T.Buffer((4,), "int32")):
278278
tvm.ir.assert_structural_equal(Expected, after)
279279

280280

281+
def test_let_binding():
282+
@tvm.script.ir_module
283+
class Before:
284+
@T.prim_func
285+
def main(buf: T.handle):
286+
n = T.int64()
287+
Buf = T.match_buffer(buf, [n], "int32")
288+
ceil_log2 = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n))))
289+
for i in T.serial(ceil_log2):
290+
T.evaluate(0)
291+
292+
@tvm.script.ir_module
293+
class Expected:
294+
@T.prim_func
295+
def main(buf: T.handle):
296+
n = T.int32()
297+
Buf = T.match_buffer(buf, [n], "int32")
298+
ceil_log2 = T.Cast("int32", T.ceil(T.log2(T.Cast("float32", n))))
299+
for i in range(ceil_log2):
300+
T.evaluate(0)
301+
302+
after = tvm.tir.transform.ForceNarrowIndexToInt32()(Before)
303+
tvm.ir.assert_structural_equal(Expected, after)
304+
305+
281306
if __name__ == "__main__":
282307
tvm.testing.main()

0 commit comments

Comments
 (0)