Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix cvrq check nan #28

Merged
merged 1 commit into from
Nov 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 204 additions & 18 deletions paddle/fluid/framework/fleet/box_wrapper_kernel.kps
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ limitations under the License. */
#include "xpu/kernel/xtdk_simd.h"

#ifdef TRACE_PROFILE
// #include "xpu/kernel/xtdk_io.h"
#include "xpu/kernel/xtdk_io.h"
#include <fstream>

// The producer side.
Expand Down Expand Up @@ -70,6 +70,15 @@ struct ExpandPushGetOp {
}
};

struct ExpandPushEmdGetOp {
__device__ float get(float* expand, const int& row,
const int& expand_id,
const int& hidden,
const int& expand_dim) const {
return expand[row * (hidden + expand_dim) + hidden + expand_id];
}
};

template <typename T>
__device__ void set_byfloat(float* dest, const T& val) {
(*reinterpret_cast<T*>(dest)) = val;
Expand Down Expand Up @@ -340,6 +349,152 @@ __global__ void PullCopyNNCross(const TEmbedxOp* op,
}
}

template <typename TEmbedxOp>
__global__ void PullCopyNNCrossWithEmb(const TEmbedxOp* op,
const float scale,
const boxps::FeaturePullOffset* info,
int* total_dims,
unsigned long long* dst_vals,
const int* key2slot,
float* total_values,
const uint32_t* restore_idx,
const int total_length,
const int max_cols_num,
const int hidden_size,
const int expand_embed_dim,
const int pull_float_num,
const int skip_offset,
const int cvm_offset,
const int slot_num) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = cluster_id() * ncores + cid;
int nthreads = cluster_num() * ncores;

const int buf_length = 5;
int per_thread_len = roundup_div(total_length, nthreads);
int per_thread_loop_count = roundup_div(per_thread_len, buf_length);
int per_thread_per_loop_len = roundup_div(per_thread_len, per_thread_loop_count);

__local__ float lm_total_values[buf_length * pull_float_num];
__local__ float lm_dst_vals[buf_length * hidden_size];
__local__ float lm_dst_expand_vals[buf_length * (hidden_size + expand_embed_dim)];
__local__ int lm_key2slot[buf_length];
__local__ int lm_total_dims[buf_length];
__local__ uint32_t lm_restore_idx[buf_length];
__local__ boxps::FeaturePullOffset lm_info[1];
__local__ TEmbedxOp lm_op[1];

const int max_slot_num = 1000;
int sm_slot_len = min(max_slot_num, slot_num);
__shared__ uint64_t sm_dst_vals_ptr[max_slot_num];
__shared__ uint64_t sm_dst_expand_vals_ptr[max_slot_num];
for (int i = cid; i < sm_slot_len; i += ncores) {
GM2SM(dst_vals + i, sm_dst_vals_ptr + i, sizeof(uint64_t));
GM2SM(dst_vals + slot_num + i, sm_dst_expand_vals_ptr + i, sizeof(uint64_t));
}
mfence();
xpu_sync_all();

__local__ uint64_t lm_dst_vals_ptr[1];
for(int i=0;i<slot_num;i++) {
if(sm_dst_vals_ptr[i] != 0) {
lm_dst_vals_ptr[0]=sm_dst_vals_ptr[i];
break;
}
}

GM2LM(info, lm_info, sizeof(boxps::FeaturePullOffset));
GM2LM(op, lm_op, sizeof(TEmbedxOp));
for (int i = thread_id; i < per_thread_loop_count * nthreads; i += nthreads) {
int gm_offset = i * per_thread_per_loop_len;
if (gm_offset >= total_length) {
return;
}

int len = min(per_thread_per_loop_len, total_length - gm_offset);
if(restore_idx != nullptr) {
GM2LM(restore_idx + gm_offset, lm_restore_idx, len * sizeof(uint32_t));
}
int pos = (restore_idx != nullptr) ? lm_restore_idx[gm_offset] : gm_offset;
GM2LM(total_values + pos * pull_float_num, lm_total_values, len * pull_float_num * sizeof(float));
GM2LM(total_dims + gm_offset, lm_total_dims, len * sizeof(int));
GM2LM(key2slot + gm_offset, lm_key2slot, len * sizeof(int));

for (int j = 0; j < len; j++) {
// mfence();
// cvm offset
for (int k = 0; k < cvm_offset; ++k) {
//TODO:consider xpu_value[slot_id]==nullptr?
if (sm_dst_vals_ptr[lm_key2slot[j]] != 0) {
lm_dst_vals[j * hidden_size + k] = lm_total_values[j * pull_float_num + lm_info[0].show + skip_offset + k];
}
if (sm_dst_expand_vals_ptr[lm_key2slot[j]] != 0) {
lm_dst_expand_vals[j * (hidden_size + expand_embed_dim) + k] = lm_total_values[j * pull_float_num + lm_info[0].show + skip_offset + k];
}
}

// embedx
// embedx flags + expand flags && *(keys[x] + y) != 0 && *(keys[x] + y)
int embedx_size = *((int *)&(lm_total_values[j * pull_float_num + lm_info[0].embedx_size]));
// int embedx_size = 0;
// TODO: expand_size = expand_embed_dim?
int expand_size = *((int *)&(lm_total_values[j * pull_float_num + lm_info[0].expand_size]));
lm_total_dims[j] = static_cast<int>(embedx_size > 0) | static_cast<int>((expand_size > 0) << 1);

if (sm_dst_vals_ptr[lm_key2slot[j]] != 0) {
for (int k = cvm_offset; k < cvm_offset + embedx_size; ++k) {
lm_op[0].copy(lm_dst_vals + j * hidden_size + k,
lm_total_values + j * pull_float_num + lm_info[0].embedx,
k - cvm_offset,
scale);
}

for (int k = cvm_offset + embedx_size; k < hidden_size; ++k) {
lm_dst_vals[j * hidden_size + k] = 0;
}
}

if (sm_dst_expand_vals_ptr[lm_key2slot[j]] != 0) {
for (int k = cvm_offset; k < cvm_offset + embedx_size; ++k) {
lm_op[0].copy(lm_dst_expand_vals + j * (hidden_size + expand_embed_dim) + k,
lm_total_values + j * pull_float_num + lm_info[0].embedx,
k - cvm_offset,
scale);
}

for (int k = cvm_offset + embedx_size; k < hidden_size; ++k) {
lm_dst_expand_vals[j * (hidden_size + expand_embed_dim) + k] = 0;
}
}

// expand
if (sm_dst_expand_vals_ptr[lm_key2slot[j]] == 0) {
continue;
}

for (int k = hidden_size; k < hidden_size + expand_size; ++k) {
lm_op[0].copy(lm_dst_expand_vals + j * (hidden_size + expand_embed_dim) + k,
lm_total_values + j * pull_float_num + lm_info[0].expand,
k - hidden_size,
scale);
}
for (int k = hidden_size + expand_size; k < max_cols_num; ++k) {
lm_dst_expand_vals[j * (hidden_size + expand_embed_dim) + k] = 0;
}
}
mfence();

LM2GM(lm_total_dims, total_dims + gm_offset, len * sizeof(int));
LM2GM(lm_dst_vals, ((__global_ptr__ float*)lm_dst_vals_ptr[0] + gm_offset * hidden_size), len * hidden_size * sizeof(float));
LM2GM(lm_dst_expand_vals, ((__global_ptr__ float*)lm_dst_vals_ptr[0] + total_length * hidden_size + gm_offset * (hidden_size + expand_embed_dim)), len * (hidden_size + expand_embed_dim) * sizeof(float));
mfence();
}
}

