Skip to content

Commit

Permalink
gpu: ocl: switch eltwise to use dim_t data type
Browse files Browse the repository at this point in the history
  • Loading branch information
rjoursler authored and vpirogov committed May 12, 2023
1 parent 36bf079 commit 6ce52eb
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 48 deletions.
40 changes: 20 additions & 20 deletions src/gpu/ocl/gen9_eltwise.cl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2022 Intel Corporation
* Copyright 2020-2023 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,15 +22,15 @@

KERNEL_ATTR
__kernel void gen9_eltwise_fwd(__global DATA_T *src, __global DATA_T *dst,
int nelems, float alpha, float beta) {
const uint grsize = get_local_size(0);
const uint grid = get_group_id(0);
const uint sgid = get_sub_group_id();
const uint lid = get_sub_group_local_id();
dim_t nelems, float alpha, float beta) {
const dim_t grsize = get_local_size(0);
const dim_t grid = get_group_id(0);
const dim_t sgid = get_sub_group_id();
const dim_t lid = get_sub_group_local_id();

const uint gid = get_global_id(0);
const dim_t gid = get_global_id(0);

ptrdiff_t offset
dim_t offset
= (grid * grsize + sgid * get_max_sub_group_size()) * VECT_DT_N;

// grsize is a multiple of 16, SIMD is 16 -> offset mod 16 = 0
Expand All @@ -40,15 +40,15 @@ __kernel void gen9_eltwise_fwd(__global DATA_T *src, __global DATA_T *dst,
__global BLOCK_DATA_T *write_pos = (__global BLOCK_DATA_T *)dst + offset;

VECT_DATA_T val;
const uint nel_per_read = SIMD * VECT_DT_N;
const int nel_per_read = SIMD * VECT_DT_N;

// READ
if (!NELEMS_OVERFLOW || offset + nel_per_read < nelems) {
val = AS_VECT_DATA_T(VECT_BLOCK_READ(read_pos));

} else {
// read data in the same access pattern block_reads would
uint pos = offset + lid;
dim_t pos = offset + lid;
for (int i = 0; i < VECT_DT_N && pos < nelems; ++i) {
val[i] = src[pos];
pos += SIMD;
Expand All @@ -66,7 +66,7 @@ __kernel void gen9_eltwise_fwd(__global DATA_T *src, __global DATA_T *dst,
VECT_BLOCK_WRITE(write_pos, AS_VECT_BLOCK_DATA_T(val));

} else {
uint pos = offset + lid;
dim_t pos = offset + lid;
for (int i = 0; i < VECT_DT_N && pos < nelems; ++i) {
dst[pos] = val[i];
pos += SIMD;
Expand All @@ -76,13 +76,13 @@ __kernel void gen9_eltwise_fwd(__global DATA_T *src, __global DATA_T *dst,

KERNEL_ATTR
__kernel void gen9_eltwise_bwd(__global DATA_T *src, __global DATA_T *diff_src,
__global DATA_T *diff_dst, int nelems, float alpha, float beta) {
const uint grsize = get_local_size(0);
const uint grid = get_group_id(0);
const uint sgid = get_sub_group_id();
const uint lid = get_sub_group_local_id();
__global DATA_T *diff_dst, dim_t nelems, float alpha, float beta) {
const dim_t grsize = get_local_size(0);
const dim_t grid = get_group_id(0);
const dim_t sgid = get_sub_group_id();
const dim_t lid = get_sub_group_local_id();

ptrdiff_t offset = (grid * grsize + sgid * SIMD) * VECT_DT_N;
dim_t offset = (grid * grsize + sgid * SIMD) * VECT_DT_N;
//TODO: It should be implemented two distinct offsets
//The one for src and the second for diff_src

Expand All @@ -97,7 +97,7 @@ __kernel void gen9_eltwise_bwd(__global DATA_T *src, __global DATA_T *diff_src,

VECT_DATA_T val_dd;
VECT_DATA_T val_src;
const uint nel_per_read = SIMD * VECT_DT_N;
const int nel_per_read = SIMD * VECT_DT_N;

// READ
if (!NELEMS_OVERFLOW || offset + nel_per_read < nelems) {
Expand All @@ -106,7 +106,7 @@ __kernel void gen9_eltwise_bwd(__global DATA_T *src, __global DATA_T *diff_src,

} else {
// read data in the same access pattern block_reads would
uint pos = offset + lid;
dim_t pos = offset + lid;
for (int i = 0; i < VECT_DT_N && pos < nelems; ++i) {
val_dd[i] = diff_dst[pos];
val_src[i] = src[pos];
Expand All @@ -126,7 +126,7 @@ __kernel void gen9_eltwise_bwd(__global DATA_T *src, __global DATA_T *diff_src,

} else {
// write data in the same access pattern block_writes would
uint pos = offset + lid;
dim_t pos = offset + lid;
for (int i = 0; i < VECT_DT_N && pos < nelems; ++i) {
diff_src[pos] = val_dd[i];
pos += SIMD;
Expand Down
12 changes: 6 additions & 6 deletions src/gpu/ocl/gen9_eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ status_t gen9_eltwise_fwd_t::execute_forward_dense(
auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);

const memory_desc_wrapper src_d(pd()->src_md());
const int nelems = src_d.nelems(pd()->conf.with_zero_padding);
const dim_t nelems = src_d.nelems(pd()->conf.with_zero_padding);
const float alpha = pd()->desc()->alpha;
const float beta = pd()->desc()->beta;

Expand All @@ -114,8 +114,8 @@ status_t gen9_eltwise_fwd_t::execute_forward_dense(
arg_list.set(3, alpha);
arg_list.set(4, beta);

size_t lws = pd()->conf.work_group_size;
size_t total_wi = utils::div_up(nelems, pd()->conf.vector_size);
dim_t lws = pd()->conf.work_group_size;
dim_t total_wi = utils::div_up(nelems, pd()->conf.vector_size);
compute::nd_range_t nd_range({utils::rnd_up(total_wi, lws)}, {lws});

status = parallel_for(ctx, nd_range, kernel_, arg_list);
Expand Down Expand Up @@ -156,7 +156,7 @@ status_t gen9_eltwise_bwd_t::execute_backward_dense(
auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);

const memory_desc_wrapper data_d(pd()->data_md());
const int nelems = data_d.nelems(pd()->conf.with_zero_padding);
const dim_t nelems = data_d.nelems(pd()->conf.with_zero_padding);
const float alpha = pd()->desc()->alpha;
const float beta = pd()->desc()->beta;

Expand All @@ -168,8 +168,8 @@ status_t gen9_eltwise_bwd_t::execute_backward_dense(
arg_list.set(4, alpha);
arg_list.set(5, beta);

size_t lws = pd()->conf.work_group_size;
size_t total_wi = utils::div_up(nelems, pd()->conf.vector_size);
dim_t lws = pd()->conf.work_group_size;
dim_t total_wi = utils::div_up(nelems, pd()->conf.vector_size);
compute::nd_range_t nd_range({utils::rnd_up(total_wi, lws)}, {lws});

status = parallel_for(ctx, nd_range, kernel_, arg_list);
Expand Down
44 changes: 22 additions & 22 deletions src/gpu/ocl/ref_eltwise.cl
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@
__kernel void ref_eltwise_fwd(__global DATA_T *src, __global DATA_T *dst,
float alpha, float beta POST_OP_ARGS) {
#if USE_GWS_GET
int d0 = GWS_GET_D0();
int d1 = GWS_GET_D1();
int d2 = GWS_GET_D2();
int d3 = GWS_GET_D3();
int d4 = GWS_GET_D4();
int d5 = GWS_GET_D5();
dim_t d0 = GWS_GET_D0();
dim_t d1 = GWS_GET_D1();
dim_t d2 = GWS_GET_D2();
dim_t d3 = GWS_GET_D3();
dim_t d4 = GWS_GET_D4();
dim_t d5 = GWS_GET_D5();

const size_t data_off = DATA_OFF(d0, d1, d2, d3, d4, d5);
const dim_t data_off = DATA_OFF(d0, d1, d2, d3, d4, d5);

if (d0 >= DATA_D0 || d1 >= DATA_D1 || d2 >= DATA_D2 || d3 >= DATA_D3
|| d4 >= DATA_D4 || d5 >= DATA_D5) {
dst[data_off] = CONVERT_DATA_T(0.f);
return;
}
#else
const size_t data_off = get_global_id(0)
const dim_t data_off = get_global_id(0)
#if GWS1 > 1
+ get_global_id(1) * GWS0
#endif
Expand All @@ -51,12 +51,12 @@ __kernel void ref_eltwise_fwd(__global DATA_T *src, __global DATA_T *dst,
#endif
;

const int d0 = 0;
const int d1 = 0;
const int d2 = 0;
const int d3 = 0;
const int d4 = 0;
const int d5 = 0;
const dim_t d0 = 0;
const dim_t d1 = 0;
const dim_t d2 = 0;
const dim_t d3 = 0;
const dim_t d4 = 0;
const dim_t d5 = 0;
#endif

#if DT_F16 == 1
Expand All @@ -83,15 +83,15 @@ __kernel void ref_eltwise_fwd(__global DATA_T *src, __global DATA_T *dst,
__kernel void ref_eltwise_bwd(__global DATA_T *src, __global DATA_T *diff_src,
__global DATA_T *diff_dst, float alpha, float beta) {

int d0 = GWS_GET_D0();
int d1 = GWS_GET_D1();
int d2 = GWS_GET_D2();
int d3 = GWS_GET_D3();
int d4 = GWS_GET_D4();
int d5 = GWS_GET_D5();
dim_t d0 = GWS_GET_D0();
dim_t d1 = GWS_GET_D1();
dim_t d2 = GWS_GET_D2();
dim_t d3 = GWS_GET_D3();
dim_t d4 = GWS_GET_D4();
dim_t d5 = GWS_GET_D5();

const size_t data_off = DATA_OFF(d0, d1, d2, d3, d4, d5);
const size_t diff_data_off = DIFF_DATA_OFF(d0, d1, d2, d3, d4, d5);
const dim_t data_off = DATA_OFF(d0, d1, d2, d3, d4, d5);
const dim_t diff_data_off = DIFF_DATA_OFF(d0, d1, d2, d3, d4, d5);

if (d0 >= DATA_D0 || d1 >= DATA_D1 || d2 >= DATA_D2 || d3 >= DATA_D3
|| d4 >= DATA_D4 || d5 >= DATA_D5) {
Expand Down

0 comments on commit 6ce52eb

Please sign in to comment.