Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions include/onnxruntime/core/framework/kernel_def_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ class KernelDefBuilder;
typedef std::map<size_t, OrtMemType> MemTypeMap;

// note that input/output might be on CPU implicitly when the node is from CPU execution provider
inline bool MemTypeOnCpuExplicitly(const MemTypeMap& mem_type_map, size_t index) {
auto iter = mem_type_map.find(index);
return iter != mem_type_map.end() && (iter->second == OrtMemTypeCPUInput || iter->second == OrtMemTypeCPUOutput);
inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) {
return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput;
}

class KernelDef {
public:
explicit KernelDef() : default_inputs_mem_type_(OrtMemTypeDefault), default_outputs_mem_type_(OrtMemTypeDefault) {
}

const std::string& OpName() const {
return op_name_;
}
Expand Down Expand Up @@ -56,17 +58,20 @@ class KernelDef {
return alias_map_;
}

const MemTypeMap& InputMemoryType() const {
return input_memory_type_args_;
}

const MemTypeMap& OutputMemoryType() const {
return output_memory_type_args_;
OrtMemType InputMemoryType(size_t input_index) const {
auto it = input_memory_type_args_.find(input_index);
if (it == input_memory_type_args_.end())
return default_inputs_mem_type_;
else
return it->second;
}

// legacy interface for winml, should not be used in onnxruntime
const MemTypeMap& MemoryType() const {
return output_memory_type_args_;
OrtMemType OutputMemoryType(size_t output_index) const {
auto it = output_memory_type_args_.find(output_index);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not introduced by your PR, but why we use map to store the types, instead of vector.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess in most case people only declare memory type for very few input/output, map seems more suitable.

if (it == output_memory_type_args_.end())
return default_outputs_mem_type_;
else
return it->second;
}

int ExecQueueId() const {
Expand Down Expand Up @@ -111,6 +116,10 @@ class KernelDef {

// execution command queue id, 0 for default queue in execution provider
int exec_queue_id_ = 0;
// Default memory type for all inputs
OrtMemType default_inputs_mem_type_;
// Default memory type for all outputs
OrtMemType default_outputs_mem_type_;
};

class KernelDefBuilder {
Expand Down Expand Up @@ -212,6 +221,22 @@ class KernelDefBuilder {
return *this;
}

/**
Specify the default inputs memory type, if not specified, it is DefaultMemory
*/
KernelDefBuilder& SetDefaultInputsMemoryType(OrtMemType mem_type) {
kernel_def_->default_inputs_mem_type_ = mem_type;
return *this;
}

/**
Specify the default outputs memory type, if not specified, it is DefaultMemory
*/
KernelDefBuilder& SetDefaultOutputMemoryType(OrtMemType mem_type) {
kernel_def_->default_outputs_mem_type_ = mem_type;
return *this;
}

/**
Return the kernel definition, passing ownership of the KernelDef to the caller
*/
Expand Down
9 changes: 4 additions & 5 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ class PlannerImpl {
ORT_ENFORCE(exec_provider);

auto& default_allocator_info = exec_provider->GetAllocator(0, OrtMemTypeDefault)->Info();
auto& mem_type_allocated_args = p_kernelDef->OutputMemoryType();
auto& outputs = pnode->OutputDefs();
auto num_outputs = outputs.size();

Expand All @@ -393,11 +392,11 @@ class PlannerImpl {
if (strcmp(default_allocator_info.name, CPU) != 0) {
// By default, outputs of this node are allocated on the default device allocator,
// except for outputs marked for allocation in MemoryType:
auto memory_type_iter = mem_type_allocated_args.find(i);
if (memory_type_iter == mem_type_allocated_args.end()) {
auto memory_type = p_kernelDef->OutputMemoryType(i);
if (memory_type == OrtMemTypeDefault) {
AllocPlan(index).location = default_allocator_info;
} else {
AllocPlan(index).location = exec_provider->GetAllocator(0, memory_type_iter->second)->Info();
AllocPlan(index).location = exec_provider->GetAllocator(0, memory_type)->Info();
}
}
}
Expand Down Expand Up @@ -438,7 +437,7 @@ class PlannerImpl {

thisplan.alloc_kind = AllocKind::kAllocateStatically;
auto p_opkernelDef = utils::GetKernelDef(kernel_registry_, node);
if (MemTypeOnCpuExplicitly(p_opkernelDef->InputMemoryType(), index))
if (MemTypeOnCpuExplicitly(p_opkernelDef->InputMemoryType(index)))
// weights are not output from any node, so it's OK to put its location on CPU provider
thisplan.location = execution_providers_.Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(0, OrtMemTypeDefault)->Info();
else
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/core/framework/kernel_def_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,20 @@ bool KernelDef::IsConflict(const KernelDef& other) const {
return false;

//check memory type
auto other_input_mem_types = other.InputMemoryType();
auto& other_input_mem_types = other.input_memory_type_args_;
for (auto it : input_memory_type_args_) {
if (other_input_mem_types.count(it.first) && other_input_mem_types[it.first] == it.second)
if (other_input_mem_types.count(it.first) && other_input_mem_types.find(it.first)->second == it.second)
return false;
}
if (input_memory_type_args_.empty() && !other.InputMemoryType().empty())
if (input_memory_type_args_.empty() && !other.input_memory_type_args_.empty())
return false;

auto other_output_mem_types = other.OutputMemoryType();
auto& other_output_mem_types = other.output_memory_type_args_;
for (auto it : output_memory_type_args_) {
if (other_output_mem_types.count(it.first) && other_output_mem_types[it.first] == it.second)
if (other_output_mem_types.count(it.first) && other_output_mem_types.find(it.second)->second == it.second)
return false;
}
return !(output_memory_type_args_.empty() && !other.OutputMemoryType().empty());
return !(output_memory_type_args_.empty() && !other.output_memory_type_args_.empty());
}

KernelDefBuilder& KernelDefBuilder::SetName(const std::string& op_name) {
Expand Down
23 changes: 11 additions & 12 deletions onnxruntime/core/framework/transformer_memcpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,24 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
// note KernelCreateInfo might be nullptr for custom kernel
const KernelCreateInfo* kci = nullptr;
kernel_registries.SearchKernelRegistry(node, &kci);
const auto* input_mem_types = kci ? &kci->kernel_def->InputMemoryType() : nullptr;
const auto* output_mem_types = kci ? &kci->kernel_def->InputMemoryType() : nullptr;

ORT_ENFORCE(onnxruntime::Node::ForEachWithIndex(
node.InputDefs(),
[this, &input_mem_types](const onnxruntime::NodeArg& arg, size_t index) {
if (input_mem_types && MemTypeOnCpuExplicitly(*input_mem_types, index))
non_provider_input_defs_.insert(&arg);
else
provider_input_defs_.insert(&arg);
return Status::OK();
})
.IsOK());
node.InputDefs(),
[this, &kci](const onnxruntime::NodeArg& arg, size_t index) {
if (kci && MemTypeOnCpuExplicitly(kci->kernel_def->InputMemoryType(index)))
non_provider_input_defs_.insert(&arg);
else
provider_input_defs_.insert(&arg);
return Status::OK();
})
.IsOK());
auto& output_defs = node.MutableOutputDefs();
for (size_t i = 0; i < output_defs.size(); ++i) {
auto arg = output_defs[i];
if (!arg->Exists())
continue;

if (output_mem_types && MemTypeOnCpuExplicitly(*output_mem_types, i))
if (kci && MemTypeOnCpuExplicitly(kci->kernel_def->OutputMemoryType(i)))
non_provider_output_defs_.insert(arg);
else
provider_output_defs_.insert(arg);
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/session/IOBinding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,9 @@ common::Status IOBinding::CopyOneInputAcrossDevices(const SessionState& session_
size_t index = node_info.index;
auto& node = *node_info.p_node;
const KernelCreateInfo* kci = node_info.kci;
const auto* node_input_mem_types = (kci != nullptr) ? &kci->kernel_def->InputMemoryType() : nullptr;

// node may declare input_mem_type to be on CPU explicitly
bool node_input_on_cpu = node_input_mem_types && MemTypeOnCpuExplicitly(*node_input_mem_types, index);
bool node_input_on_cpu = kci && MemTypeOnCpuExplicitly(kci->kernel_def->InputMemoryType(index));
auto& required_provider_type = node_input_on_cpu ? onnxruntime::kCpuExecutionProvider : node.GetExecutionProviderType();
if (!orig_mlvalue.IsTensor()) {
// copying not supported for non-tensor types
Expand Down