Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,39 @@ class IExecutionProvider {
return std::nullopt;
}

/**
Query the preferred format descriptor for an initializer without performing the transformation.
This is a lightweight query called during session initialization to determine what format
transformations are needed.

@param node The node that consumes the initializer
@param input_index The input index of the initializer in the node
@param[out] format_descriptor A string that uniquely identifies the preferred format.
Empty string means no transformation is needed.
Examples: "ABcd16a4b", "hwio".
@return Status::OK() if query succeeded (format_descriptor will be set).
Failed status indicates no transformation is needed.
*/
virtual Status GetPreferredInitializerFormat(const Node& /*node*/, int /*input_index*/,
std::string& /*format_descriptor*/) const {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "No format transformation needed");
}

/**
Transform an initializer to the specified format.
This performs the actual data transformation. It is only called once per unique format
even if multiple nodes need the same format.

@param original_tensor The original initializer tensor
@param format_descriptor The target format (from GetPreferredInitializerFormat)
@param[out] transformed_tensor The EP should allocate and fill this with the transformed data.
@return Status::OK() if transformation succeeded.
*/
virtual Status TransformInitializerFormat(const Tensor& /*original_tensor*/, const std::string& /*format_descriptor*/,
std::unique_ptr<Tensor>& /*transformed_tensor*/) const {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Format transformation not supported");
}

virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/, AllocatorMap&) const {}

/** Does the EP support concurrent calls to InferenceSession::Run to execute the model.
Expand Down
16 changes: 16 additions & 0 deletions include/onnxruntime/core/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,19 @@ class Tensor final {
byte_offset_ = byte_offset;
}

/**
* Get the memory format descriptor for this tensor.
* Returns empty string if the tensor is in standard format.
*/
inline const std::string& GetFormatDescriptor() const { return format_descriptor_; }

/**
* Set the memory format descriptor for this tensor.
* Used for EP-specific memory layouts (e.g., "ABcd16a4b" for blocked format).
* The format string encodes all necessary information including block sizes.
*/
inline void SetFormatDescriptor(const std::string& format) { format_descriptor_ = format; }

/// <summary>
/// The number of Tensor "storage" elements. A single storage element may contain multiple sub-elements for
/// sub-byte data types (e.g., int4/float4).
Expand Down Expand Up @@ -349,6 +362,9 @@ class Tensor final {
const PrimitiveDataTypeBase* dtype_;
OrtMemoryInfo alloc_info_;
ptrdiff_t byte_offset_;

// Memory format descriptor for EP-specific layouts (e.g., "ABcd16a4b")
std::string format_descriptor_;
};
#ifdef __GNUC__
#pragma GCC diagnostic pop
Expand Down
160 changes: 160 additions & 0 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "core/framework/ort_value_pattern_planner.h"
#include "core/framework/prepacked_weights_container.h"
#include "core/framework/session_state_utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/utils.h"
#include "core/providers/cpu/controlflow/utils.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
Expand Down Expand Up @@ -1332,6 +1333,9 @@
ORT_RETURN_IF_ERROR(CreateSubgraphSessionState());

ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEp(graph_, logger_, execution_providers_));

ORT_RETURN_IF_ERROR(TransformInitializersToPreferredFormat());

ORT_RETURN_IF_ERROR(PopulateKernelCreateInfo(kernel_registry_manager, saving_ort_format));

InlinedHashMap<std::string, size_t> constant_initializers_use_count;
Expand Down Expand Up @@ -1501,6 +1505,14 @@
CreateGraphInfo(save_prepacked_initializers);
}

// Index all initializers including those that may have become unreferenced after transformation.
// This runs after CreateGraphInfo() to ensure consistent ordering - CreateGraphInfo indexes based on
// graph structure, then we add any remaining initializers (e.g., original weights before transformation).
for (const auto& [init_name, tensor_proto] : graph_.GetAllInitializedTensors()) {
ORT_UNUSED_PARAMETER(tensor_proto);
ort_value_name_idx_map_.Add(init_name);
}

