Skip to content

Commit 53d2431

Browse files
zhiicsyzhliu
authored andcommitted
Separate fusion and Compilation (#1564)
* Separate fusion and compilation * fix description of graph_fuse.h * fix lint * fix @masahi 's comments, move fusion out of target * fix graph passing and make fused_entries singula in graph attr * fix typo * fix some comments * run test again * remove rvalue for graphfuse and graphfindfusiablegroups
1 parent c9f9a3f commit 53d2431

File tree

4 files changed

+388
-288
lines changed

4 files changed

+388
-288
lines changed

nnvm/python/nnvm/compiler/build_module.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,10 @@ def build(graph, target=None, shape=None, dtype="float32",
298298
else:
299299
graph._set_json_attr("opt_level", 0, "int")
300300
graph = graph.apply("InferShape").apply("InferType")
301+
graph = graph.apply("GraphFindFusibleGroups")
302+
graph = graph.apply("GraphFuse")
301303
with target:
302-
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
304+
graph = graph.apply("GraphCompile")
303305
libmod = graph_attr._move_out_module(graph, "module")
304306
# Write variable initial values into params
305307
if init_var:

nnvm/src/compiler/graph_compile.cc

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
/*!
2+
* Copyright (c) 2018 by Contributors
3+
* \file graph_compile.cc
4+
* \brief Compile a graph. It lowers the graph nodes into low level IR.
5+
*/
6+
7+
#include <dmlc/parameter.h>
8+
#include <nnvm/compiler/packed_func_ext.h>
9+
#include <nnvm/graph.h>
10+
#include <nnvm/graph_attr_types.h>
11+
#include <nnvm/node.h>
12+
#include <nnvm/op_attr_types.h>
13+
#include <nnvm/pass.h>
14+
#include <nnvm/pass_functions.h>
15+
#include <nnvm/tuple.h>
16+
#include <tvm/lowered_func.h>
17+
#include <tvm/runtime/packed_func.h>
18+
19+
#include "compile_engine.h"
20+
#include "graph_fuse.h"
21+
#include "graph_runtime.h"
22+
#include "pattern_util.h"
23+
24+
namespace nnvm {
25+
namespace compiler {
26+
27+
using namespace tvm;
28+
29+
// Decorate the result of PlanMemory
30+
// This function does two things:
31+
// - Give separate memory to each variable.
32+
// - Tie the memory of output/lhs in assign node properly
33+
// so the execution of assign can have side effect.
34+
nnvm::Graph DecorateMemoryPlan(
35+
nnvm::Graph g,
36+
const std::vector<int>& assign_flag) {
37+
const IndexedGraph& idx = g.indexed_graph();
38+
StorageVector storage_vec = g.MoveCopyAttr<StorageVector>("storage_id");
39+
g.attrs.erase("storage_allocated_bytes");
40+
g.attrs.erase("storage_inplace_index");
41+
size_t num_not_allocated = g.MoveCopyAttr<size_t>(
42+
"storage_num_not_allocated");
43+
CHECK_EQ(num_not_allocated, 0U)
44+
<< "Can only build inference graph with all statically allocated memory";
45+
46+
// Reassign variable id so that they are different.
47+
int max_id = 0;
48+
for (size_t i = 0; i < storage_vec.size(); ++i) {
49+
max_id = std::max(storage_vec[i] + 1, max_id);
50+
}
51+
for (uint32_t nid : idx.input_nodes()) {
52+
storage_vec[idx.entry_id(nid, 0)] = max_id++;
53+
}
54+
// Tie up the assign node storage properly.
55+
for (uint32_t nid = 0 ; nid < idx.num_nodes(); ++nid) {
56+
if (assign_flag[nid] == 0) continue;
57+
const auto& inode = idx[nid];
58+
int var_storage_id = storage_vec[idx.entry_id(inode.inputs[0])];
59+
storage_vec[idx.entry_id(nid, 0)] = var_storage_id;
60+
61+
if (assign_flag[nid] == 2) {
62+
storage_vec[idx.entry_id(inode.inputs[1])] = var_storage_id;
63+
}
64+
}
65+
g.attrs["storage_id"] = std::make_shared<any>(std::move(storage_vec));
66+
return g;
67+
}
68+
69+
nnvm::Graph GraphCompile(const nnvm::Graph& g) {
70+
// Get attributes from the graph.
71+
const ShapeVector& shape_vec = g.GetAttr<ShapeVector>("shape");
72+
const DTypeVector& dtype_vec = g.GetAttr<DTypeVector>("dtype");
73+
const GroupVec& group_vec = g.GetAttr<GroupVec>("group_root");
74+
const MasterVec& master_vec = g.GetAttr<MasterVec>("group_master");
75+
const PatternVec& pattern_vec = g.GetAttr<PatternVec>("pattern");
76+
77+
CHECK(g.HasAttr("fused_entry")) << "Fusion hasn't been applied yet.";
78+
FuseEntryVec fuse_entries = g.GetAttr<FuseEntryVec>("fused_entry");
79+
80+
std::string target = g.GetAttr<std::string>("target");
81+
std::string target_host;
82+
83+
if (g.HasAttr("target_host")) {
84+
target_host = g.GetAttr<std::string>("target_host");
85+
}
86+
// Specially handle assign.
87+
const nnvm::Op* assign_op = nnvm::Op::Get("_assign");
88+
89+
// Start lowering.
90+
Array<tvm::LoweredFunc> func_list;
91+
std::unordered_set<const tvm::Node*> func_set;
92+
const IndexedGraph& idx = g.indexed_graph();
93+
94+
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
95+
const auto& inode = idx[nid];
96+
if (inode.source->is_variable()) continue;
97+
int root_id = group_vec[nid];
98+
if (static_cast<int>(nid) != root_id) continue;
99+
int master = master_vec[root_id];
100+
FuseEntry& fe = fuse_entries[root_id];
101+
102+
const IndexedGraph& subidx = fe.subgraph.indexed_graph();
103+
CHECK_EQ(subidx.input_nodes().size(), fe.imap.size());
104+
CHECK_EQ(subidx.input_nodes().size(), fe.input_info.size());
105+
106+
Array<Tensor> inputs;
107+
for (uint32_t sub_input_id : subidx.input_nodes()) {
108+
auto it = fe.input_info.find(subidx[sub_input_id].source);
109+
inputs.push_back(it->second);
110+
}
111+
// Find master idx in the subgraph.
112+
int sub_master_idx = 0;
113+
for (uint32_t i = 0; i < subidx.num_nodes(); i++) {
114+
if (subidx[i].source->op() == idx[master].source->op()) {
115+
sub_master_idx = i;
116+
break;
117+
}
118+
}
119+
fe.compiled_func = GraphLower(fe.subgraph, inputs, target, sub_master_idx);
120+
for (LoweredFunc f : fe.compiled_func->funcs) {
121+
if (!func_set.count(f.get())) {
122+
func_set.insert(f.get());
123+
func_list.push_back(f);
124+
}
125+
}
126+
}
127+
128+
const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op");
129+
130+
std::unordered_map<uint32_t, nnvm::NodePtr> old_new;
131+
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
132+
const auto& inode = idx[nid];
133+
if (inode.source->is_variable()) {
134+
// Only copy name since that is sufficient.
135+
nnvm::NodePtr np = nnvm::Node::Create();
136+
np->attrs.name = inode.source->attrs.name;
137+
old_new[nid] = np;
138+
continue;
139+
}
140+
int root_id = group_vec[nid];
141+
if (static_cast<int>(nid) != root_id) continue;
142+
143+
// Handle normal op
144+
FuseEntry& fe = fuse_entries[root_id];
145+
const IndexedGraph& subidx = fe.subgraph.indexed_graph();
146+
nnvm::NodePtr np = nnvm::Node::Create();
147+
np->attrs.op = tvm_op;
148+
np->attrs.name = inode.source->attrs.name;
149+
TVMOpParam param;
150+
param.func_name = fe.compiled_func->func_name;
151+
param.num_inputs = static_cast<uint32_t>(fe.imap.size());
152+
param.num_outputs = static_cast<uint32_t>(fe.subgraph.outputs.size());
153+
param.flatten_data = fe.flatten_data;
154+
param.UpdateDict(&(np->attrs.dict));
155+
np->attrs.parsed = std::move(param);
156+
157+
for (uint32_t sub_input_id : subidx.input_nodes()) {
158+
// Need to make sure subgraph input order is consistent to the order of
159+
// the graph input.
160+
auto rit = fe.reverse_imap.find(subidx[sub_input_id].source);
161+
CHECK(rit != fe.reverse_imap.end());
162+
const IndexedGraph::NodeEntry& e = rit->second;
163+
auto it = old_new.find(e.node_id);
164+
CHECK(it != old_new.end())
165+
<< "cannot find node_id=" << e.node_id;
166+
np->inputs.emplace_back(
167+
nnvm::NodeEntry{it->second, e.index, e.version});
168+
}
169+
for (const uint32_t node_id : inode.control_deps) {
170+
auto it = old_new.find(node_id);
171+
CHECK(it != old_new.end());
172+
np->control_deps.emplace_back(it->second);
173+
}
174+
old_new[nid] = np;
175+
}
176+
nnvm::Graph ret;
177+
for (const auto& e : idx.outputs()) {
178+
auto it = old_new.find(group_vec[e.node_id]);
179+
CHECK(it != old_new.end())
180+
<< "cannot find node_id=" << e.node_id;
181+
ret.outputs.emplace_back(
182+
nnvm::NodeEntry{it->second, e.index, e.version});
183+
}
184+
185+
// Reference counter of each op node.
186+
// For now, always store result when an op is referred more than once.
187+
std::vector<uint32_t> ref_count = GetNodeRefCounts(idx);
188+
for (const auto& e : idx.outputs()) {
189+
// This line will realize all the outputs.
190+
ref_count[e.node_id] += 1;
191+
}
192+
193+
const IndexedGraph& new_idx = ret.indexed_graph();
194+
195+
// Handling assign:
196+
//
197+
// assign is a special operator that mutates the variable.
198+
// Currently assign is implemented as output = copy(input[1])
199+
// Then we run DecorageMemoryPlan to force
200+
// output.storage = input[0].storage
201+
//
202+
std::vector<int> assign_flag(new_idx.num_nodes(), 0);
203+
ShapeVector new_shape_vec = ShapeVector(new_idx.num_node_entries(), TShape());
204+
DTypeVector new_dtype_vec = DTypeVector(new_idx.num_node_entries());
205+
std::vector<std::string> new_dltype_vec(new_idx.num_node_entries());
206+
207+
for (const auto& kv : old_new) {
208+
uint32_t nid = kv.first;
209+
const auto& inode = idx[nid];
210+
uint32_t new_nid = new_idx.node_id(kv.second.get());
211+
if (inode.source->op() == assign_op) {
212+
// Check if rhs of assign can be computed inplace.
213+
// If yes, we can simply set that memory to be assign target
214+
// and change assign to nop.
215+
const IndexedGraph::NodeEntry& rhs = inode.inputs[1];
216+
if (ref_count[rhs.node_id] <= 1 &&
217+
!(idx[rhs.node_id].source->is_variable()) &&
218+
pattern_vec[group_vec[rhs.node_id]] <= kBroadcast) {
219+
assign_flag[new_nid] = 2;
220+
TVMOpParam& param = dmlc::get<TVMOpParam>(kv.second->attrs.parsed);
221+
param.func_name = "__nop";
222+
param.UpdateDict(&(kv.second->attrs.dict));
223+
} else {
224+
assign_flag[new_nid] = 1;
225+
}
226+
}
227+
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
228+
uint32_t new_eid = new_idx.entry_id(new_idx.node_id(kv.second.get()), i);
229+
uint32_t old_eid = idx.entry_id(nid, i);
230+
new_shape_vec[new_eid] = shape_vec[old_eid];
231+
new_dtype_vec[new_eid] = dtype_vec[old_eid];
232+
new_dltype_vec[new_eid] = tvm::runtime::TVMType2String(
233+
GetDLType(dtype_vec[old_eid]));
234+
}
235+
}
236+
ret.attrs["shape"] = std::make_shared<any>(std::move(new_shape_vec));
237+
ret.attrs["dtype"] = std::make_shared<any>(std::move(new_dtype_vec));
238+
ret.attrs["dltype"] = std::make_shared<any>(std::move(new_dltype_vec));
239+
240+
// Setup module
241+
static const PackedFunc& fbuild = GetPackedFunc("nnvm.compiler.build_target");
242+
tvm::runtime::Module module = fbuild(func_list, target, target_host);
243+
ret.attrs["module"] = std::make_shared<any>(std::move(module));
244+
ret = nnvm::ApplyPass(ret, "PlanMemory");
245+
ret = DecorateMemoryPlan(ret, assign_flag);
246+
return ret;
247+
}
248+
249+
NNVM_REGISTER_PASS(GraphCompile)
250+
.set_body(GraphCompile)
251+
.depend_graph_attr("shape")
252+
.depend_graph_attr("dtype")
253+
.depend_graph_attr("fused_entry")
254+
.depend_graph_attr("group_root")
255+
.depend_graph_attr("pattern")
256+
.depend_graph_attr("group_master");
257+
258+
} // namespace compiler
259+
} // namespace nnvm

0 commit comments

Comments
 (0)