diff --git a/src/gpu/jit/conv/conv_kernel.hpp b/src/gpu/jit/conv/conv_kernel.hpp index 3c212494936..e9f8f6cead3 100644 --- a/src/gpu/jit/conv/conv_kernel.hpp +++ b/src/gpu/jit/conv/conv_kernel.hpp @@ -3625,7 +3625,8 @@ class ir_to_ngen_t : public ir_visitor_t { ir_to_ngen_t(ir_kernel_t *host, const expr_binding_t &expr_binding) : host_(host) , expr_binding_(expr_binding) - , simd_size_(host->getSIMD()) {} + , simd_size_(host->getSIMD()) + , eu_count_(host->hw_cfg_.eu_count()) {} ~ir_to_ngen_t() { #ifdef GEN_CONV_DEBUG @@ -4170,8 +4171,8 @@ class ir_to_ngen_t : public ir_visitor_t { auto &data_op = eltwise_t::arg_data(args); auto data_rd = data_op.reg_buf_data(); - jit_eltwise_injector_f32 inj( - host_, func.alg_kind, func.alpha, func.beta, func.scale); + jit_eltwise_injector_f32 inj(host_, func.alg_kind, func.alpha, + func.beta, func.scale, eu_count_); auto scratch = scope.alloc_range(inj.preferred_scratch_regs()); inj.set_scratch(scratch); inj.prepare(); @@ -4233,6 +4234,7 @@ class ir_to_ngen_t : public ir_visitor_t { ir_kernel_t *host_; expr_binding_t expr_binding_; int simd_size_; + int eu_count_; std::vector loop_end_labels_; diff --git a/src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp b/src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp index 89879a0bbf6..0dfc39a4372 100644 --- a/src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp +++ b/src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp @@ -12429,8 +12429,10 @@ bool gemm_kernel_generator_t::gemmUpdateC( // Prepare postop injector if configured. GRFRange postOpScratch; if (problem.hasPostOp()) { + // No EU count information available. + const int eu_count = 0; postOpInjector.reset(new Injector(this, problem.Ts.get_dnnl_type(), - problem.post_ops, GRFRange(), problem.postOpFwd)); + problem.post_ops, eu_count, GRFRange(), problem.postOpFwd)); if (!postOpInjector) stub(); postOpScratch = state.ra.try_alloc_range( diff --git a/src/gpu/jit/gemm/xe_hp_systolic_gemm.cpp b/src/gpu/jit/gemm/xe_hp_systolic_gemm.cpp index 7e8691b9b9d..cc980223e71 100644 --- a/src/gpu/jit/gemm/xe_hp_systolic_gemm.cpp +++ b/src/gpu/jit/gemm/xe_hp_systolic_gemm.cpp @@ -486,6 +486,9 @@ status_t xe_hp_systolic_gemm_t::init_compute_old(engine_t *engine) { cfg.tile_n = pd()->unroll_n(); cfg.global_3x_buf = (cfg.tile_n == 32); + auto compute_engine = utils::downcast(engine); + cfg.eu_count = compute_engine->device_info()->eu_count(); + if (pd()->with_c_zero_points()) cfg.co_type = cfg.c_type; else if (pd()->with_bias()) { diff --git a/src/gpu/jit/gemm/xe_hp_systolic_gemm_kernel.cpp b/src/gpu/jit/gemm/xe_hp_systolic_gemm_kernel.cpp index a6408c6c386..3773cac5604 100644 --- a/src/gpu/jit/gemm/xe_hp_systolic_gemm_kernel.cpp +++ b/src/gpu/jit/gemm/xe_hp_systolic_gemm_kernel.cpp @@ -1449,7 +1449,7 @@ void xehp_systolic_gemm_kernel_t::generate() { if (cfg.have_post_op()) { auto inj_ptr = new injector_t(this, data_type::f32, cfg.post_ops, - upost_op_scratch, cfg.post_op_is_fwd); + cfg.eu_count, upost_op_scratch, cfg.post_op_is_fwd); assert(inj_ptr); post_op_injector.reset(inj_ptr); } diff --git a/src/gpu/jit/gemm/xe_hp_systolic_gemm_kernel.hpp b/src/gpu/jit/gemm/xe_hp_systolic_gemm_kernel.hpp index 0b1c8d44286..848c0cf4f0d 100644 --- a/src/gpu/jit/gemm/xe_hp_systolic_gemm_kernel.hpp +++ b/src/gpu/jit/gemm/xe_hp_systolic_gemm_kernel.hpp @@ -55,6 +55,7 @@ class xehp_systolic_gemm_kernel_t : public jit_generator { bool c_packed = false; bool batch = false; bool emulate64 = (hw == ngen::HW::XeHPG); + int eu_count = 0; int tile_m = 32; int tile_n = 48; diff --git a/src/gpu/jit/jit_eltwise_injector.cpp b/src/gpu/jit/jit_eltwise_injector.cpp index 7ca5d1d3199..faa5523513c 100644 --- a/src/gpu/jit/jit_eltwise_injector.cpp +++ b/src/gpu/jit/jit_eltwise_injector.cpp @@ -40,7 +40,7 @@ int jit_eltwise_injector_f32::min_scratch_regs() { case eltwise_hardswish: return 1; case eltwise_log: return 0; case eltwise_logsigmoid: return 1; - case eltwise_mish: return 2; + case eltwise_mish: return 4; case eltwise_pow: return 1; case eltwise_relu: case eltwise_relu_use_dst_for_bwd: return (alpha_ == 0.f) ? 0 : 1; @@ -51,7 +51,7 @@ int jit_eltwise_injector_f32::min_scratch_regs() { case eltwise_square: return 0; case eltwise_swish: return 1; case eltwise_tanh: - case eltwise_tanh_use_dst_for_bwd: return 1; + case eltwise_tanh_use_dst_for_bwd: return 2; case eltwise_round: return 0; case eltwise_linear: return 0; case eltwise_bounded_relu: @@ -124,8 +124,8 @@ int jit_eltwise_injector_f32::max_batch_size() { case eltwise_logsigmoid: case eltwise_pow: case eltwise_soft_relu: - case eltwise_tanh: case eltwise_swish: return ss; + case eltwise_tanh: case eltwise_mish: case eltwise_gelu_erf: return ss / min_scratch_regs(); case eltwise_gelu_tanh: return ss & ~1; @@ -166,7 +166,8 @@ int jit_eltwise_injector_f32::phase_count(alg_kind_t alg) { case eltwise_soft_relu: return 9; case eltwise_swish: return 5; case eltwise_tanh: - case eltwise_tanh_use_dst_for_bwd: return 7; + case eltwise_tanh_use_dst_for_bwd: + return (use_tanh_compat()) ? 9 : 6; case eltwise_linear: return 2; case eltwise_bounded_relu: case eltwise_clip: @@ -255,17 +256,43 @@ void jit_eltwise_injector_f32::square_compute_fwd( template void jit_eltwise_injector_f32::tanh_compute_fwd( - int simd, const ngen::GRF &r, int phase, int off) { - const float log2e = 1.442695f; // log_2(e) - auto a = scratch_[off].f(); + int simd, const ngen::GRF &r, int phase, int off, int batch) { + const float log2e = 1.44269502162933349609375f; // log_2(e) + auto one_half = scratch_[0].f(7); + auto a = scratch_[off + batch].f(); switch (phase) { - case 0: h->mul(simd, a, abs(r), 2 * log2e); break; + case 0: h->mul(simd, a, abs(r), 2.f * log2e); break; case 1: h->exp(simd, a, a); break; - case 2: h->add(simd, a, a, 1.f); break; + case 2: h->mad(simd, a, one_half, a, one_half); break; case 3: h->inv(simd, a, a); break; - case 4: h->mul(simd, a, a, 2.f); break; - case 5: h->add(simd, a, -a, 1.f); break; - case 6: h->csel(simd | ge | f0[0], r, a, -a, r); break; + case 4: h->add(simd, a, -a, 1.f); break; + case 5: h->csel(simd | ge | f0[0], r, a, -a, r); break; + default: assert(!"invalid phase"); + } +} + +template +void jit_eltwise_injector_f32::tanh_compute_fwd_compat( + int simd, const ngen::GRF &r, int phase, int off, int batch) { + // This approximation of tanh(x) does not use the math.exp instruction + // that seems to be faulty on DG2-128; the exact formula is as follows: + // R = max(min(0.0519867*x*((x^2 + k)^2 + l)/((x^2 + m)^2 + n), 1), -1) + // Both absolute and relative errors are <7*10^-5 \forall x \in \mathbb R + auto k = scratch_[0].f(4); + auto l = scratch_[0].f(5); + auto m = scratch_[0].f(6); + auto n = scratch_[0].f(7); + auto a = scratch_[off + batch].f(); + switch (phase) { + case 0: h->mad(simd, a, m, r, r); break; + case 1: h->mad(simd, a, n, a, a); break; + case 2: h->inv(simd, a, a); break; + case 3: h->mul(simd, a, a, r); break; + case 4: h->mad(simd, r, k, r, r); break; + case 5: h->mad(simd, r, l, r, r); break; + case 6: h->mul(simd, r, r, 0.0519867f); break; // 0.051986694f + case 7: h->mul(simd | sat, r, r, abs(a)); break; + case 8: h->csel(simd | ge | f0[0], r, r, -r, a); break; default: assert(!"invalid phase"); } } @@ -385,6 +412,24 @@ void jit_eltwise_injector_f32::clip_prepare_bwd() { h->mov(1, pos_inf, pos_inf_imm); } +template +void jit_eltwise_injector_f32::tanh_prepare_fwd() { + auto one_half = scratch_[0].f(7); + h->mov(1, one_half, 0.5f); +} + +template +void jit_eltwise_injector_f32::tanh_prepare_fwd_compat() { + auto k = scratch_[0].f(4); + auto l = scratch_[0].f(5); + auto m = scratch_[0].f(6); + auto n = scratch_[0].f(7); + h->mov(1, k, 77.0954f); // 77.095392909578f + h->mov(1, l, -4435.55f); // -4435.54623970169f + h->mov(1, m, 17.06396f); // 17.06396485f + h->mov(1, n, -212.7724f); // -212.772646402036f +} + template void jit_eltwise_injector_f32::abs_compute_bwd( int simd, const ngen::GRF &r, int phase) { @@ -580,15 +625,20 @@ void jit_eltwise_injector_f32::logsigmoid_compute_fwd( template void jit_eltwise_injector_f32::mish_compute_fwd( int simd, const ngen::GRF &r, int phase, int off, int batch) { - auto temp = scratch_[off].f(); - auto temp2 = scratch_[off + batch].f(); + auto temp = scratch_[off + batch].f(); + auto temp2 = scratch_[off + 2 * batch].f(); const int srelu_phases = phase_count(alg_kind::eltwise_soft_relu); const int tanh_phases = phase_count(alg_kind::eltwise_tanh); - // note tanh_compute_fwd will trash temp + // note tanh_compute_fwd_* clobbers scratch_[off] and scratch_[off + batch] if (phase < srelu_phases) soft_relu_compute_fwd_inner(simd, r, temp, temp2, phase, off); - if (phase >= srelu_phases && phase < srelu_phases + tanh_phases) - tanh_compute_fwd(simd, temp2, phase - srelu_phases, off); + if (phase >= srelu_phases && phase < srelu_phases + tanh_phases) { + if (use_tanh_compat()) + tanh_compute_fwd_compat( + simd, temp2, phase - srelu_phases, off, batch); + else + tanh_compute_fwd(simd, temp2, phase - srelu_phases, off, batch); + } if (phase == srelu_phases + tanh_phases) h->mul(simd, r, r, temp2); if (phase > srelu_phases + tanh_phases) assert(!"invalid phase"); } @@ -687,7 +737,11 @@ void jit_eltwise_injector_f32::compute(const ngen::GRFRange ®s) { break; case eltwise_tanh: case eltwise_tanh_use_dst_for_bwd: - tanh_compute_fwd(simd, base, phase, ii); + if (use_tanh_compat()) + tanh_compute_fwd_compat( + simd, base, phase, ii, batch); + else + tanh_compute_fwd(simd, base, phase, ii, batch); break; case eltwise_round: round_compute_fwd(simd, base); @@ -752,7 +806,16 @@ void jit_eltwise_injector_f32::prepare() { assert(scratch_.getLen() >= min_scratch_regs()); if (is_fwd_) { - /* nothing to do */ + switch (alg_) { + case eltwise_mish: + case eltwise_tanh: + if (use_tanh_compat()) + tanh_prepare_fwd_compat(); + else + tanh_prepare_fwd(); + break; + default: break; + } } else { switch (alg_) { case eltwise_relu: relu_prepare_bwd(); break; diff --git a/src/gpu/jit/jit_eltwise_injector.hpp b/src/gpu/jit/jit_eltwise_injector.hpp index f3040266219..08c4a3905bc 100644 --- a/src/gpu/jit/jit_eltwise_injector.hpp +++ b/src/gpu/jit/jit_eltwise_injector.hpp @@ -46,7 +46,7 @@ inline bool jit_eltwise_injector_f32_is_supported(alg_kind_t alg) { template struct jit_eltwise_injector_f32 { jit_eltwise_injector_f32(jit_generator *host, alg_kind_t alg, - float alpha, float beta, float scale, + float alpha, float beta, float scale, int eu_count, const ngen::GRFRange &scratch = ngen::GRFRange(), bool is_fwd = true) : alg_(alg) @@ -54,6 +54,7 @@ struct jit_eltwise_injector_f32 { , beta_(beta) , scale_(scale) , is_fwd_(is_fwd) + , eu_count_(eu_count) , h(host) , scratch_(scratch) { @@ -76,16 +77,27 @@ struct jit_eltwise_injector_f32 { const float scale_; const bool is_fwd_; + const int eu_count_; + jit_generator *h; ngen::GRFRange scratch_; + bool is_gpu(ngen::HW arg_hw, int arg_eu_count) const { + return (hw == arg_hw) && (eu_count_ == arg_eu_count); + } + bool use_tanh_compat() const { + return is_gpu(ngen::HW::XeHPG, 96) || is_gpu(ngen::HW::XeHPG, 128); + } + int max_batch_size(); int phase_count(alg_kind_t alg); void relu_prepare_bwd(); void abs_prepare_bwd(); void clip_prepare_bwd(); + void tanh_prepare_fwd(); + void tanh_prepare_fwd_compat(); void relu_zero_ns_compute_fwd(int simd, const ngen::GRF &r); void relu_compute_fwd(int simd, const ngen::GRF &r, int phase, int off); @@ -112,7 +124,10 @@ struct jit_eltwise_injector_f32 { void square_compute_fwd(int simd, const ngen::GRF &r); void round_compute_fwd(int simd, const ngen::GRF &r); void swish_compute_fwd(int simd, const ngen::GRF &r, int phase, int off); - void tanh_compute_fwd(int simd, const ngen::GRF &r, int phase, int off); + void tanh_compute_fwd( + int simd, const ngen::GRF &r, int phase, int off, int batch); + void tanh_compute_fwd_compat( + int simd, const ngen::GRF &r, int phase, int off, int batch); void linear_compute_fwd(int simd, const ngen::GRF &r, int phase); void clip_compute_fwd( int simd, const ngen::GRF &r, int phase, float alpha, float beta); diff --git a/src/gpu/jit/jit_post_op_injector.hpp b/src/gpu/jit/jit_post_op_injector.hpp index 6331c8e5dcd..7ede10885cb 100644 --- a/src/gpu/jit/jit_post_op_injector.hpp +++ b/src/gpu/jit/jit_post_op_injector.hpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright 2021 Intel Corporation + * Copyright 2021-2022 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +47,7 @@ inline bool jit_post_op_injector_is_supported( template struct jit_post_op_injector { jit_post_op_injector(jit_generator *host, data_type_t accumulator_type, - const post_ops_t &post_ops, + const post_ops_t &post_ops, int eu_count, const ngen::GRFRange &scratch = ngen::GRFRange(), bool is_fwd = true) : post_ops_(post_ops), is_fwd_(is_fwd), scratch_(scratch) { @@ -57,7 +57,8 @@ struct jit_post_op_injector { const auto &po = post_ops.entry_[idx]; if (po.is_eltwise()) workers_.emplace_back(host, po.eltwise.alg, po.eltwise.alpha, - po.eltwise.beta, po.eltwise.scale, scratch, is_fwd); + po.eltwise.beta, po.eltwise.scale, eu_count, scratch, + is_fwd); } }