diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 3e81ae4cfe69f5..0f08f02fa2a6e9 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -252,7 +252,6 @@ bool AddNOpInferSymbolicShape(pir::Operation *op, "should be larger than 0. But received X's dimensions %d.", inputs_shape.size())); symbol::TensorShapeOrDataDimExprs candidate_shape = inputs_shape.front(); - size_t candidate_idx = 0; for (size_t i = 1; i < inputs_shape.size(); ++i) { // 0D tensor if (inputs_shape[i].shape().size() == 0) { @@ -260,19 +259,12 @@ bool AddNOpInferSymbolicShape(pir::Operation *op, } if (candidate_shape.shape().size() == 0) { candidate_shape = inputs_shape[i]; - candidate_idx = i; continue; } - PADDLE_ENFORCE_EQ(candidate_shape, - inputs_shape[i], - common::errors::InvalidArgument( - "The input tensor X of AddNOp must" - " have same shape. But received X[%d]'s shape = " - "[%s], X[%d]'s shape = [%s].", - candidate_idx, - candidate_shape, - i, - inputs_shape[i])); + for (size_t j = 0; j < candidate_shape.shape().size(); ++j) { + infer_context->AddEqualCstr(candidate_shape.shape()[j], + inputs_shape[i].shape()[j]); + } } infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{candidate_shape});