Skip to content

Commit 970cd1d

Browse files
authored
[TIR][Hexagon] Enhancement of NarrowDataType pass for binary ops (#14298)
This is enhancement of PR#13327. Motivation: Playing with MetaScheduler for Hexagon target it was found that avg_pool2d has rather poor performance due to lack of vectorized code. IndexDataTypeNormalizer pass converts all indices to int64 format and NarrowDataTypeRewriter should do the opposite (back to int32). In case of fail, we have a lot of int64 arithmetic for average pooling that can not be vectorized. What was done: Added support of binary ops ("div", "max", "min", "+" etc.) in NarrowDataTypeRewriter. In case of different bitwidth of operands in binary opeation it does downcasting instead of upcasting (as it was before). Performance impact: avg_pool2d from quantized InceptionV3 with the shape [1, 8, 35, 35, 32] (NCHW32c layout) tuned with MetaScheduler on Snapdragon 8gen1: shape | Before fix, ms | After fix, ms | speedup | ------------------|----------------|---------------|-------------| avg_pool2d, int32 | 6.67 | 4.41 | +34% | -----------------------------------------------------------------|
1 parent 075e2ec commit 970cd1d

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

src/tir/transforms/narrow_datatype.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,44 @@ class NarrowDataTypeRewriter : public IndexDataTypeRewriter {
258258
return Parent::VisitExpr_(op);
259259
}
260260

261+
#define TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \
262+
PrimExpr VisitExpr_(const OP* op) { \
263+
PrimExpr a = this->VisitExpr(op->a); \
264+
PrimExpr b = this->VisitExpr(op->b); \
265+
if (op->a.same_as(a) && op->b.same_as(b) && a.dtype() == b.dtype()) { \
266+
return GetRef<PrimExpr>(op); \
267+
} else { \
268+
if (a.dtype() != b.dtype()) { \
269+
bool is_enabled = is_enabled_; \
270+
is_enabled_ = true; \
271+
PrimExpr lhs = this->VisitExpr(op->a); \
272+
PrimExpr rhs = this->VisitExpr(op->b); \
273+
is_enabled_ = is_enabled; \
274+
return FUNC(lhs, rhs); \
275+
} else { \
276+
return FUNC(a, b); \
277+
} \
278+
} \
279+
}
280+
281+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+);
282+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-);
283+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*);
284+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div);
285+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(ModNode, truncmod);
286+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorDivNode, floordiv);
287+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(FloorModNode, floormod);
288+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min);
289+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max);
290+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==);
291+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=);
292+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=);
293+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<); // NOLINT(*)
294+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>); // NOLINT(*)
295+
TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=);
296+
297+
#undef TVM_DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH
298+
261299
private:
262300
// the internal visitor to deduce the narrowed dtype
263301
DataTypeVisitor visitor_;

tests/python/unittest/test_tir_transform_narrow_datatype.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,5 +346,66 @@ def expected_after(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32"
346346
tvm.ir.assert_structural_equal(after, expected_after)
347347

348348

349+
def test_avg_pool2d():
350+
@T.prim_func
351+
def before(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,), "int32")):
352+
for j in T.parallel(T.int64(0), T.int64(280)):
353+
for i in T.serial(T.int64(0), T.int64(35)):
354+
for vi in T.vectorized(T.int64(0), T.int64(32)):
355+
PAVG[(((j * T.int64(1120)) + (i * T.int64(32))) + vi)] = T.cast(
356+
T.Div(
357+
T.cast(PSUM[(((j * T.int64(1120)) + (i * T.int64(32))) + vi)], "int64"),
358+
T.max(
359+
(
360+
(
361+
(
362+
T.min(
363+
T.int64(1),
364+
(T.int64(34) - T.floormod(j, T.int64(35))),
365+
)
366+
+ T.int64(2)
367+
)
368+
- T.max(
369+
(T.int64(1) - T.floormod(j, T.int64(35))), T.int64(0)
370+
)
371+
)
372+
* (
373+
(T.min(T.int64(1), (T.int64(34) - i)) + T.int64(2))
374+
- T.max((T.int64(1) - i), T.int64(0))
375+
)
376+
),
377+
T.int64(1),
378+
),
379+
),
380+
"int32",
381+
)
382+
383+
@T.prim_func
384+
def expected_after(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,), "int32")):
385+
for j in T.parallel(T.int32(0), T.int32(280)):
386+
for i in T.serial(T.int32(0), T.int32(35)):
387+
for vi in T.vectorized(T.int32(0), T.int32(32)):
388+
PAVG[(((j * T.int32(1120)) + (i * T.int32(32))) + vi)] = T.Div(
389+
PSUM[(((j * T.int32(1120)) + (i * T.int32(32))) + vi)],
390+
(
391+
(
392+
(
393+
T.min(T.int32(1), (T.int32(34) - T.floormod(j, T.int32(35))))
394+
+ T.int32(2)
395+
)
396+
- T.max((T.int32(1) - T.floormod(j, T.int32(35))), T.int32(0))
397+
)
398+
* (
399+
(T.min(T.int32(1), (T.int32(34) - i)) + T.int32(2))
400+
- T.max((T.int32(1) - i), T.int32(0))
401+
)
402+
),
403+
)
404+
405+
after = tvm.tir.transform.NarrowDataType(32)(tvm.IRModule.from_expr(before))
406+
after = tvm.tir.transform.Simplify()(after)
407+
tvm.ir.assert_structural_equal(after["main"], expected_after)
408+
409+
349410
if __name__ == "__main__":
350411
tvm.testing.main()

0 commit comments

Comments
 (0)