Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
69f3721
[fix] fix fail test when backend is mack
zhang-chenyi Sep 4, 2025
e45d324
[Metax] fix fail test when backend is mack
metax666 Sep 4, 2025
ef9d554
Merge branch 'PaddlePaddle:develop' into develop
metax666 Sep 4, 2025
a1530d2
[metax]change_cupti_and_fix_softmax (#7)
duqimeng Sep 9, 2025
352f02e
[Metax] fix dgc & mklml compile product path problem (#8)
StareAtYou Sep 9, 2025
8f13fae
[Metax] fix accuracy kernel & add test_accuracy_op_metax.py unit test…
StareAtYou Sep 11, 2025
8938293
[Metax] update metax_gpu CMakeLists.txt (#10)
StareAtYou Sep 11, 2025
f54187f
[metax] updata_qr_kernel (#11)
duqimeng Sep 11, 2025
7964c35
Merge branch 'PaddlePaddle:develop' into develop
metax666 Sep 12, 2025
1e04216
[Metax] fix illegal address access error in test_momentum_op (#12)
StareAtYou Sep 15, 2025
aca80a4
[Metax] fix cufft and fix some blas kernel apply (#13)
duqimeng Sep 15, 2025
1c54010
Merge branch 'PaddlePaddle:develop' into develop
metax666 Sep 15, 2025
fb547db
[metax] add warpctc_warprnn (#14)
duqimeng Sep 15, 2025
8e98198
[Metax] update metax CI (#15)
StareAtYou Sep 15, 2025
528ec55
[Metax] update metax CI CMakeLists (#16)
StareAtYou Sep 16, 2025
5b31405
[Metax] add github action (#18)
duqimeng Sep 16, 2025
b93c971
[metax] chang build (#19)
duqimeng Sep 16, 2025
6dbbe84
change_build (#20)
duqimeng Sep 16, 2025
ef1b28e
change_build (#21)
duqimeng Sep 16, 2025
3737e48
change_build (#22)
duqimeng Sep 16, 2025
16f3584
【metax】modify cmake for warpctc and warprnnt (#17)
jxwangmetax Sep 16, 2025
ce54693
[metax]modify library to static library (#24)
jxwangmetax Sep 16, 2025
4cda637
[Metax] organize documents (#25)
StareAtYou Sep 16, 2025
23fca59
[metax]fix_code style and index_elementwise_put_kernel (#27)
duqimeng Sep 17, 2025
a513aae
change_build_917 (#29)
duqimeng Sep 17, 2025
4eb455e
chang_build (#30)
duqimeng Sep 17, 2025
1773978
[metax]modify kernel (#31)
jxwangmetax Sep 17, 2025
69af381
change_metax_work (#32)
duqimeng Sep 17, 2025
7fe6f2d
change_build (#33)
duqimeng Sep 17, 2025
b22fc13
[metax] modify fused_bias_dropout_residual_layer_norm (#34)
jxwangmetax Sep 17, 2025
c3d1444
change_build (#35)
duqimeng Sep 17, 2025
569a867
change_build (#36)
duqimeng Sep 17, 2025
0edc6f6
change_warpctc.cmake (#38)
duqimeng Sep 18, 2025
2688c86
change_warpctc.cmake (#39)
duqimeng Sep 18, 2025
6f031fe
test (#40)
duqimeng Sep 18, 2025
e84d399
test_ut (#41)
duqimeng Sep 18, 2025
b5f2feb
tets (#43)
duqimeng Sep 18, 2025
e20eca7
test (#44)
duqimeng Sep 18, 2025
e37f633
[metax] modify compile (#42)
jxwangmetax Sep 19, 2025
1af5148
[Metax] add log analysis script (#46)
StareAtYou Sep 19, 2025
518bee8
add_generate_pb (#47)
duqimeng Sep 19, 2025
bc02549
modify blas (#51)
jxwangmetax Sep 22, 2025
1977ca8
[metax] modify tf32 (#52)
jxwangmetax Sep 22, 2025
1ae2618
[Metax] update metax backend CI test (#53)
StareAtYou Sep 22, 2025
76d5eb0
[Metax] fix log_analysis.py bug (#54)
StareAtYou Sep 23, 2025
9c17b6e
[Metax] update metax CI CMakeLists & scripts (#56)
StareAtYou Sep 23, 2025
51c98a2
[Metax] fix MatmulKernel problem (#57)
StareAtYou Sep 23, 2025
d113018
[metax]fix paddle bug" (#58)
duqimeng Sep 23, 2025
8991299
change—ut (#59)
duqimeng Sep 23, 2025
a770e6f
change_ut (#60)
duqimeng Sep 23, 2025
902112b
change_ut (#63)
duqimeng Sep 24, 2025
9a88a09
Merge branch 'PaddlePaddle:develop' into develop
metax666 Sep 24, 2025
4ae65f7
Merge branch 'PaddlePaddle:develop' into develop
metax666 Sep 24, 2025
61c32ba
[Metax] add keyword filter in CI CMakeLists.txt
StareAtYou Sep 24, 2025
642eb37
Merge branch 'metax666:develop' into develop
StareAtYou Sep 25, 2025
b2ddc81
[Metax] add ignore case list
StareAtYou Sep 25, 2025
cfe44ce
[Metax] add keyword filter in CI CMakeLists.txt (#64)
StareAtYou Sep 25, 2025
041e585
Merge branch 'metax666:develop' into develop
StareAtYou Sep 25, 2025
087a9c1
[Metax] fix phi::backends::gpu::DnnVersion() symbol not found
StareAtYou Sep 26, 2025
73710c5
Revert "[Metax] fix phi::backends::gpu::DnnVersion() symbol not found"
StareAtYou Sep 26, 2025
78946fd
[metax] modify kernels (#67)
jxwangmetax Sep 26, 2025
ac78af2
Fix part of the missing kernel issues (#66)
Theendlessofhell Sep 26, 2025
404ff3d
[Metax] fix index_elementwise_get kernel
StareAtYou Sep 26, 2025
4ce9fe6
[Metax] fix index_elementwise_get kernel (#68)
StareAtYou Sep 26, 2025
739c5c7
Merge branch 'metax666:develop' into develop
StareAtYou Sep 28, 2025
3c8d017
[metax]fix patch and fix missing kernel (#72)
duqimeng Sep 29, 2025
35a4e49
Merge branch 'metax666:develop' into develop
StareAtYou Sep 29, 2025
453fda5
Update Paddle submodule to latest develop
tianshuo78520a Sep 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Paddle
Submodule Paddle updated 516 files
5 changes: 4 additions & 1 deletion backends/metax_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/im2sequence_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/im2sequence_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/increment_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu
# ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_get_grad_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_put_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu
Expand Down Expand Up @@ -535,6 +535,7 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/clip_by_norm_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/uniform_random_batch_size_like_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/get_tensor_from_selected_rows_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adam_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/batch_norm_grad_kernel.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/batch_norm_kernel.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/sparse/empty_kernel.cc
Expand Down Expand Up @@ -642,6 +643,8 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gumbel_softmax_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/top_p_sampling_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/rms_norm_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/lars_momentum_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/partial_sum_kernel.cu
# ############################################################################
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu
# kernels/kps
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
#include "paddle/phi/kernels/selected_rows/adam_kernel.h"

PD_CUSTOM_KERNEL_REGISTER(adam_dense_param_sparse_grad,
metax_gpu,
ALL_LAYOUT,
phi::sr::AdamDenseParamSparseGradKernel,
float,
double,
phi::float16) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(7).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(9).SetBackend(phi::Backend::ALL_BACKEND);

if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(5).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(6).SetDataType(phi::DataType::FLOAT32);
}
kernel->OutputAt(4).SetBackend(phi::Backend::UNDEFINED);
kernel->OutputAt(5).SetBackend(phi::Backend::UNDEFINED);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@ PD_CUSTOM_KERNEL_REGISTER(einsum,
phi::EinsumKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::float16,
phi::bfloat16,
phi::complex64,
phi::complex128) {}

PD_CUSTOM_KERNEL_REGISTER(einsum_infer,
metax_gpu,
ALL_LAYOUT,
phi::EinsumInferKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::float16,
phi::bfloat16,
phi::complex64,
phi::complex128) {}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/index_elementwise_get_kernel.h"
#include "paddle/phi/kernels/gpu/index_elementwise_get_kernel.cu" // NOLINT

PD_CUSTOM_KERNEL_REGISTER(index_elementwise_get,
metax_gpu,
Expand All @@ -27,7 +27,7 @@ PD_CUSTOM_KERNEL_REGISTER(index_elementwise_get,
int64_t,
int16_t,
uint8_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::float16,
phi::bfloat16,
phi::complex64,
phi::complex128) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/lars_momentum_kernel.h"

PD_CUSTOM_KERNEL_REGISTER(lars_momentum,
metax_gpu,
ALL_LAYOUT,
phi::LarsMomentumKernel,
float,
double,
phi::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ PD_CUSTOM_KERNEL_REGISTER(multinomial,
phi::MultinomialKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float) {
float,
double) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ PD_CUSTOM_KERNEL_REGISTER(nonzero,
int64_t,
int,
int16_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::float16,
phi::bfloat16,
bool,
float,
double) {
double,
phi::complex64,
phi::complex128) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ PD_CUSTOM_KERNEL_REGISTER(put_along_axis,
float,
double,
int64_t,
uint8_t,
int16_t,
int,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::float16,
phi::bfloat16) {}
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,7 @@ PD_CUSTOM_KERNEL_REGISTER(take_along_axis,
int64_t,
int,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
uint8_t, // 支持 uint8
int16_t // 支持 int16
) {}
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ PD_REGISTER_PLUGIN_KERNEL(addmm,
ALL_LAYOUT,
phi::AddmmKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ PD_REGISTER_PLUGIN_KERNEL(layer_norm_grad,
ALL_LAYOUT,
phi::LayerNormGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
Expand Down
26 changes: 0 additions & 26 deletions backends/metax_gpu/kernels/metax_kernel/metax_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,6 @@
#include "kernels/metax_kernel/metax_context.h"

namespace phi {
const bool allow_tf32_cublas = []() -> bool {
const char* v = std::getenv("ALLOW_TF32_CUBLAS");
if (v) {
return std::atoi(v);
}
return true;
}();

const bool allow_tf32_cudnn = []() -> bool {
const char* v = std::getenv("ALLOW_TF32_CUDNN");
if (v) {
return std::atoi(v);
}
return false;
}();

bool AllowTF32Cublas() { return allow_tf32_cublas; }
bool AllowTF32Cudnn() { return allow_tf32_cudnn; }
void DnnWorkspaceHandle::RunFuncSync(
const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes,
Expand All @@ -42,19 +24,11 @@ void DnnWorkspaceHandle::RunFuncSync(
void* workspace_ptr = nullptr;
size_t size = ((required_workspace_bytes + 255) >> 8) << 8;
std::lock_guard<std::mutex> guard(*mtx_);
#ifdef PADDLE_WITH_HIP
auto status = hipMalloc(&workspace_ptr, size);
#else
auto status = cudaMalloc(&workspace_ptr, size);
#endif
if (status == gpuSuccess) {
cudnn_func(workspace_ptr);
phi::backends::gpu::GpuStreamSync(stream_);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipFree(workspace_ptr));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaFree(workspace_ptr));
#endif
return;
}
}
Expand Down
3 changes: 1 addition & 2 deletions backends/metax_gpu/kernels/metax_kernel/metax_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <mutex>

#include "kernels/funcs/blas/cublasLt.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/forwards.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
Expand All @@ -30,8 +31,6 @@
cublasLtHandle_t GetBlasLtHandle();

namespace phi {
bool AllowTF32Cublas();
bool AllowTF32Cudnn();
class DnnWorkspaceHandle {
public:
inline DnnWorkspaceHandle(Allocator* allocator, gpuStream_t stream)
Expand Down
65 changes: 0 additions & 65 deletions backends/metax_gpu/patch/paddle.patch
Original file line number Diff line number Diff line change
Expand Up @@ -869,19 +869,6 @@ index e838778952..83e805e75a 100644

namespace phi {
namespace fusion {
diff --git a/paddle/phi/kernels/gpu/correlation_kernel.cu b/paddle/phi/kernels/gpu/correlation_kernel.cu
index 4c93778bde..c7bdf8a2cc 100644
--- a/paddle/phi/kernels/gpu/correlation_kernel.cu
+++ b/paddle/phi/kernels/gpu/correlation_kernel.cu
@@ -103,7 +103,7 @@ void CorrelationCUDAKernel(const Context &dev_ctx,
int stride2,
int corr_type_multiply,
DenseTensor *out) {
- bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU;
+ bool is_gpu_place = dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM;
PADDLE_ENFORCE_EQ(
is_gpu_place,
true,
diff --git a/paddle/phi/kernels/gpu/depthwise_conv.h b/paddle/phi/kernels/gpu/depthwise_conv.h
index f0cca0f701..02ea957240 100644
--- a/paddle/phi/kernels/gpu/depthwise_conv.h
Expand All @@ -897,19 +884,6 @@ index f0cca0f701..02ea957240 100644

namespace phi {
// To determine use cudnn or not.
diff --git a/paddle/phi/kernels/gpu/dgc_kernel.cu b/paddle/phi/kernels/gpu/dgc_kernel.cu
index c2ddfa1347..c6adf5a6de 100644
--- a/paddle/phi/kernels/gpu/dgc_kernel.cu
+++ b/paddle/phi/kernels/gpu/dgc_kernel.cu
@@ -188,7 +188,7 @@ void DGCKernel(const Context& dev_ctx,
int buf_size = paddle::communication::dgc::get_buffer_size(k);
phi::Allocator::AllocationPtr tmp_ious_data;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
- if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) {
+ if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) {
tmp_ious_data = phi::memory_utils::Alloc(
dev_ctx.GetPlace(),
buf_size,
diff --git a/paddle/phi/kernels/gpu/gelu_funcs.h b/paddle/phi/kernels/gpu/gelu_funcs.h
index 29fa252e96..4ae72b0935 100644
--- a/paddle/phi/kernels/gpu/gelu_funcs.h
Expand Down Expand Up @@ -974,19 +948,6 @@ index 1bdbe1564c..f753b54bc6 100644
#include "paddle/phi/kernels/impl/qr_kernel_impl.h"
#include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h"
#include "paddle/phi/kernels/lstsq_kernel.h"
diff --git a/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu b/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu
index 05a977828f..5136608c41 100644
--- a/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu
+++ b/paddle/phi/kernels/gpu/shuffle_batch_kernel.cu
@@ -58,7 +58,7 @@ void ShuffleBatchKernel(const Context& dev_ctx,
int64_t seed_int = 0;
if (seed.initialized()) {
const auto& seed_place = seed.place().GetType();
- bool is_gpu_place = seed_place == phi::AllocationType::GPU;
+ bool is_gpu_place = seed_place == phi::AllocationType::GPU || seed_place == phi::AllocationType::CUSTOM;
if (is_gpu_place) {
// NOTE: We have overwritten GetKernelTypeForVar, so seed_place would
// not be CUDAPlace in practice. This case would only happen in Python
diff --git a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h b/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h
index 9bc5326c90..79b57a8203 100644
--- a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h
Expand Down Expand Up @@ -1144,32 +1105,6 @@ index 6f03f76eeb..5fe2c3e7dc 100644
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h"

diff --git a/paddle/phi/kernels/impl/merged_momentum_impl.h b/paddle/phi/kernels/impl/merged_momentum_impl.h
index 7b85903776..3f4b298807 100644
--- a/paddle/phi/kernels/impl/merged_momentum_impl.h
+++ b/paddle/phi/kernels/impl/merged_momentum_impl.h
@@ -297,7 +297,7 @@ void MergedMomentumInnerCompute(
params_out[idx],
velocities_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel.";
- } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) {
+ } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) {
phi::funcs::ForRange<Context> for_range(
static_cast<const Context &>(dev_ctx), params[idx]->numel());
const auto grad_type = grads[idx]->dtype();
diff --git a/paddle/phi/kernels/impl/momentum_kernel_impl.h b/paddle/phi/kernels/impl/momentum_kernel_impl.h
index de5bcfc30b..eb2a9714f5 100644
--- a/paddle/phi/kernels/impl/momentum_kernel_impl.h
+++ b/paddle/phi/kernels/impl/momentum_kernel_impl.h
@@ -457,7 +457,7 @@ void MomentumDenseImpl(const Context& dev_ctx,
regularization_coeff,
param_out,
velocity_out);
- } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU) {
+ } else if (dev_ctx.GetPlace().GetType() == phi::AllocationType::GPU || dev_ctx.GetPlace().GetType() == phi::AllocationType::CUSTOM) {
funcs::ForRange<Context> for_range(dev_ctx, param.numel());
const auto grad_type = grad.dtype();
#define PADDLE_LAUNCH_DENSE_MOMENTUM_KERNEL(__nesterov, __reg_type) \
diff --git a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h b/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h
index 4099d8b506..baef2cd643 100644
--- a/paddle/phi/kernels/impl/spectral_norm_kernel_impl.h
Expand Down
Loading