Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
156 commits
Select commit Hold shift + click to select a range
fd28881
[Metax_change_ut]
duqimeng Jul 23, 2025
a9d2aa7
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Jul 24, 2025
1695f36
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Jul 31, 2025
b931d38
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 1, 2025
bef21bf
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 8, 2025
f4e5004
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 13, 2025
55422eb
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 18, 2025
815a63a
Merge branch 'PaddlePaddle:develop' into develop
duqimeng Aug 19, 2025
1739a15
fix sum&collect_fpn_proposals op register
StareAtYou Aug 19, 2025
af0bae5
fix sum&collect_fpn_proposals op register
metax666 Aug 19, 2025
be61f06
modify profile
jxwangmetax Aug 20, 2025
0fc2dd1
modify profile
metax666 Aug 20, 2025
1ad95c5
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 20, 2025
f12b3e4
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 21, 2025
789c9fc
[Metax] fix paddle bug replace 'MoeGradDispatchKernel' to 'MoeGateDis…
StareAtYou Aug 21, 2025
a0116fb
[Metax] fix paddle bug
metax666 Aug 21, 2025
a2da5e0
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 22, 2025
f9e6d2c
[Metax] register bce_loss_grad & bce_loss & index_add_grad kernels
StareAtYou Aug 22, 2025
4b4f562
Merge branch 'develop' of https://github.com/duqimeng/PaddleCustomDev…
duqimeng Aug 22, 2025
662e22e
[Metax] con2d_grad use gpudnn
duqimeng Aug 22, 2025
3e8d6ce
Merge branch 'metax666:develop' into develop
StareAtYou Aug 25, 2025
9dae9b7
[Metax] register bce_loss_grad & bce_loss & index_add_grad kernels
metax666 Aug 25, 2025
47fef62
blas handle support
jxwangmetax Aug 25, 2025
266c0df
blas handle support
metax666 Aug 25, 2025
a0b340b
[Metax] register some kernels & update CMakeLists
StareAtYou Aug 25, 2025
aa9bd35
Merge branch 'metax666:develop' into develop
StareAtYou Aug 26, 2025
8c6ac05
[Metax] register some kernels & update CMakeLists
metax666 Aug 26, 2025
9510f7d
Merge branch 'metax666:develop' into develop
duqimeng Aug 26, 2025
fa7cc1a
[Metax] fix metax unittest fail
StareAtYou Aug 26, 2025
a907545
[Metax] fix metax unittest fail
metax666 Aug 26, 2025
7a6312e
[Metax] add group_norm & label_smooth kernel and update matmul kernel
StareAtYou Aug 26, 2025
90bb94e
[Metax] add group_norm & label_smooth kernel and update matmul kernel
metax666 Aug 27, 2025
9f130fe
[Metax] fix rmsprop kernel register and add meshgrid & meshgrid_grad …
StareAtYou Aug 27, 2025
ca38fb5
Merge branch 'metax666:develop' into develop
StareAtYou Aug 27, 2025
f0cc1e0
add test
zhang-chenyi Aug 27, 2025
8e8b732
add test
zhang-chenyi Aug 27, 2025
8d7efbd
Merge branch 'metax666:develop' into develop
zhang-chenyi Aug 27, 2025
28c992b
Merge branch 'develop' of https://github.com/zhang-chenyi/PaddleCusto…
zhang-chenyi Aug 27, 2025
d3470bb
[test] chang the logic of workspace_host in cholesky_kernel_register
zhang-chenyi Aug 27, 2025
db17ebf
Merge branch 'develop' of https://github.com/zhang-chenyi/PaddleCusto…
zhang-chenyi Aug 27, 2025
83bc87f
[Metax] fix compile fail
StareAtYou Aug 27, 2025
f1e8d0c
Revert "[Metax] fix compile fail"
StareAtYou Aug 27, 2025
a13daa8
[Metax] fix compile fail by 'conv_transpose_grad_kernel_impl.h'
StareAtYou Aug 27, 2025
95a179b
[Metax] fix bug & add some kernel register
metax666 Aug 28, 2025
4576ef4
[Metax]fix bug and add qr lstsq logsoftmax
duqimeng Aug 28, 2025
ca51a1e
Merge branch 'metax666:develop' into develop
duqimeng Aug 28, 2025
7789e9b
[Metax] con2d_grad use gpudnn
duqimeng Aug 22, 2025
afd0863
[Metax]fix bug and add qr lstsq logsoftmax
duqimeng Aug 28, 2025
6da0f0d
Merge branch 'metax666:develop' into develop
duqimeng Aug 28, 2025
e1e07ba
[Metax] change_patch
duqimeng Aug 28, 2025
046637c
[Metax] change_patch
metax666 Aug 28, 2025
c27b492
Merge branch 'PaddlePaddle:develop' into develop
metax666 Aug 28, 2025
05ecd9d
[Metax] update unit test CMakeLists.txt
StareAtYou Aug 28, 2025
b1bf7e8
[Metax] update unit test CMakeLists.txt
StareAtYou Aug 28, 2025
f90d585
Merge branch 'metax666:develop' into develop
StareAtYou Aug 28, 2025
874d9b6
Merge branch 'metax666:develop' into develop
zhang-chenyi Aug 28, 2025
0ca02b9
[feature] add unique_consecutive kernel
zhang-chenyi Aug 28, 2025
40d8f21
[metax-feature] add kernel for test_math_op_patch_var_base
metax666 Aug 28, 2025
3e9b526
[metax] add some kernel
duqimeng Aug 28, 2025
8911576
[metax] add some kernel
duqimeng Aug 28, 2025
8471597
Merge branch 'metax666:develop' into develop
duqimeng Aug 28, 2025
0758887
Merge branch 'metax666:develop' into develop
StareAtYou Aug 29, 2025
61be33d
[Metax] register baddbmm kernel & update blas api
StareAtYou Aug 29, 2025
2fe962e
[Metax] register baddbmm kernel & update blas api
StareAtYou Aug 29, 2025
531fedb
Merge branch 'metax666:develop' into develop
StareAtYou Aug 29, 2025
c0dcfff
[Metax] register deformable_conv kernel & fix 'ModulatedDeformableCol…
StareAtYou Aug 29, 2025
bd65451
[feature] add add unique_consecutive kernel.cu
zhang-chenyi Aug 29, 2025
0def63d
[fix] fix some test case due to missing op register
zhang-chenyi Aug 29, 2025
e503c9e
[fix] fix some fail text
zhang-chenyi Aug 29, 2025
9844878
[metax]fix lu eigvalshsqueeze rnn kernel
duqimeng Aug 29, 2025
70b86e7
[metax]fix lu eigvalshsqueeze rnn kernel
duqimeng Aug 29, 2025
1e90757
add and fix some kernels
1184319564 Aug 30, 2025
f93307d
[Metax] register deformable_conv kernel & fix 'ModulatedDeformableCol…
StareAtYou Aug 29, 2025
c4b0eb9
[Metax] fix conflict
StareAtYou Sep 1, 2025
06dda18
[Metax] fix conflict
StareAtYou Sep 1, 2025
dae6ce8
[Metax] adapt to paddle-cpu-20250901 & resolve the issue of 'test_ele…
StareAtYou Sep 1, 2025
b4a5c62
[Metax] update repeat_interleave kernel & ignore max op test
StareAtYou Sep 2, 2025
7cf4405
Merge branch 'metax666:develop' into develop
StareAtYou Sep 2, 2025
0015f2e
[Metax] register deformable_conv kernel & fix 'ModulatedDeformableCol…
metax666 Sep 2, 2025
fc2c0f5
Merge branch 'metax666:develop' into develop
duqimeng Sep 2, 2025
829c3b6
Merge dev
duqimeng Sep 2, 2025
3104a9c
【metax】add and fix some kernels
metax666 Sep 2, 2025
175cca6
[metax]fix lu eigvalshsqueeze rnn kernel
metax666 Sep 2, 2025
c7db810
[metax]fix lu eigvalshsqueeze rnn kernel
duqimeng Aug 29, 2025
f5813ed
[metax] chang patch fix copy
duqimeng Sep 2, 2025
6f0b705
[metax] chang patch fix copy
duqimeng Sep 2, 2025
8f47f0e
[metax] chang patch fix copy
metax666 Sep 2, 2025
b420f97
[Metax] update metax_gpu unit test
StareAtYou Sep 2, 2025
c08533e
[Metax] update metax_gpu unit test
metax666 Sep 2, 2025
414715f
[Metax] fix test CMakeList.txt
StareAtYou Sep 2, 2025
aa6b5bf
[Metax] fix test CMakeList.txt
metax666 Sep 2, 2025
0bfc6e7
[metax]change_cupti_and_fix_softmax
duqimeng Sep 9, 2025
cb93f6a
[metax]change_cupti_and_fix_softmax
duqimeng Sep 9, 2025
2e99f62
[metax]change_patch
duqimeng Sep 9, 2025
026551a
[metax]change_patch
duqimeng Sep 9, 2025
b09babb
Merge branch 'metax666:develop' into develop
duqimeng Sep 9, 2025
31594f8
[metax] updata_qr_kernel
duqimeng Sep 11, 2025
4fb467c
[metax] updata_qr_kernel
duqimeng Sep 11, 2025
5dc60a3
Merge branch 'metax666:develop' into develop
duqimeng Sep 11, 2025
e4fd192
Merge branch 'metax666:develop' into develop
duqimeng Sep 15, 2025
471b184
[Metax] fix cufft and fix some blas kernel apply
duqimeng Sep 15, 2025
a0d237c
Merge branch 'metax666:develop' into develop
duqimeng Sep 15, 2025
4c86266
[metax] fix bug
duqimeng Sep 15, 2025
a8b4696
[Metax] add github action
duqimeng Sep 16, 2025
8dff471
[metax]chaneg build
duqimeng Sep 16, 2025
ee4eefd
[metax]chaneg build
duqimeng Sep 16, 2025
8a36c4c
[metax]chaneg build
duqimeng Sep 16, 2025
bd5ac4d
Merge branch 'develop' into develop
duqimeng Sep 16, 2025
656d684
[metax]chaneg build
duqimeng Sep 16, 2025
2c224ad
[metax]chaneg build
duqimeng Sep 16, 2025
4c65070
Merge branch 'develop' of https://github.com/duqimeng/PaddleCustomDev…
duqimeng Sep 16, 2025
a7f6ed7
[metax]chaneg build
duqimeng Sep 16, 2025
9bfec7e
Merge branch 'develop' into develop
metax666 Sep 16, 2025
00014e2
[metax]chaneg build
duqimeng Sep 16, 2025
25e76dc
Merge branch 'develop' of https://github.com/duqimeng/PaddleCustomDev…
duqimeng Sep 16, 2025
e95cc2c
Merge branch 'metax666:develop' into develop
duqimeng Sep 16, 2025
a7f53dd
Merge branch 'metax666:develop' into develop
duqimeng Sep 16, 2025
6ada0e9
[metax]fix_code style and index_elementwise_put_kernel
duqimeng Sep 17, 2025
3834990
[metax]change_build
duqimeng Sep 17, 2025
77ebcb8
[metax]change_build
duqimeng Sep 17, 2025
19c9184
Merge branch 'develop' into develop
metax666 Sep 17, 2025
4339ed4
Merge branch 'metax666:develop' into develop
duqimeng Sep 17, 2025
44532ba
change_metax_work
duqimeng Sep 17, 2025
02047f9
change_metax_work
duqimeng Sep 17, 2025
bda901e
change_metax_work
duqimeng Sep 17, 2025
1c7d32a
change_metax_work
duqimeng Sep 17, 2025
ed8f128
Merge branch 'develop' into develop
metax666 Sep 17, 2025
287691f
Merge branch 'metax666:develop' into develop
duqimeng Sep 17, 2025
976ecec
change_metax_work
duqimeng Sep 17, 2025
0c6ebe2
change_warpctc.cmake
duqimeng Sep 18, 2025
5e7a84b
change warpctc.cmake
duqimeng Sep 18, 2025
542efeb
test
duqimeng Sep 18, 2025
40daeb9
change_run_ut
duqimeng Sep 18, 2025
4c21a9c
Merge branch 'develop' into develop
metax666 Sep 18, 2025
322dc15
remove_tets
duqimeng Sep 18, 2025
0e4b75d
Merge branch 'develop' of https://github.com/duqimeng/PaddleCustomDev…
duqimeng Sep 18, 2025
bd106bd
Merge branch 'metax666:develop' into develop
duqimeng Sep 18, 2025
7dbab02
test
duqimeng Sep 18, 2025
27ebafe
Merge branch 'develop' of https://github.com/duqimeng/PaddleCustomDev…
duqimeng Sep 18, 2025
bd39be9
Merge branch 'metax666:develop' into develop
duqimeng Sep 19, 2025
f79b1bd
add_generate_pb
duqimeng Sep 19, 2025
f8a0cca
Merge branch 'metax666:develop' into develop
duqimeng Sep 22, 2025
37aa236
Merge branch 'metax666:develop' into develop
duqimeng Sep 22, 2025
6f925da
Merge branch 'metax666:develop' into develop
duqimeng Sep 23, 2025
e08b161
[metax]fix paddle bug
duqimeng Sep 23, 2025
9404022
Merge branch 'metax666:develop' into develop
duqimeng Sep 23, 2025
1a0a84e
change_ut
duqimeng Sep 23, 2025
ece9f09
change_ut
duqimeng Sep 23, 2025
d1d25ad
change_ut
duqimeng Sep 24, 2025
8ff82b6
Merge branch 'metax666:develop' into develop
duqimeng Sep 24, 2025
bfdf3da
Merge branch 'metax666:develop' into develop
duqimeng Sep 26, 2025
be4aeff
Merge branch 'metax666:develop' into develop
duqimeng Sep 26, 2025
d75ccc7
[metax]fix patch and fix missing kernel
duqimeng Sep 29, 2025
b6b8778
Merge branch 'metax666:develop' into develop
duqimeng Sep 30, 2025
901d3db
[metax] link mccl and fix missing kernel
duqimeng Sep 30, 2025
a561f35
[metax] rename yaml file
duqimeng Sep 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/metax_work.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: padlle metax gpu test
name: paddle metax gpu test

