Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions onnxruntime/core/framework/session_state_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,38 +398,67 @@ common::Status SaveInitializedTensors(
ort_value = *(session_options.initializers_to_share_map.at(name));
LOGS(logger, INFO) << "Using user supplied initializer with name (" << name << ").";

} else if (graph.GetOrtValueInitializer(name, ort_value)) {
// populated OrtValue from the Graph instance
} else {
const ONNX_NAMESPACE::TensorProto& tensor_proto = *(entry.second);

std::optional<MemBuffer> m;
AllocatorPtr alloc;
// TODO: if the tensor need be copied, does it have enough room?
ORT_RETURN_IF_ERROR(planner.GetPreallocatedBuffer(ort_value_index, name, m, alloc));
bool use_device_allocator_for_initializers =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsUseDeviceAllocatorForInitializers, "0") == "1";

Tensor* p_tensor = nullptr;
auto buffered_tensors_iter = buffered_tensors.find(name);
if (buffered_tensors_iter != buffered_tensors.end()) {
p_tensor = buffered_tensors_iter->second.get();
}
if (graph.GetOrtValueInitializer(name, ort_value)) {
// populated OrtValue from the Graph instance

auto& memory_info = (alloc != nullptr) ? alloc->Info() : m->GetAllocInfo();
auto device_type = memory_info.device.Type();

if (device_type != OrtDevice::CPU) {
// if the initializer is on a non-CPU device, copy it from CPU to the device.
const auto& initializer_tensor = ort_value.Get<Tensor>();
if (initializer_tensor.GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_STRING) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "string tensor is not supported for copying between allocators");
Copy link

Copilot AI Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message for string tensors ('string tensor is not supported for copying between allocators') could be enhanced with additional context to guide users toward proper usage or handling. Consider providing a more descriptive message or suggesting alternative approaches where applicable.

Suggested change
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "string tensor is not supported for copying between allocators");
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"String tensor copying between allocators is not supported due to limitations in ONNX Runtime. "
"Consider preprocessing string tensors on the CPU or avoiding operations that require copying them between devices.");

Copilot uses AI. Check for mistakes.
}

std::unique_ptr<Tensor> device_tensor;
TensorShape device_tensor_shape{initializer_tensor.Shape()};
ORT_RETURN_IF_ERROR(AllocateTensor((m.has_value()) ? &*m : nullptr,
device_tensor,
initializer_tensor.DataType(),
device_tensor_shape,
use_device_allocator_for_initializers,
alloc));

ORT_RETURN_IF_ERROR(data_transfer_mgr.CopyTensor(initializer_tensor, *device_tensor));

ort_value = {device_tensor.release(),
DataTypeImpl::GetType<Tensor>(),
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc()};
}

Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc,
default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr,
prepacked_for_graph,
use_device_allocator_for_initializers, p_tensor);
if (!st.IsOK()) {
std::ostringstream oss;
oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage();
return Status(st.Category(), st.Code(), oss.str());
}
} else {
const ONNX_NAMESPACE::TensorProto& tensor_proto = *(entry.second);

if (p_tensor != nullptr) {
// p_tensor was wrapped in a deleter by DeserializeTensorProto so we can simply release it here.
ORT_IGNORE_RETURN_VALUE(buffered_tensors_iter->second.release());
buffered_tensors.erase(buffered_tensors_iter);
Tensor* p_tensor = nullptr;
auto buffered_tensors_iter = buffered_tensors.find(name);
if (buffered_tensors_iter != buffered_tensors.end()) {
p_tensor = buffered_tensors_iter->second.get();
}

Status st = DeserializeTensorProto(env, graph_loc, tensor_proto, (m.has_value()) ? &*m : nullptr, alloc,
default_cpu_alloc, ort_value, data_transfer_mgr, external_data_loader_mgr,
prepacked_for_graph,
use_device_allocator_for_initializers, p_tensor);
if (!st.IsOK()) {
std::ostringstream oss;
oss << "Deserialize tensor " << name << " failed." << st.ErrorMessage();
return Status(st.Category(), st.Code(), oss.str());
}

if (p_tensor != nullptr) {
// p_tensor was wrapped in a deleter by DeserializeTensorProto so we can simply release it here.
ORT_IGNORE_RETURN_VALUE(buffered_tensors_iter->second.release());
buffered_tensors.erase(buffered_tensors_iter);
}
}
}

Expand Down
Loading