Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ TVM_DLL const Op& call_spirv_pure_glsl450();
// TODO(tvm-team) revisit the builtins below
// some of them can simply become ops with special codegen attr.
/*!
* \brief Prefetch a cacheline
* \brief same signature as llvm.prefetch
*/
TVM_DLL const Op& prefetch();

Expand Down
5 changes: 0 additions & 5 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1106,11 +1106,6 @@ constexpr const char* pragma_import_c = "pragma_import_c";
constexpr const char* pragma_import_llvm = "pragma_import_llvm";
/*! \brief Try to modify the AST to support Tensor Core */
constexpr const char* pragma_tensor_core = "pragma_tensor_core";
/*!
* \brief Mark of prefetch scope, value=offset,
* run prefetch of Tensor on the current loop scope
*/
constexpr const char* prefetch_scope = "prefetch_scope";
/*!
* \brief Marks the layout transforms to be used for a tensor.
*
Expand Down
18 changes: 1 addition & 17 deletions python/tvm/tir/tensor_intrin/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,13 @@ def neon_4x4_i8i8i32_impl(

multiply_low = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
T.uint32(2),
vec_a,
vec_b_low,
dtype="int16x8",
)

pairwise_reduction_low = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
T.uint32(1),
multiply_low,
dtype="int32x4",
)
Expand All @@ -91,22 +89,19 @@ def neon_4x4_i8i8i32_impl(

multiply_high = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
T.uint32(2),
vec_a,
vec_b_high,
dtype="int16x8",
)

pairwise_reduction_high = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
T.uint32(1),
multiply_high,
dtype="int32x4",
)

C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"),
T.uint32(2),
pairwise_reduction_low,
pairwise_reduction_high,
dtype="int32x4",
Expand Down Expand Up @@ -159,7 +154,6 @@ def dot_prod_impl(a: T.handle, b: T.handle, c: T.handle) -> None:

C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id(f"llvm.aarch64.neon.{instr}"),
T.uint32(3),
vec_c,
vec_a,
vec_b,
Expand Down Expand Up @@ -311,7 +305,6 @@ def impl():
T.call_llvm_intrin(
"void",
"llvm.aarch64.sme.ld1w.horiz",
T.uint32(4),
predicate,
input_ptr,
sub_tile,
Expand All @@ -335,7 +328,6 @@ def impl():
T.call_llvm_intrin(
"void",
"llvm.aarch64.sme.st1w.vert",
T.uint32(4),
predicate,
output_ptr,
sub_tile,
Expand Down Expand Up @@ -438,7 +430,6 @@ def impl():
T.call_llvm_intrin(
"void",
"llvm.aarch64.sme.ld1h.horiz",
T.uint32(4),
ptrue_fp16,
input_ptr,
sub_tile_idx,
Expand All @@ -450,7 +441,6 @@ def impl():
T.call_llvm_intrin(
"void",
"llvm.aarch64.sme.ld1h.horiz",
T.uint32(4),
ptrue_fp16,
input_ptr,
sub_tile_idx,
Expand All @@ -467,7 +457,6 @@ def impl():
T.call_llvm_intrin(
"void",
"llvm.aarch64.sme.st1w.vert",
T.uint32(4),
ptrue_fp32,
output_ptr,
sub_tile_idx,
Expand All @@ -479,7 +468,6 @@ def impl():
T.call_llvm_intrin(
"void",
"llvm.aarch64.sme.st1w.vert",
T.uint32(4),
ptrue_fp32,
output_ptr,
sub_tile_idx + 2,
Expand Down Expand Up @@ -692,7 +680,6 @@ def impl():
T.call_llvm_intrin(
"void",
fmopa_intrin,
T.uint32(5),
sub_tile,
input_1[1],
input_2[1],
Expand All @@ -713,7 +700,6 @@ def impl():
T.call_llvm_intrin(
"void",
"llvm.aarch64.sme.st1w.horiz",
T.uint32(4),
_create_active_lane_mask(
C, (vert_offset + slice_idx, horiz_offset), M
),
Expand Down Expand Up @@ -752,9 +738,7 @@ def impl(c: T.handle) -> None:
T.reads()
T.writes(C[0:SVF2, 0:SVF2])
clear_all_tiles = T.int32(255)
T.evaluate(
T.call_llvm_intrin("void", "llvm.aarch64.sme.zero", T.uint32(1), clear_all_tiles)
)
T.evaluate(T.call_llvm_intrin("void", "llvm.aarch64.sme.zero", clear_all_tiles))

return desc, impl

Expand Down
3 changes: 0 additions & 3 deletions python/tvm/tir/tensor_intrin/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non

C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyub.acc.128B"),
T.uint32(3),
C[T.ramp(T.int32(0), 1, 32)],
B_i32x32,
A_i32,
Expand Down Expand Up @@ -149,7 +148,6 @@ def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non

C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpybusv.acc.128B"),
T.uint32(3),
C[T.ramp(T.int32(0), 1, 32)],
T.broadcast(A_i32, 32),
B_i32x32,
Expand Down Expand Up @@ -191,7 +189,6 @@ def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> N

C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vdmpyhvsat.acc.128B"),
T.uint32(3),
C[T.ramp(T.int32(0), 1, 32)],
T.Broadcast(A_i32, 32),
B_i32x32,
Expand Down
3 changes: 0 additions & 3 deletions python/tvm/tir/tensor_intrin/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def sdot4(

C[0] += T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.amdgcn.sdot4"),
T.uint32(4),
T.reinterpret(A.vload([0], "int8x4"), dtype="int32"),
T.reinterpret(B.vload([0], "int8x4"), dtype="int32"),
T.int32(0),
Expand Down Expand Up @@ -337,7 +336,6 @@ def mfma_sync_impl_float(a: T.handle, b: T.handle, c: T.handle) -> None:
T.launch_thread(tx, WARP_SIZE)
C[tx, 0:local_size_out] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id(mfma_intrin),
T.uint32(6),
A[tx, 0:local_size],
B[tx, 0:local_size],
C[tx, 0:local_size_out],
Expand Down Expand Up @@ -365,7 +363,6 @@ def mfma_sync_impl_integer(a: T.handle, b: T.handle, c: T.handle) -> None:

C[tx, 0:local_size_out] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id(mfma_intrin),
T.uint32(6),
T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]),
T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]),
C[tx, 0:local_size_out],
Expand Down
3 changes: 0 additions & 3 deletions python/tvm/tir/tensor_intrin/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def dot_product_16x4_u8i8i32_vnni(

C[T.ramp(T.int32(0), 1, 16)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
T.uint32(3),
C_i32x16,
T.broadcast(A_i32, 16),
B_i32x16,
Expand All @@ -86,15 +85,13 @@ def dot_product_16x4_u8i8i32_avx512(

Red = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddubs.w.512"),
T.uint32(2),
A_u8x64,
B_i8x64,
dtype="int16x32",
)

C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.pmaddw.d.512"),
T.uint32(2),
Red,
T.int16x32(1),
dtype="int32x16",
Expand Down
7 changes: 1 addition & 6 deletions src/target/llvm/codegen_arm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) {

PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
using namespace tir;
const PrimExpr& e = call->args[2];
const PrimExpr& e = call->args[1];
llvm::Intrinsic::ID ctpop_id = llvm::Intrinsic::ctpop;
llvm::Intrinsic::ID vpaddlu_id = llvm::Intrinsic::arm_neon_vpaddlu;

Expand All @@ -77,7 +77,6 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
(total_size != 128 && total_size != 64)) {
Array<PrimExpr> vcnt_args;
vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
vcnt_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt_args.push_back(e);
return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args);
}
Expand All @@ -101,14 +100,12 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
ICHECK(c0 != nullptr);
Array<PrimExpr> vcnt8_args;
vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
vcnt8_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt8_args.push_back(input8);
PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args);

// Accumulation 8->16bit
Array<PrimExpr> vcnt16_args;
vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt16_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt16_args.push_back(vcnt8);
PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args);
if (call->dtype.bits() == 16) {
Expand All @@ -118,7 +115,6 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
// Accumulation 16->32bit
Array<PrimExpr> vcnt32_args;
vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt32_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt32_args.push_back(vcnt16);
PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args);
if (call->dtype.bits() == 32) {
Expand All @@ -128,7 +124,6 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
// Accumulation 32->64bit
Array<PrimExpr> vcnt64_args;
vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id));
vcnt64_args.push_back(IntImm(DataType::UInt(32), 1));
vcnt64_args.push_back(vcnt32);
return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args);
}
Expand Down
27 changes: 5 additions & 22 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1351,34 +1351,18 @@ void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) {

llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) {
ICHECK_GE(op->args.size(), 2U);
ICHECK_GE(op->args.size(), 1U);
llvm::Intrinsic::ID id = static_cast<llvm::Intrinsic::ID>(Downcast<IntImm>(op->args[0])->value);
int64_t num_signature = Downcast<IntImm>(op->args[1])->value;
std::vector<llvm::Value*> arg_value;
std::vector<llvm::Type*> arg_type;
for (size_t i = 2; i < op->args.size(); ++i) {
for (size_t i = 1; i < op->args.size(); ++i) {
arg_value.push_back(MakeValue(op->args[i]));
if (i - 2 < static_cast<size_t>(num_signature)) {
arg_type.push_back(arg_value.back()->getType());
}
arg_type.push_back(arg_value.back()->getType());
}
// LLVM's prefetch intrinsic returns "void", while TVM's prefetch
// returns int32. This causes problems because prefetch is one of
// those intrinsics that is generated automatically via the
// tvm.intrin.rule mechanism. Any other intrinsic with a type
// mismatch will have to be treated specially here.
// TODO(kparzysz-quic): fix this once TVM prefetch uses the same
// type as LLVM.
llvm::Type* return_type =
(id != llvm::Intrinsic::prefetch) ? GetLLVMType(GetRef<PrimExpr>(op)) : t_void_;
llvm::Type* return_type = GetLLVMType(GetRef<PrimExpr>(op));
llvm::Function* f = GetIntrinsicDecl(id, return_type, arg_type);
ICHECK(f) << "Cannot find intrinsic declaration, possible type mismatch: "
#if TVM_LLVM_VERSION >= 130
<< llvm::Intrinsic::getBaseName(id).str();
#else
<< llvm::Intrinsic::getName(id, {});
#endif

<< llvmGetIntrinName(id);
// In earlier versions of LLVM's, the prefetch intrinsic is not
// overloaded, and always takes the first argument as i8*. If
// this is the case, this argument should insert a cast to i8*.
Expand All @@ -1391,7 +1375,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
builder_->CreatePointerCast(arg_value[0], llvmGetPointerTo(t_char_, addrspace));
}
}

return builder_->CreateCall(f, arg_value);
} else if (op->op.same_as(builtin::bitwise_and())) {
return builder_->CreateAnd(MakeValue(op->args[0]), MakeValue(op->args[1]));
Expand Down
1 change: 0 additions & 1 deletion src/target/llvm/intrin_rule_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,6 @@ TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const PrimEx
ICHECK_EQ(call->args.size(), 1);
Array<PrimExpr> cargs;
cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz));
cargs.push_back(IntImm(DataType::UInt(32), 2));
cargs.push_back(call->args[0]);
cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef
// LLVM requires that the return type must match the first argument type
Expand Down
11 changes: 9 additions & 2 deletions src/target/llvm/intrin_rule_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@

#ifdef TVM_LLVM_VERSION

#include <llvm/IR/Intrinsics.h>
#include <tvm/ffi/function.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>

#include "llvm_instance.h"

namespace tvm {
namespace codegen {
// num_signature means number of arguments used to query signature
Expand All @@ -41,7 +44,9 @@ inline PrimExpr DispatchLLVMPureIntrin(const PrimExpr& e) {
Array<PrimExpr> cargs;
// intrin id.
cargs.push_back(IntImm(DataType::UInt(32), id));
cargs.push_back(IntImm(DataType::UInt(32), num_signature));
ICHECK_EQ(call->args.size(), num_signature)
<< "llvm.call_llvm_intrin" << llvmGetIntrinName(id) << "expects " << num_signature
<< " arguments, but got " << call->args.size();

for (PrimExpr arg : call->args) {
cargs.push_back(arg);
Expand All @@ -56,7 +61,9 @@ inline PrimExpr DispatchLLVMIntrin(const PrimExpr& e) {
Array<PrimExpr> cargs;
// intrin id.
cargs.push_back(IntImm(DataType::UInt(32), id));
cargs.push_back(IntImm(DataType::UInt(32), num_signature));
ICHECK_EQ(call->args.size(), num_signature)
<< "llvm.call_llvm_intrin" << llvmGetIntrinName(id) << "expects " << num_signature
<< " arguments, but got " << call->args.size();
for (PrimExpr arg : call->args) {
cargs.push_back(arg);
}
Expand Down
15 changes: 15 additions & 0 deletions src/target/llvm/llvm_instance.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@
#define llvmGetPointerTo(arg, offset) (arg->getPointerTo(offset))
#endif

#if TVM_LLVM_VERSION >= 130
#define llvmGetIntrinName(id) \
std::string(llvm::Intrinsic::getBaseName(static_cast<llvm::Intrinsic::ID>(id)))
#elif TVM_LLVM_VERSION >= 40
// This is the version of Intrinsic::getName that works for overloaded
// intrinsics. Helpfully, if we provide no types to this function, it
// will give us the overloaded name without the types appended. This
// should be enough information for most uses.
#define llvmGetIntrinName(id) \
std::string(llvm::Intrinsic::getName(static_cast<llvm::Intrinsic::ID>(id), {}))
#else
// Nothing to do, just return the intrinsic id number
#define llvmGetIntrinName(id) std::to_string(id)
#endif

namespace llvm {
class LLVMContext;
class MemoryBuffer;
Expand Down
15 changes: 1 addition & 14 deletions src/target/llvm/llvm_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -653,20 +653,7 @@ static void LLVMReflectionRegister() {
#endif
})
.def("target.llvm_get_intrinsic_name",
[](int64_t id) -> String {
#if TVM_LLVM_VERSION >= 130
return std::string(llvm::Intrinsic::getBaseName(static_cast<llvm::Intrinsic::ID>(id)));
#elif TVM_LLVM_VERSION >= 40
// This is the version of Intrinsic::getName that works for overloaded
// intrinsics. Helpfully, if we provide no types to this function, it
// will give us the overloaded name without the types appended. This
// should be enough information for most uses.
return std::string(llvm::Intrinsic::getName(static_cast<llvm::Intrinsic::ID>(id), {}));
#else
// Nothing to do, just return the intrinsic id number
return std::to_string(id);
#endif
})
[](int64_t id) -> String { return llvmGetIntrinName(id); })
.def("target.llvm_get_system_x86_vendor",
[]() -> String {
#if TVM_LLVM_VERSION >= 120
Expand Down
Loading
Loading