on:
workflow_dispatch:
Expand Down
7 changes: 7 additions & 0 deletions backends/metax_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/im2sequence_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/im2sequence_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/increment_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/embedding_grad_add_to_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/cross_entropy_bwd_w_downcast.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_put_kernel.cu
Expand Down Expand Up @@ -728,6 +730,11 @@ target_link_libraries(
${WARPCTC_LIBRARIES}
${WARPRNNT_LIBRARIES}
${PADDLE_CORE_LIB})

target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmccl.so)
target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmcFlashAttn.so)
target_link_libraries(${TARGET_NAME} /opt/maca/lib/libmcpti.so)

include_directories(BEFORE ${PADDLE_SOURCE_DIR})

target_compile_definitions(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/cross_entropy_grad_kernel.h"

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif

#include "kernels/gpudnn/softmax_gpudnn.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/softmax.h"

namespace phi {

/*
Vectorized wrapper of softmax with cross entropy grad hard label.
Optimized with float4 vectorization for memory coalescing and improved
throughput.
*/
template <typename T, typename LabelT, typename LogitT>
__global__ void SoftmaxWithCrossEntropyGradHardLabelVectorized(
LogitT* __restrict__ logits_grad,
const T* __restrict__ loss_grad,
const T* __restrict__ softmax,
const LabelT* __restrict__ labels,
const int64_t n,
const int64_t dim,
const int64_t d,
const int ignore_index) {
// Vectorized load/store with float4 for 128-bit memory transactions
constexpr int VEC_SIZE = 4;
using VecT = typename phi::AlignedVector<LogitT, VEC_SIZE>;
using SoftmaxVecT = typename phi::AlignedVector<T, VEC_SIZE>;

int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
int64_t vec_id = tid * VEC_SIZE;

// Ensure we don't exceed bounds
if (vec_id >= n * dim * d) return;

// Compute indices for vectorized access
int64_t idx_n = vec_id / (d * dim);
int64_t idx_dim_start = (vec_id / d) % dim;
int64_t idx_d = vec_id % d;
int64_t ids = idx_n * d + idx_d;

// Load label once per thread
auto lbl = static_cast<int64_t>(labels[ids]);

if (lbl == ignore_index) {
// Vectorized zero fill for ignore_index
VecT* vec_grad = reinterpret_cast<VecT*>(&logits_grad[vec_id]);
VecT zero_vec;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
zero_vec.val[i] = static_cast<LogitT>(0.0f);
}
*vec_grad = zero_vec;
return;
}

// Vectorized load of softmax values
SoftmaxVecT softmax_vec;
const SoftmaxVecT* softmax_ptr =
reinterpret_cast<const SoftmaxVecT*>(&softmax[vec_id]);
softmax_vec = *softmax_ptr;

// Load loss gradient (broadcast across vector elements)
T loss_grad_val = loss_grad[ids];

// Vectorized computation
VecT grad_vec;
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
int64_t current_dim = idx_dim_start + i;
if (current_dim < dim) { // Bounds check for partial vectors
float softmax_val = static_cast<float>(softmax_vec.val[i]);
float grad_val;

if (lbl == current_dim) {
grad_val = (softmax_val - 1.0f) * static_cast<float>(loss_grad_val);
} else {
grad_val = softmax_val * static_cast<float>(loss_grad_val);
}

grad_vec.val[i] = static_cast<LogitT>(grad_val);
} else {
grad_vec.val[i] = static_cast<LogitT>(0.0f);
}
}

// Vectorized store
VecT* grad_ptr = reinterpret_cast<VecT*>(&logits_grad[vec_id]);
*grad_ptr = grad_vec;
}

