diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index 3b74ae39c18..5930eaaebd2 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -535,6 +535,7 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/clip_by_norm_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/uniform_random_batch_size_like_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/get_tensor_from_selected_rows_kernel.cu + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/batch_norm_grad_kernel.cc ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/batch_norm_kernel.cc ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/empty_kernel.cc @@ -642,6 +643,8 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rms_norm_kernel.cu + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/lars_momentum_kernel.cu + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/partial_sum_kernel.cu # ############################################################################ ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu # kernels/kps diff --git a/backends/metax_gpu/kernels/cuda_kernels/adam_kernel_selected_rows.cu b/backends/metax_gpu/kernels/cuda_kernels/adam_kernel_selected_rows.cu new file mode 100644 index 00000000000..df4105efbd2 --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/adam_kernel_selected_rows.cu @@ -0,0 +1,41 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/selected_rows_functor.h" +#include "paddle/phi/kernels/selected_rows/adam_kernel.h" + +PD_CUSTOM_KERNEL_REGISTER(adam_dense_param_sparse_grad, + metax_gpu, + ALL_LAYOUT, + phi::sr::AdamDenseParamSparseGradKernel, + float, + double, + phi::float16) { + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND); + + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); + } + kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED); +} diff --git a/backends/metax_gpu/kernels/cuda_kernels/einsum_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/einsum_kernel_register.cu index 444928af78f..0f613b55e9e 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/einsum_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/einsum_kernel_register.cu @@ -23,10 +23,10 @@ PD_CUSTOM_KERNEL_REGISTER(einsum, phi::EinsumKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(einsum_infer, metax_gpu, @@ -34,7 +34,7 @@ PD_CUSTOM_KERNEL_REGISTER(einsum_infer, phi::EinsumInferKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/lars_momentum_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/lars_momentum_kernel_register.cu new file mode 100644 index 00000000000..5647c806bfd --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/lars_momentum_kernel_register.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/lars_momentum_kernel.h" + +PD_CUSTOM_KERNEL_REGISTER(lars_momentum, + metax_gpu, + ALL_LAYOUT, + phi::LarsMomentumKernel, + float, + double, + phi::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + } +} diff --git a/backends/metax_gpu/kernels/cuda_kernels/nonzero_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/nonzero_kernel_register.cu index 1f84b628e84..dc92b2c6d69 100755 --- a/backends/metax_gpu/kernels/cuda_kernels/nonzero_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/nonzero_kernel_register.cu @@ -23,11 +23,13 @@ PD_CUSTOM_KERNEL_REGISTER(nonzero, int64_t, int, int16_t, - phi::dtype::float16, - phi::dtype::bfloat16, + phi::float16, + phi::bfloat16, bool, float, - double) { + double, + phi::complex64, + phi::complex128) { kernel->OutputAt(0).SetDataType(phi::DataType::INT64); } diff --git a/backends/metax_gpu/kernels/cuda_kernels/put_along_axis_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/put_along_axis_kernel_register.cu index 8ff1f5959ab..ca93a8ca079 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/put_along_axis_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/put_along_axis_kernel_register.cu @@ -23,6 +23,8 @@ PD_CUSTOM_KERNEL_REGISTER(put_along_axis, float, double, int64_t, + uint8_t, + int16_t, int, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/patch/paddle.patch b/backends/metax_gpu/patch/paddle.patch index beefb730bf7..4c06609338c 100755 --- a/backends/metax_gpu/patch/paddle.patch +++ b/backends/metax_gpu/patch/paddle.patch @@ -869,19 +869,6 @@ index e838778952..83e805e75a 100644 namespace phi { namespace fusion { -diff --git a/paddle/phi/kernels/gpu/correlation_kernel.cu b/paddle/phi/kernels/gpu/correlation_kernel.cu -index 4c93778bde..c7bdf8a2cc 100644 ---- a/paddle/phi/kernels/gpu/correlation_kernel.cu -+++ b/paddle/phi/kernels/gpu/correlation_kernel.cu -@@ -103,7 +103,7 @@ void CorrelationCUDAKernel(const Context &dev_ctx, - int stride2, - int corr_type_multiply, - DenseTensor *out) { -- bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU; -+ bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM; - PADDLE_ENFORCE_EQ( - is_gpu_place, - true, diff --git a/paddle/phi/kernels/gpu/depthwise_conv.h b/paddle/phi/kernels/gpu/depthwise_conv.h index f0cca0f701..02ea957240 100644 --- a/paddle/phi/kernels/gpu/depthwise_conv.h @@ -897,19 +884,6 @@ index f0cca0f701..02ea957240 100644 namespace phi { // To determine use cudnn or not. -diff --git a/paddle/phi/kernels/gpu/dgc_kernel.cu b/paddle/phi/kernels/gpu/dgc_kernel.cu -index c2ddfa1347..c6adf5a6de 100644 ---- a/paddle/phi/kernels/gpu/dgc_kernel.cu -+++ b/paddle/phi/kernels/gpu/dgc_kernel.cu -@@ -188,7 +188,7 @@ void DGCKernel(const Context& dev_ctx, - int buf_size = paddle::communication::dgc::get_buffer_size(k); - phi::Allocator::AllocationPtr tmp_ious_data; - #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -- if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) { -+ if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) { - tmp_ious_data = phi::memory_utils::Alloc( - dev_ctx.GetPlace(), - buf_size, diff --git a/paddle/phi/kernels/gpu/gelu_funcs.h b/paddle/phi/kernels/gpu/gelu_funcs.h index 29fa252e96..4ae72b0935 100644 --- a/paddle/phi/kernels/gpu/gelu_funcs.h @@ -974,19 +948,6 @@ index 1bdbe1564c..f753b54bc6 100644 #include "paddle/phi/kernels/impl/qr_kernel_impl.h" #include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h" #include "paddle/phi/kernels/lstsq_kernel.h" -diff --git a/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu b/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu -index 05a977828f..5136608c41 100644 ---- a/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu -+++ b/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu -@@ -58,7 +58,7 @@ void ShuffleBatchKernel(const Context& dev_ctx, - int64_t seed_int = 0; - if (seed.initialized()) { - const auto& seed_place = seed.place().GetType(); -- bool is_gpu_place = seed_place == phi::AllocationType::GPU; -+ bool is_gpu_place = seed_place == phi::AllocationType::GPU || seed_place == phi::AllocationType::CUSTOM; - if (is_gpu_place) { - // NOTE: We have overwritten GetKernelTypeForVar, so seed_place would - // not be CUDAPlace in practice. This case would only happen in Python diff --git a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h b/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h index 9bc5326c90..79b57a8203 100644 --- a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h @@ -1144,32 +1105,6 @@ index 6f03f76eeb..5fe2c3e7dc 100644 #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/matrix_inverse.h" -diff --git a/paddle/phi/kernels/impl/merged_momentum_impl.h b/paddle/phi/kernels/impl/merged_momentum_impl.h -index 7b85903776..3f4b298807 100644 ---- a/paddle/phi/kernels/impl/merged_momentum_impl.h -+++ b/paddle/phi/kernels/impl/merged_momentum_impl.h -@@ -297,7 +297,7 @@ void MergedMomentumInnerCompute( - params_out[idx], - velocities_out[idx]); - VLOG(10) << "Launch MergedMomentum cpu kernel."; -- } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) { -+ } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) { - phi::funcs::ForRange for_range( - static_cast(dev_ctx), params[idx]->numel()); - const auto grad_type = grads[idx]->dtype(); -diff --git a/paddle/phi/kernels/impl/momentum_kernel_impl.h b/paddle/phi/kernels/impl/momentum_kernel_impl.h -index de5bcfc30b..eb2a9714f5 100644 ---- a/paddle/phi/kernels/impl/momentum_kernel_impl.h -+++ b/paddle/phi/kernels/impl/momentum_kernel_impl.h -@@ -457,7 +457,7 @@ void MomentumDenseImpl(const Context& dev_ctx, - regularization_coeff, - param_out, - velocity_out); -- } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) { -+ } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) { - funcs::ForRange for_range(dev_ctx, param.numel()); - const auto grad_type = grad.dtype(); - #define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \ diff --git a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h index 4099d8b506..baef2cd643 100644 --- a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h