Skip to content

Commit

Permalink
gpu: jit: optimize tanh post-op, add compatibility mode
Browse files Browse the repository at this point in the history
  • Loading branch information
hidefromkgb authored and karturov committed Oct 4, 2022
1 parent 10f0d0a commit 6224dc6
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 29 deletions.
8 changes: 5 additions & 3 deletions src/gpu/jit/conv/conv_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3625,7 +3625,8 @@ class ir_to_ngen_t : public ir_visitor_t {
ir_to_ngen_t(ir_kernel_t<hw> *host, const expr_binding_t &expr_binding)
: host_(host)
, expr_binding_(expr_binding)
, simd_size_(host->getSIMD()) {}
, simd_size_(host->getSIMD())
, eu_count_(host->hw_cfg_.eu_count()) {}

~ir_to_ngen_t() {
#ifdef GEN_CONV_DEBUG
Expand Down Expand Up @@ -4170,8 +4171,8 @@ class ir_to_ngen_t : public ir_visitor_t {
auto &data_op = eltwise_t::arg_data(args);
auto data_rd = data_op.reg_buf_data();

jit_eltwise_injector_f32<hw> inj(
host_, func.alg_kind, func.alpha, func.beta, func.scale);
jit_eltwise_injector_f32<hw> inj(host_, func.alg_kind, func.alpha,
func.beta, func.scale, eu_count_);
auto scratch = scope.alloc_range(inj.preferred_scratch_regs());
inj.set_scratch(scratch);
inj.prepare();
Expand Down Expand Up @@ -4233,6 +4234,7 @@ class ir_to_ngen_t : public ir_visitor_t {
ir_kernel_t<hw> *host_;
expr_binding_t expr_binding_;
int simd_size_;
int eu_count_;

std::vector<ngen::Label> loop_end_labels_;

Expand Down
4 changes: 3 additions & 1 deletion src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12429,8 +12429,10 @@ bool gemm_kernel_generator_t<hw>::gemmUpdateC(
// Prepare postop injector if configured.
GRFRange postOpScratch;
if (problem.hasPostOp()) {
// No EU count information available.
const int eu_count = 0;
postOpInjector.reset(new Injector(this, problem.Ts.get_dnnl_type(),
problem.post_ops, GRFRange(), problem.postOpFwd));
problem.post_ops, eu_count, GRFRange(), problem.postOpFwd));
if (!postOpInjector) stub();

postOpScratch = state.ra.try_alloc_range(
Expand Down
3 changes: 3 additions & 0 deletions src/gpu/jit/gemm/xe_hp_systolic_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,9 @@ status_t xe_hp_systolic_gemm_t::init_compute_old(engine_t *engine) {
cfg.tile_n = pd()->unroll_n();
cfg.global_3x_buf = (cfg.tile_n == 32);

auto compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
cfg.eu_count = compute_engine->device_info()->eu_count();

if (pd()->with_c_zero_points())
cfg.co_type = cfg.c_type;
else if (pd()->with_bias()) {
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/jit/gemm/xe_hp_systolic_gemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1449,7 +1449,7 @@ void xehp_systolic_gemm_kernel_t<hw>::generate() {

if (cfg.have_post_op()) {
auto inj_ptr = new injector_t(this, data_type::f32, cfg.post_ops,
upost_op_scratch, cfg.post_op_is_fwd);
cfg.eu_count, upost_op_scratch, cfg.post_op_is_fwd);
assert(inj_ptr);
post_op_injector.reset(inj_ptr);
}
Expand Down
1 change: 1 addition & 0 deletions src/gpu/jit/gemm/xe_hp_systolic_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class xehp_systolic_gemm_kernel_t : public jit_generator<hw> {
bool c_packed = false;
bool batch = false;
bool emulate64 = (hw == ngen::HW::XeHPG);
int eu_count = 0;

int tile_m = 32;
int tile_n = 48;
Expand Down
101 changes: 82 additions & 19 deletions src/gpu/jit/jit_eltwise_injector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ int jit_eltwise_injector_f32<hw>::min_scratch_regs() {
case eltwise_hardswish: return 1;
case eltwise_log: return 0;
case eltwise_logsigmoid: return 1;
case eltwise_mish: return 2;
case eltwise_mish: return 4;
case eltwise_pow: return 1;
case eltwise_relu:
case eltwise_relu_use_dst_for_bwd: return (alpha_ == 0.f) ? 0 : 1;
Expand All @@ -51,7 +51,7 @@ int jit_eltwise_injector_f32<hw>::min_scratch_regs() {
case eltwise_square: return 0;
case eltwise_swish: return 1;
case eltwise_tanh:
case eltwise_tanh_use_dst_for_bwd: return 1;
case eltwise_tanh_use_dst_for_bwd: return 2;
case eltwise_round: return 0;
case eltwise_linear: return 0;
case eltwise_bounded_relu:
Expand Down Expand Up @@ -124,8 +124,8 @@ int jit_eltwise_injector_f32<hw>::max_batch_size() {
case eltwise_logsigmoid:
case eltwise_pow:
case eltwise_soft_relu:
case eltwise_tanh:
case eltwise_swish: return ss;
case eltwise_tanh:
case eltwise_mish:
case eltwise_gelu_erf: return ss / min_scratch_regs();
case eltwise_gelu_tanh: return ss & ~1;
Expand Down Expand Up @@ -166,7 +166,8 @@ int jit_eltwise_injector_f32<hw>::phase_count(alg_kind_t alg) {
case eltwise_soft_relu: return 9;
case eltwise_swish: return 5;
case eltwise_tanh:
case eltwise_tanh_use_dst_for_bwd: return 7;
case eltwise_tanh_use_dst_for_bwd:
return (use_tanh_compat()) ? 9 : 6;
case eltwise_linear: return 2;
case eltwise_bounded_relu:
case eltwise_clip:
Expand Down Expand Up @@ -255,17 +256,43 @@ void jit_eltwise_injector_f32<hw>::square_compute_fwd(

template <gpu_gen_t hw>
void jit_eltwise_injector_f32<hw>::tanh_compute_fwd(
int simd, const ngen::GRF &r, int phase, int off) {
const float log2e = 1.442695f; // log_2(e)
auto a = scratch_[off].f();
int simd, const ngen::GRF &r, int phase, int off, int batch) {
const float log2e = 1.44269502162933349609375f; // log_2(e)
auto one_half = scratch_[0].f(7);
auto a = scratch_[off + batch].f();
switch (phase) {
case 0: h->mul(simd, a, abs(r), 2 * log2e); break;
case 0: h->mul(simd, a, abs(r), 2.f * log2e); break;
case 1: h->exp(simd, a, a); break;
case 2: h->add(simd, a, a, 1.f); break;
case 2: h->mad(simd, a, one_half, a, one_half); break;
case 3: h->inv(simd, a, a); break;
case 4: h->mul(simd, a, a, 2.f); break;
case 5: h->add(simd, a, -a, 1.f); break;
case 6: h->csel(simd | ge | f0[0], r, a, -a, r); break;
case 4: h->add(simd, a, -a, 1.f); break;
case 5: h->csel(simd | ge | f0[0], r, a, -a, r); break;
default: assert(!"invalid phase");
}
}

template <gpu_gen_t hw>
void jit_eltwise_injector_f32<hw>::tanh_compute_fwd_compat(
int simd, const ngen::GRF &r, int phase, int off, int batch) {
// This approximation of tanh(x) does not use the math.exp instruction
// that seems to be faulty on DG2-128; the exact formula is as follows:
// R = max(min(0.0519867*x*((x^2 + k)^2 + l)/((x^2 + m)^2 + n), 1), -1)
// Both absolute and relative errors are <7*10^-5 \forall x \in \mathbb R
auto k = scratch_[0].f(4);
auto l = scratch_[0].f(5);
auto m = scratch_[0].f(6);
auto n = scratch_[0].f(7);
auto a = scratch_[off + batch].f();
switch (phase) {
case 0: h->mad(simd, a, m, r, r); break;
case 1: h->mad(simd, a, n, a, a); break;
case 2: h->inv(simd, a, a); break;
case 3: h->mul(simd, a, a, r); break;
case 4: h->mad(simd, r, k, r, r); break;
case 5: h->mad(simd, r, l, r, r); break;
case 6: h->mul(simd, r, r, 0.0519867f); break; // 0.051986694f
case 7: h->mul(simd | sat, r, r, abs(a)); break;
case 8: h->csel(simd | ge | f0[0], r, r, -r, a); break;
default: assert(!"invalid phase");
}
}
Expand Down Expand Up @@ -385,6 +412,24 @@ void jit_eltwise_injector_f32<hw>::clip_prepare_bwd() {
h->mov(1, pos_inf, pos_inf_imm);
}

template <gpu_gen_t hw>
void jit_eltwise_injector_f32<hw>::tanh_prepare_fwd() {
auto one_half = scratch_[0].f(7);
h->mov(1, one_half, 0.5f);
}

template <gpu_gen_t hw>
void jit_eltwise_injector_f32<hw>::tanh_prepare_fwd_compat() {
auto k = scratch_[0].f(4);
auto l = scratch_[0].f(5);
auto m = scratch_[0].f(6);
auto n = scratch_[0].f(7);
h->mov(1, k, 77.0954f); // 77.095392909578f
h->mov(1, l, -4435.55f); // -4435.54623970169f
h->mov(1, m, 17.06396f); // 17.06396485f
h->mov(1, n, -212.7724f); // -212.772646402036f
}

template <gpu_gen_t hw>
void jit_eltwise_injector_f32<hw>::abs_compute_bwd(
int simd, const ngen::GRF &r, int phase) {
Expand Down Expand Up @@ -580,15 +625,20 @@ void jit_eltwise_injector_f32<hw>::logsigmoid_compute_fwd(
template <gpu_gen_t hw>
void jit_eltwise_injector_f32<hw>::mish_compute_fwd(
int simd, const ngen::GRF &r, int phase, int off, int batch) {
auto temp = scratch_[off].f();
auto temp2 = scratch_[off + batch].f();
auto temp = scratch_[off + batch].f();
auto temp2 = scratch_[off + 2 * batch].f();
const int srelu_phases = phase_count(alg_kind::eltwise_soft_relu);
const int tanh_phases = phase_count(alg_kind::eltwise_tanh);
// note tanh_compute_fwd will trash temp
// note tanh_compute_fwd_* clobbers scratch_[off] and scratch_[off + batch]
if (phase < srelu_phases)
soft_relu_compute_fwd_inner(simd, r, temp, temp2, phase, off);
if (phase >= srelu_phases && phase < srelu_phases + tanh_phases)
tanh_compute_fwd(simd, temp2, phase - srelu_phases, off);
if (phase >= srelu_phases && phase < srelu_phases + tanh_phases) {
if (use_tanh_compat())
tanh_compute_fwd_compat(
simd, temp2, phase - srelu_phases, off, batch);
else
tanh_compute_fwd(simd, temp2, phase - srelu_phases, off, batch);
}
if (phase == srelu_phases + tanh_phases) h->mul(simd, r, r, temp2);
if (phase > srelu_phases + tanh_phases) assert(!"invalid phase");
}
Expand Down Expand Up @@ -687,7 +737,11 @@ void jit_eltwise_injector_f32<hw>::compute(const ngen::GRFRange &regs) {
break;
case eltwise_tanh:
case eltwise_tanh_use_dst_for_bwd:
tanh_compute_fwd(simd, base, phase, ii);
if (use_tanh_compat())
tanh_compute_fwd_compat(
simd, base, phase, ii, batch);
else
tanh_compute_fwd(simd, base, phase, ii, batch);
break;
case eltwise_round:
round_compute_fwd(simd, base);
Expand Down Expand Up @@ -752,7 +806,16 @@ void jit_eltwise_injector_f32<hw>::prepare() {
assert(scratch_.getLen() >= min_scratch_regs());

if (is_fwd_) {
/* nothing to do */
switch (alg_) {
case eltwise_mish:
case eltwise_tanh:
if (use_tanh_compat())
tanh_prepare_fwd_compat();
else
tanh_prepare_fwd();
break;
default: break;
}
} else {
switch (alg_) {
case eltwise_relu: relu_prepare_bwd(); break;
Expand Down
19 changes: 17 additions & 2 deletions src/gpu/jit/jit_eltwise_injector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ inline bool jit_eltwise_injector_f32_is_supported(alg_kind_t alg) {
template <gpu_gen_t hw>
struct jit_eltwise_injector_f32 {
jit_eltwise_injector_f32(jit_generator<hw> *host, alg_kind_t alg,
float alpha, float beta, float scale,
float alpha, float beta, float scale, int eu_count,
const ngen::GRFRange &scratch = ngen::GRFRange(),
bool is_fwd = true)
: alg_(alg)
, alpha_(alpha)
, beta_(beta)
, scale_(scale)
, is_fwd_(is_fwd)
, eu_count_(eu_count)
, h(host)
, scratch_(scratch) {

Expand All @@ -76,16 +77,27 @@ struct jit_eltwise_injector_f32 {
const float scale_;
const bool is_fwd_;

const int eu_count_;

jit_generator<hw> *h;

ngen::GRFRange scratch_;

bool is_gpu(ngen::HW arg_hw, int arg_eu_count) const {
return (hw == arg_hw) && (eu_count_ == arg_eu_count);
}
bool use_tanh_compat() const {
return is_gpu(ngen::HW::XeHPG, 96) || is_gpu(ngen::HW::XeHPG, 128);
}

int max_batch_size();
int phase_count(alg_kind_t alg);

void relu_prepare_bwd();
void abs_prepare_bwd();
void clip_prepare_bwd();
void tanh_prepare_fwd();
void tanh_prepare_fwd_compat();

void relu_zero_ns_compute_fwd(int simd, const ngen::GRF &r);
void relu_compute_fwd(int simd, const ngen::GRF &r, int phase, int off);
Expand All @@ -112,7 +124,10 @@ struct jit_eltwise_injector_f32 {
void square_compute_fwd(int simd, const ngen::GRF &r);
void round_compute_fwd(int simd, const ngen::GRF &r);
void swish_compute_fwd(int simd, const ngen::GRF &r, int phase, int off);
void tanh_compute_fwd(int simd, const ngen::GRF &r, int phase, int off);
void tanh_compute_fwd(
int simd, const ngen::GRF &r, int phase, int off, int batch);
void tanh_compute_fwd_compat(
int simd, const ngen::GRF &r, int phase, int off, int batch);
void linear_compute_fwd(int simd, const ngen::GRF &r, int phase);
void clip_compute_fwd(
int simd, const ngen::GRF &r, int phase, float alpha, float beta);
Expand Down
7 changes: 4 additions & 3 deletions src/gpu/jit/jit_post_op_injector.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021 Intel Corporation
* Copyright 2021-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -47,7 +47,7 @@ inline bool jit_post_op_injector_is_supported(
template <gpu_gen_t hw>
struct jit_post_op_injector {
jit_post_op_injector(jit_generator<hw> *host, data_type_t accumulator_type,
const post_ops_t &post_ops,
const post_ops_t &post_ops, int eu_count,
const ngen::GRFRange &scratch = ngen::GRFRange(),
bool is_fwd = true)
: post_ops_(post_ops), is_fwd_(is_fwd), scratch_(scratch) {
Expand All @@ -57,7 +57,8 @@ struct jit_post_op_injector {
const auto &po = post_ops.entry_[idx];
if (po.is_eltwise())
workers_.emplace_back(host, po.eltwise.alg, po.eltwise.alpha,
po.eltwise.beta, po.eltwise.scale, scratch, is_fwd);
po.eltwise.beta, po.eltwise.scale, eu_count, scratch,
is_fwd);
}
}

Expand Down

0 comments on commit 6224dc6

Please sign in to comment.