/*
Specialized kernel for dimensions not divisible by vector size
Uses warp-level primitives for better performance on irregular sizes
*/
template <typename T, typename LabelT, typename LogitT>
__global__ void SoftmaxWithCrossEntropyGradHardLabelWarp(
LogitT* __restrict__ logits_grad,
const T* __restrict__ loss_grad,
const T* __restrict__ softmax,
const LabelT* __restrict__ labels,
const int64_t n,
const int64_t dim,
const int64_t d,
const int ignore_index) {
const int warps_per_block = 4;
const int threads_per_warp = 32;
const int threads_per_block = warps_per_block * threads_per_warp;

int tid = blockIdx.x * threads_per_block + threadIdx.x;
int warp_id = threadIdx.x / threads_per_warp;
int lane_id = threadIdx.x % threads_per_warp;

// Process multiple elements per thread using warp-level parallelism
int64_t elements_per_thread =
(n * dim * d + gridDim.x * threads_per_block - 1) /
(gridDim.x * threads_per_block);

for (int e = 0; e < elements_per_thread; ++e) {
int64_t idx = tid + e * gridDim.x * threads_per_block;
if (idx >= n * dim * d) break;

int64_t idx_n = idx / (d * dim);
int64_t idx_dim = (idx / d) % dim;
int64_t idx_d = idx % d;
int64_t ids = idx_n * d + idx_d;

auto lbl = static_cast<int64_t>(labels[ids]);

if (lbl == ignore_index) {
logits_grad[idx] = static_cast<LogitT>(0.0f);
} else if (lbl == idx_dim) {
logits_grad[idx] =
static_cast<LogitT>((static_cast<float>(softmax[idx]) - 1.0f) *
static_cast<float>(loss_grad[ids]));
} else {
logits_grad[idx] =
static_cast<LogitT>(static_cast<float>(softmax[idx]) *
static_cast<float>(loss_grad[ids]));
}
}
}

