Skip to content

Commit ce29f02

Browse files
authored
[USMP] Adding support for U4 usecase (#10785)
* [USMP] Adding support for U4 usecase This commit adds support for placing I/O tensors within the workspace buffer. This is enabled using PassConfig option tir.usmp.use_workspace_io. Once it is enabled, it will remove the I/O tensors from the TIR main PrimFunc and replace them with Allocate nodes that is annotated to contain Input and Output tensors. The USMP will plan memory for them accordingly. (i.e. it will re-use space used by them for intermediaries depending on the liveness). This will only be supported with C Interface API. Thus, this commit produces two functions to the metadata sources to obtain input and output structs that points to location inside the workspace struct. Change-Id: I4c7e750ead9a880ba900602c17f53a125f97dbf9 * fixup! [USMP] Adding support for U4 usecase Change-Id: I78f03d36b12b4a5e8eae8d11701f51019489defc * fixup! [USMP] Adding support for U4 usecase Change-Id: I857f3d0ba7bc192d56d750c44b232998b2876e7a
1 parent 871d4ef commit ce29f02

File tree

19 files changed

+1086
-195
lines changed

19 files changed

+1086
-195
lines changed

include/tvm/tir/usmp/transform.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@ TVM_DLL Pass ConvertPoolAllocationsToOffsets(const Map<tir::Stmt, PoolAllocation
5656
*/
5757
TVM_DLL Pass AssignPoolInfo();
5858

59+
/*!
60+
* \brief This pass creates Allocate nodes for I/O tensors
61+
*
62+
* If the user wants to place the I/O tensors in the workspace, this pass is required to be
63+
* run. In doing so, it will create Allocate nodes for I/O tensors to be planned, and be removed
64+
* from function arguments.
65+
*
66+
* \return the pass
67+
*/
68+
TVM_DLL Pass CreateAllocatesForIO();
69+
5970
} // namespace transform
6071
} // namespace usmp
6172
} // namespace tir

include/tvm/tir/usmp/utils.h

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,20 @@ constexpr const char* kUSMPEnableOption = "tir.usmp.enable";
4141
* \brief PassContext option to select the memory planning algorithm in USMP
4242
*/
4343
constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm";
44+
/*!
45+
* \brief PassContext option to enable placing I/O tensors in the workspace
46+
*/
47+
constexpr const char* kUSMPUseWorkspaceIO = "tir.usmp.use_workspace_io";
4448

