Skip to content

Commit d641354

Browse files
authored
[RELAX][PASS] Convert layout pass and ops enhanced to support sub indexing (#17568)
Convert layout pass and ops enhanced to support sub indexing Majority of the operations made compatible with custom layouts. Incompatible ops will fallback to regular layout. Conv1D, Conv3D, Pool1D, Pool3D, AdaptiveAvgPool1D, AdaptiveAvgPool3D are left unchanged now. 2D networks are expected to work now.
1 parent 8b59368 commit d641354

File tree

15 files changed

+3398
-42
lines changed

15 files changed

+3398
-42
lines changed

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1688,9 +1688,10 @@ def index_map(
16881688
mapping: Callable,
16891689
*,
16901690
inverse_index_map: Optional[Callable] = None,
1691+
index_dtype: str = "int64",
16911692
) -> IndexMap:
16921693
"""Create a TIR Index mapping"""
1693-
return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map)
1694+
return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map, index_dtype=index_dtype)
16941695

16951696

16961697
def target(

src/relax/op/image/resize.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ InferLayoutOutput InferLayoutResize2d(const Call& call,
121121
} else {
122122
// We dont have a desired layout for resize2d, propagate from the input instead.
123123
data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
124+
// Not handling sub indexing now.
125+
if (data_layout->layout.ndim() != data_layout->layout.ndim_primal()) {
126+
data_layout = LayoutDecision(InitialLayout(4));
127+
}
124128
new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), data_layout->layout).name();
125129
}
126130
return InferLayoutOutput({data_layout, InitialNLayout(call->args[1])}, {data_layout},

src/relax/op/nn/convolution.cc

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -308,30 +308,59 @@ InferLayoutOutput InferLayoutConv2d(const Call& call,
308308
Layout desired_data_layout = (*it).second[0];
309309
Layout desired_weight_layout = (*it).second[1];
310310
Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
311-
ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only";
312-
ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal())
313-
<< "Axis swap only";
314-
ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal())
315-
<< "Axis swap only";
316-
data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, desired_data_layout);
317-
weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, desired_weight_layout);
318-
output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, desired_output_layout);
319-
new_attrs->data_layout = (*it).second[0];
320-
new_attrs->kernel_layout = (*it).second[1];
321-
new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
322-
} else {
323-
// We don't have a desired layout for conv2d.
324-
// We can just propagate the layout from the input.
325-
data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
326-
weight_layout = GetLayoutDecision(var_layout_map, call->args[1]);
327-
output_layout = data_layout;
328-
new_attrs->data_layout =
329-
TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name();
330-
new_attrs->kernel_layout =
331-
TransposeLike(attrs->kernel_layout, InitialLayout(4), weight_layout->layout).name();
332-
new_attrs->out_layout =
333-
TransposeLike(attrs->out_layout, InitialLayout(4), output_layout->layout).name();
311+
tir::Layout input_layout(attrs->data_layout, DataType::Int(64));
312+
tir::Layout kernel_layout(attrs->kernel_layout, DataType::Int(64));
313+
tir::Layout out_layout(attrs->out_layout, DataType::Int(64));
314+
315+
if ((desired_data_layout.ndim() == input_layout.ndim()) &&
316+
(desired_weight_layout.ndim() == kernel_layout.ndim()) &&
317+
(desired_output_layout.ndim() == out_layout.ndim())) {
318+
// Just a transpose
319+
data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, desired_data_layout);
320+
weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, desired_weight_layout);
321+
output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, desired_output_layout);
322+
new_attrs->data_layout = (*it).second[0];
323+
new_attrs->kernel_layout = (*it).second[1];
324+
new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
325+
return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs));
326+
} else {
327+
// Layout Transform
328+
auto data_si = GetStructInfo(call->args[0]);
329+
auto kernel_si = GetStructInfo(call->args[1]);
330+
TensorStructInfo data_sinfo = data_si.as<TensorStructInfo>().value();
331+
TensorStructInfo kernel_sinfo = kernel_si.as<TensorStructInfo>().value();
332+
Optional<ShapeExpr> data_shape = GetRef<ShapeExpr>(data_sinfo->shape.as<ShapeExprNode>());
333+
Optional<ShapeExpr> kernel_shape = GetRef<ShapeExpr>(kernel_sinfo->shape.as<ShapeExprNode>());
334+
335+
bool can_data_proved =
336+
CanProveLayoutTransform(input_layout, desired_data_layout, data_shape.value()->values);
337+
bool can_kernel_proved = CanProveLayoutTransform(kernel_layout, desired_weight_layout,
338+
kernel_shape.value()->values);
339+
340+
if (can_data_proved && can_kernel_proved) {
341+
data_layout = TransposeSubLayoutLike(InitialLayout(4), input_layout, desired_data_layout);
342+
weight_layout =
343+
TransposeSubLayoutLike(InitialLayout(4), kernel_layout, desired_weight_layout);
344+
output_layout = TransposeSubLayoutLike(InitialLayout(4), out_layout, desired_output_layout);
345+
new_attrs->data_layout = (*it).second[0];
346+
new_attrs->kernel_layout = (*it).second[1];
347+
new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
348+
return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs));
349+
}
350+
}
334351
}
352+
353+
// We don't have a desired layout for conv2d or desired layouts not compatible.
354+
// We can just propagate the layout from the input.
355+
data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
356+
weight_layout = GetLayoutDecision(var_layout_map, call->args[1]);
357+
output_layout = data_layout;
358+
new_attrs->data_layout =
359+
TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name();
360+
new_attrs->kernel_layout =
361+
TransposeLike(attrs->kernel_layout, InitialLayout(4), weight_layout->layout).name();
362+
new_attrs->out_layout =
363+
TransposeLike(attrs->out_layout, InitialLayout(4), output_layout->layout).name();
335364
return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs));
336365
}
337366

