Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions ggml/src/ggml-sycl/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,17 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
#endif
}

template <typename dst_t>
static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
const int64_t nb = k / QK_K;

dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });

stream->parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
}

template <typename dst_t>
static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
dpct::queue_ptr stream) {
Expand Down Expand Up @@ -530,7 +541,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_sycl;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_sycl;
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q6_K_sycl_reorder;
} else {
return dequantize_row_q6_K_sycl;
}
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_sycl;
case GGML_TYPE_IQ1_M:
Expand Down Expand Up @@ -587,7 +602,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_sycl;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_sycl;
if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
return dequantize_row_q6_K_sycl_reorder;
} else {
return dequantize_row_q6_K_sycl;
}
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_sycl;
case GGML_TYPE_IQ1_M:
Expand Down
32 changes: 32 additions & 0 deletions ggml/src/ggml-sycl/dequantize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,38 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri
#endif
}

template <typename dst_t>
static void dequantize_block_q6_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> & item_ct1, int64_t n_blocks) {
const int64_t ib = item_ct1.get_group(2);

const int64_t tid = item_ct1.get_local_id(2);
const int64_t ip = tid / 32; // ip is 0 or 1
const int64_t il = tid - 32 * ip; // 0...32
const int64_t is = 8 * ip + il / 16;

const uint8_t * base_ptr = static_cast<const uint8_t *>(vx);
const auto ql_offset = ib * (QK_K / 2);
const auto qh_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * ib;
const auto base_scales_offset = (QK_K / 2) * n_blocks + (QK_K / 4) * n_blocks + (QK_K / 16) * ib;
const auto base_d_offset = ((QK_K / 2) + (QK_K / 4) + (QK_K / 16)) * n_blocks;
const uint8_t * ql_ptr = base_ptr + ql_offset;
const uint8_t * qh_ptr = base_ptr + qh_offset;
const uint8_t * scales_ptr = base_ptr + base_scales_offset;
const ggml_half * d = (const ggml_half *) (base_ptr + base_d_offset) + ib;

dst_t * y = yy + ib * QK_K + 128 * ip + il;

const uint8_t * ql = ql_ptr + 64 * ip + il;
const uint8_t qh = *(qh_ptr + 32 * ip + il);
const int8_t * sc = reinterpret_cast<const int8_t *>(scales_ptr + is);

y[0] = *d * sc[0] * ((int8_t) ((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
y[32] = *d * sc[2] * ((int8_t) ((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
y[64] = *d * sc[4] * ((int8_t) ((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
y[96] = *d * sc[6] * ((int8_t) ((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
}

template<typename dst_t>
static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1,
Expand Down
52 changes: 51 additions & 1 deletion ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
assert(tensor->view_src->buffer->buft == buffer->buft);
return GGML_STATUS_SUCCESS;
}
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K) && !g_ggml_sycl_disable_optimize) {
if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) &&
!g_ggml_sycl_disable_optimize) {
ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
tensor->extra = extra;
ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx.
Expand Down Expand Up @@ -2989,6 +2990,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
case GGML_TYPE_Q4_0:
return true;
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
return !g_ggml_sycl_prioritize_dmmv;
default:
return false;
Expand All @@ -3008,6 +3010,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q6_K:
return true;
default:
return false;
Expand Down Expand Up @@ -3092,6 +3095,50 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d
sycl::free(tmp_buf, *stream);
}

static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) {
GGML_ASSERT(size % sizeof(block_q6_K) == 0);
GGML_ASSERT(offset % sizeof(block_q6_K) == 0);

const int nblocks = size / sizeof(block_q6_K);

auto * tmp_buf = sycl::malloc_shared<uint8_t>(size, *stream);
SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(tmp_buf, data_device, size).wait()));

auto * ql_ptr = data_device;
auto * qh_ptr = ql_ptr + (QK_K / 2) * nblocks;
auto * scales_ptr = qh_ptr + (QK_K / 4) * nblocks;
sycl::half * dm_ptr = (sycl::half *) (scales_ptr + (QK_K / 16) * nblocks);

stream
->parallel_for(nblocks,
[=](auto i) {
const block_q6_K * x = (const block_q6_K *) tmp_buf;
const int ib = i;

const uint8_t * ql = x[ib].ql;
const uint8_t * qh = x[ib].qh;
uint8_t * base_ql_ptr = ql_ptr + (QK_K / 2) * ib;
uint8_t * base_qh_ptr = qh_ptr + (QK_K / 4) * ib;
uint8_t * base_scales_ptr = scales_ptr + (QK_K / 16) * ib;

for (int j = 0; j < QK_K / 2; ++j) {
base_ql_ptr[j] = ql[j];
}
for (int j = 0; j < QK_K / 4; ++j) {
base_qh_ptr[j] = qh[j];
}

for (int j = 0; j < QK_K / 16; ++j) {
base_scales_ptr[j] = x[ib].scales[j];
}

dm_ptr[ib] = x[ib].d;
})
.wait_and_throw();

sycl::free(tmp_buf, *stream);
}

static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
uint8_t * data_device = (uint8_t *) src0->data;
size_t ncols = src0->ne[0];
Expand All @@ -3105,6 +3152,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) {
case GGML_TYPE_Q4_K:
reorder_qw_q4_k(data_device, size, 0, stream);
break;
case GGML_TYPE_Q6_K:
reorder_qw_q6_k(data_device, size, 0, stream);
break;
default:
GGML_ABORT("reorder_qw() called with unsupported type");
break;
Expand Down
36 changes: 30 additions & 6 deletions ggml/src/ggml-sycl/mmvq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r

float partial_sum = 0.0f;
for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) {
const int ibx = row * blocks_per_row + i; // x block index
// TODO: Generalize offsets, right now only works for quantizations that don't split high and low bits
const int bx_offset = block_type::get_block_offset(ibx);
const int d_offset = block_type::get_d_offset(nrows, ncols, ibx);
const int ibx = row * blocks_per_row + i; // x block index

const auto bx_offset = block_type::get_block_offset(ibx, nblocks);
const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx);
// Y block index that aligns with ibx
const int iby = i * block_type::block_to_q8_1_ratio();
const int8_t* q8_1_quant_ptr = (const int8_t*)vy + iby * QK8_1;
Expand All @@ -46,7 +45,7 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r
// x block quant index when casting the quants to int
const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup);

partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs, nblocks);
partial_sum += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs);
}
}