4549
namespace tir {
4650
namespace usmp {
4751

52+
/*!
53+
* \brief A special kind to distinguish between I/O tensors to the model
54+
* and intermediate tensors of the model
55+
*/
56+
enum class BufferInfoKind { kIntermediate = 0, kInput = 1, kOutput = 2 };
57+
4858
/*!
4959
* \brief Describes an abstract memory buffer that will get allocated inside a pool.
5060
* The actual memory buffer in represented by PoolAllocationNode after static memory planning.
@@ -65,19 +75,22 @@ struct BufferInfoNode : public Object {
6575
Integer alignment;
6676
/*! \brief The liveness conflicting other buffer info objects */
6777
Array<ObjectRef> conflicts;
78+
/*! \brief Whether BufferInfo object retains info about IO tensors or intermediaries */
79+
BufferInfoKind kind;
6880

6981
void VisitAttrs(tvm::AttrVisitor* v) {
7082
v->Visit("name_hint", &name_hint);
7183
v->Visit("size_bytes", &size_bytes);
7284
v->Visit("pool_candidates", &pool_candidates);
7385
v->Visit("alignment", &alignment);
7486
v->Visit("conflicts", &conflicts);
87+
v->Visit("kind", &kind);
7588
}
7689

7790
bool SEqualReduce(const BufferInfoNode* other, SEqualReducer equal) const {
7891
return equal(name_hint, other->name_hint) && equal(size_bytes, other->size_bytes) &&
7992
equal(pool_candidates, other->pool_candidates) && equal(alignment, other->alignment) &&
80-
equal(conflicts, other->conflicts);
93+
equal(conflicts, other->conflicts) && equal(kind, other->kind);
8194
}
8295

8396
void SHashReduce(SHashReducer hash_reduce) const {
@@ -86,6 +99,7 @@ struct BufferInfoNode : public Object {
8699
hash_reduce(alignment);
87100
hash_reduce(conflicts);
88101
hash_reduce(pool_candidates);
102+
hash_reduce(kind);
89103
}
90104
/*!
91105
* \brief Set the liveness conflicts of this BufferInfo
@@ -101,7 +115,8 @@ struct BufferInfoNode : public Object {
101115
class BufferInfo : public ObjectRef {
102116
public:
103117
TVM_DLL BufferInfo(String name_hint, Integer size_bytes, Array<PoolInfo> pool_candidates,
104-
Integer alignment = runtime::kDefaultWorkspaceAlignment);
118+
Integer alignment = runtime::kDefaultWorkspaceAlignment,
119+
BufferInfoKind kind = BufferInfoKind::kIntermediate);
105120
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfo, ObjectRef, BufferInfoNode);
106121
};
107122

@@ -237,6 +252,18 @@ Integer CalculateModuleWorkspaceSize(const IRModule& mod);
237252
*/
238253
static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_pools";
239254

255+
/*!
256+
* \brief The allocate node attribute to indicate it is being used to hold
257+
* an input tensor, that needs to be initialized with.
258+
*/
259+
static constexpr const char* kInputTensorAllocate = "input_tensor";
260+
261+
/*!
262+
* \brief The allocate node attribute to indicate it is being used to hold
263+
* an output tensor.
264+
*/
265+
static constexpr const char* kOutputTensorAllocate = "output_tensor";
266+
240267
/*!
241268
* \brief Calculate the size of the extents in bytes
242269
*
@@ -254,6 +281,16 @@ Map<Stmt, PoolAllocation> AssignStmtPoolAllocations(
254281
const Map<BufferInfo, Stmt>& buffer_info_to_stmt,
255282
const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation);
256283

284+
/*!
285+
* \brief Obtains I/O tensor names to their PoolAllocation objects
286+
*
287+
* \param buffer_info_to_pool_allocation the map of BufferInfo objects to PoolAllocation objects
288+
*
289+
* This function will obtain pool allocations for I/O tensors if that had been planned
290+
*/
291+
Map<String, PoolAllocation> GetIOPoolAllocations(
292+
const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation);
293+
257294
} // namespace usmp
258295
} // namespace tir
259296

@@ -265,10 +302,10 @@ namespace attr {
265302
static constexpr const char* kPoolArgs = "pool_args";
266303

267304
/*!
268-
* \brief This is a IRModule attribute that contains all the PoolInfo objects
269-
* as an Array.
305+
* \brief This is a IRModule attribute that contains I/O Tensor names to pool
306+
* allocations.
270307
*/
271-
static constexpr const char* kPoolInfoIRModuleAttr = "pool_infos";
308+
static constexpr const char* kIOTensorPoolAllocations = "io_tensor_pool_allocations";
272309

273310
} // namespace attr
274311

python/tvm/micro/model_library_format.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,15 @@ class UnsupportedInModelLibraryFormatError(Exception):
4747

4848

4949
def generate_c_interface_header(
50-
module_name, inputs, outputs, pools, devices, workspace_size, include_path
50+
module_name, inputs, outputs, pools, io_pool_allocations, devices, workspace_size, include_path
5151
):
5252
"""Generate C Interface header to be included in MLF"""
5353
mangled_name = to_c_variable_style(prefix_generated_name(module_name))
5454
metadata_header = os.path.join(include_path, f"{mangled_name}.h")
5555

5656
interface_c_create = tvm._ffi.get_global_func("runtime.InterfaceCCreate")
5757
interface_c_module = interface_c_create(
58-
module_name, inputs, outputs, pools, devices, workspace_size
58+
module_name, inputs, outputs, pools, io_pool_allocations, devices, workspace_size
5959
)
6060

6161
with open(metadata_header, "w") as header_file:
@@ -281,24 +281,19 @@ def _convert_tuple_to_outputs(ret_type, offset=0):
281281

282282

283283
def _get_inputs_and_outputs_from_module(mod):
284-
main_func = _get_main_relay_func(mod)
285-
inputs = [argument.name_hint for argument in main_func.params]
286-
287-
if "output_tensor_names" in main_func.attrs:
288-
outputs = main_func.attrs["output_tensor_names"]
289-
else:
290-
if isinstance(main_func.ret_type, TupleType):
291-
outputs = _convert_tuple_to_outputs(main_func.ret_type)
292-
else:
293-
outputs = ["output"]
294-
284+
inputs = [str(input_var.name) for input_var in mod.executor_codegen_metadata.inputs]
285+
outputs = list(mod.executor_codegen_metadata.outputs)
295286
return inputs, outputs
296287

297288

298289
def _get_pools_from_module(mod):
299290
return list(dict(mod.executor_codegen_metadata.pool_inputs).values())
300291

301292

293+
def _get_io_pool_allocation_from_module(mod):
294+
return dict(mod.executor_codegen_metadata.io_pool_allocations)
295+
296+
302297
def _should_generate_interface_header(mod):
303298
return "interface-api" in mod.executor and mod.executor["interface-api"] == "c"
304299

@@ -369,9 +364,17 @@ def _export_graph_model_library_format(
369364
inputs, outputs = _get_inputs_and_outputs_from_module(mod)
370365
devices = mod.get_devices()
371366
pools = _get_pools_from_module(mod)
367+
io_pool_allocations = _get_io_pool_allocation_from_module(mod)
372368
workspace_size = int(metadata["memory"]["functions"]["main"][0]["workspace_size_bytes"])
373369
generate_c_interface_header(
374-
mod.libmod_name, inputs, outputs, pools, devices, workspace_size, include_path
370+
mod.libmod_name,
371+
inputs,
372+
outputs,
373+
pools,
374+
io_pool_allocations,
375+
devices,
376+
workspace_size,
377+
include_path,
375378
)
376379

377380
parameters_dir = tempdir / "parameters"

src/relay/backend/aot_executor_codegen.cc

Lines changed: 78 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -784,13 +784,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
784784
* brief Create tir::Var for input/output while updating
785785
* the buffer_maps.
786786
*/
787-
void CreateIOVar(const Expr& expr, std::string name) {
787+
void CreateIOVar(const Expr& expr, const std::string& original_name,
788+
bool use_unique_name = true) {
788789
if (expr->IsInstance<TupleNode>()) {
789790
Tuple tuple = Downcast<Tuple>(expr);
790791
for (unsigned i = 0; i < tuple->fields.size(); i++) {
791-
CreateIOVar(tuple->fields[i], name + std::to_string(i) + "_");
792+
CreateIOVar(tuple->fields[i], original_name);
792793
}
793794
} else {
795+
std::string name = original_name;
796+
if (use_unique_name) {
797+
name = GetUniqueIOVarName(original_name);
798+
}
794799
tir::Var var = tir::Var(name, DataType::Handle());
795800
main_signature_.push_back(var);
796801
auto tensor_type = expr->checked_type().as<TensorTypeNode>();
@@ -804,6 +809,19 @@ class AOTExecutorCodegen : public MixedModeVisitor {
804809
}
805810
}
806811

812+
/*!
813+
* brief Create a unique name for I/O Var
814+
*/
815+
std::string GetUniqueIOVarName(std::string name) {
816+
if (io_var_names_.find(name) == io_var_names_.end()) {
817+
io_var_names_[name] = 1;
818+
return name;
819+
} else {
820+
io_var_names_[name] = io_var_names_[name] + 1;
821+
return name + std::to_string(io_var_names_[name]);
822+
}
823+
}
824+
807825
/*!
808826
* brief Calculate workspace sizes for PrimFuncs in the IRModule
809827
*/
@@ -945,6 +963,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
945963
std::vector<tir::Stmt> stmts_;
946964
/*! \brief the list of return sids (note that the function might return more then one output */
947965
std::vector<int> return_sid_;
966+
/*! \brief This is per IO var name counter to aid the generating unique names */
967+
std::unordered_map<std::string, int> io_var_names_;
948968

949969
public:
950970
AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host)
@@ -1032,7 +1052,10 @@ class AOTExecutorCodegen : public MixedModeVisitor {
10321052
for (auto input : lowered_main_func->params) {
10331053
input_vars_.push_back(input);
10341054
std::string input_name = SanitizeName(input->name_hint());
1035-
CreateIOVar(input, input_name);
1055+
// We dont want the compiler changing input names in the
1056+
// event of a sanitization collision. Therefore, enforcing
1057+
// the var created to use the input_name strictly.
1058+
CreateIOVar(input, input_name, /*use_unique_name = */ false);
10361059
}
10371060

10381061
// Define the storage allocator ids
@@ -1052,7 +1075,27 @@ class AOTExecutorCodegen : public MixedModeVisitor {
10521075
// Retrieve the return sids
10531076
return_sid_ = final_aot_allocator.GetReturnIds();
10541077
// Insert outputs to main func signature
1055-
CreateIOVar(lowered_main_func->body, "output");
1078+
// If output tensor names were provided use them
1079+
if (auto opt = func->GetAttr<Array<String>>("output_tensor_names")) {
1080+
Array<String> output_tensor_names = opt.value();
1081+
if (lowered_main_func->body->IsInstance<TupleNode>()) {
1082+
Tuple output_tuple = Downcast<Tuple>(lowered_main_func->body);
1083+
for (unsigned i = 0; i < output_tuple->fields.size(); i++) {
1084+
// AoT Executor Codegen does not create these names,
1085+
// thus should be used as they are provided.
1086+
CreateIOVar(output_tuple->fields[i], output_tensor_names[i],
1087+
/*use_unique_name = */ false);
1088+
}
1089+
} else {
1090+
// AoT Executor Codegen does not create these names,
1091+
// thus should be used as they are provided.
1092+
CreateIOVar(lowered_main_func->body, output_tensor_names[0], /*use_unique_name = */ false);
1093+
}
1094+
} else {
1095+
// If output tensor names are not provided we will generate output(x)
1096+
// where x is a counter to create unique names.
1097+
CreateIOVar(lowered_main_func->body, "output");
1098+
}
10561099

10571100
CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar, String>>("device_contexts").value());
10581101
VisitExpr(lowered_main_func->body);
@@ -1071,8 +1114,27 @@ class AOTExecutorCodegen : public MixedModeVisitor {
10711114
// AoT Executor codegen works completely on TIR beyond this point, hence removing relay main
10721115
// function and replacing it with its TIR version. We should try to make this a Pass.
10731116
lowered_mod->Remove(lowered_mod->GetGlobalVar("main"));
1074-
auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
1075-
lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_module_main), prim_func);
1117+
auto tir_main_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
1118+
// Extract additional information around main TIR PrimFunc arguments
1119+
Array<String> devices = ListDevices();
1120+
const auto main_func_params_end_iterator =
1121+
tir_main_func->params.begin() + tir_main_func->params.size();
1122+
const auto outputs_begin_iterator =
1123+
main_func_params_end_iterator - return_sid_.size() - devices.size();
1124+
Array<tir::Var> inputs = Array<tir::Var>(tir_main_func->params.begin(), outputs_begin_iterator);
1125+
Array<TensorType> input_tensor_types;
1126+
for (auto i : inputs) {
1127+
input_tensor_types.push_back(io_tensor_types_[i]);
1128+
}
1129+
Array<tir::Var> outputs =
1130+
Array<tir::Var>(outputs_begin_iterator, main_func_params_end_iterator - devices.size());
1131+
std::vector<String> output_var_names;
1132+
for (const tir::Var& output : outputs) {
1133+
output_var_names.push_back(output->name_hint);
1134+
}
1135+
1136+
Array<TensorType> output_tensor_types{final_aot_allocator.GetReturnTtypes()};
1137+
lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_module_main), tir_main_func);
10761138
// Parallel for loops are not supported in AoT codegen.
10771139
lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod);
10781140

