Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-zendnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF")
ExternalProject_Add(
zendnn
GIT_REPOSITORY https://github.com/amd/ZenDNN.git
GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17
GIT_TAG 253b94ce0d7e9284c265fefb485714944caff9d3 # ZenDNN-2026-WW19
PREFIX ${ZENDNN_PREFIX}
SOURCE_DIR ${ZENDNN_SOURCE_DIR}
BINARY_DIR ${ZENDNN_BUILD_DIR}
Expand Down
56 changes: 45 additions & 11 deletions ggml/src/ggml-zendnn/ggml-zendnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

#include "ggml-backend-impl.h"
#include "ggml-impl.h"

#define GGML_COMMON_DECL_CPP
#include "ggml-common.h"

#include "zendnnl.hpp"

#include <cstring>
Expand All @@ -19,6 +23,8 @@ zendnnl::common::data_type_t ggml_to_zendnn_type() {
return zendnnl::common::data_type_t::f32;
} else if constexpr (std::is_same_v<T, ggml_bf16_t>) {
return zendnnl::common::data_type_t::bf16;
} else if constexpr (std::is_same_v<T, block_q8_0>) {
return zendnnl::common::data_type_t::s8;
} else {
return zendnnl::common::data_type_t::none;
}
Expand Down Expand Up @@ -48,6 +54,17 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int
params.num_threads = ctx->n_threads;

zendnnl::lowoha::matmul::matmul_batch_params_t batch_params;

if constexpr (std::is_same_v<TA, block_q8_0>) {
params.dtypes.compute = zendnnl::common::data_type_t::s8;
const int64_t num_groups = k / QK8_0;
params.dynamic_quant = true;
params.quant_params.src_scale.buff = nullptr;
params.quant_params.src_scale.dt = zendnnl::common::data_type_t::bf16;
params.quant_params.src_scale.dims = {n, num_groups};
params.packing.pack_format_b = 1;
}

zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct(
'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major)
n, // M: rows of B and C
Expand Down Expand Up @@ -108,6 +125,14 @@ static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int6
(const ggml_bf16_t *)B, ldb,
(float *)C, ldc);
return false;
case GGML_TYPE_Q8_0:
if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32)
return false;
return ggml_zendnn_matmul<block_q8_0, float, float>(
ctx, m, n, k,
(const block_q8_0 *)A, lda,
(const float *)B, ldb,
(float *)C, ldc);
default:
return false; // unsupported type
}
Expand Down Expand Up @@ -145,7 +170,9 @@ static void ggml_zendnn_compute_forward_mul_mat(
const int64_t r3 = ne13/ne03;

void * work_data = ctx->work_data.get();
if (src1->type != vec_dot_type) {

// ZenDNN requires FP32 for dynamic quantization, so conversion is skipped
if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) {
Comment thread
z-sachin marked this conversation as resolved.
const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
const size_t nbw2 = nbw1 * ne11;
const size_t nbw3 = nbw2 * ne12;
Expand All @@ -171,7 +198,7 @@ static void ggml_zendnn_compute_forward_mul_mat(

for (int64_t i13 = 0; i13 < ne13; i13++) {
for (int64_t i12 = 0; i12 < ne12; i12++) {
const void* wdata = src1->type == vec_dot_type ? src1->data : work_data;
const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
if (!ggml_zendnn_sgemm(ctx,
ne01, // m
Expand All @@ -184,7 +211,7 @@ static void ggml_zendnn_compute_forward_mul_mat(
static_cast<char *>(dst->data) + i12*nb2 + i13*nb3,
ne01, // ldc
src0->type,
vec_dot_type,
src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type,
dst->type))
GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
}
Expand Down Expand Up @@ -261,10 +288,15 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
const size_t nbw1 = row_size;
const size_t nbw2 = nbw1 * ne11;
const size_t nbw3 = nbw2 * ne12;
const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0;
const size_t src1_conv_size = (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) ? ne13 * nbw3 : 0;

// For Q8_0, src1 is always F32; the gather buffer must hold F32 rows (ne10*4 bytes),
// not Q8_0-encoded rows (row_size ≈ ne10/32*34 bytes) — they differ by ~4x.
const size_t f32_row_size = (size_t)ne10 * sizeof(float);
const size_t gather_row_size = (src0->type == GGML_TYPE_Q8_0) ? f32_row_size : row_size;

// size for MoE gather/scatter buffers
const size_t wdata_cur_size = max_rows * row_size;
const size_t wdata_cur_size = max_rows * gather_row_size;
const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01);

// allocate single buffer for all needs
Expand All @@ -279,7 +311,8 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
char * wdata_cur = work_data + src1_conv_size;
char * dst_cur = wdata_cur + wdata_cur_size;

if (src1->type != vec_dot_type) {
// ZenDNN requires FP32 for dynamic quantization, so conversion is skipped
if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) {
GGML_ASSERT(src1->type == GGML_TYPE_F32);

#pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static)
Expand All @@ -294,7 +327,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
}
}

const void * wdata = src1->type == vec_dot_type ? src1->data : work_data;
const void * wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data;

// process each expert with gather -> gemm -> scatter pattern
for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) {
Expand All @@ -315,9 +348,9 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
const int64_t i12 = row_mapping.i2;

std::memcpy(
wdata_cur + ir1 * row_size,
(const char *) wdata + (i11 + i12*ne11) * row_size,
row_size
wdata_cur + ir1 * gather_row_size,
(const char *) wdata + (i11 + i12*ne11) * gather_row_size,
gather_row_size
);
}

Expand All @@ -333,7 +366,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id(
dst_cur,
ne01, // ldc
src0->type,
vec_dot_type,
src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type,
dst->type)) {
GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
}
Expand Down Expand Up @@ -577,6 +610,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const
switch (weights->type) {
case GGML_TYPE_F32:
case GGML_TYPE_BF16:
case GGML_TYPE_Q8_0:
return true;
default:
return false;
Expand Down
Loading