From 99b75cec0f9d5086b7c64c812ebe517980cb0845 Mon Sep 17 00:00:00 2001 From: shuaihehe <2728551637@qq.com> Date: Sun, 7 Apr 2024 12:28:03 +0000 Subject: [PATCH 1/2] fix1 --- .../convert_dynamic_to_static_dim_pass.cc | 17 +++++++-- .../convert_static_dim_to_dynamic_pass.cc | 10 +++++- .../group_with_group_merge_pass.cc | 29 +++++++++++++-- .../lowering_pass/broadcast_with_cf.cc | 14 ++++++-- ...plit_generate_shape_into_shape_ops_pass.cc | 35 ++++++++++++++++--- 5 files changed, 92 insertions(+), 13 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc index d1550a2bdf257..72219287fe3e3 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc @@ -181,10 +181,23 @@ class DynamicToStaticConverter { CHECK(shape_analysis_->HasShapeOrDataForValue(value)); const auto& origin_shape = GetOriginValueShape(value); const auto& target_shape = GetTargetValueShape(value); - CHECK_EQ(origin_shape.size(), target_shape.size()); + PADDLE_ENFORCE_EQ( + origin_shape.size(), + target_shape.size(), + phi::errors::InvalidArgument( + "The size of origin shape and target shape is not equal," + "where the size of origin shape:%d but the size of target " + "shape:%d.", + origin_shape.size(), + target_shape.size())); for (std::size_t i = 0; i < origin_shape.size(); ++i) { if (origin_shape.at(i) == -1) { - CHECK_GT(target_shape.at(i), 0); + PADDLE_ENFORCE_GT(target_shape.at(i), + 0, + phi::errors::InvalidArgument( + "The size of target shape is incorrect." + "Expected size is larger than 0, but receive %d.", + target_shape.at(i))); update = true; } else { CHECK(origin_shape.at(i) == target_shape.at(i)); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc index e67cb5aacabfa..e20cab270cdd3 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc @@ -154,7 +154,15 @@ struct StaticDimToDynamicConverter { const auto& origin_shape = GetOriginValueShape(value); const auto& target_shape = GetTargetValueShape( shape_analysis->GetShapeOrDataForValue(value).shape()); - CHECK_EQ(origin_shape.size(), target_shape.size()); + PADDLE_ENFORCE_EQ( + origin_shape.size(), + target_shape.size(), + phi::errors::InvalidArgument( + "The size of origin shape and target shape is not equal," + "where the size of origin shape:%d but the size of target " + "shape:%d.", + origin_shape.size(), + target_shape.size())); const auto& origin_type = value.type().dyn_cast<::pir::DenseTensorType>(); pir::DenseTensorType target_type = pir::DenseTensorType::get(pir::IrContext::Instance(), diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc index 79b8a70d28acc..1b0519938c933 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc @@ -1941,7 +1941,14 @@ class GeneralFusionMergePassHelper { } } - CHECK_GE(producer->consumer_groups().size(), candidates.size()); + PADDLE_ENFORCE_GE( + producer->consumer_groups().size(), + candidates.size(), + phi::errors::InvalidArgument( + "The size of producer consumer groups is incorrect." + "Expected size is greater than or equal to %d, but receive %d.", + candidates.size(), + producer->consumer_groups().size())); if (producer->consumer_groups().size() == 0 && candidates.size() == 0 && output_ops_set_.count(producer->CollectOps()[0]) == 0) { producer->belong_groups.insert(*fusionable_consumers->begin()); @@ -2204,8 +2211,24 @@ class GeneralFusionMergePassHelper { CHECK(consumer->belong_groups.size()); consumers.insert(*consumer->belong_groups.begin()); } - CHECK_EQ(group->producer_groups().size(), producers.size()); - CHECK_EQ(group->consumer_groups().size(), consumers.size()); + PADDLE_ENFORCE_EQ( + group->producer_groups().size(), + producers.size(), + phi::errors::InvalidArgument( + "The size of group's producer groups and producers is not equal," + "where the size of group's producer groups:%d but the size of " + "producers:%d.", + group->producer_groups().size(), + producers.size())); + PADDLE_ENFORCE_EQ( + group->consumer_groups().size(), + consumers.size(), + phi::errors::InvalidArgument( + "The size of group's consumer groups and consumers is not equal," + "where the size of group's consumer groups:%d but the size of " + "consumers:%d.", + group->consumer_groups().size(), + consumers.size())); (*group->mut_producer_groups()) = producers; (*group->mut_consumer_groups()) = consumers; } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc index 7068221d77fe5..c9ea6732fb072 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc @@ -134,7 +134,9 @@ bool EraseOneExpand( if (!SameInputOutputShape(expand, ShapeOrDataDimExprs4Value)) continue; auto generate_shape_op = expand.shape().defining_op(); - CHECK_NOTNULL(generate_shape_op); + PADDLE_ENFORCE_NOT_NULL(generate_shape_op, + phi::errors::PreconditionNotMet( + "The generate shape op must not be null.")); rewriter.ReplaceAllUsesWith(expand.out(), expand.x()); rewriter.EraseOp(expand); if (generate_shape_op->use_empty()) { @@ -280,7 +282,15 @@ void SetLeafBlockByGroupView( } auto new_group = CloneGroup(origin_group, block, &ir_mapping); - CHECK_EQ(origin_group->ops().size(), new_group->ops().size()); + PADDLE_ENFORCE_EQ( + origin_group->ops().size(), + new_group->ops().size(), + phi::errors::InvalidArgument( + "The size of origin group ops and new group ops is not equal," + "where the size of origin group ops:%d but the size of new group " + "ops:%d.", + origin_group->ops().size(), + new_group->ops().size())); UpdateGroupShapeExprs(new_group, origin_group, ir_mapping, diff --git a/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc index 19e7f5060eb96..696449b471b3d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc @@ -143,7 +143,12 @@ struct CachedDimExprToValueConverter { pir::Value ConvertToValueImpl(const symbol::Add& dim_expr) { const auto& [operands] = dim_expr; - CHECK_GT(operands->size(), 0); + PADDLE_ENFORCE_GT(operands->size(), + 0, + phi::errors::InvalidArgument( + "The size of operands is incorrect." + "Expected size is larger than 0, but receive %d.", + operands->size())); pir::Value acc = ConvertToValue(operands->at(0)); for (int i = 1; i < operands->size(); ++i) { if (operands->at(i).isa>()) { @@ -162,7 +167,12 @@ struct CachedDimExprToValueConverter { pir::Value ConvertToValueImpl(const symbol::Mul& dim_expr) { const auto& [operands] = dim_expr; - CHECK_GT(operands->size(), 0); + PADDLE_ENFORCE_GT(operands->size(), + 0, + phi::errors::InvalidArgument( + "The size of operands is incorrect." + "Expected size is larger than 0, but receive %d.", + operands->size())); pir::Value prod = ConvertToValue(operands->at(0)); for (int i = 1; i < operands->size(); ++i) { if (operands->at(i).isa>()) { @@ -182,7 +192,12 @@ struct CachedDimExprToValueConverter { pir::Value ConvertToValueImpl(const symbol::Max& dim_expr) { const auto& [operands] = dim_expr; - CHECK_GT(operands->size(), 0); + PADDLE_ENFORCE_GT(operands->size(), + 0, + phi::errors::InvalidArgument( + "The size of operands is incorrect." + "Expected size is larger than 0, but receive %d.", + operands->size())); pir::Value max = ConvertToValue(operands->at(0)); for (int i = 1; i < operands->size(); ++i) { pir::Value operand_value = ConvertToValue(operands->at(i)); @@ -193,7 +208,12 @@ struct CachedDimExprToValueConverter { pir::Value ConvertToValueImpl(const symbol::Min& dim_expr) { const auto& [operands] = dim_expr; - CHECK_GT(operands->size(), 0); + PADDLE_ENFORCE_GT(operands->size(), + 0, + phi::errors::InvalidArgument( + "The size of operands is incorrect." + "Expected size is larger than 0, but receive %d.", + operands->size())); pir::Value min = ConvertToValue(operands->at(0)); for (int i = 1; i < operands->size(); ++i) { pir::Value operand_value = ConvertToValue(operands->at(i)); @@ -205,7 +225,12 @@ struct CachedDimExprToValueConverter { pir::Value ConvertToValueImpl( const symbol::Broadcast& dim_expr) { const auto& [operands] = dim_expr; - CHECK_GT(operands->size(), 0); + PADDLE_ENFORCE_GT(operands->size(), + 0, + phi::errors::InvalidArgument( + "The size of operands is incorrect." + "Expected size is larger than 0, but receive %d.", + operands->size())); pir::Value broadcasted = ConvertToValue(operands->at(0)); for (int i = 1; i < operands->size(); ++i) { pir::Value operand_value = ConvertToValue(operands->at(i)); From 0dbc31ea750a9557d65879d34cd8e28150bdb194 Mon Sep 17 00:00:00 2001 From: shuaihehe <2728551637@qq.com> Date: Tue, 9 Apr 2024 11:23:01 +0000 Subject: [PATCH 2/2] push again