Skip to content

Commit 6c77e6b

Browse files
committed
Merge branch 'develop' into b51
2 parents abfc92d + 425c14d commit 6c77e6b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1418
-120
lines changed

cmake/external/cub.cmake

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,16 @@ set(CUB_SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/cub)
2626

2727
if(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.6)
2828
# cuda_11.6/11.7/11.8‘s own cub is 1.15.0, which will cause compiling error in windows.
29-
set(CUB_TAG 1.16.0)
29+
set(CUB_TAG 2.1.0)
3030
execute_process(COMMAND git --git-dir=${CUB_SOURCE_DIR}/.git
3131
--work-tree=${CUB_SOURCE_DIR} checkout ${CUB_TAG})
32-
# cub 1.16.0 is not compatible with current thrust version
32+
# cub 2.1.0 is not compatible with current thrust version
3333
add_definitions(-DTHRUST_IGNORE_CUB_VERSION_CHECK)
34+
if(${CMAKE_CUDA_COMPILER_VERSION} EQUAL 11.8)
35+
set(cub_patches "${PADDLE_SOURCE_DIR}/patches/cub")
36+
message(STATUS "Add cub patches: ${cub_patches}")
37+
include_directories(${cub_patches})
38+
endif()
3439
else()
3540
set(CUB_TAG 1.8.0)
3641
endif()

cmake/third_party.cmake

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,9 @@ if(WITH_ONNXRUNTIME)
482482
endif()
483483

484484
if(WITH_GPU)
485-
if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
485+
if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0
486+
OR (${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.7
487+
AND ${CMAKE_CUDA_COMPILER_VERSION} LESS 11.9))
486488
include(external/cub) # download cub
487489
list(APPEND third_party_deps extern_cub)
488490
elseif(${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 12.0 AND WITH_SHARED_PHI)

paddle/fluid/distributed/collective/deep_ep/kernels/internode_ll.cu

Lines changed: 87 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -279,13 +279,33 @@ __global__ __launch_bounds__(
279279
rank * num_max_dispatch_tokens_per_rank * num_bytes_per_msg +
280280
slot_idx * num_bytes_per_msg;
281281
if (dst_rank != rank) {
282-
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
283-
src_ptr,
284-
num_bytes_per_msg,
285-
dst_rank,
286-
dst_expert_local_idx,
287-
lane_id,
288-
slot_idx);
282+
void* peer_base_addr = reinterpret_cast<void*>(
283+
__ldg(reinterpret_cast<const uint64_t*>(
284+
nvshmemi_device_state_d.peer_heap_base_p2p) +
285+
dst_rank));
286+
if (peer_base_addr) {
287+
char* req_rptr_actual =
288+
reinterpret_cast<char*>(peer_base_addr) +
289+
(reinterpret_cast<char*>(dst_ptr) -
290+
reinterpret_cast<char*>(nvshmemi_device_state_d.heap_base));
291+
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
292+
const auto* dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
293+
UNROLLED_WARP_COPY(8,
294+
lane_id,
295+
num_int4_per_msg,
296+
dst_int4_ptr,
297+
src_int4_ptr,
298+
ld_nc_global,
299+
st_na_global);
300+
} else {
301+
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
302+
src_ptr,
303+
num_bytes_per_msg,
304+
dst_rank,
305+
dst_expert_local_idx,
306+
lane_id,
307+
slot_idx);
308+
}
289309
} else {
290310
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
291311
const auto* src_int4_ptr = reinterpret_cast<const int4*>(src_ptr);
@@ -367,11 +387,24 @@ __global__ __launch_bounds__(
367387
responsible_expert_idx) != FINISHED_SUM_TAG * 2) {
368388
}
369389
if (dst_rank != rank) {
370-
nvshmemi_ibgda_amo_nonfetch_add(
371-
rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
372-
-num_tokens_sent - 1,
373-
dst_rank,
374-
dst_expert_local_idx);
390+
void* peer_base_addr = reinterpret_cast<void*>(
391+
__ldg(reinterpret_cast<const uint64_t*>(
392+
nvshmemi_device_state_d.peer_heap_base_p2p) +
393+
dst_rank));
394+
if (peer_base_addr) { // P2P enabled
395+
int* rptr_actual = reinterpret_cast<int*>(
396+
reinterpret_cast<char*>(peer_base_addr) +
397+
(reinterpret_cast<char*>(rdma_recv_count +
398+
dst_expert_local_idx * num_ranks + rank) -
399+
reinterpret_cast<char*>(nvshmemi_device_state_d.heap_base)));
400+
st_na_release(rptr_actual, -num_tokens_sent - 1);
401+
} else {
402+
nvshmemi_ibgda_amo_nonfetch_add(
403+
rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
404+
-num_tokens_sent - 1,
405+
dst_rank,
406+
dst_expert_local_idx);
407+
}
375408
} else {
376409
st_na_release(rdma_recv_count + dst_expert_local_idx * num_ranks + rank,
377410
-num_tokens_sent - 1);
@@ -691,13 +724,32 @@ __global__ __launch_bounds__(
691724
x_int4,
692725
ld_nc_global,
693726
st_na_global);
694-
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
695-
buf_ptr,
696-
hidden * sizeof(nv_bfloat16),
697-
dst_rank,
698-
local_expert_idx,
699-
lane_id,
700-
token_idx - offset);
727+
void* peer_base_addr = reinterpret_cast<void*>(
728+
__ldg(reinterpret_cast<const uint64_t*>(
729+
nvshmemi_device_state_d.peer_heap_base_p2p) +
730+
dst_rank));
731+
if (peer_base_addr) {
732+
char* req_rptr_actual =
733+
reinterpret_cast<char*>(peer_base_addr) +
734+
(reinterpret_cast<char*>(dst_ptr) -
735+
reinterpret_cast<char*>(nvshmemi_device_state_d.heap_base));
736+
const auto dst_int4_ptr = reinterpret_cast<int4*>(req_rptr_actual);
737+
UNROLLED_WARP_COPY(7,
738+
lane_id,
739+
hidden_bf16_int4,
740+
dst_int4_ptr,
741+
x_int4,
742+
ld_nc_global,
743+
st_na_global);
744+
} else {
745+
nvshmemi_ibgda_put_nbi_warp(dst_ptr,
746+
buf_ptr,
747+
hidden * sizeof(nv_bfloat16),
748+
dst_rank,
749+
local_expert_idx,
750+
lane_id,
751+
token_idx - offset);
752+
}
701753
}
702754
}
703755

