Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mfence for XPU2 KP #44258

Merged
merged 4 commits into from
Jul 19, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ __device__ __forceinline__ void WriteData(T _global_ptr_* dst,
T* src,
int num) {
if (num > 0) {
mfence_local();
LM2GM(src, dst, num * sizeof(T));
}
}
Expand Down Expand Up @@ -495,6 +496,7 @@ __device__ __inline__ void ReadData(T* dst,
}
Copy link
Contributor

@tiancaitzp tiancaitzp Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

402行, in_temp在for循环中,且看代码NX应该有可能>1,那么在403行scalar read之后,下一次循环则发生GM2LM, 所以应该在402行之前是否应该mfence一下

Copy link
Contributor

@tiancaitzp tiancaitzp Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

391,416, 494, 515, 571,627, 720行的in_temp看着也是同样,最好用模拟器的mfence检查工具跑一下,这样最保险

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经添加

}
} else { // core_num() * NX < num
mfence_local();
GM2LM(src + thread_offset, dst, NX * sizeof(T));
}
}
Expand All @@ -515,6 +517,7 @@ __device__ __inline__ void ReadData(T* dst,
}
}
} else { // core_num() * read_lens < num
mfence_local();
GM2LM(src + thread_offset, dst, read_lens * sizeof(T));
}
}
Expand Down Expand Up @@ -756,12 +759,12 @@ __device__ void WriteData(T _global_ptr_* dst,
for (int idx = 0; idx < read_lens; ++idx) {
if (idx + thread_offset < num) {
in_temp[0] = src[idx];
mfence();
mfence_local();
LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
}
}
} else { // core_num() * read_lens < num
mfence();
mfence_local();
LM2GM(src, dst + thread_offset, read_lens * sizeof(T));
}
}
Expand All @@ -776,10 +779,12 @@ __device__ void WriteData(T _global_ptr_* dst, const T* src, int num) {
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
in_temp[0] = src[idx];
mfence_local();
LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
}
}
} else { // core_num() * NX < num
mfence_local();
LM2GM(src, dst + thread_offset, NX * sizeof(T));
}
}
Expand Down Expand Up @@ -831,10 +836,12 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
if (IsBoundary) {
if (left_size_nx > 0) {
in_temp[0] = static_cast<Ty>(src[0]);
mfence_local();
LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
}
} else {
in_temp[0] = static_cast<Ty>(src[0]);
mfence_local();
LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
}
} else if (NX == 1) {
Expand All @@ -847,6 +854,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
}

in_temp[0] = static_cast<Ty>(src[idy]);
mfence_local();
LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(Ty));
}
} else if (NY == 1) { // for NY == 1 and NX != 1
Expand All @@ -859,6 +867,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
}

in_temp[0] = static_cast<Ty>(src[idx]);
mfence_local();
LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(Ty));
}
} else { // for NX != 1 and NY != 1
Expand All @@ -877,6 +886,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
}
}
in_temp[0] = static_cast<Ty>(src[idx + idy * NX]);
mfence_local();
LM2GM(in_temp,
dst + thread_offset + idx * stride_nx + idy * stride_ny,
sizeof(Ty));
Expand Down Expand Up @@ -1169,6 +1179,7 @@ __device__ __inline__ void ReadDataBcCanNotCmp(
if (index_src >= index_base && index_src < index_base + cache_size) {
in_temp = src_temp[index_src - index_base];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在1040行对in_temp发生了scalar read,注意一下1042行,这里GM2LM之前需要mfence

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经添加

} else {
mfence_local();
GM2LM(src + index_src, &in_temp, sizeof(T));
}
dst[nx] = in_temp;
Expand Down