Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim] Optimize the use of the reshape operator in eager composite backward #69515

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 108 additions & 42 deletions paddle/fluid/prim/api/composite_backward/composite_backward_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,14 @@ void subtract_grad(const Tensor& x,
by_pass<T>(scale_out_grad, dy);
} else {
auto dy_reduce_res =
scale_out_grad.sum(common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
scale_out_grad.sum(common::vectorize(reduce_dim),
y.dtype(),
scale_out_grad.dims().size() == y.dims().size());
if (dy_reduce_res.dims() != y.dims()) {
dy_reduce_res =
reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
}
set_output<T>(dy_reduce_res, dy);
}
} else {
by_pass<T>(scale_out_grad, dy);
Expand All @@ -307,9 +312,14 @@ void subtract_grad(const Tensor& x,
by_pass<T>(out_grad, dx);
} else {
auto dx_reduce_res =
out_grad.sum(common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);
out_grad.sum(common::vectorize(reduce_dim),
x.dtype(),
out_grad.dims().size() == x.dims().size());
if (dx_reduce_res.dims() != x.dims()) {
dx_reduce_res =
reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
}
set_output<T>(dx_reduce_res, dx);
}
} else {
by_pass<T>(out_grad, dx);
Expand All @@ -332,9 +342,14 @@ void add_grad(const Tensor& x,
by_pass<T>(out_grad, dy);
} else {
auto dy_reduce_res =
out_grad.sum(common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
out_grad.sum(common::vectorize(reduce_dim),
y.dtype(),
out_grad.dims().size() == y.dims().size());
if (dy_reduce_res.dims() != y.dims()) {
dy_reduce_res =
reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
}
set_output<T>(dy_reduce_res, dy);
}
} else {
by_pass<T>(out_grad, dy);
Expand All @@ -348,9 +363,14 @@ void add_grad(const Tensor& x,
by_pass<T>(out_grad, dx);
} else {
auto dx_reduce_res =
out_grad.sum(common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);
out_grad.sum(common::vectorize(reduce_dim),
x.dtype(),
out_grad.dims().size() == x.dims().size());
if (dx_reduce_res.dims() != x.dims()) {
dx_reduce_res =
reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
}
set_output<T>(dx_reduce_res, dx);
}
} else {
by_pass<T>(out_grad, dx);
Expand Down Expand Up @@ -424,9 +444,14 @@ void divide_grad(const Tensor& x,
set_output<T>(dy_res, dy);
} else {
auto dy_reduce_res =
dy_res.sum(common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
dy_res.sum(common::vectorize(reduce_dim),
y.dtype(),
dy_res.dims().size() == y.dims().size());
if (dy_reduce_res.dims() != y.dims()) {
dy_reduce_res =
reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
}
set_output<T>(dy_reduce_res, dy);
}
} else {
set_output<T>(dy_res, dy);
Expand All @@ -442,9 +467,14 @@ void divide_grad(const Tensor& x,
set_output<T>(dx_res, dx);
} else {
auto dx_reduce_res =
dx_res.sum(common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);
dx_res.sum(common::vectorize(reduce_dim),
x.dtype(),
dx_res.dims().size() == x.dims().size());
if (dx_reduce_res.dims() != x.dims()) {
dx_reduce_res =
reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
}
set_output<T>(dx_reduce_res, dx);
}
} else {
set_output<T>(dx_res, dx);
Expand All @@ -470,9 +500,14 @@ void elementwise_pow_grad(const Tensor& x,
set_output<T>(dy_res, dy);
} else {
auto dy_reduce_res =
dy_res.sum(common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, dy);
dy_res.sum(common::vectorize(reduce_dim),
y.dtype(),
dy_res.dims().size() == y.dims().size());
if (dy_reduce_res.dims() != y.dims()) {
dy_reduce_res =
reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
}
set_output<T>(dy_reduce_res, dy);
}
} else {
set_output<T>(dy_res, dy);
Expand All @@ -490,9 +525,14 @@ void elementwise_pow_grad(const Tensor& x,
set_output<T>(dx_res, dx);
} else {
auto dx_reduce_res =
dx_res.sum(common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, dx);
dx_res.sum(common::vectorize(reduce_dim),
x.dtype(),
dx_res.dims().size() == x.dims().size());
if (dx_reduce_res.dims() != x.dims()) {
dx_reduce_res =
reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
}
set_output<T>(dx_reduce_res, dx);
}

} else {
Expand Down Expand Up @@ -606,8 +646,10 @@ void multiply_grad(const Tensor& x,
set_output<T>(x_grad_unreduce, x_grad);
} else {
auto x_grad_reduced = x_grad_unreduce.sum(
common::vectorize(axes), x_grad_unreduce.dtype(), false);
if (x_grad_reduced.dims().size() != x.dims().size()) {
common::vectorize(axes),
x_grad_unreduce.dtype(),
x_grad_unreduce.dims().size() == x.dims().size());
if (x_grad_reduced.dims() != x.dims()) {
x_grad_reduced = reshape<T>(x_grad_reduced, x.shape());
}
set_output<T>(x_grad_reduced, x_grad);
Expand All @@ -624,8 +666,10 @@ void multiply_grad(const Tensor& x,
set_output<T>(y_grad_unreduce, y_grad);
} else {
auto y_grad_reduced = y_grad_unreduce.sum(
common::vectorize(axes), y_grad_unreduce.dtype(), false);
if (y_grad_reduced.dims().size() != y.dims().size()) {
common::vectorize(axes),
y_grad_unreduce.dtype(),
y_grad_unreduce.dims().size() != y.dims().size());
if (y_grad_reduced.dims() != y.dims()) {
y_grad_reduced = reshape<T>(y_grad_reduced, y.shape());
}
set_output<T>(y_grad_reduced, y_grad);
Expand All @@ -648,8 +692,10 @@ void expand_grad(const Tensor& x,
if (!axes.size()) {
by_pass<T>(out_grad, x_grad);
} else {
auto reduced = out_grad.sum(common::vectorize(axes), x.dtype(), false);
if (reduced.dims().size() != x.dims().size()) {
auto reduced = out_grad.sum(common::vectorize(axes),
x.dtype(),
out_grad.dims().size() == x.dims().size());
if (reduced.dims() != x.dims()) {
reduced = reshape<T>(reduced, x.shape());
}
set_output<T>(reduced, x_grad);
Expand Down Expand Up @@ -1377,9 +1423,14 @@ void maximum_grad(const Tensor& x,
set_output<T>(dx_res, x_grad);
} else {
auto dx_reduce_res =
dx_res.sum(common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, x_grad);
dx_res.sum(common::vectorize(reduce_dim),
x.dtype(),
dx_res.dims().size() == x.dims().size());
if (dx_reduce_res.dims() != x.dims()) {
dx_reduce_res =
reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
}
set_output<T>(dx_reduce_res, x_grad);
}
} else {
set_output<T>(dx_res, x_grad);
Expand All @@ -1396,9 +1447,14 @@ void maximum_grad(const Tensor& x,
set_output<T>(dy_res, y_grad);
} else {
auto dy_reduce_res =
dy_res.sum(common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, y_grad);
dy_res.sum(common::vectorize(reduce_dim),
y.dtype(),
dy_res.dims().size() == y.dims().size());
if (dy_reduce_res.dims() != y.dims()) {
dy_reduce_res =
reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
}
set_output<T>(dy_reduce_res, y_grad);
}
} else {
set_output<T>(dy_res, y_grad);
Expand Down Expand Up @@ -1851,9 +1907,14 @@ void minimum_grad(const Tensor& x,
set_output<T>(dx_res, x_grad);
} else {
auto dx_reduce_res =
dx_res.sum(common::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
set_output<T>(dx_tmp, x_grad);
dx_res.sum(common::vectorize(reduce_dim),
x.dtype(),
dx_res.dims().size() == x.dims().size());
if (dx_reduce_res.dims() != x.dims()) {
dx_reduce_res =
reshape<T>(dx_reduce_res, common::vectorize(x.dims()));
}
set_output<T>(dx_reduce_res, x_grad);
}
} else {
set_output<T>(dx_res, x_grad);
Expand All @@ -1870,9 +1931,14 @@ void minimum_grad(const Tensor& x,
set_output<T>(dy_res, y_grad);
} else {
auto dy_reduce_res =
dy_res.sum(common::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
set_output<T>(dy_tmp, y_grad);
dy_res.sum(common::vectorize(reduce_dim),
y.dtype(),
dy_res.dims().size() == y.dims().size());
if (dy_reduce_res.dims() != y.dims()) {
dy_reduce_res =
reshape<T>(dy_reduce_res, common::vectorize(y.dims()));
}
set_output<T>(dy_reduce_res, y_grad);
}
} else {
set_output<T>(dy_res, y_grad);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,10 @@ void multiply_double_grad(const Tensor& x,
if (!axes.size()) {
set_output<T>(dx, x_grad);
} else {
auto dx_reduce = dx.sum(common::vectorize(axes), dx.dtype(), false);
if (dx_reduce.dims().size() != x.dims().size()) {
auto dx_reduce = dx.sum(common::vectorize(axes),
dx.dtype(),
dx.dims().size() == x.dims().size());
if (dx_reduce.dims() != x.dims()) {
dx_reduce = reshape<T>(dx_reduce, x.shape());
}
set_output<T>(dx_reduce, x_grad);
Expand All @@ -682,8 +684,10 @@ void multiply_double_grad(const Tensor& x,
if (!axes.size()) {
set_output<T>(dy, y_grad);
} else {
auto dy_reduce = dy.sum(common::vectorize(axes), dy.dtype(), false);
if (dy_reduce.dims().size() != y.dims().size()) {
auto dy_reduce = dy.sum(common::vectorize(axes),
dy.dtype(),
dy.dims().size() == y.dims().size());
if (dy_reduce.dims() != y.dims()) {
dy_reduce = reshape<T>(dy_reduce, y.shape());
}
set_output<T>(dy_reduce, y_grad);
Expand Down Expand Up @@ -754,11 +758,16 @@ void add_triple_grad(const paddle::optional<Tensor>& grad_grad_x,
if (!reduce_dim.size()) {
by_pass<T>(grad_grad_out_grad, grad_grad_y_grad);
} else {
auto dddy_reduce_res = grad_grad_out_grad.sum(
common::vectorize(reduce_dim), grad_grad_y.get().dtype(), false);
auto dddy_tmp = reshape<T>(
dddy_reduce_res, common::vectorize(grad_grad_y.get().dims()));
set_output<T>(dddy_tmp, grad_grad_y_grad);
auto dddy_reduce_res =
grad_grad_out_grad.sum(common::vectorize(reduce_dim),
grad_grad_y.get().dtype(),
grad_grad_out_grad.dims().size() ==
grad_grad_y.get().dims().size());
if (dddy_reduce_res.dims() != grad_grad_y.get().dims()) {
dddy_reduce_res = reshape<T>(
dddy_reduce_res, common::vectorize(grad_grad_y.get().dims()));
}
set_output<T>(dddy_reduce_res, grad_grad_y_grad);
}
} else {
by_pass<T>(grad_grad_out_grad, grad_grad_y_grad);
Expand All @@ -774,11 +783,16 @@ void add_triple_grad(const paddle::optional<Tensor>& grad_grad_x,
if (!reduce_dim.size()) {
by_pass<T>(grad_grad_out_grad, grad_grad_x_grad);
} else {
auto dddx_reduce_res = grad_grad_out_grad.sum(
common::vectorize(reduce_dim), grad_grad_x.get().dtype(), false);
auto dddx_tmp = reshape<T>(
dddx_reduce_res, common::vectorize(grad_grad_x.get().dims()));
set_output<T>(dddx_tmp, grad_grad_x_grad);
auto dddx_reduce_res =
grad_grad_out_grad.sum(common::vectorize(reduce_dim),
grad_grad_x.get().dtype(),
grad_grad_out_grad.dims().size() ==
grad_grad_x.get().dims().size());
if (dddx_reduce_res.dims() != grad_grad_x.get().dims()) {
dddx_reduce_res = reshape<T>(
dddx_reduce_res, common::vectorize(grad_grad_x.get().dims()));
}
set_output<T>(dddx_reduce_res, grad_grad_x_grad);
}
} else {
by_pass<T>(grad_grad_out_grad, grad_grad_x_grad);
Expand Down