Add EP-specific weight layout transformation framework#26554
Add EP-specific weight layout transformation framework#26554jchen10 wants to merge 1 commit intomicrosoft:mainfrom
Conversation
Weight Layout Transformation DesignOverviewDesign Goal: Support EP-specific blocked and optimized weight layouts that match hardware characteristics, enabling significant performance gains without runtime overhead. Current Implementation:
MotivationThe Challenge: Hardware-Optimized Memory LayoutsModern compute accelerators (GPUs, NPUs, specialized AI chips) achieve peak performance when data is laid out in memory patterns that match their hardware architecture. However, ONNX models use standardized layouts (e.g., OIHW for Conv weights) that may not be optimal for specific hardware. The Problem:
The Solution: EP-Specific Weight Layout TransformationThis framework allows each EP to transform weights once during session initialization to hardware-optimized layouts, then use them directly at inference time. ArchitectureCore Design PrincipleSeparation of Concerns: The framework separates EP-specific format decisions from the core transformation infrastructure:
This design allows each EP to implement custom transformations without modifying core ONNX Runtime code. Two-Phase EP APIExecution providers implement two virtual methods to participate in weight transformation: class IExecutionProvider {
public:
// Phase 1: Query (lightweight) - "Do you want this weight transformed?"
// Called during graph optimization/partitioning
virtual Status GetPreferredInitializerFormat(
const Node& node, // Which node needs the initializer
int input_index, // Which input (e.g., Conv weight is index 1)
std::string& format_descriptor) const; // OUT: format name (e.g., "hwio")
// Phase 2: Transform (heavyweight) - "Transform this weight to requested format"
// Called once during session initialization per unique (initializer, format) pair
virtual Status TransformInitializerFormat(
const Tensor& original_tensor, // Input: original weight
const std::string& format_descriptor, // Which format to transform to
std::unique_ptr<Tensor>& transformed_tensor) const; // OUT: transformed result
};Design Benefits:
Format Descriptor StringFormats are identified by string descriptors that encode transformation details:
Benefits:
Block Size Encoding: The "ABcd16a4b" notation from OneDNN encodes:
Complete Data FlowKey Points:
Key Design Decisions and Rationale1. Why EP-Specific Instead of Global Transformations?Decision: Let each EP decide its own transformations via virtual methods Rationale:
Example: WebGPU needs HWIO for channels-last Conv, but another EP might prefer a different layout or no transformation at all. 2. Why CPU-Based Transformation?Decision: Transform weights on CPU during session init, before device loading Rationale:
Trade-off: Slightly slower session initialization for much faster inference. 3. Why Tensor Member Instead of External Map?Decision: Store Rationale:
Alternative rejected: External map would require synchronization and lifetime management. 4. Why Both Original and Transformed Initializers?Decision: Keep both versions in graph when nodes have different needs Rationale:
Memory trade-off: Small duplication for large inference speedup. |
|
The perf improvement of this PR on LNL, thanks to eliminating the kernel transpose on every inference run in Conv:
I assume it's so much good for the sd models, simply because of the regression, #26501 is trying to fix. @xhcao @JianhuiD @Jiawei-Shao PTAL |
This infrastructure enables execution providers to optimize operator weights with custom memory layouts (such as blocked formats) during session initialization, dramatically improving inference performance through better cache utilization and memory access patterns. Current Implementation: HWIO Transpose (WebGPU EP): Transposes Conv weights from OIHW to HWIO layout as the first application of the framework ABcd16a4b Blocking (Proof-of-Concept): OneDNN-style blocked format with 16×4 tiles demonstrates the framework's primary purpose The framework is generic and extensible, allowing any EP to implement custom weight transformations optimized for their target hardware.
|
General question, does this transformation change the shape of the weight? Also, have you considered using PrePack mechanism? |
|
@yuslepukhin Thanks for your attention.
|
|
First, I think this is a regression for WebGPU EP. JSEP has the code to do the kernel transpose only once if it's initializer, but the same logic is missing in WebGPU EP. Then the
So the summary:
|
|
Thank you @fs-eire for the insightful comments. |
Currently, graph capture is recorded when the regular run count > min_num_runs_before_graph_capture_(1). So it won't impact the correctness if the first run is different with the second run. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc#L1041-L1043. |
|
There are many concerns with this PR.
I have not looked in-depth yet, so I do not have a good proposal yet. One clue may be taken from compiling EPs that perform weights transformations internally and retain them while removing a reference to the original weight from the model. |
fs-eire
left a comment
There was a problem hiding this comment.
According to the discussion, let's redo this optimization:
- use the JSEP approach and use a new PR.
- do we already have the model that multiple
Convshares the same kernel?- if no, then let's start with a simpler version that uses a
map<input_shape, tensor>inside the Conv node for the weight cache. - if yes, then we need to think about how to put a
map<initializer_tensor, map<runtime_data, tensor>>inside the EP object for the weight cache.
- if no, then let's start with a simpler version that uses a
- only cache when weight is initializer, and when the cache reaches the size limit(for example, you want to only cache at most N different input shape, otherwise you may run into OOM. this is unlikely to happen to a real model usage, I just mention for information)
|
Thank you all for the comments. I will work a new PR as you proposed. |
This infrastructure enables execution providers to optimize operator weights with custom memory layouts (such as blocked formats) during session initialization, dramatically improving inference performance through better cache utilization and memory access patterns.
Current Implementation:
The framework is generic and extensible, allowing any EP to implement custom weight transformations optimized for their target hardware.