/*
Optimized kernel selector based on problem size and alignment
*/
template <typename T, typename LabelT, typename LogitT>
void LaunchOptimizedCrossEntropyGradKernel(const GPUContext& dev_ctx,
LogitT* logits_grad,
const T* loss_grad,
const T* softmax,
const LabelT* labels,
const int64_t n,
const int64_t dim,
const int64_t d,
const int ignore_index) {
const int64_t total_elements = n * dim * d;
auto stream = dev_ctx.stream();

// Check alignment for vectorized kernel
bool is_aligned = (reinterpret_cast<uintptr_t>(logits_grad) % 16 == 0) &&
(reinterpret_cast<uintptr_t>(softmax) % 16 == 0) &&
(total_elements % 4 == 0);

if (is_aligned && total_elements >= 1024) {
// Use vectorized kernel for aligned, large problems
constexpr int VEC_SIZE = 4;
const int threads_per_block = 256;
const int vec_elements = total_elements / VEC_SIZE;
const int blocks =
(vec_elements + threads_per_block - 1) / threads_per_block;

SoftmaxWithCrossEntropyGradHardLabelVectorized<T, LabelT, LogitT>
<<<blocks, threads_per_block, 0, stream>>>(
logits_grad, loss_grad, softmax, labels, n, dim, d, ignore_index);
} else {
// Use warp-specialized kernel for irregular sizes
const int warps_per_block = 4;
const int threads_per_block = warps_per_block * 32;
const int blocks =
std::min(1024,
static_cast<int>((total_elements + threads_per_block - 1) /
threads_per_block));

SoftmaxWithCrossEntropyGradHardLabelWarp<T, LabelT, LogitT>
<<<blocks, threads_per_block, 0, stream>>>(
logits_grad, loss_grad, softmax, labels, n, dim, d, ignore_index);
}
}

