Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions onnxruntime/python/onnxruntime_pybind_mlvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ std::string _get_type_name(std::string&) {
#if !defined(DISABLE_ML_OPS)
template <typename KeyType, typename ValueType, typename KeyGetterType, typename ValueGetterType>
static void CreateMapMLValue_LoopIntoMap(Py_ssize_t& pos, PyObject*& key, const std::string& name_input, PyObject*& value,
PyObject* item, std::map<KeyType, ValueType>& current,
PyObject* item, bool owns_item_ref, std::map<KeyType, ValueType>& current,
KeyGetterType keyGetter, ValueGetterType valueGetter) {
KeyType ckey;
ValueType cvalue;
Expand All @@ -806,7 +806,9 @@ static void CreateMapMLValue_LoopIntoMap(Py_ssize_t& pos, PyObject*& key, const
std::string sType = spyType;
Py_XDECREF(pStr);
Py_XDECREF(pType);
Py_XDECREF(item);
if (owns_item_ref) {
Py_XDECREF(item);
}
throw std::runtime_error(std::string("Unexpected key type ") + sType +
std::string(", it cannot be linked to C type ") +
_get_type_name(ckey) + std::string(" for input '") +
Expand All @@ -820,7 +822,9 @@ static void CreateMapMLValue_LoopIntoMap(Py_ssize_t& pos, PyObject*& key, const
std::string sType = spyType;
Py_XDECREF(pStr);
Py_XDECREF(pType);
Py_XDECREF(item);
if (owns_item_ref) {
Py_XDECREF(item);
}
throw std::runtime_error(std::string("Unexpected value type ") + sType +
std::string(", it cannot be linked to C type ") +
_get_type_name(ckey) + std::string(" for input '") +
Expand All @@ -836,7 +840,7 @@ static void CreateMapMLValue_Map(Py_ssize_t& pos, PyObject*& key, const std::str
ValueGetterType valueGetter) {
std::unique_ptr<std::map<KeyType, ValueType>> dst;
dst = std::make_unique<std::map<KeyType, ValueType>>();
CreateMapMLValue_LoopIntoMap(pos, key, name_input, value, item, *dst, keyGetter, valueGetter);
CreateMapMLValue_LoopIntoMap(pos, key, name_input, value, item, false, *dst, keyGetter, valueGetter);
p_mlvalue->Init(dst.release(), DataTypeImpl::GetType<std::map<KeyType, ValueType>>(),
DataTypeImpl::GetType<std::map<KeyType, ValueType>>()->GetDeleteFunc());
}
Expand All @@ -850,7 +854,7 @@ void CreateMapMLValue_VectorMap(Py_ssize_t& pos, PyObject*& key, const std::stri
int index = 0;
do {
dstVector->push_back(std::map<KeyType, ValueType>());
CreateMapMLValue_LoopIntoMap(pos, key, name_input, value, item, (*dstVector)[index], keyGetter, valueGetter);
CreateMapMLValue_LoopIntoMap(pos, key, name_input, value, item, true, (*dstVector)[index], keyGetter, valueGetter);
Py_DECREF(item);
++index;
item = iterator == NULL ? NULL : PyIter_Next(iterator);
Expand Down
Loading