Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,13 @@ static const char* const kOrtSessionOptionsMlasDisableKleidiAi = "mlas.disable_k
// If not provided, default is 4.
static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level";

// Block size used when converting per-tensor or per-axis DQ + MatMul to MatMulNBits.
// Only applies to DQ nodes without an existing block_size attribute (i.e., per-tensor or per-axis quantization).
// Positive value: explicit block_size (must be power-of-2 and >= 16, e.g., 16, 32, 64, 128).
// "0" or not provided: use default block_size of 32.
// "-1": heuristic - largest power-of-2 <= min(K, 256) that minimizes padding.
static const char* const kOrtSessionOptionsQDQMatMulNBitsBlockSize = "session.qdq_matmulnbits_block_size";

// Enable the DQ->MatMulNBits fusion graph transformer.
// "0": disabled (default). "1": enabled.
// This is typically set automatically by InferenceSession when the NvTensorRTRTX EP is registered.
Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
ParseStringWithClassicLocale<int64_t>(
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel,
"4"));
const int64_t qdq_matmulnbits_block_size =
ParseStringWithClassicLocale<int64_t>(
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsBlockSize,
"0"));
#ifdef MLAS_TARGET_AMD64_IX86
const bool avx2_precision_mode =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow();
Expand All @@ -363,7 +367,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed,
SatApplyContextVariant{},
qdq_matmulnbits_accuracy_level,
intra_op_thread_pool));
intra_op_thread_pool,
qdq_matmulnbits_block_size));
}

transformers.emplace_back(std::make_unique<GemmActivationFusion>(cpu_ep));
Expand Down Expand Up @@ -504,14 +509,19 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
ParseStringWithClassicLocale<int64_t>(
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel,
"4"));
const int64_t qdq_matmulnbits_block_size =
ParseStringWithClassicLocale<int64_t>(
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsBlockSize,
"0"));
// runtime optimizations only support CPU EP now
const InlinedHashSet<std::string_view> cpu_ep = {onnxruntime::kCpuExecutionProvider};

if (!disable_quant_qdq) {
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed,
apply_context,
qdq_matmulnbits_accuracy_level,
intra_op_thread_pool));
intra_op_thread_pool,
qdq_matmulnbits_block_size));
}

transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_ep, apply_context));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,67 @@
dt_weight == TensorProto::INT8;
}