src/relax/op/nn/nn.cc

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,16 @@ InferLayoutOutput InferLayoutSoftmax(const Call& call,
9393
ICHECK(attrs) << "Invalid Call";
9494

9595
LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
96+
97+
// TODO(Siva): We could handle if the axis is not the sub indexed one.
98+
if (layout->layout.ndim() != layout->layout.ndim_primal()) {
99+
const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
100+
ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
101+
ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
102+
int ndim = tensor_sinfo->ndim;
103+
layout = LayoutDecision(InitialLayout(ndim));
104+
}
105+
96106
ObjectPtr<SoftmaxAttrs> new_attrs = make_object<SoftmaxAttrs>(*attrs);
97107
new_attrs->axis = FindAxis(layout->layout, attrs->axis);
98108
return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
@@ -290,8 +300,18 @@ InferLayoutOutput InferLayoutBatchNorm(const Call& call,
290300
ICHECK(attrs) << "Invalid Call";
291301

292302
LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
303+
304+
// While dealing with sub layouts, its adviced to deal with batchnorm
305+
// on other ways like decomposing or fusion methods.
306+
// This handling is fail safe fallback.
307+
const auto* input_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
308+
int ndim = input_sinfo->ndim;
309+
if (layout->layout.ndim() != layout->layout.ndim_primal()) {
310+
layout = LayoutDecision(InitialLayout(ndim));
311+
}
312+
293313
ObjectPtr<BatchNormAttrs> new_attrs = make_object<BatchNormAttrs>(*attrs);
294-
new_attrs->axis = FindAxis(layout->layout, attrs->axis);
314+
new_attrs->axis = FindAxis(layout->layout, (attrs->axis + ndim) % ndim);
295315
return InferLayoutOutput(
296316
{layout, initial_layouts[1], initial_layouts[2], initial_layouts[3], initial_layouts[4]},
297317
{{layout, initial_layouts[3], initial_layouts[4]}}, Attrs(new_attrs));
@@ -353,9 +373,11 @@ InferLayoutOutput InferLayoutLayerNorm(const Call& call,
353373

354374
LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
355375
ObjectPtr<LayerNormAttrs> new_attrs = make_object<LayerNormAttrs>(*attrs);
376+
const auto* input_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
377+
int ndim = input_sinfo->ndim;
356378
std::vector<Integer> new_axis;
357379
for (const auto& axis : attrs->axes) {
358-
new_axis.push_back(FindAxis(layout->layout, axis->value));
380+
new_axis.push_back(FindAxis(layout->layout, (axis->value + ndim) % ndim));
359381
}
360382
new_attrs->axes = std::move(new_axis);
361383
return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout},

src/relax/op/nn/pooling.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,23 @@ InferLayoutOutput InferLayoutPool2d(const Call& call,
234234

235235
LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
236236
ObjectPtr<Pool2DAttrs> new_attrs = make_object<Pool2DAttrs>(*attrs);
237+
238+
if (layout->layout.ndim() != layout->layout.ndim_primal()) {
239+
tir::Layout in_layout(attrs->layout, DataType::Int(64));
240+
auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout);
241+
auto data_si = GetStructInfo(call->args[0]);
242+
TensorStructInfo data_sinfo = data_si.as<TensorStructInfo>().value();
243+
Optional<ShapeExpr> data_shape = GetRef<ShapeExpr>(data_sinfo->shape.as<ShapeExprNode>());
244+
if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) {
245+
// Not handling out_layout being different from in_layout now. Any use case ?
246+
new_attrs->layout = desired_layout.name();
247+
new_attrs->out_layout = desired_layout.name();
248+
return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
249+
} else {
250+
layout = InitialLayout(4);
251+
}
252+
}
253+
237254
new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name();
238255
new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4), layout->layout).name();
239256
return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
@@ -583,6 +600,21 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool2D(const Call& call,
583600

584601
LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
585602
ObjectPtr<AdaptivePool2DAttrs> new_attrs = make_object<AdaptivePool2DAttrs>(*attrs);
603+
if (layout->layout.ndim() != layout->layout.ndim_primal()) {
604+
tir::Layout in_layout(attrs->layout, DataType::Int(64));
605+
auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout);
606+
auto data_si = GetStructInfo(call->args[0]);
607+
TensorStructInfo data_sinfo = data_si.as<TensorStructInfo>().value();
608+
Optional<ShapeExpr> data_shape = GetRef<ShapeExpr>(data_sinfo->shape.as<ShapeExprNode>());
609+
if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) {
610+
// Not handling out_layout being different from in_layout now. Any use case ?
611+
new_attrs->layout = desired_layout.name();
612+
new_attrs->out_layout = desired_layout.name();
613+
return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
614+
} else {
615+
layout = InitialLayout(4);
616+
}
617+
}
586618
new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name();
587619
new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4), layout->layout).name();
588620
return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));

