Skip to content

Commit 826ea06

Browse files
metax666zhang-chenyiduqimengStareAtYoujxwangmetax
authored
[Metax]fix patch and fix missing kernel (#2024)
Co-authored-by: chezhang <[email protected]> Co-authored-by: duqimeng <[email protected]> Co-authored-by: Mingkun.Zhang <[email protected]> Co-authored-by: jiaxinWang-metax <[email protected]> Co-authored-by: MingkunZhang <[email protected]> Co-authored-by: zhang-chenyi <[email protected]> Co-authored-by: ZhouDuan <[email protected]> Co-authored-by: Theendlessofhell <[email protected]> Co-authored-by: root <[email protected]>
1 parent 6595afc commit 826ea06

34 files changed

+2127
-203
lines changed

.github/workflows/metax_work.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: padlle metax gpu test
1+
name: paddle metax gpu test
22

33
on:
44
workflow_dispatch:

backends/metax_gpu/CMakeLists.txt

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,9 @@ file(
326326
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/im2sequence_kernel.cu
327327
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/im2sequence_grad_kernel.cu
328328
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/increment_kernel.cu
329-
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu
329+
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/embedding_grad_add_to_kernel.cu
330+
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/cross_entropy_bwd_w_downcast.cu
331+
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu
330332
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu
331333
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_put_kernel.cu
332334
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu
@@ -535,6 +537,7 @@ file(
535537
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/clip_by_norm_kernel.cu
536538
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/uniform_random_batch_size_like_kernel.cu
537539
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/get_tensor_from_selected_rows_kernel.cu
540+
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu
538541
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/batch_norm_grad_kernel.cc
539542
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/batch_norm_kernel.cc
540543
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/empty_kernel.cc
@@ -643,6 +646,8 @@ file(
643646
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu
644647
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu
645648
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rms_norm_kernel.cu
649+
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/lars_momentum_kernel.cu
650+
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/partial_sum_kernel.cu
646651
# ############################################################################
647652
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu
648653
# kernels/kps
@@ -726,6 +731,11 @@ target_link_libraries(
726731
${WARPCTC_LIBRARIES}
727732
${WARPRNNT_LIBRARIES}
728733
${PADDLE_CORE_LIB})
734+
735+
target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmccl.so)
736+
target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmcFlashAttn.so)
737+
target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmcpti.so)
738+
729739
include_directories(BEFORE ${PADDLE_SOURCE_DIR})
730740

731741
target_compile_definitions(
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/core/kernel_registry.h"
16+
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
17+
#include "paddle/phi/kernels/selected_rows/adam_kernel.h"
18+
19+
PD_CUSTOM_KERNEL_REGISTER(adam_dense_param_sparse_grad,
20+
metax_gpu,
21+
ALL_LAYOUT,
22+
phi::sr::AdamDenseParamSparseGradKernel,
23+
float,
24+
double,
25+
phi::float16) {
26+
// Skip beta1_pow, beta2_pow, skip_update data transform
27+
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
28+
kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND);
29+
kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND);
30+
31+
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
32+
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
33+
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
34+
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
35+
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
36+
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
37+
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
38+
}
39+
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
40+
kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED);
41+
}
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/kernels/cross_entropy_grad_kernel.h"
16+
17+
#ifdef __NVCC__
18+
#include "cub/cub.cuh"
19+
#endif
20+
#ifdef __HIPCC__
21+
#include <hipcub/hipcub.hpp>
22+
namespace cub = hipcub;
23+
#endif
24+
25+
#include "kernels/gpudnn/softmax_gpudnn.h"
26+
#include "paddle/phi/backends/gpu/gpu_device_function.h"
27+
#include "paddle/phi/backends/gpu/gpu_dnn.h"
28+
#include "paddle/phi/common/amp_type_traits.h"
29+
#include "paddle/phi/core/kernel_registry.h"
30+
#include "paddle/phi/core/tensor_utils.h"
31+
#include "paddle/phi/core/visit_type.h"
32+
#include "paddle/phi/kernels/funcs/axis_utils.h"
33+
#include "paddle/phi/kernels/funcs/for_range.h"
34+
#include "paddle/phi/kernels/funcs/math_function.h"
35+
#include "paddle/phi/kernels/funcs/softmax.h"
36+
37+
namespace phi {
38+
39+
/*
40+
Vectorized wrapper of softmax with cross entropy grad hard label.
41+
Optimized with float4 vectorization for memory coalescing and improved
42+
throughput.
43+
*/
44+
template <typename T, typename LabelT, typename LogitT>
45+
__global__ void SoftmaxWithCrossEntropyGradHardLabelVectorized(
46+
LogitT* __restrict__ logits_grad,
47+
const T* __restrict__ loss_grad,
48+
const T* __restrict__ softmax,
49+
const LabelT* __restrict__ labels,
50+
const int64_t n,
51+
const int64_t dim,
52+
const int64_t d,
53+
const int ignore_index) {
54+
// Vectorized load/store with float4 for 128-bit memory transactions
55+
constexpr int VEC_SIZE = 4;
56+
using VecT = typename phi::AlignedVector<LogitT, VEC_SIZE>;
57+
using SoftmaxVecT = typename phi::AlignedVector<T, VEC_SIZE>;
58+
59+
int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
60+
int64_t vec_id = tid * VEC_SIZE;
61+
62+
// Ensure we don't exceed bounds
63+
if (vec_id >= n * dim * d) return;
64+
65+
// Compute indices for vectorized access
66+
int64_t idx_n = vec_id / (d * dim);
67+
int64_t idx_dim_start = (vec_id / d) % dim;
68+
int64_t idx_d = vec_id % d;
69+
int64_t ids = idx_n * d + idx_d;
70+
71+
// Load label once per thread
72+
auto lbl = static_cast<int64_t>(labels[ids]);
73+
74+
if (lbl == ignore_index) {
75+
// Vectorized zero fill for ignore_index
76+
VecT* vec_grad = reinterpret_cast<VecT*>(&logits_grad[vec_id]);
77+
VecT zero_vec;
78+
#pragma unroll
79+
for (int i = 0; i < VEC_SIZE; ++i) {
80+
zero_vec.val[i] = static_cast<LogitT>(0.0f);
81+
}
82+
*vec_grad = zero_vec;
83+
return;
84+
}
85+
86+
// Vectorized load of softmax values
87+
SoftmaxVecT softmax_vec;
88+
const SoftmaxVecT* softmax_ptr =
89+
reinterpret_cast<const SoftmaxVecT*>(&softmax[vec_id]);
90+
softmax_vec = *softmax_ptr;
91+
92+
// Load loss gradient (broadcast across vector elements)
93+
T loss_grad_val = loss_grad[ids];
94+
95+
// Vectorized computation
96+
VecT grad_vec;
97+
#pragma unroll
98+
for (int i = 0; i < VEC_SIZE; ++i) {
99+
int64_t current_dim = idx_dim_start + i;
100+
if (current_dim < dim) { // Bounds check for partial vectors
101+
float softmax_val = static_cast<float>(softmax_vec.val[i]);
102+
float grad_val;
103+
104+
if (lbl == current_dim) {
105+
grad_val = (softmax_val - 1.0f) * static_cast<float>(loss_grad_val);
106+
} else {
107+
grad_val = softmax_val * static_cast<float>(loss_grad_val);
108+
}
109+
110+
grad_vec.val[i] = static_cast<LogitT>(grad_val);
111+
} else {
112+
grad_vec.val[i] = static_cast<LogitT>(0.0f);
113+
}
114+
}
115+
116+
// Vectorized store
117+
VecT* grad_ptr = reinterpret_cast<VecT*>(&logits_grad[vec_id]);
118+
*grad_ptr = grad_vec;
119+
}
120+
121+
/*
122+
Specialized kernel for dimensions not divisible by vector size
123+
Uses warp-level primitives for better performance on irregular sizes
124+
*/
125+
template <typename T, typename LabelT, typename LogitT>
126+
__global__ void SoftmaxWithCrossEntropyGradHardLabelWarp(
127+
LogitT* __restrict__ logits_grad,
128+
const T* __restrict__ loss_grad,
129+
const T* __restrict__ softmax,
130+
const LabelT* __restrict__ labels,
131+
const int64_t n,
132+
const int64_t dim,
133+
const int64_t d,
134+
const int ignore_index) {
135+
const int warps_per_block = 4;
136+
const int threads_per_warp = 32;
137+
const int threads_per_block = warps_per_block * threads_per_warp;
138+
139+
int tid = blockIdx.x * threads_per_block + threadIdx.x;
140+
int warp_id = threadIdx.x / threads_per_warp;
141+
int lane_id = threadIdx.x % threads_per_warp;
142+
143+
// Process multiple elements per thread using warp-level parallelism
144+
int64_t elements_per_thread =
145+
(n * dim * d + gridDim.x * threads_per_block - 1) /
146+
(gridDim.x * threads_per_block);
147+
148+
for (int e = 0; e < elements_per_thread; ++e) {
149+
int64_t idx = tid + e * gridDim.x * threads_per_block;
150+
if (idx >= n * dim * d) break;
151+
152+
int64_t idx_n = idx / (d * dim);
153+
int64_t idx_dim = (idx / d) % dim;
154+
int64_t idx_d = idx % d;
155+
int64_t ids = idx_n * d + idx_d;
156+
157+
auto lbl = static_cast<int64_t>(labels[ids]);
158+
159+
if (lbl == ignore_index) {
160+
logits_grad[idx] = static_cast<LogitT>(0.0f);
161+
} else if (lbl == idx_dim) {
162+
logits_grad[idx] =
163+
static_cast<LogitT>((static_cast<float>(softmax[idx]) - 1.0f) *
164+
static_cast<float>(loss_grad[ids]));
165+
} else {
166+
logits_grad[idx] =
167+
static_cast<LogitT>(static_cast<float>(softmax[idx]) *
168+
static_cast<float>(loss_grad[ids]));
169+
}
170+
}
171+
}
172+
173+
/*
174+
Optimized kernel selector based on problem size and alignment
175+
*/
176+
template <typename T, typename LabelT, typename LogitT>
177+
void LaunchOptimizedCrossEntropyGradKernel(const GPUContext& dev_ctx,
178+
LogitT* logits_grad,
179+
const T* loss_grad,
180+
const T* softmax,
181+
const LabelT* labels,
182+
const int64_t n,
183+
const int64_t dim,
184+
const int64_t d,
185+
const int ignore_index) {
186+
const int64_t total_elements = n * dim * d;
187+
auto stream = dev_ctx.stream();
188+
189+
// Check alignment for vectorized kernel
190+
bool is_aligned = (reinterpret_cast<uintptr_t>(logits_grad) % 16 == 0) &&
191+
(reinterpret_cast<uintptr_t>(softmax) % 16 == 0) &&
192+
(total_elements % 4 == 0);
193+
194+
if (is_aligned && total_elements >= 1024) {
195+
// Use vectorized kernel for aligned, large problems
196+
constexpr int VEC_SIZE = 4;
197+
const int threads_per_block = 256;
198+
const int vec_elements = total_elements / VEC_SIZE;
199+
const int blocks =
200+
(vec_elements + threads_per_block - 1) / threads_per_block;
201+
202+
SoftmaxWithCrossEntropyGradHardLabelVectorized<T, LabelT, LogitT>
203+
<<<blocks, threads_per_block, 0, stream>>>(
204+
logits_grad, loss_grad, softmax, labels, n, dim, d, ignore_index);
205+
} else {
206+
// Use warp-specialized kernel for irregular sizes
207+
const int warps_per_block = 4;
208+
const int threads_per_block = warps_per_block * 32;
209+
const int blocks =
210+
std::min(1024,
211+
static_cast<int>((total_elements + threads_per_block - 1) /
212+
threads_per_block));
213+
214+
SoftmaxWithCrossEntropyGradHardLabelWarp<T, LabelT, LogitT>
215+
<<<blocks, threads_per_block, 0, stream>>>(
216+
logits_grad, loss_grad, softmax, labels, n, dim, d, ignore_index);
217+
}
218+
}
219+
220+
template <typename T, typename LabelT>
221+
void CrossEntropyWithSoftmaxBwdWithDowncastGPUKernel(
222+
const GPUContext& dev_ctx,
223+
const DenseTensor& label,
224+
const DenseTensor& softmax,
225+
const DenseTensor& loss_grad,
226+
int axis,
227+
DenseTensor* logits_grad) {
228+
// PADDLE_ENFORCE_EQ(
229+
// dev_ctx.GetPlace().GetType(),
230+
// phi::AllocationType::GPU,
231+
// common::errors::Unavailable("softmax_with_cross_entropy operator's "
232+
// "CUDA kernel only runs on GPU device."));
233+
234+
using LogitT = phi::bfloat16;
235+
const T* loss_grad_data = loss_grad.data<T>();
236+
DenseTensor* logit_grad = logits_grad;
237+
238+
LogitT* logit_grad_data = nullptr;
239+
logit_grad_data = dev_ctx.template Alloc<LogitT>(logit_grad);
240+
241+
const int rank = logit_grad->dims().size();
242+
const int axis_v = phi::funcs::CanonicalAxis(axis, rank);
243+
int axis_dim = logit_grad->dims()[axis_v];
244+
245+
const int64_t n = phi::funcs::SizeToAxis(axis_v, logit_grad->dims());
246+
const int64_t d = phi::funcs::SizeFromAxis(axis_v, logit_grad->dims());
247+
const int64_t remain = d / axis_dim;
248+
249+
const T* softmax_data = softmax.data<T>();
250+
const auto* label_data = label.data<LabelT>();
251+
252+
// Launch optimized kernel with automatic selection
253+
LaunchOptimizedCrossEntropyGradKernel<T, LabelT, LogitT>(dev_ctx,
254+
logit_grad_data,
255+
loss_grad_data,
256+
softmax_data,
257+
label_data,
258+
n,
259+
axis_dim,
260+
remain,
261+
-100);
262+
}
263+
264+
template <typename T, typename Context>
265+
void CrossEntropyWithSoftmaxBwdWithDowncastKernel(const Context& dev_ctx,
266+
const DenseTensor& label,
267+
const DenseTensor& softmax,
268+
const DenseTensor& loss_grad,
269+
DenseTensor* logits_grad) {
270+
constexpr int axis = -1;
271+
if (logits_grad->numel() == 0) {
272+
dev_ctx.template Alloc<phi::bfloat16>(logits_grad);
273+
return;
274+
}
275+
auto dtype = label.dtype();
276+
PD_VISIT_INTEGRAL_TYPES(
277+
dtype, "CrossEntropyWithSoftmaxBwdWithDowncastGPUKernel", ([&] {
278+
CrossEntropyWithSoftmaxBwdWithDowncastGPUKernel<T, data_t>(
279+
dev_ctx, label, softmax, loss_grad, axis, logits_grad);
280+
}));
281+
}
282+
283+
} // namespace phi
284+
285+
PD_REGISTER_PLUGIN_KERNEL(cross_entropy_with_softmax_bwd_w_downcast,
286+
metax_gpu,
287+
ALL_LAYOUT,
288+
phi::CrossEntropyWithSoftmaxBwdWithDowncastKernel,
289+
float,
290+
double,
291+
phi::float16) {}

0 commit comments

Comments
 (0)