Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Phi] Migrate infermeta and add yaml for solve op #44379

Merged
merged 17 commits into from
Jul 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 0 additions & 222 deletions paddle/fluid/operators/solve_op.cc

This file was deleted.

10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@
func : poisson
backward : poisson_grad

- api : solve
args : (Tensor x, Tensor y)
output : Tensor
infer_meta :
func : SolveInferMeta
kernel :
func : solve
data_type : x
backward : solve_grad

- api : trace
args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1)
output : Tensor
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/api_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@
outputs :
out : Out

- api : solve
inputs :
{x : X, y : Y}
outputs :
out : Out

- api : trace
inputs :
x : Input
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,16 @@
kernel :
func : poisson_grad

- backward_api : solve_grad
forward : solve (Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, y]
kernel :
func : solve_grad

- backward_api : trace_grad
forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2)
Expand Down
87 changes: 87 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2082,6 +2082,93 @@ void ValueCompareInferMeta(const MetaTensor& x,
out->set_dtype(DataType::BOOL);
}

void SolveInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
auto x_dims = x.dims();
auto y_dims = y.dims();

std::vector<int64_t> x_dims_vec = phi::vectorize(x.dims());
std::vector<int64_t> y_dims_vec = phi::vectorize(y.dims());

auto x_dims_n = x_dims_vec.size();
auto y_dims_n = y_dims_vec.size();

PADDLE_ENFORCE_GT(
x_dims_n,
1,
phi::errors::InvalidArgument("The input tensor X's dimensions of SolveOp "
"should be larger than 1. But received X's "
"dimensions = %d, X's shape = [%s]",
x_dims_n,
x_dims));

PADDLE_ENFORCE_GE(y_dims_n,
1,
phi::errors::InvalidArgument(
"The input tensor Y's dimensions of SolveOp "
"should be larger than or equal 1. But received Y's "
"dimensions = %d, Y's shape = [%s]",
y_dims_n,
y_dims));

PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2],
x_dims[x_dims_n - 1],
phi::errors::InvalidArgument(
"The inner-most 2 dimensions of Input(X) all should "
"be square matrices "
"But received X's shape[-2] = %d and shape[-1] = %d.",
x_dims[x_dims_n - 2],
x_dims[x_dims_n - 1]));

bool x_broadcasted = false, y_broadcasted = false;
bool trans_x = false, trans_y = false;
if (x_dims_n == 1) {
x_dims_vec.insert(x_dims_vec.begin(), 1);
x_dims_n = 2;
x_broadcasted = true;
}

if (y_dims_n == 1) {
y_dims_vec.push_back(1);
y_dims_n = 2;
y_broadcasted = true;
}

size_t M, N;
if (trans_x) {
M = x_dims_vec[x_dims_n - 1];
} else {
M = x_dims_vec[x_dims_n - 2];
}
if (trans_y) {
N = y_dims_vec[y_dims_n - 2];
} else {
N = y_dims_vec[y_dims_n - 1];
}

std::vector<int64_t> new_dims;
if (x_dims_n >= y_dims_n) {
new_dims.assign(x_dims_vec.begin(), x_dims_vec.end() - 2);
} else {
new_dims.assign(y_dims_vec.begin(), y_dims_vec.end() - 2);
}
if (!x_broadcasted) {
new_dims.push_back(M);
}
if (!y_broadcasted) {
new_dims.push_back(N);
}
if (x_broadcasted && y_broadcasted) {
new_dims.push_back(1);
}

auto out_dims = phi::make_ddim(new_dims);

out->set_dims(out_dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
}

} // namespace phi

PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,6 @@ void ValueCompareInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void SolveInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);

} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/kernels/impl/solve_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ template <typename T, typename Context>
void SolveGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy) {
bool is_vector = false;
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/solve_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ template <typename T, typename Context>
void SolveGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& dout,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy);

Expand Down
Loading