From 0d338828eebaa3ff705e8521f2a1b3530f73dc7d Mon Sep 17 00:00:00 2001 From: Balint Cristian Date: Fri, 13 Oct 2023 17:32:35 +0300 Subject: [PATCH] [TOPI][TIR][TE][x86] Extend x86 SIMD (u)int8 coverage for dense & conv2d --- include/tvm/tir/builtin.h | 28 ++- include/tvm/tir/expr.h | 33 +++ include/tvm/tir/expr_functor.h | 4 + python/tvm/autotvm/task/task.py | 2 + python/tvm/ir/json_compact.py | 1 + python/tvm/relay/op/nn/_nn.py | 2 +- python/tvm/relay/op/strategy/x86.py | 20 +- python/tvm/script/ir_builder/tir/ir.py | 12 ++ python/tvm/target/x86.py | 12 +- python/tvm/tir/__init__.py | 7 +- python/tvm/tir/expr.py | 30 +++ python/tvm/tir/op.py | 155 +++++++++++++- python/tvm/topi/x86/batch_matmul.py | 34 ++- python/tvm/topi/x86/conv2d_alter_op.py | 10 +- python/tvm/topi/x86/conv2d_avx_1x1.py | 30 ++- python/tvm/topi/x86/conv2d_avx_common.py | 10 +- python/tvm/topi/x86/conv2d_int8.py | 10 +- python/tvm/topi/x86/conv3d.py | 4 +- python/tvm/topi/x86/dense.py | 44 ++-- python/tvm/topi/x86/dense_alter_op.py | 84 ++++++-- python/tvm/topi/x86/depthwise_conv2d.py | 4 +- python/tvm/topi/x86/group_conv2d.py | 5 +- python/tvm/topi/x86/sparse.py | 4 +- python/tvm/topi/x86/tensor_intrin.py | 157 ++++++++------ python/tvm/utils/roofline/x86.py | 3 +- src/ir/attr_functor.h | 2 + src/relay/printer/relay_text_printer.cc | 12 ++ src/relay/printer/text_printer.h | 4 +- src/relay/printer/tir_text_printer.cc | 13 ++ src/relay/printer/tvmscript_printer.cc | 16 +- src/script/printer/legacy_repr.cc | 13 ++ src/script/printer/tir/expr.cc | 8 + src/target/llvm/codegen_arm.cc | 8 +- src/target/llvm/codegen_llvm.cc | 53 ++++- src/tir/ir/expr.cc | 26 ++- src/tir/ir/expr_functor.cc | 2 + src/tir/op/builtin.cc | 28 +++ src/tir/transforms/common_subexpr_elim.cc | 3 +- .../transforms/common_subexpr_elim_tools.cc | 6 +- src/tir/transforms/install_debug_spans.h | 3 +- tests/python/contrib/test_gemm_acc32_simd.py | 142 +++++++++++++ tests/python/relay/test_op_level1.py | 72 +++++-- tests/python/relay/test_op_level10.py | 100 ++++++--- tests/python/relay/test_op_level2.py | 197 ++++++++++++------ tests/python/unittest/test_tir_constructor.py | 7 + tests/python/unittest/test_tir_nodes.py | 7 + tests/python/unittest/test_tir_op_types.py | 46 ++++ .../test_tvmscript_printer_metadata.py | 18 ++ .../unittest/test_tvmscript_printer_tir.py | 5 + 49 files changed, 1208 insertions(+), 288 deletions(-) create mode 100644 tests/python/contrib/test_gemm_acc32_simd.py diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 65012c6c0f0f..43a4e0f1d3b0 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -50,6 +50,21 @@ TVM_DLL const Op& ret(); */ TVM_DLL const Op& reinterpret(); +/*! + * \brief Zero extend the value using the target type. + */ +TVM_DLL const Op& zextend(); + +/*! + * \brief Sign extend the value using the target type. + */ +TVM_DLL const Op& sextend(); + +/*! + * \brief Truncate the value using the target type. + */ +TVM_DLL const Op& truncate(); + /*! * \brief Marks a condition is likely going to happen. */ @@ -769,9 +784,20 @@ TVM_DLL const Op& vectorlow(); TVM_DLL const Op& vectorcombine(); /*! - * \brief atomic add instruction, corresponding e.g. to atomicAdd in CUDA + * \brief Shuffle two vectors using indices. + */ +TVM_DLL const Op& vectorshuffle(); + +/*! + * \brief Permute vector using indices. + */ +TVM_DLL const Op& vectorpermute(); + +/*! + * \brief Atomic add instruction. */ TVM_DLL const Op& atomic_add(); + /*! * \brief Create an Nd memory allocation with storage scope */ diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 4e29eddadd8c..180d55719897 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -82,6 +82,39 @@ class StringImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(StringImmNode); }; +/*! \brief Array of integer constants */ +class ArrayIntImmNode : public PrimExprNode { + public: + /*! \brief The constant value content. */ + Array data; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &dtype); + v->Visit("data", &data); + v->Visit("span", &span); + } + + bool SEqualReduce(const ArrayIntImmNode* other, SEqualReducer equal) const { + return equal(data, other->data); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(data); } + + static constexpr const char* _type_key = "tir.ArrayIntImm"; + TVM_DECLARE_FINAL_OBJECT_INFO(ArrayIntImmNode, PrimExprNode); +}; + +/*! + * \brief Managed reference to ArrayIntImmNode. + * \sa ArrayIntImmNode + */ +class ArrayIntImm : public PrimExpr { + public: + TVM_DLL ArrayIntImm(Array data, Span span = Span()); + TVM_DEFINE_OBJECT_REF_METHODS(ArrayIntImm, PrimExpr, ArrayIntImmNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ArrayIntImmNode); +}; + /*! * \brief Cast value from one data type to another. * \note The lanes of value should keep fixed. diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h index 3f66164b42c0..7299e41a980d 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tir/expr_functor.h @@ -149,6 +149,7 @@ class ExprFunctor { virtual R VisitExpr_(const IntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const FloatImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const StringImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const ArrayIntImmNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const AnyNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); @@ -192,6 +193,7 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(IntImmNode); IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode); IR_EXPR_FUNCTOR_DISPATCH(StringImmNode); + IR_EXPR_FUNCTOR_DISPATCH(ArrayIntImmNode); IR_EXPR_FUNCTOR_DISPATCH(AnyNode); return vtable; } @@ -243,6 +245,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor { void VisitExpr_(const IntImmNode* op) override; void VisitExpr_(const FloatImmNode* op) override; void VisitExpr_(const StringImmNode* op) override; + void VisitExpr_(const ArrayIntImmNode* op) override; void VisitExpr_(const AnyNode* op) override; }; @@ -289,6 +292,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor { PrimExpr VisitExpr_(const IntImmNode* op) override; PrimExpr VisitExpr_(const FloatImmNode* op) override; PrimExpr VisitExpr_(const StringImmNode* op) override; + PrimExpr VisitExpr_(const ArrayIntImmNode* op) override; PrimExpr VisitExpr_(const AnyNode* op) override; }; diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 575325c80e5b..5ecf745504b2 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -65,6 +65,8 @@ def _encode(x): return x if isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)): return x.value + if isinstance(x, expr.ArrayIntImm): + return x.data if isinstance(x, runtime.container.String): return str(x) if x is None: diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 6ce2a8b9e241..c1ee212143d1 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -191,6 +191,7 @@ def _convert(item, nodes): "Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")], "SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")], "StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")], + "ArrayIntImm": [_rename("tir.ArrayIntImm"), _update_from_std_str("data")], "Cast": _rename("tir.Cast"), "Add": _rename("tir.Add"), "Sub": _rename("tir.Sub"), diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index c68685f0ae09..d31dea5ae703 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -80,7 +80,7 @@ def legalize_dense(attrs, inputs, types): Parameters ---------- attrs : tvm.ir.Attrs - Attributes of current convolution + Attributes of current dense operation inputs : list of tvm.relay.Expr The args of the Relay expr to be legalized types : list of types diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 1b69c7a6ca42..e0fd7e97aa15 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -24,6 +24,7 @@ from tvm.meta_schedule import is_meta_schedule_enabled from tvm.relay.ty import is_dynamic from tvm.te import SpecializedCondition +from tvm.target.x86 import get_x86_simd_32bit_lanes from .. import op as _op from .generic import * @@ -588,11 +589,12 @@ def dense_strategy_cpu(attrs, inputs, out_type, target): def dense_pack_strategy_cpu(attrs, inputs, out_type, target): """dense_pack x86 strategy""" strategy = _op.OpStrategy() + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 if ( - inputs[0].dtype == "uint8" - and inputs[1].dtype == "int8" + inputs[0].dtype in ("uint8", "int8") + and inputs[1].dtype in ("int8", "uint8") and out_type.dtype == "int32" - and attrs["weight_layout"] == "NC16n4c" + and attrs["weight_layout"] == f"NC{vec_width}n4c" ): strategy.add_implementation( wrap_compute_dense(topi.x86.dense_int8), @@ -622,10 +624,14 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): if ( not attrs.transpose_a and attrs.transpose_b - and inputs[0].dtype == "uint8" - and inputs[1].dtype == "int8" - and inputs[1].shape[-2] % 16 == 0 - and inputs[1].shape[-1] % 4 == 0 + and inputs[0].dtype in ("uint8", "int8") + and inputs[1].dtype in ("int8", "uint8") + and ( + # legalized SIMD + get_x86_simd_32bit_lanes() + # unknown SIMD + or (inputs[1].shape[-2] % 16 == 0 and inputs[1].shape[-1] % 4 == 0) + ) ): strategy.add_implementation( wrap_compute_batch_matmul(topi.x86.batch_matmul_int8_compute, need_out_dtype=True), diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 5471288878f5..0edb71bb9a6d 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -76,6 +76,7 @@ Shuffle, SizeVar, StringImm, + ArrayIntImm, Sub, Var, ) @@ -1869,6 +1870,11 @@ def wrapped(*args, **kwargs): reinterpret = _dtype_forward(_tir_op.reinterpret) +sextend = _dtype_forward(_tir_op.sextend) +zextend = _dtype_forward(_tir_op.zextend) +truncate = _dtype_forward(_tir_op.truncate) +vectorpermute = _dtype_forward(_tir_op.vectorpermute) +vectorshuffle = _dtype_forward(_tir_op.vectorshuffle) call_extern = _dtype_forward(_tir_op.call_extern) call_intrin = _dtype_forward(_tir_op.call_intrin) call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin) @@ -2072,6 +2078,11 @@ def wrapped(*args, **kwargs): "q_multiply_shift_per_axis", "ret", "reinterpret", + "sextend", + "zextend", + "truncate", + "vectorpermute", + "vectorshuffle", "round", "rsqrt", "shift_left", @@ -2155,6 +2166,7 @@ def wrapped(*args, **kwargs): "FloatImm", "IntImm", "StringImm", + "ArrayIntImm", "Cast", "Add", "Sub", diff --git a/python/tvm/target/x86.py b/python/tvm/target/x86.py index c040eface808..d0b9e1039a24 100644 --- a/python/tvm/target/x86.py +++ b/python/tvm/target/x86.py @@ -19,8 +19,8 @@ from .codegen import target_has_features -@register_func("tvm.topi.x86.utils.get_simd_32bit_lanes") -def get_simd_32bit_lanes(): +@register_func("tvm.topi.x86.utils.get_x86_simd_32bit_lanes") +def get_x86_simd_32bit_lanes(): """X86 SIMD optimal vector length lookup. Parameters ---------- @@ -29,9 +29,13 @@ def get_simd_32bit_lanes(): vec_len : int The optimal vector length of CPU from the global context target. """ - vec_len = 4 - if target_has_features(["avx512bw", "avx512f"]): + vec_len = None + if target_has_features("avx512vnni") or target_has_features("avxvnni"): + vec_len = 16 + elif target_has_features(["avx512bw", "avx512f"]): vec_len = 16 elif target_has_features("avx2"): vec_len = 8 + elif target_has_features("ssse3"): + vec_len = 4 return vec_len diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index f0500290b888..a37418717efd 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -21,7 +21,7 @@ from .buffer import Buffer, decl_buffer, DataProducer from .data_layout import Layout, BijectiveLayout, bijective_layout, layout -from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast +from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, ArrayIntImm, Cast from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not from .expr import Select, BufferLoad, ProducerLoad, Ramp, Broadcast, Shuffle @@ -73,8 +73,8 @@ ptx_wait_barrier, create_barriers, ) -from .op import vectorlow, vectorhigh, vectorcombine -from .op import infinity, reinterpret +from .op import vectorlow, vectorhigh, vectorcombine, vectorpermute, vectorshuffle +from .op import infinity, reinterpret, zextend, sextend, truncate from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz from .op import sin, sinh, asin, asinh from .op import cos, cosh, acos, acosh @@ -88,6 +88,7 @@ from .op import q_multiply_shift, q_multiply_shift_per_axis, shift_left, shift_right from .op import TVMBackendAllocWorkspace, TVMBackendFreeWorkspace from .op import start_profile_intrinsic, end_profile_intrinsic +from .op import atomic_add from .generic import add, subtract, multiply from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index f93e39ee0fbd..9425d1e1cfd6 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -602,6 +602,36 @@ def __hash__(self): return PrimExpr.__hash__(self) +@tvm._ffi.register_object("tir.ArrayIntImm") # type: ignore +class ArrayIntImm(ConstExpr): + """Array of integer constants. + + Parameters + ---------- + data : list + The list with values of the function. + + span : Optional[Span] + The location of this itervar in the source code. + """ + + def __init__(self, data, span=None): + self.__init_handle_by_constructor__(_ffi_api.ArrayIntImm, data, span) # type: ignore + + def __eq__(self, other): + if isinstance(other, ConstExpr): + return str(self.data) == str(other.data) + return str(self.data) == str(other) + + def __ne__(self, other): + if isinstance(other, ConstExpr): + return str(self.data) != str(other.data) + return str(self.data) != str(other) + + def __hash__(self): + return PrimExpr.__hash__(self) + + @tvm._ffi.register_object("tir.Cast") class Cast(PrimExprWithOp): """Cast expression. diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 905d14296d98..a1f3e5298b4b 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=redefined-builtin, invalid-name """Operators used in TIR expression.""" -import warnings from typing import Any, Optional import tvm._ffi @@ -271,11 +270,11 @@ def call_llvm_intrin(dtype, name, *args, span=None): else: llvm_id = name if llvm_id == 0: - warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0") + raise ValueError(f"Unknown llvm intrinsic function {name}") return call_intrin( dtype, Op.get("tir.call_llvm_intrin"), - tvm.tir.const(llvm_id, "uint32"), + codegen.llvm_get_intrinsic_name(llvm_id), *args, span=span, ) @@ -293,7 +292,7 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None): The name of the llvm intrinsic function. args : list - Poistional arguments. + Positional arguments. span : Optional[Span] The location of this operator in the source code. @@ -313,11 +312,11 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None): else: llvm_id = name if llvm_id == 0: - warnings.warn(f"Unknown llvm intrinsic function {name}, falling back to 0") + raise ValueError(f"Unknown llvm intrinsic function {name}") return call_intrin( dtype, Op.get("tir.call_llvm_pure_intrin"), - tvm.tir.const(llvm_id, "uint32"), + codegen.llvm_get_intrinsic_name(llvm_id), *args, span=span, ) @@ -1609,6 +1608,80 @@ def vectorcombine(dtype, vec1, vec2): return call_intrin(dtype, "tir.vectorcombine", vec1, vec2) +def vectorpermute(dtype, vec, indices): + """Permute vector using position indices + + Parameters + ---------- + dtype : str + The data type of the result. + + vec : list + The input vector. + + indices : list + The list with positional indices for element permutation, + numbered from left to right starting with 0 as the first element. + The length of resulting vector is given by the length of the indices. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin(dtype, "tir.vectorpermute", vec, indices) + + +def vectorshuffle(dtype, vec0, vec1, indices): + """Shuffle two vector using position indices + + Parameters + ---------- + dtype : str + The data type of the result. + + vec0 : list + The first input vector. + + vec1 : list + The second input vector. + + indices : list + The list with positional indices for element permutation, + numbered from left to right, starting with 0 as the first element + of vec0, or `len(vec0)` as the first element starting in vec1. + The length of resulting vector is given by the length of the indices. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin(dtype, "tir.vectorshuffle", vec0, vec1, indices) + + +def atomic_add(dtype, vec0, vec1): + """Atomic add instruction. + + Parameters + ---------- + vec0 : list + The input vector. + + Parameters + ---------- + vec1 : list + The input vector. + + Returns + ------- + call : PrimExpr + The call expression. + """ + assert vec0.dtype == vec1.dtype == dtype + return call_intrin(dtype, "tir.atomic_add", vec0, vec1) + + def ret(val): """Create a tir return expression @@ -1775,7 +1848,7 @@ def infinity(dtype: str, span: Optional[Span] = None) -> Any: def reinterpret(dtype, value) -> Any: - """infinity value of dtype + """Reinterpret of the value Parameters ---------- @@ -1791,11 +1864,77 @@ def reinterpret(dtype, value) -> Any: Returns ------- value : tvm.Expr - The reinterpret cast value of dtype. + The reinterpret value of dtype. """ return call_intrin(dtype, "tir.reinterpret", value) +def zextend(dtype, value) -> Any: + """Zero extend the value + + Parameters + ---------- + dtype : str + The target data type. + + value : PrimExpr + The input value. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The zero extend value of dtype. + """ + return call_intrin(dtype, "tir.zextend", value) + + +def sextend(dtype, value) -> Any: + """Sign extend the value + + Parameters + ---------- + dtype : str + The target data type. + + value : PrimExpr + The input value. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The sign extend value of dtype. + """ + return call_intrin(dtype, "tir.sextend", value) + + +def truncate(dtype, value) -> Any: + """Truncate the value + + Parameters + ---------- + dtype : str + The target data type. + + value : PrimExpr + The input value. + + span : Optional[Span] + The location of this operator in the source code. + + Returns + ------- + value : tvm.Expr + The truncated value of dtype. + """ + return call_intrin(dtype, "tir.truncate", value) + + def exp(x): """Take exponential of input x. diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index e10313323089..76407740da52 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -22,6 +22,7 @@ from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, mkl from tvm.target.codegen import target_has_features +from tvm.target.x86 import get_x86_simd_32bit_lanes from .. import generic, nn from ..transform import layout_transform @@ -32,13 +33,20 @@ @autotvm.register_topi_compute("batch_matmul_int8.x86") def batch_matmul_int8_compute(cfg, x, y, *_): - """Compute for uint8 x int8 -> int32 batch_matmul""" + """Compute for (u)int8 x (u)int8 -> int32 batch_matmul""" batch, m, k = x.shape - packed_y_layout = "BNK16n4k" + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 + packed_y_layout = f"BNK{vec_width}n4k" packed_y = layout_transform(y, "BNK", packed_y_layout) _, n_o, _, n_i, _ = packed_y.shape ak = te.reduce_axis((0, k), name="k") - if target_has_features(["avx512bw", "avx512f"]): + if ( + target_has_features(["avx512bw", "avx512f"]) + or target_has_features("avx512vnni") + or target_has_features("avxvnni") + or target_has_features("avx2") + or target_has_features("ssse3") + ): attrs_info = {"schedule_rule": "batch_matmul_int8"} else: attrs_info = None @@ -47,9 +55,9 @@ def batch_matmul_int8_compute(cfg, x, y, *_): (batch, m, n_o * n_i), lambda b, i, j: te.sum( x[b, i, ak].astype("int32") - * packed_y[b, tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype( - "int32" - ), + * packed_y[ + b, tvm.tir.indexdiv(j, vec_width), tvm.tir.indexdiv(ak, 4), j % vec_width, ak % 4 + ].astype("int32"), axis=ak, ), tag="batch_matmul_int8", @@ -222,7 +230,8 @@ def _callback(op): _, _, y, x = s[Crf].op.axis s[Crf].fuse(y, x) s[Crf].vectorize(s[Crf].op.axis[0]) - s[O].pragma(bxyo, "auto_unroll_max_step", 16) + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 + s[O].pragma(bxyo, "auto_unroll_max_step", vec_width) traverse_inline(s, outs[0].op, _callback) return s @@ -238,7 +247,13 @@ def _callback(op): layout_trans = op.input_tensors[1] if target_has_features("amx-int8"): batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans) - elif target_has_features(["avx512bw", "avx512f"]): + elif ( + target_has_features(["avx512bw", "avx512f"]) + or target_has_features("avxvnni") + or target_has_features("avx512vnni") + or target_has_features("avx2") + or target_has_features("ssse3") + ): batch_matmul_int8_schedule(cfg, s, op.output(0), outs[0], layout_trans) traverse_inline(s, outs[0].op, _callback) @@ -246,7 +261,8 @@ def _callback(op): def _default_batch_matmul_config(cfg, M, N, K): - cfg["tile_k"] = SplitEntity([K // 16, 16]) + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 + cfg["tile_k"] = SplitEntity([K // vec_width, vec_width]) x_bn = get_max_power2_factor(N, 8) cfg["tile_x"] = SplitEntity([N // x_bn, x_bn]) y_bn = get_max_power2_factor(M, 8) diff --git a/python/tvm/topi/x86/conv2d_alter_op.py b/python/tvm/topi/x86/conv2d_alter_op.py index 3772aaec046d..e38bb80f6b3f 100644 --- a/python/tvm/topi/x86/conv2d_alter_op.py +++ b/python/tvm/topi/x86/conv2d_alter_op.py @@ -24,6 +24,7 @@ from tvm import te from tvm import relay from tvm import autotvm +from tvm.target.x86 import get_x86_simd_32bit_lanes from .conv2d import _get_default_config from .conv2d_int8 import is_int8_hw_support, _get_default_config_int8 from ..utils import get_const_tuple @@ -148,6 +149,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): if topi_tmpl == "conv2d_NCHWc_int8.x86": # TODO(@icemelon9, @anijain2305): Need to support data layout NHWC with kernel layout HWIO assert data_layout == "NCHW" and kernel_layout == "OIHW" + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 if cfg.is_fallback: _get_default_config_int8( cfg, @@ -159,7 +161,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): out_dtype, False, data_layout, - int32_lanes=16, + int32_lanes=vec_width, ) batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) @@ -282,11 +284,13 @@ def _conv2d_legalize(attrs, inputs, arg_types): # Collect the input exprs. data, kernel = inputs - # Intel vector intructions require data and kernel to have different dtypes. + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 + + # x86 vector intructions require data as uint8 and kernel as int8. if data_tensor.dtype == "int8" and kernel_tensor.dtype == "int8": data_dtype = "uint8" if is_int8_hw_support(data_dtype, kernel_dtype): return conv2d_alter_int8_common( - data, data_tensor, kernel, kernel_tensor, output_tensor, attrs, data_dtype, 4, 16 + data, data_tensor, kernel, kernel_tensor, output_tensor, attrs, data_dtype, 4, vec_width ) return None diff --git a/python/tvm/topi/x86/conv2d_avx_1x1.py b/python/tvm/topi/x86/conv2d_avx_1x1.py index 047377f83e86..07921dd7914e 100644 --- a/python/tvm/topi/x86/conv2d_avx_1x1.py +++ b/python/tvm/topi/x86/conv2d_avx_1x1.py @@ -21,7 +21,7 @@ import tvm from tvm import te from tvm.autotvm.task.space import OtherOptionEntity, SplitEntity -from tvm.target.x86 import get_simd_32bit_lanes +from tvm.target.x86 import get_x86_simd_32bit_lanes from ..generic import conv2d as conv2d_generic from ..nn.pad import pad @@ -31,7 +31,7 @@ def _fallback_schedule(cfg, wkl): - simd_width = get_simd_32bit_lanes() + simd_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 4 pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr HSTR, WSTR = wkl.stride_h, wkl.stride_w dilated_kernel_h = (wkl.kernel_h - 1) * wkl.dilation_h + 1 @@ -158,7 +158,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data_vec, kernel_vec, conv_out, last): kernel_vec, conv_out, last, - int32_lanes=get_simd_32bit_lanes(), + int32_lanes=get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 4, intrin=dot_16x1x16_uint8_int8_int32(), ) @@ -199,10 +199,14 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o idxd = tvm.tir.indexdiv idxm = tvm.tir.indexmod - packw_shape = (kernel_h, kernel_w, idxd(num_filter, 16), 16 * idxd(channel, 4), 4) + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 + + packw_shape = (kernel_h, kernel_w, idxd(num_filter, vec_width), vec_width * idxd(channel, 4), 4) PackW = te.compute( packw_shape, - lambda a, b, c, d, e: Filter[a, b, c * 16 + idxm(d, 16), idxd(d, 16) * 4 + e], + lambda a, b, c, d, e: Filter[ + a, b, c * vec_width + idxm(d, vec_width), idxd(d, vec_width) * 4 + e + ], name="packed_filter", ) @@ -215,9 +219,13 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o PaddedInput[ nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc ].astype(out_dtype) - * PackW[ry, rx, idxd(ff, 16), idxd(rc, 4) * 16 + idxm(ff, 16), idxm(rc, 4)].astype( - out_dtype - ), + * PackW[ + ry, + rx, + idxd(ff, vec_width), + idxd(rc, 4) * vec_width + idxm(ff, vec_width), + idxm(rc, 4), + ].astype(out_dtype), axis=[ry, rx, rc], ), name="Conv2d_1x1_Output_int8", @@ -237,13 +245,13 @@ def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last): # pylint: disable=unreachable return s - int32_lanes = 16 + int32_lanes = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 # assertion to fail the unhandled case _, _, _, ic_num = get_const_tuple(data.shape) _, _, _, oc_num = get_const_tuple(conv_out.shape) assert ic_num % 4 == 0 - assert oc_num % 16 == 0 + assert oc_num % int32_lanes == 0 ic_factor, oc_factor = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] # schedule data @@ -269,7 +277,7 @@ def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last): if C != O: batch, last_oh, last_ow, last_oc = s[O].op.axis - oc_chunk, oc_block = s[O].split(ochannel, 16) + oc_chunk, oc_block = s[O].split(ochannel, int32_lanes) # not saw perf improvement to split oh/ow here s[O].vectorize(oc_block) diff --git a/python/tvm/topi/x86/conv2d_avx_common.py b/python/tvm/topi/x86/conv2d_avx_common.py index 73283e7888dd..77b7fbfaf51c 100644 --- a/python/tvm/topi/x86/conv2d_avx_common.py +++ b/python/tvm/topi/x86/conv2d_avx_common.py @@ -18,7 +18,7 @@ """Conv2D schedule on for Intel CPU""" import tvm from tvm.autotvm.task.space import OtherOptionEntity, SplitEntity -from tvm.target.x86 import get_simd_32bit_lanes +from tvm.target.x86 import get_x86_simd_32bit_lanes from ..generic import conv2d as conv2d_generic from ..utils import get_const_tuple @@ -26,7 +26,7 @@ def _fallback_schedule(cfg, wkl): - simd_width = get_simd_32bit_lanes() + simd_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 4 pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr HSTR, WSTR = wkl.stride_h, wkl.stride_w dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1 @@ -62,7 +62,9 @@ def _fallback_schedule_int8(cfg, wkl): HSTR, WSTR = wkl.stride_h, wkl.stride_w out_width = (wkl.width + pl + pr - wkl.kernel_w) // WSTR + 1 - oc_bn = 16 + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 + + oc_bn = vec_width assert wkl.out_filter % oc_bn == 0 ic_bn = 1 @@ -174,7 +176,7 @@ def _schedule_conv_NCHWc_int8(s, cfg, data_vec, kernel_vec, conv_out, last): kernel_vec, conv_out, last, - int32_lanes=get_simd_32bit_lanes(), + int32_lanes=get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 4, intrin=dot_16x1x16_uint8_int8_int32(), inline_fused=True, ) diff --git a/python/tvm/topi/x86/conv2d_int8.py b/python/tvm/topi/x86/conv2d_int8.py index 7c01967e87d3..ab2e8fcaf34a 100644 --- a/python/tvm/topi/x86/conv2d_int8.py +++ b/python/tvm/topi/x86/conv2d_int8.py @@ -20,7 +20,7 @@ import tvm from tvm import autotvm, te -from tvm.target.x86 import target_has_features +from tvm.target.x86 import target_has_features, get_x86_simd_32bit_lanes from .. import nn, tag from ..generic import conv2d as conv2d_generic @@ -153,8 +153,12 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out oh = (ih - dilated_kernel_h + pt + pb) // sh + 1 ow = (iw - dilated_kernel_w + pl + pr) // sw + 1 + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 + cfg.define_split("tile_ic", in_channel, num_outputs=2, filter=lambda y: y.size[-1] % 4 == 0) - cfg.define_split("tile_oc", num_filter, num_outputs=2, filter=lambda y: y.size[-1] % 16 == 0) + cfg.define_split( + "tile_oc", num_filter, num_outputs=2, filter=lambda y: y.size[-1] % vec_width == 0 + ) cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64) if is_kernel_1x1: cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1]) @@ -173,7 +177,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out padding, dilation, out_dtype, - int32_lanes=16, + int32_lanes=vec_width, ) # Pack data if raw 4-D data is provided. diff --git a/python/tvm/topi/x86/conv3d.py b/python/tvm/topi/x86/conv3d.py index 20f2c4ac128c..aafc8a858ace 100644 --- a/python/tvm/topi/x86/conv3d.py +++ b/python/tvm/topi/x86/conv3d.py @@ -22,7 +22,7 @@ import tvm from tvm import autotvm, te from tvm.autotvm.task.space import OtherOptionEntity, SplitEntity -from tvm.target.x86 import get_simd_32bit_lanes +from tvm.target.x86 import get_x86_simd_32bit_lanes from ..nn.pad import pad from ..nn.utils import get_pad_tuple3d, infer_pad3d @@ -536,7 +536,7 @@ def _get_conv3d_workload(data, kernel, stride, padding, groups, out_dtype, data_ def _fallback_schedule(cfg, wkl): - simd_width = get_simd_32bit_lanes() + simd_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 4 DPAD, HPAD, WPAD = wkl.dpad, wkl.hpad, wkl.wpad DSTR, HSTR, WSTR = wkl.dstride, wkl.hstride, wkl.wstride out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1 diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index 4151ea0b7006..03728785c5ec 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -23,7 +23,7 @@ from tvm import autotvm, te from tvm.autotvm.task.space import SplitEntity from tvm.contrib import cblas, dnnl, mkl -from tvm.target.x86 import get_simd_32bit_lanes +from tvm.target.x86 import get_x86_simd_32bit_lanes from tvm.target.codegen import target_has_features from .. import generic, tag @@ -111,7 +111,7 @@ def _default_dense_pack_config(cfg, M, N, K): if isinstance(K, (tvm.tir.Var, tvm.tir.Any)): K = 16 - vec_width = get_simd_32bit_lanes() + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 4 tilex_ii = 1 for bn in range(vec_width * 2, 0, -1): if N % bn == 0: @@ -149,7 +149,7 @@ def _default_dense_nopack_config(cfg, M, N, K): if isinstance(K, (tvm.tir.Var, tvm.tir.Any)): K = 16 - vec_width = get_simd_32bit_lanes() + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 4 tilek_bn = 1 for bn in range(vec_width * 2, 0, -1): if K % bn == 0: @@ -291,20 +291,27 @@ def dense_int8(cfg, data, weight, bias=None, out_dtype=None): assert len(weight.shape) == 4 assert data.dtype == "uint8" and weight.dtype == "int8" _, _, n_inner, k_inner = get_const_tuple(weight.shape) # out_dim - assert n_inner == 16 and k_inner == 4 + simd_lanes = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 + assert n_inner == simd_lanes and k_inner == 4 return dense_int8_compute(cfg, data, weight, bias) @autotvm.register_topi_schedule("dense_int8.x86") def schedule_dense_int8(cfg, outs): - """Create a schedule for dense__int8""" + """Create a schedule for dense_int8""" s = te.create_schedule([x.op for x in outs]) def _callback(op): if "dense_int8" in op.tag: if target_has_features("amx-int8"): dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) - elif target_has_features(["avx512bw", "avx512f"]): + elif ( + target_has_features(["avx512bw", "avx512f"]) + or target_has_features("avx512vnni") + or target_has_features("avxvnni") + or target_has_features("avx2") + or target_has_features("ssse3") + ): dense_int8_schedule(cfg, s, op.output(0), outs[0]) traverse_inline(s, outs[0].op, _callback) @@ -316,18 +323,26 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): m, k = X.shape n_o, _, n_i, _ = packed_w.shape ak = te.reduce_axis((0, k), name="k") - if target_has_features(["avx512bw", "avx512f"]): + if ( + target_has_features(["avx512bw", "avx512f"]) + or target_has_features("avx512vnni") + or target_has_features("avxvnni") + or target_has_features("avx2") + or target_has_features("ssse3") + ): target_attr = {"schedule_rule": "meta_schedule.x86.dense_int8"} else: target_attr = None + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 + C = te.compute( (m, n_o * n_i), lambda i, j: te.sum( X[i, ak].astype("int32") - * packed_w[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4].astype( - "int32" - ), + * packed_w[ + tvm.tir.indexdiv(j, vec_width), tvm.tir.indexdiv(ak, 4), j % vec_width, ak % 4 + ].astype("int32"), axis=ak, ), tag="dense_int8", @@ -341,8 +356,7 @@ def dense_int8_compute(cfg, X, packed_w, bias=None): def dense_int8_schedule(cfg, s, C, O, do_parallel=True): - """Schedule dense compute using avx512 or lower instructions - including VNNI vpdpbusd instruction if possible""" + """Schedule dense compute using x86 VNNI or AVX512, AVX2, SSSE3""" # C: The output of GEMM # O: The output of the fused op def split_y(out): @@ -357,8 +371,10 @@ def split_y(out): (a_k,) = C.op.reduce_axis + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 + a_yo, a_yi = split_y(C) - a_xo, a_xi = s[C].split(C.op.axis[-1], factor=16) + a_xo, a_xi = s[C].split(C.op.axis[-1], factor=vec_width) a_ko, a_ki = s[C].split(a_k, factor=4) s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki) @@ -370,7 +386,7 @@ def split_y(out): fused = s[O].fuse(a_yo, a_xo) else: a_yo, a_yi = split_y(O) - a_xo, a_xi = s[O].split(O.op.axis[-1], factor=16) + a_xo, a_xi = s[O].split(O.op.axis[-1], factor=vec_width) s[O].reorder(a_yo, a_xo, a_yi, a_xi) s[O].vectorize(a_xi) diff --git a/python/tvm/topi/x86/dense_alter_op.py b/python/tvm/topi/x86/dense_alter_op.py index 0e9b1f7b65f0..5540251f17da 100644 --- a/python/tvm/topi/x86/dense_alter_op.py +++ b/python/tvm/topi/x86/dense_alter_op.py @@ -20,6 +20,7 @@ import tvm from tvm import autotvm, relay, te from tvm.target.codegen import target_has_features +from tvm.target.x86 import get_x86_simd_32bit_lanes from .. import nn from ..nn import dense_alter_layout @@ -28,14 +29,24 @@ def check_int8_applicable(x, y, allow_padding=False): + """Check (u)int8 SIMD elegibility.""" + # x86 SIMD simd_avai = target_has_features(["avx512bw", "avx512f"]) simd_avai |= target_has_features("amx-int8") - # TODO(vvchernov): may be also target_has_features("avx2") or lower? + simd_avai |= target_has_features("avx512vnni") + simd_avai |= target_has_features("avxvnni") + simd_avai |= target_has_features("avx2") + simd_avai |= target_has_features("ssse3") + # arm SIMD + simd_avai |= target_has_features("dotprod") + + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 + return ( simd_avai - and "int8" in x.dtype - and "int8" in y.dtype - and (allow_padding or (y.shape[-2] % 16 == 0 and y.shape[-1] % 4 == 0)) + and x.dtype in ("int8", "uint8") + and y.dtype in ("int8", "uint8") + and (allow_padding or y.shape[-2] % vec_width == 0) ) @@ -48,8 +59,13 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): M, K = get_const_tuple(data_tensor.shape) N, _ = get_const_tuple(weight_tensor.shape) - if check_int8_applicable(data_tensor, weight_tensor) and data_tensor.dtype == "uint8": - weight_layout = "NC16n4c" + if ( + check_int8_applicable(data_tensor, weight_tensor, allow_padding=True) + and data_tensor.dtype in ("uint8", "int8") + and weight_tensor.dtype in ("uint8", "int8") + ): + vec_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 + weight_layout = f"NC{vec_width}n4c" return relay.nn.contrib_dense_pack(inputs[0], inputs[1], weight_layout, None, out_dtype) _, outs = relay.backend.te_compiler.select_implementation( @@ -77,26 +93,43 @@ def _alter_dense_layout(attrs, inputs, tinfos, out_type): def int8_int8_legalize(inputs, arg_types, op, attrs, need_expand=False): - """Legalizes s8, s8 -> s32 GEMM op for VNNI.""" - if ( - check_int8_applicable(arg_types[0], arg_types[1], allow_padding=True) - and arg_types[0].dtype == "int8" - ): + """Legalizes s8, s8 -> s32 GEMM op for SIMD.""" + if check_int8_applicable(arg_types[0], arg_types[1], allow_padding=True): + x, y = inputs - x = relay.cast(x, "int32") - x = relay.add(x, relay.const(128, "int32")) - x = relay.cast(x, "uint8") - adjust_shift = relay.const(128, "int32") * relay.sum(relay.cast(y, "int32"), axis=[-1]) + # x{data} int8 -> uint8 + if arg_types[0].dtype == "int8": + x = relay.cast(x, "int32") + x = relay.add(x, relay.const(128, "int32")) + x = relay.cast(x, "uint8") - if need_expand: - adjust_shift = relay.expand_dims(adjust_shift, axis=1) + x_adjust_shift = relay.const(128, "int32") * relay.sum( + relay.cast(y, "int32"), axis=[-1] + ) + + if need_expand: + x_adjust_shift = relay.expand_dims(x_adjust_shift, axis=1) + + # y{weight} uint8 -> int8 + if arg_types[1].dtype == "uint8": + y = relay.cast(y, "int32") + y = relay.subtract(y, relay.const(128, "int32")) + y = relay.cast(y, "int8") + + y_adjust_shift = relay.const(128, "int32") * relay.sum( + relay.cast(x, "int32"), axis=[-1] + ) + + if need_expand: + y_adjust_shift = relay.expand_dims(y_adjust_shift, axis=1) analyzer = tvm.arith.Analyzer() x_shape = arg_types[0].shape y_shape = arg_types[1].shape - inst_n = 16 - inst_k = 4 + + inst_n = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 16 + inst_k = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 4 pad_n = analyzer.simplify((inst_n - y_shape[-2] % inst_n) % inst_n) pad_k = analyzer.simplify((inst_k - y_shape[-1] % inst_k) % inst_k) if pad_k != 0 or pad_n != 0: @@ -117,20 +150,27 @@ def int8_int8_legalize(inputs, arg_types, op, attrs, need_expand=False): else: out = op(x, y, **attrs) - return relay.subtract(out, adjust_shift) + if arg_types[0].dtype == "int8": + # int8->uint8 +adjust +padding + out = relay.subtract(out, x_adjust_shift) + if arg_types[1].dtype == "uint8": + # uint8->int8 +adjust +padding + out = relay.add(out, y_adjust_shift) + + return out return None @nn.dense_legalize.register("cpu") def _dense_legalize(attrs, inputs, arg_types): - """Legalizes s8, s8 -> s32 dense for VNNI.""" + """Legalizes s8, s8 -> s32 dense for SIMD.""" return int8_int8_legalize(inputs, arg_types, relay.nn.dense, attrs) @nn.batch_matmul_legalize.register("cpu") def _batch_matmul_legalize(attrs, inputs, arg_types): - """Legalizes s8, s8 -> s32 batch_matmul for VNNI.""" + """Legalizes s8, s8 -> s32 batch_matmul for SIMD.""" if attrs["transpose_a"] or not attrs["transpose_b"]: return None return int8_int8_legalize(inputs, arg_types, relay.nn.batch_matmul, attrs, need_expand=True) diff --git a/python/tvm/topi/x86/depthwise_conv2d.py b/python/tvm/topi/x86/depthwise_conv2d.py index 59d7412befc0..913bf1ccb96b 100644 --- a/python/tvm/topi/x86/depthwise_conv2d.py +++ b/python/tvm/topi/x86/depthwise_conv2d.py @@ -20,7 +20,7 @@ import tvm from tvm import autotvm, te from tvm.autotvm.task.space import OtherOptionEntity, SplitEntity -from tvm.target.x86 import get_simd_32bit_lanes +from tvm.target.x86 import get_x86_simd_32bit_lanes from ..nn.conv2d import unpack_NCHWc_to_nchw from ..nn.depthwise_conv2d import _get_workload, depthwise_conv2d_infer_layout @@ -39,7 +39,7 @@ def _fallback_schedule(cfg, wkl): wkl : topi.nn.depthwise_conv2d.Workload Convolution workload """ - simd_width = get_simd_32bit_lanes() + simd_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 4 pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr HSTR, WSTR = wkl.stride_h, wkl.stride_w diff --git a/python/tvm/topi/x86/group_conv2d.py b/python/tvm/topi/x86/group_conv2d.py index 60b99f796bf9..246f85b73d3c 100644 --- a/python/tvm/topi/x86/group_conv2d.py +++ b/python/tvm/topi/x86/group_conv2d.py @@ -21,7 +21,7 @@ import tvm from tvm import autotvm, te from tvm.autotvm.task.space import OtherOptionEntity, SplitEntity -from tvm.target.x86 import get_simd_32bit_lanes +from tvm.target.x86 import get_x86_simd_32bit_lanes from .. import tag from ..nn.conv2d import _get_workload as _get_conv2d_workload @@ -60,7 +60,8 @@ def _get_default_config( def _fallback_schedule(cfg, wkl): - simd_width = get_simd_32bit_lanes() + simd_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 4 + pad_left, pad_right = wkl.padl, wkl.padr stride_w = wkl.stride_w out_width = (wkl.width + pad_left + pad_right - wkl.kernel_w) // stride_w + 1 diff --git a/python/tvm/topi/x86/sparse.py b/python/tvm/topi/x86/sparse.py index fdbbaf1002de..4867b3936755 100644 --- a/python/tvm/topi/x86/sparse.py +++ b/python/tvm/topi/x86/sparse.py @@ -19,7 +19,7 @@ from functools import partial, reduce from tvm import autotvm, te, tir -from tvm.target.x86 import get_simd_32bit_lanes +from tvm.target.x86 import get_x86_simd_32bit_lanes from ..transform import reshape from ..utils import get_const_int, traverse_inline @@ -30,7 +30,7 @@ def schedule_sparse_dense(outs): s = te.create_schedule([x.op for x in outs]) def _callback(op): - simd_width = get_simd_32bit_lanes() + simd_width = get_x86_simd_32bit_lanes() if get_x86_simd_32bit_lanes() else 4 if op.tag == "sparse_dense_sp_lhs_csrmm" or op.tag == "sparse_dense_sp_lhs_csrmm": (y_o, y_i) = s[op].split(s[op].op.axis[1], 2) fused = s[op].fuse(s[op].op.axis[0], y_o) diff --git a/python/tvm/topi/x86/tensor_intrin.py b/python/tvm/topi/x86/tensor_intrin.py index f2e84a62ecbd..e8e86185209b 100644 --- a/python/tvm/topi/x86/tensor_intrin.py +++ b/python/tvm/topi/x86/tensor_intrin.py @@ -16,21 +16,23 @@ # under the License. """Core kernel of dot product of 4 Int8 operations""" # pylint: disable=invalid-name,unused-variable +import logging + import tvm from tvm import te import tvm.target.codegen -from tvm.target.x86 import target_has_features, get_simd_32bit_lanes +from tvm.target.x86 import target_has_features, get_x86_simd_32bit_lanes + +logger = logging.getLogger("topi") def dot_16x1x16_uint8_int8_int32(): """Dispatch the most optimized intrin depending on the target""" - assert target_has_features( - "sse4.2" - ), "An old Intel machine that does not have fast Int8 support." + assert target_has_features("sse4.2"), "An old x86 machine that does not have fast int8 support." if target_has_features("avx512vnni") or target_has_features("avxvnni"): - # VNNI capable platform + # x86 VNNI return dot_16x1x16_uint8_int8_int32_cascadelake() - # vpmaddubsw/vpmaddwd fallback + # x86 AVX512 -> AVX2 -> SSSE3 (fallthrough) return dot_16x1x16_uint8_int8_int32_skylake() @@ -63,7 +65,8 @@ def dot_16x1x16_uint8_int8_int32_skylake(): The Skylake int8 TensorIntrin that can be used in tensorizing schedule """ - int32_lanes = get_simd_32bit_lanes() + target = tvm.target.Target.current() + int32_lanes = get_x86_simd_32bit_lanes() num_int8_elements = 4 # 4 int8 elements in int32 data = te.placeholder((num_int8_elements,), dtype="uint8", name="data") kernel = te.placeholder((int32_lanes, num_int8_elements), dtype="int8", name="kernel") @@ -85,26 +88,33 @@ def _intrin_func(ins, outs): def _instr(index): # int_lx32 - output datatype after pmaddubs - 16 bits to number of lanes # int_8xl - input datatype to pmaddubs - 8 bits to number of lanes + # int_16xl - upcast datatype from 8 -> 16 bits to number of lanes # int_32xl - output datatype after pmaddw - 32 bits per number of lanes if int32_lanes == 4: int_lx32 = "int16x8" int_8xl = "int8x16" + int_16xl = "int16x16" int_32xl = "int32x4" pmaddubs = "llvm.x86.ssse3.pmadd.ub.sw.128" pmaddw = "llvm.x86.sse2.pmadd.wd" + phaddw = "llvm.x86.ssse3.phadd.d.128" elif int32_lanes == 8: int_lx32 = "int16x16" int_8xl = "int8x32" + int_16xl = "int16x32" int_32xl = "int32x8" pmaddubs = "llvm.x86.avx2.pmadd.ub.sw" pmaddw = "llvm.x86.avx2.pmadd.wd" + phaddw = "llvm.x86.avx2.phadd.d" elif int32_lanes == 16: int_lx32 = "int16x32" int_8xl = "int8x64" + int_16xl = "int16x64" int_32xl = "int32x16" pmaddubs = "llvm.x86.avx512.pmaddubs.w.512" pmaddw = "llvm.x86.avx512.pmaddw.d.512" + phaddw = None # does not exist for _m512 ib = tvm.tir.ir_builder.create() if index == 1: @@ -116,21 +126,77 @@ def _instr(index): vec_ai32 = re_int32.astype(int_32xl) vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32) vec_b = ins[1].vload([0, 0], int_8xl) - vec_one = tvm.tir.const(1, int_lx32) - pair_reduction = tvm.tir.call_llvm_pure_intrin( - int_lx32, - pmaddubs, - tvm.tir.const(2, "uint32"), - vec_a, - vec_b, - ) - quad_reduction = tvm.tir.call_llvm_pure_intrin( - int_32xl, - pmaddw, - tvm.tir.const(2, "uint32"), - pair_reduction, - vec_one, - ) + + # fast-math (may saturate on overflow) + if "fast-math" in target.keys: + msg = ( + "Using `fast-math` may overflow, make sure ranges" + " for either data is [0,128] or weight is [-64,+64]" + ) + logger.warning(msg) + pair_reduction = tvm.tir.call_llvm_pure_intrin( + int_lx32, + pmaddubs, + tvm.tir.const(2, "uint32"), + vec_a, + vec_b, + ) + quad_reduction = tvm.tir.call_llvm_pure_intrin( + int_32xl, + pmaddw, + tvm.tir.const(2, "uint32"), + pair_reduction, + tvm.tir.const(1, int_lx32), + ) + # no-fast-math (no overflow and no saturate) + else: + vec_a_w = tvm.tir.zextend(int_16xl, vec_a) + vec_b_w = tvm.tir.sextend(int_16xl, vec_b) + pair_reduction_lo = tvm.tir.call_llvm_pure_intrin( + int_32xl, + pmaddw, + tvm.tir.const(2, "uint32"), + tvm.tir.vectorlow("", vec_a_w), + tvm.tir.vectorlow("", vec_b_w), + ) + pair_reduction_hi = tvm.tir.call_llvm_pure_intrin( + int_32xl, + pmaddw, + tvm.tir.const(2, "uint32"), + tvm.tir.vectorhigh("", vec_a_w), + tvm.tir.vectorhigh("", vec_b_w), + ) + if int32_lanes in [4, 8]: + # reduce pairs _m128 & _m256 + quad_reduction = tvm.tir.call_llvm_pure_intrin( + int_32xl, + phaddw, + tvm.tir.const(2, "uint32"), + pair_reduction_lo, + pair_reduction_hi, + ) + # _m256 result needs reorder + if int32_lanes == 8: + quad_reduction = tvm.tir.vectorpermute( + int_32xl, quad_reduction, [0, 1, 4, 5, 2, 3, 6, 7] + ) + # there is no phaddw pair reductor for _m512 + elif int32_lanes == 16: + pairs_even = tvm.tir.vectorshuffle( + int_32xl, + pair_reduction_lo, + pair_reduction_hi, + list(range(0, int32_lanes * 2, 2)), + ) + pairs_odd = tvm.tir.vectorshuffle( + int_32xl, + pair_reduction_lo, + pair_reduction_hi, + list(range(1, int32_lanes * 2, 2)), + ) + # final reduce prearranged pairs + quad_reduction = tvm.tir.atomic_add(int_32xl, pairs_even, pairs_odd) + if index == 0: ib.emit(outs[0].vstore(0, quad_reduction)) else: @@ -300,42 +366,17 @@ def _instr(index): vec_ai32 = re_int32.astype("int32x16") vec_b = ins[1].vload([0, 0], "int8x64") - vnni_inst_name = "llvm.x86.avx512.vpdpbusd.512" - llvm_id = tvm.target.codegen.llvm_lookup_intrinsic_id(vnni_inst_name) - - if llvm_id != 0: # VNNI is available for current LLVM version - vec_bi32 = tvm.tir.call_intrin("int32x16", "tir.reinterpret", vec_b) - vec_c = outs[0].vload([0], "int32x16") - quad_reduction = tvm.tir.call_llvm_pure_intrin( - "int32x16", - "llvm.x86.avx512.vpdpbusd.512", - tvm.tir.const(3, "uint32"), - vec_c, - vec_ai32, - vec_bi32, - ) - ib.emit(outs[0].vstore(0, quad_reduction)) - else: # Fall back to the normal AVX512 - vec_a = tvm.tir.call_intrin("int8x64", "tir.reinterpret", vec_ai32) - vec_one = tvm.tir.const(1, "int16x32") - pair_reduction = tvm.tir.call_llvm_pure_intrin( - "int16x32", - "llvm.x86.avx512.pmaddubs.w.512", - tvm.tir.const(2, "uint32"), - vec_a, - vec_b, - ) - quad_reduction = tvm.tir.call_llvm_pure_intrin( - "int32x16", - "llvm.x86.avx512.pmaddw.d.512", - tvm.tir.const(2, "uint32"), - pair_reduction, - vec_one, - ) - if index == 0: - ib.emit(outs[0].vstore(0, quad_reduction)) - else: - ib.emit(outs[0].vstore(0, quad_reduction + outs[0].vload([0], "int32x16"))) + vec_bi32 = tvm.tir.call_intrin("int32x16", "tir.reinterpret", vec_b) + vec_c = outs[0].vload([0], "int32x16") + quad_reduction = tvm.tir.call_llvm_pure_intrin( + "int32x16", + "llvm.x86.avx512.vpdpbusd.512", + tvm.tir.const(3, "uint32"), + vec_c, + vec_ai32, + vec_bi32, + ) + ib.emit(outs[0].vstore(0, quad_reduction)) return ib.get() # body, reset, update diff --git a/python/tvm/utils/roofline/x86.py b/python/tvm/utils/roofline/x86.py index 5d2dd27e523b..76e86f0c662d 100644 --- a/python/tvm/utils/roofline/x86.py +++ b/python/tvm/utils/roofline/x86.py @@ -62,7 +62,8 @@ def _detect_vec_width_registers( and target.keys[0] == "cpu" ): with target: - vec_width = x86.get_simd_32bit_lanes() * 4 # in number of bytes + simd_width = x86.get_x86_simd_32bit_lanes() if x86.get_x86_simd_32bit_lanes() else 4 + vec_width = simd_width * 4 # in number of bytes else: raise RuntimeError(f"Cannot determine vector width for target {target}") if num_vector_registers is None: diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index 12b4f6f65b11..5f908a33ae76 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -79,6 +79,7 @@ class AttrFunctor { virtual R VisitAttr_(const tir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tir::ArrayIntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; // deep comparison of symbolic integer expressions. virtual R VisitAttr_(const tir::VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const tir::SizeVarNode* op, Args... args) { @@ -116,6 +117,7 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(IntImmNode); ATTR_FUNCTOR_DISPATCH(FloatImmNode); ATTR_FUNCTOR_DISPATCH(StringImmNode); + ATTR_FUNCTOR_DISPATCH(ArrayIntImmNode); ATTR_FUNCTOR_DISPATCH(VarNode); ATTR_FUNCTOR_DISPATCH(SizeVarNode); ATTR_FUNCTOR_DISPATCH(AddNode); diff --git a/src/relay/printer/relay_text_printer.cc b/src/relay/printer/relay_text_printer.cc index 618e8fe138d8..766458518c60 100644 --- a/src/relay/printer/relay_text_printer.cc +++ b/src/relay/printer/relay_text_printer.cc @@ -799,6 +799,18 @@ Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) { return Doc::StrLiteral(op->value); } +Doc RelayTextPrinter::VisitAttr_(const tir::ArrayIntImmNode* op) { + Doc doc; + doc << "["; + std::vector arr_vals; + for (const auto& val : op->data) { + arr_vals.push_back(PrintAttributeValue(val)); + } + doc << Doc::Concat(arr_vals); + doc << "]"; + return doc; +} + /*! * \brief Attribute printer which prints the attributes in the call. */ diff --git a/src/relay/printer/text_printer.h b/src/relay/printer/text_printer.h index a6684bf4e5ce..0ee316745b4c 100644 --- a/src/relay/printer/text_printer.h +++ b/src/relay/printer/text_printer.h @@ -190,6 +190,7 @@ class RelayTextPrinter : public ExprFunctor, Doc VisitAttr_(const tir::IntImmNode* op) final; Doc VisitAttr_(const tir::FloatImmNode* op) final; Doc VisitAttr_(const tir::StringImmNode* op) final; + Doc VisitAttr_(const tir::ArrayIntImmNode* op) final; private: /*! \brief Whether to print meta data. */ @@ -242,7 +243,7 @@ class MetaCollector : public StmtExprVisitor { void Collect(const ObjectRef& n) { // these nodes can be print directly(StringLiteral or use identifier to identify) if (!n.defined() || n.as() || n.as() || n.as() || - n.as() || n.as() || n.as()) { + n.as() || n.as() || n.as() || n.as()) { return; } if (n->IsInstance()) { @@ -290,6 +291,7 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitExpr_(const IntImmNode* op) override; Doc VisitExpr_(const FloatImmNode* op) override; Doc VisitExpr_(const StringImmNode* op) override; + Doc VisitExpr_(const ArrayIntImmNode* op) override; Doc VisitExpr_(const CastNode* op) override; Doc VisitExpr_(const tir::VarNode* op) override; Doc VisitExpr_(const AddNode* op) override; diff --git a/src/relay/printer/tir_text_printer.cc b/src/relay/printer/tir_text_printer.cc index e9a9ee231358..868e1e3b6b04 100644 --- a/src/relay/printer/tir_text_printer.cc +++ b/src/relay/printer/tir_text_printer.cc @@ -296,6 +296,19 @@ Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) { Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); } +Doc TIRTextPrinter::VisitExpr_(const ArrayIntImmNode* op) { + Doc doc; + doc << "["; + for (size_t i = 0; i < op->data.size(); ++i) { + doc << Print(op->data[i]); + if (i < op->data.size() - 1) { + doc << ", "; + } + } + doc << "]"; + return doc; +} + Doc TIRTextPrinter::VisitExpr_(const CastNode* op) { Doc doc; doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")"; diff --git a/src/relay/printer/tvmscript_printer.cc b/src/relay/printer/tvmscript_printer.cc index b0085b82426e..7b2e2bc927f4 100644 --- a/src/relay/printer/tvmscript_printer.cc +++ b/src/relay/printer/tvmscript_printer.cc @@ -240,6 +240,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitExpr_(const IntImmNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const FloatImmNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const StringImmNode* op, ExprPrecedence* out_precedence) override; + Doc VisitExpr_(const ArrayIntImmNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const ProducerLoadNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) override; @@ -784,6 +785,19 @@ Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op, ExprPrecedence* out_pr return Doc::StrLiteral(op->value); } +Doc TVMScriptPrinter::VisitExpr_(const ArrayIntImmNode* op, ExprPrecedence* out_precedence) { + Doc doc; + doc << "["; + for (size_t i = 0; i < op->data.size(); ++i) { + doc << Print(op->data[i]); + if (i < op->data.size() - 1) { + doc << ", "; + } + } + doc << "]"; + return doc; +} + Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; @@ -1559,7 +1573,7 @@ Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { } Doc TVMScriptPrinter::PrintBody(const Stmt& body) { - int memo_num_child, memo_current_num; + int memo_num_child = 0, memo_current_num = 0; std::swap(memo_num_child, num_child_); std::swap(memo_current_num, current_num_); diff --git a/src/script/printer/legacy_repr.cc b/src/script/printer/legacy_repr.cc index 01fb514c497e..6f6ccbd50ec7 100644 --- a/src/script/printer/legacy_repr.cc +++ b/src/script/printer/legacy_repr.cc @@ -270,6 +270,19 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) (*p) << '\"' << support::StrEscape(op->value) << '\"'; }); +TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { + auto* op = static_cast(node.get()); + (*p) << '['; + for (size_t i = 0; i < op->data.size(); ++i) { + p->Print(op->data[i]); + if (i < op->data.size() - 1) { + (*p) << ", "; + } + } + (*p) << ']'; + }); + TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 8de142f8613e..e24709c23b67 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -115,6 +115,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } }); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](tir::ArrayIntImm a, ObjectPath p, IRDocsifier d) -> Doc { + return TIR(d, "ArrayIntImm") + ->Call({ + d->AsDoc(a->data, p->Attr("data")), + }); + }); + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Cast cast, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc dtype = LiteralDoc::DataType(cast->dtype, p->Attr("dtype")); diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 15d1699b3b59..3617a46672a9 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -55,7 +55,13 @@ class CodeGenARM final : public CodeGenCPU { llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) { - llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); + llvm::Intrinsic::ID id = 0; + if (op->args[0]->IsInstance()) { + id = llvm::Function::lookupIntrinsicID(Downcast(op->args[0])->value.c_str()); + } else if (op->args[0]->IsInstance()) { + id = static_cast(Downcast(op->args[0])->value); + } + assert(id != 0); if (id == llvm::Intrinsic::ctpop) { PrimExpr e = ARMPopcount(op); return CodeGenCPU::CreateIntrinsic(e.as()); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 3d4d3def2411..fe773d3ac0dc 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1320,7 +1320,13 @@ void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) { llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) { ICHECK_GE(op->args.size(), 2U); - llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); + llvm::Intrinsic::ID id = 0; + if (op->args[0]->IsInstance()) { + id = llvm::Function::lookupIntrinsicID(Downcast(op->args[0])->value.c_str()); + } else if (op->args[0]->IsInstance()) { + id = static_cast(Downcast(op->args[0])->value); + } + assert(id != 0); int64_t num_signature = Downcast(op->args[1])->value; std::vector arg_value; std::vector arg_type; @@ -1441,6 +1447,15 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::reinterpret())) { llvm::Type* target = DTypeToLLVMType(op->dtype); return builder_->CreateBitCast(MakeValue(op->args[0]), target); + } else if (op->op.same_as(builtin::zextend())) { + llvm::Type* target = DTypeToLLVMType(op->dtype); + return builder_->CreateZExt(MakeValue(op->args[0]), target); + } else if (op->op.same_as(builtin::sextend())) { + llvm::Type* target = DTypeToLLVMType(op->dtype); + return builder_->CreateSExt(MakeValue(op->args[0]), target); + } else if (op->op.same_as(builtin::truncate())) { + llvm::Type* target = DTypeToLLVMType(op->dtype); + return builder_->CreateTrunc(MakeValue(op->args[0]), target); } else if (op->op.same_as(builtin::isnan())) { // TODO(hgt312): set fast math flag llvm::Value* a = MakeValue(op->args[0]); @@ -1466,9 +1481,39 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { indices.push_back(i); } return builder_->CreateShuffleVector(v0, v1, indices); + } else if (op->op.same_as(builtin::vectorpermute())) { + llvm::Value* vec = MakeValue(op->args[0]); + const ArrayIntImmNode* ind = op->args[1].as(); +#if TVM_LLVM_VERSION >= 110 + std::vector indices; +#else + std::vector indices; +#endif + for (size_t i = 0; i < ind->data.size(); ++i) { + indices.push_back(ind->data[i].IntValue()); + } +#if TVM_LLVM_VERSION >= 120 + return builder_->CreateShuffleVector(vec, indices); +#else + return builder_->CreateShuffleVector(vec, vec, indices); +#endif + } else if (op->op.same_as(builtin::vectorshuffle())) { + llvm::Value* vec0 = MakeValue(op->args[0]); + llvm::Value* vec1 = MakeValue(op->args[1]); + const ArrayIntImmNode* ind = op->args[2].as(); +#if TVM_LLVM_VERSION >= 110 + std::vector indices; +#else + std::vector indices; +#endif + for (size_t i = 0; i < ind->data.size(); ++i) { + indices.push_back(ind->data[i].IntValue()); + } + return builder_->CreateShuffleVector(vec0, vec1, indices); } else if (op->op.same_as(builtin::atomic_add())) { - // TODO(masahi): Support atomic for CPU backend - LOG(FATAL) << "CPU backend does not support atomic add yet."; + llvm::Value* v0 = MakeValue(op->args[0]); + llvm::Value* v1 = MakeValue(op->args[1]); + return builder_->CreateAdd(v0, v1); } else if (op->op.same_as(builtin::start_profile_intrinsic()) || op->op.same_as(builtin::end_profile_intrinsic())) { LOG(INFO) << "Ignoring profile_intrinsic ... " << op->op; @@ -1837,7 +1882,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { std::vector idx(op->indices.size()); for (int i = 0, e = op->indices.size(); i < e; ++i) { const int64_t* val = as_const_int(op->indices[i]); - ICHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indeces are suppose to be int, " + ICHECK(val && *val >= 0 && *val < total_lanes) << "Shuffled indices are suppose to be int, " << "but get " << op->indices[i] << "\n"; idx[i] = *val; } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index d590f8b2dd8b..06b6734282b4 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -187,6 +187,21 @@ TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](String value, Span span) TVM_REGISTER_NODE_TYPE(StringImmNode); +// ArrayIntImm +ArrayIntImm::ArrayIntImm(Array data, Span span) { + ObjectPtr node = make_object(); + node->dtype = DataType::Handle(); + node->data = std::move(data); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.ArrayIntImm").set_body_typed([](Array data, Span span) { + return ArrayIntImm(data, span); +}); + +TVM_REGISTER_NODE_TYPE(ArrayIntImmNode); + // Cast Cast::Cast(DataType t, PrimExpr value, Span span) { ICHECK(value.defined()); @@ -511,11 +526,18 @@ TVM_REGISTER_GLOBAL("tir.Call") .set_body_typed([](DataType type, RelayExpr op, Array args, Span span) { Array prim_expr_args; for (const auto& it : args) { - ICHECK(it->IsInstance() || it->IsInstance() || - it->IsInstance() || it->IsInstance()) + ICHECK(it->IsInstance() || it->IsInstance() || + it->IsInstance() || it->IsInstance() || + it->IsInstance()) << "Argument " << it << " is not a string or primexpr"; if (const auto* str = it.as()) { prim_expr_args.push_back(StringImm(str->data)); + } else if (const auto* arr = it.as()) { + Array indices; + for (size_t i = 0; i < arr->size(); ++i) { + indices.push_back(arr->at(i).as()->value); + } + prim_expr_args.push_back(ArrayIntImm(indices)); } else if (const auto* iter_var = it.as()) { prim_expr_args.push_back(iter_var->var); } else if (const auto* br = it.as()) { diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 8a93d9dd8242..8b787de7100f 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -78,6 +78,7 @@ DEFINE_BINOP_VISIT_(OrNode); void ExprVisitor::VisitExpr_(const IntImmNode* op) {} void ExprVisitor::VisitExpr_(const FloatImmNode* op) {} void ExprVisitor::VisitExpr_(const StringImmNode* op) {} +void ExprVisitor::VisitExpr_(const ArrayIntImmNode* op) {} void ExprVisitor::VisitExpr_(const ReduceNode* op) { VisitArray(op->axis, [this](const IterVar& r) { @@ -168,6 +169,7 @@ PrimExpr ExprMutator::VisitExpr_(const CallNode* op) { DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImmNode) DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImmNode) +DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(ArrayIntImmNode) #define DEFINE_BIOP_EXPR_MUTATE_(OP) \ PrimExpr ExprMutator::VisitExpr_(const OP##Node* op) { \ diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 1b80959b5705..df88ef36953c 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -44,6 +44,24 @@ TIR_DEFINE_BUILTIN_FUNC(reinterpret) Integer(ScriptDtypePrintLocation::kFirst)) .set_num_inputs(1); +TIR_DEFINE_BUILTIN_FUNC(zextend) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)) + .set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(sextend) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)) + .set_num_inputs(1); + +TIR_DEFINE_BUILTIN_FUNC(truncate) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)) + .set_num_inputs(1); + TIR_DEFINE_BUILTIN_FUNC(ret) .set_attr("TCallEffectKind", Integer(CallEffectKind::kControlJump)) .set_num_inputs(1); @@ -338,6 +356,16 @@ TIR_DEFINE_BUILTIN_FUNC(vectorcombine) .set_attr("TScriptDtypePrintLocation", Integer(ScriptDtypePrintLocation::kFirst)); +TIR_DEFINE_BUILTIN_FUNC(vectorpermute) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); + +TIR_DEFINE_BUILTIN_FUNC(vectorshuffle) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); + TIR_DEFINE_BUILTIN_FUNC(atomic_add) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/tir/transforms/common_subexpr_elim.cc b/src/tir/transforms/common_subexpr_elim.cc index b6c52ec1a3be..0036ae390f44 100644 --- a/src/tir/transforms/common_subexpr_elim.cc +++ b/src/tir/transforms/common_subexpr_elim.cc @@ -85,7 +85,8 @@ bool CommonSubexpressionEliminator::IsEligibleComputation(const PrimExpr& expr) return ( // In order to be eligible, the given expression should not be a constant (expr.as() == nullptr) && (expr.as() == nullptr) && - (expr.as() == nullptr) + (expr.as() == nullptr) && + (expr.as() == nullptr) // and it should not be a variable && (expr.as() == nullptr) // and it should not be a forbidden computation (function calls and loads) diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index ba101ce4e70f..3bb4957085b5 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -269,7 +269,8 @@ ComputationTable ComputationsDoneBy::GetComputationsDoneBy( // (We don't want to use a "line of cache" of that, as that would cost an empty table of // computations in memory for absolutely no gain) if (expr.as() != nullptr || expr.as() != nullptr || - expr.as() != nullptr || expr.as() != nullptr) { + expr.as() != nullptr || expr.as() != nullptr || + expr.as() != nullptr) { // Return an empty table return {}; } @@ -346,7 +347,8 @@ void ComputationsDoneBy::VisitExpr(const PrimExpr& expr) { // (We don't want to use a "line of cache" of that, as that would cost an empty table of // computations in memory for absolutely no gain) if (expr.as() != nullptr || expr.as() != nullptr || - expr.as() != nullptr || expr.as() != nullptr) { + expr.as() != nullptr || expr.as() != nullptr || + expr.as() != nullptr) { return; } diff --git a/src/tir/transforms/install_debug_spans.h b/src/tir/transforms/install_debug_spans.h index 40f3e07940cf..cfd85befbbda 100644 --- a/src/tir/transforms/install_debug_spans.h +++ b/src/tir/transforms/install_debug_spans.h @@ -64,7 +64,8 @@ X(Shuffle) \ X(IntImm) \ X(FloatImm) \ - X(StringImm) + X(StringImm) \ + X(ArrayIntImm) #define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS \ X(AttrStmt) \ diff --git a/tests/python/contrib/test_gemm_acc32_simd.py b/tests/python/contrib/test_gemm_acc32_simd.py new file mode 100644 index 000000000000..223845f23696 --- /dev/null +++ b/tests/python/contrib/test_gemm_acc32_simd.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition + +import logging + +logging.basicConfig(level=logging.ERROR) + +import tvm +from tvm import relay +from tvm import transform +from tvm.relay import testing +from tvm.testing.aot import AOTTestModel, compile_and_run +from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER + + +def benchmark_dense_int8_acc32(tgt, opt): + + m = 1024 + n = 1024 + k = 1024 + + # Gops for gemm + gops_per_mm = 2.0 * (n * m * k) / 1e9 + + def verify(tgt, opt): + + target = tvm.target.Target("llvm -mcpu=ivybridge -keys=cpu,fast-math") + + ## + ## GRAPH + ## + + # network graph + dat = relay.var("data", shape=(m, k), dtype="uint8") + weight = relay.var("weight", shape=(n, k), dtype="int8") + out = relay.nn.dense(dat, weight, out_dtype="int32") + + # convert to relay IR + f = relay.Function(relay.analysis.free_vars(out), out) + mod, params = testing.create_workload(f) + + ## + ## EVAL + ## + + with relay.build_config(opt_level=3): + with tvm.transform.PassContext(opt_level=opt): + # build relay module + lib = relay.build(mod, target=target, params=None) + + tensorized = False + if "@llvm.x86." in lib.lib.get_source(): + tensorized = True + + import numpy as np + + np.random.seed(seed=None) + d_dtype = dat.type_annotation + w_dtype = weight.type_annotation + from tvm.topi.utils import get_const_tuple + + X = np.random.randint( + low=0, high=127, size=get_const_tuple(d_dtype.shape), dtype=d_dtype.dtype + ) + W = np.random.randint( + low=-63, high=63, size=get_const_tuple(w_dtype.shape), dtype=w_dtype.dtype + ) + + # build runtime module + dev = tvm.device(str(target), 0) + import tvm.contrib.graph_executor as runtime + + module = runtime.GraphModule(lib["default"](dev)) + module.set_input("data", tvm.nd.array(X)) + params = {"weight": tvm.nd.array(W)} + module.set_input(**params) + + # evaluate performance + ftimer = module.module.time_evaluator("run", dev, number=1, repeat=10) + result = np.array(ftimer().results) + gops_per_sec = gops_per_mm / np.mean(result) + print( + "Task tensorized: {%-5s} [%-45s], running time: %.3f ms, %.2f Gops/s" + % (tensorized, tgt, np.mean(result) * 1000, gops_per_sec) + ) + + # evaluate results + module.run() + module.get_output(0).asnumpy() + O = module.get_output(0).asnumpy() + tvm.testing.assert_allclose(O, np.dot(X.astype("int32"), W.T.astype("int32")), rtol=0) + + return + + verify(tgt, opt) + + +@tvm.testing.requires_x86_vnni +def test_fc_int8_acc32_x86_vnni(): + benchmark_dense_int8_acc32("llvm -mcpu=cascadelake", opt=3) + benchmark_dense_int8_acc32("llvm -mcpu=cascadelake", opt=2) + + +@tvm.testing.requires_x86_avx512 +def test_fc_int8_acc32_x86_avx512(): + benchmark_dense_int8_acc32("llvm -mcpu=skylake-avx512", opt=3) + benchmark_dense_int8_acc32("llvm -mcpu=skylake-avx512 -keys=cpu,fast-math", opt=3) + benchmark_dense_int8_acc32("llvm -mcpu=skylake-avx512", opt=2) + + +@tvm.testing.requires_x86 +def test_fc_int8_acc32_x86_simd(): + benchmark_dense_int8_acc32("llvm -mcpu=ivybridge", opt=3) + benchmark_dense_int8_acc32("llvm -mcpu=ivybridge -keys=cpu,fast-math", opt=3) + benchmark_dense_int8_acc32("llvm -mcpu=haswell", opt=3) + benchmark_dense_int8_acc32("llvm -mcpu=haswell -keys=cpu,fast-math", opt=3) + benchmark_dense_int8_acc32("llvm -mcpu=ivybridge", opt=2) + benchmark_dense_int8_acc32("llvm -mcpu=haswell", opt=2) + + +if __name__ == "__main__": + benchmark_dense_int8_acc32("llvm -mcpu=ivybridge", opt=3) + benchmark_dense_int8_acc32("llvm -mcpu=ivybridge -keys=cpu,fast-math", opt=3) + benchmark_dense_int8_acc32("llvm -mcpu=haswell", opt=3) + benchmark_dense_int8_acc32("llvm -mcpu=haswell -keys=cpu,fast-math", opt=3) + benchmark_dense_int8_acc32("llvm -mcpu=ivybridge", opt=2) + benchmark_dense_int8_acc32("llvm -mcpu=haswell", opt=2) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index ca8ffda9ba59..a3608df7365f 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -760,7 +760,7 @@ def test_bitserial_dense(): assert yy.checked_type == relay.TensorType((m, 32), "int16") -def dense_x86_test(m, n, k, target="llvm -mcpu=cascadelake", intrins=["vpdpbusd"]): +def dense_x86_test(m, n, k, target, intrins): data_shape = (m, k) weight_shape = (n, k) @@ -775,33 +775,41 @@ def dense_x86_test(m, n, k, target="llvm -mcpu=cascadelake", intrins=["vpdpbusd" with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target) - # TODO(vvchernov): needs for avx512 arch, can be extended - if n % 16 == 0 and k % 4 == 0: - asm = lib.lib.get_source("asm") - for intrin in intrins: - assert intrin in asm + irllvm = lib.lib.get_source() + for intrin in intrins: + assert intrin in irllvm dev = tvm.device(target, 0) runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) - a = np.random.uniform(1, 10, size=data_shape).astype(data_dtype) - b = np.random.uniform(1, 10, size=weight_shape).astype("int8") - c = np.random.uniform(1, 10, size=(weight_shape[0],)).astype("int32") - - runtime.set_input("data", a) - runtime.set_input("weight", b) - runtime.set_input("bias", c) + if "fast-math" in target: + if data_dtype == "int8": + d_np = np.random.randint(low=0, high=127, size=data_shape).astype(data_dtype) + else: + d_np = np.random.randint(low=-63, high=63, size=data_shape).astype(data_dtype) + w_np = np.random.randint(low=-63, high=63, size=weight_shape).astype("int8") + else: + if data_dtype == "int8": + d_np = np.random.randint(low=-128, high=-127, size=data_shape).astype(data_dtype) + else: + d_np = np.random.randint(low=0, high=255, size=data_shape).astype(data_dtype) + w_np = np.random.randint(low=-128, high=127, size=weight_shape).astype("int8") + b_np = np.random.randint(low=-65535, high=65535, size=(weight_shape[0],)).astype("int32") + + runtime.set_input("data", d_np) + runtime.set_input("weight", w_np) + runtime.set_input("bias", b_np) runtime.run() out = runtime.get_output(0).numpy() - ref = np.dot(a.astype("int32"), b.transpose().astype("int32")) + c + ref = np.dot(d_np.astype("int32"), w_np.transpose().astype("int32")) + b_np np.testing.assert_equal(out, ref) @tvm.testing.requires_llvm @pytest.mark.skip("skip due to AMX feature not avaliable yet") -def test_dense_amx_int8(): +def test_dense_int8_x86_amx(): data_shape = (32, 128) weight_shape = (32, 128) @@ -847,15 +855,37 @@ def test_dense_amx_int8(): @tvm.testing.requires_x86_vnni -@pytest.mark.parametrize("m,n,k", [(32, 128, 96), (32, 128, 97)]) -def test_dense_vnni(m, n, k): - dense_x86_test(m, n, k) +@pytest.mark.parametrize("m,n,k", [(4, 4, 4), (32, 128, 96), (32, 128, 97)]) +def test_dense_int8_x86_vnni(m, n, k): + target = "llvm -mcpu=cascadelake" + dense_x86_test(m, n, k, target, ["avx512.vpdpbusd"]) @tvm.testing.requires_x86_avx512 -@pytest.mark.parametrize("m,n,k", [(32, 128, 96), (32, 128, 97)]) -def test_dense_skylake_avx512(m, n, k): - dense_x86_test(m, n, k, "llvm -mcpu=skylake-avx512", ["pmaddubs", "pmaddw", "vpaddd"]) +@pytest.mark.parametrize("m,n,k", [(4, 4, 4), (32, 128, 96), (32, 128, 97)]) +def test_dense_int8_x86_avx512(m, n, k): + target = "llvm -mcpu=skylake-avx512" + fast = " -keys=cpu,fast-math" + dense_x86_test(m, n, k, target, ["avx512.pmaddw.d"]) + dense_x86_test(m, n, k, target + fast, ["avx512.pmaddubs.w", "avx512.pmaddw.d"]) + + +@tvm.testing.requires_x86 +@pytest.mark.parametrize("m,n,k", [(4, 4, 4), (32, 128, 96), (32, 128, 97)]) +def test_dense_int8_x86_avx2(m, n, k): + target = "llvm -mcpu=haswell" + fast = " -keys=cpu,fast-math" + dense_x86_test(m, n, k, target, ["avx2.pmadd.wd", "avx2.phadd.d"]) + dense_x86_test(m, n, k, target + fast, ["avx2.pmadd.ub.sw", "avx2.pmadd.wd"]) + + +@tvm.testing.requires_x86 +@pytest.mark.parametrize("m,n,k", [(4, 4, 4), (32, 128, 96), (32, 128, 97)]) +def test_dense_int8_x86_ssse3(m, n, k): + target = "llvm -mcpu=ivybridge" + fast = " -keys=cpu,fast-math" + dense_x86_test(m, n, k, target, ["sse2.pmadd", "ssse3.phadd.d"]) + dense_x86_test(m, n, k, target + fast, ["ssse3.pmadd.ub.sw", "sse2.pmadd"]) @pytest.mark.skip("Requires GFX10 AMDGPU") diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 6036f707126b..89b79065761e 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -473,42 +473,50 @@ def test_batch_matmul(executor_kind): verify_batch_matmul_with_inputs(executor_kind, x, x, x_np, x_np, (10, 27, 27)) -def batch_matmul_x86_test(b, m, n, k, target="llvm -mcpu=cascadelake", intrins=["vpdpbusd"]): - x_shape = (b, m, k) - y_shape = (b, n, k) - z_shape = (b, m, n) - - for lhs_dtype in ["uint8", "int8"]: - x = relay.var("x", shape=x_shape, dtype=lhs_dtype) - y = relay.var("y", shape=y_shape, dtype="int8") - z = relay.var("z", shape=z_shape, dtype="int32") - bmm = relay.nn.batch_matmul(x, y, out_dtype="int32") - out = bmm + z +def batch_matmul_x86_test(b, m, n, k, target, intrins): + d_shape = (b, m, k) + w_shape = (b, n, k) + b_shape = (b, m, n) + + for data_dtype in ["uint8", "int8"]: + d = relay.var("data", shape=d_shape, dtype=data_dtype) + w = relay.var("weight", shape=w_shape, dtype="int8") + b = relay.var("bias", shape=b_shape, dtype="int32") + bmm = relay.nn.batch_matmul(d, w, out_dtype="int32") + out = bmm + b mod = tvm.IRModule.from_expr(out) with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target) - # TODO(vvchernov): needs for avx512 arch, can be extended - if n % 16 == 0 and k % 4 == 0: - asm = lib.lib.get_source("asm") - for intrin in intrins: - assert intrin in asm + irllvm = lib.lib.get_source() + for intrin in intrins: + assert intrin in irllvm dev = tvm.device(target, 0) runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) - x_np = np.random.uniform(1, 10, size=x_shape).astype(lhs_dtype) - y_np = np.random.uniform(1, 10, size=y_shape).astype("int8") - z_np = np.random.uniform(1, 10, size=z_shape).astype("int32") + if "fast-math" in target: + if data_dtype == "int8": + d_np = np.random.randint(low=0, high=127, size=d_shape).astype(data_dtype) + else: + d_np = np.random.randint(low=-63, high=63, size=d_shape).astype(data_dtype) + w_np = np.random.randint(low=-63, high=63, size=w_shape).astype("int8") + else: + if data_dtype == "int8": + d_np = np.random.randint(low=-128, high=-127, size=d_shape).astype(data_dtype) + else: + d_np = np.random.randint(low=0, high=255, size=d_shape).astype(data_dtype) + w_np = np.random.randint(low=-128, high=127, size=w_shape).astype("int8") + b_np = np.random.randint(low=-65535, high=65535, size=(b_shape)).astype("int32") - runtime.set_input("x", x_np) - runtime.set_input("y", y_np) - runtime.set_input("z", z_np) + runtime.set_input("data", d_np) + runtime.set_input("weight", w_np) + runtime.set_input("bias", b_np) runtime.run() out = runtime.get_output(0).numpy() - ref = tvm.topi.testing.batch_matmul(x_np, y_np, out_dtype="int32") + z_np + ref = tvm.topi.testing.batch_matmul(d_np, w_np, out_dtype="int32") + b_np np.testing.assert_equal(out, ref) @@ -522,7 +530,7 @@ def batch_matmul_x86_test(b, m, n, k, target="llvm -mcpu=cascadelake", intrins=[ (16, 32, 31, 128), ], ) -def test_batch_matmul_amx(b, m, n, k): +def test_batch_matmul_int8_x86_amx(b, m, n, k): amx_init = tvm.get_global_func("runtime.amx_init") amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") assert amx_init() @@ -577,8 +585,9 @@ def test_batch_matmul_amx(b, m, n, k): (16, 32, 129, 96), ], ) -def test_batch_matmul_vnni(b, m, n, k): - batch_matmul_x86_test(b, m, n, k) +def test_batch_matmul_int8_x86_vnni(b, m, n, k): + target = "llvm -mcpu=cascadelake" + batch_matmul_x86_test(b, m, n, k, target, ["avx512.vpdpbusd"]) @tvm.testing.requires_x86_avx512 @@ -590,8 +599,43 @@ def test_batch_matmul_vnni(b, m, n, k): (16, 32, 129, 96), ], ) -def test_batch_matmul_skylake_avx512(b, m, n, k): - batch_matmul_x86_test(b, m, n, k, "llvm -mcpu=skylake-avx512", ["pmaddubs", "pmaddw", "vpaddd"]) +def test_batch_matmul_int8_x86_avx512(b, m, n, k): + target = "llvm -mcpu=skylake-avx512" + fast = " -keys=cpu,fast-math" + batch_matmul_x86_test(b, m, n, k, target, ["avx512.pmaddw.d"]) + batch_matmul_x86_test(b, m, n, k, target + fast, ["avx512.pmaddubs.w", "avx512.pmaddw.d"]) + + +@tvm.testing.requires_x86 +@pytest.mark.parametrize( + "b,m,n,k", + [ + (16, 32, 128, 96), + (16, 32, 128, 97), + (16, 32, 129, 96), + ], +) +def test_batch_matmul_int8_x86_avx2(b, m, n, k): + target = "llvm -mcpu=haswell" + fast = " -keys=cpu,fast-math" + batch_matmul_x86_test(b, m, n, k, target, ["avx2.pmadd.wd", "avx2.phadd.d"]) + batch_matmul_x86_test(b, m, n, k, target + fast, ["avx2.pmadd.ub.sw", "avx2.pmadd.wd"]) + + +@tvm.testing.requires_x86 +@pytest.mark.parametrize( + "b,m,n,k", + [ + (16, 32, 128, 96), + (16, 32, 128, 97), + (16, 32, 129, 96), + ], +) +def test_batch_matmul_int8_x86_ssse3(b, m, n, k): + target = "llvm -mcpu=ivybridge" + fast = " -keys=cpu,fast-math" + batch_matmul_x86_test(b, m, n, k, target, ["sse2.pmadd", "ssse3.phadd.d"]) + batch_matmul_x86_test(b, m, n, k, target + fast, ["ssse3.pmadd.ub.sw", "sse2.pmadd"]) @pytest.mark.skip("Requires GFX10 AMDGPU") diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index cb785021783d..c8708611e68f 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -29,6 +29,7 @@ from tvm.relay import transform from tvm.relay.testing import run_infer_type from tvm.topi.cuda.conv3d_winograd import _infer_tile_size +from tvm.target import codegen executor_kind = tvm.testing.parameter("graph", "vm") @@ -1672,12 +1673,15 @@ def test_upsampling3d(): @tvm.testing.requires_x86 -@pytest.mark.skipif(tvm.target.codegen.llvm_version_major() < 8, reason="Requires LLVM 8") +@pytest.mark.skipif(codegen.llvm_version_major() < 8, reason="Requires LLVM 8") class TestConv2DInt8Intrinsics: supported_targets = [ "llvm -mcpu=nehalem", + "llvm -mcpu=nehalem -keys=cpu,fast-math", "llvm -mcpu=core-avx2", + "llvm -mcpu=core-avx2 -keys=cpu,fast-math", "llvm -mcpu=skylake-avx512", + "llvm -mcpu=skylake-avx512 -keys=cpu,fast-math", "llvm -mcpu=cascadelake", ] @@ -1708,16 +1712,33 @@ class TestConv2DInt8Intrinsics: ) @tvm.testing.fixture - def fast_int8_intrinsic(self, target): - if "nehalem" in target or "core-avx2" in target or "skylake-avx512" in target: - return "pmaddubs" - elif "cascadelake" in target: - return "vpdpbusd" - else: - assert False, "Target should be Nehalem or core-avx2 or Skylake or Cascadelake" + def fast_int8_intrinsics(self, target): + code = ["throw-error"] + with tvm.target.Target(target): + if codegen.llvm_cpu_has_features("avxvnni") or codegen.llvm_cpu_has_features( + "avx512vnni" + ): + code = ["avx512.vpdpbusd"] + elif codegen.llvm_cpu_has_features(["avx512bw", "avx512f"]): + if "fast-math" in target: + code = ["avx512.pmaddubs.w", "avx512.pmaddw.d"] + else: + code = ["avx512.pmaddw.d"] + elif codegen.llvm_cpu_has_features("avx2"): + if "fast-math" in target: + code = ["avx2.pmadd.ub.sw", "avx2.pmadd.wd"] + else: + code = ["avx2.pmadd.wd", "avx2.phadd.d"] + elif codegen.llvm_cpu_has_features("ssse3"): + if "fast-math" in target: + code = ["ssse3.pmadd.ub.sw", "sse2.pmadd"] + else: + code = ["sse2.pmadd", "ssse3.phadd.d"] + + return code @tvm.testing.fixture - def assembly( + def irllvm( self, target, dtypes, @@ -1770,15 +1791,15 @@ def assembly( with tvm.transform.PassContext(opt_level=3): graph, lib, params = relay.build(func, target, params=parameters) - return lib.get_source("asm") + return lib.get_source() # Ensure that code uses the fast int8 instructions when available. @tvm.testing.parametrize_targets(*supported_targets) @pytest.mark.parametrize( "dtypes", [ - # compile conv2d for x86 (skylake, cascadelake) and test - # assembly contains *pmadd* instructions + # compile conv2d for x86 targets + # check llvm ir contains expected SIMD instructions ("uint8", "int8", "int32"), # Check that int8 x int8 goes through legalization so that # fast instructions can be picked up. @@ -1787,29 +1808,42 @@ def assembly( ) def test_uses_intrinsic( self, - fast_int8_intrinsic, - assembly, + fast_int8_intrinsics, + irllvm, ): - assert fast_int8_intrinsic in assembly + is_present = True + for intrin in fast_int8_intrinsics: + is_present &= intrin in irllvm + assert is_present == True # For datatypes that don't have HW support, ensure that code is # generated without the fast int8 intrinsic. @tvm.testing.parametrize_targets(*supported_targets) - @pytest.mark.parametrize("dtypes", [("uint8", "uint8", "int32")]) + @pytest.mark.parametrize( + "dtypes", + [ + ("uint8", "uint8", "int32"), + ], + ) def test_no_intrinsic( self, - fast_int8_intrinsic, - assembly, + fast_int8_intrinsics, + irllvm, ): - assert fast_int8_intrinsic not in assembly + is_present = True + for intrin in fast_int8_intrinsics: + is_present &= intrin in irllvm + assert is_present == False # Check that a vectorized instruction is generated for older Intel # generations, because we default to NCHWc layout. @tvm.testing.parametrize_targets(*unsupported_targets) @pytest.mark.parametrize("dtypes", [("uint8", "int8", "int32")]) - def test_uses_vectorized_instruction(self, assembly): - assert "pmulhw" in assembly or "pmaddwd" in assembly - assert "paddd" in assembly + def test_uses_vectorized_instruction(self, irllvm): + assert "mul nsw" in irllvm + assert "pmadd" not in irllvm + assert "phadd" not in irllvm + assert "vpdpbusd" not in irllvm @tvm.testing.uses_gpu @@ -2159,15 +2193,15 @@ def get_subgraph(dtype): np.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) -def _test_conv2d_int8_alter_dtype(data_dtype, target, dot_product_instrs): +def _test_conv2d_int8(I, O, H, W, kHW, S, P, target, dot_product_instrs): def get_conv2d_nchw( d_shape, w_shape, data_dtype, ): out_dtype = "int32" - strides = (1, 1) - padding = (1, 1) + strides = S + padding = P data = relay.var("data", shape=d_shape, dtype=data_dtype) weight = relay.var("weight", shape=w_shape, dtype="int8") out_channel = w_shape[0] @@ -2181,72 +2215,97 @@ def get_conv2d_nchw( out_dtype=out_dtype, ) - I, O, H, W = 64, 64, 56, 56 - kH = kW = 3 + kH, kW = kHW data_shape = (1, I, H, W) weight_shape = (O, I, kH, kW) bias_shape = (1, weight_shape[0], 1, 1) - bias = relay.var("bias", shape=bias_shape, dtype="int32") - bias_np = np.random.randint(low=-127, high=128, size=bias_shape).astype("int32") - weight_np = np.random.uniform(-32, 32, size=weight_shape).astype("int8") + for data_dtype in ["uint8", "int8"]: - conv2d = get_conv2d_nchw(data_shape, weight_shape, data_dtype) - bias_add = relay.add(conv2d, bias) - mod = tvm.IRModule.from_expr(bias_add) + bias = relay.var("bias", shape=bias_shape, dtype="int32") + conv2d = get_conv2d_nchw(data_shape, weight_shape, data_dtype) + bias_add = relay.add(conv2d, bias) + mod = tvm.IRModule.from_expr(bias_add) - if data_dtype == "uint8": - data_np = np.random.uniform(0, 64, size=data_shape).astype("uint8") - else: - data_np = np.random.uniform(-32, 32, size=data_shape).astype("int8") - - params = {"weight": weight_np, "bias": bias_np} + if "fast-math" in target: + if data_dtype == "int8": + data_np = np.random.randint(low=0, high=127, size=data_shape).astype(data_dtype) + else: + data_np = np.random.randint(low=-63, high=63, size=data_shape).astype(data_dtype) + weight_np = np.random.randint(low=-63, high=63, size=weight_shape).astype("int8") + else: + if data_dtype == "int8": + data_np = np.random.randint(low=-128, high=-127, size=data_shape).astype(data_dtype) + else: + data_np = np.random.randint(low=0, high=255, size=data_shape).astype(data_dtype) + weight_np = np.random.randint(low=-128, high=127, size=weight_shape).astype("int8") + bias_np = np.random.randint(low=-65535, high=65535, size=bias_shape).astype("int32") + + params = {"weight": weight_np, "bias": bias_np} + + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight_np, bias_np]) + .numpy() + ) - ref = ( - relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") - .evaluate()(*[data_np, weight_np, bias_np]) - .numpy() - ) + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target, params=params) - dev = tvm.cpu(0) + for dot_product_instr in dot_product_instrs: + assert dot_product_instr in lib.lib.get_source() - with tvm.transform.PassContext( - opt_level=3, - ): - lib = relay.build(mod, target=target, params=params) + dev = tvm.cpu(0) + rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + rt_mod.set_input("data", data_np) + rt_mod.run() + out = rt_mod.get_output(0).numpy() - for dot_product_instr in dot_product_instrs: - assert dot_product_instr in lib.lib.get_source("asm") + np.testing.assert_equal(out, ref) - rt_mod = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) - rt_mod.set_input("data", data_np) - - rt_mod.run() +@tvm.testing.requires_arm_dot +@pytest.mark.parametrize("I,O,H,W,kHW,S,P", [(64, 64, 56, 56, (3, 3), (1, 1), (1, 1))]) +def test_conv2d_int8_arm_dot(I, O, H, W, kHW, S, P): + target = "llvm -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod" + _test_conv2d_int8(I, O, H, W, kHW, S, P, target, ["sdot"]) - out = rt_mod.get_output(0).numpy() - np.testing.assert_equal(out, ref) +@tvm.testing.requires_x86_vnni +@pytest.mark.parametrize("I,O,H,W,kHW,S,P", [(64, 64, 56, 56, (3, 3), (1, 1), (1, 1))]) +def test_conv2d_int8_x86_vnni(I, O, H, W, kHW, S, P): + target = "llvm -mcpu=cascadelake" + _test_conv2d_int8(I, O, H, W, kHW, S, P, target, ["avx512.vpdpbusd"]) -@tvm.testing.requires_arm_dot -def test_conv2d_int8_alter_dtype_arm(): - _test_conv2d_int8_alter_dtype( - "uint8", "llvm -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod", ["sdot"] +@tvm.testing.requires_x86_avx512 +@pytest.mark.parametrize("I,O,H,W,kHW,S,P", [(64, 64, 56, 56, (3, 3), (1, 1), (1, 1))]) +def test_conv2d_int8_avx512(I, O, H, W, kHW, S, P): + target = "llvm -mcpu=skylake-avx512" + fast = " -keys=cpu,fast-math" + _test_conv2d_int8(I, O, H, W, kHW, S, P, target, ["avx512.pmaddw.d"]) + _test_conv2d_int8( + I, O, H, W, kHW, S, P, target + fast, ["avx512.pmaddubs.w", "avx512.pmaddw.d"] ) -@tvm.testing.requires_x86_vnni -def test_conv2d_int8_alter_dtype_vnni(): - _test_conv2d_int8_alter_dtype("int8", "llvm -mcpu=cascadelake", ["vpdpbusd"]) +@tvm.testing.requires_x86 +@pytest.mark.parametrize("I,O,H,W,kHW,S,P", [(64, 64, 56, 56, (3, 3), (1, 1), (1, 1))]) +def test_conv2d_int8_x86_avx2(I, O, H, W, kHW, S, P): + target = "llvm -mcpu=haswell" + fast = " -keys=cpu,fast-math" + _test_conv2d_int8(I, O, H, W, kHW, S, P, target, ["avx2.pmadd.wd", "avx2.phadd.d"]) + _test_conv2d_int8(I, O, H, W, kHW, S, P, target + fast, ["avx2.pmadd.ub.sw", "avx2.pmadd.wd"]) -@tvm.testing.requires_x86_avx512 -def test_conv2d_int8_alter_dtype_avx512(): - _test_conv2d_int8_alter_dtype( - "int8", "llvm -mcpu=skylake-avx512", ["pmaddubs", "pmaddw", "vpaddd"] - ) +@tvm.testing.requires_x86 +@pytest.mark.parametrize("I,O,H,W,kHW,S,P", [(64, 64, 56, 56, (3, 3), (1, 1), (1, 1))]) +def test_conv2d_int8_x86_ssse3(I, O, H, W, kHW, S, P): + target = "llvm -mcpu=ivybridge" + fast = " -keys=cpu,fast-math" + _test_conv2d_int8(I, O, H, W, kHW, S, P, target, ["sse2.pmadd", "ssse3.phadd.d"]) + _test_conv2d_int8(I, O, H, W, kHW, S, P, target + fast, ["ssse3.pmadd.ub.sw", "sse2.pmadd"]) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 2df644d7e198..6a0bb62e69a0 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -44,6 +44,13 @@ def test_expr_constructor(): assert isinstance(x, tvm.tir.StringImm) assert x.value == "xyza" + x = tvm.tir.ArrayIntImm([1, 2, 3]) + assert isinstance(x, tvm.tir.ArrayIntImm) + assert x == [1, 2, 3] + assert x.data[-1] == 3 + assert len(x.data) == 3 + assert str(x.data) == str([1, 2, 3]) + x = tvm.tir.Cast("float32", tvm.tir.IntImm("uint32", 1)) assert isinstance(x, tvm.tir.Cast) assert x.dtype == "float32" diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 49816778f11f..d0eaef2b18af 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -338,6 +338,13 @@ def test_equality_string_imm(): x == y +def test_equality_array_imm(): + x = [1, 2, 3] + y = tvm.tir.ArrayIntImm(x) + x == y.data + x == y + + def test_prim_func(): x = te.var("x") y = te.var("y") diff --git a/tests/python/unittest/test_tir_op_types.py b/tests/python/unittest/test_tir_op_types.py index 7398ee781b9e..05b276b2b83f 100644 --- a/tests/python/unittest/test_tir_op_types.py +++ b/tests/python/unittest/test_tir_op_types.py @@ -57,6 +57,27 @@ def test_tir_op_reinterpret(): assert expr.op.name == "tir.reinterpret" +def test_tir_op_zextend(): + buffer = tir.decl_buffer((4, 4), "uint8", offset_factor=1) + vec = buffer.vload([0, 0], dtype="uint8x16") + expr = tir.zextend("uint16x8", vec) + assert expr.op.name == "tir.zextend" + + +def test_tir_op_sextend(): + buffer = tir.decl_buffer((4, 4), "uint8", offset_factor=1) + vec = buffer.vload([0, 0], dtype="uint8x16") + expr = tir.sextend("int16x8", vec) + assert expr.op.name == "tir.sextend" + + +def test_tir_op_truncate(): + buffer = tir.decl_buffer((4, 4), "uint16", offset_factor=1) + vec = buffer.vload([0, 0], dtype="uint16x16") + expr = tir.truncate("uint8x32", vec) + assert expr.op.name == "tir.truncate" + + def test_tir_op_isnullptr(): x = tir.Var("x", dtype="int32") expr = tir.isnullptr(x) @@ -302,6 +323,31 @@ def test_tir_op_vectorcombine(): assert expr.op.name == "tir.vectorcombine" +def test_tir_op_vectorpermute(): + buffer = tir.decl_buffer((2, 2), "uint32", offset_factor=1) + vec = buffer.vload([0, 0], dtype="uint32x4") + expr = tir.vectorpermute("uint32x4", vec, [2, 3, 0, 1]) + assert expr.op.name == "tir.vectorpermute" + + +def test_tir_op_vectorshuffle(): + buffer0 = tir.decl_buffer((2, 2), "uint32", offset_factor=1) + buffer1 = tir.decl_buffer((2, 2), "uint32", offset_factor=1) + vec0 = buffer0.vload([0, 0], dtype="uint32x4") + vec1 = buffer1.vload([0, 0], dtype="uint32x4") + expr = tir.vectorshuffle("uint32x4", vec0, vec1, [0, 1, 4, 5]) + assert expr.op.name == "tir.vectorshuffle" + + +def test_tir_op_atomic_add(): + buffer0 = tir.decl_buffer((2, 2), "uint32", offset_factor=1) + buffer1 = tir.decl_buffer((2, 2), "uint32", offset_factor=1) + vec0 = buffer0.vload([0, 0], dtype="uint32x4") + vec1 = buffer1.vload([0, 0], dtype="uint32x4") + expr = tir.atomic_add("uint32x4", vec0, vec1) + assert expr.op.name == "tir.atomic_add" + + def test_tir_op_shift_left(): x = tir.Var("x", dtype="int32") y = tir.Var("x", dtype="int32") diff --git a/tests/python/unittest/test_tvmscript_printer_metadata.py b/tests/python/unittest/test_tvmscript_printer_metadata.py index a57f4c71f7dd..60e857c039d4 100644 --- a/tests/python/unittest/test_tvmscript_printer_metadata.py +++ b/tests/python/unittest/test_tvmscript_printer_metadata.py @@ -43,5 +43,23 @@ def foo1() -> None: ) +def test_array_metadata(): + arr_imm = T.ArrayIntImm([1, 2, 3]) + + @I.ir_module + class Module: + @T.prim_func + def foo() -> None: + A = arr_imm + B = arr_imm + + @T.prim_func + def foo1() -> None: + A = arr_imm + + printed_str = Module.script(verbose_expr=True) + assert printed_str.count("T.ArrayIntImm([1, 2, 3])") == 3 + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 70d56e6903b7..a51a5b25f5ab 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -493,6 +493,11 @@ def test_string_imm(): _assert_print(s, '"str"') +def test_array_imm(): + a = tir.ArrayIntImm([1, 2, 3]) + _assert_print(a, "[1, 2, 3]") + + def test_cast(): obj = tir.Cast("float64", tir.Var("a", "float32")) _assert_print(