@@ -1109,9 +1171,10 @@ class AOTExecutorCodegen : public MixedModeVisitor {
11091171

11101172
ret.external_mods = external_modules.value();
11111173

1174+
// Extract USMP metadata to pass onto metadata sources
11121175
Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info;
11131176
std::vector<tir::Var> pool_vars;
1114-
tir::PrimFunc tir_main_func =
1177+
tir_main_func =
11151178
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
11161179
Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
11171180
tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
@@ -1122,41 +1185,16 @@ class AOTExecutorCodegen : public MixedModeVisitor {
11221185
pool_var_info.Set(tir_main_func->params[pool_var_index], allocated_pool_info);
11231186
}
11241187
}
1125-
Array<String> devices = ListDevices();
1126-
Array<tir::Var> inputs =
1127-
Array<tir::Var>(tir_main_func->params.begin(),
1128-
tir_main_func->params.begin() + tir_main_func->params.size() -
1129-
return_sid_.size() - pool_vars.size() - devices.size());
1188+
Map<String, tir::usmp::PoolAllocation> io_pool_allocations =
1189+
lowered_mod
1190+
->GetAttr<Map<String, tir::usmp::PoolAllocation>>(tvm::attr::kIOTensorPoolAllocations)
1191+
.value_or({});
11301192

1131-
Array<TensorType> input_tensor_types;
1132-
for (auto i : inputs) {
1133-
input_tensor_types.push_back(io_tensor_types_[i]);
1134-
}
1135-
1136-
std::vector<String> output_var_names;
1137-
if (auto opt = func->GetAttr<Array<String>>("output_tensor_names")) {
1138-
Array<String> output_tensor_names = opt.value();
1139-
for (size_t i = 0; i < output_tensor_names.size(); ++i) {
1140-
output_var_names.push_back(output_tensor_names[i]);
1141-
}
1142-
}
1143-
1144-
// If output names have not been specified then generate default output names
1145-
if (output_var_names.size() == 0) {
1146-
if (return_sid_.size() == 1) {
1147-
output_var_names.push_back(String("output"));
1148-
} else {
1149-
for (size_t i = 0; i < return_sid_.size(); ++i) {
1150-
output_var_names.push_back(String("output" + std::to_string(i)));
1151-
}
1152-
}
1153-
}
1154-
1155-
Array<TensorType> output_tensor_types{final_aot_allocator.GetReturnTtypes()};
1193+
ret.metadata =
1194+
ExecutorCodegenMetadata(inputs, input_tensor_types, output_var_names, output_tensor_types,
1195+
pool_vars, devices, runtime::kTvmExecutorAot, mod_name,
1196+
interface_api, unpacked_api, pool_var_info, io_pool_allocations);
11561197

