Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dockerfiles/Dockerfile.source
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ RUN cd /code && /bin/bash ./build.sh --allow_running_as_root --skip_submodule_sy
FROM mcr.microsoft.com/azurelinux/base/python:3
COPY --from=0 /code/build/Linux/Release/dist /root
COPY --from=0 /code/dockerfiles/LICENSE-IMAGE.txt /code/LICENSE-IMAGE.txt
RUN tdnf install -y ca-certificates python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf python3-mpmath python3-sympy && python3 -m pip install coloredlogs humanfriendly && python3 -m pip install --no-index --find-links /root onnxruntime && rm -rf /root/*.whl
RUN tdnf install -y ca-certificates python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf python3-mpmath python3-sympy && python3 -m pip install humanfriendly && python3 -m pip install --no-index --find-links /root onnxruntime && rm -rf /root/*.whl
1 change: 0 additions & 1 deletion docs/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ furo
pyquickhelper
pandas
pydot
coloredlogs
flatbuffers
numpy<2.0.0
packaging
Expand Down
25 changes: 25 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -7195,6 +7195,31 @@ struct OrtApi {
* \since 1.24
*/
ORT_API_T(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream);

/** \brief Get the element data type and shape for an OrtValue that represents a Tensor (scalar, dense, or sparse).
*
* \note This function is an alternative to ::GetTensorTypeAndShape() that does not allocate a new array for
* the shape data. The OrtValue instance's internal shape data is returned directly.
*
* \note Returns an error if the underlying OrtValue is not a Tensor.
*
* \param[in] value The OrtValue instance.
* \param[out] elem_type Output parameter set to the tensor element data type.
* \param[out] shape_data Output parameter set to the OrtValue instance's internal shape data array.
* For a scalar, `shape_data` is NULL and `shape_data_count` is 0.
* Must not be released as it is owned by the OrtValue instance. This pointer becomes invalid
* when the OrtValue is released or if the underlying shape data is updated or reallocated.
* \param[out] shape_data_count Output parameter set to the number of elements in `shape_data`.
* `shape_data_count` is 0 for a scalar.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value,
_Out_ ONNXTensorElementDataType* elem_type,
_Outptr_result_maybenull_ const int64_t** shape_data,
_Out_ size_t* shape_data_count);
};

/*
Expand Down
13 changes: 13 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2220,6 +2220,19 @@ struct ConstValueImpl : Base<T> {
const R* GetSparseTensorValues() const;

#endif

/// <summary>
/// Returns the tensor's element type and a reference to the tensor's internal shape data. The shape data is owned
/// by the Ort::Value and becomes invalid when the Ort::Value is destroyed or if the underlying shape data is
/// updated or reallocated.
///
/// For a scalar, shape.shape is nullptr and shape.shape_len is 0.
///
/// Wraps OrtApi::GetTensorElementTypeAndShapeDataReference.
/// </summary>
/// <param name="elem_type">Output parameter set to the element's data type.</param>
/// <param name="shape">Output parameter set to the OrtValue instance's shape data and number of elements.</param>
void GetTensorElementTypeAndShapeDataReference(ONNXTensorElementDataType& elem_type, Shape& shape) const;
};

template <typename T>
Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -2377,6 +2377,13 @@ inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {

#endif

template <typename T>
void ConstValueImpl<T>::GetTensorElementTypeAndShapeDataReference(ONNXTensorElementDataType& elem_type,
Shape& shape) const {
ThrowOnError(GetApi().GetTensorElementTypeAndShapeDataReference(this->p_, &elem_type, &shape.shape,
&shape.shape_len));
}

template <typename T>
void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
Expand Down
58 changes: 58 additions & 0 deletions onnxruntime/core/framework/tensor_type_and_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,64 @@ std::unique_ptr<OrtTensorTypeAndShapeInfo> OrtTensorTypeAndShapeInfo::GetTensorS
return GetTensorShapeAndTypeHelper(type, shape, dim_params);
}

ORT_API_STATUS_IMPL(OrtApis::GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value,
_Out_ ONNXTensorElementDataType* elem_type,
_Outptr_result_maybenull_ const int64_t** shape_data,
_Out_ size_t* shape_data_count) {
API_IMPL_BEGIN
if (!value->IsAllocated() || (!value->IsTensor() && !value->IsSparseTensor())) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Input parameter `value` must contain a constructed tensor or sparse tensor");
}

if (elem_type == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Output parameter `elem_type` must not be NULL");
}

if (shape_data == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Output parameter `shape_data` must not be NULL");
}

if (shape_data_count == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Output parameter `shape_data_count` must not be NULL");
}

gsl::span<const int64_t> shape_span;
onnxruntime::MLDataType ml_data_type = nullptr;
ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;

if (value->IsTensor()) {
const Tensor& tensor = value->Get<onnxruntime::Tensor>();
ml_data_type = tensor.DataType();
shape_span = tensor.Shape().GetDims();
} else {
#if !defined(DISABLE_SPARSE_TENSORS)
const SparseTensor& tensor = value->Get<onnxruntime::SparseTensor>();
ml_data_type = tensor.DataType();
shape_span = tensor.DenseShape().GetDims();
#else
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SparseTensor is not supported in this build.");
#endif
}

if (ml_data_type != nullptr) {
type = MLDataTypeToOnnxRuntimeTensorElementDataType(ml_data_type);
}

if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
return OrtApis::CreateStatus(ORT_FAIL, "Tensor does not have a valid or supported tensor element data type");
}

*elem_type = type;
*shape_data = shape_span.empty() ? nullptr : shape_span.data();
*shape_data_count = shape_span.size();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape,
_In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) {
API_IMPL_BEGIN
Expand Down
48 changes: 24 additions & 24 deletions onnxruntime/core/mlas/lib/qlutgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,32 +548,32 @@ MlasLutGemm(

// const int num_groups = static_cast<int>(K / BlkLen);

// Parallelize over M (batch dimension)
// Each iteration processes one row of the activation matrix
// Iterate over M (batch dimension)
// Each iteration processes one row of the activation matrix.
// NOTE: This loop is intentionally serialized. Previous attempts to parallelize
// using MlasTrySimpleParallel caused flaky test failures (race conditions)
// when M > 1 (e.g., Batch32 case). Since GenerateLUT is lightweight,
// serial execution ensures correctness with negligible performance impact.
// TODO(vraspar): Ideally we have to do block parallelism here

MlasTrySimpleParallel(
threadpool,
static_cast<size_t>(M),
[&](ptrdiff_t ine11) {
const size_t row_offset = static_cast<size_t>(ine11) * K;
const size_t lut_offset = static_cast<size_t>(ine11) * K * 4; // 4 bytes per K element for 2-bit LUT
const size_t scale_bias_offset = static_cast<size_t>(ine11) * lut_scales_size;

// Call the dispatch function for this row
// ggml_tmac_mul_mat_task_init
Dispatch->GenerateLUT(
const_cast<float*>(a_float + row_offset), // Input activation for this row
qlut + lut_offset, // Output LUT for this row
lut_scales + scale_bias_offset, // Scales for this row
lut_biases + scale_bias_offset, // Biases for this row
M,
K,
N,
tmac_params.act_group_size
);
}
);
for (size_t ine11 = 0; ine11 < static_cast<size_t>(M); ine11++) {
const size_t row_offset = ine11 * K;
const size_t lut_offset = ine11 * K * 4; // 4 bytes per K element for 2-bit LUT
const size_t scale_bias_offset = ine11 * lut_scales_size;

// Call the dispatch function for this row
// ggml_tmac_mul_mat_task_init
Dispatch->GenerateLUT(
const_cast<float*>(a_float + row_offset), // Input activation for this row
qlut + lut_offset, // Output LUT for this row
lut_scales + scale_bias_offset, // Scales for this row
lut_biases + scale_bias_offset, // Biases for this row
M,
K,
N,
tmac_params.act_group_size
);
}

// all relevant LUT's have been generated
// equivalent of lut_mul_mat's ggml_backend_tmac_mul_mat function ggml_barrier line
Expand Down
54 changes: 42 additions & 12 deletions onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,21 +187,53 @@ get_bias_scale()
return 3;
}

static inline void
MlasAvx2LoaduDeinterleave32Ps(const float* src, __m256& v0, __m256& v1, __m256& v2, __m256& v3)
{
// Process 32 activations contiguously using loadu + shuffle.
// This allows us to mix neighbors (src[4i], src[4i+1], src[4i+2], src[4i+3]) across lanes,
// which matches the T-MAC weight packing.
// We use loadu + shuffle instead of gather to avoid potential issues with gather
// on some hardware and ensure deterministic behavior.
__m256 vec_b0 = _mm256_loadu_ps(src + 0);
__m256 vec_b1 = _mm256_loadu_ps(src + 8);
__m256 vec_b2 = _mm256_loadu_ps(src + 16);
__m256 vec_b3 = _mm256_loadu_ps(src + 24);

__m256 t0 = _mm256_unpacklo_ps(vec_b0, vec_b1);
__m256 t1 = _mm256_unpackhi_ps(vec_b0, vec_b1);
__m256 t2 = _mm256_unpacklo_ps(vec_b2, vec_b3);
__m256 t3 = _mm256_unpackhi_ps(vec_b2, vec_b3);

__m256 u0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t2)));
__m256 u1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t2)));
__m256 u2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t1), _mm256_castps_pd(t3)));
__m256 u3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t1), _mm256_castps_pd(t3)));

const __m256i perm_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
v0 = _mm256_permutevar8x32_ps(u0, perm_idx);
v1 = _mm256_permutevar8x32_ps(u1, perm_idx);
v2 = _mm256_permutevar8x32_ps(u2, perm_idx);
v3 = _mm256_permutevar8x32_ps(u3, perm_idx);
}

void
partial_max_g4_int8_k8(float* lut_scales, const float* b)
{
// TODO(vraspar): add support for arm neon
const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0);
__m256 vec_b0 = _mm256_i32gather_ps(b + 0, vec_bi, 1);
__m256 vec_b1 = _mm256_i32gather_ps(b + 1, vec_bi, 1);
__m256 vec_b2 = _mm256_i32gather_ps(b + 2, vec_bi, 1);
__m256 vec_b3 = _mm256_i32gather_ps(b + 3, vec_bi, 1);
__m256 vec_b0, vec_b1, vec_b2, vec_b3;
MlasAvx2LoaduDeinterleave32Ps(b, vec_b0, vec_b1, vec_b2, vec_b3);

const __m256 vec_sign = _mm256_set1_ps(-0.0f);
__m256 vec_babs0 = _mm256_andnot_ps(vec_sign, vec_b0);
__m256 vec_babs1 = _mm256_andnot_ps(vec_sign, vec_b1);
__m256 vec_babs2 = _mm256_andnot_ps(vec_sign, vec_b2);
__m256 vec_babs3 = _mm256_andnot_ps(vec_sign, vec_b3);

// The upper bound for the LUT values (mixtures of 4 activations) is the sum
// of their absolute values.
__m256 abssum = _mm256_add_ps(_mm256_add_ps(vec_babs0, vec_babs1), _mm256_add_ps(vec_babs2, vec_babs3));

// Reduce max across lanes to find the global maximum sum in this chunk.
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(abssum, 1), _mm256_castps256_ps128(abssum));
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
Expand All @@ -222,16 +254,14 @@ lut_ctor_g4_int8_impl(
)
{
__m256 vec_lut[16];
float biases = 0.0;
const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0);
float biases = 0.0f;
float scales = *lut_scales;
float t_scales = scales ? 1.0f / scales : 0.0f;

for (int k = 0; k < act_k / 32; ++k) {
__m256 vec_b0 = _mm256_i32gather_ps(b + k * 32 + 0, vec_bi, 1);
__m256 vec_b1 = _mm256_i32gather_ps(b + k * 32 + 1, vec_bi, 1);
__m256 vec_b2 = _mm256_i32gather_ps(b + k * 32 + 2, vec_bi, 1);
__m256 vec_b3 = _mm256_i32gather_ps(b + k * 32 + 3, vec_bi, 1);
const float* b_chunk = b + k * 32;
__m256 vec_b0, vec_b1, vec_b2, vec_b3;
MlasAvx2LoaduDeinterleave32Ps(b_chunk, vec_b0, vec_b1, vec_b2, vec_b3);

PRAGMA_UNROLL
for (int g = 1; g < 16; g += 2) {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4802,6 +4802,8 @@ static constexpr OrtApi ort_api_1_to_24 = {
&OrtApis::EpAssignedNode_GetDomain,
&OrtApis::EpAssignedNode_GetOperatorType,
&OrtApis::RunOptionsSetSyncStream,
&OrtApis::GetTensorElementTypeAndShapeDataReference,
// End of Version 24 - DO NOT MODIFY ABOVE (see above text for more information)
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down Expand Up @@ -4838,6 +4840,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz

static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change");
static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change");
static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change");

// So that nobody forgets to finish an API version, this check will serve as a reminder:
static_assert(std::string_view(ORT_VERSION) == "1.24.0",
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -808,4 +808,9 @@ ORT_API_STATUS_IMPL(EpAssignedSubgraph_GetNodes, _In_ const OrtEpAssignedSubgrap
ORT_API_STATUS_IMPL(EpAssignedNode_GetName, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);
ORT_API_STATUS_IMPL(EpAssignedNode_GetDomain, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);
ORT_API_STATUS_IMPL(EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out);

ORT_API_STATUS_IMPL(GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value,
_Out_ ONNXTensorElementDataType* elem_type,
_Outptr_result_maybenull_ const int64_t** shape_data,
_Out_ size_t* shape_data_count);
} // namespace OrtApis
10 changes: 5 additions & 5 deletions onnxruntime/python/tools/tensorrt/perf/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import timeit
from datetime import datetime

import coloredlogs
import numpy as np
from perf_utils import (
acl,
Expand Down Expand Up @@ -2259,12 +2258,13 @@ def parse_arguments():

def setup_logger(verbose):
if verbose:
coloredlogs.install(
level="DEBUG",
fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
logging.basicConfig(
level=logging.DEBUG,
format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
force=True,
)
else:
coloredlogs.install(fmt="%(message)s")
logging.basicConfig(format="%(message)s", force=True)
logging.getLogger("transformers").setLevel(logging.WARNING)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import pprint
import re

import coloredlogs # noqa: F401
from benchmark import * # noqa: F403
from perf_utils import * # noqa: F403

Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/python/tools/tensorrt/perf/perf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import subprocess
import sys

import coloredlogs # noqa: F401

debug = False
debug_verbose = False

Expand Down
1 change: 0 additions & 1 deletion onnxruntime/python/tools/tensorrt/perf/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
onnxconverter-common
onnxmltools
pandas
coloredlogs
9 changes: 4 additions & 5 deletions onnxruntime/python/tools/transformers/benchmark_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from time import sleep
from typing import Any

import coloredlogs
import numpy
import torch
import transformers
Expand Down Expand Up @@ -147,12 +146,12 @@ def create_onnxruntime_session(

def setup_logger(verbose=True):
if verbose:
coloredlogs.install(
level="DEBUG",
fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
logging.basicConfig(
format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
level=logging.DEBUG,
)
else:
coloredlogs.install(fmt="%(message)s")
logging.basicConfig(format="%(message)s", level=logging.INFO)
logging.getLogger("transformers").setLevel(logging.WARNING)


Expand Down
Loading
Loading