src/relax/op/op_common.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,5 +185,27 @@ InferLayoutOutput InferLayoutUnaryEwise(const Call& call,
185185
return InferLayoutOutput({layout}, {layout}, Attrs(call->attrs));
186186
}
187187

188+
bool CanProveLayoutTransform(const Layout& input_layout, const Layout& desired_layout,
189+
Array<PrimExpr> shape) {
190+
bool can_prove = true;
191+
try {
192+
tir::BijectiveLayout todesired(input_layout, desired_layout);
193+
Array<PrimExpr> desired_shape = todesired.ForwardShape(shape);
194+
Array<PrimExpr> back_shape = todesired.BackwardShape(desired_shape);
195+
arith::Analyzer analyzer;
196+
for (size_t i = 0; i < shape.size(); ++i) {
197+
if (tir::is_const_int(shape[i])) {
198+
if (!analyzer.CanProveEqual(shape[i], back_shape[i])) {
199+
can_prove = false;
200+
break;
201+
}
202+
}
203+
}
204+
} catch (std::exception& err) {
205+
return false;
206+
}
207+
return can_prove;
208+
}
209+
188210
} // namespace relax
189211
} // namespace tvm

src/relax/op/op_common.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,16 @@ Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_ind
570570
*/
571571
Array<Expr> GetCallArgs(const Call& call);
572572

573+
/**
574+
* \brief Checks the given shape can be proved from the source layout to dst layout
575+
* \param input_layout is the layout of given shape
576+
* \param desired_layout is the target layout the shape to be transformed
577+
* \param shape array
578+
* \return true or false depending on the compatibility
579+
*/
580+
bool CanProveLayoutTransform(const Layout& input_layout, const Layout& desired_layout,
581+
Array<PrimExpr> shape);
582+
573583
} // namespace relax
574584
} // namespace tvm
575585

src/relax/op/tensor/binary.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,21 @@ InferLayoutOutput InferLayoutBinaryEwise(const Call& call,
155155
ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim())
156156
<< "Unknown dim tensors should not be handled by this function";
157157

158+
Optional<ShapeExpr> shape1 = GetRef<ShapeExpr>(x1_sinfo->shape.as<ShapeExprNode>());
159+
Optional<ShapeExpr> shape2 = GetRef<ShapeExpr>(x2_sinfo->shape.as<ShapeExprNode>());
160+
// Lets handle sub indexing as long as primal dims are matching
161+
if (layout1->layout.ndim_primal() == layout2->layout.ndim_primal()) {
162+
if ((layout1->layout.ndim() >= layout2->layout.ndim()) && shape2.defined()) {
163+
if (CanProveLayoutTransform(layout2->layout, layout1->layout, shape2.value()->values)) {
164+
return InferLayoutOutput({layout1, layout1}, {layout1}, Attrs(call->attrs));
165+
}
166+
} else if (shape1.defined()) {
167+
if (CanProveLayoutTransform(layout1->layout, layout2->layout, shape1.value()->values)) {
168+
return InferLayoutOutput({layout2, layout2}, {layout2}, Attrs(call->attrs));
169+
}
170+
}
171+
}
172+
158173
if (x1_sinfo->ndim <= x2_sinfo->ndim) {
159174
if (x1_sinfo->ndim == 0) {
160175
LayoutDecision out_layout = layout2;

src/relax/op/tensor/index.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,10 @@ InferLayoutOutput InferLayoutStridedSlice(const Call& call,
438438
<< "but expression " << call << " has argument "
439439
<< call->args[0] << " of unknown dimensionality.";
440440
LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]);
441+
// Can't handle sub indexed layouts.
442+
if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) {
443+
existing_layout = LayoutDecision(InitialLayout(tensor_sinfo->ndim));
444+
}
441445