#if defined(ORT_EXTENDED_MINIMAL_BUILD)
// Remove any unused initializers.
// Not needed in a full build because unused initializers should have been removed earlier by Graph::Resolve().
Expand Down Expand Up @@ -1793,4 +1805,152 @@
}
#endif

Status SessionState::TransformInitializersToPreferredFormat() {
// Build a map from initializer name to all nodes that consume it
std::unordered_map<std::string, std::vector<std::pair<NodeIndex, int>>> initializer_to_consumers;

const auto& initialized_tensors_map = graph_.GetAllInitializedTensors();
std::unordered_set<std::string> initializer_names;

Check warning on line 1813 in onnxruntime/core/framework/session_state.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/session_state.cc:1813: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]
for (const auto& [name, tensor_proto] : initialized_tensors_map) {
ORT_UNUSED_PARAMETER(tensor_proto);
initializer_names.insert(name);
}

// Scan nodes to find which initializers they use
for (const auto& node : graph_.Nodes()) {
int input_index = 0;
for (const auto* input_def : node.InputDefs()) {
if (input_def && input_def->Exists()) {
const auto& input_name = input_def->Name();
if (initializer_names.count(input_name) > 0) {
initializer_to_consumers[input_name].emplace_back(node.Index(), input_index);
}
}
++input_index;
}
}

auto cpu_allocator = GetAllocator(OrtDevice());

for (const auto& [init_name, consumers] : initializer_to_consumers) {
if (consumers.empty()) {
continue;
}

const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_.GetInitializer(init_name, true);
if (!tensor_proto) {
continue;
}

// Skip if this initializer was already transformed (when loading a saved ORT format model)
// Transformed initializers have format metadata in string_data
bool already_transformed = false;
for (const auto& attr_str : tensor_proto->string_data()) {
if (attr_str.find("onnxruntime_format:") == 0) {
already_transformed = true;
break;
}
}
if (already_transformed) {
continue;
}

// Phase 1: Query all consumers to discover what formats are needed
// Multiple nodes may request the same format, so we deduplicate by format
std::unordered_map<std::string, std::vector<std::pair<NodeIndex, int>>> format_to_consumers;

Check warning on line 1860 in onnxruntime/core/framework/session_state.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/session_state.cc:1860: Add #include <vector> for vector<> [build/include_what_you_use] [4]

Check warning on line 1860 in onnxruntime/core/framework/session_state.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/session_state.cc:1860: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]

for (const auto& [node_idx, input_idx] : consumers) {
const Node* node = graph_.GetNode(node_idx);
if (!node) {
continue;
}

const auto& ep_type = node->GetExecutionProviderType();
if (ep_type.empty()) {
continue;
}

const auto* ep = execution_providers_.Get(ep_type);
if (!ep) {
continue;
}

// Ask EP if it wants this initializer in a different format
std::string format_descriptor;
Status query_status = ep->GetPreferredInitializerFormat(*node, input_idx, format_descriptor);

if (!query_status.IsOK() || format_descriptor.empty()) {
continue;
}

format_to_consumers[format_descriptor].emplace_back(node_idx, input_idx);
}

if (format_to_consumers.empty()) {
continue;
}

// Load the original initializer to CPU for transformation
TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(*tensor_proto);
const auto* tensor_type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto->data_type())->GetElementType();

Tensor original_tensor(tensor_type, tensor_shape, cpu_allocator);
ORT_RETURN_IF_ERROR(
utils::TensorProtoToTensor(Env::Default(), std::filesystem::path(), *tensor_proto, original_tensor));

