Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d122d5c
q6_k scale caching
netrunnereve Jan 4, 2025
6b06d16
16 bit unpack
netrunnereve Jan 4, 2025
21c6b80
q4_k test (slow)
netrunnereve Jan 4, 2025
b0e4ccb
revert it
netrunnereve Jan 4, 2025
07d0d58
q3_k
netrunnereve Jan 4, 2025
d70a731
q2_k
netrunnereve Jan 5, 2025
c01ccf8
little stuff
netrunnereve Jan 5, 2025
bdd98c7
try precalculating products of a and q2_k scales
netrunnereve Jan 5, 2025
1730771
Revert "try precalculating products of a and q2_k scales"
netrunnereve Jan 5, 2025
b4ae700
unpack should be u16, add vim swap to gitignore (about time)
netrunnereve Jan 6, 2025
cdf70cf
better q4_k scales
netrunnereve Jan 6, 2025
6f5d62b
q5_k
netrunnereve Jan 6, 2025
91f1d9c
better q6_k with separate paths for all threads and partial threads i…
netrunnereve Jan 8, 2025
cc28742
q2_k better dequant
netrunnereve Jan 8, 2025
fe71a8c
q3_k optimizations
netrunnereve Jan 8, 2025
923e9a8
q3_k use hmask simd from cpu avx version
netrunnereve Jan 9, 2025
c946364
Merge https://github.com/ggerganov/llama.cpp into vulkan
netrunnereve Jan 9, 2025
51b5ac5
make the caches happy
netrunnereve Jan 9, 2025
973bc40
q3_k separate out calculation
netrunnereve Jan 10, 2025
6145fc7
q2_k separate out
netrunnereve Jan 10, 2025
845d572
little stuff
netrunnereve Jan 10, 2025
d63497b
merge master
netrunnereve Jan 11, 2025
30eacad
use calc_superblock everywhere
netrunnereve Jan 11, 2025
ed1ad94
q2_k optimize scale calculation
netrunnereve Jan 12, 2025
4ae3fc0
more barriers
netrunnereve Jan 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*.metallib
*.o
*.so
*.swp
*.tmp

# IDE / OS
Expand Down
153 changes: 83 additions & 70 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,79 @@

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16];
shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16];

FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];

void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;

[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;

if (!all_threads) { // when we don't have enough blocks to use all threads
if (i < num_blocks_per_row) {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
}
barrier();

if (i >= num_blocks_per_row)
continue;
} else {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
barrier();
}

const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));

vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);

FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ],
fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2],
fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ],
fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2],
fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ],
fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2],
fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ],
fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im],
fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im],
fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im],
fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im],
fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im],
fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im],
fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2))))))));
}
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
}
}
}

void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
Expand All @@ -14,88 +87,28 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...16
const uint ix = tid/16;

const uint step = 8;
const uint itid = tid%16; // 0...15
const uint ix = tid/16;

const uint v_im = itid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - step*v_im; // 0...15 or 0...7
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - 8*v_im; // 0...7

const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0;
const uint s_offset = 8*v_im;
const uint y_offset = 128*v_im + l0;

FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}

[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
const uint y_idx = i * QUANT_K + y_offset;

[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);

uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];

uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F;
uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F;
uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F;
uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F;

uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32));
uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32));
uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32));
uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32));

uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0];
uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8];
uvec2 qs0 = uvec2(unpack8(qs0_u16));
uvec2 qs16 = uvec2(unpack8(qs16_u16));

[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);

FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
}
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
}
}
}
const uint nbr_par_th = num_blocks_per_row%it_size;
const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
uint i0 = 0;
[[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);

reduce_result(temp, d_offset, first_row, num_rows, tid);
}
Expand Down
Loading
Loading