diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index f54f4a5a6f1ef..a9a1ff1933473 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -378,6 +378,39 @@ class IExecutionProvider { return std::nullopt; } + /** + Query the preferred format descriptor for an initializer without performing the transformation. + This is a lightweight query called during session initialization to determine what format + transformations are needed. + + @param node The node that consumes the initializer + @param input_index The input index of the initializer in the node + @param[out] format_descriptor A string that uniquely identifies the preferred format. + Empty string means no transformation is needed. + Examples: "ABcd16a4b", "hwio". + @return Status::OK() if query succeeded (format_descriptor will be set). + Failed status indicates no transformation is needed. + */ + virtual Status GetPreferredInitializerFormat(const Node& /*node*/, int /*input_index*/, + std::string& /*format_descriptor*/) const { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "No format transformation needed"); + } + + /** + Transform an initializer to the specified format. + This performs the actual data transformation. It is only called once per unique format + even if multiple nodes need the same format. + + @param original_tensor The original initializer tensor + @param format_descriptor The target format (from GetPreferredInitializerFormat) + @param[out] transformed_tensor The EP should allocate and fill this with the transformed data. + @return Status::OK() if transformation succeeded. + */ + virtual Status TransformInitializerFormat(const Tensor& /*original_tensor*/, const std::string& /*format_descriptor*/, + std::unique_ptr& /*transformed_tensor*/) const { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Format transformation not supported"); + } + virtual void RegisterStreamHandlers(IStreamCommandHandleRegistry& /*stream_handle_registry*/, AllocatorMap&) const {} /** Does the EP support concurrent calls to InferenceSession::Run to execute the model. diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index c7f7f23f70334..aebfda9c04878 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -282,6 +282,19 @@ class Tensor final { byte_offset_ = byte_offset; } + /** + * Get the memory format descriptor for this tensor. + * Returns empty string if the tensor is in standard format. + */ + inline const std::string& GetFormatDescriptor() const { return format_descriptor_; } + + /** + * Set the memory format descriptor for this tensor. + * Used for EP-specific memory layouts (e.g., "ABcd16a4b" for blocked format). + * The format string encodes all necessary information including block sizes. + */ + inline void SetFormatDescriptor(const std::string& format) { format_descriptor_ = format; } + /// /// The number of Tensor "storage" elements. A single storage element may contain multiple sub-elements for /// sub-byte data types (e.g., int4/float4). @@ -349,6 +362,9 @@ class Tensor final { const PrimitiveDataTypeBase* dtype_; OrtMemoryInfo alloc_info_; ptrdiff_t byte_offset_; + + // Memory format descriptor for EP-specific layouts (e.g., "ABcd16a4b") + std::string format_descriptor_; }; #ifdef __GNUC__ #pragma GCC diagnostic pop diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 01ba492eb166e..2001a904f207c 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -15,6 +15,7 @@ #include "core/framework/ort_value_pattern_planner.h" #include "core/framework/prepacked_weights_container.h" #include "core/framework/session_state_utils.h" +#include "core/framework/tensorprotoutils.h" #include "core/framework/utils.h" #include "core/providers/cpu/controlflow/utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -1332,6 +1333,9 @@ Status SessionState::FinalizeSessionState(const std::basic_string constant_initializers_use_count; @@ -1501,6 +1505,14 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_string>> initializer_to_consumers; + + const auto& initialized_tensors_map = graph_.GetAllInitializedTensors(); + std::unordered_set initializer_names; + for (const auto& [name, tensor_proto] : initialized_tensors_map) { + ORT_UNUSED_PARAMETER(tensor_proto); + initializer_names.insert(name); + } + + // Scan nodes to find which initializers they use + for (const auto& node : graph_.Nodes()) { + int input_index = 0; + for (const auto* input_def : node.InputDefs()) { + if (input_def && input_def->Exists()) { + const auto& input_name = input_def->Name(); + if (initializer_names.count(input_name) > 0) { + initializer_to_consumers[input_name].emplace_back(node.Index(), input_index); + } + } + ++input_index; + } + } + + auto cpu_allocator = GetAllocator(OrtDevice()); + + for (const auto& [init_name, consumers] : initializer_to_consumers) { + if (consumers.empty()) { + continue; + } + + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_.GetInitializer(init_name, true); + if (!tensor_proto) { + continue; + } + + // Skip if this initializer was already transformed (when loading a saved ORT format model) + // Transformed initializers have format metadata in string_data + bool already_transformed = false; + for (const auto& attr_str : tensor_proto->string_data()) { + if (attr_str.find("onnxruntime_format:") == 0) { + already_transformed = true; + break; + } + } + if (already_transformed) { + continue; + } + + // Phase 1: Query all consumers to discover what formats are needed + // Multiple nodes may request the same format, so we deduplicate by format + std::unordered_map>> format_to_consumers; + + for (const auto& [node_idx, input_idx] : consumers) { + const Node* node = graph_.GetNode(node_idx); + if (!node) { + continue; + } + + const auto& ep_type = node->GetExecutionProviderType(); + if (ep_type.empty()) { + continue; + } + + const auto* ep = execution_providers_.Get(ep_type); + if (!ep) { + continue; + } + + // Ask EP if it wants this initializer in a different format + std::string format_descriptor; + Status query_status = ep->GetPreferredInitializerFormat(*node, input_idx, format_descriptor); + + if (!query_status.IsOK() || format_descriptor.empty()) { + continue; + } + + format_to_consumers[format_descriptor].emplace_back(node_idx, input_idx); + } + + if (format_to_consumers.empty()) { + continue; + } + + // Load the original initializer to CPU for transformation + TensorShape tensor_shape = utils::GetTensorShapeFromTensorProto(*tensor_proto); + const auto* tensor_type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto->data_type())->GetElementType(); + + Tensor original_tensor(tensor_type, tensor_shape, cpu_allocator); + ORT_RETURN_IF_ERROR( + utils::TensorProtoToTensor(Env::Default(), std::filesystem::path(), *tensor_proto, original_tensor)); + + // Phase 2: Transform once per unique format requested + for (const auto& [format_descriptor, nodes_needing_format] : format_to_consumers) { + const Node* first_node = graph_.GetNode(nodes_needing_format[0].first); + if (!first_node) { + continue; + } + + const auto& ep_type = first_node->GetExecutionProviderType(); + const auto* ep = execution_providers_.Get(ep_type); + if (!ep) { + continue; + } + + // Perform the actual transformation + std::unique_ptr transformed_tensor; + Status transform_status = ep->TransformInitializerFormat(original_tensor, format_descriptor, transformed_tensor); + + if (!transform_status.IsOK() || !transformed_tensor) { + LOGS(logger_, WARNING) << "Failed to transform initializer '" << init_name << "' to format '" + << format_descriptor << "': " << transform_status.ErrorMessage(); + continue; + } + + // Set format metadata on the transformed tensor + transformed_tensor->SetFormatDescriptor(format_descriptor); + + // Add the transformed initializer with a new name + std::string transformed_name = init_name + "_fmt_" + format_descriptor; + + ONNX_NAMESPACE::TensorProto transformed_proto = utils::TensorToTensorProto(*transformed_tensor, transformed_name); + + // Add format metadata as TensorProto attribute + auto* format_attr = transformed_proto.add_string_data(); + *format_attr = "onnxruntime_format:" + format_descriptor; + + graph_.AddInitializedTensor(transformed_proto); + + // Update all nodes that need this format to use the transformed version + for (const auto& [node_idx, input_idx] : nodes_needing_format) { + Node* node = graph_.GetNode(node_idx); + if (!node) { + continue; + } + + const auto* original_node_arg = node->InputDefs()[input_idx]; + auto* transformed_node_arg = &graph_.GetOrCreateNodeArg(transformed_name, original_node_arg->TypeAsProto()); + + node->MutableInputDefs()[input_idx] = transformed_node_arg; + } + } + } + + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.h b/onnxruntime/core/framework/session_state.h index e2102d95e1f17..6264ae651cf2b 100644 --- a/onnxruntime/core/framework/session_state.h +++ b/onnxruntime/core/framework/session_state.h @@ -431,6 +431,12 @@ class SessionState { const InlinedHashMap& outer_scope_node_arg_to_location_map = {}, bool graph_info_already_created = false); + /** + * Transform initializer tensors to EP-preferred memory formats. + * This is called during session initialization before kernel creation. + */ + Status TransformInitializersToPreferredFormat(); + #ifdef ENABLE_TRAINING Status GeneratePatternGroupCache( gsl::span inputs, diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 254c520b4e54a..79cfa3b738891 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -227,6 +227,11 @@ common::Status CopyTensorFromCPUToDevice( } return copy_status; } else { + // Preserve format descriptor when copying from CPU to device + const std::string& format = deserialized_tensor.GetFormatDescriptor(); + if (!format.empty()) { + tensor.SetFormatDescriptor(format); + } Tensor::InitOrtValue(std::move(tensor), ort_value); return common::Status::OK(); } diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index eefd7825eca5b..05ff569a9e51c 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -195,7 +195,8 @@ Tensor::Tensor(Tensor&& other) noexcept #endif dtype_(other.dtype_), alloc_info_(other.alloc_info_), - byte_offset_(other.byte_offset_) { + byte_offset_(other.byte_offset_), + format_descriptor_(std::move(other.format_descriptor_)) { other.p_data_ = nullptr; other.buffer_deleter_ = nullptr; other.dtype_ = DataTypeImpl::GetType()->AsPrimitiveDataType(); @@ -221,6 +222,7 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept { dtype_ = other.dtype_; alloc_info_ = other.alloc_info_; byte_offset_ = other.byte_offset_; + format_descriptor_ = std::move(other.format_descriptor_); other.p_data_ = nullptr; other.buffer_deleter_ = nullptr; diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index f66966b335454..a563832f74b6e 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1424,6 +1424,15 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa } } + // Read format metadata from TensorProto string_data + for (const auto& attr_str : tensor_proto.string_data()) { + if (attr_str.find("onnxruntime_format:") == 0) { + std::string format = attr_str.substr(19); // Skip "onnxruntime_format:" + tensor.SetFormatDescriptor(format); + break; // Only one format descriptor expected + } + } + return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index a2777979ae983..0e8c3c7ef14b3 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -12,6 +12,194 @@ namespace onnxruntime { namespace webgpu { +// Get the preferred kernel format for Conv operator +// Returns format descriptor string (e.g., "hwio", "ABcd16a4b"), or empty string if no transformation needed +Status ConvGetPreferredKernelFormat(const Node& node, int input_index, std::string& format_descriptor) { + // Conv operator - kernel is input index 1 + // Conv signature: [X, W, B?] where X=activations, W=kernel, B=optional bias + if (input_index != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "No format transformation needed"); + } + + // Get kernel shape and dtype from NodeArg + const auto& input_defs = node.InputDefs(); + if (input_index >= static_cast(input_defs.size())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Invalid input index"); + } + + const auto* kernel_arg = input_defs[input_index]; + if (!kernel_arg || !kernel_arg->Exists()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Kernel input does not exist"); + } + + // Get shape + const auto* shape_proto = kernel_arg->Shape(); + if (!shape_proto) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Kernel shape is unknown"); + } + + TensorShapeVector dims; + for (const auto& dim : shape_proto->dim()) { + if (dim.has_dim_value()) { + dims.push_back(dim.dim_value()); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Kernel has dynamic shape"); + } + } + + // Conv kernels must be 4D: [O, I, H, W] (or 3D for Conv1D which gets expanded to 4D) + if (dims.size() != 4 && dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "No format transformation needed"); + } + + // Get data type + const auto* type_proto = kernel_arg->TypeAsProto(); + if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Kernel type is unknown"); + } + + auto elem_type_enum = static_cast(type_proto->tensor_type().elem_type()); + const auto* kernel_dtype = DataTypeImpl::TensorTypeFromONNXEnum(elem_type_enum)->GetElementType(); + + // Only support float32 and float16 + const bool is_float32 = (kernel_dtype == DataTypeImpl::GetType()); + const bool is_float16 = (kernel_dtype == DataTypeImpl::GetType()); + + if (!is_float32 && !is_float16) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "No format transformation needed"); + } + + // Check if this is channels_last (NHWC) layout by checking the domain + const bool is_channels_last = (node.Domain() == kMSInternalNHWCDomain); + + // Get group attribute to match Conv execution logic + int64_t group = 1; + const auto& attributes = node.GetAttributes(); + auto group_attr = attributes.find("group"); + if (group_attr != attributes.end()) { + group = group_attr->second.i(); + } + + // Get kernel spatial dimensions for MatMul optimization path check + const int64_t kernel_height = dims.size() == 4 ? dims[2] : 1; // Conv1D has no H dim + const int64_t kernel_width = dims.size() == 4 ? dims[3] : dims[2]; + + // Get input shape to check same_size condition + const auto* input_arg = input_defs[0]; + int64_t input_height = -1; + int64_t input_width = -1; + if (input_arg && input_arg->Exists()) { + const auto* input_shape_proto = input_arg->Shape(); + if (input_shape_proto && input_shape_proto->dim_size() >= 3) { + if (is_channels_last) { + // For channels_last: [N, H, W, C] or [N, W, C] for Conv1D + if (input_shape_proto->dim_size() == 4) { + if (input_shape_proto->dim(1).has_dim_value()) { + input_height = input_shape_proto->dim(1).dim_value(); + } + if (input_shape_proto->dim(2).has_dim_value()) { + input_width = input_shape_proto->dim(2).dim_value(); + } + } else if (input_shape_proto->dim_size() == 3) { + // Conv1D + input_height = 1; + if (input_shape_proto->dim(1).has_dim_value()) { + input_width = input_shape_proto->dim(1).dim_value(); + } + } + } else { + // For channels_first: [N, C, H, W] or [N, C, W] for Conv1D + if (input_shape_proto->dim_size() == 4) { + if (input_shape_proto->dim(2).has_dim_value()) { + input_height = input_shape_proto->dim(2).dim_value(); + } + if (input_shape_proto->dim(3).has_dim_value()) { + input_width = input_shape_proto->dim(3).dim_value(); + } + } else if (input_shape_proto->dim_size() == 3) { + // Conv1D + input_height = 1; + if (input_shape_proto->dim(2).has_dim_value()) { + input_width = input_shape_proto->dim(2).dim_value(); + } + } + } + } + } + + // Get pads and strides attributes + std::vector pads; + auto pads_attr = attributes.find("pads"); + if (pads_attr != attributes.end()) { + pads.assign(pads_attr->second.ints().begin(), pads_attr->second.ints().end()); + } + + std::vector strides_vec; + auto strides_attr = attributes.find("strides"); + if (strides_attr != attributes.end()) { + strides_vec.assign(strides_attr->second.ints().begin(), strides_attr->second.ints().end()); + } + + // Default pads and strides if not specified + if (pads.empty()) { + pads.resize(dims.size() == 4 ? 4 : 2, 0); // 4 for 2D conv, 2 for 1D conv + } + if (strides_vec.empty()) { + strides_vec.resize(dims.size() == 4 ? 2 : 1, 1); + } + + // Analyze execution paths to determine if kernel needs pre-transformation: + + // Path 1: Grouped convolution (group > 1) + // - Only transposes when is_channels_last + // - channels_first: no transpose + if (group > 1) { + if (is_channels_last) { + format_descriptor = "hwio"; + return Status::OK(); + } else { + // channels_first grouped conv doesn't transpose + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "No format transformation needed"); + } + } + + // Path 2: MatMul optimization (same_size or 1x1 conv conditions) + // - channels_last: transposes + // - channels_first: does NOT transpose + + const bool same_size = (input_height > 0 && input_width > 0 && input_height == kernel_height && + input_width == kernel_width && pads[0] == 0 && pads[1] == 0); + + const bool is_1x1_conv = + (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides_vec.size() > 0 && + strides_vec[0] == 1 && (strides_vec.size() == 1 || strides_vec[1] == 1)); + + if (same_size || is_1x1_conv) { + if (is_channels_last) { + // MatMul optimization transposes for channels_last + format_descriptor = "hwio"; + return Status::OK(); + } else { + // MatMul optimization does NOT transpose for channels_first + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "No format transformation needed"); + } + } + + // Path 3: General convolution (fallback path) + // - ALWAYS transposes regardless of is_channels_last + // - Both channels_last AND channels_first transpose here + format_descriptor = "hwio"; + return Status::OK(); + + // TODO: Add shape-based heuristics for blocked format in the future: + // const int64_t O = dims[0]; // output channels + // const int64_t I = dims[1]; // input channels + // if (O >= 16 && I >= 4 && minimal_padding_overhead) { + // format_descriptor = "ABcd16a4b"; // 16x4 blocks on O and I dims + // return Status::OK(); + // } +} + Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm) { // Transpose weights auto rank = kernel_shape.NumDimensions(); @@ -33,6 +221,22 @@ Status Conv::ComputeInternal(ComputeContext& context const auto* bias = has_bias ? context.Input(2) : nullptr; TensorShape input_shape = input->Shape(); TensorShape kernel_shape = kernel->Shape(); + + // Check if kernel is pre-transformed to hwio format + const bool is_kernel_hwio = (kernel->GetFormatDescriptor() == "hwio"); + + // If kernel is pre-transformed to hwio format, we need to get the logical oihw shape + // for computing kernel spatial dimensions and output channels + // hwio format: [H, W, I, O] -> oihw format: [O, I, H, W] + if (is_kernel_hwio) { + // Convert hwio shape back to oihw for dimension calculations + const auto& hwio_shape = kernel_shape.GetDims(); + if (hwio_shape.size() == 4) { + // hwio -> oihw: permutation is {3, 2, 0, 1} + kernel_shape = TensorShape({hwio_shape[3], hwio_shape[2], hwio_shape[0], hwio_shape[1]}); + } + } + ConvAttributes::ConvPadVector local_pads(conv_attrs_.pads.begin(), conv_attrs_.pads.end()); TensorShapeVector local_dilations(conv_attrs_.dilations.begin(), conv_attrs_.dilations.end()); TensorShapeVector local_strides(conv_attrs_.strides.begin(), conv_attrs_.strides.end()); @@ -106,9 +310,17 @@ Status Conv::ComputeInternal(ComputeContext& context if (conv_attrs_.group > 1) { Tensor transposed_kernel; if (is_channels_last) { - ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); - inputs[1] = &transposed_kernel; - modified_input_output_shapes[1] = transposed_kernel.Shape(); + // Check if kernel is already in hwio format (pre-transformed) + if (is_kernel_hwio) { + // Kernel is already in hwio format, use it directly + inputs[1] = kernel; + modified_input_output_shapes[1] = kernel->Shape(); // Use actual tensor shape (hwio) + } else { + // Need to transpose kernel from oihw to hwio at runtime + ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); + inputs[1] = &transposed_kernel; + modified_input_output_shapes[1] = transposed_kernel.Shape(); + } } auto output_channels_per_group = output_channels / conv_attrs_.group; auto components = static_cast(is_channels_last && output_channels_per_group >= 4 ? GetMaxComponents(output_channels) : 1); @@ -145,10 +357,17 @@ Status Conv::ComputeInternal(ComputeContext& context std::vector matmul_inputs; std::vector matmul_input_reshapes; if (is_channels_last) { - // Transpose weights - - ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); - inputs[1] = &transposed_kernel; + // Check if kernel is already in hwio format (pre-transformed) + const Tensor* kernel_to_use = kernel; + if (is_kernel_hwio) { + // Kernel is already in hwio format, use it directly + kernel_to_use = kernel; + } else { + // Need to transpose kernel from oihw to hwio at runtime + ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); + kernel_to_use = &transposed_kernel; + } + inputs[1] = kernel_to_use; if (same_size) { const auto shared_dim = input_height * input_width * input_channels; input_reshape = TensorShape({1, batch, shared_dim}); @@ -160,7 +379,7 @@ Status Conv::ComputeInternal(ComputeContext& context matmul_output_shape = TensorShape({batch, output_height * output_width, output_channels}); } matmul_inputs.push_back(input); - matmul_inputs.push_back(&transposed_kernel); + matmul_inputs.push_back(kernel_to_use); matmul_input_reshapes.push_back(input_reshape); matmul_input_reshapes.push_back(kernel_reshape); } else { @@ -204,16 +423,31 @@ Status Conv::ComputeInternal(ComputeContext& context return context.RunProgram(program); } } - // Transpose weights + // General Conv path - transpose weights if needed Tensor transposed_kernel; - ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); + const Tensor* kernel_to_use = kernel; + TensorShape kernel_to_use_shape; + + // Check if kernel is already in hwio format (pre-transformed) + if (is_kernel_hwio) { + // Kernel is already in hwio format, use it directly + kernel_to_use = kernel; + kernel_to_use_shape = kernel->Shape(); // Use actual tensor shape (hwio) + } else { + // Need to transpose kernel from oihw to hwio at runtime + ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); + kernel_to_use = &transposed_kernel; + kernel_to_use_shape = transposed_kernel.Shape(); + } + auto dim_a_outer = static_cast(is_channels_last ? output_height * output_width : output_channels); auto dim_b_outer = static_cast(is_channels_last ? output_channels : output_height * output_width); auto dim_inner = static_cast(kernel_height * kernel_width * input_channels); - inputs[1] = &transposed_kernel; - TensorShape transposed_kernel_shape = transposed_kernel.Shape(); - modified_input_output_shapes[1] = transposed_kernel.Shape(); - Conv2dMMProgram conv2d_mm_program = CreateConv2dMMProgram(activation_, inputs, pads, strides, dilations, output, dim_a_outer, dim_b_outer, dim_inner, is_channels_last, modified_input_output_shapes); + inputs[1] = kernel_to_use; + modified_input_output_shapes[1] = kernel_to_use_shape; + Conv2dMMProgram conv2d_mm_program = + CreateConv2dMMProgram(activation_, inputs, pads, strides, dilations, output, dim_a_outer, dim_b_outer, dim_inner, + is_channels_last, modified_input_output_shapes); return context.RunProgram(conv2d_mm_program); } diff --git a/onnxruntime/core/providers/webgpu/nn/conv.h b/onnxruntime/core/providers/webgpu/nn/conv.h index cafaa272c0613..08f2d2da0ae37 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.h +++ b/onnxruntime/core/providers/webgpu/nn/conv.h @@ -28,6 +28,10 @@ class Conv : public WebGpuKernel { Activation activation_; }; +// Get the preferred kernel format for Conv operator +// Returns format descriptor string (e.g., "hwio", "ABcd16a4b"), or empty string if no transformation needed +Status ConvGetPreferredKernelFormat(const Node& node, int input_index, std::string& format_descriptor); + Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm); } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 3df194217933e..b71df21904d7d 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -19,6 +19,7 @@ #include "core/framework/fallback_cpu_capability.h" #include "core/framework/kernel_registry.h" #include "core/framework/run_options.h" +#include "core/providers/webgpu/weight_layout_transformer.h" #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" #include "core/session/onnxruntime_run_options_config_keys.h" @@ -29,6 +30,7 @@ #include "core/providers/webgpu/external_data_loader.h" #include "core/providers/webgpu/webgpu_profiler.h" #include "core/providers/webgpu/tensor/cast.h" +#include "core/providers/webgpu/nn/conv.h" namespace onnxruntime { @@ -947,6 +949,28 @@ std::optional WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::s return std::nullopt; } +Status WebGpuExecutionProvider::GetPreferredInitializerFormat(const Node& node, int input_index, + std::string& format_descriptor) const { + // Delegate to operator-specific functions based on operator type + if (node.OpType() == "Conv") { + return ConvGetPreferredKernelFormat(node, input_index, format_descriptor); + } + + // Add more operators here as needed: + // - ConvTranspose: similar to Conv + // - Gemm/MatMul: may use different blocking schemes based on input/output shapes + // etc. + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "No format transformation needed"); +} + +Status WebGpuExecutionProvider::TransformInitializerFormat(const Tensor& original_tensor, + const std::string& format_descriptor, + std::unique_ptr& transformed_tensor) const { + // Delegate to WeightLayoutTransformer + return WeightLayoutTransformer::TransformLayout(original_tensor, format_descriptor, transformed_tensor); +} + WebGpuExecutionProvider::~WebGpuExecutionProvider() { // Release all resources associated with the captured graph if (!captured_commands_.empty()) { diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index a9282a028c803..268fa17968c97 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -64,6 +64,12 @@ class WebGpuExecutionProvider : public IExecutionProvider { std::string_view node_op_type, DataLayout target_data_layout) const override; + Status GetPreferredInitializerFormat(const Node& node, int input_index, + std::string& format_descriptor) const override; + + Status TransformInitializerFormat(const Tensor& original_tensor, const std::string& format_descriptor, + std::unique_ptr& transformed_tensor) const override; + FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } // WebGPU EP disallow concurrent run because actual implementation (eg. WebGPU backend) relies on global states to diff --git a/onnxruntime/core/providers/webgpu/weight_layout_transformer.cc b/onnxruntime/core/providers/webgpu/weight_layout_transformer.cc new file mode 100644 index 0000000000000..6561d18cf60e9 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/weight_layout_transformer.cc @@ -0,0 +1,181 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/weight_layout_transformer.h" +#include "core/framework/allocator.h" +#include "core/framework/data_types.h" +#include "core/framework/tensorprotoutils.h" +#include + +namespace onnxruntime { +namespace webgpu { + +// Template helper function to transpose weights from oihw to hwio layout +template +void WeightLayoutTransformer::TransposeOIHWToHWIO(const T* src, T* dst, + int64_t O, int64_t I, int64_t H, int64_t W) { + // Transpose from oihw to hwio + // Source layout: [O][I][H][W] + // Destination layout: [H][W][I][O] + // Permutation: {2, 3, 1, 0} + + for (int64_t o = 0; o < O; ++o) { + for (int64_t i = 0; i < I; ++i) { + for (int64_t h = 0; h < H; ++h) { + for (int64_t w = 0; w < W; ++w) { + // Source index: oihw + const size_t src_idx = ((o * I + i) * H + h) * W + w; + + // Destination index: hwio + const size_t dst_idx = ((h * W + w) * I + i) * O + o; + + dst[dst_idx] = src[src_idx]; + } + } + } + } +} + +// Template helper function to reorder weights from oihw to ABcd16a4b blocked format +template +void WeightLayoutTransformer::ReorderToBlockedFormat(const T* src, T* dst, + int64_t O, int64_t I, int64_t H, int64_t W, + int64_t O_blocks, int64_t I_blocks, + int64_t block_o, int64_t block_i) { + // Reorder from oihw to ABcd16a4b + // Source layout: [O][I][H][W] + // Destination layout: [O_blocks][I_blocks][H][W][block_o][block_i] + // + // Destination strides: + // - O_blocks: I_blocks * H * W * block_o * block_i + // - I_blocks: H * W * block_o * block_i + // - H: W * block_o * block_i + // - W: block_o * block_i + // - block_o: block_i + // - block_i: 1 + + for (int64_t ob = 0; ob < O_blocks; ++ob) { + for (int64_t ib = 0; ib < I_blocks; ++ib) { + for (int64_t h = 0; h < H; ++h) { + for (int64_t w = 0; w < W; ++w) { + for (int64_t o_in_block = 0; o_in_block < block_o; ++o_in_block) { + for (int64_t i_in_block = 0; i_in_block < block_i; ++i_in_block) { + const int64_t o = ob * block_o + o_in_block; + const int64_t i = ib * block_i + i_in_block; + + // Calculate destination index for ABcd16a4b layout + const size_t dst_idx = + ob * (I_blocks * H * W * block_o * block_i) + + ib * (H * W * block_o * block_i) + + h * (W * block_o * block_i) + + w * (block_o * block_i) + + o_in_block * block_i + + i_in_block; + + // Only copy if within original dimensions (handle padding) + if (o < O && i < I) { + // Source index: oihw format + const size_t src_idx = ((o * I + i) * H + h) * W + w; + dst[dst_idx] = src[src_idx]; + } + // For padding (o >= O or i >= I), dst is already zero-initialized + } + } + } + } + } + } +} + +Status WeightLayoutTransformer::TransformLayout(const Tensor& original_tensor, + const std::string& format_descriptor, + std::unique_ptr& transformed_tensor) { + const auto& orig_shape = original_tensor.Shape(); + const auto* elem_type = original_tensor.DataType(); + + // Only support 4D tensors + if (orig_shape.NumDimensions() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported format transformation: ", format_descriptor); + } + + // Validate tensor location (common for all formats) + if (original_tensor.Location().device.Type() != OrtDevice::CPU) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Tensor is not on CPU, device type: ", original_tensor.Location().device.Type()); + } + + const int64_t O = orig_shape[0]; + const int64_t I = orig_shape[1]; + const int64_t H = orig_shape[2]; + const int64_t W = orig_shape[3]; + + // Helper lambda to execute transformation for a specific data type + auto execute_transform = [&](auto&& transform_func, const TensorShape& new_shape, + size_t buffer_size = 0) -> Status { + const T* src = original_tensor.Data(); + if (!src) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Source tensor data pointer is null"); + } + + auto cpu_allocator = std::make_shared(); + transformed_tensor = std::make_unique(elem_type, new_shape, cpu_allocator); + T* dst = transformed_tensor->MutableData(); + + // Zero-initialize if buffer size is specified (for blocked formats with padding) + if (buffer_size > 0) { + std::memset(dst, 0, buffer_size * sizeof(T)); + } + + transform_func(src, dst); + return Status::OK(); + }; + + // Helper lambda to dispatch based on data type + auto dispatch_by_type = [&](auto&& transform_func, const TensorShape& new_shape, + size_t buffer_size = 0, const char* error_msg = "Unsupported data type") -> Status { + if (elem_type == DataTypeImpl::GetType()) { + return execute_transform.template operator()(transform_func, new_shape, buffer_size); + } else if (elem_type == DataTypeImpl::GetType()) { + return execute_transform.template operator()(transform_func, new_shape, buffer_size); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, error_msg); + } + }; + + if (format_descriptor == "hwio") { + // Transpose from oihw to hwio + // Transposed shape: [H, W, I, O] + TensorShape new_shape({H, W, I, O}); + + auto transform_func = [&](auto* src, auto* dst) { + TransposeOIHWToHWIO(src, dst, O, I, H, W); + }; + + return dispatch_by_type(transform_func, new_shape, 0, "Unsupported data type for hwio transpose"); + } else if (format_descriptor == "ABcd16a4b") { + // Reorder from oihw to blocked format + constexpr int64_t block_o = 16; + constexpr int64_t block_i = 4; + + const int64_t O_padded = ((O + block_o - 1) / block_o) * block_o; + const int64_t I_padded = ((I + block_i - 1) / block_i) * block_i; + const int64_t O_blocks = O_padded / block_o; + const int64_t I_blocks = I_padded / block_i; + + // Keep 4D shape for kernel compatibility, but data is in blocked format + // Shape: [O_padded, I_padded, H, W] with data internally blocked as ABcd16a4b + TensorShape new_shape({O_padded, I_padded, H, W}); + const size_t buffer_size = O_blocks * I_blocks * H * W * block_o * block_i; + + auto transform_func = [&](auto* src, auto* dst) { + ReorderToBlockedFormat(src, dst, O, I, H, W, O_blocks, I_blocks, block_o, block_i); + }; + + return dispatch_by_type(transform_func, new_shape, buffer_size, "Unsupported data type for blocked format"); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported format transformation: ", format_descriptor); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/weight_layout_transformer.h b/onnxruntime/core/providers/webgpu/weight_layout_transformer.h new file mode 100644 index 0000000000000..4409a238dbace --- /dev/null +++ b/onnxruntime/core/providers/webgpu/weight_layout_transformer.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/status.h" +#include "core/framework/tensor.h" +#include +#include + +namespace onnxruntime { +namespace webgpu { + +// Weight layout transformer for optimized weight formats +// Handles transformations like oihw->hwio transpose and blocked formats +class WeightLayoutTransformer { + public: + // Transform a tensor to a different layout format + // format_descriptor: Format string (e.g., "hwio", "ABcd16a4b") + // Returns Status::OK() on success, error status otherwise + static Status TransformLayout(const Tensor& original_tensor, + const std::string& format_descriptor, + std::unique_ptr& transformed_tensor); + + private: + // Transpose weights from oihw to hwio layout + template + static void TransposeOIHWToHWIO(const T* src, T* dst, + int64_t O, int64_t I, int64_t H, int64_t W); + + // Reorder weights from oihw to ABcd16a4b blocked format + template + static void ReorderToBlockedFormat(const T* src, T* dst, + int64_t O, int64_t I, int64_t H, int64_t W, + int64_t O_blocks, int64_t I_blocks, + int64_t block_o, int64_t block_i); +}; + +} // namespace webgpu +} // namespace onnxruntime