Expand Down Expand Up @@ -785,6 +784,24 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
}
}

static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols,
const int nrows, dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y);
constexpr size_t num_subgroups = 16;
GGML_ASSERT(block_num_y % num_subgroups == 0);

const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE);
const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE);

stream->submit([&](sycl::handler & cgh) {
cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size),
[=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>>(vx, vy, dst, ncols, nrows,
nd_item);
});
});
}
static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols,
const int nrows,
Expand Down Expand Up @@ -1070,7 +1087,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q6_K:
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
if ((ggml_tensor_extra_gpu *) dst->src[0]->extra &&
((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) {
GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n");
reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
} else {
GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n");
mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
}
break;
case GGML_TYPE_IQ1_S:
mul_mat_vec_iq1_s_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
Expand Down
48 changes: 38 additions & 10 deletions ggml/src/ggml-sycl/quants.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
#ifndef GGML_SYCL_QUANTS_HPP
#define GGML_SYCL_QUANTS_HPP

#include <utility>

#include "ggml-common.h"
#include "ggml.h"

namespace ggml_sycl_reordered {


// The reordered block moves quants (qs) and scales(d) to two
// uniform regions of memory that is contiguous in the same tensor.
// What this means is that instead of having:
Expand All @@ -32,7 +33,6 @@ namespace ggml_sycl_reordered {

template <ggml_type type> struct block_q_t;


// qk number of weights / quants in a block
// qr number of weights in a byte (described as 'before dequantization')
// for quantization types that has low and high bits split, qr is calculated with
Expand All @@ -47,10 +47,12 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> {
static constexpr uint32_t vdr_mmvq = 2;
};

static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
return { block_index * (traits::qk / traits::qr), 0 };
}

static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
return (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half);
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 };
}

static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
Expand All @@ -64,20 +66,46 @@ template <> struct block_q_t<GGML_TYPE_Q4_K> {
static constexpr uint32_t vdr_mmvq = 2;
};

static constexpr int get_block_offset(const int block_index) { return block_index * (traits::qk / traits::qr); }
static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) {
return { block_index * (traits::qk / traits::qr), 0 };
}

static constexpr int get_d_offset(int nrows, int ncols, const int block_index) {
static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
auto nblocks = (nrows * (ncols / traits::qk));
return (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2));
return { nblocks * (QK_K / 2),
(nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) };
}

static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }

constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; }

constexpr size_t get_dm_offset(int nblocks) { return get_total_qs_bytes(nblocks) + nblocks * K_SCALE_SIZE; }
};

template <> struct block_q_t<GGML_TYPE_Q6_K> {
struct traits {
static constexpr uint32_t qk = QK_K;
static constexpr uint32_t qi = QI6_K;
static constexpr uint32_t qr = QR6_K;
static constexpr uint32_t vdr_mmvq = 1;
};

static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) {
auto low_bits_index = block_index * (traits::qk / traits::qr);
// the index of high bits it's after all low bits
auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4));
return { low_bits_index, high_bits_index };
}

static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this function, and the one above is marked as constexpr ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any particular reason, just being consistent with other quants implementation

auto nblocks = (nrows * (ncols / traits::qk));
auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4);
auto block_scales = total_qs_bytes + block_index * (QK_K / 16);
auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16);
return { block_scales, sb_scale };
}

static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; }
};
} // namespace ggml_sycl_reordered

#endif // GGML_SYCL_QUANTS_HPP
Loading
Loading