Skip to content

Commit dc1b629

Browse files
committed
[TOPI][TIR][TE][x86] Extend x86 SIMD (u)int8 coverage for dense & conv2d
1 parent 71caa19 commit dc1b629

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1214
-294
lines changed

include/tvm/tir/builtin.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,21 @@ TVM_DLL const Op& ret();
5050
*/
5151
TVM_DLL const Op& reinterpret();
5252

53+
/*!
54+
* \brief Zero extend the value using the target type.
55+
*/
56+
TVM_DLL const Op& zextend();
57+
58+
/*!
59+
* \brief Sign extend the value using the target type.
60+
*/
61+
TVM_DLL const Op& sextend();
62+
63+
/*!
64+
* \brief Truncate the value using the target type.
65+
*/
66+
TVM_DLL const Op& truncate();
67+
5368
/*!
5469
* \brief Marks a condition is likely going to happen.
5570
*/
@@ -769,9 +784,20 @@ TVM_DLL const Op& vectorlow();
769784
TVM_DLL const Op& vectorcombine();
770785

771786
/*!
772-
* \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA
787+
* \brief Shuffle two vectors using indices.
788+
*/
789+
TVM_DLL const Op& vectorshuffle();
790+
791+
/*!
792+
* \brief Permute vector using indices.
793+
*/
794+
TVM_DLL const Op& vectorpermute();
795+
796+
/*!
797+
* \brief Atomic add instruction.
773798
*/
774799
TVM_DLL const Op& atomic_add();
800+
775801
/*!
776802
* \brief Create an Nd memory allocation with storage scope
777803
*/

include/tvm/tir/expr.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,39 @@ class StringImm : public PrimExpr {
8282
TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode);
8383
};
8484