// Phase 2: Transform once per unique format requested
for (const auto& [format_descriptor, nodes_needing_format] : format_to_consumers) {
const Node* first_node = graph_.GetNode(nodes_needing_format[0].first);
if (!first_node) {
continue;
}

const auto& ep_type = first_node->GetExecutionProviderType();
const auto* ep = execution_providers_.Get(ep_type);
if (!ep) {
continue;
}

// Perform the actual transformation
std::unique_ptr<Tensor> transformed_tensor;

Check warning on line 1915 in onnxruntime/core/framework/session_state.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/session_state.cc:1915: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
Status transform_status = ep->TransformInitializerFormat(original_tensor, format_descriptor, transformed_tensor);

if (!transform_status.IsOK() || !transformed_tensor) {
LOGS(logger_, WARNING) << "Failed to transform initializer '" << init_name << "' to format '"
<< format_descriptor << "': " << transform_status.ErrorMessage();
continue;
}

// Set format metadata on the transformed tensor
transformed_tensor->SetFormatDescriptor(format_descriptor);

// Add the transformed initializer with a new name
std::string transformed_name = init_name + "_fmt_" + format_descriptor;

Check warning on line 1928 in onnxruntime/core/framework/session_state.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/session_state.cc:1928: Add #include <string> for string [build/include_what_you_use] [4]

ONNX_NAMESPACE::TensorProto transformed_proto = utils::TensorToTensorProto(*transformed_tensor, transformed_name);

// Add format metadata as TensorProto attribute
auto* format_attr = transformed_proto.add_string_data();
*format_attr = "onnxruntime_format:" + format_descriptor;

graph_.AddInitializedTensor(transformed_proto);

// Update all nodes that need this format to use the transformed version
for (const auto& [node_idx, input_idx] : nodes_needing_format) {
Node* node = graph_.GetNode(node_idx);
if (!node) {
continue;
}

const auto* original_node_arg = node->InputDefs()[input_idx];
auto* transformed_node_arg = &graph_.GetOrCreateNodeArg(transformed_name, original_node_arg->TypeAsProto());

node->MutableInputDefs()[input_idx] = transformed_node_arg;
}
}
}

return Status::OK();
}

} // namespace onnxruntime
6 changes: 6 additions & 0 deletions onnxruntime/core/framework/session_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,12 @@ class SessionState {
const InlinedHashMap<OrtValueName, OrtDevice>& outer_scope_node_arg_to_location_map = {},
bool graph_info_already_created = false);

/**
* Transform initializer tensors to EP-preferred memory formats.
* This is called during session initialization before kernel creation.
*/
Status TransformInitializersToPreferredFormat();

#ifdef ENABLE_TRAINING
Status GeneratePatternGroupCache(
gsl::span<const OrtValue> inputs,
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/framework/session_state_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ common::Status CopyTensorFromCPUToDevice(
}
return copy_status;
} else {
// Preserve format descriptor when copying from CPU to device
const std::string& format = deserialized_tensor.GetFormatDescriptor();
if (!format.empty()) {
tensor.SetFormatDescriptor(format);
}
Tensor::InitOrtValue(std::move(tensor), ort_value);
return common::Status::OK();
}
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/framework/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ Tensor::Tensor(Tensor&& other) noexcept
#endif
dtype_(other.dtype_),
alloc_info_(other.alloc_info_),
byte_offset_(other.byte_offset_) {
byte_offset_(other.byte_offset_),
format_descriptor_(std::move(other.format_descriptor_)) {
other.p_data_ = nullptr;
other.buffer_deleter_ = nullptr;
other.dtype_ = DataTypeImpl::GetType<float>()->AsPrimitiveDataType();
Expand All @@ -221,6 +222,7 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept {
dtype_ = other.dtype_;
alloc_info_ = other.alloc_info_;
byte_offset_ = other.byte_offset_;
format_descriptor_ = std::move(other.format_descriptor_);

other.p_data_ = nullptr;
other.buffer_deleter_ = nullptr;
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/framework/tensorprotoutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1424,6 +1424,15 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa
}
}

// Read format metadata from TensorProto string_data
for (const auto& attr_str : tensor_proto.string_data()) {
if (attr_str.find("onnxruntime_format:") == 0) {
std::string format = attr_str.substr(19); // Skip "onnxruntime_format:"
tensor.SetFormatDescriptor(format);
break; // Only one format descriptor expected
}
}

return Status::OK();
}

Expand Down
Loading
Loading