Skip to content

Commit

Permalink
update for
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG committed Jul 13, 2022
1 parent 3d29736 commit 590a50a
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ __device__ __inline__ void ReadData(Ty* dst,
__local__ Tx in_temp[1];
// Each branch is added for better performance
if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1
mfence_local();
if (IsBoundary) {
if (left_size_nx > 0) {
GM2LM(src + thread_offset, in_temp, sizeof(Tx));
Expand All @@ -388,6 +389,7 @@ __device__ __inline__ void ReadData(Ty* dst,
break;
}
}
mfence_local();
GM2LM(src + thread_offset + idy * stride_ny, in_temp, sizeof(Tx));
dst[idy] = static_cast<Ty>(in_temp[0]);
}
Expand All @@ -399,6 +401,7 @@ __device__ __inline__ void ReadData(Ty* dst,
break;
}
}
mfence_local();
GM2LM(src + thread_offset + idx * stride_nx, in_temp, sizeof(Tx));
dst[idx] = static_cast<Ty>(in_temp[0]);
}
Expand All @@ -413,6 +416,7 @@ __device__ __inline__ void ReadData(Ty* dst,
}
}
int fix = thread_offset + idx * stride_nx + idy * stride_ny;
mfence_local();
GM2LM(src + fix, in_temp, sizeof(Tx));
dst[idy * NX + idx] = static_cast<Ty>(in_temp[0]);
}
Expand Down Expand Up @@ -485,18 +489,16 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ __inline__ void ReadData(T* dst,
const T _global_ptr_* src,
int num) {
mfence_local();
int thread_offset = core_id() * NX;
__local__ T in_temp[1];
if (IsBoundary) { // core_num() * NX > num
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
GM2LM(src + thread_offset + idx, in_temp, sizeof(T));
dst[idx] = in_temp[0];
GM2LM(src + thread_offset + idx, dst + idx, sizeof(T));
}
}
} else { // core_num() * NX < num
mfence_local();
GM2LM(src + thread_offset, dst, NX * sizeof(T));
}
}
Expand All @@ -507,17 +509,15 @@ __device__ __inline__ void ReadData(T* dst,
int num,
int read_lens) {
int thread_offset = core_id() * read_lens;
__local__ T in_temp[1];
mfence_local();
if (IsBoundary) { // core_num() * read_lens > num
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
if (idx + thread_offset < num) {
GM2LM(src + thread_offset + idx, in_temp, sizeof(T));
dst[idx] = in_temp[0];
GM2LM(src + thread_offset + idx, dst + idx, sizeof(T));
}
}
} else { // core_num() * read_lens < num
mfence_local();
GM2LM(src + thread_offset, dst, read_lens * sizeof(T));
}
}
Expand Down Expand Up @@ -610,8 +610,7 @@ __device__ __inline__ void ReadDataBc(T* dst,
int stride_ny) {
uint32_t thread_offset = block_offset + core_id();
uint32_t index_src = 0;
__local__ T in_temp[1];

mfence_local();
#pragma unroll
for (int ny = 0; ny < NY; ++ny) {
#pragma unroll
Expand All @@ -624,8 +623,7 @@ __device__ __inline__ void ReadDataBc(T* dst,
}
}
index_src = config(index_output);
GM2LM(src + index_src, in_temp, sizeof(T));
dst[nx + ny * NX] = in_temp[0];
GM2LM(src + index_src, dst + nx + ny * Nx, sizeof(T));
}
}
}
Expand Down Expand Up @@ -701,8 +699,10 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t index_src = index_cal(thread_offset + block_offset);
mfence_local();
GM2LM(src + index_src, in_temp, sizeof(Tx));
dst[ny] = static_cast<Ty>(func(in_temp[0]));

thread_offset += stride_ny;
}
} else {
Expand All @@ -717,6 +717,7 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t index_src = index_cal(thread_offset + block_offset);
mfence_local();
GM2LM(src + index_src, in_temp, sizeof(Tx));
dst[nx + ny * NX] = static_cast<Ty>(func(in_temp[0]));
thread_offset += stride_ny;
Expand Down Expand Up @@ -752,35 +753,30 @@ __device__ void WriteData(T _global_ptr_* dst,
int num,
int read_lens) {
int thread_offset = core_id() * read_lens;
__local__ T in_temp[1];
mfence_local();

if (IsBoundary) { // core_num() * read_lens > num
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
if (idx + thread_offset < num) {
in_temp[0] = src[idx];
mfence_local();
LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
LM2GM(src + idx, dst + idx + thread_offset, sizeof(T));
}
}
} else { // core_num() * read_lens < num
mfence_local();
LM2GM(src, dst + thread_offset, read_lens * sizeof(T));
}
}

template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ void WriteData(T _global_ptr_* dst, const T* src, int num) {
int thread_offset = core_id() * NX;
__local__ T in_temp[1];
mfence_local();

if (IsBoundary) { // core_num() * NX > num
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
in_temp[0] = src[idx];
mfence_local();
LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
LM2GM(src + idx, dst + idx + thread_offset, sizeof(T));
}
}
} else { // core_num() * NX < num
Expand Down Expand Up @@ -1039,6 +1035,7 @@ __device__ __inline__ void ReadDataBc1NMn(
for (int i = 0; i < last_col; i++) {
dst[i] = in_temp;
}
mfence_local();
GM2LM(src + index_base + 1, &in_temp, sizeof(T));
for (int i = 0; i < read_lens - last_col; i++) {
dst[last_col + i] = in_temp;
Expand Down Expand Up @@ -1093,6 +1090,7 @@ __device__ __inline__ void ReadDataBc1N1Mnk(
} else {
next_part_index = 0;
}
mfence_local();
GM2LM(src + next_part_index, &in_temp, sizeof(T));
for (int i = 0; i < read_lens - last_col; i++) {
dst[last_col + i] = in_temp;
Expand Down

0 comments on commit 590a50a

Please sign in to comment.