1515#pragma once
1616
1717#include " paddle/phi/core/tensor_utils.h"
18+ #include " paddle/phi/kernels/cast_kernel.h"
1819#include " paddle/phi/kernels/full_kernel.h"
1920#include " paddle/phi/kernels/funcs/eigen/common.h"
2021#include " paddle/phi/kernels/funcs/eigen/eigen_function.h"
@@ -32,19 +33,50 @@ void ExpandBackward(const Context& dev_ctx,
3233 dev_ctx.template Alloc <T>(in_grad);
3334 in_grad->data <T>();
3435
35- auto x_grad = EigenVector<T>::Flatten (*in_grad);
36- Eigen::DSizes<Eigen::DenseIndex, Dims * 2 > reshape_dims;
37- for (size_t i = 0 ; i < reshape_size; ++i) {
38- reshape_dims[i] = reshape_dims_vec[i];
39- }
40- Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
41- for (size_t i = 0 ; i < reduce_size; ++i) {
42- reduce_dims[i] = reduce_dims_vec[i];
36+ if constexpr (std::is_same_v<T, dtype::float16> ||
37+ std::is_same_v<T, dtype::bfloat16>) {
38+ const DenseTensor out_grad_fp32 =
39+ phi::Cast<T, Context>(dev_ctx, out_grad, DataType::FLOAT32);
40+ DenseTensor in_grad_fp32;
41+ in_grad_fp32.Resize (in_grad->dims ());
42+ dev_ctx.template Alloc <float >(&in_grad_fp32);
43+
44+ auto x_grad = EigenVector<float >::Flatten (in_grad_fp32);
45+ Eigen::DSizes<Eigen::DenseIndex, Dims * 2 > reshape_dims;
46+ for (size_t i = 0 ; i < reshape_size; ++i) {
47+ reshape_dims[i] = reshape_dims_vec[i];
48+ }
49+ Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
50+ for (size_t i = 0 ; i < reduce_size; ++i) {
51+ reduce_dims[i] = reduce_dims_vec[i];
52+ }
53+ const auto out_grad0 = EigenVector<float >::Flatten (out_grad_fp32);
54+ auto & place = *dev_ctx.eigen_device ();
55+ phi::funcs::EigenBroadcastGrad<std::decay_t <decltype (place)>, float , Dims>::
56+ Eval (place, x_grad, out_grad0, reduce_dims, reshape_dims);
57+
58+ if constexpr (std::is_same_v<T, dtype::float16>) {
59+ phi::CastKernel<float , Context>(
60+ dev_ctx, in_grad_fp32, DataType::FLOAT16, in_grad);
61+ } else {
62+ phi::CastKernel<float , Context>(
63+ dev_ctx, in_grad_fp32, DataType::BFLOAT16, in_grad);
64+ }
65+ } else {
66+ auto x_grad = EigenVector<T>::Flatten (*in_grad);
67+ Eigen::DSizes<Eigen::DenseIndex, Dims * 2 > reshape_dims;
68+ for (size_t i = 0 ; i < reshape_size; ++i) {
69+ reshape_dims[i] = reshape_dims_vec[i];
70+ }
71+ Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
72+ for (size_t i = 0 ; i < reduce_size; ++i) {
73+ reduce_dims[i] = reduce_dims_vec[i];
74+ }
75+ auto out_grad0 = EigenVector<T>::Flatten (out_grad);
76+ auto & place = *dev_ctx.eigen_device ();
77+ phi::funcs::EigenBroadcastGrad<std::decay_t <decltype (place)>, T, Dims>::
78+ Eval (place, x_grad, out_grad0, reduce_dims, reshape_dims);
4379 }
44- auto out_grad0 = EigenVector<T>::Flatten (out_grad);
45- auto & place = *dev_ctx.eigen_device ();
46- phi::funcs::EigenBroadcastGrad<std::decay_t <decltype (place)>, T, Dims>::Eval (
47- place, x_grad, out_grad0, reduce_dims, reshape_dims);
4880}
4981
5082template <typename T, typename Context>
0 commit comments