Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ class CodeGen_ARM : public CodeGen_Posix {
void visit(const Call *) override;
void visit(const LT *) override;
void visit(const LE *) override;
Value *interleave_vectors(const std::vector<Value *> &) override;
Value *shuffle_vectors(Value *a, Value *b, const std::vector<int> &indices) override;
void codegen_vector_reduce(const VectorReduce *, const Expr &) override;
bool codegen_dot_product_vector_reduce(const VectorReduce *, const Expr &);
bool codegen_pairwise_vector_reduce(const VectorReduce *, const Expr &);
Expand Down Expand Up @@ -1942,12 +1944,100 @@ void CodeGen_ARM::visit(const Shuffle *op) {
load->type.lanes() == stride * op->type.lanes()) {

value = codegen_dense_vector_load(load, nullptr, /* slice_to_native */ false);
value = shuffle_vectors(value, op->indices);
value = CodeGen_Posix::shuffle_vectors(value, op->indices);
} else {
CodeGen_Posix::visit(op);
}
}

Value *CodeGen_ARM::interleave_vectors(const std::vector<Value *> &vecs) {
if (simd_intrinsics_disabled() || target_vscale() == 0 ||
vecs.size() < 2 ||
!is_scalable_vector(vecs[0])) {
return CodeGen_Posix::interleave_vectors(vecs);
}

// Lower into llvm.vector.interleave intrinsic
const std::set<int> supported_strides{2, 3, 4, 8};
const int stride = vecs.size();
const int src_lanes = get_vector_num_elements(vecs[0]->getType());

if (supported_strides.find(stride) != supported_strides.end() &&
is_power_of_two(src_lanes) &&
src_lanes % target_vscale() == 0 &&
src_lanes / target_vscale() > 1) {

llvm::Type *elt = get_vector_element_type(vecs[0]->getType());
llvm::Type *ret_type = get_vector_type(elt, src_lanes * stride);

std::string instr = concat_strings("llvm.vector.interleave", stride, mangle_llvm_type(ret_type));

std::vector<llvm::Type *> arg_types(stride, vecs[0]->getType());
llvm::FunctionType *fn_type = FunctionType::get(ret_type, arg_types, false);
FunctionCallee fn = module->getOrInsertFunction(instr, fn_type);

CallInst *interleave = builder->CreateCall(fn, vecs);
return interleave;
}

return CodeGen_Posix::interleave_vectors(vecs);
};

Value *CodeGen_ARM::shuffle_vectors(Value *a, Value *b, const std::vector<int> &indices) {
if (simd_intrinsics_disabled() || target_vscale() == 0 ||
!is_scalable_vector(a)) {
return CodeGen_Posix::shuffle_vectors(a, b, indices);
}

internal_assert(a->getType() == b->getType());

llvm::Type *elt = get_vector_element_type(a->getType());
const int src_lanes = get_vector_num_elements(a->getType());
const int dst_lanes = indices.size();

// Check if deinterleaved slice
{
// Get the stride of slice
int slice_stride = 0;
const int start_index = indices[0];
if (dst_lanes > 1) {
const int stride = indices[1] - start_index;
bool stride_equal = true;
for (int i = 2; i < dst_lanes; ++i) {
stride_equal &= (indices[i] == start_index + i * stride);
}
slice_stride = stride_equal ? stride : 0;
}

// Lower slice with stride into llvm.vector.deinterleave intrinsic
const std::set<int> supported_strides{2, 3, 4, 8};
if (supported_strides.find(slice_stride) != supported_strides.end() &&
dst_lanes * slice_stride == src_lanes &&
indices.front() < slice_stride && // Start position cannot be larger than stride
is_power_of_two(dst_lanes) &&
dst_lanes % target_vscale() == 0 &&
dst_lanes / target_vscale() > 1) {

std::string instr = concat_strings("llvm.vector.deinterleave", slice_stride, mangle_llvm_type(a->getType()));

// We cannot mix FixedVector and ScalableVector, so dst_type must be scalable
llvm::Type *dst_type = get_vector_type(elt, dst_lanes / target_vscale(), VectorTypeConstraint::VScale);
StructType *sret_type = StructType::get(*context, std::vector(slice_stride, dst_type));
std::vector<llvm::Type *> arg_types{a->getType()};
llvm::FunctionType *fn_type = FunctionType::get(sret_type, arg_types, false);
FunctionCallee fn = module->getOrInsertFunction(instr, fn_type);

CallInst *deinterleave = builder->CreateCall(fn, {a});
// extract one element out of the returned struct
Value *extracted = builder->CreateExtractValue(deinterleave, indices.front());

return extracted;
}
}

return CodeGen_Posix::shuffle_vectors(a, b, indices);
}

