Skip to content

Commit

Permalink
add axis check for elementwise op while the dimension of x is equal t…
Browse files Browse the repository at this point in the history
…o the dimension of tensor (#35340)
  • Loading branch information
wangxinxin08 authored Sep 2, 2021
1 parent a622b70 commit 25871e0
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions paddle/fluid/operators/elementwise/elementwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ class ElementwiseOp : public framework::OperatorWithKernel {
auto y_dims = ctx->GetInputDim("Y");
int max_dim = std::max(x_dims.size(), y_dims.size());
int axis = ctx->Attrs().Get<int>("axis");
if (x_dims.size() == y_dims.size()) {
PADDLE_ENFORCE_EQ((axis == -1) || (axis == 0), true,
platform::errors::InvalidArgument(
"axis should be -1 or 0 while the dimension of "
"tensor X (%s) is equal to the dimension of "
"tensor Y (%s), but received axis: %s",
x_dims.size(), y_dims.size(), axis));
}
PADDLE_ENFORCE_EQ((axis >= (-1 * max_dim)) && (axis < max_dim), true,
platform::errors::InvalidArgument(
"The axis range must be [%s, %s), but axis is %s. "
Expand Down

0 comments on commit 25871e0

Please sign in to comment.