Skip to content

Commit

Permalink
Add checks to prevent people from using negative split factors (halid…
Browse files Browse the repository at this point in the history
…e#8076)

* Add checks to prevent people from using negative split factors

Our analysis passes assume that loop maxes are greater than loop mins,
so negative split factors cause sufficient havoc that not even output
bounds queries are safe. These are therefore checked on pipeline entry.

This is a new way for output bounds queries to throw errors (in addition
to the buffer pointers themselves being null, and maybe some buffer
constraints). Testing this, I realized these errors were getting thrown
twice, because the output buffer bounds query in Pipeline::realize was
built around two recursive calls to realize, and both were calling the
custom error handler. In addition to reporting errors in this class
twice, this implies several other inefficiencies, e.g. jit call args
were being prepped twice. I reworked it to be built around two calls to
call_jit_code instead.

Fixes halide#7938

* Add test to cmakelists

* Remove pointless target arg to call_jit_code

It has to be the same as the cached target in the receiving object
anyway
  • Loading branch information
abadams authored Feb 11, 2024
1 parent 22581bf commit 9c3615b
Show file tree
Hide file tree
Showing 15 changed files with 208 additions and 28 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ SOURCE_FILES = \
AddAtomicMutex.cpp \
AddImageChecks.cpp \
AddParameterChecks.cpp \
AddSplitFactorChecks.cpp \
AlignLoads.cpp \
AllocationBoundsInference.cpp \
ApplySplit.cpp \
Expand Down Expand Up @@ -637,6 +638,7 @@ HEADER_FILES = \
AddAtomicMutex.h \
AddImageChecks.h \
AddParameterChecks.h \
AddSplitFactorChecks.h \
AlignLoads.h \
AllocationBoundsInference.h \
ApplySplit.h \
Expand Down
68 changes: 68 additions & 0 deletions src/AddSplitFactorChecks.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include "AddSplitFactorChecks.h"
#include "Definition.h"
#include "Function.h"
#include "IR.h"
#include "IROperator.h"
#include "Simplify.h"

namespace Halide {
namespace Internal {

namespace {

void check_all_split_factors(const Function &f, const Definition &def, std::vector<Stmt> *stmts) {
const StageSchedule &sched = def.schedule();
for (const Split &split : sched.splits()) {
if (split.split_type != Split::SplitVar) {
continue;
}
if (is_positive_const(split.factor)) {
// Common-case optimization
continue;
}
Expr positive = simplify(split.factor > 0);
if (is_const_one(positive)) {
// We statically proved it
continue;
}
// We need a runtime check that says: if the condition is
// entered, the split factor will be positive. We can still
// assume the pipeline preconditions, because they will be
// checked before this.
std::ostringstream factor_str;
factor_str << split.factor;
Expr error = Call::make(Int(32), "halide_error_split_factor_not_positive",
{f.name(),
split_string(split.old_var, ".").back(),
split_string(split.outer, ".").back(),
split_string(split.inner, ".").back(),
factor_str.str(), split.factor},
Call::Extern);
stmts->push_back(AssertStmt::make(positive, error));
}

for (const auto &s : def.specializations()) {
check_all_split_factors(f, s.definition, stmts);
}
}

} // namespace

Stmt add_split_factor_checks(const Stmt &s, const std::map<std::string, Function> &env) {
// Check split factors are strictly positive
std::vector<Stmt> stmts;

for (const auto &p : env) {
const Function &f = p.second;
check_all_split_factors(f, f.definition(), &stmts);
for (const auto &u : f.updates()) {
check_all_split_factors(f, u, &stmts);
}
}

stmts.push_back(s);
return Block::make(stmts);
}

} // namespace Internal
} // namespace Halide
25 changes: 25 additions & 0 deletions src/AddSplitFactorChecks.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef HALIDE_INTERNAL_ADD_SPLIT_FACTOR_CHECKS_H
#define HALIDE_INTERNAL_ADD_SPLIT_FACTOR_CHECKS_H

/** \file
*
* Defines the lowering pass that adds the assertions that all split factors are
* strictly positive.
*/
#include <map>

#include "Expr.h"

namespace Halide {
namespace Internal {

class Function;

/** Insert checks that all split factors that depend on scalar parameters are
* strictly positive. */
Stmt add_split_factor_checks(const Stmt &s, const std::map<std::string, Function> &env);

} // namespace Internal
} // namespace Halide

#endif
6 changes: 4 additions & 2 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set(HEADER_FILES
AddAtomicMutex.h
AddImageChecks.h
AddParameterChecks.h
AddSplitFactorChecks.h
AlignLoads.h
AllocationBoundsInference.h
ApplySplit.h
Expand All @@ -22,7 +23,7 @@ set(HEADER_FILES
Bounds.h
BoundsInference.h
BoundConstantExtentLoops.h
BoundSmallAllocations.h
BoundSmallAllocations.h
Buffer.h
Callable.h
CanonicalizeGPUVars.h
Expand Down Expand Up @@ -178,6 +179,7 @@ set(SOURCE_FILES
AddAtomicMutex.cpp
AddImageChecks.cpp
AddParameterChecks.cpp
AddSplitFactorChecks.cpp
AlignLoads.cpp
AllocationBoundsInference.cpp
ApplySplit.cpp
Expand Down Expand Up @@ -546,7 +548,7 @@ set_target_properties(Halide PROPERTIES
# Note that we (deliberately) redeclare these versions here, even though the macros
# with identical versions are expected to be defined in source; this allows us to
# ensure that the versions defined between all build systems are identical.
target_compile_definitions(Halide PUBLIC
target_compile_definitions(Halide PUBLIC
HALIDE_VERSION_MAJOR=${Halide_VERSION_MAJOR}
HALIDE_VERSION_MINOR=${Halide_VERSION_MINOR}
HALIDE_VERSION_PATCH=${Halide_VERSION_PATCH})
Expand Down
2 changes: 1 addition & 1 deletion src/Callable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ Callable::FailureFn Callable::check_fcci(size_t argc, const FullCallCheckInfo *a

JITFuncCallContext jit_call_context(context, contents->saved_jit_handlers);

int exit_status = contents->jit_cache.call_jit_code(contents->jit_cache.jit_target, argv);
int exit_status = contents->jit_cache.call_jit_code(argv);

// If we're profiling, report runtimes and reset profiler stats.
contents->jit_cache.finish_profiling(context);
Expand Down
4 changes: 2 additions & 2 deletions src/JITModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ Target JITCache::get_compiled_jit_target() const {
return jit_target;
}

int JITCache::call_jit_code(const Target &target, const void *const *args) {
int JITCache::call_jit_code(const void *const *args) {
#if defined(__has_feature)
#if __has_feature(memory_sanitizer)
user_warning << "MSAN does not support JIT compilers of any sort, and will report "
Expand All @@ -1122,7 +1122,7 @@ int JITCache::call_jit_code(const Target &target, const void *const *args) {
"compilation for Halide code.";
#endif
#endif
if (target.arch == Target::WebAssembly) {
if (get_compiled_jit_target().arch == Target::WebAssembly) {
internal_assert(wasm_module.contents.defined());
return wasm_module.run(args);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/JITModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ struct JITCache {

Target get_compiled_jit_target() const;

int call_jit_code(const Target &target, const void *const *args);
int call_jit_code(const void *const *args);

void finish_profiling(JITUserContext *context);
};
Expand Down
5 changes: 5 additions & 0 deletions src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "AddAtomicMutex.h"
#include "AddImageChecks.h"
#include "AddParameterChecks.h"
#include "AddSplitFactorChecks.h"
#include "AllocationBoundsInference.h"
#include "AsyncProducers.h"
#include "BoundConstantExtentLoops.h"
Expand Down Expand Up @@ -182,6 +183,10 @@ void lower_impl(const vector<Function> &output_funcs,
s = bounds_inference(s, outputs, order, fused_groups, env, func_bounds, t);
log("Lowering after computation bounds inference:", s);

debug(1) << "Asserting that all split factors are positive...\n";
s = add_split_factor_checks(s, env);
log("Lowering after asserting that all split factors are positive:", s);

debug(1) << "Removing extern loops...\n";
s = remove_extern_loops(s);
log("Lowering after removing extern loops:", s);
Expand Down
62 changes: 42 additions & 20 deletions src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,18 @@ Target Pipeline::get_compiled_jit_target() const {
void Pipeline::compile_jit(const Target &target_arg) {
user_assert(defined()) << "Pipeline is undefined\n";

Target target = target_arg.with_feature(Target::JIT).with_feature(Target::UserContext);
Target target = target_arg;

if (target.has_unknowns()) {
// If we've already jit-compiled for a specific target, use that.
target = get_compiled_jit_target();
if (target.has_unknowns()) {
// Otherwise get the target from the environment
target = get_jit_target_from_environment();
}
}

target.set_features({Target::JIT, Target::UserContext});

// If we're re-jitting for the same target, we can just keep the old jit module.
if (get_compiled_jit_target() == target) {
Expand Down Expand Up @@ -751,17 +762,37 @@ Realization Pipeline::realize(JITUserContext *context,
bufs.emplace_back(t, nullptr, sizes);
}
}
Realization r(std::move(bufs));
Realization r{std::move(bufs)};

compile_jit(target);
JITUserContext empty_user_context = {};
if (!context) {
context = &empty_user_context;
}
JITFuncCallContext jit_context(context, jit_handlers());
JITCallArgs args(contents->inferred_args.size() + r.size());
RealizationArg arg{r};
prepare_jit_call_arguments(arg, contents->jit_cache.jit_target,
&context, true, args);

// Do an output bounds query if we can. Otherwise just assume the
// output size is good.
int exit_status = 0;
if (!target.has_feature(Target::NoBoundsQuery)) {
realize(context, r, target);
exit_status = call_jit_code(args);
}
for (size_t i = 0; i < r.size(); i++) {
r[i].allocate();
if (exit_status == 0) {
// Make the output allocations
for (size_t i = 0; i < r.size(); i++) {
r[i].allocate();
}
// Do the actual computation
exit_status = call_jit_code(args);
}
// Do the actual computation
realize(context, r, target);

// If we're profiling, report runtimes and reset profiler stats.
contents->jit_cache.finish_profiling(context);
jit_context.finalize(exit_status);

// Crop back to the requested size if necessary
bool needs_crop = false;
Expand Down Expand Up @@ -943,8 +974,8 @@ Pipeline::make_externs_jit_module(const Target &target,
return result;
}

int Pipeline::call_jit_code(const Target &target, const JITCallArgs &args) {
return contents->jit_cache.call_jit_code(target, args.store);
int Pipeline::call_jit_code(const JITCallArgs &args) {
return contents->jit_cache.call_jit_code(args.store);
}

void Pipeline::realize(RealizationArg outputs, const Target &t) {
Expand All @@ -959,15 +990,6 @@ void Pipeline::realize(JITUserContext *context,

debug(2) << "Realizing Pipeline for " << target << "\n";

if (target.has_unknowns()) {
// If we've already jit-compiled for a specific target, use that.
target = get_compiled_jit_target();
if (target.has_unknowns()) {
// Otherwise get the target from the environment
target = get_jit_target_from_environment();
}
}

// We need to make a context for calling the jitted function to
// carry the the set of custom handlers. Here's how handlers get
// called when running jitted code:
Expand Down Expand Up @@ -1041,7 +1063,7 @@ void Pipeline::realize(JITUserContext *context,
// exception.

debug(2) << "Calling jitted function\n";
int exit_status = call_jit_code(target, args);
int exit_status = call_jit_code(args);
debug(2) << "Back from jitted function. Exit status was " << exit_status << "\n";

// If we're profiling, report runtimes and reset profiler stats.
Expand Down Expand Up @@ -1111,7 +1133,7 @@ void Pipeline::infer_input_bounds(JITUserContext *context,
}

Internal::debug(2) << "Calling jitted function\n";
int exit_status = call_jit_code(contents->jit_cache.jit_target, args);
int exit_status = call_jit_code(args);
jit_context.finalize(exit_status);
Internal::debug(2) << "Back from jitted function\n";
bool changed = false;
Expand Down
3 changes: 1 addition & 2 deletions src/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ class Pipeline {
private:
Internal::IntrusivePtr<PipelineContents> contents;

// For the three method below, precisely one of the first two args should be non-null
void prepare_jit_call_arguments(RealizationArg &output, const Target &target,
JITUserContext **user_context, bool is_bounds_inference, Internal::JITCallArgs &args_result);

Expand All @@ -160,7 +159,7 @@ class Pipeline {

static AutoSchedulerFn find_autoscheduler(const std::string &autoscheduler_name);

int call_jit_code(const Target &target, const Internal::JITCallArgs &args);
int call_jit_code(const Internal::JITCallArgs &args);

// Get the value of contents->jit_target, but reality-check that the contents
// sensibly match the value. Return Target() if not jitted.
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/HalideRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,10 @@ enum halide_error_code_t {
/** An explicit storage bound provided is too small to store
* all the values produced by the function. */
halide_error_code_storage_bound_too_small = -45,

/** A factor used to split a loop was discovered to be zero or negative at
* runtime. */
halide_error_code_split_factor_not_positive = -46,
};

/** Halide calls the functions below on various error conditions. The
Expand Down Expand Up @@ -1316,6 +1320,8 @@ extern int halide_error_device_dirty_with_no_device_support(void *user_context,
extern int halide_error_storage_bound_too_small(void *user_context, const char *func_name, const char *var_name,
int provided_size, int required_size);
extern int halide_error_device_crop_failed(void *user_context);
extern int halide_error_split_factor_not_positive(void *user_context, const char *func_name, const char *orig, const char *outer, const char *inner, const char *factor_str, int factor);

// @}

/** Optional features a compilation Target can have.
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/errors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,4 +291,13 @@ WEAK int halide_error_device_crop_failed(void *user_context) {
return halide_error_code_device_crop_failed;
}

WEAK int halide_error_split_factor_not_positive(void *user_context, const char *func_name, const char *orig, const char *outer, const char *inner, const char *factor_str, int factor) {
error(user_context) << "In schedule for func " << func_name
<< ", the factor used to split the variable " << orig
<< " into " << outer << " and " << inner << " is " << factor_str
<< ". This evaluated to " << factor << ", which is not strictly positive. "
<< "Consider using max(" << factor_str << ", 1) instead.";
return halide_error_code_split_factor_not_positive;
}

} // extern "C"
1 change: 1 addition & 0 deletions src/runtime/runtime_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ extern "C" __attribute__((used)) void *halide_runtime_api_functions[] = {
(void *)&halide_error_param_too_small_u64,
(void *)&halide_error_requirement_failed,
(void *)&halide_error_specialize_fail,
(void *)&halide_error_split_factor_not_positive,
(void *)&halide_error_unaligned_host_ptr,
(void *)&halide_error_storage_bound_too_small,
(void *)&halide_error_device_crop_failed,
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ tests(GROUPS correctness
multiple_outputs.cpp
mux.cpp
narrow_predicates.cpp
negative_split_factors.cpp
nested_tail_strategies.cpp
newtons_method.cpp
non_nesting_extern_bounds_query.cpp
Expand Down
Loading

0 comments on commit 9c3615b

Please sign in to comment.