Skip to content

Commit

Permalink
gpu: sycl: resampling: Reduce argument size (#1939)
Browse files Browse the repository at this point in the history
Co-authored-by: Denis Samoilov <[email protected]>
  • Loading branch information
sgeor255 and densamoilov committed Jun 3, 2024
1 parent 21a14cd commit 1ea366e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 61 deletions.
26 changes: 3 additions & 23 deletions src/gpu/sycl/ref_resampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,9 @@ namespace sycl {
status_t ref_resampling_fwd_t::pd_t::init_conf() {
conf_ = sycl_resampling_conf_t();

conf_.src_dt = src_md(0)->data_type;
conf_.dst_dt = dst_md()->data_type;

conf_.block_size = 16;
conf_.wg_size = 32;

conf_.MB = MB();
conf_.C = C();
conf_.ID = ID();
conf_.IH = IH();
conf_.IW = IW();
conf_.OD = OD();
conf_.OH = OH();
conf_.OW = OW();

for (int i = 0; i < DNNL_MAX_NDIMS; i++) {
conf_.dst_dims[i] = dst_md()->dims[i];
}
Expand All @@ -56,6 +44,9 @@ status_t ref_resampling_fwd_t::pd_t::init_conf() {
conf_.alg = desc()->alg_kind;
const auto *att = attr();
const auto &attr_po = att->post_ops_;
if (attr_po.len() > sycl_post_ops_t::max_post_ops) {
return dnnl_unimplemented;
}
conf_.po_len = attr_po.len();

for (auto i = 0; i < attr_po.len(); ++i) {
Expand Down Expand Up @@ -113,8 +104,6 @@ status_t ref_resampling_bwd_t::pd_t::init_conf() {
conf_.diff_src_md = xpu::sycl::md_t(diff_src_md(0));
conf_.diff_dst_md = xpu::sycl::md_t(diff_dst_md());

conf_.src_dt = src_md(0)->data_type;
conf_.dst_dt = dst_md()->data_type;
conf_.block_size = 16;
conf_.wg_size = 32;
conf_.dst_ndims = dst_md()->ndims;
Expand All @@ -125,15 +114,6 @@ status_t ref_resampling_bwd_t::pd_t::init_conf() {
conf_.n_thr = n_wgs * conf_.wg_size;
conf_.alg = desc()->alg_kind;

conf_.MB = MB();
conf_.C = C();
conf_.ID = ID();
conf_.IH = IH();
conf_.IW = IW();
conf_.OD = OD();
conf_.OH = OH();
conf_.OW = OW();

for (int i = 0; i < DNNL_MAX_NDIMS; i++) {
conf_.dst_dims[i] = dst_md()->dims[i];
}
Expand Down
62 changes: 38 additions & 24 deletions src/gpu/sycl/resampling_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,20 @@ struct resampling_kernel_fwd_vec_t {

void operator()(::sycl::nd_item<1> item) const {
size_t ithr = item.get_group(0) * conf_.wg_size + item.get_local_id();
dim_t MB = conf_.MB;
dim_t C = conf_.C;

dim_t ID = conf_.ID;
dim_t IH = conf_.IH;
dim_t IW = conf_.IW;
const auto &src_ndims = conf_.src_md.ndims();
const auto &src_dims = conf_.src_md.dims();
dim_t MB = src_dims[0];
dim_t C = src_dims[1];
dim_t ID = src_ndims >= 5 ? src_dims[src_ndims - 3] : 1;
dim_t IH = src_ndims >= 4 ? src_dims[src_ndims - 2] : 1;
dim_t IW = src_ndims >= 3 ? src_dims[src_ndims - 1] : 1;

dim_t OD = conf_.OD;
dim_t OH = conf_.OH;
dim_t OW = conf_.OW;
const auto &dst_ndims = conf_.dst_md.ndims();
const auto &dst_dims = conf_.dst_md.dims();
dim_t OD = dst_ndims >= 5 ? dst_dims[dst_ndims - 3] : 1;
dim_t OH = dst_ndims >= 4 ? dst_dims[dst_ndims - 2] : 1;
dim_t OW = dst_ndims >= 3 ? dst_dims[dst_ndims - 1] : 1;

auto lin_interp = [&](float c0, float c1, float w) {
return c0 * w + c1 * (1 - w);
Expand Down Expand Up @@ -229,16 +233,21 @@ struct resampling_kernel_bwd_vec_t {
void operator()(::sycl::nd_item<1> item) const {

size_t ithr = item.get_group(0) * conf_.wg_size + item.get_local_id();
dim_t MB = conf_.MB;
dim_t C = conf_.C;

dim_t ID = conf_.ID;
dim_t IH = conf_.IH;
dim_t IW = conf_.IW;
const auto &diff_src_ndims = conf_.diff_src_md.ndims();
const auto &diff_src_dims = conf_.diff_src_md.dims();
dim_t MB = diff_src_dims[0];
dim_t C = diff_src_dims[1];

dim_t OD = conf_.OD;
dim_t OH = conf_.OH;
dim_t OW = conf_.OW;
dim_t ID = diff_src_ndims >= 5 ? diff_src_dims[diff_src_ndims - 3] : 1;
dim_t IH = diff_src_ndims >= 4 ? diff_src_dims[diff_src_ndims - 2] : 1;
dim_t IW = diff_src_ndims >= 3 ? diff_src_dims[diff_src_ndims - 1] : 1;

const auto &diff_dst_ndims = conf_.diff_dst_md.ndims();
const auto &diff_dst_dims = conf_.diff_dst_md.dims();
dim_t OD = diff_dst_ndims >= 5 ? diff_dst_dims[diff_dst_ndims - 3] : 1;
dim_t OH = diff_dst_ndims >= 4 ? diff_dst_dims[diff_dst_ndims - 2] : 1;
dim_t OW = diff_dst_ndims >= 3 ? diff_dst_dims[diff_dst_ndims - 1] : 1;

const dim_t work_amount = MB * C * ID * IH * IW;
if (work_amount == 0) return;
Expand Down Expand Up @@ -308,16 +317,21 @@ struct resampling_kernel_bwd_vec1_t {

void operator()(::sycl::nd_item<1> item) const {
size_t ithr = item.get_group(0) * conf_.wg_size + item.get_local_id();
dim_t MB = conf_.MB;
dim_t C = conf_.C;

dim_t ID = conf_.ID;
dim_t IH = conf_.IH;
dim_t IW = conf_.IW;
const auto &diff_src_ndims = conf_.diff_src_md.ndims();
const auto &diff_src_dims = conf_.diff_src_md.dims();
dim_t MB = diff_src_dims[0];
dim_t C = diff_src_dims[1];

dim_t ID = diff_src_ndims >= 5 ? diff_src_dims[diff_src_ndims - 3] : 1;
dim_t IH = diff_src_ndims >= 4 ? diff_src_dims[diff_src_ndims - 2] : 1;
dim_t IW = diff_src_ndims >= 3 ? diff_src_dims[diff_src_ndims - 1] : 1;

dim_t OD = conf_.OD;
dim_t OH = conf_.OH;
dim_t OW = conf_.OW;
const auto &diff_dst_ndims = conf_.diff_dst_md.ndims();
const auto &diff_dst_dims = conf_.diff_dst_md.dims();
dim_t OD = diff_dst_ndims >= 5 ? diff_dst_dims[diff_dst_ndims - 3] : 1;
dim_t OH = diff_dst_ndims >= 4 ? diff_dst_dims[diff_dst_ndims - 2] : 1;
dim_t OW = diff_dst_ndims >= 3 ? diff_dst_dims[diff_dst_ndims - 1] : 1;

const dim_t work_amount = MB * C * ID * IH * IW;
if (work_amount == 0) return;
Expand Down
17 changes: 3 additions & 14 deletions src/gpu/sycl/sycl_primitive_conf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,24 +148,13 @@ struct sycl_reorder_conf_t {
};

struct sycl_resampling_conf_t {
dim_t MB;
dim_t C;
dim_t ID;
dim_t IH;
dim_t IW;
dim_t OD;
dim_t OH;
dim_t OW;
dims_t dst_dims;
int dst_ndims;
int po_len;
size_t work_amount;

data_type_t src_dt;
data_type_t dst_dt;

xpu::sycl::md_t src_md;
xpu::sycl::md_t src1_md[8];
xpu::sycl::md_t src1_md[sycl_post_ops_t::max_post_ops];
xpu::sycl::md_t dst_md;
xpu::sycl::md_t diff_src_md;
xpu::sycl::md_t diff_dst_md;
Expand All @@ -174,7 +163,6 @@ struct sycl_resampling_conf_t {
float src_scale;
bool do_scale_src;
int broadcast_dims[xpu::sycl::md_t::max_dims];
int ndims;
bool is_tensor_op;

int block_size;
Expand Down Expand Up @@ -314,7 +302,8 @@ struct sycl_lrn_conf_t {

struct sycl_pooling_conf_t {
xpu::sycl::md_t src_md;
xpu::sycl::md_t src1_md[8];
// The size "5" is lower than DNNL_MAX_NDIMS because only 5 dimension formats are supported.
xpu::sycl::md_t src1_md[5];
xpu::sycl::md_t dst_md;
xpu::sycl::md_t ws_md;
xpu::sycl::md_t diff_src_md;
Expand Down

0 comments on commit 1ea366e

Please sign in to comment.