@@ -710,8 +762,22 @@ __global__ __launch_bounds__(
710762
while (ld_acquire_global(atomic_clean_flag) == 0) {
711763
}
712764
if (dst_rank != rank) {
713-
nvshmemi_ibgda_amo_nonfetch_add(
714-
rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx);
765+
void* peer_base_addr = reinterpret_cast<void*>(
766+
__ldg(reinterpret_cast<const uint64_t*>(
767+
nvshmemi_device_state_d.peer_heap_base_p2p) +
768+
dst_rank));
769+
if (peer_base_addr) {
770+
int* req_rptr_actual = reinterpret_cast<int*>(
771+
reinterpret_cast<char*>(peer_base_addr) +
772+
(reinterpret_cast<char*>(rdma_recv_flag + global_expert_idx) -
773+
reinterpret_cast<char*>(nvshmemi_device_state_d.heap_base)));
774+
st_na_release(req_rptr_actual, 1);
775+
} else {
776+
nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx,
777+
1,
778+
dst_rank,
779+
local_expert_idx);
780+
}
715781
} else {
716782
st_na_release(rdma_recv_flag + global_expert_idx, 1);
717783
}

paddle/phi/infermeta/backward.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,8 +1747,8 @@ void UnStackGradInferMeta(const std::vector<const MetaTensor*>& out_grad,
17471747
rank));
17481748
if (axis < 0) axis += (rank + 1);
17491749

1750-
auto vec = common::vectorize<int>(input_dims[0]);
1751-
vec.insert(vec.begin() + axis, static_cast<int>(input_dims.size()));
1750+
auto vec = common::vectorize<int64_t>(input_dims[0]);
1751+
vec.insert(vec.begin() + axis, static_cast<int64_t>(input_dims.size()));
17521752
x_grad->set_dims(common::make_ddim(vec));
17531753
x_grad->set_dtype(out_grad[0]->dtype());
17541754
}

