Skip to content

Commit

Permalink
gpu: jit: do not rewrite 64-bit exprs after overflow fix pass
Browse files Browse the repository at this point in the history
  • Loading branch information
echeresh committed Apr 14, 2023
1 parent 31ac0e0 commit 12d5743
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 14 deletions.
1 change: 1 addition & 0 deletions src/gpu/jit/conv/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,7 @@ void conv_ir_builder_t::build() {
cfg_.reserved_regs());
stmt_ = split_shuffle(stmt_, ir_ctx);
stmt_ = fixup_if_conditions(stmt_, ir_ctx);
stmt_ = optimize_int64_exprs(stmt_, ir_ctx);
stmt_ = fix_int32_overflow(stmt_, ir_ctx);
stmt_ = eliminate_common_subexprs(
stmt_, ir_ctx, cfg_.reserved_regs(), cfg_.slm().gmem_bufs());
Expand Down
30 changes: 30 additions & 0 deletions src/gpu/jit/pass/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "gpu/jit/ir/message.hpp"
#include "gpu/jit/ir/reorder.hpp"
#include "gpu/jit/pass/simplify.hpp"
#include "gpu/jit/utils/trace.hpp"

namespace dnnl {
Expand Down Expand Up @@ -176,6 +177,35 @@ stmt_t fixup_if_conditions(const stmt_t &s, ir_context_t &ir_ctx) {
return ret;
}

class int64_expr_optimizer_t : public ir_mutator_t {
public:
#define HANDLE_IR_OBJECT(type) \
object_t _mutate(const type &obj) override { return mutate_expr(obj); }

HANDLE_EXPR_IR_OBJECTS()

#undef HANDLE_IR_OBJECT

private:
template <typename T>
object_t mutate_expr(const T &obj) {
auto new_obj = ir_mutator_t::_mutate(obj);
if (auto *binary = new_obj.template as_ptr<binary_op_t>()) {
if (binary->op_kind == op_kind_t::_add) {
new_obj = simplify_64_bit_add(new_obj);
}
}
return new_obj;
}
};

stmt_t optimize_int64_exprs(const stmt_t &s, ir_context_t &ir_ctx) {
trace_start();
auto ret = int64_expr_optimizer_t().mutate(s);
trace_pass("optimize_int64_exprs", ret, ir_ctx);
return ret;
}

} // namespace jit
} // namespace gpu
} // namespace impl
Expand Down
4 changes: 4 additions & 0 deletions src/gpu/jit/pass/pass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ stmt_t split_wide_stores(const stmt_t &s, ir_context_t &ir_ctx);
// if (bcast8(cond)) { ... }
stmt_t fixup_if_conditions(const stmt_t &s, ir_context_t &ir_ctx);

// Rewrites mixed 64-bit/32-bit expressions to reduce 64-bit arithmetic.
// Potential overflow is ignored and must be checked/fixed by further passes.
stmt_t optimize_int64_exprs(const stmt_t &s, ir_context_t &ir_ctx);

} // namespace jit
} // namespace gpu
} // namespace impl
Expand Down
31 changes: 18 additions & 13 deletions src/gpu/jit/pass/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1521,18 +1521,6 @@ expr_t reorder_nary_add_args(const expr_t &e, bool x64_first) {
return nary_op_t::make(nary_op->op_kind, new_args);
}

// Rewrites addition with mixed 64-bit/32-bit expressions to reduce 64-bit
// arithmetic. Example:
// Before: ((x.s64 + y.s32) + z.s32) [two 64-bit add]
// After: ((y.s32 + z.s32) + x.s64) [one 32-bit add and one 64-bit add]
class _64_bit_add_optimizer_t : public nary_op_mutator_t {
public:
object_t _mutate(const nary_op_t &obj) override {
auto new_obj = nary_op_mutator_t::_mutate(obj);
return reorder_nary_add_args(new_obj, /*x64_first=*/false);
}
};

// Simplifies using the N-ary form.
expr_t simplify_with_nary(const expr_t &_e, const constraint_set_t &cset) {
auto e = _e;
Expand All @@ -1545,13 +1533,30 @@ expr_t simplify_with_nary(const expr_t &_e, const constraint_set_t &cset) {
e = int_div_mod_expander_t(cset).mutate(e);
e = common_factor_simplifier_t().mutate(e);
e = int_div_mod_range_simplifier_t(cset).mutate(e);
e = _64_bit_add_optimizer_t().mutate(e);

e = nary_op_back_transform(e);

return e;
}

class _64_bit_add_optimizer_t : public nary_op_mutator_t {
public:
object_t _mutate(const nary_op_t &obj) override {
auto new_obj = nary_op_mutator_t::_mutate(obj);
return reorder_nary_add_args(new_obj, /*x64_first=*/false);
}
};

expr_t simplify_64_bit_add(const expr_t &_e) {
auto e = _e;

e = nary_op_canonicalize(e);
e = _64_bit_add_optimizer_t().mutate(e);
e = nary_op_back_transform(e);

return e;
}

class stmt_simplifier_t : public ir_mutator_t {
public:
stmt_simplifier_t(const constraint_set_t &cset) : cset_(cset) {}
Expand Down
8 changes: 7 additions & 1 deletion src/gpu/jit/pass/simplify.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022 Intel Corporation
* Copyright 2022-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,6 +43,12 @@ expr_t simplify_rewrite_with_ternary(const expr_t &e, bool recursive = true);
// Example: (c0 + x) op c1 -> x op (c1 - c0)
expr_t simplify_cmp_move_const_to_rhs(const expr_t &e);

// Rewrites addition with mixed 64-bit/32-bit expressions to reduce 64-bit
// arithmetic. Example:
// Before: ((x.s64 + y.s32) + z.s32) [two 64-bit add]
// After: ((y.s32 + z.s32) + x.s64) [one 32-bit add and one 64-bit add]
expr_t simplify_64_bit_add(const expr_t &e);

// Reduces left and right hand sides of an expression.
// Example: A * x < A * B -> x < B (if A > 0).
expr_t simplify_cmp_reduce_lhs_rhs(const expr_t &e);
Expand Down

0 comments on commit 12d5743

Please sign in to comment.