Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 62 additions & 167 deletions onnxruntime/core/mlas/lib/qlutgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,45 @@ Module Name:

#include <cassert>
#include <cstring>
#include <functional>
#include <memory>
#include <string>
#include <thread>
#include <mutex>
#include <unordered_map>

/** T-MAC GEMM kernel config key - struct-based for type safety and performance */
struct TMACConfigKey {
size_t M;
size_t N;
size_t nbits;
size_t block_size;
bool has_zero_point;
Comment thread
vraspar marked this conversation as resolved.

bool operator==(const TMACConfigKey& other) const {
return M == other.M && N == other.N && nbits == other.nbits &&
block_size == other.block_size && has_zero_point == other.has_zero_point;
}
};

struct TMACConfigKeyHash {
size_t operator()(const TMACConfigKey& k) const {
// Combine hash values using a simple mixing function
size_t h = std::hash<size_t>{}(k.M);
h ^= std::hash<size_t>{}(k.N) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<size_t>{}(k.nbits) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<size_t>{}(k.block_size) + 0x9e3779b9 + (h << 6) + (h >> 2);
h ^= std::hash<bool>{}(k.has_zero_point) + 0x9e3779b9 + (h << 6) + (h >> 2);
return h;
}
};

/**
* Global cache for T-MAC kernel parameters, indexed by configuration.
* This map and its associated mutex ensure thread-safe parameter management
* across concurrent MLAS calls.
*/
static std::unordered_map<std::string, struct MlasTMACKernelParams> tmac_kernel_configs;
static std::unordered_map<TMACConfigKey, MlasTMACKernelParams, TMACConfigKeyHash> tmac_kernel_configs;
static std::mutex tmac_kernel_configs_mutex;

