Skip to content

Commit

Permalink
[PIR][DynamicShape] Remove redundant code for shapeAnalysis and shape…
Browse files Browse the repository at this point in the history
…dTypeInterface (#60744)

att, remove redundant code for shapeAnalysis and shapedTypeInterface
  • Loading branch information
lanxianghit authored Jan 11, 2024
1 parent bcd5e37 commit a576356
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,8 @@ bool ProcessOp(paddle::dialect::ExpandOp op, pir::PatternRewriter* rewriter) {
pir::ShapeConstraintIRAnalysis& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(
op.x().defining_op()->GetParentProgram());
CHECK(shape_analysis.value_id_to_shapeordata_.find(GetValueId(&value)) !=
shape_analysis.value_id_to_shapeordata_.end());
return shape_analysis.value_id_to_shapeordata_.at(GetValueId(&value));

return shape_analysis.GetShapeOrDataForValue(value);
};
std::optional<pir::Value> opt_generated_shape =
GetOutOfRewritedGenerateShapeOp(
Expand Down
43 changes: 22 additions & 21 deletions paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/builtin_type_interfaces.h"
#include "paddle/pir/dialect/shape/ir/shape_attribute.h"

namespace paddle::dialect {
Expand All @@ -33,27 +34,25 @@ bool SameOperandsAndResultShape(
pir::Value operand_source = op->operand_source(0);

symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source];
shape_analysis->GetShapeOrDataForValue(operand_source);

op->set_attribute("symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
operand_shape_or_data));
pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = operand_shape_or_data;
shape_analysis->SetShapeOrDataForValue(res, operand_shape_or_data);
return true;
}

bool InferSymbolicShapeElementWiseBinary(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source_0 = op->operand_source(0);
std::string operand_source_0_id = pir::GetValueId(&operand_source_0);
std::vector<symbol::DimExpr> shape_0{
shape_analysis->value_id_to_shapeordata_[operand_source_0_id].shape()};
shape_analysis->GetShapeOrDataForValue(operand_source_0).shape()};

pir::Value operand_source_1 = op->operand_source(1);
std::string operand_source_1_id = pir::GetValueId(&operand_source_1);
std::vector<symbol::DimExpr> shape_1{
shape_analysis->value_id_to_shapeordata_[operand_source_1_id].shape()};
shape_analysis->GetShapeOrDataForValue(operand_source_1).shape()};

if (shape_0.size() > shape_1.size()) {
for (size_t i = 0; i < shape_0.size() - shape_1.size(); i++) {
Expand All @@ -75,9 +74,11 @@ bool InferSymbolicShapeElementWiseBinary(
std::vector<symbol::DimExpr> data;

pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);
symbol::ShapeOrDataDimExprs shape_data{shapes, data};
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
shape_analysis->SetShapeOrDataForValue(res, shape_data);
op->set_attribute(
"symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
return true;
}

Expand All @@ -104,7 +105,7 @@ bool DataOpInferSymbolicShape(pir::Operation *op,
std::vector<symbol::DimExpr> sym_dims;
for (auto dim : dims) {
symbol::DimExpr dim_expr;
if (dim == -1) {
if (dim == pir::ShapedTypeInterface::kDynamic) {
symbol::DimExpr symbolic_dim_expr(shape_analysis->GetNextSymName());
dim_expr = symbolic_dim_expr;
} else {
Expand All @@ -120,7 +121,7 @@ bool DataOpInferSymbolicShape(pir::Operation *op,
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));

pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;
shape_analysis->SetShapeOrDataForValue(res, shape_data);

return true;
}
Expand Down Expand Up @@ -171,13 +172,13 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op,
pir::OpResult res = op->result(0);

symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source];
shape_analysis->GetShapeOrDataForValue(operand_source);

symbol::ShapeOrDataDimExprs extend_shape_or_data =
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(
operand_shape_or_data);

shape_analysis->value_to_shape_or_data_[res] = extend_shape_or_data;
shape_analysis->SetShapeOrDataForValue(res, extend_shape_or_data);
op->set_attribute("symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
extend_shape_or_data));
Expand All @@ -193,7 +194,7 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source];
shape_analysis->GetShapeOrDataForValue(operand_source);

std::vector<symbol::DimExpr> out_dims;
if (operand_shape_or_data.data().has_value()) {
Expand All @@ -213,7 +214,7 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
"symbolic_shape",
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;
shape_analysis->SetShapeOrDataForValue(res, shape_data);
return true;
}

Expand All @@ -222,7 +223,7 @@ bool ReshapeOpInferSymbolicShape(
pir::Value operand_source_shape = op->operand_source(1);

symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source_shape];
shape_analysis->GetShapeOrDataForValue(operand_source_shape);

std::vector<symbol::DimExpr> out_dims;
if (operand_shape_or_data.data().has_value()) {
Expand All @@ -236,9 +237,9 @@ bool ReshapeOpInferSymbolicShape(

pir::OpResult res0 = op->result(0);
pir::OpResult res1 = op->result(1);
shape_analysis->value_to_shape_or_data_[res0] = shape_data;
shape_analysis->value_to_shape_or_data_[res1] =
shape_analysis->value_to_shape_or_data_[operand_source_shape];
shape_analysis->SetShapeOrDataForValue(res0, shape_data);
shape_analysis->SetShapeOrDataForValue(
res1, shape_analysis->GetShapeOrDataForValue(operand_source_shape));
return true;
}

Expand Down Expand Up @@ -267,7 +268,7 @@ bool FullIntArrayOpInferSymbolicShape(
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));

pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;
shape_analysis->SetShapeOrDataForValue(res, shape_data);
return true;
}

