Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 274 additions & 13 deletions onnxruntime/core/mlas/lib/q4_dq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,9 @@ struct BlockwiseQDQQuantizer {
return (val >> (idx << 1)) & 0x3;
} else if constexpr (qbits == 4) {
return (val >> (idx << 2)) & 0xF;
} else if constexpr (qbits == 8) {
(void)idx;
return val;
}
}

Expand All @@ -674,6 +677,10 @@ struct BlockwiseQDQQuantizer {
} else if constexpr (qbits == 4) {
auto shift = idx << 2;
return ((val & 0xF) << shift) | (dst & (~(0xF << shift)));
} else if constexpr (qbits == 8) {
(void)idx;
(void)dst;
return val;
}
}

Expand Down Expand Up @@ -813,21 +820,185 @@ struct BlockwiseQDQQuantizer {
src_zero_points || signed_quant || dst_zero_points,
"Unsigned quant types without zero points must allocate zero points with value 0."
);
// Must avoid multiple thread write to a single byte, which means the starting index
// of a thread block must be even. To achieve that, we need to customize the thread
// block size based on the parity of columns.
if (columns & 1) {
TransposeColumnWiseQuantizedPackUnaligned(
src_weights, src_scales, src_zero_points,
dst_weights, dst_scales, dst_zero_points,
rows, columns, quant_block_size, thread_pool

if constexpr (qbits == 8) {
// 8-bit: each element is one byte, no sub-byte packing needed.
// Simple byte-level transpose from [rows, columns] to [columns, k_blocks, block_size].
auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size;
auto dst_bytes_per_quant_blk = quant_block_size; // 8 bits = 1 byte per element
auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk;

// Transpose weights: src [rows, columns] -> dst [columns, k_blocks, block_size]
MlasTryBatchParallel(
thread_pool, static_cast<ptrdiff_t>(row_quant_blk_num * columns),
[&](ptrdiff_t thread_blk_idx) {
auto row_blk = static_cast<int32_t>(thread_blk_idx / columns);
auto col = static_cast<int32_t>(thread_blk_idx % columns);

auto src_row_start = row_blk * quant_block_size;
auto src_row_end = std::min(src_row_start + quant_block_size, rows);

auto dst_base = col * dstT_num_row + row_blk * dst_bytes_per_quant_blk;
for (auto r = src_row_start; r < src_row_end; ++r) {
auto src_val = src_weights[r * columns + col];
if constexpr (signed_quant) {
src_val ^= 0x80; // INT8 -> UINT8: add 128
}
dst_weights[dst_base + (r - src_row_start)] = src_val;
}
// Zero-pad remaining bytes in the last block if rows % block_size != 0
for (auto r = src_row_end - src_row_start; r < quant_block_size; ++r) {
dst_weights[dst_base + r] = signed_quant ? 0x80 : 0;
}
}
);
} else {
TransposeColumnWiseQuantizedPackAligned(
src_weights, src_scales, src_zero_points,
dst_weights, dst_scales, dst_zero_points,
rows, columns, quant_block_size, thread_pool

// Transpose scales: src [k_blocks, columns] -> dst [columns, k_blocks]
MlasTryBatchParallel(
thread_pool, static_cast<ptrdiff_t>(columns),
[&](ptrdiff_t col) {
auto src_idx = static_cast<int32_t>(col);
auto dst_idx = static_cast<int32_t>(col) * row_quant_blk_num;
for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) {
dst_scales[dst_idx] = src_scales[src_idx];
}
}
);

// Transpose zero points: src [k_blocks, columns] -> dst [columns, k_blocks]
// For 8-bit, zero points are byte-aligned (1 byte each), no packing needed.
if (src_zero_points && dst_zero_points) {
MlasTryBatchParallel(
thread_pool, static_cast<ptrdiff_t>(columns),
[&](ptrdiff_t col) {
auto src_idx = static_cast<int32_t>(col);
auto dst_idx = static_cast<int32_t>(col) * row_quant_blk_num;
for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) {
auto zp = src_zero_points[src_idx];
if constexpr (signed_quant) {
zp ^= 0x80; // INT8 -> UINT8
}
dst_zero_points[dst_idx] = zp;
}
}
);
}
} else if constexpr (qbits == 2) {
// 2-bit: 4 elements per byte. Element-by-element transpose.
constexpr int32_t kPackSize = 4;
auto row_quant_blk_num = (rows + quant_block_size - 1) / quant_block_size;
auto packed_src_cols = (columns + kPackSize - 1) / kPackSize;
auto dst_bytes_per_quant_blk = (quant_block_size + kPackSize - 1) / kPackSize;
auto dstT_num_row = row_quant_blk_num * dst_bytes_per_quant_blk;

// Transpose weights: src [rows, ceil(columns/4)] -> dst [columns, k_blocks, ceil(block_size/4)]
// Each thread handles one (row_block, column) pair writing to non-overlapping dst ranges.
MlasTryBatchParallel(
thread_pool, static_cast<ptrdiff_t>(row_quant_blk_num * columns),
[&](ptrdiff_t thread_blk_idx) {
auto row_blk = static_cast<int32_t>(thread_blk_idx / columns);
auto col = static_cast<int32_t>(thread_blk_idx % columns);

auto src_row_start = row_blk * quant_block_size;
auto src_row_end = std::min(src_row_start + quant_block_size, rows);

auto dst_base = col * dstT_num_row + row_blk * dst_bytes_per_quant_blk;

// Zero destination bytes for this block
for (int32_t b = 0; b < dst_bytes_per_quant_blk; ++b) {
dst_weights[dst_base + b] = 0;
}

for (auto r = src_row_start; r < src_row_end; ++r) {
// Extract 2-bit value from source
auto src_byte_idx = r * packed_src_cols + col / kPackSize;
auto src_bit_shift = (col % kPackSize) * 2;
uint8_t val = (src_weights[src_byte_idx] >> src_bit_shift) & 0x3;

if constexpr (signed_quant) {
val ^= 0x2; // int2[-2,1] -> uint2[0,3]
}

// Place in destination
auto r_in_blk = r - src_row_start;
auto dst_byte_off = r_in_blk / kPackSize;
auto dst_bit_shift = (r_in_blk % kPackSize) * 2;
dst_weights[dst_base + dst_byte_off] |= (val << dst_bit_shift);
}

// Zero-pad remaining positions (unsigned equivalent of 0)
if constexpr (signed_quant) {
for (auto r_in_blk = src_row_end - src_row_start;
r_in_blk < quant_block_size; ++r_in_blk) {
auto dst_byte_off = r_in_blk / kPackSize;
auto dst_bit_shift = (r_in_blk % kPackSize) * 2;
dst_weights[dst_base + dst_byte_off] |= (0x2 << dst_bit_shift);
}
}
}
);

// Transpose scales: src [k_blocks, columns] -> dst [columns, k_blocks]
MlasTryBatchParallel(
thread_pool, static_cast<ptrdiff_t>(columns),
[&](ptrdiff_t col) {
auto src_idx = static_cast<int32_t>(col);
auto dst_idx = static_cast<int32_t>(col) * row_quant_blk_num;
for (int32_t i = 0; i < row_quant_blk_num; ++i, ++dst_idx, src_idx += columns) {
dst_scales[dst_idx] = src_scales[src_idx];
}
}
);

// Transpose zero points: src [k_blocks, ceil(columns/4)] -> dst [columns, ceil(k_blocks/4)]
if (src_zero_points && dst_zero_points) {
auto packed_src_zp_cols = (columns + kPackSize - 1) / kPackSize;
auto zp_dst_bytes_per_col = (row_quant_blk_num + kPackSize - 1) / kPackSize;

MlasTryBatchParallel(
thread_pool, static_cast<ptrdiff_t>(columns),
[&](ptrdiff_t col_idx) {
auto col = static_cast<int32_t>(col_idx);
auto dst_base = col * zp_dst_bytes_per_col;

for (int32_t b = 0; b < zp_dst_bytes_per_col; ++b) {
dst_zero_points[dst_base + b] = 0;
}

for (int32_t blk = 0; blk < row_quant_blk_num; ++blk) {
auto src_byte_idx = blk * packed_src_zp_cols + col / kPackSize;
auto src_bit_shift = (col % kPackSize) * 2;
uint8_t val = (src_zero_points[src_byte_idx] >> src_bit_shift) & 0x3;

if constexpr (signed_quant) {
val ^= 0x2;
}

auto dst_byte_off = blk / kPackSize;
auto dst_bit_shift = (blk % kPackSize) * 2;
dst_zero_points[dst_base + dst_byte_off] |= (val << dst_bit_shift);
}
}
);
}
} else {
// 4-bit sub-byte types: use packing-aware transpose paths.
// Must avoid multiple thread write to a single byte, which means the starting index
// of a thread block must be even. To achieve that, we need to customize the thread
// block size based on the parity of columns.
if (columns & 1) {
TransposeColumnWiseQuantizedPackUnaligned(
src_weights, src_scales, src_zero_points,
dst_weights, dst_scales, dst_zero_points,
rows, columns, quant_block_size, thread_pool
);
} else {
TransposeColumnWiseQuantizedPackAligned(
src_weights, src_scales, src_zero_points,
dst_weights, dst_scales, dst_zero_points,
rows, columns, quant_block_size, thread_pool
);
}
}
}

