Skip to content

Commit 8a94b66

Browse files
[Runtime][PipelineExecutor] Added Interface to Track Number of Global Inputs (#11315)
* [Runtime][PipleineExecutor] Added Interface to Track Number of Global Inputs Added a feature to PipelineExecutor to track number of Global Inputs. * Fixed CI Error * Fixed remaining CI Error
1 parent 648154d commit 8a94b66

File tree

5 files changed

+24
-1
lines changed

5 files changed

+24
-1
lines changed

python/tvm/contrib/pipeline_executor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(self, module):
5555
self._get_input = self.module["get_input"]
5656
self._get_output = self.module["get_output"]
5757
self._get_num_outputs = self.module["get_num_outputs"]
58+
self._get_num_inputs = self.module["get_num_inputs"]
5859
self._get_input_pipeline_map = self.module["get_input_pipeline_map"]
5960
self._get_pipe_execute_count = self.module["get_execute_count"]
6061

@@ -159,6 +160,16 @@ def num_outputs(self):
159160
"""
160161
return self._get_num_outputs()
161162

163+
@property
164+
def num_inputs(self):
165+
"""Get the number of inputs
166+
Returns
167+
-------
168+
count : int
169+
The number of inputs
170+
"""
171+
return self._get_num_inputs()
172+
162173
@staticmethod
163174
def load_library(config_file_name):
164175
"""Import files to create a pipeline executor.

src/runtime/pipeline/pipeline_executor.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
3434
if (name == "get_num_outputs") {
3535
return PackedFunc(
3636
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); });
37+
} else if (name == "get_num_inputs") {
38+
return PackedFunc(
39+
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); });
3740
} else if (name == "get_input_pipeline_map") {
3841
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
3942
if (String::CanConvertFrom(args[0])) {
@@ -87,7 +90,10 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
8790
return PackedFunc();
8891
}
8992
}
90-
93+
/*!
94+
* brief Returns number of global inputs.
95+
*/
96+
int PipelineExecutor::NumInputs(void) { return input_connection_config_.GetInputNum(); }
9197
/*!
9298
* \brief set input to the runtime module.
9399
* \param input_name The input name.

src/runtime/pipeline/pipeline_executor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
115115
int NumOutputs() const { return num_outputs_; }
116116
/*!\brief Run the pipeline executor.*/
117117
void Run();
118+
int NumInputs();
118119
/*!
119120
* \brief Get a list output data.
120121
* \return A list of output data.

src/runtime/pipeline/pipeline_struct.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,9 @@ struct InputConnectionConfig {
560560
}
561561
return input_connection[key];
562562
}
563+
/*!\brief Returns the number of global inputs through the input_runtime_map list size.*/
564+
int GetInputNum() { return input_runtime_map.size(); }
565+
563566
/*!
564567
* \brief Getting the global input index through the input name.
565568
* \param input_name The global input name.

tests/python/relay/test_pipeline_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,8 @@ def test_pipeline():
595595
if input_map[0] == "0":
596596
input_data = pipeline_module_test.get_input("data_a")
597597
tvm.testing.assert_allclose(data, input_data.numpy())
598+
599+
assert pipeline_module_test.num_inputs == 2
598600
# Running the pipeline executor in the pipeline mode.
599601
pipeline_module_test.run()
600602

0 commit comments

Comments
 (0)