template <typename TEmbedxOp>
inline void FeaturePullCopyNNCross(
const paddle::platform::Place& place,
Expand Down Expand Up @@ -405,9 +560,22 @@ inline void FeaturePullCopyNNCross(
cvm_offset,
slot_num);
} else {
// PullCopyNNCrossWithEmb
// TODO:
CHECK(false) << "PullCopyNNCrossWithEmb not implement";
PullCopyNNCrossWithEmb<TEmbedxOp><<<8, 64, stream>>>(d_op,
scale,
info,
total_dims,
reinterpret_cast<unsigned long long*>(d_xpu_values),
key2slot,
total_values_xpu,
xpu_restore_idx,
total_length,
(hidden_size + expand_embed_dim),
hidden_size,
expand_embed_dim,
pull_float_num,
skip_offset,
cvm_offset,
slot_num);
}
xpu_free(d_xpu_values);
xpu_wait(stream);
Expand Down Expand Up @@ -816,21 +984,18 @@ inline void FeaturePushCopyNNCross(
auto ctx_xpu = static_cast<platform::XPUDeviceContext*>(dev_ctx)->x_context();
auto stream = ctx_xpu->xpu_stream;

auto d_op_tmp = memory::Alloc(place, sizeof(TExpandPushGetOp));
TExpandPushGetOp* d_op = reinterpret_cast<TExpandPushGetOp*>(d_op_tmp->ptr());
memory::Copy(place,
d_op,
platform::CPUPlace(),
op,
sizeof(TExpandPushGetOp));

#ifdef TRACE_PROFILE
TRACE_SCOPE_START("PushCopyNNCross", xpu_wait(stream));
#endif
if (expand_only) {
// TODO:
// if (d_sort_idx != nullptr){
// }
ExpandPushGetOp op;
auto d_op_tmp = memory::Alloc(place, sizeof(ExpandPushGetOp));
ExpandPushGetOp* d_op = reinterpret_cast<ExpandPushGetOp*>(d_op_tmp->ptr());
memory::Copy(place,
d_op,
platform::CPUPlace(),
&op,
sizeof(ExpandPushGetOp));
PushCopyNNCross<TExpandPushGetOp><<<8, 64, stream>>>(d_op,
info,
reinterpret_cast<unsigned long long*>(gm_src),//src
Expand All @@ -848,9 +1013,30 @@ inline void FeaturePushCopyNNCross(
skip_offset,
bs);
} else {
// PullCopyNNCrossWithEmb
// TODO:
CHECK(false) << "PullCopyNNCrossWithEmb not implement";
ExpandPushEmdGetOp op;
auto d_op_tmp = memory::Alloc(place, sizeof(ExpandPushEmdGetOp));
ExpandPushEmdGetOp* d_op = reinterpret_cast<ExpandPushEmdGetOp*>(d_op_tmp->ptr());
memory::Copy(place,
d_op,
platform::CPUPlace(),
&op,
sizeof(ExpandPushEmdGetOp));
PushCopyNNCross<ExpandPushEmdGetOp><<<8, 64, stream>>>(d_op,
info,
reinterpret_cast<unsigned long long*>(gm_src),//src
total_dims,
key2slot,
slot_vector,
slot_inner_offset,
push_grad_values,//dst
total_length,
hidden_size,
expand_embed_dim,
slot_num,
push_float_num,
cvm_offset,
skip_offset,
bs);
}
#ifdef TRACE_PROFILE
xpu_wait(stream);
Expand Down