Expand Down Expand Up @@ -2184,3 +2355,93 @@ MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 4, false>(
int quant_block_size,
MLAS_THREADPOOL* thread_pool
);

template void
MlasQDQTransposeBlockwiseQuantized<float, 8, true>(
const uint8_t* src_weights,
const float* src_scales,
const uint8_t* src_zero_points,
uint8_t* dst_weights,
float* dst_scales,
uint8_t* dst_zero_points,
bool columnwise,
int rows,
int columns,
int quant_block_size,
MLAS_THREADPOOL* thread_pool
);

template void
MlasQDQTransposeBlockwiseQuantized<float, 8, false>(
const uint8_t* src_weights,
const float* src_scales,
const uint8_t* src_zero_points,
uint8_t* dst_weights,
float* dst_scales,
uint8_t* dst_zero_points,
bool columnwise,
int rows,
int columns,
int quant_block_size,
MLAS_THREADPOOL* thread_pool
);

template void
MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 8, true>(
const uint8_t* src_weights,
const MLAS_FP16* src_scales,
const uint8_t* src_zero_points,
uint8_t* dst_weights,
MLAS_FP16* dst_scales,
uint8_t* dst_zero_points,
bool columnwise,
int rows,
int columns,
int quant_block_size,
MLAS_THREADPOOL* thread_pool
);

template void
MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 8, false>(
const uint8_t* src_weights,
const MLAS_FP16* src_scales,
const uint8_t* src_zero_points,
uint8_t* dst_weights,
MLAS_FP16* dst_scales,
uint8_t* dst_zero_points,
bool columnwise,
int rows,
int columns,
int quant_block_size,
MLAS_THREADPOOL* thread_pool
);

template void
MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 2, true>(
const uint8_t* src_weights,
const MLAS_FP16* src_scales,
const uint8_t* src_zero_points,
uint8_t* dst_weights,
MLAS_FP16* dst_scales,
uint8_t* dst_zero_points,
bool columnwise,
int rows,
int columns,
int quant_block_size,
MLAS_THREADPOOL* thread_pool
);

template void
MlasQDQTransposeBlockwiseQuantized<MLAS_FP16, 2, false>(
const uint8_t* src_weights,
const MLAS_FP16* src_scales,
const uint8_t* src_zero_points,
uint8_t* dst_weights,
MLAS_FP16* dst_scales,
uint8_t* dst_zero_points,
bool columnwise,
int rows,
int columns,
int quant_block_size,
MLAS_THREADPOOL* thread_pool
);
Loading
Loading