85+
/*! \brief Array of integer constants */
86+
class ArrayIntImmNode : public PrimExprNode {
87+
public:
88+
/*! \brief The constant value content. */
89+
Array<Integer> data;
90+
91+
void VisitAttrs(AttrVisitor* v) {
92+
v->Visit("dtype", &dtype);
93+
v->Visit("data", &data);
94+
v->Visit("span", &span);
95+
}
96+
97+
bool SEqualReduce(const ArrayIntImmNode* other, SEqualReducer equal) const {
98+
return equal(data, other->data);
99+
}
100+
101+
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); }
102+
103+
static constexpr const char* _type_key = "tir.ArrayIntImm";
104+
TVM_DECLARE_FINAL_OBJECT_INFO(ArrayIntImmNode, PrimExprNode);
105+
};
106+
107+
/*!
108+
* \brief Managed reference to ArrayIntImmNode.
109+
* \sa ArrayIntImmNode
110+
*/
111+
class ArrayIntImm : public PrimExpr {
112+
public:
113+
TVM_DLL ArrayIntImm(Array<Integer> data, Span span = Span());
114+
TVM_DEFINE_OBJECT_REF_METHODS(ArrayIntImm, PrimExpr, ArrayIntImmNode);
115+
TVM_DEFINE_OBJECT_REF_COW_METHOD(ArrayIntImmNode);
116+
};
117+
85118
/*!
86119
* \brief Cast value from one data type to another.
87120
* \note The lanes of value should keep fixed.

include/tvm/tir/expr_functor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
149149
virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
150150
virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
151151
virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
152+
virtual R VisitExpr_(const ArrayIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
152153
virtual R VisitExpr_(const AnyNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
153154
virtual R VisitExprDefault_(const Object* op, Args...) {
154155
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
@@ -192,6 +193,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
192193
IR_EXPR_FUNCTOR_DISPATCH(IntImmNode);
193194
IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
194195
IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
196+
IR_EXPR_FUNCTOR_DISPATCH(ArrayIntImmNode);
195197
IR_EXPR_FUNCTOR_DISPATCH(AnyNode);
196198
return vtable;
197199
}
@@ -243,6 +245,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor<void(const PrimExpr&)> {
243245
void VisitExpr_(const IntImmNode* op) override;
244246
void VisitExpr_(const FloatImmNode* op) override;
245247
void VisitExpr_(const StringImmNode* op) override;
248+
void VisitExpr_(const ArrayIntImmNode* op) override;
246249
void VisitExpr_(const AnyNode* op) override;
247250
};
248251

@@ -289,6 +292,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor<PrimExpr(const PrimExpr&)> {
289292
PrimExpr VisitExpr_(const IntImmNode* op) override;
290293
PrimExpr VisitExpr_(const FloatImmNode* op) override;
291294
PrimExpr VisitExpr_(const StringImmNode* op) override;
295+
PrimExpr VisitExpr_(const ArrayIntImmNode* op) override;
292296
PrimExpr VisitExpr_(const AnyNode* op) override;
293297
};
294298

python/tvm/autotvm/task/task.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def _encode(x):
6565
return x
6666
if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
6767
return x.value
68+
if isinstance(x, expr.ArrayIntImm):
69+
return x.data
6870
if isinstance(x, runtime.container.String):
6971
return str(x)
7072
if x is None:

python/tvm/ir/json_compact.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def _convert(item, nodes):
191191
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
192192
"SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")],
193193
"StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")],
194+
"ArrayIntImm": [_rename("tir.ArrayIntImm"), _update_from_std_str("data")],
194195
"Cast": _rename("tir.Cast"),
195196
"Add": _rename("tir.Add"),
196197
"Sub": _rename("tir.Sub"),

python/tvm/relay/op/nn/_nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def legalize_dense(attrs, inputs, types):
8080
Parameters
8181
----------
8282
attrs : tvm.ir.Attrs
83-
Attributes of current convolution
83+
Attributes of current dense operation
8484
inputs : list of tvm.relay.Expr
8585
The args of the Relay expr to be legalized
8686
types : list of types

python/tvm/relay/op/strategy/x86.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tvm.meta_schedule import is_meta_schedule_enabled
2525
from tvm.relay.ty import is_dynamic
2626
from tvm.te import SpecializedCondition
27+
from tvm.target.x86 import get_x86_simd_32bit_lanes
2728

2829
from .. import op as _op
2930
from .generic import *
@@ -588,11 +589,12 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
588589
def dense_pack_strategy_cpu(attrs, inputs, out_type, target):
589590
"""dense_pack x86 strategy"""
590591
strategy = _op.OpStrategy()
592+
vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16
591593
if (
592-
inputs[0].dtype == "uint8"
593-
and inputs[1].dtype == "int8"
594+
inputs[0].dtype in ("uint8", "int8")
595+
and inputs[1].dtype in ("int8", "uint8")
594596
and out_type.dtype == "int32"
595-
and attrs["weight_layout"] == "NC16n4c"
597+
and attrs["weight_layout"] == f"NC{vec_width}n4c"
596598
):
597599
strategy.add_implementation(
598600
wrap_compute_dense(topi.x86.dense_int8),
@@ -622,10 +624,14 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
622624
if (
623625
not attrs.transpose_a
624626
and attrs.transpose_b
625-
and inputs[0].dtype == "uint8"
626-
and inputs[1].dtype == "int8"
627-
and inputs[1].shape[-2] % 16 == 0
628-
and inputs[1].shape[-1] % 4 == 0
627+
and inputs[0].dtype in ("uint8", "int8")
628+
and inputs[1].dtype in ("int8", "uint8")
629+
and (
630+
# legalized SIMD
631+
get_x86_simd_32bit_lanes()
632+
# unknown SIMD
633+
or (inputs[1].shape[-2] % 16 == 0 and inputs[1].shape[-1] % 4 == 0)
634+
)
629635
):
630636
strategy.add_implementation(
631637
wrap_compute_batch_matmul(topi.x86.batch_matmul_int8_compute, need_out_dtype=True),

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
Shuffle,
7777
SizeVar,
7878
StringImm,
79+
ArrayIntImm,
7980
Sub,
8081
Var,
8182
)
@@ -1869,6 +1870,11 @@ def wrapped(*args, **kwargs):
18691870

18701871

18711872
reinterpret = _dtype_forward(_tir_op.reinterpret)
1873+
sextend = _dtype_forward(_tir_op.sextend)
1874+
zextend = _dtype_forward(_tir_op.zextend)
1875+
truncate = _dtype_forward(_tir_op.truncate)
1876+
vectorpermute = _dtype_forward(_tir_op.vectorpermute)
1877+
vectorshuffle = _dtype_forward(_tir_op.vectorshuffle)
18721878
call_extern = _dtype_forward(_tir_op.call_extern)
18731879
call_intrin = _dtype_forward(_tir_op.call_intrin)
18741880
call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin)
@@ -2072,6 +2078,11 @@ def wrapped(*args, **kwargs):
20722078
"q_multiply_shift_per_axis",
20732079
"ret",
20742080
"reinterpret",
2081+
"sextend",
2082+
"zextend",
2083+
"truncate",
2084+
"vectorpermute",
2085+
"vectorshuffle",
20752086
"round",
20762087
"rsqrt",
20772088
"shift_left",
@@ -2155,6 +2166,7 @@ def wrapped(*args, **kwargs):
21552166
"FloatImm",
21562167
"IntImm",
21572168
"StringImm",
2169+
"ArrayIntImm",
21582170
"Cast",
21592171
"Add",
21602172
"Sub",

python/tvm/target/x86.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from .codegen import target_has_features
2020

2121

22-
@register_func("tvm.topi.x86.utils.get_simd_32bit_lanes")
23-
def get_simd_32bit_lanes():
22+
@register_func("tvm.topi.x86.utils.get_x86_simd_32bit_lanes")
23+
def get_x86_simd_32bit_lanes():
2424
"""X86 SIMD optimal vector length lookup.
2525
Parameters
2626
----------
@@ -29,9 +29,13 @@ def get_simd_32bit_lanes():
2929
vec_len : int
3030
The optimal vector length of CPU from the global context target.
3131
"""
32-
vec_len = 4
33-
if target_has_features(["avx512bw", "avx512f"]):
32+
vec_len = None
33+
if target_has_features("avx512vnni") or target_has_features("avxvnni"):
34+
vec_len = 16
35+
elif target_has_features(["avx512bw", "avx512f"]):
3436
vec_len = 16
3537
elif target_has_features("avx2"):
3638
vec_len = 8
39+
elif target_has_features("ssse3"):
40+
vec_len = 4
3741
return vec_len

python/tvm/tir/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from .buffer import Buffer, decl_buffer, DataProducer
2323
from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
24-
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
24+
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, ArrayIntImm, Cast
2525
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
2626
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
2727
from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle
@@ -73,8 +73,8 @@
7373
ptx_wait_barrier,
7474
create_barriers,
7575
)
76-
from .op import vectorlow, vectorhigh, vectorcombine
77-
from .op import infinity, reinterpret
76+
from .op import vectorlow, vectorhigh, vectorcombine, vectorpermute, vectorshuffle
77+
from .op import infinity, reinterpret, zextend, sextend, truncate
7878
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
7979
from .op import sin, sinh, asin, asinh
8080
from .op import cos, cosh, acos, acosh
@@ -88,6 +88,7 @@
8888
from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right
8989
from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace
9090
from .op import start_profile_intrinsic, end_profile_intrinsic
91+
from .op import atomic_add
9192
from .generic import add, subtract, multiply
9293

9394
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError

0 commit comments

Comments
 (0)