Expand All @@ -286,7 +287,7 @@ bool SliceOpInferSymbolicShape(pir::Operation *op,
// dialect.
pir::Value operand_source = op->operand_source(0);
symbol::ShapeOrDataDimExprs operand_shape_or_data =
shape_analysis->value_to_shape_or_data_[operand_source];
shape_analysis->GetShapeOrDataForValue(operand_source);
pir::AttributeMap attributes = op->attributes();

std::vector<pir::Attribute> attr_starts =
Expand All @@ -309,7 +310,7 @@ bool SliceOpInferSymbolicShape(pir::Operation *op,
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));

pir::OpResult res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;
shape_analysis->SetShapeOrDataForValue(res, shape_data);
return true;
}

Expand Down
15 changes: 4 additions & 11 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3157,15 +3157,9 @@ bool ShapeBroadcastOp::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value x = operand_source(0);
pir::Value y = operand_source(1);
std::string x_id = pir::GetValueId(&x);
std::string y_id = pir::GetValueId(&y);

IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(x_id) > 0,
"x_id does not exist.");
IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(y_id) > 0,
"y_id does not exist.");
const auto &x_data_shape = shape_analysis->value_id_to_shapeordata_.at(x_id);
const auto &y_data_shape = shape_analysis->value_id_to_shapeordata_.at(y_id);

const auto &x_data_shape = shape_analysis->GetShapeOrDataForValue(x);
const auto &y_data_shape = shape_analysis->GetShapeOrDataForValue(y);
IR_ENFORCE(x_data_shape.data().has_value(),
"Value x comes from ShapeOp, it must have data");
IR_ENFORCE(y_data_shape.data().has_value(),
Expand All @@ -3180,10 +3174,9 @@ bool ShapeBroadcastOp::InferSymbolicShape(
}

pir::OpResult res = result(0);
std::string res_id = pir::GetValueId(&res);
symbol::ShapeOrDataDimExprs output_data_shape =
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(output_data);
shape_analysis->value_id_to_shapeordata_[res_id] = output_data_shape;
shape_analysis->SetShapeOrDataForValue(res, output_data_shape);
return true;
}

Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ struct CombineOpInferSymbolicShapeInterfaceModel
}

