Skip to content

Commit 35551d4

Browse files
authored
[Unity] Support cumsum with pure int32 (#16439)
This PR fixes a bug on attr handling in data type rewriter and enforces i32 buffer in cumsum function definition, which ensures that cumsum can be run on a machine with int32 but not int64.
1 parent 4e754a7 commit 35551d4

File tree

3 files changed

+24
-12
lines changed

3 files changed

+24
-12
lines changed

include/tvm/tir/data_type_rewriter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
104104
Stmt VisitStmt_(const BlockRealizeNode* op) override;
105105
Stmt VisitStmt_(const BlockNode* op) override;
106106
Stmt VisitStmt_(const BufferStoreNode* op) override;
107+
Stmt VisitStmt_(const AttrStmtNode* op) override;
107108
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
108109
Array<PrimExpr> VisitIndices(Array<PrimExpr> indices);
109110
Stmt VisitStmt_(const IfThenElseNode* op) override;

python/tvm/topi/cuda/scan.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
6060
your operation.
6161
"""
6262

63-
batch_size = prod(data.shape[:-1])
64-
scan_axis_size = data.shape[-1]
63+
batch_size = cast(prod(data.shape[:-1]), "int32")
64+
scan_axis_size = cast(data.shape[-1], "int32")
6565

6666
ib = tvm.tir.ir_builder.create()
6767

@@ -105,7 +105,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
105105
# Up Sweep of exclusive scan
106106
lim = ceil_log2(scan_axis_size)
107107

108-
with ib.for_range(0, cast(lim, "int64"), dtype="int64") as l2_width:
108+
with ib.for_range(0, cast(lim, "int32"), dtype="int32") as l2_width:
109109
width = 2 << l2_width
110110

111111
with ib.new_scope():
@@ -121,9 +121,9 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
121121

122122
by = te.thread_axis("blockIdx.y")
123123
ib.scope_attr(by, "thread_extent", nthread_by)
124-
start = ib.allocate("int64", (1,), name="start", scope="local")
125-
middle = ib.allocate("int64", (1,), name="middle", scope="local")
126-
end = ib.allocate("int64", (1,), name="end", scope="local")
124+
start = ib.allocate("int32", (1,), name="start", scope="local")
125+
middle = ib.allocate("int32", (1,), name="middle", scope="local")
126+
end = ib.allocate("int32", (1,), name="end", scope="local")
127127
start[0] = width * tid
128128
with ib.if_scope(start[0] < scan_axis_size):
129129
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
@@ -143,7 +143,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
143143
reduction[bx] = output[(bx + 1) * scan_axis_size - 1]
144144
output[(bx + 1) * scan_axis_size - 1] = cast(identity_value, out_dtype)
145145

146-
with ib.for_range(0, cast(lim, "int64"), dtype="int64") as l2_width:
146+
with ib.for_range(0, cast(lim, "int32"), dtype="int32") as l2_width:
147147
width = 2 << (lim - l2_width - 1)
148148

149149
with ib.new_scope():
@@ -159,9 +159,9 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
159159

160160
by = te.thread_axis("blockIdx.y")
161161
ib.scope_attr(by, "thread_extent", nthread_by)
162-
start = ib.allocate("int64", (1,), name="start", scope="local")
163-
middle = ib.allocate("int64", (1,), name="middle", scope="local")
164-
end = ib.allocate("int64", (1,), name="end", scope="local")
162+
start = ib.allocate("int32", (1,), name="start", scope="local")
163+
middle = ib.allocate("int32", (1,), name="middle", scope="local")
164+
end = ib.allocate("int32", (1,), name="end", scope="local")
165165
tmp = ib.allocate(out_dtype, (1,), name="end", scope="local")
166166
start[0] = width * tid
167167
with ib.if_scope(tvm.tir.all(start[0] < scan_axis_size)):
@@ -206,8 +206,8 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output, binop=tvm.tir.generi
206206
ex_scan_output = expand_dims(ex_scan_output, axis=0)
207207

208208
def ir(data, data_ex_scan, reduction):
209-
batch_size = prod(data.shape[:-1])
210-
scan_axis_size = data.shape[-1]
209+
batch_size = cast(prod(data.shape[:-1]), "int32")
210+
scan_axis_size = cast(data.shape[-1], "int32")
211211

212212
ib = tvm.tir.ir_builder.create()
213213

src/tir/ir/data_type_rewriter.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,17 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) {
258258
}
259259
}
260260

261+
Stmt IndexDataTypeRewriter::VisitStmt_(const AttrStmtNode* op) {
262+
if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
263+
bool is_enabled = is_enabled_;
264+
is_enabled_ = true;
265+
auto stmt = DataTypeLegalizer::VisitStmt_(op);
266+
is_enabled_ = is_enabled;
267+
return stmt;
268+
}
269+
return DataTypeLegalizer::VisitStmt_(op);
270+
}
271+
261272
Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) {
262273
Buffer new_buffer = VisitBuffer(op->buffer);
263274
DeclBuffer decl_buffer = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));

0 commit comments

Comments
 (0)