Skip to content

Commit 13ddd8c

Browse files
committed
Fix review comments
1 parent 28a4bb6 commit 13ddd8c

File tree

5 files changed

+12
-12
lines changed

5 files changed

+12
-12
lines changed

python/tvm/relax/transform/legalize_ops/manipulate.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,7 @@ def te_layout_transform(data, name):
182182
)
183183

184184
def set_axis_sep(axis_sep: list, sch: tir.schedule, buffer_type: str):
185-
if len(axis_sep) != 0:
186-
sch.set_axis_separator(primfunc_name, (buffer_type, 0), axis_separators=axis_sep)
185+
sch.set_axis_separator(primfunc_name, (buffer_type, 0), axis_separators=axis_sep)
187186

188187
index_map: tvm.tir.IndexMap = call.attrs.index_map
189188
pad_value = call.attrs.pad_value
@@ -199,7 +198,7 @@ def set_axis_sep(axis_sep: list, sch: tir.schedule, buffer_type: str):
199198
input_axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = call.attrs.input_axis_separators
200199

201200
# Convert to list from array
202-
axis_separators = list(map(lambda x: x.value, axis_separators))
201+
axis_separators = [int(sep) for sep in axis_separators]
203202
primfunc_name = "te_layout_transform"
204203
_, padding_predicate = index_map.non_surjective_inverse(call.args[0].struct_info.shape)
205204
if not isinstance(padding_predicate, tvm.tir.expr.IntImm):
@@ -214,7 +213,7 @@ def set_axis_sep(axis_sep: list, sch: tir.schedule, buffer_type: str):
214213
sch.transform_layout(primfunc_name, ("write", 0), index_map, pad_value)
215214
set_axis_sep(axis_separators, sch, "write")
216215
if input_axis_separators is not None:
217-
input_axis_separators = list(map(lambda x: x.value, input_axis_separators))
216+
input_axis_separators = [int(sep) for sep in input_axis_separators]
218217
set_axis_sep(input_axis_separators, sch, "read")
219218
gvar = bb.add_func(sch.mod["main"], primfunc_name)
220219
output_shape = index_map.map_shape(list(call_args[0].struct_info.shape))

python/tvm/relax/transform/transform.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np # type: ignore
2525

2626
import tvm.ir
27+
from tvm.ir.container import Array
2728
from tvm.relax import Expr, Var, StructInfo
2829
from tvm.relax.dpl import DFPattern
2930
from tvm.runtime import NDArray, Object
@@ -1309,7 +1310,9 @@ def AlterOpImpl(
13091310
# Extract the index_map
13101311
if isinstance(transform, Callable):
13111312
transform = IndexMap.from_func_with_separators(transform)[0]
1312-
elif isinstance(transform, tuple) and isinstance(transform[0], IndexMap):
1313+
elif (isinstance(transform, tuple) or isinstance(transform, Array)) and isinstance(
1314+
transform[0], IndexMap
1315+
):
13131316
transform = transform[0]
13141317
l.append(transform)
13151318
op_buffer_transforms[operator_name] = l

src/relax/op/tensor/manipulate.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ Expr flatten(Expr x);
7272
*/
7373
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value,
7474
Optional<Array<IntImm>> axis_separators,
75-
Optional<Array<IntImm>> input_axis_separators);
75+
Optional<Array<IntImm>> input_axis_separators = NullOpt);
7676

7777
/*!
7878
* \brief Permutes the dimensions of an array.

src/relax/transform/alter_op_impl.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ class AlterOpImplMutator : public ExprMutator {
128128
const auto& replacement_func = op_impl_map_[op_kind];
129129

130130
Array<IndexMap> buffer_transforms;
131-
Optional<Array<Array<IntImm>>> axis_separators, input_axis_separators;
131+
Optional<Array<Array<IntImm>>> axis_separators;
132+
Optional<Array<Array<IntImm>>> input_axis_separators;
132133
if (op_buffer_transforms__.count(op_kind)) buffer_transforms = op_buffer_transforms__[op_kind];
133134
if (op_buffer_axis_separators__.count(op_kind))
134135
axis_separators = op_buffer_axis_separators__[op_kind];
@@ -293,7 +294,8 @@ class AlterOpImplMutator : public ExprMutator {
293294
Array<Expr> updated_inputs;
294295
int index = 0;
295296
for (const auto& input : inputs->fields) {
296-
Array<IntImm> axis_separator, input_axis_separator;
297+
Array<IntImm> axis_separator;
298+
Array<IntImm> input_axis_separator;
297299
if (axis_separators.defined()) {
298300
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
299301
axis_separator = axis_separators_value[index];

tests/python/relax/test_transform_alter_op_impl.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -609,8 +609,6 @@ def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer(
609609
for ax0, ax1 in T.grid(4, 4):
610610
with T.block("T_add"):
611611
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
612-
T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
613-
T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
614612
output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
615613
output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1]
616614

@@ -633,8 +631,6 @@ def some_op_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float3
633631
for ax0, ax1 in T.grid(4, 4):
634632
with T.block("T_add"):
635633
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
636-
T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
637-
T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
638634
output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
639635
output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1]
640636
# fmt: on

0 commit comments

Comments
 (0)