auto operand_source_1st_data =
shape_analysis->value_to_shape_or_data_[op->operand_source(0)].data();
shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).data();
if (operand_source_1st_data.has_value()) {
for (auto operand_source : op->operands_source()) {
auto source_data =
shape_analysis->value_to_shape_or_data_[operand_source]
shape_analysis->GetShapeOrDataForValue(operand_source)
.data()
.value();
out_dims.push_back(source_data[0]);
Expand All @@ -83,7 +83,7 @@ struct CombineOpInferSymbolicShapeInterfaceModel
pir::shape::SymbolAttribute::get(
pir::IrContext::Instance(), shape_data));
auto res = op->result(0);
shape_analysis->value_to_shape_or_data_[res] = shape_data;
shape_analysis->SetShapeOrDataForValue(res, shape_data);
return true;
}

Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/pir/transforms/shape_optimization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void DebugPrintOpInfo(
<< "ShapeOrData: ";

if (shape_analysis != nullptr) {
auto shape_data = shape_analysis->value_to_shape_or_data_[res];
auto shape_data = shape_analysis->GetShapeOrDataForValue(res);
print_stream << "shape: [";

for (size_t i = 0; i < shape_data.shape().size(); ++i) {
Expand Down Expand Up @@ -94,7 +94,9 @@ void InferSymExprForAllValues(ModuleOp module_op) {
if (infer_symbolic_shape_interface) {
VLOG(3) << op.name() << " has InferSymbolicShapeInterface.";
PADDLE_ENFORCE(infer_symbolic_shape_interface.InferSymbolicShape(
&shape_analysis));
&shape_analysis),
"InferSymbolicShape for %s failed.",
op.name());
}
DebugPrintOpInfo(&op, &shape_analysis);
}
Expand Down
12 changes: 2 additions & 10 deletions paddle/pir/core/builtin_type_interfaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,8 @@ Type ShapedTypeInterface::GetElementType() const {
return impl_->get_element_type(*this);
}

std::vector<int64_t> ShapedTypeInterface::GetDyShape() const {
if (dy_shape_.size() == 0) {
auto ddim_vec = common::vectorize(impl_->get_shape(*this));
dy_shape_ = ddim_vec;
std::replace(dy_shape_.begin(),
dy_shape_.end(),
(int64_t)-1,
ShapedTypeInterface::kDynamic);
}
return dy_shape_;
pir::DDim ShapedTypeInterface::GetShape() const {
return impl_->get_shape(*this);
}

} // namespace pir
Expand Down
22 changes: 10 additions & 12 deletions paddle/pir/core/builtin_type_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class IR_API ShapedTypeInterface
///
/// \brief kDynamic
///
static constexpr int64_t kDynamic = std::numeric_limits<int64_t>::min();
static constexpr int64_t kDynamic = std::int64_t(-1);

ShapedTypeInterface(Type type, Concept *impl)
: TypeInterfaceBase<ShapedTypeInterface>(type), impl_(impl) {}
Expand All @@ -69,7 +69,7 @@ class IR_API ShapedTypeInterface
///
/// \brief Get the shape of this type.
///
std::vector<int64_t> GetDyShape() const;
pir::DDim GetShape() const;

///
/// \brief Check whether this type is ranked, currently return true.
Expand All @@ -81,7 +81,7 @@ class IR_API ShapedTypeInterface
///
int64_t GetRank() const {
IR_ENFORCE((*this).HasRank(), "Cannot query rank of unranked shaped type.");
return (*this).GetDyShape().size();
return (*this).GetShape().size();
}

///
Expand All @@ -94,11 +94,10 @@ class IR_API ShapedTypeInterface
/// dimension.
///
bool IsDynamicShape() const {
auto size_vec = (*this).GetDyShape();
return std::any_of(
size_vec.begin(), size_vec.end(), [](int64_t size_value) {
return IsDynamic(size_value);
});
auto size_vec = common::vectorize(impl_->get_shape(*this));
return std::any_of(size_vec.begin(), size_vec.end(), [](int64_t size_val) {
return IsDynamic(size_val);
});
}

///
Expand All @@ -112,15 +111,15 @@ class IR_API ShapedTypeInterface
///
bool IsDynamicDim(unsigned idx) const {
IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type.");
return ShapedTypeInterface::IsDynamic((*this).GetDyShape()[idx]);
return ShapedTypeInterface::IsDynamic((*this).GetShape()[idx]);
}

///
/// \brief Get the number of dimensions with dynamic size for a ranked type.
/// Aborts for unranked types.
///
int64_t GetNumDynamicDims() const {
auto shape_vec = (*this).GetDyShape();
auto shape_vec = vectorize((*this).GetShape());
return std::count_if(
shape_vec.begin(), shape_vec.end(), ShapedTypeInterface::IsDynamic);
}
Expand All @@ -131,12 +130,11 @@ class IR_API ShapedTypeInterface
///
int64_t GetDimSize(unsigned idx) const {
IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type.");
return (*this).GetDyShape()[idx];
return (*this).GetShape()[idx];
}

private:
Concept *impl_;
mutable std::vector<int64_t> dy_shape_;
};

} // namespace pir
Expand Down
12 changes: 6 additions & 6 deletions paddle/pir/core/type_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ Type GetElementTypeOrSelf(Type type) {
return type;
}

bool VerifyCompatibleShape(const std::vector<int64_t> &lhs_shape,
const std::vector<int64_t> &rhs_shape) {
bool VerifyCompatibleShape(const pir::DDim &lhs_shape,
const pir::DDim &rhs_shape) {
if (lhs_shape.size() != rhs_shape.size()) return false;

for (auto dim1 : lhs_shape) {
for (auto dim2 : rhs_shape) {
for (auto dim1 : common::vectorize(lhs_shape)) {
for (auto dim2 : common::vectorize(rhs_shape)) {
if (!ShapedTypeInterface::IsDynamic(dim1) &&
!ShapedTypeInterface::IsDynamic(dim2) && dim1 != dim2)
return false;
Expand All @@ -47,8 +47,8 @@ bool VerifyCompatibleShape(Type lhs_type, Type rhs_type) {

if (!lhs_shaped_type.HasRank() || !rhs_shaped_type.HasRank()) return true;

return VerifyCompatibleShape(lhs_shaped_type.GetDyShape(),
rhs_shaped_type.GetDyShape());
return VerifyCompatibleShape(lhs_shaped_type.GetShape(),
rhs_shaped_type.GetShape());
}

bool VerifyCompatibleDims(const std::vector<int64_t> &dims) {
Expand Down
7 changes: 3 additions & 4 deletions paddle/pir/dialect/shape/utils/shape_optimization_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,9 @@ std::vector<SymbolicDimOp> SymbolicDimMgr::CreateSymbolicDimsForRankedValue(
std::vector<SymbolicDimOp> symbols;
auto dims = value.type().dyn_cast<pir::DenseTensorType>().dims();
for (int idx = 0; idx < dims.size(); ++idx) {
symbols.push_back(
(dims[idx] == ShapedTypeInterface::kDynamic || dims[idx] == -1)
? NewSymbolicDim()
: NewConstantSymbolicDim(dims[idx]));
symbols.push_back(dims[idx] == ShapedTypeInterface::kDynamic
? NewSymbolicDim()
: NewConstantSymbolicDim(dims[idx]));
}
return symbols;
}
Expand Down
Loading

0 comments on commit a576356

Please sign in to comment.