diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc index 8ad51581c1a740..473f1c9de1b485 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc @@ -32,6 +32,7 @@ #include "paddle/fluid/framework/feed_hook.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" @@ -75,23 +76,34 @@ void VisitFeedName(const pir::Program& program, const DoEachFeadNameT& DoEachFeadName) { auto module_op = program.module_op(); const auto& block = module_op.block(); - const auto& IsDataOp = [](const pir::Operation& op) -> bool { - return op.isa(); - }; - const auto& GetDataOpName = [](const pir::Operation& op) -> std::string { + auto GetDataOpName = + [](const pir::Operation& op) -> std::optional { + if (!op.isa()) return std::nullopt; return op.attributes().at("name").dyn_cast().AsString(); }; - const auto& IsFeedOp = [](const pir::Operation& op) -> bool { - return op.isa(); + auto GetFeedOpName = + [](const pir::Operation& op) -> std::optional { + if (!op.isa()) return std::nullopt; + return op.attributes().at("name").dyn_cast().AsString(); }; - const auto& GetFeedOpName = [](const pir::Operation& op) -> std::string { + auto GetPhiFeedOpName = + [](const pir::Operation& op) -> std::optional { + if (!op.isa()) return std::nullopt; + const auto& attributes = op.attributes(); + const auto& op_name_it = attributes.find("op_name"); + if (op_name_it == attributes.end()) return std::nullopt; + const auto& op_name = + op_name_it->second.dyn_cast().AsString(); + if (op_name != "pd_op.feed") return std::nullopt; return op.attributes().at("name").dyn_cast().AsString(); }; for (const auto& op : block) { - if (IsDataOp(op)) { - DoEachFeadName(GetDataOpName(op)); - } else if (IsFeedOp(op)) { - DoEachFeadName(GetFeedOpName(op)); + if (const auto& name = GetDataOpName(op)) { + DoEachFeadName(name.value()); + } else if (const auto& name = GetFeedOpName(op)) { + DoEachFeadName(name.value()); + } else if (const auto& name = GetPhiFeedOpName(op)) { + DoEachFeadName(name.value()); } else { // Do nothing. } @@ -1431,34 +1443,48 @@ std::optional GetNullShapeAnalysis( return std::nullopt; } +void TryTruncateLogginFile(const std::string& file_path) { + if (!FLAGS_logging_trunc_pir_py_code) return; + static std::mutex mutex; + std::unique_lock lock(mutex); + static std::unordered_map once_flags; + std::call_once(once_flags[file_path], [&] { + std::ofstream ofs; + ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc); + ofs.close(); + }); +} + } // namespace void PirToPyCodeConverter::SaveIfFlagEnabled() const { if (program_ == nullptr) return; if (file_name_.empty()) return; - if (FLAGS_logging_pir_py_code_dir == "") return; + if (FLAGS_logging_pir_py_code_dir.empty()) return; const std::string file_path = FLAGS_logging_pir_py_code_dir + "/" + file_name_; - ShapeAnalysisGetterT ShapeAnalysisGetter = - (dump_symbolic_shape_ ? GetShapeAnalysisFromManager - : GetNullShapeAnalysis); - PirToPyCodeConverterHelper converter_helper(program_, ShapeAnalysisGetter); - const std::string content = converter_helper.Convert(); - static std::mutex mutex; - std::unique_lock lock(mutex); - if (FLAGS_logging_trunc_pir_py_code) { - static std::unordered_map once_flags; - std::call_once(once_flags[file_path], [&] { - std::ofstream ofs; - ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc); - ofs.close(); - }); - } - std::ofstream ofs; - ofs.open(file_path.c_str(), std::ios::out | std::ios::app); - if (!ofs.is_open()) return; - ofs << content << std::endl; - ofs.close(); + TryTruncateLogginFile(file_path); + const auto MutOnceFlag = [&]() -> std::once_flag* { + static std::mutex mutex; + std::unique_lock lock(mutex); + using FileName = std::string; + using FileName2OnceFlag = std::unordered_map; + using ProgramId = int64_t; + static std::unordered_map once_flags; + return &once_flags[program_->id()][file_name_]; + }; + std::call_once(*MutOnceFlag(), [&] { + ShapeAnalysisGetterT ShapeAnalysisGetter = + (dump_symbolic_shape_ ? GetShapeAnalysisFromManager + : GetNullShapeAnalysis); + PirToPyCodeConverterHelper converter_helper(program_, ShapeAnalysisGetter); + const std::string content = converter_helper.Convert(); + std::ofstream ofs; + ofs.open(file_path.c_str(), std::ios::out | std::ios::app); + if (!ofs.is_open()) return; + ofs << content << std::endl; + ofs.close(); + }); } void DumpExecProgram(const pir::Program& program,