Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ xcuserdata
# NeoVim + clangd
.cache

# CCLS
.ccls-cache

# Emacs
tags
TAGS
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ SOURCE_FILES = \
IRVisitor.cpp \
JITModule.cpp \
Lambda.cpp \
LegalizeVectors.cpp \
Lerp.cpp \
LICM.cpp \
LLVM_Output.cpp \
Expand Down Expand Up @@ -737,6 +738,7 @@ HEADER_FILES = \
WasmExecutor.h \
JITModule.h \
Lambda.h \
LegalizeVectors.h \
Lerp.h \
LICM.h \
LLVM_Output.h \
Expand Down
32 changes: 17 additions & 15 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ target_sources(
Associativity.h
AsyncProducers.h
AutoScheduleUtils.h
BoundConstantExtentLoops.h
BoundSmallAllocations.h
BoundaryConditions.h
Bounds.h
BoundsInference.h
BoundConstantExtentLoops.h
BoundSmallAllocations.h
Buffer.h
CPlusPlusMangle.h
CSE.h
Callable.h
CanonicalizeGPUVars.h
ClampUnsafeAccesses.h
Expand All @@ -79,18 +81,16 @@ target_sources(
CodeGen_LLVM.h
CodeGen_Metal_Dev.h
CodeGen_OpenCL_Dev.h
CodeGen_Posix.h
CodeGen_PTX_Dev.h
CodeGen_Posix.h
CodeGen_PyTorch.h
CodeGen_Targets.h
CodeGen_Vulkan_Dev.h
CodeGen_WebGPU_Dev.h
CompilerLogger.h
ConciseCasts.h
CPlusPlusMangle.h
ConstantBounds.h
ConstantInterval.h
CSE.h
Debug.h
DebugArguments.h
DebugToFile.h
Expand Down Expand Up @@ -127,6 +127,13 @@ target_sources(
Generator.h
HexagonOffload.h
HexagonOptimize.h
IR.h
IREquality.h
IRMatch.h
IRMutator.h
IROperator.h
IRPrinter.h
IRVisitor.h
ImageParam.h
InferArguments.h
InjectHostDevBufferCopies.h
Expand All @@ -135,19 +142,13 @@ target_sources(
IntegerDivisionTable.h
Interval.h
IntrusivePtr.h
IR.h
IREquality.h
IRMatch.h
IRMutator.h
IROperator.h
IRPrinter.h
IRVisitor.h
JITModule.h
Lambda.h
Lerp.h
LICM.h
LLVM_Output.h
LLVM_Runtime_Linker.h
Lambda.h
LegalizeVectors.h
Lerp.h
LoopCarry.h
LoopPartitioningDirective.h
Lower.h
Expand All @@ -173,8 +174,8 @@ target_sources(
PurifyIndexMath.h
PythonExtensionGen.h
Qualify.h
Random.h
RDom.h
Random.h
Realization.h
RealizationOrder.h
RebaseLoopsToZero.h
Expand Down Expand Up @@ -320,6 +321,7 @@ target_sources(
IRVisitor.cpp
JITModule.cpp
Lambda.cpp
LegalizeVectors.cpp
Lerp.cpp
LICM.cpp
LLVM_Output.cpp
Expand Down
5 changes: 5 additions & 0 deletions src/CSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ bool should_extract(const Expr &e, bool lift_all) {
return false;
}

if (const Call *c = e.as<Call>()) {
// Calls with side effects should not be moved.
return c->is_pure() || c->call_type == Call::Halide;
}

if (lift_all) {
return true;
}
Expand Down
24 changes: 15 additions & 9 deletions src/CodeGen_Hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1188,15 +1188,16 @@ Value *CodeGen_Hexagon::shuffle_vectors(Value *a, Value *b,
create_bitcast(a_call->getArgOperand(1), native_ty),
create_bitcast(a_call->getArgOperand(0), native_ty), indices);
} else if (ShuffleVectorInst *a_shuffle = dyn_cast<ShuffleVectorInst>(a)) {
bool is_identity = true;
for (int i = 0; i < a_elements; i++) {
int mask_i = a_shuffle->getMaskValue(i);
is_identity = is_identity && (mask_i == i || mask_i == -1);
}
if (is_identity) {
return shuffle_vectors(a_shuffle->getOperand(0),
a_shuffle->getOperand(1), indices);
std::vector<int> new_indices(indices.size());
for (size_t i = 0; i < indices.size(); i++) {
if (indices[i] != -1) {
new_indices[i] = a_shuffle->getMaskValue(indices[i]);
} else {
new_indices[i] = -1;
}
}
return shuffle_vectors(a_shuffle->getOperand(0),
a_shuffle->getOperand(1), new_indices);
}
}

Expand Down Expand Up @@ -1518,7 +1519,11 @@ Value *CodeGen_Hexagon::vdelta(Value *lut, const vector<int> &indices) {
vector<int> i8_indices(indices.size() * replicate);
for (size_t i = 0; i < indices.size(); i++) {
for (int j = 0; j < replicate; j++) {
i8_indices[i * replicate + j] = indices[i] * replicate + j;
if (indices[i] == -1) {
i8_indices[i * replicate + j] = -1; // Replicate the don't-care.
} else {
i8_indices[i * replicate + j] = indices[i] * replicate + j;
}
}
}
Value *result = vdelta(i8_lut, i8_indices);
Expand Down Expand Up @@ -1558,6 +1563,7 @@ Value *CodeGen_Hexagon::vdelta(Value *lut, const vector<int> &indices) {
Value *ret = nullptr;
for (int i = 0; i < lut_elements; i += native_elements) {
Value *lut_i = slice_vector(lut, i, native_elements);
internal_assert(get_vector_num_elements(lut_i->getType()) == native_elements);
vector<int> indices_i(native_elements);
vector<Constant *> mask(native_elements);
bool all_used = true;
Expand Down
3 changes: 2 additions & 1 deletion src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4968,10 +4968,11 @@ Value *CodeGen_LLVM::shuffle_vectors(Value *a, Value *b,
}
// Check for type identity *after* normalizing to fixed vectors
internal_assert(a->getType() == b->getType());
int elements_a = get_vector_num_elements(a->getType());
vector<Constant *> llvm_indices(indices.size());
for (size_t i = 0; i < llvm_indices.size(); i++) {
if (indices[i] >= 0) {
internal_assert(indices[i] < get_vector_num_elements(a->getType()) * 2);
internal_assert(indices[i] < elements_a * 2) << indices[i] << " " << elements_a * 2;
llvm_indices[i] = ConstantInt::get(i32_t, indices[i]);
} else {
// Only let -1 be undef.
Expand Down
24 changes: 7 additions & 17 deletions src/CodeGen_Vulkan_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2055,31 +2055,21 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Shuffle *op) {
debug(3) << "\n";

if (arg_ids.size() == 1) {

// 1 argument, just do a simple assignment via a cast
SpvId result_id = cast_type(op->type, op->vectors[0].type(), arg_ids[0]);
builder.update_id(result_id);

} else if (arg_ids.size() == 2) {

// 2 arguments, use a composite insert to update even and odd indices
uint32_t even_idx = 0;
uint32_t odd_idx = 1;
SpvFactory::Indices even_indices;
SpvFactory::Indices odd_indices;
for (int i = 0; i < op_lanes; ++i) {
even_indices.push_back(even_idx);
odd_indices.push_back(odd_idx);
even_idx += 2;
odd_idx += 2;
// 2 arguments, use vector-shuffle with logical indices indexing into (vec1[0], vec1[1], ..., vec2[0], vec2[1], ...)
SpvFactory::Indices logical_indices;
for (int i = 0; i < arg_lanes; ++i) {
logical_indices.push_back(uint32_t(i));
logical_indices.push_back(uint32_t(i + arg_lanes));
}

SpvId type_id = builder.declare_type(op->type);
SpvId value_id = builder.declare_null_constant(op->type);
SpvId partial_id = builder.reserve_id(SpvResultId);
SpvId result_id = builder.reserve_id(SpvResultId);
builder.append(SpvFactory::composite_insert(type_id, partial_id, arg_ids[0], value_id, even_indices));
builder.append(SpvFactory::composite_insert(type_id, result_id, arg_ids[1], partial_id, odd_indices));
builder.append(SpvFactory::vector_shuffle(type_id, result_id, arg_ids[0], arg_ids[1], logical_indices));
builder.update_id(result_id);

} else {
Expand Down Expand Up @@ -2109,7 +2099,7 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Shuffle *op) {
} else if (op->is_extract_element()) {
int idx = op->indices[0];
internal_assert(idx >= 0);
internal_assert(idx <= op->vectors[0].type().lanes());
internal_assert(idx < op->vectors[0].type().lanes());
if (op->vectors[0].type().is_vector()) {
SpvFactory::Indices indices = {(uint32_t)idx};
SpvId type_id = builder.declare_type(op->type);
Expand Down
24 changes: 20 additions & 4 deletions src/Deinterleave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ class Deinterleaver : public IRGraphMutator {
} else {

Type t = op->type.with_lanes(new_lanes);
internal_assert((op->type.lanes() - starting_lane + lane_stride - 1) / lane_stride == new_lanes)
<< "Deinterleaving with lane stride " << lane_stride << " and staring lane " << starting_lane
<< " for var of Type " << op->type << " to " << t << " drops lanes unexpectedly."
<< " Deinterleaver probably recursed too deep into types of different lane count.";
if (external_lets.contains(op->name) &&
starting_lane == 0 &&
lane_stride == 2) {
Expand Down Expand Up @@ -393,8 +397,12 @@ class Deinterleaver : public IRGraphMutator {
int index = indices.front();
for (const auto &i : op->vectors) {
if (index < i.type().lanes()) {
ScopedValue<int> lane(starting_lane, index);
return mutate(i);
if (i.type().lanes() == op->type.lanes()) {
ScopedValue<int> scoped_starting_lane(starting_lane, index);
return mutate(i);
} else {
return Shuffle::make(op->vectors, indices);
}
}
index -= i.type().lanes();
}
Expand All @@ -406,10 +414,18 @@ class Deinterleaver : public IRGraphMutator {
};

Expr deinterleave(Expr e, int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets) {
debug(3) << "Deinterleave "
<< "(start:" << starting_lane << ", stide:" << lane_stride << ", new_lanes:" << new_lanes << "): "
<< e << " of Type: " << e.type() << "\n";
Type original_type = e.type();
e = substitute_in_all_lets(e);
Deinterleaver d(starting_lane, lane_stride, new_lanes, lets);
e = d.mutate(e);
e = common_subexpression_elimination(e);
Type final_type = e.type();
int expected_lanes = (original_type.lanes() + lane_stride - starting_lane - 1) / lane_stride;
internal_assert(original_type.code() == final_type.code()) << "Underlying types not identical after interleaving.";
internal_assert(expected_lanes == final_type.lanes()) << "Number of lanes incorrect after interleaving: " << final_type.lanes() << "while expected was " << expected_lanes << ".";
return simplify(e);
}

Expand All @@ -420,12 +436,12 @@ Expr extract_odd_lanes(const Expr &e, const Scope<> &lets) {

Expr extract_even_lanes(const Expr &e, const Scope<> &lets) {
internal_assert(e.type().lanes() % 2 == 0);
return deinterleave(e, 0, 2, (e.type().lanes() + 1) / 2, lets);
return deinterleave(e, 0, 2, e.type().lanes() / 2, lets);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the starting lane is zero, and there are 3 lanes total, shouldn't there be two lanes in the result?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, there's an assert that the number of lanes is even, never mind.

}

Expr extract_mod3_lanes(const Expr &e, int lane, const Scope<> &lets) {
internal_assert(e.type().lanes() % 3 == 0);
return deinterleave(e, lane, 3, (e.type().lanes() + 2) / 3, lets);
return deinterleave(e, lane, 3, e.type().lanes() / 3, lets);
}

} // namespace
Expand Down
4 changes: 2 additions & 2 deletions src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace Internal {

Expr Cast::make(Type t, Expr v) {
internal_assert(v.defined()) << "Cast of undefined\n";
internal_assert(t.lanes() == v.type().lanes()) << "Cast may not change vector widths\n";
internal_assert(t.lanes() == v.type().lanes()) << "Cast may not change vector widths: " << v << " of type " << v.type() << " cannot be cast to " << t << "\n";

Cast *node = new Cast;
node->type = t;
Expand Down Expand Up @@ -281,7 +281,7 @@ Expr Ramp::make(Expr base, Expr stride, int lanes) {

Expr Broadcast::make(Expr value, int lanes) {
internal_assert(value.defined()) << "Broadcast of undefined\n";
internal_assert(lanes != 1) << "Broadcast of lanes 1\n";
internal_assert(lanes != 1) << "Broadcast over 1 lane is not a broadcast\n";

Broadcast *node = new Broadcast;
node->type = value.type().with_lanes(lanes * value.type().lanes());
Expand Down
3 changes: 2 additions & 1 deletion src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1278,7 +1278,8 @@ Expr random_int(Expr seed = Expr());

/** Create an Expr that prints out its value whenever it is
* evaluated. It also prints out everything else in the arguments
* list, separated by spaces. This can include string literals. */
* list, separated by spaces. This can include string literals.
* Evaluates to the first argument passed. */
//@{
Expr print(const std::vector<Expr> &values);

Expand Down
Loading
Loading