Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,13 @@ static const char* const kOrtSessionOptionsMlasDisableKleidiAi = "mlas.disable_k
// If not provided, default is 4.
static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level";

// Block size used when converting per-tensor or per-axis DQ + MatMul to MatMulNBits.
// Only applies to DQ nodes without an existing block_size attribute (i.e., per-tensor or per-axis quantization).
// Positive value: explicit block_size (must be power-of-2 and >= 16, e.g., 16, 32, 64, 128).
// "0" or not provided: use default block_size of 32.
// "-1": heuristic - largest power-of-2 <= min(K, 256) that minimizes padding.
static const char* const kOrtSessionOptionsQDQMatMulNBitsBlockSize = "session.qdq_matmulnbits_block_size";

// Enable the DQ->MatMulNBits fusion graph transformer.
// "0": disabled (default). "1": enabled.
// This is typically set automatically by InferenceSession when the NvTensorRTRTX EP is registered.
Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
ParseStringWithClassicLocale<int64_t>(
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel,
"4"));
const int64_t qdq_matmulnbits_block_size =
ParseStringWithClassicLocale<int64_t>(
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsBlockSize,
"0"));
#ifdef MLAS_TARGET_AMD64_IX86
const bool avx2_precision_mode =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow();
Expand All @@ -363,7 +367,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed,
SatApplyContextVariant{},
qdq_matmulnbits_accuracy_level,
intra_op_thread_pool));
intra_op_thread_pool,
qdq_matmulnbits_block_size));
}

transformers.emplace_back(std::make_unique<GemmActivationFusion>(cpu_ep));
Expand Down Expand Up @@ -504,14 +509,19 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
ParseStringWithClassicLocale<int64_t>(
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel,
"4"));
const int64_t qdq_matmulnbits_block_size =
ParseStringWithClassicLocale<int64_t>(
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsBlockSize,
"0"));
// runtime optimizations only support CPU EP now
const InlinedHashSet<std::string_view> cpu_ep = {onnxruntime::kCpuExecutionProvider};

if (!disable_quant_qdq) {
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed,
apply_context,
qdq_matmulnbits_accuracy_level,
intra_op_thread_pool));
intra_op_thread_pool,
qdq_matmulnbits_block_size));
}

transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_ep, apply_context));
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ struct MatMulReplaceWithQLinear : public Action {
// used together with DQMatMulNodeGroupSelector, which does the sanity check
struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew {
DQMatMulToMatMulNBitsAction(int64_t accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool);
concurrency::ThreadPool* intra_op_thread_pool,
int64_t block_size_for_non_blockwise = 0);

private:
std::string OpType(const RuntimeState&) const override { return op_type_; }
Expand All @@ -105,6 +106,7 @@ struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew {
const std::string op_type_;
const std::vector<NodeAndMoveInfo> value_moves_;
concurrency::ThreadPool* intra_op_thread_pool_;
const int64_t block_size_for_non_blockwise_;
};

struct GemmReplaceWithQuant : public Action {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,19 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i

void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry,
int64_t qdq_matmulnbits_accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool) {
concurrency::ThreadPool* intra_op_thread_pool,
int64_t qdq_matmulnbits_block_size) {
// 2 nodes. DQ -> MatMul. DQ is the second input to MatMul.
// DQ's weight is 2/4/8-bit int (int2/uint2, int4/uint4, int8/uint8). DQ's scale is float/float16.
// DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power.
// Also supports per-tensor and per-channel (axis=1) quantized DQ weights by expanding
// scales/zero-points to blockwise format using qdq_matmulnbits_block_size.
const std::string action_name{"DQMatMulToMatMulNBits"};

std::unique_ptr<Action> action =
std::make_unique<QDQ::DQMatMulToMatMulNBitsAction>(qdq_matmulnbits_accuracy_level,
intra_op_thread_pool);
intra_op_thread_pool,
qdq_matmulnbits_block_size);

#if !defined(ORT_MINIMAL_BUILD)
// Include "" (empty string) to match nodes not yet assigned to an EP.
Expand All @@ -315,7 +319,8 @@ void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_regi
std::vector<const char*> providers = {kCpuExecutionProvider, kCudaExecutionProvider, kDmlExecutionProvider, ""};
std::unique_ptr<NodeSelector> selector = std::make_unique<QDQ::DQMatMulToMatMulNBitsSelector>(providers);
qdq_selector_action_registry.RegisterSelectorAndAction(action_name,
{{"MatMul", {}}},
{{"MatMul", {}},
{"Gemm", {}}},
std::move(selector),
std::move(action));

Expand Down Expand Up @@ -370,7 +375,8 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) {
SelectorActionRegistry CreateSelectorActionRegistry(
bool is_int8_allowed,
int64_t qdq_matmulnbits_accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool) {
concurrency::ThreadPool* intra_op_thread_pool,
int64_t qdq_matmulnbits_block_size) {
SelectorActionRegistry qdq_selector_action_registry;
SplitQDQRules(qdq_selector_action_registry);
DropQDQNodesRules(qdq_selector_action_registry);
Expand All @@ -384,7 +390,8 @@ SelectorActionRegistry CreateSelectorActionRegistry(
WhereQDQRules(qdq_selector_action_registry);
DQMatMulToMatMulNBitsRules(qdq_selector_action_registry,
qdq_matmulnbits_accuracy_level,
intra_op_thread_pool);
intra_op_thread_pool,
qdq_matmulnbits_block_size);

return qdq_selector_action_registry;
}
Expand All @@ -395,11 +402,13 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer(
bool is_int8_allowed,
const SatApplyContextVariant& apply_context,
int64_t qdq_matmulnbits_accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool)
concurrency::ThreadPool* intra_op_thread_pool,
int64_t qdq_matmulnbits_block_size)
: SelectorActionTransformer{
"QDQSelectorActionTransformer",
CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level,
intra_op_thread_pool),
intra_op_thread_pool,
qdq_matmulnbits_block_size),
apply_context,
// this transformer is compatible with CPU, DML, ACL and CUDA EP.
// There is further EP control on the rule level.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class QDQSelectorActionTransformer : public SelectorActionTransformer {
QDQSelectorActionTransformer(bool is_int8_allowed,
const SatApplyContextVariant& apply_context = {},
int64_t qdq_matmulnbits_accuracy_level = 4,
concurrency::ThreadPool* intra_op_thread_pool = nullptr);
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
int64_t qdq_matmulnbits_block_size = 0);
};

} // namespace onnxruntime
Loading
Loading