1157-
ret.metadata = ExecutorCodegenMetadata(
1158-
inputs, input_tensor_types, output_var_names, output_tensor_types, pool_vars, devices,
1159-
runtime::kTvmExecutorAot, mod_name, interface_api, unpacked_api, pool_var_info);
11601198
return ret;
11611199
}
11621200

src/relay/backend/utils.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ ExecutorCodegenMetadata::ExecutorCodegenMetadata(
185185
Array<tir::Var> inputs, Array<TensorType> input_tensor_types, Array<String> outputs,
186186
Array<TensorType> output_tensor_types, Array<tir::Var> pools, Array<String> devices,
187187
String executor, String mod_name, String interface_api, bool unpacked_api,
188-
Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_inputs) {
188+
Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_inputs,
189+
Map<String, tir::usmp::PoolAllocation> io_pool_allocations) {
189190
auto n = make_object<ExecutorCodegenMetadataNode>();
190191
n->inputs = inputs;
191192
n->input_tensor_types = input_tensor_types;
@@ -198,6 +199,7 @@ ExecutorCodegenMetadata::ExecutorCodegenMetadata(
198199
n->unpacked_api = unpacked_api;
199200
n->mod_name = mod_name;
200201
n->pool_inputs = pool_inputs;
202+
n->io_pool_allocations = io_pool_allocations;
201203
data_ = std::move(n);
202204
}
203205

0 commit comments

Comments
 (0)