template <typename T, typename LabelT>
void CrossEntropyWithSoftmaxBwdWithDowncastGPUKernel(
const GPUContext& dev_ctx,
const DenseTensor& label,
const DenseTensor& softmax,
const DenseTensor& loss_grad,
int axis,
DenseTensor* logits_grad) {
// PADDLE_ENFORCE_EQ(
// dev_ctx.GetPlace().GetType(),
// phi::AllocationType::GPU,
// common::errors::Unavailable("softmax_with_cross_entropy operator's "
// "CUDA kernel only runs on GPU device."));

using LogitT = phi::bfloat16;
const T* loss_grad_data = loss_grad.data<T>();
DenseTensor* logit_grad = logits_grad;

LogitT* logit_grad_data = nullptr;
logit_grad_data = dev_ctx.template Alloc<LogitT>(logit_grad);

const int rank = logit_grad->dims().size();
const int axis_v = phi::funcs::CanonicalAxis(axis, rank);
int axis_dim = logit_grad->dims()[axis_v];

const int64_t n = phi::funcs::SizeToAxis(axis_v, logit_grad->dims());
const int64_t d = phi::funcs::SizeFromAxis(axis_v, logit_grad->dims());
const int64_t remain = d / axis_dim;

const T* softmax_data = softmax.data<T>();
const auto* label_data = label.data<LabelT>();

// Launch optimized kernel with automatic selection
LaunchOptimizedCrossEntropyGradKernel<T, LabelT, LogitT>(dev_ctx,
logit_grad_data,
loss_grad_data,
softmax_data,
label_data,
n,
axis_dim,
remain,
-100);
}

