Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "paddle/fluid/framework/feed_hook.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
Expand Down Expand Up @@ -75,23 +76,34 @@ void VisitFeedName(const pir::Program& program,
const DoEachFeadNameT& DoEachFeadName) {
auto module_op = program.module_op();
const auto& block = module_op.block();
const auto& IsDataOp = [](const pir::Operation& op) -> bool {
return op.isa<paddle::dialect::DataOp>();
};
const auto& GetDataOpName = [](const pir::Operation& op) -> std::string {
auto GetDataOpName =
[](const pir::Operation& op) -> std::optional<std::string> {
if (!op.isa<paddle::dialect::DataOp>()) return std::nullopt;
return op.attributes().at("name").dyn_cast<pir::StrAttribute>().AsString();
};
const auto& IsFeedOp = [](const pir::Operation& op) -> bool {
return op.isa<paddle::dialect::FeedOp>();
auto GetFeedOpName =
[](const pir::Operation& op) -> std::optional<std::string> {
if (!op.isa<paddle::dialect::FeedOp>()) return std::nullopt;
return op.attributes().at("name").dyn_cast<pir::StrAttribute>().AsString();
};
const auto& GetFeedOpName = [](const pir::Operation& op) -> std::string {
auto GetPhiFeedOpName =
[](const pir::Operation& op) -> std::optional<std::string> {
if (!op.isa<paddle::dialect::PhiKernelOp>()) return std::nullopt;
const auto& attributes = op.attributes();
const auto& op_name_it = attributes.find("op_name");
if (op_name_it == attributes.end()) return std::nullopt;
const auto& op_name =
op_name_it->second.dyn_cast<pir::StrAttribute>().AsString();
if (op_name != "pd_op.feed") return std::nullopt;
return op.attributes().at("name").dyn_cast<pir::StrAttribute>().AsString();
};
for (const auto& op : block) {
if (IsDataOp(op)) {
DoEachFeadName(GetDataOpName(op));
} else if (IsFeedOp(op)) {
DoEachFeadName(GetFeedOpName(op));
if (const auto& name = GetDataOpName(op)) {
DoEachFeadName(name.value());
} else if (const auto& name = GetFeedOpName(op)) {
DoEachFeadName(name.value());
} else if (const auto& name = GetPhiFeedOpName(op)) {
DoEachFeadName(name.value());
} else {
// Do nothing.
}
Expand Down Expand Up @@ -1431,34 +1443,48 @@ std::optional<pir::ShapeConstraintIRAnalysis*> GetNullShapeAnalysis(
return std::nullopt;
}

void TryTruncateLogginFile(const std::string& file_path) {
if (!FLAGS_logging_trunc_pir_py_code) return;
static std::mutex mutex;
std::unique_lock<std::mutex> lock(mutex);
static std::unordered_map<std::string, std::once_flag> once_flags;
std::call_once(once_flags[file_path], [&] {
std::ofstream ofs;
ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc);
ofs.close();
});
}

} // namespace

void PirToPyCodeConverter::SaveIfFlagEnabled() const {
if (program_ == nullptr) return;
if (file_name_.empty()) return;
if (FLAGS_logging_pir_py_code_dir == "") return;
if (FLAGS_logging_pir_py_code_dir.empty()) return;
const std::string file_path =
FLAGS_logging_pir_py_code_dir + "/" + file_name_;
ShapeAnalysisGetterT ShapeAnalysisGetter =
(dump_symbolic_shape_ ? GetShapeAnalysisFromManager
: GetNullShapeAnalysis);
PirToPyCodeConverterHelper converter_helper(program_, ShapeAnalysisGetter);
const std::string content = converter_helper.Convert();
static std::mutex mutex;
std::unique_lock<std::mutex> lock(mutex);
if (FLAGS_logging_trunc_pir_py_code) {
static std::unordered_map<std::string, std::once_flag> once_flags;
std::call_once(once_flags[file_path], [&] {
std::ofstream ofs;
ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc);
ofs.close();
});
}
std::ofstream ofs;
ofs.open(file_path.c_str(), std::ios::out | std::ios::app);
if (!ofs.is_open()) return;
ofs << content << std::endl;
ofs.close();
TryTruncateLogginFile(file_path);
const auto MutOnceFlag = [&]() -> std::once_flag* {
static std::mutex mutex;
std::unique_lock<std::mutex> lock(mutex);
using FileName = std::string;
using FileName2OnceFlag = std::unordered_map<FileName, std::once_flag>;
using ProgramId = int64_t;
static std::unordered_map<ProgramId, FileName2OnceFlag> once_flags;
return &once_flags[program_->id()][file_name_];
};
std::call_once(*MutOnceFlag(), [&] {
ShapeAnalysisGetterT ShapeAnalysisGetter =
(dump_symbolic_shape_ ? GetShapeAnalysisFromManager
: GetNullShapeAnalysis);
PirToPyCodeConverterHelper converter_helper(program_, ShapeAnalysisGetter);
const std::string content = converter_helper.Convert();
std::ofstream ofs;
ofs.open(file_path.c_str(), std::ios::out | std::ios::app);
if (!ofs.is_open()) return;
ofs << content << std::endl;
ofs.close();
});
}

void DumpExecProgram(const pir::Program& program,
Expand Down