diff --git a/Paddle b/Paddle index a8b4de5f626..b51d97ff7ff 160000 --- a/Paddle +++ b/Paddle @@ -1 +1 @@ -Subproject commit a8b4de5f6260e598d6426f7778364d1277b2ad76 +Subproject commit b51d97ff7ff0bdac6a16380ee90100b787979b05 diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index bca1ce7aad4..5930eaaebd2 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -326,7 +326,7 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/im2sequence_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/im2sequence_grad_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/increment_kernel.cu - ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu + # ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_put_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu @@ -535,6 +535,7 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/clip_by_norm_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/uniform_random_batch_size_like_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/get_tensor_from_selected_rows_kernel.cu + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/batch_norm_grad_kernel.cc ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/batch_norm_kernel.cc ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/empty_kernel.cc @@ -642,6 +643,8 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rms_norm_kernel.cu + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/lars_momentum_kernel.cu + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/partial_sum_kernel.cu # ############################################################################ ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu # kernels/kps diff --git a/backends/metax_gpu/kernels/cuda_kernels/adam_kernel_selected_rows.cu b/backends/metax_gpu/kernels/cuda_kernels/adam_kernel_selected_rows.cu new file mode 100644 index 00000000000..df4105efbd2 --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/adam_kernel_selected_rows.cu @@ -0,0 +1,41 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/selected_rows_functor.h" +#include "paddle/phi/kernels/selected_rows/adam_kernel.h" + +PD_CUSTOM_KERNEL_REGISTER(adam_dense_param_sparse_grad, + metax_gpu, + ALL_LAYOUT, + phi::sr::AdamDenseParamSparseGradKernel, + float, + double, + phi::float16) { + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND); + + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32); + } + kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED); + kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED); +} diff --git a/backends/metax_gpu/kernels/cuda_kernels/einsum_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/einsum_kernel_register.cu index 444928af78f..0f613b55e9e 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/einsum_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/einsum_kernel_register.cu @@ -23,10 +23,10 @@ PD_CUSTOM_KERNEL_REGISTER(einsum, phi::EinsumKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} PD_CUSTOM_KERNEL_REGISTER(einsum_infer, metax_gpu, @@ -34,7 +34,7 @@ PD_CUSTOM_KERNEL_REGISTER(einsum_infer, phi::EinsumInferKernel, float, double, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/index_elementwise_get_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/index_elementwise_get_kernel_register.cu index 5ab3d2a3170..a45a740fc61 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/index_elementwise_get_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/index_elementwise_get_kernel_register.cu @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/index_elementwise_get_kernel.h" +#include "paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu" // NOLINT PD_CUSTOM_KERNEL_REGISTER(index_elementwise_get, metax_gpu, @@ -27,7 +27,7 @@ PD_CUSTOM_KERNEL_REGISTER(index_elementwise_get, int64_t, int16_t, uint8_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/lars_momentum_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/lars_momentum_kernel_register.cu new file mode 100644 index 00000000000..5647c806bfd --- /dev/null +++ b/backends/metax_gpu/kernels/cuda_kernels/lars_momentum_kernel_register.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/lars_momentum_kernel.h" + +PD_CUSTOM_KERNEL_REGISTER(lars_momentum, + metax_gpu, + ALL_LAYOUT, + phi::LarsMomentumKernel, + float, + double, + phi::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + } +} diff --git a/backends/metax_gpu/kernels/cuda_kernels/multinomial_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/multinomial_kernel_register.cu index 622e70728f1..1325fa339b0 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/multinomial_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/multinomial_kernel_register.cu @@ -21,6 +21,7 @@ PD_CUSTOM_KERNEL_REGISTER(multinomial, phi::MultinomialKernel, phi::dtype::float16, phi::dtype::bfloat16, - float) { + float, + double) { kernel->OutputAt(0).SetDataType(phi::DataType::INT64); } diff --git a/backends/metax_gpu/kernels/cuda_kernels/nonzero_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/nonzero_kernel_register.cu index 1f84b628e84..dc92b2c6d69 100755 --- a/backends/metax_gpu/kernels/cuda_kernels/nonzero_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/nonzero_kernel_register.cu @@ -23,11 +23,13 @@ PD_CUSTOM_KERNEL_REGISTER(nonzero, int64_t, int, int16_t, - phi::dtype::float16, - phi::dtype::bfloat16, + phi::float16, + phi::bfloat16, bool, float, - double) { + double, + phi::complex64, + phi::complex128) { kernel->OutputAt(0).SetDataType(phi::DataType::INT64); } diff --git a/backends/metax_gpu/kernels/cuda_kernels/put_along_axis_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/put_along_axis_kernel_register.cu index 8ff1f5959ab..ca93a8ca079 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/put_along_axis_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/put_along_axis_kernel_register.cu @@ -23,6 +23,8 @@ PD_CUSTOM_KERNEL_REGISTER(put_along_axis, float, double, int64_t, + uint8_t, + int16_t, int, - phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::float16, + phi::bfloat16) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/take_along_axis_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/take_along_axis_kernel_register.cu index 4b23b0820fc..b628552aaaf 100644 --- a/backends/metax_gpu/kernels/cuda_kernels/take_along_axis_kernel_register.cu +++ b/backends/metax_gpu/kernels/cuda_kernels/take_along_axis_kernel_register.cu @@ -25,4 +25,7 @@ PD_CUSTOM_KERNEL_REGISTER(take_along_axis, int64_t, int, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + uint8_t, // 支持 uint8 + int16_t // 支持 int16 +) {} diff --git a/backends/metax_gpu/kernels/metax_kernel/fused_conv2d_add_act_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_conv2d_add_act_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/metax_kernel/fused_conv2d_add_act_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_conv2d_add_act_kernel_register.cu diff --git a/backends/metax_gpu/kernels/metax_kernel/fused_rope_grad_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_rope_grad_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/metax_kernel/fused_rope_grad_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_rope_grad_kernel_register.cu diff --git a/backends/metax_gpu/kernels/metax_kernel/fused_rope_kernel_register.cu b/backends/metax_gpu/kernels/fusion/fused_rope_kernel_register.cu similarity index 100% rename from backends/metax_gpu/kernels/metax_kernel/fused_rope_kernel_register.cu rename to backends/metax_gpu/kernels/fusion/fused_rope_kernel_register.cu diff --git a/backends/metax_gpu/kernels/metax_kernel/addmm_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/addmm_kernel_register.cu index 287fa8de41a..ead21b1eb7e 100644 --- a/backends/metax_gpu/kernels/metax_kernel/addmm_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/addmm_kernel_register.cu @@ -22,5 +22,6 @@ PD_REGISTER_PLUGIN_KERNEL(addmm, ALL_LAYOUT, phi::AddmmKernel, float, + double, phi::dtype::float16, phi::dtype::bfloat16) {} diff --git a/backends/metax_gpu/kernels/metax_kernel/layer_norm_grad_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/layer_norm_grad_kernel_register.cu index 87c06dab2a4..857dcb6d522 100644 --- a/backends/metax_gpu/kernels/metax_kernel/layer_norm_grad_kernel_register.cu +++ b/backends/metax_gpu/kernels/metax_kernel/layer_norm_grad_kernel_register.cu @@ -115,6 +115,7 @@ PD_REGISTER_PLUGIN_KERNEL(layer_norm_grad, ALL_LAYOUT, phi::LayerNormGradKernel, float, + double, phi::dtype::float16, phi::dtype::bfloat16) { if (kernel_key.dtype() == phi::DataType::FLOAT16) { diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.cc b/backends/metax_gpu/kernels/metax_kernel/metax_context.cc index efddba5f00b..0712fb75bbe 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.cc +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.cc @@ -15,24 +15,6 @@ #include "kernels/metax_kernel/metax_context.h" namespace phi { -const bool allow_tf32_cublas = []() -> bool { - const char* v = std::getenv("ALLOW_TF32_CUBLAS"); - if (v) { - return std::atoi(v); - } - return true; -}(); - -const bool allow_tf32_cudnn = []() -> bool { - const char* v = std::getenv("ALLOW_TF32_CUDNN"); - if (v) { - return std::atoi(v); - } - return false; -}(); - -bool AllowTF32Cublas() { return allow_tf32_cublas; } -bool AllowTF32Cudnn() { return allow_tf32_cudnn; } void DnnWorkspaceHandle::RunFuncSync( const std::function& cudnn_func, size_t required_workspace_bytes, @@ -42,19 +24,11 @@ void DnnWorkspaceHandle::RunFuncSync( void* workspace_ptr = nullptr; size_t size = ((required_workspace_bytes + 255) >> 8) << 8; std::lock_guard guard(*mtx_); -#ifdef PADDLE_WITH_HIP - auto status = hipMalloc(&workspace_ptr, size); -#else auto status = cudaMalloc(&workspace_ptr, size); -#endif if (status == gpuSuccess) { cudnn_func(workspace_ptr); phi::backends::gpu::GpuStreamSync(stream_); -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS(hipFree(workspace_ptr)); -#else PADDLE_ENFORCE_GPU_SUCCESS(cudaFree(workspace_ptr)); -#endif return; } } diff --git a/backends/metax_gpu/kernels/metax_kernel/metax_context.h b/backends/metax_gpu/kernels/metax_kernel/metax_context.h index 2d761439089..7386811a236 100644 --- a/backends/metax_gpu/kernels/metax_kernel/metax_context.h +++ b/backends/metax_gpu/kernels/metax_kernel/metax_context.h @@ -18,6 +18,7 @@ #include #include "kernels/funcs/blas/cublasLt.h" +#include "paddle/phi/backends/context_pool.h" #include "paddle/phi/backends/custom/custom_context.h" #include "paddle/phi/backends/gpu/forwards.h" #include "paddle/phi/backends/gpu/gpu_decls.h" @@ -30,8 +31,6 @@ cublasLtHandle_t GetBlasLtHandle(); namespace phi { -bool AllowTF32Cublas(); -bool AllowTF32Cudnn(); class DnnWorkspaceHandle { public: inline DnnWorkspaceHandle(Allocator* allocator, gpuStream_t stream) diff --git a/backends/metax_gpu/patch/paddle.patch b/backends/metax_gpu/patch/paddle.patch index beefb730bf7..4c06609338c 100755 --- a/backends/metax_gpu/patch/paddle.patch +++ b/backends/metax_gpu/patch/paddle.patch @@ -869,19 +869,6 @@ index e838778952..83e805e75a 100644 namespace phi { namespace fusion { -diff --git a/paddle/phi/kernels/gpu/correlation_kernel.cu b/paddle/phi/kernels/gpu/correlation_kernel.cu -index 4c93778bde..c7bdf8a2cc 100644 ---- a/paddle/phi/kernels/gpu/correlation_kernel.cu -+++ b/paddle/phi/kernels/gpu/correlation_kernel.cu -@@ -103,7 +103,7 @@ void CorrelationCUDAKernel(const Context &dev_ctx, - int stride2, - int corr_type_multiply, - DenseTensor *out) { -- bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU; -+ bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM; - PADDLE_ENFORCE_EQ( - is_gpu_place, - true, diff --git a/paddle/phi/kernels/gpu/depthwise_conv.h b/paddle/phi/kernels/gpu/depthwise_conv.h index f0cca0f701..02ea957240 100644 --- a/paddle/phi/kernels/gpu/depthwise_conv.h @@ -897,19 +884,6 @@ index f0cca0f701..02ea957240 100644 namespace phi { // To determine use cudnn or not. -diff --git a/paddle/phi/kernels/gpu/dgc_kernel.cu b/paddle/phi/kernels/gpu/dgc_kernel.cu -index c2ddfa1347..c6adf5a6de 100644 ---- a/paddle/phi/kernels/gpu/dgc_kernel.cu -+++ b/paddle/phi/kernels/gpu/dgc_kernel.cu -@@ -188,7 +188,7 @@ void DGCKernel(const Context& dev_ctx, - int buf_size = paddle::communication::dgc::get_buffer_size(k); - phi::Allocator::AllocationPtr tmp_ious_data; - #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -- if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) { -+ if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) { - tmp_ious_data = phi::memory_utils::Alloc( - dev_ctx.GetPlace(), - buf_size, diff --git a/paddle/phi/kernels/gpu/gelu_funcs.h b/paddle/phi/kernels/gpu/gelu_funcs.h index 29fa252e96..4ae72b0935 100644 --- a/paddle/phi/kernels/gpu/gelu_funcs.h @@ -974,19 +948,6 @@ index 1bdbe1564c..f753b54bc6 100644 #include "paddle/phi/kernels/impl/qr_kernel_impl.h" #include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h" #include "paddle/phi/kernels/lstsq_kernel.h" -diff --git a/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu b/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu -index 05a977828f..5136608c41 100644 ---- a/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu -+++ b/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu -@@ -58,7 +58,7 @@ void ShuffleBatchKernel(const Context& dev_ctx, - int64_t seed_int = 0; - if (seed.initialized()) { - const auto& seed_place = seed.place().GetType(); -- bool is_gpu_place = seed_place == phi::AllocationType::GPU; -+ bool is_gpu_place = seed_place == phi::AllocationType::GPU || seed_place == phi::AllocationType::CUSTOM; - if (is_gpu_place) { - // NOTE: We have overwritten GetKernelTypeForVar, so seed_place would - // not be CUDAPlace in practice. This case would only happen in Python diff --git a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h b/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h index 9bc5326c90..79b57a8203 100644 --- a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h @@ -1144,32 +1105,6 @@ index 6f03f76eeb..5fe2c3e7dc 100644 #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/matrix_inverse.h" -diff --git a/paddle/phi/kernels/impl/merged_momentum_impl.h b/paddle/phi/kernels/impl/merged_momentum_impl.h -index 7b85903776..3f4b298807 100644 ---- a/paddle/phi/kernels/impl/merged_momentum_impl.h -+++ b/paddle/phi/kernels/impl/merged_momentum_impl.h -@@ -297,7 +297,7 @@ void MergedMomentumInnerCompute( - params_out[idx], - velocities_out[idx]); - VLOG(10) << "Launch MergedMomentum cpu kernel."; -- } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) { -+ } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) { - phi::funcs::ForRange for_range( - static_cast(dev_ctx), params[idx]->numel()); - const auto grad_type = grads[idx]->dtype(); -diff --git a/paddle/phi/kernels/impl/momentum_kernel_impl.h b/paddle/phi/kernels/impl/momentum_kernel_impl.h -index de5bcfc30b..eb2a9714f5 100644 ---- a/paddle/phi/kernels/impl/momentum_kernel_impl.h -+++ b/paddle/phi/kernels/impl/momentum_kernel_impl.h -@@ -457,7 +457,7 @@ void MomentumDenseImpl(const Context& dev_ctx, - regularization_coeff, - param_out, - velocity_out); -- } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) { -+ } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) { - funcs::ForRange for_range(dev_ctx, param.numel()); - const auto grad_type = grad.dtype(); - #define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \ diff --git a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h index 4099d8b506..baef2cd643 100644 --- a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h diff --git a/backends/metax_gpu/tests/CMakeLists.txt b/backends/metax_gpu/tests/CMakeLists.txt index e8b11d347d9..0c84ada4b65 100755 --- a/backends/metax_gpu/tests/CMakeLists.txt +++ b/backends/metax_gpu/tests/CMakeLists.txt @@ -9,6 +9,8 @@ set(PADDLE_LEGACY_TEST_PATH ${CMAKE_CURRENT_LIST_DIR}/../../../Paddle/test/legacy_test) set(METAX_UNIT_TEST_PATH ${CMAKE_CURRENT_LIST_DIR}/unit_test) +set(NEED_REMOVE_KEYWORDS "attention") + file(GLOB_RECURSE PYTHON_TEST_SCRIPTS "${METAX_UNIT_TEST_PATH}/*.py") if(NOT TEST_LIST_FILE) @@ -33,6 +35,20 @@ else() endif() foreach(test_name ${TEST_PROGRAMS}) + set(IS_REMOVE FALSE) + + foreach(keyword ${NEED_REMOVE_KEYWORDS}) + string(FIND "${test_name}" "${keyword}" RES) + if(NOT RES EQUAL -1) + set(IS_REMOVE TRUE) + break() + endif() + endforeach() + + if(IS_REMOVE) + continue() + endif() + set(CURRENT_TEST_PROGRAM ${PADDLE_LEGACY_TEST_PATH}/${test_name}.py) if(NOT EXISTS ${CURRENT_TEST_PROGRAM}) message(WARNING "${CURRENT_TEST_PROGRAM} is not exist, skip it.") @@ -44,39 +60,19 @@ endforeach() list(REMOVE_DUPLICATES PYTHON_TEST_SCRIPTS) if(NOT TEST_LIST_FILE) - list( - REMOVE_ITEM - PYTHON_TEST_SCRIPTS - # Metax unit test - ${METAX_UNIT_TEST_PATH}/test_matmul_op_metax.py - # 精度问题 - ${PADDLE_LEGACY_TEST_PATH}/test_sum_op.py - ${PADDLE_LEGACY_TEST_PATH}/test_max_op.py - ${PADDLE_LEGACY_TEST_PATH}/test_cumsum_op.py - # core.cudnnversion - ${PADDLE_LEGACY_TEST_PATH}/test_softmax_with_cross_entropy_op.py - ${PADDLE_LEGACY_TEST_PATH}/test_softmax_op.py - ${PADDLE_LEGACY_TEST_PATH}/test_elementwise_add_op.py - ${PADDLE_LEGACY_TEST_PATH}/test_gather_op.py - # op_test.py 里 self._get_places()接口的适配问题 - ${PADDLE_LEGACY_TEST_PATH}/test_elementwise_pow_op.py - ${PADDLE_LEGACY_TEST_PATH}/test_layer_norm_op.py - # device == "gpu" 适配问题 - ${PADDLE_LEGACY_TEST_PATH}/test_index_add_op.py - # paddle-gpu 报错一致 - ${PADDLE_LEGACY_TEST_PATH}/test_elementwise_div_op.py - ${PADDLE_LEGACY_TEST_PATH}/test_stack_op.py - ${PADDLE_LEGACY_TEST_PATH}/test_logical_op.py - ${PADDLE_LEGACY_TEST_PATH}/test_mean_op.py - # paddle.device.cuda.get_device_properties - ${PADDLE_LEGACY_TEST_PATH}/test_transpose_op.py - ${PADDLE_LEGACY_TEST_PATH}/test_randint_op.py - ${PADDLE_LEGACY_TEST_PATH}/test_uniform_random_op.py - # needs check_grad with fp64 precision - ${PADDLE_LEGACY_TEST_PATH}/test_c_embedding_op.py - # CUDAPinnedPlace 问题 - ${PADDLE_LEGACY_TEST_PATH}/test_slice_op.py - ${PADDLE_LEGACY_TEST_PATH}/test_compare_op.py) + set(NEED_IGNORE_FILE ${CMAKE_CURRENT_LIST_DIR}/ignore.txt) + if(EXISTS ${NEED_IGNORE_FILE}) + file(STRINGS ${NEED_IGNORE_FILE} NEED_IGNORE_TEST_PROGRAMS) + foreach(test_name ${NEED_IGNORE_TEST_PROGRAMS}) + if(EXISTS ${PADDLE_LEGACY_TEST_PATH}/${test_name}.py) + list(REMOVE_ITEM PYTHON_TEST_SCRIPTS + ${PADDLE_LEGACY_TEST_PATH}/${test_name}.py) + else() + list(REMOVE_ITEM PYTHON_TEST_SCRIPTS + ${METAX_UNIT_TEST_PATH}/${test_name}.py) + endif() + endforeach() + endif() endif() if(LOG_OUTPUT_DIR AND NOT EXISTS ${LOG_OUTPUT_DIR}) diff --git a/backends/metax_gpu/tests/ignore.txt b/backends/metax_gpu/tests/ignore.txt new file mode 100644 index 00000000000..b4f1afbe5b0 --- /dev/null +++ b/backends/metax_gpu/tests/ignore.txt @@ -0,0 +1,21 @@ +test_matmul_op_metax +test_sum_op +test_max_op +test_cumsum_op +test_softmax_with_cross_entropy_op +test_softmax_op +test_elementwise_add_op +test_gather_op +test_elementwise_pow_op +test_layer_norm_op +test_index_add_op +test_elementwise_div_op +test_stack_op +test_logical_op +test_mean_op +test_transpose_op +test_randint_op +test_uniform_random_op +test_c_embedding_op +test_slice_op +test_compare_op