Skip to content

Commit

Permalink
gpu: sycl: eltwise: Reduce size of kernel arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
sgeor255 committed May 27, 2024
1 parent e93576d commit a5a53e4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
20 changes: 9 additions & 11 deletions src/gpu/sycl/eltwise_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ struct eltwise_fwd_kernel_vec_t {

auto src = load_float_value(
src_md().data_type(), src_ptr(), src_offset);
auto dst = load_float_value(
dst_md().data_type(), dst_ptr(), src_offset);
auto dst = load_float_value(dst_dt(), dst_ptr(), src_offset);

dim_t data_l_off = (((n * conf_.c + c) * conf_.d + d) * conf_.h + h)
* conf_.w
Expand All @@ -70,7 +69,7 @@ struct eltwise_fwd_kernel_vec_t {

dst = compute_alg_n(src, conf_.alpha, conf_.beta, conf_.alg_kind);
dst = conf_.post_ops.apply(dst, post_po_sr);
store_float_value(dst_md().data_type(), dst, dst_ptr(), src_offset);
store_float_value(dst_dt(), dst, dst_ptr(), src_offset);
};

for (dim_t blk_idx = 0; blk_idx < conf_.block_size; blk_idx++) {
Expand All @@ -94,7 +93,7 @@ struct eltwise_fwd_kernel_vec_t {

private:
const xpu::sycl::md_t &src_md() const { return conf_.src_md; }
const xpu::sycl::md_t &dst_md() const { return conf_.dst_md; }
const data_type_t &dst_dt() const { return conf_.dst_dt; }

void *src_ptr() const { return src_.get_pointer(); }
void *dst_ptr() const { return dst_.get_pointer(); }
Expand Down Expand Up @@ -235,7 +234,7 @@ struct eltwise_fwd_kernel_vec_t {
auto src1_desc = conf_.binary_src_arr[idx];

const auto off = get_binary_src1_off(
src1_desc, offset, dst_md().dims(), dst_md().ndims());
src1_desc, offset, conf_.dst_dims, conf_.dst_ndims);

auto dst = load_float_value(
src1_desc.data_type(), bin_src_op.get_pointer(), off);
Expand Down Expand Up @@ -323,23 +322,22 @@ struct eltwise_bwd_kernel_vec_t {
for (dim_t i = 0; i < conf_.block_size; i++) {
dim_t idx = base_idx + i;
if (idx < conf_.wk_size) {
auto diff_src = load_float_value(
diff_src_md().data_type(), diff_src_ptr(), idx);
auto diff_src
= load_float_value(diff_src_dt(), diff_src_ptr(), idx);
auto src = load_float_value(
src_md().data_type(), src_ptr(), idx);

auto dst = compute_alg_n(
diff_src, src, conf_.alpha, conf_.beta, conf_.alg_kind);
store_float_value(
diff_dst_md().data_type(), dst, diff_dst_ptr(), idx);
store_float_value(diff_dst_dt(), dst, diff_dst_ptr(), idx);
}
}
}

private:
const xpu::sycl::md_t &src_md() const { return conf_.src_md; }
const xpu::sycl::md_t &diff_src_md() const { return conf_.diff_src_md; }
const xpu::sycl::md_t &diff_dst_md() const { return conf_.diff_dst_md; }
const data_type_t &diff_src_dt() const { return conf_.diff_src_dt; }
const data_type_t &diff_dst_dt() const { return conf_.diff_dst_dt; }

void *src_ptr() const { return src_.get_pointer(); }
void *diff_src_ptr() const { return diff_src_.get_pointer(); }
Expand Down
10 changes: 7 additions & 3 deletions src/gpu/sycl/ref_eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ using namespace impl::sycl;
status_t ref_sycl_eltwise_fwd_t::pd_t::init_conf() {
conf_ = sycl_eltwise_conf_t();
conf_.src_md = xpu::sycl::md_t(src_md());
conf_.dst_md = xpu::sycl::md_t(dst_md());
xpu::sycl::md_t sycl_dst_md(dst_md());
conf_.dst_dt = sycl_dst_md.data_type();
utils::array_copy(
conf_.dst_dims, sycl_dst_md.dims(), xpu::sycl::md_t::max_dims);
conf_.dst_ndims = sycl_dst_md.ndims();
conf_.wk_size = memory_desc_wrapper(src_md()).nelems();
conf_.alg_kind = desc()->alg_kind;
conf_.alpha = desc()->alpha;
Expand Down Expand Up @@ -90,8 +94,8 @@ status_t ref_sycl_eltwise_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
status_t ref_sycl_eltwise_bwd_t::pd_t::init_conf() {
conf_ = sycl_eltwise_conf_t();
conf_.src_md = xpu::sycl::md_t(data_md(0));
conf_.diff_src_md = xpu::sycl::md_t(diff_src_md());
conf_.diff_dst_md = xpu::sycl::md_t(diff_dst_md());
conf_.diff_src_dt = diff_src_md()->data_type;
conf_.diff_dst_dt = diff_dst_md()->data_type;
conf_.block_size = 16;
conf_.wg_size = 32;
conf_.wk_size = memory_desc_wrapper(data_md(0)).nelems();
Expand Down
8 changes: 5 additions & 3 deletions src/gpu/sycl/sycl_primitive_conf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,11 @@ struct sycl_binary_conf_t {
struct sycl_eltwise_conf_t {
prop_kind_t prop_kind;
xpu::sycl::md_t src_md;
xpu::sycl::md_t dst_md;
xpu::sycl::md_t diff_src_md;
xpu::sycl::md_t diff_dst_md;
data_type_t dst_dt;
xpu::sycl::md_t::dims32_t dst_dims;
xpu::sycl::md_t::dim32_t dst_ndims;
data_type_t diff_src_dt;
data_type_t diff_dst_dt;
alg_kind_t alg_kind;
float alpha;
float beta;
Expand Down

0 comments on commit a5a53e4

Please sign in to comment.