Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions python/tvm/contrib/pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, module):
self._get_input = self.module["get_input"]
self._get_output = self.module["get_output"]
self._get_num_outputs = self.module["get_num_outputs"]
self._get_num_inputs = self.module["get_num_inputs"]
self._get_input_pipeline_map = self.module["get_input_pipeline_map"]
self._get_pipe_execute_count = self.module["get_execute_count"]

Expand Down Expand Up @@ -159,6 +160,16 @@ def num_outputs(self):
"""
return self._get_num_outputs()

@property
def num_inputs(self):
"""Get the number of inputs
Returns
-------
count : int
The number of inputs
"""
return self._get_num_inputs()

@staticmethod
def load_library(config_file_name):
"""Import files to create a pipeline executor.
Expand Down
8 changes: 7 additions & 1 deletion src/runtime/pipeline/pipeline_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
if (name == "get_num_outputs") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); });
} else if (name == "get_num_inputs") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); });
} else if (name == "get_input_pipeline_map") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (String::CanConvertFrom(args[0])) {
Expand Down Expand Up @@ -87,7 +90,10 @@ PackedFunc PipelineExecutor::GetFunction(const std::string& name,
return PackedFunc();
}
}

/*!
* brief Returns number of global inputs.
*/
int PipelineExecutor::NumInputs(void) { return input_connection_config_.GetInputNum(); }
/*!
* \brief set input to the runtime module.
* \param input_name The input name.
Expand Down
1 change: 1 addition & 0 deletions src/runtime/pipeline/pipeline_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class TVM_DLL PipelineExecutor : public ModuleNode {
int NumOutputs() const { return num_outputs_; }
/*!\brief Run the pipeline executor.*/
void Run();
int NumInputs();
/*!
* \brief Get a list output data.
* \return A list of output data.
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/pipeline/pipeline_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,9 @@ struct InputConnectionConfig {
}
return input_connection[key];
}
/*!\brief Returns the number of global inputs through the input_runtime_map list size.*/
int GetInputNum() { return input_runtime_map.size(); }

/*!
* \brief Getting the global input index through the input name.
* \param input_name The global input name.
Expand Down
2 changes: 2 additions & 0 deletions tests/python/relay/test_pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,8 @@ def test_pipeline():
if input_map[0] == "0":
input_data = pipeline_module_test.get_input("data_a")
tvm.testing.assert_allclose(data, input_data.numpy())

assert pipeline_module_test.num_inputs == 2
# Running the pipeline executor in the pipeline mode.
pipeline_module_test.run()

Expand Down