template <typename T, typename Context>
void CrossEntropyWithSoftmaxBwdWithDowncastKernel(const Context& dev_ctx,
const DenseTensor& label,
const DenseTensor& softmax,
const DenseTensor& loss_grad,
DenseTensor* logits_grad) {
constexpr int axis = -1;
if (logits_grad->numel() == 0) {
dev_ctx.template Alloc<phi::bfloat16>(logits_grad);
return;
}
auto dtype = label.dtype();
PD_VISIT_INTEGRAL_TYPES(
dtype, "CrossEntropyWithSoftmaxBwdWithDowncastGPUKernel", ([&] {
CrossEntropyWithSoftmaxBwdWithDowncastGPUKernel<T, data_t>(
dev_ctx, label, softmax, loss_grad, axis, logits_grad);
}));
}

} // namespace phi

PD_REGISTER_PLUGIN_KERNEL(cross_entropy_with_softmax_bwd_w_downcast,
metax_gpu,
ALL_LAYOUT,
phi::CrossEntropyWithSoftmaxBwdWithDowncastKernel,
float,
double,
phi::float16) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_grad.h"
#include "paddle/phi/kernels/gpu/embedding_grad_add_to_kernel.cu" // NOLINT

PD_CUSTOM_KERNEL_REGISTER(embedding_grad_add_to,
metax_gpu,
ALL_LAYOUT,
phi::EmbeddingGradAddToAddToKernel,
float,
double,
phi::float16,
phi::bfloat16) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/legacy/gpu/moe_combine_no_weight_grad_kernel.cu" // NOLINT

PD_CUSTOM_KERNEL_REGISTER(moe_combine_no_weight_grad,
metax_gpu,
ALL_LAYOUT,
phi::MoeCombineNoWeightGradKernel,
float,
double,
phi::bfloat16,
phi::float16) {}
Loading
Loading