// Compute the effective block_size for per-tensor/per-channel DQ nodes that lack a block_size attribute.
// session_block_size: 0 = default (32), positive = explicit, -1 = min-padding heuristic.
int64_t ComputeEffectiveBlockSize(int64_t session_block_size, int64_t K) {
// MatMulNBits CPU kernel currently only supports block_size in [16, 256] correctly.
constexpr int64_t kMinBlockSize = 16;
constexpr int64_t kMaxBlockSize = 256;

if (session_block_size > 0) {
// Explicit block_size — must be power-of-2 and within [kMinBlockSize, kMaxBlockSize].
ORT_ENFORCE(session_block_size >= kMinBlockSize &&
((session_block_size & (session_block_size - 1)) == 0),
"Explicit qdq_matmulnbits_block_size must be a power-of-2 and >= ",
kMinBlockSize, ", got: ", session_block_size);
ORT_ENFORCE(session_block_size <= kMaxBlockSize,
"Explicit qdq_matmulnbits_block_size must be <= ",
kMaxBlockSize, ", got: ", session_block_size);
return session_block_size;
}

if (session_block_size == -1) {
// Heuristic: largest power-of-2 <= min(K, kMaxBlockSize) that minimizes padding.
// Capped at kMaxBlockSize because CPU EP only supports block_size up to kMaxBlockSize correctly.
// We want ceil(K / B) * B - K to be minimized (least wasted padding).
int64_t best_bs = kMinBlockSize;
int64_t best_padding = (((K + (kMinBlockSize - 1)) / kMinBlockSize) * kMinBlockSize) - K;
for (int64_t bs = kMinBlockSize * 2; bs <= std::min(K, kMaxBlockSize); bs *= 2) {

Check warning on line 70 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for min [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc:70: Add #include <algorithm> for min [build/include_what_you_use] [4]
Comment thread
vraspar marked this conversation as resolved.
int64_t padding = (((K + bs - 1) / bs) * bs) - K;
if (padding <= best_padding) {
best_padding = padding;
best_bs = bs;
}
}
return best_bs;
}

// Default (session_block_size == 0): use 32
return 32;
}

// Get the DQ block_size: from the attribute if blockwise, or computed for per-tensor/per-channel.
int64_t GetEffectiveBlockSize(const Node& dq_node, int64_t block_size_for_non_blockwise) {
const auto& dq_attrs = dq_node.GetAttributes();
const auto bs_iter = dq_attrs.find("block_size");
if (bs_iter != dq_attrs.end() && bs_iter->second.i() > 0) {
return bs_iter->second.i();
}

// Derive K from the weight input shape if available. Shape information may be missing even
// when the weight is a constant initializer, so guard against nullptrs / unknown dims.
int64_t K = 32; // reasonable default consistent with ComputeEffectiveBlockSize default
const auto* weight_arg = dq_node.InputDefs()[0];
if (weight_arg != nullptr) {
const auto* shape = weight_arg->Shape();
if (shape != nullptr && shape->dim_size() > 0 && shape->dim(0).has_dim_value()) {
K = static_cast<int64_t>(shape->dim(0).dim_value());
}
}

return ComputeEffectiveBlockSize(block_size_for_non_blockwise, K);
}

// Holds transposed weight/scale/zp tensors and their TensorProtos for MatMulNBits.
// Used by DQMatMulToMatMulNBitsAction.
struct TransposedQuantizedTensors {
Expand All @@ -56,16 +117,17 @@

// Transpose DQ weight/scale/zp tensors from column-wise layout to MatMulNBits layout via MLAS.
// default_zp_name_prefix: prefix for auto-generated zero-point name when unsigned type has no explicit zp.
// effective_block_size: the block_size to use for MatMulNBits (may differ from DQ's block_size for per-tensor/per-channel).
Status TransposeDQWeightsForMatMulNBits(
Graph& graph,
const Node& dq_node,
const std::string& default_zp_name_prefix,
concurrency::ThreadPool* intra_op_thread_pool,
int64_t effective_block_size,
TransposedQuantizedTensors& result) {
const auto* weight_arg = dq_node.InputDefs()[0];
const auto* scale_arg = dq_node.InputDefs()[1];
const auto* zp_arg = dq_node.InputDefs().size() > 2 ? dq_node.InputDefs()[2] : nullptr;
const auto& attrs = dq_node.GetAttributes();

const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr;
ORT_RETURN_IF_NOT(graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto),
Expand All @@ -80,7 +142,7 @@

auto K = weight_arg->Shape()->dim(0).dim_value();
auto N = weight_arg->Shape()->dim(1).dim_value();
Comment thread
jambayk marked this conversation as resolved.
Outdated
auto block_size = attrs.at("block_size").i();
auto block_size = effective_block_size;
int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type();
auto bits = DQWeightBits(dt_weight);
auto quant_num = (K + block_size - 1) / block_size;
Expand All @@ -94,8 +156,100 @@
std::optional<Initializer> zp_src;
auto cpu_allocator = CPUAllocator::DefaultInstance();

// Determine if scale/zp need expansion from per-tensor/per-channel to blockwise [quant_num, N].
const bool is_blockwise = (scale_tensor_proto->dims_size() == 2);
std::optional<Tensor> expanded_scale;
std::optional<Tensor> expanded_zp;

if (!is_blockwise) {
// Expand scale to [quant_num, N]
expanded_scale.emplace(scale_type, TensorShape{quant_num, N}, cpu_allocator);
bool is_per_tensor = (scale_tensor_proto->dims_size() == 0);

auto expand_scale = [&](auto* src_data, auto* dst_data) {
if (is_per_tensor) {
auto val = src_data[0];
for (int64_t i = 0; i < quant_num * N; ++i) {
dst_data[i] = val;
}
} else {
// Per-channel: scale shape [N], replicate across quant_num blocks
for (int64_t b = 0; b < quant_num; ++b) {
for (int64_t n = 0; n < N; ++n) {
dst_data[b * N + n] = src_data[n];
}
}
}
};

if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
expand_scale(scale_src.data<float>(), expanded_scale->MutableData<float>());
} else {
expand_scale(scale_src.data<MLFloat16>(), expanded_scale->MutableData<MLFloat16>());
}

// Expand zp if present
if (zp_tensor_proto) {
zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath());
// Allocate as uint8 with enough bytes to hold quant_num*N packed sub-byte elements.
int64_t expanded_zp_bytes = (quant_num * N * bits + 7) / 8;
expanded_zp.emplace(uint8_type, TensorShape{expanded_zp_bytes}, cpu_allocator);

// For sub-byte types, the zp is packed in bytes. We need to expand element-wise.
// For 8-bit, each byte is one element. For 4-bit, 2 elements per byte. For 2-bit, 4 elements per byte.
const uint8_t* zp_bytes = zp_src->DataAsByteSpan().data();
uint8_t* dst_zp_bytes = expanded_zp->MutableData<uint8_t>();

auto get_element = [bits](const uint8_t* data, int64_t idx) -> uint8_t {
if (bits == 8) return data[idx];
if (bits == 4) {
uint8_t byte = data[idx / 2];
return (idx % 2 == 0) ? (byte & 0x0F) : (byte >> 4);
}
// bits == 2
uint8_t byte = data[idx / 4];
int shift = static_cast<int>((idx % 4) * 2);
return (byte >> shift) & 0x03;
};

auto set_element = [bits](uint8_t* data, int64_t idx, uint8_t val) {
if (bits == 8) {
data[idx] = val;
return;
}
if (bits == 4) {
int64_t byte_idx = idx / 2;
if (idx % 2 == 0) {
data[byte_idx] = (data[byte_idx] & 0xF0) | (val & 0x0F);
} else {
data[byte_idx] = (data[byte_idx] & 0x0F) | ((val & 0x0F) << 4);
}
return;
}
// bits == 2
int64_t byte_idx = idx / 4;
int shift = static_cast<int>((idx % 4) * 2);
data[byte_idx] = (data[byte_idx] & ~(0x03 << shift)) | ((val & 0x03) << shift);
};

// Initialize expanded zp to 0
memset(dst_zp_bytes, 0, expanded_zp->SizeInBytes());

for (int64_t b = 0; b < quant_num; ++b) {
for (int64_t n = 0; n < N; ++n) {
int64_t src_idx = is_per_tensor ? 0 : n;
uint8_t val = get_element(zp_bytes, src_idx);
set_element(dst_zp_bytes, b * N + n, val);
}
}
}
}

auto weight_dst_name = graph.GenerateNodeArgName(weight_arg->Name() + "_T");
result.weight = Tensor(uint8_type, TensorShape{N, quant_num, blob_bytes}, cpu_allocator);
// Zero-initialize: MLAS 4-bit transpose does not zero-pad when K < block_size,
// leaving uninitialized bytes in the last block's padding region.
memset(result.weight.MutableDataRaw(), 0, result.weight.SizeInBytes());

auto scale_dst_name = graph.GenerateNodeArgName(scale_arg->Name() + "_T");
auto scale_size = (TensorShape{N, quant_num}).Size();
Expand All @@ -104,7 +258,13 @@
std::string zp_dst_name;
auto zp_size = (TensorShape{N, (quant_num * bits + 7) / 8}).Size();

if (zp_tensor_proto) {
if (!is_blockwise && expanded_zp.has_value()) {
// Per-tensor/per-channel path with expanded zero-point
zp_dst_name = graph.GenerateNodeArgName(
(zp_arg ? zp_arg->Name() : default_zp_name_prefix + "_zero_point") + "_T");
result.zero_point = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator);
} else if (zp_tensor_proto) {
// Blockwise path with explicit zero-point
zp_src.emplace(graph, *zp_tensor_proto, graph.ModelPath());
zp_dst_name = graph.GenerateNodeArgName(zp_arg->Name() + "_T");
result.zero_point = Tensor(uint8_type, TensorShape{zp_size}, cpu_allocator);
Expand All @@ -116,10 +276,15 @@

// Dispatch MLAS transpose based on scale type, bits, and signedness.
auto transpose = [&](auto* scale_data, auto* scale_dst_data) {
using ScaleType = std::remove_pointer_t<decltype(scale_data)>;
using ScaleType = std::remove_const_t<std::remove_pointer_t<decltype(scale_data)>>;
bool is_signed = IsDQWeightSigned(dt_weight);
const uint8_t* src_w = weight_src.DataAsByteSpan().data();
const uint8_t* src_zp = zp_src ? zp_src->DataAsByteSpan().data() : nullptr;
const uint8_t* src_zp = nullptr;
if (expanded_zp.has_value()) {
src_zp = expanded_zp->Data<uint8_t>();
} else if (zp_src.has_value()) {
src_zp = zp_src->DataAsByteSpan().data();
}
uint8_t* dst_w = result.weight.MutableData<uint8_t>();
uint8_t* dst_zp = result.zero_point ? result.zero_point->MutableData<uint8_t>() : nullptr;
int K_int = static_cast<int>(K);
Expand Down Expand Up @@ -148,9 +313,11 @@
};

if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
transpose(scale_src.data<float>(), result.scale.MutableData<float>());
const float* s_data = expanded_scale.has_value() ? expanded_scale->Data<float>() : scale_src.data<float>();
transpose(s_data, result.scale.MutableData<float>());
} else {
transpose(scale_src.data<MLFloat16>(), result.scale.MutableData<MLFloat16>());
const MLFloat16* s_data = expanded_scale.has_value() ? expanded_scale->Data<MLFloat16>() : scale_src.data<MLFloat16>();
transpose(s_data, result.scale.MutableData<MLFloat16>());
}

result.weight_proto = utils::TensorToTensorProto(result.weight, weight_dst_name, true);
Expand Down Expand Up @@ -430,7 +597,8 @@

DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(
int64_t accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool)
concurrency::ThreadPool* intra_op_thread_pool,
int64_t block_size_for_non_blockwise)
: accuracy_level_{accuracy_level},
domain_{kMSDomain},
op_type_{"MatMulNBits"},
Expand All @@ -440,7 +608,8 @@
MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput),
MoveAll(target, ArgType::kOutput)};
}()},
intra_op_thread_pool_{intra_op_thread_pool} {
intra_op_thread_pool_{intra_op_thread_pool},
block_size_for_non_blockwise_{block_size_for_non_blockwise} {
ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4");
}

Expand All @@ -449,15 +618,15 @@
NodeAttributes extra_attributes;

const auto* dq_node = runtime_state.selected_nodes.Input(0);
auto& attrs = dq_node->GetAttributes();
const auto* weight_shape = dq_node->InputDefs()[0]->Shape();
Comment thread
jambayk marked this conversation as resolved.

utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes);
utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes);
utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes);
int32_t dt_weight = dq_node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
utils::SetNodeAttribute(utils::MakeAttribute("bits", DQWeightBits(dt_weight)), extra_attributes);
utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes);
int64_t effective_bs = GetEffectiveBlockSize(*dq_node, block_size_for_non_blockwise_);
utils::SetNodeAttribute(utils::MakeAttribute("block_size", effective_bs), extra_attributes);

