Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

abacus-aibox-901 fix the asq's pushcopy compatible error #29

Merged
merged 3 commits into from
Nov 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions paddle/fluid/framework/fleet/box_wrapper_kernel.kps
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ inline void FeaturePushCopyNNCross(
#endif
}

__global__ void PushCopy(float* src_vals,
__global__ void PushCopy(unsigned long long* src_vals,
float* dest_vals,
boxps::FeaturePushOffset* push_offset,
const int push_float_num,
Expand Down Expand Up @@ -902,9 +902,21 @@ __global__ void PushCopy(float* src_vals,
sm_slots[i] = lm_slot;
}

__shared__ uint64_t sm_src_vals_ptr[max_slot_num];
for (int i = cid; i < sm_slot_len; i += ncores) {
GM2SM(src_vals + i, sm_src_vals_ptr + i, sizeof(uint64_t));
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个slot数比较多的话,可以用多core进行并行


mfence();
xpu_sync_all();

__local__ uint64_t lm_src_vals_ptr[1];
for (int i = 0; i < slot_num; i++) {
if (sm_src_vals_ptr[i] != 0) {
lm_src_vals_ptr[0] = sm_src_vals_ptr[i];
break;
}
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个循环看着是所有core都拿到sm_src_vals_ptr[0]?符合预期吗

GM2LM(push_offset, &info, sizeof(boxps::FeaturePushOffset));

float scale = -1. * batch_size;
Expand All @@ -917,7 +929,7 @@ __global__ void PushCopy(float* src_vals,
int count_per_loop =
min(per_thread_per_loop_len, total_length - gm_offset);

GM2LM(src_vals + gm_offset * hidden_size, lm_src_vals,
GM2LM((__global_ptr__ float*)lm_src_vals_ptr[0] + gm_offset * hidden_size, lm_src_vals,
count_per_loop * hidden_size * sizeof(float));
GM2LM(total_dims + gm_offset, lm_total_dims,
count_per_loop * sizeof(int));
Expand Down Expand Up @@ -1024,14 +1036,7 @@ void BoxWrapperKernel::CopyForPush(
} else {
// FeaturePushCopy
// TODO:
float* real_gm_src_ptr;
for (int i = 0; i < slot_num; i++) {
if(gm_src_ptr[i] != 0) {
real_gm_src_ptr = const_cast<float*>(gm_src_ptr[i]);
break;
}
}
PushCopy<<<8, 64, stream>>>(real_gm_src_ptr, push_grad_values, push_offset,
PushCopy<<<8, 64, stream>>>(reinterpret_cast<unsigned long long*>(gm_src_ptr), push_grad_values, push_offset,
push_float_num_, c_total_length, hidden_size, batch_size, total_dims,
skip_offset, cvm_offset, key2slot, slots, slot_num);
}
Expand Down