forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph.h
179 lines (148 loc) · 5.13 KB
/
graph.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#pragma once
#include "caffe2/core/common.h"
#include "caffe2/proto/caffe2_pb.h"
#include "caffe2/utils/proto_utils.h"
#include "caffe2/utils/string_utils.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
namespace caffe2 {
namespace transform {
/**
* Graph representation of an operator.
*/
struct TORCH_API Node {
public:
// Empty constructor for resize
Node() {}
// Alternate constructor
Node(
const OperatorDef& op,
bool active,
std::map<int, std::vector<string>> parents,
std::map<int, std::vector<string>> children)
: op(op), active(active), parents(parents), children(children) {}
// The OperatorDef which this node represents.
OperatorDef op;
// Keeps track of if an operator has been deleted through a transformation.
bool active = true;
// Stores a pair (idx, blob_list),
// idx = index of the child
// blob_list = a list of strings, containing the blobs that connect the nodes
std::map<int, std::vector<string>> parents;
std::map<int, std::vector<string>> children;
};
/**
* Graph representation of a Netdef.
*/
struct TORCH_API Graph {
public:
/**
* Given a subgraph, gets all of the parents of the subgraph, as well as
* their associated blob names. Sorted by blob names.
*
* <string, int> := (name of blob writing into subgraph,
* index of node that writes into subgraph using that blob)
*/
const std::vector<std::pair<string, int>> GetSubgraphInput(
const std::vector<int>& subgraph);
/**
* Given a subgraph, gets all of the children of the subgraph, as well as
* their associated blob names. Sorted by blob names.
*
* <string, int> := (name of blob reading from subgraph,
* index of node that reads from subgraph using that blob)
*/
const std::vector<std::pair<string, int>> GetSubgraphOutput(
const std::vector<int>& subgraph);
/**
* Graph generation.
* Given a netdef, returns a Graph.
*
* Each node represents an operator.
* An edge exists between two nodes if the parent op writes to a blob, which
* is the input of the child blob, with no other op writing to the blob in
* between the execution order.
*
* Time Complexity: O(E), where E is the number of blobs
*/
explicit Graph(const NetDef& net_def);
/**
* Generates a NetDef Representation for the current graph.
* Nodes are visited in topological order, which is proper Opdef ordering.
* TODO(benz):
* There exists conflicts with repeated blob names, where topological sorting
* is not sufficient for correct netdef representation, unless blobs are
* renamed.
* For example, if after a transformation, We have operator ancestry:
* A --> B --> C, and also A --> D --> E, where B -> C and D -> E uses the
* same blob name, then A, B, D, E, C is a correct topological ordering,
* but D will write to the blob that C reads from, instead of B.
* Currently believe that there will always be ambiguity unless blobs are
* renamed.
* This is solved by performing SSA on all transformed blob names.
*/
NetDef GetNetDef();
/**
* Deactivate a subgraph, and get rid of all edges into this subgraph.
*/
void DeactivateSubgraph(std::vector<int> subgraph);
size_t size() const {
return nodes_.size();
}
void push_node(const Node& new_node) {
return nodes_.push_back(new_node);
}
void resize_nodes(size_t new_size) {
nodes_.resize(new_size);
}
// Index safe, less verbose way to access nodes
inline const Node& node(size_t idx) const {
return nodes_.at(idx);
}
inline Node& node(size_t idx) {
return nodes_.at(idx);
}
inline bool is_node_active(size_t idx) {
return node(idx).active;
}
inline const std::set<string>& external_input() const {
return external_input_;
}
inline const std::set<string>& external_output() const {
return external_output_;
}
private:
const std::vector<std::pair<string, int>> GetSubgraphPerimeterHelper(
bool from_children,
const std::vector<int>& match);
// Stores the netdef representation. Is updated upon calls to GetNetDef.
NetDef netdef_;
// Stores which blobs the graph reads from, and writes to.
std::set<string> external_input_;
std::set<string> external_output_;
// Keeps track of all the Operators currently within graph, even if inactive.
std::vector<Node> nodes_;
};
} // namespace transform
// Adds an operator def to a netdef.
// Returns the ptr, if you want to add anything extra (such as device_option)
TORCH_API OperatorDef* AddOp(
NetDef* netdef_ptr,
string op_type,
std::vector<string> inputs,
std::vector<string> outputs);
/**
* This allows for the use of * and | to match operator types,
* engines, or any other property that is represented by strings.
*
* For example, if we wanted to match an operator to Conv or FC, we can give:
* "Conv|FC" as the type() of that op.
*/
TORCH_API bool MatchStrings(string p, string s);
/**
* This ensures that each named arg that exists in the pattern exists in g_op,
* is equal in value.
*/
TORCH_API bool MatchArguments(const OperatorDef& p_op, const OperatorDef& g_op);
} // namespace caffe2