Skip to content

Commit 3b44ff1

Browse files
committed
Add test case where PrimFunc is used both in-place and DPS
1 parent 0a836ab commit 3b44ff1

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

tests/python/relax/test_transform_fuse_tir.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,5 +2162,97 @@ def main(
21622162
_check(Module, Expected)
21632163

21642164

2165+
def test_use_as_inplace_and_dps():
2166+
@I.ir_module
2167+
class Module:
2168+
# we will use it both in-place and normally (DPS)
2169+
@T.prim_func(private=True)
2170+
def add(
2171+
A: T.Buffer((T.int64(10), T.int64(20)), "float32"),
2172+
B: T.Buffer((), "float32"),
2173+
Out: T.Buffer((T.int64(10), T.int64(20)), "float32"),
2174+
):
2175+
T.func_attr({"tir.noalias": T.bool(True)})
2176+
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
2177+
with T.block("T_add"):
2178+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
2179+
Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()]
2180+
2181+
@R.function(private=True)
2182+
def fused_sums(
2183+
x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32")
2184+
) -> R.Tensor((10, 20), dtype="float32"):
2185+
R.func_attr({"Primitive": 1})
2186+
cls = Module
2187+
with R.dataflow():
2188+
lv = R.call_tir(
2189+
cls.add,
2190+
(x, p0),
2191+
out_sinfo=R.Tensor((10, 20), dtype="float32"),
2192+
)
2193+
lv1 = R.call_tir_inplace(
2194+
cls.add,
2195+
(x, p0, lv),
2196+
inplace_indices=[2],
2197+
out_sinfo=R.Tensor((10, 20), dtype="float32"),
2198+
)
2199+
lv2 = R.call_tir_inplace(
2200+
cls.add,
2201+
(x, p0, lv1),
2202+
inplace_indices=[2],
2203+
out_sinfo=R.Tensor((10, 20), dtype="float32"),
2204+
)
2205+
R.output(lv2)
2206+
return lv2
2207+
2208+
@R.function
2209+
def main(
2210+
x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32")
2211+
) -> R.Tensor((10, 20), dtype="float32"):
2212+
cls = Module
2213+
with R.dataflow():
2214+
gv1: R.Tensor((10, 20), dtype="float32") = cls.fused_sums(x, p0)
2215+
R.output(gv1)
2216+
return gv1
2217+
2218+
@I.ir_module
2219+
class Expected:
2220+
@T.prim_func(private=True)
2221+
def fused_sums(
2222+
x: T.Buffer((T.int64(10), T.int64(20)), "float32"),
2223+
p0: T.Buffer((), "float32"),
2224+
p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"),
2225+
):
2226+
T.func_attr({"tir.noalias": T.bool(True)})
2227+
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
2228+
with T.block("T_add"):
2229+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
2230+
p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
2231+
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
2232+
with T.block("T_add"):
2233+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
2234+
p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
2235+
for ax0, ax1 in T.grid(T.int64(10), T.int64(20)):
2236+
with T.block("T_add"):
2237+
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
2238+
p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()]
2239+
2240+
@R.function
2241+
def main(
2242+
x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32")
2243+
) -> R.Tensor((10, 20), dtype="float32"):
2244+
cls = Expected
2245+
with R.dataflow():
2246+
gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir(
2247+
cls.fused_sums,
2248+
(x, p0),
2249+
out_sinfo=R.Tensor((10, 20), dtype="float32"),
2250+
)
2251+
R.output(gv1)
2252+
return gv1
2253+
2254+
_check(Module, Expected)
2255+
2256+
21652257
if __name__ == "__main__":
21662258
tvm.testing.main()

0 commit comments

Comments
 (0)