diff --git a/onnxruntime/core/mlas/lib/qnbitgemm.h b/onnxruntime/core/mlas/lib/qnbitgemm.h index 06e8e49b59e2e..7ec80c6d67f15 100644 --- a/onnxruntime/core/mlas/lib/qnbitgemm.h +++ b/onnxruntime/core/mlas/lib/qnbitgemm.h @@ -53,16 +53,25 @@ struct PackedQuantBDataStruct { { const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T); - if constexpr (BlkBitWidth == 8) { - PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); - } else { #if defined(MLAS_TARGET_AMD64_IX86) // avx512 requires alignment on a 64-byte boundary PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 64); +#elif defined (MLAS_TARGET_ARM64) + // Only for 8-bit Gemms is the `PackedQuantBData` is to be 32-byte aligned and + // there is enough memory allocated to support this alignment. + // See QNBitGemmPackQuantBDataSize(). + // When bit width is 4, there is no alignment guarantee. + // TODO(hasesh): Can we unify the alignment for 4-bit and 8-bit ARM64 Gemms so as to + // simpify this logic and make code here cleaner ? + if constexpr (BlkBitWidth == 8) { + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); + } + else { + PackedQuantBData = (std::byte*)PackedQuantBWorkspace; + } #else PackedQuantBData = (std::byte*)PackedQuantBWorkspace; #endif - } QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize); QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); diff --git a/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp index 3ed283d54f41d..1be05d88849cd 100644 --- a/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sq8bitgemm.cpp @@ -773,7 +773,8 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase { N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, nullptr, HasZp, inputZp, nullptr); - PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen, true); + const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned; + PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen, isQuantAUnsigned); auto* C = C_.GetBuffer(M * ldc, true); auto* ref = ref_.GetBuffer(M * ldc, true); @@ -825,7 +826,9 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase { void ExecuteShort(void) override { Execute<1, 16, 1, 16>(); + Execute<1, 1, 1, 16>(); Execute<7, 2, 4, 16>(); + Execute<7, 128, 4, 16>(); Execute<8, 497, 5, 16>(); Execute<1, 3072, 128, 16>(); Execute<2, 3072, 128, 16>();