void CodeGen_ARM::visit(const Ramp *op) {
if (target_vscale() != 0 && op->type.is_int_or_uint()) {
if (is_const_zero(op->base) && is_const_one(op->stride)) {
Expand Down
30 changes: 22 additions & 8 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1362,14 +1362,11 @@ void CodeGen_LLVM::codegen(const Stmt &s) {
value = nullptr;
s.accept(this);
}
namespace {

bool is_power_of_two(int x) {
bool CodeGen_LLVM::is_power_of_two(int x) const {
return (x & (x - 1)) == 0;
}

} // namespace

Type CodeGen_LLVM::upgrade_type_for_arithmetic(const Type &t) const {
if (t.is_bfloat() || (t.is_float() && t.bits() < 32)) {
return Float(32, t.lanes());
Expand Down Expand Up @@ -4139,10 +4136,22 @@ void CodeGen_LLVM::visit(const Shuffle *op) {
} else if (op->is_concat()) {
value = concat_vectors(vecs);
} else {
auto can_divide_lanes = [this](int lanes, int factor) -> bool {
if (lanes % factor) {
return false;
}
if (effective_vscale == 0) {
return true;
} else {
int divided = lanes / factor;
return (divided % effective_vscale) == 0 && (divided / effective_vscale) > 1;
}
};

// If the even-numbered indices equal the odd-numbered
// indices, only generate one and then do a self-interleave.
for (int f : {4, 3, 2}) {
bool self_interleave = (op->indices.size() % f) == 0;
bool self_interleave = can_divide_lanes(op->indices.size(), f);
for (size_t i = 0; i < op->indices.size(); i++) {
self_interleave &= (op->indices[i] == op->indices[i - (i % f)]);
}
Expand All @@ -4158,8 +4167,8 @@ void CodeGen_LLVM::visit(const Shuffle *op) {
}

// Check for an interleave of slices (i.e. an in-vector transpose)
bool interleave_of_slices = op->vectors.size() == 1 && (op->indices.size() % f) == 0;
int step = op->type.lanes() / f;
bool interleave_of_slices = can_divide_lanes(op->indices.size(), f) && step != 1 && op->vectors.size() == 1;
for (int i = 0; i < step; i++) {
for (int j = 0; j < f; j++) {
interleave_of_slices &= (op->indices[i * f + j] == j * step + i);
Expand All @@ -4173,13 +4182,14 @@ void CodeGen_LLVM::visit(const Shuffle *op) {
slices.push_back(slice_vector(value, i * step, step));
}
value = interleave_vectors(slices);
return;
}
}
// If the indices form contiguous aligned runs, do the shuffle
// on entire sub-vectors by reinterpreting them as a wider
// type.
for (int f : {8, 4, 2}) {
if (op->type.lanes() % f != 0) {
if (!can_divide_lanes(op->type.lanes(), f)) {
continue;
}

Expand All @@ -4188,7 +4198,7 @@ void CodeGen_LLVM::visit(const Shuffle *op) {
}
bool contiguous = true;
for (const Expr &vec : op->vectors) {
contiguous &= ((vec.type().lanes() % f) == 0);
contiguous &= can_divide_lanes(vec.type().lanes(), f);
}
for (size_t i = 0; i < op->indices.size(); i += f) {
contiguous &= (op->indices[i] % f) == 0;
Expand Down Expand Up @@ -5414,6 +5424,10 @@ llvm::Constant *CodeGen_LLVM::get_splat(int lanes, llvm::Constant *value,
return ConstantVector::getSplat(ec, value);
}

bool CodeGen_LLVM::is_scalable_vector(llvm::Value *v) const {
return isa<ScalableVectorType>(v->getType());
}

std::string CodeGen_LLVM::mangle_llvm_type(llvm::Type *type) {
std::string type_string = ".";
if (isa<llvm::ScalableVectorType>(type)) {
Expand Down
4 changes: 4 additions & 0 deletions src/CodeGen_LLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,10 @@ class CodeGen_LLVM : public IRVisitor {
/** Shorthand for shuffling a single vector. */
llvm::Value *shuffle_vectors(llvm::Value *v, const std::vector<int> &indices);

bool is_power_of_two(int x) const;

bool is_scalable_vector(llvm::Value *v) const;

/** Go looking for a vector version of a runtime function. Will
* return the best match. Matches in the following order:
*
Expand Down
Loading