Skip to content

Commit 7dfc863

Browse files
authored
[Unity] Alter op impl handling empty transform for output (#16331)
Alterop impl, handling empty transform for output
1 parent 0cf5f47 commit 7dfc863

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

src/relax/transform/alter_op_impl.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ class AlterOpImplMutator : public ExprMutator {
324324

325325
/*! \brief Returns the TensorStructInfo after applying the \p transform on its shape */
326326
StructInfo UpdateStructInfo(const TensorStructInfo& tensor_sinfo, const IndexMap& transform) {
327+
if (transform.get() == nullptr) return tensor_sinfo;
327328
auto shape = GetShapeFromTensorStructInfo(tensor_sinfo);
328329
arith::Analyzer analyzer;
329330
auto new_shape = transform->MapShape(shape, &analyzer);

tests/python/relax/test_transform_alter_op_impl.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,5 +472,105 @@ def add_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"),
472472
)
473473

474474

475+
def test_reshape():
476+
@I.ir_module
477+
class Before:
478+
@T.prim_func(private=True)
479+
def reshape(
480+
A: T.Buffer((T.int64(850), T.int64(2048)), "float16"),
481+
T_reshape: T.Buffer((T.int64(850), T.int64(1), T.int64(2048)), "float16"),
482+
):
483+
T.func_attr({"operator_name": "relax.reshape"})
484+
for ax0, ax1, ax2 in T.grid(T.int64(850), T.int64(1), T.int64(2048)):
485+
with T.block("T_reshape"):
486+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
487+
T.reads(
488+
A[
489+
(v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % T.int64(850),
490+
v_ax2 % T.int64(2048),
491+
]
492+
)
493+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
494+
T_reshape[v_ax0, v_ax1, v_ax2] = A[
495+
(v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % T.int64(850),
496+
v_ax2 % T.int64(2048),
497+
]
498+
499+
@R.function
500+
def main(
501+
x: R.Tensor((850, 2048), dtype="float16")
502+
) -> R.Tensor((850, 1, 2048), dtype="float16"):
503+
cls = Before
504+
with R.dataflow():
505+
lv = R.call_tir(
506+
cls.reshape, (x,), out_sinfo=R.Tensor((850, 1, 2048), dtype="float16")
507+
)
508+
gv: R.Tensor((850, 1, 2048), dtype="float16") = lv
509+
R.output(gv)
510+
return gv
511+
512+
@I.ir_module
513+
class Expected:
514+
@T.prim_func(private=True)
515+
def relax_reshape_replacement(
516+
A: T.Buffer((T.int64(850), T.int64(2), T.int64(1024)), "float16"),
517+
T_reshape: T.Buffer((T.int64(850), T.int64(1), T.int64(2048)), "float16"),
518+
):
519+
T.func_attr({"operator_name": "relax.reshape"})
520+
for ax0, ax1, ax2 in T.grid(T.int64(850), T.int64(1), T.int64(2048)):
521+
with T.block("T_reshape"):
522+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
523+
T.reads(A[v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024)])
524+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
525+
T_reshape[v_ax0, v_ax1, v_ax2] = A[
526+
v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024)
527+
]
528+
529+
@R.function
530+
def main(
531+
x: R.Tensor((850, 2048), dtype="float16")
532+
) -> R.Tensor((850, 1, 2048), dtype="float16"):
533+
cls = Expected
534+
with R.dataflow():
535+
lv: R.Tensor((850, 2, 1024), dtype="float16") = R.layout_transform(
536+
x,
537+
index_map=T.index_map(lambda i, j: (i, j // 1024, j % 1024)),
538+
pad_value=None,
539+
axis_separators=[],
540+
)
541+
lv_1 = R.call_tir(
542+
cls.relax_reshape_replacement,
543+
(lv,),
544+
out_sinfo=R.Tensor((850, 1, 2048), dtype="float16"),
545+
)
546+
gv: R.Tensor((850, 1, 2048), dtype="float16") = lv_1
547+
R.output(gv)
548+
return gv
549+
550+
@T.prim_func(private=True)
551+
def reshape_new(
552+
A: T.Buffer((T.int64(850), T.int64(2), T.int64(1024)), "float16"),
553+
T_reshape: T.Buffer((T.int64(850), T.int64(1), T.int64(2048)), "float16"),
554+
):
555+
for ax0, ax1, ax2 in T.grid(T.int64(850), T.int64(1), T.int64(2048)):
556+
with T.block("T_reshape"):
557+
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
558+
T.reads(A[v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024)])
559+
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
560+
T_reshape[v_ax0, v_ax1, v_ax2] = A[
561+
v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024)
562+
]
563+
564+
# fmt: on
565+
index_map = lambda i, j: (i, j // 1024, j % 1024)
566+
_check(
567+
Before,
568+
Expected,
569+
operator_name="relax.reshape",
570+
replacement_primfunc=reshape_new,
571+
layout_changes=[index_map, None],
572+
)
573+
574+
475575
if __name__ == "__main__":
476576
tvm.testing.main()

0 commit comments

Comments
 (0)