paddle/phi/infermeta/unary.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6075,7 +6075,7 @@ void UnStackInferMeta(const MetaTensor& x,
60756075
x_dim[axis],
60766076
num));
60776077
}
6078-
auto vec = common::vectorize<int>(x_dim);
6078+
auto vec = common::vectorize<int64_t>(x_dim);
60796079
vec.erase(vec.begin() + axis);
60806080
for (size_t i = 0; i < output_count; i++) {
60816081
outs[i]->set_dims(common::make_ddim(vec));

paddle/phi/kernels/cpu/concat_kernel.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ void ConcatKernel(const Context& dev_ctx,
4545
phi::DDim out_dims = phi::funcs::ComputeAndCheckShape(true, x_dims, axis);
4646
out->Resize(out_dims);
4747
dev_ctx.template Alloc<T>(out);
48-
48+
if (out->numel() == 0) {
49+
return;
50+
}
4951
// If axis is 0, the lod of the output is not the same as inputs.
5052
if (axis == 0 && !x[0]->lod().empty()) {
5153
size_t lod_size_0 = x[0]->lod().size();

paddle/phi/kernels/cpu/reduce_sum_kernel.cc

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,41 @@ void SumRawKernel(const Context& dev_ctx,
4444
out);
4545
return;
4646
}
47-
phi::Reduce<CPUContext, T, phi::funcs::SumFunctor>(
48-
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
47+
if constexpr (std::is_same_v<T, phi::dtype::float16> ||
48+
std::is_same_v<T, phi::dtype::bfloat16>) {
49+
DenseTensor x_fp32 = phi::Cast<T, Context>(dev_ctx, x, DataType::FLOAT32);
50+
DataType final_out_dtype = out_dtype;
51+
if (final_out_dtype == DataType::UNDEFINED) {
52+
final_out_dtype = x.dtype();
53+
}
54+
if (final_out_dtype == DataType::FLOAT32) {
55+
phi::Reduce<CPUContext, float, phi::funcs::SumFunctor>(
56+
dev_ctx,
57+
x_fp32,
58+
reduce_all,
59+
dims.GetData(),
60+
keep_dim,
61+
phi::DataType::UNDEFINED,
62+
out);
63+
} else {
64+
DenseTensor intermediate_result;
65+
intermediate_result.set_meta(out->meta());
66+
phi::Reduce<CPUContext, float, phi::funcs::SumFunctor>(
67+
dev_ctx,
68+
x_fp32,
69+
reduce_all,
70+
dims.GetData(),
71+
keep_dim,
72+
phi::DataType::UNDEFINED,
73+
&intermediate_result);
74+
75+
phi::CastKernel<float, Context>(
76+
dev_ctx, intermediate_result, final_out_dtype, out);
77+
}
78+
} else {
79+
phi::Reduce<CPUContext, T, phi::funcs::SumFunctor>(
80+
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
81+
}
4982
}
5083

5184
} // namespace phi

paddle/phi/kernels/cpu/temporal_shift_grad_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ void TemporalShiftGradKernel(const Context& dev_ctx,
8989
float shift_ratio,
9090
const std::string& data_format_str,
9191
DenseTensor* x_grad) {
92+
if (x_grad && x_grad->numel() == 0) {
93+
dev_ctx.template Alloc<T>(x_grad);
94+
return;
95+
}
9296
auto* input_grad = x_grad;
9397
auto* output_grad = &out_grad;
9498
int t = seg_num;

paddle/phi/kernels/cpu/temporal_shift_kernel.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ void TemporalShiftKernel(const Context& dev_ctx,
8989
float shift_ratio,
9090
const std::string& data_format_str,
9191
DenseTensor* out) {
92+
if (out && out->numel() == 0) {
93+
dev_ctx.template Alloc<T>(out);
94+
return;
95+
}
9296
auto* input = &x;
9397
auto* output = out;
9498
int t = seg_num;

paddle/phi/kernels/funcs/stack_and_unstack.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ void LaunchUnStackKernel(const Context& ctx,
210210
constexpr int kWarpSize = 32;
211211
constexpr int kMaxOut = 16;
212212

213-
int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1;
213+
int64_t tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1;
214214
if (split_dim < kMaxOut) {
215215
tid_y = split_dim;
216216
tid_x =
@@ -219,10 +219,13 @@ void LaunchUnStackKernel(const Context& ctx,
219219
} else {
220220
tid_y = kMaxOut;
221221
tid_x = kWarpSize;
222-
bid_y = backends::gpu::DivUp<int>(split_dim, kMaxOut);
222+
bid_y = backends::gpu::DivUp<int64_t>(split_dim, kMaxOut);
223223
}
224-
int tile_x_num = backends::gpu::DivUp<int>(out_row, tid_x);
225-
bid_x = std::min(tile_x_num, backends::gpu::kMultiDimslimit);
224+
int64_t tile_x_num = backends::gpu::DivUp<int64_t>(out_row, tid_x);
225+
if (tile_x_num < static_cast<int64_t>(backends::gpu::kMultiDimslimit))
226+
bid_x = tile_x_num;
227+
else
228+
bid_x = backends::gpu::kMultiDimslimit;
226229
dim3 blocks(tid_x, tid_y, 1);
227230
dim3 grids(bid_x, bid_y, 1);
228231

0 commit comments

Comments
 (0)