442446
auto opt_axes_tuple = UnpackTupleOfPrimValue<Integer>(GetStructInfo(call->args[1]));
443447
CHECK(opt_axes_tuple) << "Layout inference of " << call->op

src/relax/op/tensor/manipulate.cc

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,10 @@ InferLayoutOutput InferLayoutExpandDims(const Call& call,
393393

394394
LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]);
395395
int ndim = tensor_sinfo->ndim;
396+
// Can't handle sub indexed layouts.
397+
if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) {
398+
existing_layout = LayoutDecision(InitialLayout(ndim));
399+
}
396400
int n_new_dim = attrs->axis.size();
397401
int output_ndim = ndim + n_new_dim;
398402
std::vector<bool> is_new_dim(output_ndim, false);
@@ -622,6 +626,12 @@ InferLayoutOutput InferLayoutPermuteDims(const Call& call,
622626
int ndim = tensor_sinfo->ndim;
623627

624628
LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]);
629+
630+
// permute_dims can't handle sub indexed layouts.
631+
if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) {
632+
existing_layout = LayoutDecision(InitialLayout(ndim));
633+
}
634+
625635
Array<Integer> order;
626636
if (attrs->axes.defined()) {
627637
order = attrs->axes.value();
@@ -942,10 +952,33 @@ InferLayoutOutput InferLayoutSplit(const Call& call,
942952
ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
943953

944954
LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]);
945-
ObjectPtr<SplitAttrs> new_attrs = make_object<SplitAttrs>(*attrs);
946-
new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis);
947955
StructInfo out_sinfo = InferStructInfoSplit(call, BlockBuilder::Create(IRModule()));
948956
const auto* out_tuple = out_sinfo.as<TupleStructInfoNode>();
957+
958+
/*
959+
* Fallback if the outputs can't be represented in input sub indexed layout
960+
* This can happen after sub indexing, if we can't split the corresponding primal axis
961+
*/
962+
if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) {
963+
for (const auto& si : out_tuple->fields) {
964+
ICHECK(si->IsInstance<TensorStructInfoNode>())
965+
<< "Fields of TupleStructInfo must be TensorStructInfo"
966+
"output structinfo, but got "
967+
<< si;
968+
auto sinfo = Downcast<TensorStructInfo>(si);
969+
Optional<ShapeExpr> shape_expr = GetRef<ShapeExpr>(sinfo->shape.as<ShapeExprNode>());
970+
CHECK(shape_expr.defined());
971+
auto shape_arr = shape_expr.value();
972+
if (!CanProveLayoutTransform(InitialLayout(tensor_sinfo->ndim), existing_layout->layout,
973+
shape_arr->values)) {
974+
existing_layout = InitialLayout(tensor_sinfo->ndim);
975+
break;
976+
}
977+
}
978+
}
979+
980+
ObjectPtr<SplitAttrs> new_attrs = make_object<SplitAttrs>(*attrs);
981+
new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis);
949982
ICHECK(out_tuple != nullptr) << "Invalid Call";
950983
NLayout tuple_layouts(Array<NLayout>(out_tuple->fields.size(), existing_layout));
951984
return InferLayoutOutput({existing_layout}, {tuple_layouts}, Attrs(new_attrs));
@@ -1092,6 +1125,10 @@ InferLayoutOutput InferLayoutSqueeze(const Call& call,
10921125
}
10931126

10941127
LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]);
1128+
// Can't handle sub indexed layouts.
1129+
if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) {
1130+
existing_layout = LayoutDecision(InitialLayout(ndim));
1131+
}
10951132
String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout);
10961133
Array<Integer> new_axis;
10971134
for (size_t i = 0; i < new_axis_str.size(); ++i) {

0 commit comments

Comments
 (0)