diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index baf21745e0e40..0b37ade63a03c 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -99,7 +99,7 @@ option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF)
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" OFF)
cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
-cmake_dependent_option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" ON "onnxruntime_USE_CUDA" OFF)
+option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" OFF)
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
option(onnxruntime_USE_AVX "Use AVX instructions" OFF)
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index cbfc38068ac2a..42eedd5c2feb2 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -3121,13 +3121,13 @@ This version of the operator has been available since version 1 of the 'com.micr
- input : T
-- 2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+- 2D input tensor with shape (num_tokens, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
- router_probs : T
-- 2D input tensor with shape (num_rows, num_experts)
+- 2D input tensor with shape (num_tokens, num_experts)
- fc1_experts_weights : T
-- 3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size) for swiglu
+- 3D input tensor with shape (num_experts, fusion_size * inter_size, hidden_size), where fusion_size is 2 for fused swiglu, and 1 otherwise
- fc1_experts_bias (optional) : T
-- 2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
+- 2D optional input tensor with shape (num_experts, fusion_size * inter_size)
- fc2_experts_weights : T
- 3D input tensor with shape (num_experts, hidden_size, inter_size)
- fc2_experts_bias (optional) : T
@@ -3142,7 +3142,7 @@ This version of the operator has been available since version 1 of the 'com.micr
- output : T
-- 2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+- 2D input tensor with shape (num_tokens, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
#### Type Constraints
@@ -4532,7 +4532,23 @@ This version of the operator has been available since version 1 of the 'com.micr
### **com.microsoft.QMoE**
- Quantized MoE
+ Quantized mixture of experts (MoE).
+
+ Only weights are quantized with symmetric quantization.
+ The quantized weights are stored in column major order per expert.
+ The quantization block size can be specified. If not provided, column wise quantization is used.
+
+ The SwiGLU (Swish-Gated Linear Unit) activation function is like:
+ g = xW + b
+ l = xV + c
+ G = clamp(g, max=limit)
+ L = clamp(l, min=-limit, max=limit)
+ swiglu = G * sigmoid(alpha * G) * (L + beta)
+ where x is the input, W and V are weight matrices, b and c are bias vectors, and alpha, beta and limit are constant float parameters.
+ When swiglu_fusion=0, two GEMMs are not fused, and they are FC1 and FC3 in the inputs.
+ When swiglu_fusion=1, two GEMMs are fused so that g and l are computed in a single GEMM (FC1), and g and l are interleaved on each row of size 2 * inter_size.
+ When swiglu_fusion=2, two GEMMs are fused, and g and l are concatenated on each row.
+
#### Version
@@ -4547,6 +4563,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- Beta parameter used in activation function.
- activation_type : string
- Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
+- block_size : int
+- Size of each quantization block along the K (input feature) dimension. Must be power of two and ≥ 16 (e.g., 16, 32, 64, 128). If provided, both hidden_size and inter_size must be divisible by the block size. Otherwise, there is no blocking and a whole column shares one scaling factor.
- expert_weight_bits : int
- Number of bits used in quantized weights. Default is 4 bits
- k : int
@@ -4565,34 +4583,34 @@ This version of the operator has been available since version 1 of the 'com.micr
- input : T
-- 2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+- 2D tensor with shape (num_tokens, hidden_size), or 3D tensor with shape (batch_size, sequence_length, hidden_size)
- router_probs : T
-- 2D input tensor with shape (num_rows, num_experts)
+- 2D tensor with shape (num_tokens, num_experts)
- fc1_experts_weights : T1
-- 3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, inter_size, hidden_size / 2) for 4 bits. For swiglu, shape can be (num_experts, 2 * inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size / 2) for 4 bits.
+- 3D tensor with shape (num_experts, fusion_size * inter_size, hidden_size / pack_size), The fusion_size is 2 for fused swiglu, or 1 otherwise. The pack_size is 8 / expert_weight_bits.
- fc1_scales : T2
-- 2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
+- 2D tensor with shape (num_experts, fusion_size * inter_size), or 3D tensor with shape (num_experts, fusion_size * inter_size, hidden_size / block_size) when block_size is provided.
- fc1_experts_bias (optional) : T
-- 2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
+- 2D optional tensor with shape (num_experts, fusion_size * inter_size)
- fc2_experts_weights : T1
-- 3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2) for 4 bits
+- 3D tensor with shape (num_experts, hidden_size, inter_size / pack_size)
- fc2_scales : T2
-- 2D input tensor with shape (num_experts, hidden_size)
+- 2D tensor with shape (num_experts, hidden_size), or 3D tensor with shape (num_experts, hidden_size, inter_size / block_size) when block_size is provided.
- fc2_experts_bias (optional) : T
-- 2D optional input tensor with shape (num_experts, hidden_size)
+- 2D optional tensor with shape (num_experts, hidden_size)
- fc3_experts_weights (optional) : T1
-- 3D optional input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
+- 3D optional tensor with shape (num_experts, inter_size, hidden_size / pack_size)
- fc3_scales (optional) : T2
-- 2D optional input tensor with shape (num_experts, inter_size)
+- 2D optional tensor with shape (num_experts, inter_size), or 3D optional tensor with shape (num_experts, inter_size, hidden_size / block_size) when block_size is provided.
- fc3_experts_bias (optional) : T
-- 2D optional input tensor with shape (num_experts, inter_size)
+- 2D optional tensor with shape (num_experts, inter_size)
#### Outputs
- output : T
-- 2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)
+- output tensor with same shape of input
#### Type Constraints
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl
index 4f901a550e8bf..588f37051b534 100644
--- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl
@@ -60,7 +60,7 @@ namespace cutlass_kernels {
template
-#ifdef COMPILE_HOPPER_TMA_GEMMS
+#if defined(COMPILE_HOPPER_TMA_GEMMS) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 900) && defined(__NV_SASS_VERSION__)
void sm90_generic_mixed_gemm_kernelLauncher(
ActivationType const* A, WeightType const* B,
ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
@@ -269,6 +269,7 @@ void sm90_generic_mixed_gemm_kernelLauncher(
}
}
#else // COMPILE_HOPPER_TMA_GEMMS
+// This stub is now used for ALL non-SASS or non-SM90A compilation passes includes the 90-virtual (PTX) pass.
void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const*, WeightType const*,
ScaleZeroType const*, ScaleZeroType const*, BiasType const*,
float const, OutputType*, int, int, int, int const, tkc::CutlassGemmConfig,
diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc
index 925a6913a2890..e5b15856a6c05 100644
--- a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc
+++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc
@@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+#if USE_FPA_INTB_GEMM
#include "contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h"
#include "contrib_ops/cuda/llm/common/workspace.h"
@@ -97,3 +98,4 @@ bool WeightOnlyGroupwiseQuantGemmPluginProfiler::checkTactic(int m, int /*n*/, i
}
} // namespace onnxruntime::llm::kernels::weight_only
+#endif
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index b48fe8c1e1839..c973a281e6c3f 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -1412,22 +1412,41 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1))
.Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast(0))
.Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0))
- .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
- .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T")
- .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size), or (num_experts, 2 * inter_size, hidden_size) for swiglu", "T")
- .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional)
+ .Input(0, "input", "2D input tensor with shape (num_tokens, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
+ .Input(1, "router_probs", "2D input tensor with shape (num_tokens, num_experts)", "T")
+ .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, fusion_size * inter_size, hidden_size), where fusion_size is 2 for fused swiglu, and 1 otherwise", "T")
+ .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, fusion_size * inter_size)", "T", OpSchema::Optional)
.Input(4, "fc2_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T")
.Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional)
.Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, inter_size, hidden_size)", "T", OpSchema::Optional)
.Input(7, "fc3_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional)
- .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
+ .Output(0, "output", "2D input tensor with shape (num_tokens, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
+constexpr const char* qMoE_ver1_doc = R"DOC(
+ Quantized mixture of experts (MoE).
+
+ Only weights are quantized with symmetric quantization.
+ The quantized weights are stored in column major order per expert.
+ The quantization block size can be specified. If not provided, column wise quantization is used.
+
+ The SwiGLU (Swish-Gated Linear Unit) activation function is like:
+ g = xW + b
+ l = xV + c
+ G = clamp(g, max=limit)
+ L = clamp(l, min=-limit, max=limit)
+ swiglu = G * sigmoid(alpha * G) * (L + beta)
+ where x is the input, W and V are weight matrices, b and c are bias vectors, and alpha, beta and limit are constant float parameters.
+ When swiglu_fusion=0, two GEMMs are not fused, and they are FC1 and FC3 in the inputs.
+ When swiglu_fusion=1, two GEMMs are fused so that g and l are computed in a single GEMM (FC1), and g and l are interleaved on each row of size 2 * inter_size.
+ When swiglu_fusion=2, two GEMMs are fused, and g and l are concatenated on each row.
+ )DOC";
+
ONNX_MS_OPERATOR_SET_SCHEMA(
QMoE, 1,
OpSchema()
- .SetDoc("Quantized MoE")
+ .SetDoc(qMoE_ver1_doc)
.Attr("activation_type",
"Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu",
AttributeProto::STRING,
@@ -1440,63 +1459,90 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Whether to normalize routing weights",
AttributeProto::INT,
static_cast(0))
- .Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0))
+ .Attr("use_sparse_mixer",
+ "Whether to use sparse mixer",
+ AttributeProto::INT,
+ static_cast(0))
.Attr("expert_weight_bits",
"Number of bits used in quantized weights. Default is 4 bits",
AttributeProto::INT,
static_cast(4))
- .Attr("swiglu_fusion", "0: not fused, 1: fused and interleaved. 2: fused and not interleaved.", AttributeProto::INT, static_cast(0))
- .Attr("swiglu_limit", "The limit used to clamp inputs in SwiGLU. It is infinite when limit is not provided.", AttributeProto::FLOAT, OPTIONAL_VALUE)
- .Attr("activation_alpha", "Alpha parameter used in activation function.", AttributeProto::FLOAT, 1.0f)
- .Attr("activation_beta", "Beta parameter used in activation function.", AttributeProto::FLOAT, 0.0f)
+ .Attr("swiglu_fusion",
+ "0: not fused, 1: fused and interleaved. 2: fused and not interleaved.",
+ AttributeProto::INT,
+ static_cast(0))
+ .Attr("swiglu_limit",
+ "The limit used to clamp inputs in SwiGLU. It is infinite when limit is not provided.",
+ AttributeProto::FLOAT,
+ OPTIONAL_VALUE)
+ .Attr("activation_alpha",
+ "Alpha parameter used in activation function.",
+ AttributeProto::FLOAT, 1.0f)
+ .Attr("activation_beta",
+ "Beta parameter used in activation function.",
+ AttributeProto::FLOAT, 0.0f)
+ .Attr("block_size",
+ "Size of each quantization block along the K (input feature) dimension. "
+ "Must be power of two and ≥ 16 (e.g., 16, 32, 64, 128). "
+ "If provided, both hidden_size and inter_size must be divisible by the block size. "
+ "Otherwise, there is no blocking and a whole column shares one scaling factor. ",
+ AttributeProto::INT,
+ OPTIONAL_VALUE)
.Input(0,
"input",
- "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape "
- "(batch_size, sequence_length, hidden_size)",
+ "2D tensor with shape (num_tokens, hidden_size), or "
+ "3D tensor with shape (batch_size, sequence_length, hidden_size)",
+ "T")
+ .Input(1,
+ "router_probs",
+ "2D tensor with shape (num_tokens, num_experts)",
"T")
- .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T")
.Input(2,
"fc1_experts_weights",
- "3D input tensor with shape (num_experts, inter_size, hidden_size), "
- "or (num_experts, inter_size, hidden_size / 2) for 4 bits. "
- "For swiglu, shape can be (num_experts, 2 * inter_size, hidden_size), "
- "or (num_experts, 2 * inter_size, hidden_size / 2) for 4 bits.",
+ "3D tensor with shape (num_experts, fusion_size * inter_size, hidden_size / pack_size), "
+ "The fusion_size is 2 for fused swiglu, or 1 otherwise. The pack_size is 8 / expert_weight_bits.",
"T1")
- .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T2")
+ .Input(3,
+ "fc1_scales",
+ "2D tensor with shape (num_experts, fusion_size * inter_size), or "
+ "3D tensor with shape (num_experts, fusion_size * inter_size, hidden_size / block_size) when block_size is provided.",
+ "T2")
.Input(4,
"fc1_experts_bias",
- "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional)
+ "2D optional tensor with shape (num_experts, fusion_size * inter_size)", "T", OpSchema::Optional)
.Input(5,
"fc2_experts_weights",
- "3D input tensor with shape (num_experts, hidden_size, inter_size) "
- "or (num_experts, hidden_size, inter_size / 2) for 4 bits",
+ "3D tensor with shape (num_experts, hidden_size, inter_size / pack_size)",
"T1")
- .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T2")
+ .Input(6,
+ "fc2_scales",
+ "2D tensor with shape (num_experts, hidden_size), or "
+ "3D tensor with shape (num_experts, hidden_size, inter_size / block_size) when block_size is provided.",
+ "T2")
.Input(7,
"fc2_experts_bias",
- "2D optional input tensor with shape (num_experts, hidden_size)",
+ "2D optional tensor with shape (num_experts, hidden_size)",
"T",
OpSchema::Optional)
.Input(8,
"fc3_experts_weights",
- "3D optional input tensor with shape (num_experts, inter_size, hidden_size) "
- "or (num_experts, inter_size, hidden_size / 2)",
+ "3D optional tensor with shape (num_experts, inter_size, hidden_size / pack_size)",
"T1",
OpSchema::Optional)
.Input(9,
"fc3_scales",
- "2D optional input tensor with shape (num_experts, inter_size)",
+ "2D optional tensor with shape (num_experts, inter_size), or "
+ "3D optional tensor with shape (num_experts, inter_size, hidden_size / block_size) when block_size is provided.",
"T2",
OpSchema::Optional)
.Input(10,
"fc3_experts_bias",
- "2D optional input tensor with shape (num_experts, inter_size)",
+ "2D optional tensor with shape (num_experts, inter_size)",
"T",
OpSchema::Optional)
.Output(0,
"output",
- "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape "
- "(batch_size, sequence_length, hidden_size)",
+ "output tensor with same shape of input",
"T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("T1", {"tensor(uint8)"}, "Constrain weights type to uint8 tensors.")
diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index 9a97711996343..a9506f2fa1b35 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -2678,6 +2678,27 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
// only return data if it's for a constant initializer. checks for outer scope initializers
// if this is a subgraph and the name isn't found locally.
const TensorProto* initializer = graph_.GetConstantInitializer(def->Name(), true);
+ if (initializer != nullptr) {
+ // Check if this is in-memory external data (data stored in OrtValue)
+ // ONNX shape inference cannot handle external data, so we need to materialize it
+ if (utils::HasExternalDataInMemory(*initializer)) {
+ // Try to get the OrtValue for this initializer
+ OrtValue ort_value;
+ if (graph_.GetOrtValueInitializer(def->Name(), ort_value, true)) {
+ // Create a temporary TensorProto with the actual data from the OrtValue
+ // This allows ONNX shape inference to access the data
+ const Tensor& tensor = ort_value.Get();
+ auto temp_tensor_proto = utils::TensorToTensorProto(tensor, initializer->name(), /*use_tensor_buffer=*/false);
+ // Store the temporary proto so it outlives this call, maintain pointers steady
+ temp_tensor_protos_.push_back(std::make_unique(std::move(temp_tensor_proto)));
+ return temp_tensor_protos_.back().get();
+ } else {
+ // If we can't get the OrtValue, it is a bug
+ ORT_THROW("Initializer ", def->Name(),
+ " has in-memory external data but cannot get OrtValue during shape inference");
+ }
+ }
+ }
return initializer;
}
@@ -2717,6 +2738,11 @@ class InferenceContextImpl : public ONNX_NAMESPACE::InferenceContext {
std::vector> graph_inferencers_;
const Graph& graph_;
const Graph::ResolveOptions& options_;
+ // Temporary TensorProtos created for in-memory external data during shape inference
+ // These need to outlive the shape inference call, so we store them here
+ // Inference is per node and the instance of this context is on the stack,
+ // so this is safe.
+ mutable InlinedVector> temp_tensor_protos_;
};
Status Graph::InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph,
diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp
index bc1221475fd90..9518134631f2d 100644
--- a/onnxruntime/core/mlas/lib/convolve.cpp
+++ b/onnxruntime/core/mlas/lib/convolve.cpp
@@ -729,6 +729,82 @@ Return Value:
}
}
+void
+MlasConvExpandThenGemmSegmentedThreaded(
+ void* Context,
+ ptrdiff_t Index
+)
+/*++
+
+Routine Description:
+
+ This routine is invoked from a worker thread to execute a segment of a
+ convolution operation.
+
+ If using this, the entire convolution operation is parallelized on the
+ (batch size * group count) parameter and this routine has logic to
+ perform a specific thread's shard of the entire Convolution operation.
+
+Arguments:
+
+ Context - Supplies the pointer to the context for the threaded operation.
+
+ Index - Supplies the current index of the threaded operation.
+
+Return Value:
+
+ None.
+
+--*/
+
+{
+ MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context;
+
+ const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters;
+
+ const size_t GroupCount = Parameters->GroupCount;
+ const size_t BatchGroupCount = Parameters->BatchCount * GroupCount;
+
+ const size_t TargetThreadCount = WorkBlock->TargetThreadCount;
+
+ const size_t BatchGroupCountPerThread = BatchGroupCount / TargetThreadCount;
+ const size_t BatchGroupCountExtra = BatchGroupCount % TargetThreadCount;
+
+ size_t BatchGroupStart;
+ size_t BatchGroupEnd;
+
+ if (static_cast(Index) < BatchGroupCountExtra) {
+ BatchGroupStart = (BatchGroupCountPerThread + 1) * Index;
+ BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread + 1;
+ } else {
+ BatchGroupStart = BatchGroupCountPerThread * Index + BatchGroupCountExtra;
+ BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread;
+ }
+
+ const size_t FilterCount = Parameters->FilterCount;
+ const size_t OutputSize = Parameters->OutputSize;
+ const size_t K = Parameters->K;
+
+ const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize;
+ const size_t OutputGroupSize = FilterCount * OutputSize;
+ const size_t FilterGroupSize = FilterCount * K;
+
+ for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) {
+ size_t group = bg % GroupCount;
+
+ const float* input = WorkBlock->Input + bg * InputGroupSize;
+ const float* filter = WorkBlock->Filter + group * FilterGroupSize;
+ float* output = WorkBlock->Output + bg * OutputGroupSize;
+ const float* bias = WorkBlock->Bias;
+ if (bias != nullptr) {
+ bias += group * FilterCount;
+ }
+ float* ColumnBuffer = WorkBlock->WorkingBuffer + Index * OutputSize * K;
+
+ MlasConvOperation(Parameters, input, filter, bias, ColumnBuffer, output, 0, OutputSize);
+ }
+}
+
inline
bool
MlasConvTryMultithread(
@@ -890,8 +966,8 @@ Return Value:
ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool);
- if (size_t(TargetThreadCount) >= BatchGroupCount) {
- TargetThreadCount = ptrdiff_t(BatchGroupCount);
+ if (static_cast(TargetThreadCount) >= BatchGroupCount) {
+ TargetThreadCount = static_cast(BatchGroupCount);
}
MLAS_CONV_WORK_BLOCK WorkBlock;
@@ -919,6 +995,30 @@ Return Value:
#endif
+ if (Algorithm == MlasConvAlgorithmExpandThenGemmSegmented && ((BatchCount > 1) || (GroupCount > 1))) {
+ const size_t BatchGroupCount = BatchCount * GroupCount;
+
+ ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool);
+
+ if (static_cast(TargetThreadCount) >= BatchGroupCount) {
+ TargetThreadCount = static_cast(BatchGroupCount);
+ }
+
+ MLAS_CONV_WORK_BLOCK WorkBlock;
+
+ WorkBlock.Parameters = Parameters;
+ WorkBlock.Input = Input;
+ WorkBlock.Filter = Filter;
+ WorkBlock.Bias = Bias;
+ WorkBlock.WorkingBuffer = WorkingBuffer;
+ WorkBlock.Output = Output;
+ WorkBlock.TargetThreadCount = TargetThreadCount;
+
+ MlasExecuteThreaded(MlasConvExpandThenGemmSegmentedThreaded, &WorkBlock, TargetThreadCount, ThreadPool);
+
+ return;
+ }
+
//
// Iterate over each batch and group.
//
@@ -1308,6 +1408,18 @@ Return Value:
Parameters->u.ExpandThenGemmSegmented.ThreadStrideN = StrideN;
*WorkingBufferSize = TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD;
+
+ if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) {
+
+ size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K,
+ Parameters->FilterCount * Parameters->OutputSize,
+ static_cast(MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD)});
+ TargetThreadCount = MaximumThreadCount;
+ if (static_cast(TargetThreadCount) >= Parameters->BatchCount * Parameters->GroupCount) {
+ TargetThreadCount = static_cast(Parameters->BatchCount * Parameters->GroupCount);
+ }
+ *WorkingBufferSize = TargetThreadCount * WorkingBufferSizePerThread;
+ }
}
}
#if defined(_MSC_VER) && !defined(__clang__)
diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
index 4a6545a0e6f0a..2bdbfb9c1c62e 100644
--- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
+++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
@@ -1472,7 +1472,7 @@ Status QNNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_optio
}
uint32_t rpc_polling_time = 0;
- if (qnn::HtpPerformanceMode::kHtpBurst != htp_performance_mode) {
+ if (qnn::HtpPerformanceMode::kHtpBurst == htp_performance_mode) {
rpc_polling_time = 9999;
}
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
index 508d932459bf9..cd0c0e4bffdb5 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
@@ -3976,6 +3976,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
// Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior.
trt_state->context->reset();
trt_state->engine->reset();
+
+ // Clear dds output allocator map since the engine and context will be recreated.
+ dds_output_allocator_map.clear();
+
auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig());
if (max_workspace_size_ > 0) {
trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_);
diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc
index 14f12c906f11a..5fc0b8900730b 100644
--- a/onnxruntime/core/providers/vitisai/imp/global_api.cc
+++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc
@@ -580,6 +580,8 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
graph.RemoveInitializedTensor(tensor_name);
};
the_global_api.graph_reverse_dfs_from_preemp = vaip::graph_reverse_dfs_from;
+ the_global_api.graph_save_string = vaip::graph_save_string;
+
if (!s_library_vitisaiep.vaip_get_version) {
return reinterpret_cast(&(the_global_api.host_));
} else {
diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc
index c6bf29dafa184..028ee7fa8c5ce 100644
--- a/onnxruntime/core/providers/vitisai/imp/graph.cc
+++ b/onnxruntime/core/providers/vitisai/imp/graph.cc
@@ -205,6 +205,29 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri
vai_assert(result, "model serialize to ostream error");
}
+vaip_core::DllSafe graph_save_string(const Graph& graph) {
+ auto model_proto = const_cast(graph.GetModel()).ToProto();
+ auto graph_proto_subgraph = graph.ToGraphProto();
+ *model_proto->mutable_graph() = *graph_proto_subgraph;
+ auto& logger = logging::LoggingManager::DefaultLogger();
+ auto model = Model::Create(std::move(*model_proto), graph.ModelPath(), nullptr, logger);
+ model_proto = model->ToProto();
+ auto& metadata = model->MetaData();
+ if (!metadata.empty()) {
+ auto metadata_props = model_proto->mutable_metadata_props();
+ metadata_props->Clear();
+ for (auto& m : metadata) {
+ auto prop = metadata_props->Add();
+ *prop->mutable_key() = m.first;
+ *prop->mutable_value() = m.second;
+ }
+ }
+ std::string graph_string;
+ bool result = model_proto->SerializeToString(graph_string);
+ vai_assert(result, "model serialize to string error");
+ return vaip_core::DllSafe(graph_string);
+}
+
Node& graph_fuse(Graph& graph, const std::string& name,
const std::string& op_type,
const std::vector& nodes,
diff --git a/onnxruntime/core/providers/vitisai/include/vaip/graph.h b/onnxruntime/core/providers/vitisai/include/vaip/graph.h
index bd8d0229d627c..440b8295da658 100644
--- a/onnxruntime/core/providers/vitisai/include/vaip/graph.h
+++ b/onnxruntime/core/providers/vitisai/include/vaip/graph.h
@@ -14,6 +14,7 @@ Node& graph_add_node(Graph& graph, const std::string& name, const std::string& o
const NodeAttributes& attributes, const std::string& domain);
void graph_save(const Graph& graph, const std::string& filename, const std::string& dat_filename,
size_t initializer_size_threshold);
+vaip_core::DllSafe graph_save_string(const Graph& graph);
Node& graph_fuse(Graph& graph, const std::string& name, const std::string& op_type, const std::vector& nodes,
const std::vector& inputs, const std::vector& outputs,
const std::vector& constant_initializers);
diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h
index 63949116507e4..acb258894e11c 100644
--- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h
+++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h
@@ -13,7 +13,7 @@ struct OrtApi;
namespace vaip_core {
-#define VAIP_ORT_API_MAJOR (17u)
+#define VAIP_ORT_API_MAJOR (18u)
#define VAIP_ORT_API_MINOR (0u)
#define VAIP_ORT_API_PATCH (0u)
struct OrtApiForVaip {
@@ -252,10 +252,11 @@ struct OrtApiForVaip {
stop); // [103]
void (*graph_set_name)(Graph& graph, const char* name); // [104]
void (*graph_infer_shapes_from_filepath)(
- const std::string& m, const std::string& save_path); // [105]
- GraphProto* (*graph_to_graph_proto)(const Graph& graph); // [106]
- void (*graph_proto_delete)(GraphProto* p); // [107]
- void (*graph_infer_shapes)(ModelProto& m); // [108]
+ const std::string& m, const std::string& save_path); // [105]
+ GraphProto* (*graph_to_graph_proto)(const Graph& graph); // [106]
+ void (*graph_proto_delete)(GraphProto* p); // [107]
+ void (*graph_infer_shapes)(ModelProto& m); // [108]
+ DllSafe (*graph_save_string)(const Graph& graph); // [109]
};
#ifndef USE_VITISAI
diff --git a/onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc b/onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc
index cac6d46226ef8..1652d16f5cb66 100644
--- a/onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc
+++ b/onnxruntime/test/contrib_ops/cuda_kernels/fpA_intB_gemm_kernel_test.cc
@@ -2,8 +2,8 @@
// Licensed under the MIT License.
// Test can be run like the following:
-// ./onnxruntime_test_all --gtest_filter=CUDA_EP_Unittest.*
-
+// ./onnxruntime_provider_test --gtest_filter=CUDA_EP_Unittest.*
+#if USE_FPA_INTB_GEMM
#include
#include
#include
@@ -620,3 +620,4 @@ TEST_F(Bf16Int4GroupwiseTest, BF16_Int4_Gemm_CudaKernel) {
}
}
}
+#endif
diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc
index ca1166e19037c..027a2f0019386 100644
--- a/onnxruntime/test/ir/graph_test.cc
+++ b/onnxruntime/test/ir/graph_test.cc
@@ -2,13 +2,17 @@
// Licensed under the MIT License.
#include
+#include
#include "core/common/inlined_containers.h"
#include "core/common/span_utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/graph/op.h"
+#include "core/session/inference_session.h"
+#include "core/session/environment.h"
#include "test/providers/provider_test_utils.h"
+#include "test/test_environment.h"
#include "gtest/gtest.h"
#include "gmock/gmock.h"
#include "onnx/defs/function.h"
@@ -2573,5 +2577,259 @@ TEST_F(GraphTest, GraphConstruction_MemoryEfficientTopologicalSort_SubgraphGener
#endif
+// Test for shape inference with in-memory external data (issue #26261)
+// This tests the fix for a regression where Constant nodes with large tensors (>127 bytes)
+// stored as in-memory external data would cause shape inference to fail
+TEST_F(GraphTest, ShapeInferenceWithInMemoryExternalData) {
+ // Create a model with a Constant node that produces a tensor larger than kSmallTensorExternalDataThreshold (127 bytes)
+ // This will trigger the in-memory externalization path
+ ModelProto model_proto;
+ model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
+ auto* opset = model_proto.add_opset_import();
+ opset->set_version(17);
+
+ auto* graph_proto = model_proto.mutable_graph();
+ graph_proto->set_name("test_graph");
+
+ // Create a Constant node with a tensor of 16 INT64 values (128 bytes, just over the 127 threshold)
+ auto* constant_node = graph_proto->add_node();
+ constant_node->set_op_type("Constant");
+ constant_node->set_name("const_node");
+ constant_node->add_output("const_output");
+
+ // Add the value attribute with a tensor
+ auto* attr = constant_node->add_attribute();
+ attr->set_name("value");
+ attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR);
+ auto* tensor = attr->mutable_t();
+ tensor->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
+ tensor->add_dims(16); // 16 elements * 8 bytes = 128 bytes
+ // Each split will be size 1, totaling 16
+ for (int64_t i = 0; i < 16; ++i) {
+ tensor->add_int64_data(1);
+ }
+
+ // Create a Split node that uses the constant as input
+ // Split requires constant input for the 'split' parameter, which triggers shape inference
+ auto* split_node = graph_proto->add_node();
+ split_node->set_op_type("Split");
+ split_node->set_name("split_node");
+ split_node->add_input("input_data");
+ split_node->add_input("const_output"); // Use constant as split sizes
+ for (int i = 0; i < 16; ++i) {
+ split_node->add_output("split_output_" + std::to_string(i));
+ }
+
+ // Add axis attribute
+ auto* axis_attr = split_node->add_attribute();
+ axis_attr->set_name("axis");
+ axis_attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT);
+ axis_attr->set_i(0);
+
+ // Add graph input
+ auto* input = graph_proto->add_input();
+ input->set_name("input_data");
+ auto* input_type = input->mutable_type()->mutable_tensor_type();
+ input_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
+ input_type->mutable_shape()->add_dim()->set_dim_value(16);
+ input_type->mutable_shape()->add_dim()->set_dim_value(10);
+
+ // Add graph outputs
+ for (int i = 0; i < 16; ++i) {
+ auto* output = graph_proto->add_output();
+ output->set_name("split_output_" + std::to_string(i));
+ }
+
+ // Load the model - this should succeed with the fix
+ // Before the fix, this would fail with:
+ // "Cannot parse data from external tensors. Please load external data into raw data for tensor"
+ std::shared_ptr model;
+ ASSERT_STATUS_OK(Model::Load(std::move(model_proto), model, nullptr, *logger_));
+
+ // Verify the graph was properly constructed
+ Graph& graph = model->MainGraph();
+ ASSERT_STATUS_OK(graph.Resolve());
+
+ // Verify the constant node was converted to an initializer
+ const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
+ ASSERT_TRUE(graph.GetInitializedTensor("const_output", initializer));
+ ASSERT_NE(initializer, nullptr);
+
+ // Verify the Split node can access the constant data during shape inference
+ const Node* split_node_ptr = nullptr;
+ for (const auto& node : graph.Nodes()) {
+ if (node.Name() == "split_node") {
+ split_node_ptr = &node;
+ break;
+ }
+ }
+ ASSERT_NE(split_node_ptr, nullptr);
+
+ // Verify outputs are properly shaped
+ ASSERT_EQ(split_node_ptr->OutputDefs().size(), 16u);
+}
+
+// Test for shape inference with in-memory external data using InferenceSession
+// This test more accurately reproduces the issue by going through the full session initialization
+// which includes graph optimizations that trigger the in-memory externalization
+TEST_F(GraphTest, ShapeInferenceWithInMemoryExternalDataViaSession) {
+ // Create the same model as above
+ ModelProto model_proto;
+ model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
+ auto* opset = model_proto.add_opset_import();
+ opset->set_version(17);
+
+ auto* graph_proto = model_proto.mutable_graph();
+ graph_proto->set_name("test_graph");
+
+ // Create a Constant node with a tensor of 16 INT64 values (128 bytes)
+ auto* constant_node = graph_proto->add_node();
+ constant_node->set_op_type("Constant");
+ constant_node->set_name("const_node");
+ constant_node->add_output("const_output");
+
+ auto* attr = constant_node->add_attribute();
+ attr->set_name("value");
+ attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR);
+ auto* tensor = attr->mutable_t();
+ tensor->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
+ tensor->add_dims(16);
+ for (int64_t i = 0; i < 16; ++i) {
+ tensor->add_int64_data(1);
+ }
+
+ // Create a Split node
+ auto* split_node = graph_proto->add_node();
+ split_node->set_op_type("Split");
+ split_node->set_name("split_node");
+ split_node->add_input("input_data");
+ split_node->add_input("const_output");
+ for (int i = 0; i < 16; ++i) {
+ split_node->add_output("split_output_" + std::to_string(i));
+ }
+
+ auto* axis_attr = split_node->add_attribute();
+ axis_attr->set_name("axis");
+ axis_attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT);
+ axis_attr->set_i(0);
+
+ // Add graph input
+ auto* input = graph_proto->add_input();
+ input->set_name("input_data");
+ auto* input_type = input->mutable_type()->mutable_tensor_type();
+ input_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
+ input_type->mutable_shape()->add_dim()->set_dim_value(16);
+ input_type->mutable_shape()->add_dim()->set_dim_value(10);
+
+ // Add graph outputs
+ for (int i = 0; i < 16; ++i) {
+ auto* output = graph_proto->add_output();
+ output->set_name("split_output_" + std::to_string(i));
+ }
+
+ // Save to a temporary file
+ const std::string model_path = "test_in_memory_external_data.onnx";
+ {
+ std::ofstream file(model_path, std::ios::binary);
+ ASSERT_TRUE(file.is_open());
+ ASSERT_TRUE(model_proto.SerializeToOstream(&file));
+ }
+
+ // Test with ORT_DISABLE_ALL optimization which should trigger the bug without the fix
+ SessionOptions so;
+ so.graph_optimization_level = TransformerLevel::Default; // This triggers the issue
+ so.session_logid = "GraphTest.ShapeInferenceWithInMemoryExternalDataViaSession";
+
+ InferenceSession session_object{so, GetEnvironment()};
+
+ // This should succeed with the fix, fail without it
+ ASSERT_STATUS_OK(session_object.Load(model_path));
+ ASSERT_STATUS_OK(session_object.Initialize());
+
+ // Clean up
+ std::remove(model_path.c_str());
+}
+
+// Test that explicitly triggers the in-memory externalization and then shape inference
+// This test directly reproduces the bug scenario
+TEST_F(GraphTest, ShapeInferenceAfterInitializerExternalization) {
+ // Create a model with a Split node that depends on a constant initializer
+ ModelProto model_proto;
+ model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
+ auto* opset = model_proto.add_opset_import();
+ opset->set_version(17);
+
+ auto* graph_proto = model_proto.mutable_graph();
+ graph_proto->set_name("test_graph");
+
+ // Create initializer directly (not as Constant node) with 128 bytes
+ auto* initializer = graph_proto->add_initializer();
+ initializer->set_name("split_sizes");
+ initializer->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
+ initializer->add_dims(16); // 16 * 8 = 128 bytes
+ for (int64_t i = 0; i < 16; ++i) {
+ initializer->add_int64_data(1);
+ }
+
+ // Create a Split node that uses this initializer
+ auto* split_node = graph_proto->add_node();
+ split_node->set_op_type("Split");
+ split_node->set_name("split_node");
+ split_node->add_input("input_data");
+ split_node->add_input("split_sizes"); // Uses the large initializer
+ for (int i = 0; i < 16; ++i) {
+ split_node->add_output("split_output_" + std::to_string(i));
+ }
+
+ auto* axis_attr = split_node->add_attribute();
+ axis_attr->set_name("axis");
+ axis_attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT);
+ axis_attr->set_i(0);
+
+ // Add graph input
+ auto* input = graph_proto->add_input();
+ input->set_name("input_data");
+ auto* input_type = input->mutable_type()->mutable_tensor_type();
+ input_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
+ input_type->mutable_shape()->add_dim()->set_dim_value(16);
+ input_type->mutable_shape()->add_dim()->set_dim_value(10);
+
+ // Add graph outputs
+ for (int i = 0; i < 16; ++i) {
+ auto* output = graph_proto->add_output();
+ output->set_name("split_output_" + std::to_string(i));
+ }
+
+ // Load model
+ std::shared_ptr model;
+ ASSERT_STATUS_OK(Model::Load(std::move(model_proto), model, nullptr, *logger_));
+
+ Graph& graph = model->MainGraph();
+ // First resolve should succeed
+ ASSERT_STATUS_OK(graph.Resolve());
+
+ // Now trigger the in-memory externalization
+ // This converts initializers > 127 bytes to OrtValues with external data references
+ Status convert_status = graph.ConvertInitializersIntoOrtValues();
+ ASSERT_TRUE(convert_status.IsOK()) << "ConvertInitializersIntoOrtValues failed: " << convert_status.ErrorMessage();
+
+ // Check if the initializer was actually externalized
+ const ONNX_NAMESPACE::TensorProto* initializer_after = nullptr;
+ ASSERT_TRUE(graph.GetInitializedTensor("split_sizes", initializer_after));
+ ASSERT_NE(initializer_after, nullptr);
+ // Debug: verify it was externalized
+ ASSERT_TRUE(utils::HasExternalDataInMemory(*initializer_after))
+ << "Initializer was not externalized to in-memory external data";
+
+ // Mark the graph as needing resolve to force shape inference to run again
+ graph.SetGraphResolveNeeded();
+
+ // Resolve again - this should trigger shape inference with the externalized initializer
+ // Without the fix, this will fail with "Cannot parse data from external tensors"
+ // With the fix, getInputData() materializes the external data for shape inference
+ Status second_resolve = graph.Resolve();
+ ASSERT_TRUE(second_resolve.IsOK()) << "Second resolve failed: " << second_resolve.ErrorMessage();
+}
+
} // namespace test
} // namespace onnxruntime
diff --git a/onnxruntime/test/mlas/bench/bench_sconv.cpp b/onnxruntime/test/mlas/bench/bench_sconv.cpp
index 39d135236b89c..dc37980002978 100644
--- a/onnxruntime/test/mlas/bench/bench_sconv.cpp
+++ b/onnxruntime/test/mlas/bench/bench_sconv.cpp
@@ -3,6 +3,7 @@
#include "mlas.h"
#include "bench_util.h"
+#include "core/util/thread_utils.h"
#include
#include
@@ -138,6 +139,113 @@ void SCONV_NCHW(benchmark::State& state, const char* /*dummy*/) {
}
}
+static MLAS_THREADPOOL* GetMlasThreadPoolForConvBenchmark(void) {
+ static auto threadpool = std::make_unique(
+ &onnxruntime::Env::Default(), onnxruntime::ThreadOptions(), nullptr, 4, true);
+ return threadpool.get();
+}
+
+void SCONV_NCHW_THREADED(benchmark::State& state, const char* /*dummy*/) {
+ MLAS_THREADPOOL* tp = GetMlasThreadPoolForConvBenchmark();
+
+ const int64_t rank = state.range(0); // Rank
+ const int64_t batch_size = state.range(1); // N
+ const int64_t groups = state.range(2); // G
+ const int64_t input_channels_per_group = state.range(3); // Cpg
+ const int64_t output_channels_per_group = state.range(4); // Fpg
+
+ if (rank <= 0) throw std::invalid_argument("Kernel rank must greater than 0!");
+ if (batch_size <= 0) throw std::invalid_argument("Batch size must greater than 0!");
+ if (groups <= 0) throw std::invalid_argument("Group count must greater than 0!");
+ if (input_channels_per_group <= 0) throw std::invalid_argument("input_channels_per_group must greater than 0!");
+ if (output_channels_per_group <= 0) throw std::invalid_argument("output_channels_per_group must greater than 0!");
+
+ size_t arg_position = 5;
+ const auto input_shape = BenchArgsVector(state, arg_position, rank);
+ const auto kernel_shape = BenchArgsVector(state, arg_position, rank);
+ const auto paddings = BenchArgsVector(state, arg_position, rank * 2);
+ const auto strides = BenchArgsVector(state, arg_position, rank);
+ const auto dilations = BenchArgsVector(state, arg_position, rank);
+
+ // do not check the size of each vector as they are forced from args.
+ if (std::any_of(input_shape.begin(), input_shape.end(), [](const int64_t& dim) { return dim <= 0; })) {
+ throw std::invalid_argument("all input image dim must > 0");
+ }
+
+ if (std::any_of(kernel_shape.begin(), kernel_shape.end(), [](const int64_t& dim) { return dim <= 0; })) {
+ throw std::invalid_argument("all kernel dim must > 0");
+ }
+
+ if (std::any_of(strides.begin(), strides.end(), [](const int64_t& dim) { return dim <= 0; })) {
+ throw std::invalid_argument("all strides dim must > 0");
+ }
+
+ if (std::any_of(dilations.begin(), dilations.end(), [](const int64_t& dim) { return dim <= 0; })) {
+ throw std::invalid_argument("all dilations dim must > 0");
+ }
+
+ const int64_t GC = groups * input_channels_per_group;
+ const int64_t GF = groups * output_channels_per_group;
+ std::vector x_shape = {batch_size, GC};
+ x_shape.insert(x_shape.end(), input_shape.begin(), input_shape.end());
+ std::vector f_shape = {GF, input_channels_per_group};
+ f_shape.insert(f_shape.end(), kernel_shape.begin(), kernel_shape.end());
+
+ std::vector output_shape((size_t)rank);
+ for (int64_t i = 0; i < rank; ++i) {
+ auto km = 1 + dilations[i] * (kernel_shape[i] - 1);
+ output_shape[i] = (paddings[i] + paddings[i + rank] + input_shape[i] - km) / strides[i] + 1;
+ }
+ std::vector y_shape = {batch_size, GF};
+ y_shape.insert(y_shape.end(), output_shape.begin(), output_shape.end());
+
+ MLAS_ACTIVATION activation;
+ activation.ActivationKind = MlasIdentityActivation;
+ MLAS_CONV_PARAMETERS Parameters;
+ size_t WorkingBufferSize = 0;
+ MlasConvPrepare(&Parameters,
+ static_cast(rank),
+ static_cast(batch_size),
+ static_cast(groups),
+ static_cast(input_channels_per_group),
+ input_shape.data(),
+ kernel_shape.data(),
+ dilations.data(),
+ paddings.data(),
+ strides.data(),
+ output_shape.data(),
+ static_cast(output_channels_per_group),
+ &activation,
+ &WorkingBufferSize,
+ 0.0f,
+ tp);
+
+ auto X = RandomVectorUniform(x_shape, -2.0, 2.0);
+ auto F = RandomVectorUniform(f_shape, -1.0, 1.0);
+ int64_t y_size = std::accumulate(y_shape.begin(), y_shape.end(), 1LL, std::multiplies());
+ std::vector Y(static_cast(y_size));
+ std::vector working_buffer(WorkingBufferSize);
+
+ // warm up first round.
+ MlasConv(&Parameters,
+ X.data(),
+ F.data(),
+ nullptr,
+ working_buffer.data(),
+ Y.data(),
+ tp);
+
+ for (auto _ : state) {
+ MlasConv(&Parameters,
+ X.data(),
+ F.data(),
+ nullptr,
+ working_buffer.data(),
+ Y.data(),
+ tp);
+ }
+}
+
static void ResNet50(benchmark::internal::Benchmark* b) {
b->ArgNames(ArgNamesForConv(2));
@@ -221,6 +329,7 @@ static void TeamsModel(benchmark::internal::Benchmark* b) {
}
BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime();
+BENCHMARK_CAPTURE(SCONV_NCHW_THREADED, TeamsModel, "")->Apply(TeamsModel)->UseRealTime();
static void General_Conv2d(benchmark::internal::Benchmark* b) {
b->ArgNames(ArgNamesForConv(2));
diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc
index 0b8624ad6c67f..7c84aefa1c01f 100644
--- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc
+++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc
@@ -339,6 +339,61 @@ TEST(ConvTest, Conv2D_2) {
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true);
}
+TEST(ConvTest, Conv2D_3) {
+ ConvOpAndTestAttributes attrs = {
+ "", // auto_pad
+ vector{1, 1}, // dilations
+ 2, // group
+ vector{2, 2}, // kernel_shape
+ vector{0, 0, 0, 0}, // pads
+ vector{1, 1}, // strides
+ {} // excluded EPs
+ };
+
+ vector X_shape = {2, 2, 3, 3};
+ vector X = {1.f, 2.f, 3.f,
+ 4.f, 5.f, 6.f,
+ 7.f, 8.f, 9.f,
+
+ 10.f, 11.f, 12.f,
+ 13.f, 14.f, 15.f,
+ 16.f, 17.f, 18.f,
+
+ 1.f, 2.f, 3.f,
+ 7.f, 8.f, 9.f,
+ 4.f, 5.f, 6.f,
+
+ 13.f, 14.f, 15.f,
+ 10.f, 11.f, 12.f,
+ 16.f, 17.f, 18.f};
+
+ vector W_shape = {2, 1, 2, 2};
+ vector W = {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f};
+
+ vector Y_shape = {2, 2, 2, 2};
+ auto Y = {
+ 37.f,
+ 47.f,
+ 67.f,
+ 77.f,
+ 254.f,
+ 274.f,
+ 314.f,
+ 334.f,
+ 58.f,
+ 68.f,
+ 55.f,
+ 65.f,
+ 230.f,
+ 250.f,
+ 296.f,
+ 316.f,
+ };
+
+ TestConvOp(attrs, {X, W}, {X_shape, W_shape}, Y, Y_shape);
+ TestConvOp(attrs, {X, W}, {X_shape, W_shape}, Y, Y_shape, true);
+}
+
TEST(ConvTest, Conv2D_Bias_1) {
ConvOpAndTestAttributes attrs = {
"", // auto_pad
diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
index 706bd3c0fce62..7851f3f8c0d35 100644
--- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
+++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
@@ -713,6 +713,52 @@ TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) {
ASSERT_TRUE(status.IsOK());
}
+TEST(TensorrtExecutionProviderTest, DDSOutputTest) {
+ PathString model_name = ORT_TSTR("testdata/ort_github_issue_26272_dds.onnx");
+ SessionOptions so;
+ so.session_logid = "TensorrtExecutionProviderRunWithDDSOutput";
+ RunOptions run_options;
+ run_options.run_tag = so.session_logid;
+ InferenceSession session_object{so, GetEnvironment()};
+ auto cuda_provider = DefaultCudaExecutionProvider();
+ auto cuda_allocator = cuda_provider->CreatePreferredAllocators()[1];
+ std::vector dims_op_x = {3, 4};
+ std::vector values_op_x(12, 0.f); // 12=3*4
+ OrtValue ml_value_x;
+ CreateMLValue(cuda_allocator, dims_op_x, values_op_x, &ml_value_x);
+
+ NameMLValMap feeds;
+ feeds.insert(std::make_pair("data", ml_value_x));
+
+ // prepare outputs
+ std::vector output_names;
+ output_names.push_back("output");
+ std::vector fetches;
+
+ OrtTensorRTProviderOptionsV2 params;
+ std::unique_ptr execution_provider = TensorrtExecutionProviderWithOptions(¶ms);
+ EXPECT_TRUE(session_object.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
+ auto status = session_object.Load(model_name);
+ ASSERT_TRUE(status.IsOK());
+ status = session_object.Initialize();
+ ASSERT_TRUE(status.IsOK());
+
+ // First pass run
+ status = session_object.Run(run_options, feeds, output_names, &fetches);
+ ASSERT_TRUE(status.IsOK());
+
+ // Second pass run with new shape
+ dims_op_x = {6, 4};
+ values_op_x.resize(24, 0.f); // 24=6*4
+ CreateMLValue(cuda_allocator, dims_op_x, values_op_x, &ml_value_x);
+ feeds.clear();
+
+ feeds.insert(std::make_pair("data", ml_value_x));
+
+ status = session_object.Run(run_options, feeds, output_names, &fetches);
+ ASSERT_TRUE(status.IsOK());
+}
+
TEST_P(TensorrtExecutionProviderCacheTest, Run) {
// GetParam() returns the parameter of following format:
// ##cache type##_##input shape type##
diff --git a/onnxruntime/test/testdata/ort_github_issue_26272.py b/onnxruntime/test/testdata/ort_github_issue_26272.py
new file mode 100644
index 0000000000000..fa381e5df1094
--- /dev/null
+++ b/onnxruntime/test/testdata/ort_github_issue_26272.py
@@ -0,0 +1,26 @@
+import onnx
+from onnx import TensorProto, helper
+
+# Create a simple ONNX model with DDS output
+input = helper.make_tensor_value_info("data", TensorProto.FLOAT, ["d1", "d2"])
+output = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["nzr"])
+
+nonzeros_node = helper.make_node("NonZero", ["data"], ["nonzeros"], "nonzeros_node")
+transpose_node = helper.make_node("Transpose", ["nonzeros"], ["nonzeros_t"], "transpose_node")
+gathernd_node = helper.make_node("GatherND", ["data", "nonzeros_t"], ["output"], "gathernd_node")
+
+value_info = [
+ helper.make_tensor_value_info("nonzeros", TensorProto.INT64, [2, "nzr"]),
+ helper.make_tensor_value_info("nonzeros_t", TensorProto.INT64, ["nzr", 2]),
+]
+
+graph = helper.make_graph(
+ [nonzeros_node, transpose_node, gathernd_node],
+ "test_graph",
+ [input],
+ [output],
+ value_info=value_info,
+)
+
+model = helper.make_model(graph)
+onnx.save(model, "ort_github_issue_26272_dds.onnx")
diff --git a/onnxruntime/test/testdata/ort_github_issue_26272_dds.onnx b/onnxruntime/test/testdata/ort_github_issue_26272_dds.onnx
new file mode 100644
index 0000000000000..371f99c537898
--- /dev/null
+++ b/onnxruntime/test/testdata/ort_github_issue_26272_dds.onnx
@@ -0,0 +1,28 @@
+
+:“
+(
+datanonzeros
nonzeros_node"NonZero
+1
+nonzeros
+nonzeros_ttranspose_node" Transpose
+3
+data
+
+nonzeros_toutput
gathernd_node"GatherND
+test_graphZ
+data
+
+d1
+d2b
+output
+
+nzrj
+nonzeros
+
+
+nzrj
+
+nonzeros_t
+
+nzr
+B
\ No newline at end of file
diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml
index 61afeba2d302b..e7e541205ba0a 100644
--- a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml
+++ b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml
@@ -60,7 +60,7 @@ stages:
msbuildPlatform: x64
packageName: x64-cuda
CudaVersion: ${{ parameters.CudaVersion }}
- buildparameter: --use_cuda --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=75-real;86-real;89-real;90a-virtual"
+ buildparameter: --use_cuda --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=75-real;86-real;89-real;90-virtual"
runTests: ${{ parameters.RunOnnxRuntimeTests }}
buildJava: ${{ parameters.buildJava }}
java_artifact_id: onnxruntime_gpu
@@ -80,7 +80,7 @@ stages:
msbuildPlatform: x64
CudaVersion: ${{ parameters.CudaVersion }}
packageName: x64-tensorrt
- buildparameter: --use_tensorrt --tensorrt_home=${{ parameters.win_trt_home }} --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=75-real;86-real;89-real;90a-virtual"
+ buildparameter: --use_tensorrt --tensorrt_home=${{ parameters.win_trt_home }} --cuda_home=${{ parameters.win_cuda_home }} --enable_onnx_tests --enable_wcos --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=75-real;86-real;89-real;90-virtual"
runTests: ${{ parameters.RunOnnxRuntimeTests }}
buildJava: ${{ parameters.buildJava }}
java_artifact_id: onnxruntime_gpu
diff --git a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml
index f3d3b2a8ecbf2..91910dc2e70da 100644
--- a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml
+++ b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml
@@ -39,7 +39,7 @@ stages:
PYTHON_VERSION: ${{ python_version }}
EP_NAME: gpu
CudaVersion: ${{ parameters.cuda_version }}
- EP_BUILD_FLAGS: --enable_lto --use_cuda --cuda_home=$(Agent.TempDirectory)\v${{ parameters.cuda_version }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90a-virtual"
+ EP_BUILD_FLAGS: --enable_lto --use_cuda --cuda_home=$(Agent.TempDirectory)\v${{ parameters.cuda_version }} --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52-real;61-real;75-real;86-real;89-real;90-virtual"
use_tensorrt: True
- template: py-linux-gpu-stage.yml
diff --git a/tools/ci_build/github/linux/build_cuda_c_api_package.sh b/tools/ci_build/github/linux/build_cuda_c_api_package.sh
index 9cc140f41cf91..2f3ac991aee9c 100755
--- a/tools/ci_build/github/linux/build_cuda_c_api_package.sh
+++ b/tools/ci_build/github/linux/build_cuda_c_api_package.sh
@@ -2,4 +2,4 @@
set -e -x
docker run -e SYSTEM_COLLECTIONURI --rm --volume \
$BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}build \
-/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --enable_lto --build_java --build_nodejs --build_dir /build --config Release --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr/local/cuda-$CUDA_VERSION --skip_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90a-real;90a-virtual' 'onnxruntime_USE_FPA_INTB_GEMM=OFF' && cd /build/Release && make install DESTDIR=/build/installed"
+/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --enable_lto --build_java --build_nodejs --build_dir /build --config Release --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr/local/cuda-$CUDA_VERSION --skip_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90a-real;90-virtual' 'onnxruntime_USE_FPA_INTB_GEMM=OFF' && cd /build/Release && make install DESTDIR=/build/installed"
diff --git a/tools/ci_build/github/linux/build_linux_python_package.sh b/tools/ci_build/github/linux/build_linux_python_package.sh
index 5684029c72049..12fed185975a7 100755
--- a/tools/ci_build/github/linux/build_linux_python_package.sh
+++ b/tools/ci_build/github/linux/build_linux_python_package.sh
@@ -70,7 +70,7 @@ fi
if [ "$BUILD_DEVICE" == "GPU" ]; then
SHORT_CUDA_VERSION=$(echo $CUDA_VERSION | sed 's/\([[:digit:]]\+\.[[:digit:]]\+\)\.[[:digit:]]\+/\1/')
#Enable CUDA and TRT EPs.
- BUILD_ARGS+=("--use_cuda" "--use_tensorrt" "--cuda_version=$SHORT_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--nvcc_threads=1" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;86-real;90a-real;90a-virtual" "onnxruntime_USE_FPA_INTB_GEMM=OFF")
+ BUILD_ARGS+=("--use_cuda" "--use_tensorrt" "--cuda_version=$SHORT_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$SHORT_CUDA_VERSION" "--nvcc_threads=1" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;86-real;90a-real;90-virtual" "onnxruntime_USE_FPA_INTB_GEMM=OFF")
fi
if [ "$BUILD_DEVICE" == "NPU" ]; then
diff --git a/tools/ci_build/github/linux/build_nodejs_package.sh b/tools/ci_build/github/linux/build_nodejs_package.sh
index cc6443cc7fab6..ff5c504376d1d 100755
--- a/tools/ci_build/github/linux/build_nodejs_package.sh
+++ b/tools/ci_build/github/linux/build_nodejs_package.sh
@@ -3,4 +3,4 @@ set -e -x
mkdir -p $HOME/.onnx
docker run -e SYSTEM_COLLECTIONURI --rm --volume /data/onnx:/data/onnx:ro --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \
--volume /data/models:/build/models:ro --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}xtrt86build \
-/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_nodejs --use_webgpu --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90a-real;90a-virtual' --use_vcpkg --use_vcpkg_ms_internal_asset_cache && cd /build/Release && make install DESTDIR=/build/installed"
+/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_nodejs --use_webgpu --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90a-real;90-virtual' --use_vcpkg --use_vcpkg_ms_internal_asset_cache && cd /build/Release && make install DESTDIR=/build/installed"
diff --git a/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh b/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh
index b8d968c82d002..c0849bf0ace73 100755
--- a/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh
+++ b/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh
@@ -3,4 +3,4 @@ set -e -x
mkdir -p $HOME/.onnx
docker run -e SYSTEM_COLLECTIONURI --rm --volume /data/onnx:/data/onnx:ro --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \
--volume /data/models:/build/models:ro --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}xtrt86build \
-/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_java --build_nodejs --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90a-real;90a-virtual' 'onnxruntime_USE_FPA_INTB_GEMM=OFF' --use_vcpkg --use_vcpkg_ms_internal_asset_cache && cd /build/Release && make install DESTDIR=/build/installed"
+/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_java --build_nodejs --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90a-real;90-virtual' 'onnxruntime_USE_FPA_INTB_GEMM=OFF' --use_vcpkg --use_vcpkg_ms_internal_asset_cache && cd /build/Release && make install DESTDIR=/build/installed"