return extra_attributes;
}
Expand All @@ -467,9 +636,11 @@
Node& replacement_node) const {
const auto* dq_node = selected_nodes.Input(0);

int64_t effective_bs = GetEffectiveBlockSize(*dq_node, block_size_for_non_blockwise_);

TransposedQuantizedTensors transposed;
ORT_RETURN_IF_ERROR(TransposeDQWeightsForMatMulNBits(
graph, *dq_node, "fused_DQ_MatMul", intra_op_thread_pool_, transposed));
graph, *dq_node, "fused_DQ_MatMul", intra_op_thread_pool_, effective_bs, transposed));

auto& input_defs = replacement_node.MutableInputDefs();
input_defs.push_back(&graph_utils::AddInitializerWithOrtValue(graph, transposed.weight_proto, std::move(transposed.weight)));
Expand All @@ -483,6 +654,31 @@
replacement_node.MutableInputArgsCount().push_back(1);
}

// If the target was Gemm, strip Gemm-specific attributes from the replacement MatMulNBits node
// and wire the bias (if present) to MatMulNBits input 5.
const auto& target = selected_nodes.Target();
if (target.OpType() == "Gemm") {
replacement_node.ClearAttribute("alpha");
replacement_node.ClearAttribute("beta");
replacement_node.ClearAttribute("transA");
replacement_node.ClearAttribute("transB");

// Wire Gemm bias to MatMulNBits input 5 (bias slot).
// The bias can be a direct float tensor or the output of a DQ node.
const auto& target_inputs = target.InputDefs();
if (target_inputs.size() > 2 && target_inputs[2] && target_inputs[2]->Exists()) {
// MatMulNBits input layout: 0:A, 1:B, 2:scales, 3:zp(opt), 4:g_idx(opt), 5:bias(opt)
// Pad with empty NodeArgs up to position 5.
NodeArg& empty_arg = graph.GetOrCreateNodeArg("", nullptr);
while (input_defs.size() < 5) {
input_defs.push_back(&empty_arg);
replacement_node.MutableInputArgsCount().push_back(1);
}
input_defs.push_back(const_cast<NodeArg*>(target_inputs[2]));
replacement_node.MutableInputArgsCount().push_back(1);
}
}

return Status::OK();
}

Expand Down
Loading
Loading