Skip to content

Commit e10cdc5

Browse files
wrongtest-intellifwrongtest
andauthored
[tir][Compute-at] Make compute-ated block simple when the predicate could be merged (#16945)
make compute-ated block simple when the predicate could be merged as static loop domain Co-authored-by: wrongtest <[email protected]>
1 parent b00fc55 commit e10cdc5

File tree

3 files changed

+78
-9
lines changed

3 files changed

+78
-9
lines changed

src/tir/schedule/primitive/compute_at.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,10 @@ struct BlockVarDomainInfo {
224224
analyzer->CanProveEqual(bound.max(), intersect.max())) {
225225
dom = bound;
226226
bound = arith::IntSet::Nothing();
227+
} else if (is_const_int(intersect.min()) && is_const_int(intersect.max())) {
228+
// if the bound induce constant iter range, merge bound to loop domain
229+
dom = intersect;
230+
bound = arith::IntSet::Nothing();
227231
}
228232
}
229233
};

src/tir/schedule/primitive/decompose_padding.cc

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -393,15 +393,6 @@ class DecomposePaddingBlockReplacer : public StmtMutator {
393393
return std::move(new_loop);
394394
}
395395

396-
Stmt VisitStmt_(const SeqStmtNode* seq) final {
397-
Array<Stmt> new_stmts;
398-
new_stmts.reserve(seq->seq.size());
399-
for (const Stmt& old_stmt : seq->seq) {
400-
new_stmts.push_back(VisitStmt(old_stmt));
401-
}
402-
return SeqStmt::Flatten(new_stmts);
403-
}
404-
405396
private:
406397
const ReplaceDesc& desc_;
407398
};

tests/python/tir-schedule/test_tir_schedule_compute_at.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1915,5 +1915,79 @@ def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle):
19151915
)
19161916

19171917

1918+
def test_compute_at_sliced_concatenate():
1919+
@T.prim_func
1920+
def before():
1921+
X = T.alloc_buffer((1, 16, 28, 64), "float32")
1922+
Y = T.alloc_buffer((1, 32, 28, 64), "float32")
1923+
Z = T.alloc_buffer((1, 53, 28, 64), "float32")
1924+
Concat = T.alloc_buffer((1, 101, 28, 64), "float32")
1925+
Slice = T.alloc_buffer((1, 87, 28, 64), "float32")
1926+
for ax0, ax1, ax2, ax3 in T.grid(1, 16, 28, 64):
1927+
with T.block("compute"):
1928+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
1929+
X[v_ax0, v_ax1, v_ax2, v_ax3] = 1.0
1930+
for ax0, ax1, ax2, ax3 in T.grid(1, 101, 28, 64):
1931+
with T.block("T_concat"):
1932+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
1933+
Concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(
1934+
85 <= v_ax1,
1935+
X[v_ax0, v_ax1 - 85, v_ax2, v_ax3],
1936+
T.if_then_else(
1937+
53 <= v_ax1,
1938+
Y[v_ax0, v_ax1 - 53, v_ax2, v_ax3],
1939+
Z[v_ax0, v_ax1, v_ax2, v_ax3],
1940+
),
1941+
)
1942+
for ax0, ax1, ax2, ax3 in T.grid(1, 87, 28, 64):
1943+
with T.block("T_strided_slice"):
1944+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
1945+
Slice[v_ax0, v_ax1, v_ax2, v_ax3] = Concat[v_ax0, v_ax1, v_ax2, v_ax3]
1946+
1947+
@T.prim_func
1948+
def expect():
1949+
X = T.alloc_buffer((1, 16, 28, 64))
1950+
Y = T.alloc_buffer((1, 32, 28, 64))
1951+
Z = T.alloc_buffer((1, 53, 28, 64))
1952+
Concat = T.alloc_buffer((1, 101, 28, 64))
1953+
Slice = T.alloc_buffer((1, 87, 28, 64))
1954+
for ax0 in range(1):
1955+
for ax0_1, ax1, ax2 in T.grid(2, 28, 64):
1956+
with T.block("compute"):
1957+
v_ax0 = T.axis.spatial(1, 0)
1958+
v_ax1 = T.axis.spatial(16, ax0_1)
1959+
v_ax2, v_ax3 = T.axis.remap("SS", [ax1, ax2])
1960+
X[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(1)
1961+
for ax0_1, ax1, ax2 in T.grid(87, 28, 64):
1962+
with T.block("T_concat"):
1963+
v_ax0 = T.axis.spatial(1, 0)
1964+
v_ax1 = T.axis.spatial(101, ax0_1)
1965+
v_ax2, v_ax3 = T.axis.remap("SS", [ax1, ax2])
1966+
Concat[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(
1967+
85 <= v_ax1,
1968+
X[v_ax0, v_ax1 - 85, v_ax2, v_ax3],
1969+
T.if_then_else(
1970+
53 <= v_ax1,
1971+
Y[v_ax0, v_ax1 - 53, v_ax2, v_ax3],
1972+
Z[v_ax0, v_ax1, v_ax2, v_ax3],
1973+
),
1974+
)
1975+
for ax1, ax2, ax3 in T.grid(87, 28, 64):
1976+
with T.block("T_strided_slice"):
1977+
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
1978+
Slice[v_ax0, v_ax1, v_ax2, v_ax3] = Concat[v_ax0, v_ax1, v_ax2, v_ax3]
1979+
1980+
sch = tir.Schedule(before, debug_mask="all")
1981+
blk1 = sch.get_block("compute")
1982+
blk2 = sch.get_block("T_concat")
1983+
blk3 = sch.get_block("T_strided_slice")
1984+
loop = sch.get_loops(blk3)[0]
1985+
sch.compute_at(blk2, loop)
1986+
sch.compute_at(blk1, loop)
1987+
after = sch.mod["main"]
1988+
assert_structural_equal_ignore_global_symbol(expect, after)
1989+
verify_trace_roundtrip(sch=sch, mod=before)
1990+
1991+
19181992
if __name__ == "__main__":
19191993
tvm.testing.main()

0 commit comments

Comments
 (0)