Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add graph_key to specific graph's varmap #60567

Merged
merged 4 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions paddle/fluid/framework/io/save_paddle2cinn_varmap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ namespace framework {

void save_paddle2cinn_varmap(
std::unordered_map<std::string, std::string> paddle2cinn_var_map,
int64_t graph_compilation_key,
std::string save_path) {
std::stringstream ss;
ss << "graph_compilation_key:" << std::to_string(graph_compilation_key)
<< "\n";
for (const auto& kv : paddle2cinn_var_map) {
ss << kv.first << ":" << kv.second << "\n";
}
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/io/save_paddle2cinn_varmap.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace framework {

void save_paddle2cinn_varmap(
std::unordered_map<std::string, std::string> paddle2cinn_var_map,
int64_t graph_compilation_key,
std::string save_path);

}
Expand Down
20 changes: 19 additions & 1 deletion paddle/fluid/framework/io/save_runtime_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ void save_string(std::string content,
fout.close();
}

void save_graph_compilation_key(int64_t graph_compilation_key,
std::string type,
std::string saved_path) {
VLOG(6) << type << " will be saved to " << saved_path;
MkDirRecursively(DirName(saved_path).c_str());

std::ofstream fout(saved_path);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fout),
true,
phi::errors::Unavailable("Cannot open %s to save ", saved_path));
fout << std::to_string(graph_compilation_key);
fout.close();
}

std::string node_format(const ir::Node& node, int number) {
return "node_" + std::to_string(number) + " : " + "[" + node.Name() + ", " +
(node.IsOp() ? "op" : "var") + "]";
Expand Down Expand Up @@ -78,6 +93,7 @@ void save_graph(const ir::Graph& graph,
}

void save_runtime_cinn_graph(const ir::Graph& graph,
int64_t graph_compilation_key,
std::string clusters_ops,
std::string clusters_inputs,
std::string cluster_outputs,
Expand All @@ -91,7 +107,9 @@ void save_runtime_cinn_graph(const ir::Graph& graph,
save_string(cluster_intervals,
"cluster_intervals",
saved_path + "/cluster_intervals.txt");

save_graph_compilation_key(graph_compilation_key,
"graph_compilation_key",
saved_path + "/graph_compilation_key.txt");
save_graph(graph, "graph", saved_path + "/subgraph.txt");
}

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/io/save_runtime_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
void save_runtime_cinn_graph(const ir::Graph& graph,
int64_t graph_compilation_key,
std::string clusters_ops,
std::string clusters_inputs,
std::string cluster_outputs,
Expand Down
11 changes: 6 additions & 5 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -753,20 +753,21 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
subgraph->GetOrInit<std::unordered_set<std::string>>(kSkipGcVarNames);
sub_skip_gc_vars = all_skip_gc_vars;
}
auto compilation_key = cinn_compiler->AddGraph(std::move(subgraph));
VLOG(4) << "Compilation Key:\n"
<< cinn_compiler->ReadableKey(compilation_key);

if (FLAGS_save_static_runtime_data) {
paddle::framework::save_runtime_cinn_graph(
*subgraph,
cinn_compiler->FindGraph(compilation_key),
compilation_key,
cluster_debug_info(cluster_set),
cluster_debug_info(cluster_inputs),
cluster_debug_info(cluster_outputs),
cluster_debug_info(cluster_internals),
FLAGS_static_runtime_data_save_path + "/cluster_" +
std::to_string(++i));
}
auto compilation_key = cinn_compiler->AddGraph(std::move(subgraph));
VLOG(4) << "Compilation Key:\n"
<< cinn_compiler->ReadableKey(compilation_key);

// Replace the found cluster to a new cinn op node
ReplaceSubGraphWithCinnOpNode(cluster_set,
cluster_inputs,
Expand Down
15 changes: 9 additions & 6 deletions paddle/fluid/operators/cinn/cinn_launch_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph,
[](const auto& name_view) { return std::string(name_view.data()); });
// build name map between the original variables and compiled ones
BuildVarNameMap(compiled_obj.paddle2cinn_varmap, cinn_argument_names_);
if (FLAGS_save_static_runtime_data) {
auto graph_compilation_key =
std::hash<const framework::ir::Graph*>()((&graph));
paddle::framework::save_paddle2cinn_varmap(
paddle2cinn_varmap_,
graph_compilation_key,
FLAGS_static_runtime_data_save_path +
"/paddle2cinn_varmap/paddle2cinn_varmap.txt");
}

const auto& input_var_names =
graph.Get<std::vector<std::string>>(framework::paddle2cinn::kInputVars);
Expand Down Expand Up @@ -193,12 +202,6 @@ void CinnLaunchContext::BuildVarNameMap(
"Size of variables is not euqal, paddle[%ld] vs cinn[%ld]",
paddle2cinn_varmap_.size(),
cinn2paddle_varmap_.size()));
if (FLAGS_save_static_runtime_data) {
paddle::framework::save_paddle2cinn_varmap(
paddle2cinn_varmap_,
FLAGS_static_runtime_data_save_path +
"/paddle2cinn_varmap/paddle2cinn_varmap.txt");
}
}

std::unordered_set<std::string> CinnLaunchContext::GetVisibleVarNames() const {
Expand Down