diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 54f56f2bd93613..e27587c8596f02 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -668,7 +668,9 @@ XPUOpMap& get_kl3_ops() { phi::DataType::BFLOAT16, phi::DataType::FLOAT16})}, {"mean_grad", - XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + XPUKernelSet({phi::DataType::FLOAT32, + phi::DataType::FLOAT16, + phi::DataType::BFLOAT16})}, {"mean", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, diff --git a/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc index 37ace904b2b807..de5b4718e98603 100644 --- a/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_mean_grad_kernel.cc @@ -89,4 +89,5 @@ PD_REGISTER_KERNEL(mean_grad, ALL_LAYOUT, phi::ReduceMeanGradKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {}