diff --git a/paddle/fluid/framework/fleet/box_wrapper_kernel.kps b/paddle/fluid/framework/fleet/box_wrapper_kernel.kps index b00ac3031745f..24a8fa84468f4 100644 --- a/paddle/fluid/framework/fleet/box_wrapper_kernel.kps +++ b/paddle/fluid/framework/fleet/box_wrapper_kernel.kps @@ -858,7 +858,7 @@ inline void FeaturePushCopyNNCross( #endif } -__global__ void PushCopy(float* src_vals, +__global__ void PushCopy(unsigned long long* src_vals, float* dest_vals, boxps::FeaturePushOffset* push_offset, const int push_float_num, @@ -902,9 +902,21 @@ __global__ void PushCopy(float* src_vals, sm_slots[i] = lm_slot; } + __shared__ uint64_t sm_src_vals_ptr[max_slot_num]; + for (int i = cid; i < sm_slot_len; i += ncores) { + GM2SM(src_vals + i, sm_src_vals_ptr + i, sizeof(uint64_t)); + } + mfence(); xpu_sync_all(); + __local__ uint64_t lm_src_vals_ptr[1]; + for (int i = 0; i < slot_num; i++) { + if (sm_src_vals_ptr[i] != 0) { + lm_src_vals_ptr[0] = sm_src_vals_ptr[i]; + break; + } + } GM2LM(push_offset, &info, sizeof(boxps::FeaturePushOffset)); float scale = -1. * batch_size; @@ -917,7 +929,7 @@ __global__ void PushCopy(float* src_vals, int count_per_loop = min(per_thread_per_loop_len, total_length - gm_offset); - GM2LM(src_vals + gm_offset * hidden_size, lm_src_vals, + GM2LM((__global_ptr__ float*)lm_src_vals_ptr[0] + gm_offset * hidden_size, lm_src_vals, count_per_loop * hidden_size * sizeof(float)); GM2LM(total_dims + gm_offset, lm_total_dims, count_per_loop * sizeof(int)); @@ -1024,14 +1036,7 @@ void BoxWrapperKernel::CopyForPush( } else { // FeaturePushCopy // TODO: - float* real_gm_src_ptr; - for (int i = 0; i < slot_num; i++) { - if(gm_src_ptr[i] != 0) { - real_gm_src_ptr = const_cast(gm_src_ptr[i]); - break; - } - } - PushCopy<<<8, 64, stream>>>(real_gm_src_ptr, push_grad_values, push_offset, + PushCopy<<<8, 64, stream>>>(reinterpret_cast(gm_src_ptr), push_grad_values, push_offset, push_float_num_, c_total_length, hidden_size, batch_size, total_dims, skip_offset, cvm_offset, key2slot, slots, slot_num); }