static std::string
Expand All @@ -47,13 +74,13 @@ GetTmacKey(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_po
MlasTMACKernelParams
MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
{
std::string key = GetTmacKey(M, N, nbits, block_size, has_zero_point);
TMACConfigKey key{M, N, nbits, block_size, has_zero_point};
std::lock_guard<std::mutex> lock(tmac_kernel_configs_mutex);
auto it = tmac_kernel_configs.find(key);
if (it != tmac_kernel_configs.end()) {
return it->second;
}
MLAS_THROW_EX(std::runtime_error, "T-MAC kernel parameters not initialized for key: " + key);
MLAS_THROW_EX(std::runtime_error, "T-MAC kernel parameters not initialized for key: " + GetTmacKey(M, N, nbits, block_size, has_zero_point));
}

void MLASCALL
Expand All @@ -66,7 +93,7 @@ MlasClearLutGemmKernelConfig()
void MLASCALL
MlasInitLutGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point)
{
std::string key = GetTmacKey(M, N, nbits, block_size, has_zero_point);
TMACConfigKey key{M, N, nbits, block_size, has_zero_point};
{
std::lock_guard<std::mutex> lock(tmac_kernel_configs_mutex);
if (tmac_kernel_configs.find(key) != tmac_kernel_configs.end()) {
Expand Down Expand Up @@ -186,111 +213,17 @@ LutGemmPackQuantBData(
const size_t bm = tmac_params.bm;
const size_t kfactor = tmac_params.kfactor;

assert(BlkLen % g == 0);
assert((BlkLen / g) % kfactor == 0);

const size_t mgroup = ngroups_per_elem * simd_n_in; // 32
assert(bm % mgroup == 0);
assert(bm % bits == 0);

std::unique_ptr<uint8_t[]> buf(new uint8_t[N * bits * (K / g)]);
memset(buf.get(), 0, N * bits * (K / g));

const size_t Iterations = N; // we parallelize over N, TODO:: tune if needed
// LUT GEMM requires a valid LUT dispatch implementation, so dispatch must be available
const auto* Dispatch = GetMlasPlatform().LutGenKernel;
if (Dispatch == nullptr || Dispatch->PackQuantBData == nullptr) {
MLAS_THROW_EX(std::runtime_error, "PackQuantBData requires LUT GEMM dispatch support");
}

MlasTrySimpleParallel(
ThreadPool, Iterations,
[&](ptrdiff_t tid) {
size_t im = static_cast<size_t>(tid);
for (size_t ik = 0; ik < K; ++ik) {
size_t idx = (im * K + ik);
size_t num_elem_per_byte = 8 / bits;
size_t elem_idx = idx % num_elem_per_byte;

uint8_t v = ((const uint8_t*)QuantBDataBegin)[idx / num_elem_per_byte] >> (elem_idx * bits);

for (size_t ib = 0; ib < bits; ++ib) {
size_t new_ik = ik / g;
size_t shft_left = ik % g;
buf[im * bits * K / g + ib * K / g + new_ik] += static_cast<uint8_t>(((v >> ib) & 1) << shft_left);
}
}
}
Dispatch->PackQuantBData(
N, K, bits, g, ngroups_per_elem,
simd_n_in, simd_n_out, bm, kfactor,
QuantBDataBegin, PackedQuantBDataBegin, ThreadPool
);

// Now buf contains the bit planes grouped by g along K
// Next, we need to do a multi-reshape/transpose into the final layout

const size_t c0_fac2 = K / g;
const size_t c0_fac1 = simd_n_out * c0_fac2;
const size_t c0_fac0 = bits * c0_fac1;

const size_t c1_nb2 = K / g;
const size_t c1_nb1 = simd_n_in * c1_nb2;
const size_t c1_nb0 = ngroups_per_elem * c1_nb1;
const size_t c1_fac2 = K / g;
const size_t c1_fac1 = ngroups_per_elem * c1_fac2;
const size_t c1_fac0 = simd_n_in * c1_fac1;

const size_t c2_nb4 = kfactor;
const size_t c2_nb3 = K / g / kfactor * c2_nb4;
const size_t c2_nb2 = ngroups_per_elem * c2_nb3;
const size_t c2_nb1 = simd_n_in * c2_nb2;
const size_t c2_nb0 = bm / mgroup * c2_nb1;
const size_t c2_fac3 = simd_n_in * ngroups_per_elem;
const size_t c2_fac2 = kfactor * c2_fac3;
const size_t c2_fac1 = bm / mgroup * c2_fac2;
const size_t c2_fac0 = K / g / kfactor * c2_fac1;

const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem);
memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); // TODO: is this needed?

// NOTE: The second packing loop is intentionally serialized to avoid data races.
// T-MAC packs multiple output features (N) into a single byte if ngroups_per_elem > 1.
// Parallelizing this across N would lead to concurrent bit-plane updates on the same memory location.
for (size_t im = 0; im < Iterations; im++) {
for (size_t ib = 0; ib < bits; ib++) {
for (size_t ik = 0; ik < K / g; ik++) {
// w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3)
size_t new_im = im / simd_n_out;
size_t new_isno = im % simd_n_out;
size_t new_ib = ib;
size_t new_ik = ik;
size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik;

// w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3)
new_im = new_idx / c1_nb0;
size_t new_ing = (new_idx % c1_nb0) / c1_nb1;
size_t new_isni = (new_idx % c1_nb1) / c1_nb2;
new_ik = (new_idx % c1_nb2);
new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik;

// # 0 1 2 3 4 5
// w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3)
new_im = new_idx / c2_nb0;
size_t new_ibm = (new_idx % c2_nb0) / c2_nb1;
new_isni = (new_idx % c2_nb1) / c2_nb2;
new_ing = (new_idx % c2_nb2) / c2_nb3;
new_ik = (new_idx % c2_nb3) / c2_nb4;
size_t new_ikf = (new_idx % c2_nb4);
new_idx = new_im * c2_fac0 +
new_ik * c2_fac1 +
new_ibm * c2_fac2 +
new_ikf * c2_fac3 +
new_isni * ngroups_per_elem +
new_ing;
new_idx = new_idx / ngroups_per_elem;
size_t buf_idx = im * bits * K / g + ib * K / g + ik;
uint8_t buf_val = buf[buf_idx];

// w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)])
PackedQuantBDataBegin[new_idx] = static_cast<std::byte>(
static_cast<unsigned>(PackedQuantBDataBegin[new_idx]) +
(buf_val << (new_ing * g))
);
}
}
}
}

// Internal helper: calculates packed scales and zero points size in floats
Expand Down Expand Up @@ -320,67 +253,25 @@ LutPackScalesAndZeroPoints(
bool HasZeroPoint,
float* PackedQuantBZPBegin,
const float* QuantBScale,
const uint8_t* QuantBZeroPoint
const uint8_t* QuantBZeroPoint,
MLAS_THREADPOOL* ThreadPool
)
{
const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, BlkBitWidth, BlkLen, HasZeroPoint);
const size_t bits = tmac_params.bits;
const size_t simd_n_out = tmac_params.simd_n_out;
const size_t bm = tmac_params.bm;
const size_t num_elem_per_byte = 8 / bits;

// ZP array is column-major packed, with per-column alignment to byte boundary
const size_t row_blks = K / BlkLen; // number of blocks per column
const size_t zp_bytes_per_col = (row_blks + num_elem_per_byte - 1) / num_elem_per_byte;

for (size_t im = 0; im < N; im += 1) {
for (size_t ik = 0; ik < K; ik += BlkLen) {
size_t idx = (im * K + ik) / BlkLen; // linear block index for scale (scale is NOT packed)
float scale = QuantBScale[idx];
float zp = 0.0f;
if (HasZeroPoint) {
size_t blk_in_col = ik / BlkLen; // block index within column
size_t zp_byte_idx = im * zp_bytes_per_col + blk_in_col / num_elem_per_byte;
size_t elem_idx = blk_in_col % num_elem_per_byte;
uint8_t v = (QuantBZeroPoint[zp_byte_idx] >> (elem_idx * bits)) & ((1 << bits) - 1);

// The LUT kernel assumes weights are centered around the midpoint (2 for 2-bit).
// Thus, need to correct for the actual ZP relative to the midpoint.

int midpoint = 1 << (bits - 1); // 2 for 2-bit
zp = static_cast<float>(static_cast<int>(v) - midpoint) * scale;
}

// TODO(vraspar): fix when k < BlkLen and nb1 is 0
size_t nb1 = K / BlkLen;
size_t nb0 = bm / bits * nb1;

size_t new_im, new_ibm, new_ik;
if (nb1 == 0) {
new_im = 0;
new_ibm = 0;
new_ik = 0;

} else {
new_im = idx / nb0;
new_ibm = (idx % nb0) / nb1;
new_ik = (idx % nb1);
}

if (HasZeroPoint) {
size_t new_isimd = new_ibm % simd_n_out;
size_t new_idx_outer = new_im * bm / bits * K / BlkLen / simd_n_out + new_ik * bm / bits / simd_n_out + new_ibm / simd_n_out;
size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd;
size_t new_idx_zero = new_idx_outer * (simd_n_out * 2) + simd_n_out + new_isimd;

PackedQuantBZPBegin[new_idx_scale] = scale;
PackedQuantBZPBegin[new_idx_zero] = zp;
} else {
size_t new_idx = new_im * bm / bits * K / BlkLen + new_ik * bm / bits + new_ibm;
PackedQuantBZPBegin[new_idx] = scale;
}
}
// LUT GEMM is only available for AVX2, so dispatch must be available
const auto* Dispatch = GetMlasPlatform().LutGenKernel;
if (Dispatch == nullptr || Dispatch->PackScalesAndZeroPoints == nullptr) {
MLAS_THROW_EX(std::runtime_error, "PackScalesAndZeroPoints requires AVX2 dispatch");
}

Dispatch->PackScalesAndZeroPoints(
N, K, bits, BlkLen, simd_n_out, bm, HasZeroPoint,
PackedQuantBZPBegin, QuantBScale, QuantBZeroPoint, ThreadPool
);
}

// Internal helper: calculates the offset to scales in the packed buffer
Expand Down Expand Up @@ -440,7 +331,7 @@ MlasLutGemmPack(
if (QuantBScale != nullptr) {
size_t scales_offset = LutGemmPackedScalesOffset(N, K, BlkBitWidth, BlkLen, HasZeroPoint);
float* scales_dest = reinterpret_cast<float*>(PackedBuf + scales_offset);
LutPackScalesAndZeroPoints(N, K, BlkBitWidth, BlkLen, HasZeroPoint, scales_dest, QuantBScale, QuantBZeroPoint);
LutPackScalesAndZeroPoints(N, K, BlkBitWidth, BlkLen, HasZeroPoint, scales_dest, QuantBScale, QuantBZeroPoint, ThreadPool);
}
}

Expand All @@ -453,7 +344,11 @@ MlasIsLutGemmAvailable(
)
{
const auto* lut_kernel = GetMlasPlatform().LutGenKernel;
if (lut_kernel == nullptr || lut_kernel->GenerateLUT == nullptr || lut_kernel->ComputeGemm == nullptr) {
if (lut_kernel == nullptr ||
lut_kernel->GenerateLUT == nullptr ||
lut_kernel->ComputeGemm == nullptr ||
lut_kernel->PackQuantBData == nullptr ||
lut_kernel->PackScalesAndZeroPoints == nullptr) {
Comment thread
vraspar marked this conversation as resolved.
return false;
}

Expand Down Expand Up @@ -521,7 +416,9 @@ MlasLutGemm(
// adapted from ggml_backend_tmac_mul_mat
const auto* Dispatch = GetMlasPlatform().LutGenKernel;
// This should be ensured by calling MlasIsLutGemmAvailable() before MlasLutGemm()
assert(Dispatch && Dispatch->GenerateLUT && "TMAC not supported in this configuration.");
if (Dispatch == nullptr || Dispatch->GenerateLUT == nullptr || Dispatch->ComputeGemm == nullptr) {
MLAS_THROW_EX(std::runtime_error, "TMAC not supported in this configuration");
}

// Calculate scales offset from packed buffer
// TODO(vraspar): support other bitwidths
Expand Down Expand Up @@ -649,10 +546,8 @@ MlasLutGemm(
size_t scales_size_per_tile = 0;

if (scales_size_total % n_tiles_num != 0) {
// Sanity: scales should partition evenly across tiles. If they don't, choose floor division
// and document that callers must layout scales accordingly.
// Prefer to error loudly in debug builds.
fprintf(stderr, "Warning: scales_size_total=%zu is not divisible by n_tiles_num=%zu; using floor division.\n", scales_size_total, n_tiles_num);
// Scales must partition evenly across tiles. Callers must ensure proper layout.
MLAS_THROW_EX(std::runtime_error, "scales_size_total must be divisible by n_tiles_num");
}
scales_size_per_tile = scales_size_total / n_tiles_num;

Expand Down
39 changes: 39 additions & 0 deletions onnxruntime/core/mlas/lib/qlutgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,41 @@ typedef void(MLAS_QNBIT_LUT_GEMM_COMPUTE)(
bool HasZeroPoint
);

//
// Function signature for packing quantized B data
//
typedef void(MLAS_QNBIT_LUT_PACK_QUANTB_DATA)(
size_t N,
size_t K,
size_t bits,
size_t g,
size_t ngroups_per_elem,
size_t simd_n_in,
size_t simd_n_out,
size_t bm,
size_t kfactor,
const std::byte* QuantBDataBegin,
std::byte* PackedQuantBDataBegin,
MLAS_THREADPOOL* ThreadPool
);

//
// Function signature for packing scales and zero points
//
typedef void(MLAS_QNBIT_LUT_PACK_SCALES_AND_ZP)(
size_t N,
size_t K,
size_t bits,
size_t BlkLen,
size_t simd_n_out,
size_t bm,
bool HasZeroPoint,
float* PackedScalesBegin,
const float* QuantBScale,
const uint8_t* QuantBZeroPoint,
MLAS_THREADPOOL* ThreadPool
);

//
// Kernel dispatch structure.
//
Expand All @@ -87,4 +122,8 @@ struct MLAS_QNBIT_LUT_GEMM_DISPATCH {
MLAS_QNBIT_GEMM_LUT_GEN* GenerateLUT = nullptr;

MLAS_QNBIT_LUT_GEMM_COMPUTE* ComputeGemm = nullptr;

MLAS_QNBIT_LUT_PACK_QUANTB_DATA* PackQuantBData = nullptr;

MLAS_QNBIT_LUT_PACK_SCALES_AND_ZP* PackScalesAndZeroPoints = nullptr;
};
Loading
Loading