Skip to content

Commit a374cdd

Browse files
authored
[Runtime][Pipeline Executor] Add the map logic of global input and subgraph input. (#9751)
* [Runtime][Pipeline Executor] Add the map logic of global input and subgraph input. User can use "global input name" to feed input data for pipeline runtime. The name like "data_a" will be mapped into a input interface of subgraph. In this PR, we create the related logic to do the following things. 1. building the input map configuration 2. in runtime c++ module, parseing the input connection configuration then creating related data structure to record the said connection map. 3. providing the function to return the map information for verification. * address review comments. * addres review comments. * address review comments.
1 parent bd61d18 commit a374cdd

File tree

7 files changed

+306
-107
lines changed

7 files changed

+306
-107
lines changed

python/tvm/contrib/pipeline_executor.py

Lines changed: 114 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,26 @@ def build(pipe_configs):
4949
Common interface for pipeline executor factory modules.
5050
"""
5151
libs = {}
52-
mod_n_configs = pipe_configs.get_config()
52+
config = pipe_configs.get_config()
53+
if "module_connection" not in config:
54+
raise RuntimeError('"module_connection" is missing')
55+
if "input_connection" not in config:
56+
raise RuntimeError('"input_connection" is missing')
57+
58+
mod_n_configs = config["module_connection"]
5359
config_len = len(mod_n_configs)
54-
string_config = [{} for _ in range(config_len)]
60+
module_string_config = [{} for _ in range(config_len)]
61+
# Use hardware configurations to build backend modules for each subgraph.
5562
for ir_mod, mod_config in mod_n_configs.items():
56-
mconf = mod_config["pipeline"].copy()
57-
mod_idx = mconf["mod_idx"]
63+
pipe_config = mod_config["pipeline"].copy()
64+
mod_idx = pipe_config["mod_idx"]
5865
dev = mod_config["dev"]
5966
target = mod_config["target"]
6067
build_func = relay.build
61-
# Check whether there is a customized build function.
68+
# Callers may need to use a customized building function to wrap the pre-building logic
69+
# and the backend building logic. For example, in order to support a backend which only
70+
# can do "int8" computation, the caller may need to merge the "quantization" logic
71+
# into the building logic to creat a customized building function.
6272
if "build" in mod_config and mod_config["build"]:
6373
build_func = mod_config["build"]
6474

@@ -70,11 +80,20 @@ def build(pipe_configs):
7080
mod_name=mod_config["mod_name"],
7181
)
7282

73-
mconf["dev"] = "{},{}".format(dev.device_type, dev.device_id)
74-
# Create a pipeline configuration.
75-
string_config[mod_idx] = mconf
83+
pipe_config["dev"] = "{},{}".format(dev.device_type, dev.device_id)
84+
# Use "mod_idx" as the key to create a "module_connection" map which is not only
85+
# for the module index but also for the module connection used to build the pipeline.
86+
module_string_config[mod_idx] = pipe_config
7687
libs[mod_idx] = {"lib": lib, "dev": dev}
7788

89+
# Creating a text form configuration to record the "input_connection" and the
90+
# "module_connection" information. The "input_connection" is used to record the
91+
# map of global input and subgraph input, and the "module_connection" is used to
92+
# record module dependency.
93+
string_config = {}
94+
string_config["input_connection"] = config["input_connection"]
95+
string_config["module_connection"] = module_string_config
96+
7897
return PipelineExecutorFactoryModule(libs, string_config)
7998

8099

@@ -94,6 +113,17 @@ def __init__(self, module):
94113
self.module = module
95114
# Get the packed functions from the pipeline executor.
96115
self._get_num_outputs = self.module["get_num_outputs"]
116+
self._get_input_pipeline_map = self.module["get_input_pipeline_map"]
117+
118+
def get_input_pipeline_map(self, name):
119+
"""Using the "name" to get the corresponding subgraph index and also get the "input name"
120+
of the corresponding subgraph interface.
121+
Returns
122+
-------
123+
input map: Array[str]
124+
Returning the index and "input name" of the subgraph.
125+
"""
126+
return self._get_input_pipeline_map(name)
97127

98128
@property
99129
def num_outputs(self):
@@ -199,12 +229,48 @@ def is_pipeline_executor_interface(self):
199229
return not isinstance(self.io_owner, PipelineConfig.ModuleWrapper)
200230

201231
def __repr__(self):
202-
# Get all binding information.
203-
ret = " |{}: ".format(self.name)
232+
# Geting the binding information in the form of text.
233+
str_format = " |{}: ".format(self.name)
204234
for binding in self.bindings:
205235
mname, dname = binding.get_name()
206-
ret += "{0}:{1} ".format(mname, dname)
207-
return ret
236+
str_format += "{0}:{1} ".format(mname, dname)
237+
238+
return str_format
239+
240+
def check_binding_dict(self, connection_dict):
241+
"""Checking the binding dictionary.
242+
Parameter
243+
---------
244+
connection_dict : Dict[str, Any]
245+
It is a dictionary of module connections.
246+
"""
247+
if "interface_name" not in connection_dict:
248+
raise RuntimeError('"inteface_name" is missing in global config!"')
249+
if "connection" not in connection_dict:
250+
raise RuntimeError(f'"connection" is missing!"')
251+
# The global interface mapping should be one-to-one.
252+
if not connection_dict["connection"]:
253+
raise RuntimeError("The global interface map is empty!")
254+
if len(connection_dict["connection"]) > 1:
255+
raise RuntimeError("A global interface maps multiple module interfaces!")
256+
if "mod_idx" not in connection_dict["connection"][0]:
257+
raise RuntimeError('"mod_idx" is missing!')
258+
259+
def get_binding_dict(self):
260+
"""Returning the binding information in the form of dictionary.
261+
Returns
262+
-------
263+
data : Dict[str, Any]
264+
The binding information is in the form of dictionary.
265+
"""
266+
dict_format = {"interface_name": self.name, "connection": []}
267+
for binding in self.bindings:
268+
_, dname = binding.get_name()
269+
midx = binding.get_owner_idx()
270+
dict_format["connection"].append({"mod_idx": midx, "interface_name": dname})
271+
272+
self.check_binding_dict(dict_format)
273+
return dict_format
208274

209275
def check_dag_acyclic(self, start, inputs):
210276
"""This is to check whether the DAG containing these input interfaces is acyclic.
@@ -243,30 +309,34 @@ def connect(self, binding):
243309

244310
# Check whether the binding setting is correct or not.
245311
if self.io_owner == binding.io_owner:
246-
raise RuntimeError(f"Can not bind itself.")
312+
raise RuntimeError("Can not bind itself.")
247313

248314
if not self.is_pipeline_executor_interface() and self.io_type == "input":
249-
raise RuntimeError(f"Module can only bind from output interface!")
315+
raise RuntimeError("Module can only bind from output interface!")
250316

251317
if (
252318
not self.is_pipeline_executor_interface()
253319
and not binding.is_pipeline_executor_interface()
254320
and binding.io_type == "output"
255321
):
256-
raise RuntimeError(f"Can not bind module output with another module output!")
322+
raise RuntimeError("Can not bind module output with another module output!")
257323

258324
if (
259325
not self.is_pipeline_executor_interface()
260326
and binding.is_pipeline_executor_interface()
261327
and binding.io_type == "input"
262328
):
263-
raise RuntimeError(f"Can not bind module output with pipeline input!")
329+
raise RuntimeError("Can not bind module output with pipeline input!")
264330

265331
if self.is_pipeline_executor_interface() and self.io_type == "output":
266-
raise RuntimeError(f"Global output can not be used as binding start point.")
332+
raise RuntimeError("Global output can not be used as binding start point.")
267333

268-
if self.is_pipeline_executor_interface() and binding.io_type != "input":
269-
raise RuntimeError(f"Global input can only bind with module input.")
334+
if (
335+
self.is_pipeline_executor_interface()
336+
and self.io_type == "input"
337+
and binding.io_type != "input"
338+
):
339+
raise RuntimeError("Global input can only bind with module input.")
270340

271341
self.bindings.append(binding)
272342
if not self.is_pipeline_executor_interface():
@@ -288,7 +358,7 @@ def connect(self, binding):
288358
if not self.check_dag_acyclic(
289359
binding.io_owner, self.io_owner.input_bindings.bindings
290360
):
291-
raise RuntimeError(f"Illegal connection: Cause a cycle!")
361+
raise RuntimeError("Illegal connection: Cause a cycle!")
292362

293363
class BindingList:
294364
"""Container for bindings(input or output interface).
@@ -357,7 +427,9 @@ def __getitem__(self, key):
357427
if key == "output":
358428
return self.output_bindings
359429

360-
raise RuntimeError(f"{key} not found!")
430+
raise RuntimeError(f"{key} not found!")
431+
432+
raise RuntimeError('The data type of "key" is not supported!')
361433

362434
def get_data_type(self, key, interface_type):
363435
"""Get the module interface data type according to the key value and interface type.
@@ -468,6 +540,8 @@ def get_config(self):
468540
# Use topological sort to get the correct order of modules.
469541
self.dag_topology_sort()
470542
mconfig = {}
543+
module_connection = {}
544+
input_connection = {}
471545
for mod in self.mod_wrapper:
472546
# Generate pipeline configuration.
473547
mconf = {}
@@ -495,7 +569,7 @@ def get_config(self):
495569
mconf["mod_idx"] = module.idx
496570
mconf["output"] = output_conf
497571

498-
mconfig[mod] = {
572+
module_connection[mod] = {
499573
"pipeline": mconf,
500574
"target_host": module.target_host,
501575
"mod_name": "default",
@@ -505,6 +579,22 @@ def get_config(self):
505579
"dev": module.dev,
506580
}
507581

582+
# Create a map of pipeline input and subgraph input.
583+
input_connection = []
584+
for input_name in self.input_bindings.bindings:
585+
input_dict = self.input_bindings.bindings[input_name].get_binding_dict()
586+
if "interface_name" not in input_dict["connection"][0]:
587+
raise RuntimeError("interface_name is missing in connection config!")
588+
# Creating the map of global interface and subgraph interface.
589+
input_map = {
590+
"global_interface_name": input_dict["interface_name"],
591+
"mod_idx": input_dict["connection"][0]["mod_idx"],
592+
"module_interface_name": input_dict["connection"][0]["interface_name"],
593+
}
594+
input_connection.append(input_map)
595+
596+
mconfig["module_connection"] = module_connection
597+
mconfig["input_connection"] = input_connection
508598
return mconfig
509599

510600
def dag_topology_sort(self):
@@ -601,11 +691,11 @@ def export_library(self, directory_path):
601691
Export the files to this directory.
602692
"""
603693
if not self.pipeline_mods:
604-
raise RuntimeError(f"The pipeline executor has not been initialized.")
694+
raise RuntimeError("The pipeline executor has not been initialized.")
605695

606696
# Check if the directory_path exists.
607697
if not os.path.exists(directory_path):
608-
raise RuntimeError(f"The directory {directory_path} does not exist.")
698+
raise RuntimeError("The directory {directory_path} does not exist.")
609699
# Create an load configuration.
610700
load_config_file_name = "{}/load_config".format(directory_path)
611701
pipeline_config_file_name = "{}/pipeline_config".format(directory_path)

src/runtime/pipeline/pipeline_executor.cc

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,32 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
3434
if (name == "get_num_outputs") {
3535
return PackedFunc(
3636
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); });
37+
} else if (name == "get_input_pipeline_map") {
38+
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
39+
if (String::CanConvertFrom(args[0])) {
40+
*rv = this->GetInputPipeplineMapping(args[0].operator String());
41+
} else {
42+
LOG(FATAL) << "Function only support the input name value in the form of string";
43+
}
44+
});
3745
} else {
3846
LOG(FATAL) << "Unknown packed function: " << name;
3947
return PackedFunc();
4048
}
4149
return nullptr;
4250
}
4351

52+
/*!
53+
* \brief Using the global input name to get the index, and also get the input interface name
54+
of corresponding subgraph from the input connection configuration.
55+
* \param The global input name.
56+
* \return Returning the index and the input interface name of corresponding subgraph.
57+
*/
58+
Array<String> PipelineExecutor::GetInputPipeplineMapping(std::string input_name) {
59+
std::pair<int, std::string> map = input_connection_config[input_name];
60+
return {std::to_string(map.first), map.second};
61+
}
62+
4463
/*!
4564
* \brief Use the mod_config information to create a graph runtime list.
4665
* \param mod_config The config information that generates by the export library function call.
@@ -108,11 +127,11 @@ void PipelineExecutor::Init(const std::vector<Module>& modules, const std::strin
108127
// Use JSONReader to load pipeline configuration.
109128
std::istringstream is(pipeline_json);
110129
dmlc::JSONReader reader(&is);
111-
PipelineConfig& pipeline_config = this->LoadPipelineConfig(&reader);
112-
ICHECK(!pipeline_config.Empty()) << "The pipeline config information is empty.";
130+
this->LoadConfig(&reader);
131+
ICHECK(!pipeline_config_.Empty()) << "The pipeline config information is empty.";
113132
// Initialize the pipeline function class used for pipeline thread pool management
114133
// and schedule etc. This function returns the number of output.
115-
num_outputs_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config);
134+
num_outputs_ = pipeline_scheduler_.PipelineInit(modules, pipeline_config_);
116135
return;
117136
}
118137

src/runtime/pipeline/pipeline_executor.h

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@
2424
#ifndef TVM_RUNTIME_PIPELINE_PIPELINE_EXECUTOR_H_
2525
#define TVM_RUNTIME_PIPELINE_PIPELINE_EXECUTOR_H_
2626

27+
#include <tvm/relay/expr.h>
2728
#include <tvm/runtime/registry.h>
2829

2930
#include <array>
3031
#include <iostream>
3132
#include <sstream>
3233
#include <string>
34+
#include <utility>
3335
#include <vector>
3436

3537
#include "pipeline_scheduler.h"
@@ -67,7 +69,13 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
6769
* \return The corresponding packed function.
6870
*/
6971
virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
70-
72+
/*!
73+
* \brief Using the global input name to get the index, and also get the input interface name
74+
of corresponding subgraph from the input connection configuration.
75+
* \param The global input name.
76+
* \return Returning the index and the input interface name of corresponding subgraph.
77+
*/
78+
Array<String> GetInputPipeplineMapping(std::string input_name);
7179
/*!
7280
* \brief Get the number of outputs.
7381
*
@@ -115,37 +123,27 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
115123
/*!\brief The class used to execute and schedule the pipeline logic.*/
116124
PipelineScheduler pipeline_scheduler_;
117125
/*!\brief The dependency information of each graph runtime module of the pipeline.*/
118-
PipelineConfig pipeline_config_;
126+
ConfigPipelineExecution pipeline_config_;
127+
/*!\brief The map of global input and subgraph input.*/
128+
InputConnectionConfig input_connection_config;
119129
/*!\brief The module information used to create the graph runtimes.*/
120130
ModuleConfig mod_config_;
121131
/*!\brief How many outputs are in this pipeline executor.*/
122132
size_t num_outputs_ = 0;
123133
/*!\brief Json loader.*/
124-
PipelineConfig& LoadPipelineConfig(dmlc::JSONReader* reader) {
125-
reader->BeginArray();
126-
while (reader->NextArrayItem()) {
127-
std::string key;
128-
reader->BeginObject();
129-
int mod_idx = -1;
130-
OutputMap output;
131-
std::string dev;
132-
while (reader->NextObjectItem(&key)) {
133-
if (key == "mod_idx") {
134-
reader->Read(&mod_idx);
135-
} else if (key == "dev") {
136-
reader->Read(&dev);
137-
} else if (key == "output") {
138-
reader->Read(&output);
139-
} else {
140-
LOG(FATAL) << "do not support key " << key;
141-
}
134+
void LoadConfig(dmlc::JSONReader* reader) {
135+
reader->BeginObject();
136+
std::string key;
137+
while (reader->NextObjectItem(&key)) {
138+
if (key == "module_connection") {
139+
reader->Read(&pipeline_config_);
140+
} else if (key == "input_connection") {
141+
reader->Read(&input_connection_config);
142+
} else {
143+
LOG(FATAL) << "do not support key " << key;
142144
}
143-
ICHECK(mod_idx >= 0) << "Invalid mod_idx value " << mod_idx;
144-
// Check if the output is successfully read.
145-
ICHECK(!output.Empty()) << "Invalid output binding result.";
146-
pipeline_config_.Insert(mod_idx, output);
147145
}
148-
return pipeline_config_;
146+
return;
149147
}
150148
};
151149
} // namespace runtime

src/runtime/pipeline/pipeline_scheduler.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace runtime {
2828
* \param pipeline_conf The dependency information of each graph executor module.
2929
*/
3030
size_t PipelineScheduler::PipelineInit(const std::vector<Module>& modules,
31-
const PipelineConfig& pipeline_config) {
31+
const ConfigPipelineExecution& pipeline_config) {
3232
graph_modules_ = modules;
3333
int num_output = pipeline_config.GetGlobalOutputNum();
3434
return num_output;

src/runtime/pipeline/pipeline_scheduler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class PipelineScheduler {
4141
* \param modules The list of graph executor module.
4242
* \param pipeline_config The dependency information of each graph executor module.
4343
*/
44-
size_t PipelineInit(const std::vector<Module>& modules, const PipelineConfig& pipeline_config);
44+
size_t PipelineInit(const std::vector<Module>& modules,
45+
const ConfigPipelineExecution& pipeline_config);
4546

4647
private:
4748
/*!\brief The list of graph executors.*/

0 commit comments

Comments
 (0)