Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 17 additions & 26 deletions src/PartitionLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,40 +547,31 @@ class PartitionLoops : public IRMutator {
(op->partition_policy == Partition::Auto && in_tail)) {
return IRMutator::visit(op);
}
const auto [loop, mutated] = visit_for(op);
user_assert(op->partition_policy != Partition::Always || mutated)
<< "Loop Partition Policy is set to " << op->partition_policy
<< " for " << op->name << ", but no loop partitioning was performed.";
return loop;
}

std::tuple<Stmt, bool> visit_for(const For *op) {
bool mutated = false;
Stmt body = op->body;

// A struct that upon destruction will check if the current For was partitioned
// and error out if it wasn't when the schedule demanded it.
struct ErrorIfNotMutated {
const For *op;
bool must_mutate;
bool mutated{false};
ErrorIfNotMutated(const For *op, bool must_mutate)
: op(op), must_mutate(must_mutate) {
}
~ErrorIfNotMutated() {
if (must_mutate && !mutated) {
user_error << "Loop Partition Policy is set to " << op->partition_policy
<< " for " << op->name << ", but no loop partitioning was performed.";
}
}
} mutation_checker{op, op->partition_policy == Partition::Always};

ScopedValue<bool> old_in_gpu_loop(in_gpu_loop, in_gpu_loop || is_gpu(op->for_type));

// If we're inside GPU kernel, and the body contains thread
// barriers or warp shuffles, it's not safe to partition loops.
if (in_gpu_loop && contains_warp_synchronous_logic(op)) {
return IRMutator::visit(op);
return {IRMutator::visit(op), mutated};
}

// Find simplifications in this loop body
FindSimplifications finder(op->name);
body.accept(&finder);

if (finder.simplifications.empty()) {
return IRMutator::visit(op);
return {IRMutator::visit(op), mutated};
}

debug(3) << "\n\n**** Partitioning loop over " << op->name << "\n";
Expand Down Expand Up @@ -776,13 +767,13 @@ class PartitionLoops : public IRMutator {
prologue = For::make(op->name, op->min, min_steady - op->min,
op->for_type, op->partition_policy, op->device_api, prologue);
stmt = Block::make(prologue, stmt);
mutation_checker.mutated = true;
mutated = true;
}
if (make_epilogue) {
epilogue = For::make(op->name, max_steady, op->min + op->extent - max_steady,
op->for_type, op->partition_policy, op->device_api, epilogue);
stmt = Block::make(stmt, epilogue);
mutation_checker.mutated = true;
mutated = true;
}
} else {
// For parallel for loops we could use a Fork node here,
Expand All @@ -801,15 +792,15 @@ class PartitionLoops : public IRMutator {
stmt = simpler_body;
if (make_epilogue && make_prologue && equal(prologue, epilogue)) {
stmt = IfThenElse::make(min_steady <= loop_var && loop_var < max_steady, stmt, prologue);
mutation_checker.mutated = true;
mutated = true;
} else {
if (make_epilogue) {
stmt = IfThenElse::make(loop_var < max_steady, stmt, epilogue);
mutation_checker.mutated = true;
mutated = true;
}
if (make_prologue) {
stmt = IfThenElse::make(loop_var < min_steady, prologue, stmt);
mutation_checker.mutated = true;
mutated = true;
}
}
stmt = For::make(op->name, op->min, op->extent, op->for_type, op->partition_policy, op->device_api, stmt);
Expand All @@ -833,14 +824,14 @@ class PartitionLoops : public IRMutator {
if (can_prove(epilogue_val <= prologue_val)) {
// The steady state is empty. I've made a huge
// mistake. Try to partition a loop further in.
return IRMutator::visit(op);
return {IRMutator::visit(op), mutated};
}

debug(3) << "Partition loop.\n"
<< "Old: " << Stmt(op) << "\n"
<< "New: " << stmt << "\n";

return stmt;
return {stmt, mutated};
}
};

Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tests(GROUPS correctness
async_order.cpp
autodiff.cpp
bad_likely.cpp
bad_partition_always_throws.cpp
bit_counting.cpp
bits_known.cpp
bitwise_ops.cpp
Expand Down
31 changes: 31 additions & 0 deletions test/correctness/bad_partition_always_throws.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "Halide.h"
using namespace Halide;

int main(int argc, char **argv) {
#ifndef HALIDE_WITH_EXCEPTIONS
printf("[SKIP] bad_partition_always_throws requires exceptions\n");
return 0;
#else
try {
Func f("f");
Var x("x");
f(x) = 0;
f.partition(x, Partition::Always);
f.realize({10});
} catch (const CompileError &e) {
const std::string_view msg = e.what();
constexpr std::string_view expected_msg =
"Loop Partition Policy is set to Always for f.s0.x, "
"but no loop partitioning was performed.";
if (msg.find(expected_msg) == std::string_view::npos) {
std::cerr << "Expected error containing (" << expected_msg << "), but got (" << msg << ")\n";
return 1;
}
printf("Success!\n");
return 0;
}

printf("Did not see any exception!\n");
return 1;
#endif
}
4 changes: 3 additions & 1 deletion test/correctness/bounds_of_pure_intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ int main(int argc, char **argv) {
Scope<Interval> scope;
scope.push(p2.name(), Interval{p2_min, p2_max});

for (int limit = 1; limit < 500; limit++) {
// This test uses a lot of stack space, especially on ASAN, where we don't
// do any stack switching (see Util.cpp). Don't push this number too far.
for (int limit = 1; limit < 100; limit++) {
Expr e1 = p1, e2 = p2;
for (int i = 0; i < limit; i++) {
e1 = e1 * p1 + (i + 1);
Expand Down
Loading