diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index d96d229c942cb..89651c2d955de 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -794,7 +794,7 @@ std::string _get_type_name(std::string&) { #if !defined(DISABLE_ML_OPS) template static void CreateMapMLValue_LoopIntoMap(Py_ssize_t& pos, PyObject*& key, const std::string& name_input, PyObject*& value, - PyObject* item, std::map& current, + PyObject* item, bool owns_item_ref, std::map& current, KeyGetterType keyGetter, ValueGetterType valueGetter) { KeyType ckey; ValueType cvalue; @@ -806,7 +806,9 @@ static void CreateMapMLValue_LoopIntoMap(Py_ssize_t& pos, PyObject*& key, const std::string sType = spyType; Py_XDECREF(pStr); Py_XDECREF(pType); - Py_XDECREF(item); + if (owns_item_ref) { + Py_XDECREF(item); + } throw std::runtime_error(std::string("Unexpected key type ") + sType + std::string(", it cannot be linked to C type ") + _get_type_name(ckey) + std::string(" for input '") + @@ -820,7 +822,9 @@ static void CreateMapMLValue_LoopIntoMap(Py_ssize_t& pos, PyObject*& key, const std::string sType = spyType; Py_XDECREF(pStr); Py_XDECREF(pType); - Py_XDECREF(item); + if (owns_item_ref) { + Py_XDECREF(item); + } throw std::runtime_error(std::string("Unexpected value type ") + sType + std::string(", it cannot be linked to C type ") + _get_type_name(ckey) + std::string(" for input '") + @@ -836,7 +840,7 @@ static void CreateMapMLValue_Map(Py_ssize_t& pos, PyObject*& key, const std::str ValueGetterType valueGetter) { std::unique_ptr> dst; dst = std::make_unique>(); - CreateMapMLValue_LoopIntoMap(pos, key, name_input, value, item, *dst, keyGetter, valueGetter); + CreateMapMLValue_LoopIntoMap(pos, key, name_input, value, item, false, *dst, keyGetter, valueGetter); p_mlvalue->Init(dst.release(), DataTypeImpl::GetType>(), DataTypeImpl::GetType>()->GetDeleteFunc()); } @@ -850,7 +854,7 @@ void CreateMapMLValue_VectorMap(Py_ssize_t& pos, PyObject*& key, const std::stri int index = 0; do { dstVector->push_back(std::map()); - CreateMapMLValue_LoopIntoMap(pos, key, name_input, value, item, (*dstVector)[index], keyGetter, valueGetter); + CreateMapMLValue_LoopIntoMap(pos, key, name_input, value, item, true, (*dstVector)[index], keyGetter, valueGetter); Py_DECREF(item); ++index; item = iterator == NULL ? NULL : PyIter_Next(iterator);