Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,8 @@ void ApplyCinnPass(::pir::Program* program,
.file_name("original_programs.py")
.dump_symbolic_shape(FLAGS_logging_pir_py_code_dump_symbolic_dims)
.SaveIfFlagEnabled();
ApplyPdToCinnPass(program, CreatePassManager);
// TODO(Hongqing-work): move ApplyShapeOptimizationPass before
// ApplyPdToCinnPass after fixing infer shape bug.
ApplyShapeOptimizationPass(program, CreatePassManager);
ApplyPdToCinnPass(program, CreatePassManager);
ApplyCinnPreprocessPass(program, CreatePassManager);
ApplyBuildGroupOpPass(program, CreatePassManager);
PirToPyCodeConverter(program)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,79 @@ bool EinsumOpInferSymbolicShape(pir::Operation *op,
return true;
}

bool FlattenOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &attributes = op->attributes();
int start_axis =
attributes.at("start_axis").dyn_cast<pir::Int32Attribute>().data();
int stop_axis =
attributes.at("stop_axis").dyn_cast<pir::Int32Attribute>().data();

const auto &x_shape =
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
int in_dims_size = x_shape.size();

if (in_dims_size == 0) {
PADDLE_ENFORCE_EQ(
start_axis == 0 || start_axis == -1,
true,
phi::errors::InvalidArgument("The start_axis should be 0 or -1 when "
"the input tensor is a 0D-Tensor"));
PADDLE_ENFORCE_EQ(
stop_axis == 0 || stop_axis == -1,
true,
phi::errors::InvalidArgument("The stop_axis should be 0 or -1 when the "
"input tensor is a 0D-Tensor"));
// this can ensure out shape {1}
start_axis = 0;
stop_axis = -1;
}

if (start_axis < 0) {
start_axis = start_axis + in_dims_size;
}
if (stop_axis < 0) {
stop_axis = stop_axis + in_dims_size;
}
if (in_dims_size > 0) {
PADDLE_ENFORCE_GE(
stop_axis,
start_axis,
phi::errors::InvalidArgument("The stop_axis should be greater"
"than or equal to start_axis."));
}

symbol::DimExpr outer{1};
std::vector<symbol::DimExpr> out_shape;
out_shape.reserve(in_dims_size - stop_axis + start_axis + 1);
for (int i = 0; i < start_axis; ++i) {
out_shape.push_back(x_shape[i]);
}
for (int i = start_axis; i <= stop_axis; i++) {
outer = outer * x_shape[i];
}
Comment on lines +353 to +361
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 * 1 的表达式会自动化简掉吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会自动化简的

out_shape.push_back(outer);
for (int i = stop_axis + 1; i < in_dims_size; i++) {
out_shape.push_back(x_shape[i]);
}

symbol::ShapeOrDataDimExprs out_shape_data{
symbol::TensorShapeOrDataDimExprs(out_shape)};
infer_context->SetShapeOrDataForValue(op->result(0), out_shape_data);

std::vector<symbol::DimExpr> xshape_shape = x_shape;
xshape_shape.insert(xshape_shape.begin(), symbol::DimExpr{0});
symbol::ShapeOrDataDimExprs xshape_shape_data{
symbol::TensorShapeOrDataDimExprs(xshape_shape)};
infer_context->SetShapeOrDataForValue(op->result(1), xshape_shape_data);
return true;
}

bool Flatten_OpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
return FlattenOpInferSymbolicShape(op, infer_context);
}

bool KthvalueOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
pir::Value operand_source = op->operand_source(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(DiagEmbed)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Diagonal)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(DistributeFpnProposals)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Einsum)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flatten)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flatten_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kthvalue)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,7 @@
view : (x -> out)
intermediate : xshape
backward : flatten_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : flip
args : (Tensor x, int[] axis)
Expand Down