forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
net_async_task_graph.h
78 lines (54 loc) · 2.2 KB
/
net_async_task_graph.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#ifndef CAFFE2_NET_ASYNC_TASK_GRAPH_H
#define CAFFE2_NET_ASYNC_TASK_GRAPH_H
#include "caffe2/core/net_async_base.h"
#include "caffe2/core/net_async_task.h"
#include "caffe2/core/net_async_task_future.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
// AsyncTaskGraph represents an execution of a net, it owns the tasks and
// associated futures, sets up future callbacks and propagates errors.
// Usage steps:
// - Adding graph nodes and edges through CreateNode/AddDependency;
// - Freezing the graph (FreezeGraph), after the freezing a future
// can be obtained using GetFuture;
// - Execution of the graph is scheduled through ExecuteGraph, after each
// execution Reset must be called to prepare the graph for the next run
class AsyncTaskGraphBase {
public:
virtual bool CreateNode(
int node_id,
const std::vector<OperatorBase*>& ops) = 0;
virtual bool AddDependency(
int child_node_id,
const std::vector<int>& parent_node_ids) = 0;
virtual void FreezeGraph() = 0;
virtual AsyncTaskFuture* ExecuteGraph() = 0;
virtual AsyncTaskFuture* GetFuture() = 0;
virtual void Reset() = 0;
virtual ~AsyncTaskGraphBase() noexcept {}
};
class AsyncTaskGraph : public AsyncTaskGraphBase {
public:
AsyncTaskGraph(ExecutorHelper* helper, const ExecutionOptions& options);
bool CreateNode(int node_id, const std::vector<OperatorBase*>& ops) override;
bool AddDependency(int child_node_id, const std::vector<int>& parent_node_ids)
override;
void FreezeGraph() override;
AsyncTaskFuture* ExecuteGraph() override;
AsyncTaskFuture* GetFuture() override;
void Reset() override;
private:
// used to, e.g., get access to executor's thread pools
// TODO: pass tracer and counters through ExecutorHelper
ExecutorHelper* helper_;
ExecutionOptions options_;
bool frozen_;
std::unordered_map<int, std::unique_ptr<AsyncTask>> nodes_;
std::unordered_map<int, std::unordered_set<int>> parents_;
std::unordered_map<int, std::unordered_set<int>> children_;
std::vector<std::unique_ptr<AsyncTaskFuture>> edge_futures_;
std::vector<AsyncTask*> root_tasks_;
std::unique_ptr<AsyncTaskFuture> run_future_;
};
} // namespace caffe2
#endif // CAFFE2_NET_ASYNC_TASK_GRAPH_H