diff --git a/include/onnxruntime/core/framework/kernel_def_builder.h b/include/onnxruntime/core/framework/kernel_def_builder.h index dd30683266983..0c5c46eb19af7 100644 --- a/include/onnxruntime/core/framework/kernel_def_builder.h +++ b/include/onnxruntime/core/framework/kernel_def_builder.h @@ -20,13 +20,15 @@ class KernelDefBuilder; typedef std::map MemTypeMap; // note that input/output might be on CPU implicitly when the node is from CPU execution provider -inline bool MemTypeOnCpuExplicitly(const MemTypeMap& mem_type_map, size_t index) { - auto iter = mem_type_map.find(index); - return iter != mem_type_map.end() && (iter->second == OrtMemTypeCPUInput || iter->second == OrtMemTypeCPUOutput); +inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) { + return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput; } class KernelDef { public: + explicit KernelDef() : default_inputs_mem_type_(OrtMemTypeDefault), default_outputs_mem_type_(OrtMemTypeDefault) { + } + const std::string& OpName() const { return op_name_; } @@ -56,17 +58,20 @@ class KernelDef { return alias_map_; } - const MemTypeMap& InputMemoryType() const { - return input_memory_type_args_; - } - - const MemTypeMap& OutputMemoryType() const { - return output_memory_type_args_; + OrtMemType InputMemoryType(size_t input_index) const { + auto it = input_memory_type_args_.find(input_index); + if (it == input_memory_type_args_.end()) + return default_inputs_mem_type_; + else + return it->second; } - // legacy interface for winml, should not be used in onnxruntime - const MemTypeMap& MemoryType() const { - return output_memory_type_args_; + OrtMemType OutputMemoryType(size_t output_index) const { + auto it = output_memory_type_args_.find(output_index); + if (it == output_memory_type_args_.end()) + return default_outputs_mem_type_; + else + return it->second; } int ExecQueueId() const { @@ -111,6 +116,10 @@ class KernelDef { // execution command queue id, 0 for default queue in execution provider int exec_queue_id_ = 0; + // Default memory type for all inputs + OrtMemType default_inputs_mem_type_; + // Default memory type for all outputs + OrtMemType default_outputs_mem_type_; }; class KernelDefBuilder { @@ -212,6 +221,22 @@ class KernelDefBuilder { return *this; } + /** + Specify the default inputs memory type, if not specified, it is DefaultMemory + */ + KernelDefBuilder& SetDefaultInputsMemoryType(OrtMemType mem_type) { + kernel_def_->default_inputs_mem_type_ = mem_type; + return *this; + } + + /** + Specify the default outputs memory type, if not specified, it is DefaultMemory + */ + KernelDefBuilder& SetDefaultOutputMemoryType(OrtMemType mem_type) { + kernel_def_->default_outputs_mem_type_ = mem_type; + return *this; + } + /** Return the kernel definition, passing ownership of the KernelDef to the caller */ diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 277a6ff2ac1e2..9bc25e9303fda 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -380,7 +380,6 @@ class PlannerImpl { ORT_ENFORCE(exec_provider); auto& default_allocator_info = exec_provider->GetAllocator(0, OrtMemTypeDefault)->Info(); - auto& mem_type_allocated_args = p_kernelDef->OutputMemoryType(); auto& outputs = pnode->OutputDefs(); auto num_outputs = outputs.size(); @@ -393,11 +392,11 @@ class PlannerImpl { if (strcmp(default_allocator_info.name, CPU) != 0) { // By default, outputs of this node are allocated on the default device allocator, // except for outputs marked for allocation in MemoryType: - auto memory_type_iter = mem_type_allocated_args.find(i); - if (memory_type_iter == mem_type_allocated_args.end()) { + auto memory_type = p_kernelDef->OutputMemoryType(i); + if (memory_type == OrtMemTypeDefault) { AllocPlan(index).location = default_allocator_info; } else { - AllocPlan(index).location = exec_provider->GetAllocator(0, memory_type_iter->second)->Info(); + AllocPlan(index).location = exec_provider->GetAllocator(0, memory_type)->Info(); } } } @@ -438,7 +437,7 @@ class PlannerImpl { thisplan.alloc_kind = AllocKind::kAllocateStatically; auto p_opkernelDef = utils::GetKernelDef(kernel_registry_, node); - if (MemTypeOnCpuExplicitly(p_opkernelDef->InputMemoryType(), index)) + if (MemTypeOnCpuExplicitly(p_opkernelDef->InputMemoryType(index))) // weights are not output from any node, so it's OK to put its location on CPU provider thisplan.location = execution_providers_.Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(0, OrtMemTypeDefault)->Info(); else diff --git a/onnxruntime/core/framework/kernel_def_builder.cc b/onnxruntime/core/framework/kernel_def_builder.cc index 8555adc7563de..12d0fd936897c 100644 --- a/onnxruntime/core/framework/kernel_def_builder.cc +++ b/onnxruntime/core/framework/kernel_def_builder.cc @@ -66,20 +66,20 @@ bool KernelDef::IsConflict(const KernelDef& other) const { return false; //check memory type - auto other_input_mem_types = other.InputMemoryType(); + auto& other_input_mem_types = other.input_memory_type_args_; for (auto it : input_memory_type_args_) { - if (other_input_mem_types.count(it.first) && other_input_mem_types[it.first] == it.second) + if (other_input_mem_types.count(it.first) && other_input_mem_types.find(it.first)->second == it.second) return false; } - if (input_memory_type_args_.empty() && !other.InputMemoryType().empty()) + if (input_memory_type_args_.empty() && !other.input_memory_type_args_.empty()) return false; - auto other_output_mem_types = other.OutputMemoryType(); + auto& other_output_mem_types = other.output_memory_type_args_; for (auto it : output_memory_type_args_) { - if (other_output_mem_types.count(it.first) && other_output_mem_types[it.first] == it.second) + if (other_output_mem_types.count(it.first) && other_output_mem_types.find(it.second)->second == it.second) return false; } - return !(output_memory_type_args_.empty() && !other.OutputMemoryType().empty()); + return !(output_memory_type_args_.empty() && !other.output_memory_type_args_.empty()); } KernelDefBuilder& KernelDefBuilder::SetName(const std::string& op_name) { diff --git a/onnxruntime/core/framework/transformer_memcpy.cc b/onnxruntime/core/framework/transformer_memcpy.cc index a4faef010d5c0..b04f278599324 100644 --- a/onnxruntime/core/framework/transformer_memcpy.cc +++ b/onnxruntime/core/framework/transformer_memcpy.cc @@ -68,25 +68,24 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg // note KernelCreateInfo might be nullptr for custom kernel const KernelCreateInfo* kci = nullptr; kernel_registries.SearchKernelRegistry(node, &kci); - const auto* input_mem_types = kci ? &kci->kernel_def->InputMemoryType() : nullptr; - const auto* output_mem_types = kci ? &kci->kernel_def->InputMemoryType() : nullptr; + ORT_ENFORCE(onnxruntime::Node::ForEachWithIndex( - node.InputDefs(), - [this, &input_mem_types](const onnxruntime::NodeArg& arg, size_t index) { - if (input_mem_types && MemTypeOnCpuExplicitly(*input_mem_types, index)) - non_provider_input_defs_.insert(&arg); - else - provider_input_defs_.insert(&arg); - return Status::OK(); - }) - .IsOK()); + node.InputDefs(), + [this, &kci](const onnxruntime::NodeArg& arg, size_t index) { + if (kci && MemTypeOnCpuExplicitly(kci->kernel_def->InputMemoryType(index))) + non_provider_input_defs_.insert(&arg); + else + provider_input_defs_.insert(&arg); + return Status::OK(); + }) + .IsOK()); auto& output_defs = node.MutableOutputDefs(); for (size_t i = 0; i < output_defs.size(); ++i) { auto arg = output_defs[i]; if (!arg->Exists()) continue; - if (output_mem_types && MemTypeOnCpuExplicitly(*output_mem_types, i)) + if (kci && MemTypeOnCpuExplicitly(kci->kernel_def->OutputMemoryType(i))) non_provider_output_defs_.insert(arg); else provider_output_defs_.insert(arg); diff --git a/onnxruntime/core/session/IOBinding.cc b/onnxruntime/core/session/IOBinding.cc index 67a33be805612..cd24146d71105 100644 --- a/onnxruntime/core/session/IOBinding.cc +++ b/onnxruntime/core/session/IOBinding.cc @@ -60,10 +60,9 @@ common::Status IOBinding::CopyOneInputAcrossDevices(const SessionState& session_ size_t index = node_info.index; auto& node = *node_info.p_node; const KernelCreateInfo* kci = node_info.kci; - const auto* node_input_mem_types = (kci != nullptr) ? &kci->kernel_def->InputMemoryType() : nullptr; // node may declare input_mem_type to be on CPU explicitly - bool node_input_on_cpu = node_input_mem_types && MemTypeOnCpuExplicitly(*node_input_mem_types, index); + bool node_input_on_cpu = kci && MemTypeOnCpuExplicitly(kci->kernel_def->InputMemoryType(index)); auto& required_provider_type = node_input_on_cpu ? onnxruntime::kCpuExecutionProvider : node.GetExecutionProviderType(); if (!orig_mlvalue.IsTensor()) { // copying not supported for non-tensor types