Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ggml/src/ggml-opencl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ set(GGML_OPENCL_KERNELS
mul_mv_q6_k_f32_flat
mul_mv_q8_0_f32
mul_mv_q8_0_f32_flat
mul_mv_iq4_nl_f32
mul_mv_iq4_nl_f32_flat
mul_mv_mxfp4_f32
mul_mv_mxfp4_f32_flat
mul_mv_id_q4_0_f32_8x_flat
Expand All @@ -110,12 +112,15 @@ set(GGML_OPENCL_KERNELS
mul_mm_q4_0_f32_l4_lm
mul_mm_q4_1_f32_l4_lm
mul_mm_q8_0_f32_l4_lm
mul_mm_iq4_nl_f32_l4_lm
mul_mm_q4_k_f32_l4_lm
mul_mm_q5_k_f32_l4_lm
mul_mm_q6_k_f32_l4_lm
mul_mm_q8_0_f32_8x4
gemv_noshuffle_q4_1_f32
gemm_noshuffle_q4_1_f32
gemv_noshuffle_iq4_nl_f32
gemm_noshuffle_iq4_nl_f32
gemv_noshuffle_general_q8_0_f32
gemv_noshuffle_q4_k_f32
gemm_noshuffle_q4_k_f32
Expand Down
594 changes: 594 additions & 0 deletions ggml/src/ggml-opencl/ggml-opencl.cpp

Large diffs are not rendered by default.

107 changes: 107 additions & 0 deletions ggml/src/ggml-opencl/kernels/cvt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ struct block_q6_K {
half d; // super-block scale
};

//------------------------------------------------------------------------------
// block_iq4_nl
//------------------------------------------------------------------------------
#define QK4_NL 32

struct block_iq4_nl
{
half d;
uint8_t qs[QK4_NL / 2];
};

//------------------------------------------------------------------------------
// kernel_convert_block_q4_0
// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
Expand Down Expand Up @@ -895,3 +906,99 @@ kernel void kernel_restore_block_q6_K_noshuffle(
b->scales[i] = s[i];
}
}

//------------------------------------------------------------------------------
// kernel_convert_block_iq4_nl
// Convert the block_iq4_nl format to 2 separate arrays (AOS -> SOA).
//------------------------------------------------------------------------------
kernel void kernel_convert_block_iq4_nl(
global struct block_iq4_nl * src0,
global uchar * dst_q,
global half * dst_d,
uchar mask_0F,
uchar mask_F0,
ulong n_blk
) {
if (get_global_id(0) >= n_blk) {
return;
}
global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);

*d = b->d;

for (int i = 0; i < QK4_NL/2; ++i) {
q[i] = b->qs[i];
}
}

kernel void kernel_restore_block_iq4_nl(
global uchar * src_q,
global half * src_d,
global struct block_iq4_nl * dst,
ulong n_blk
) {
if (get_global_id(0) >= n_blk) {
return;
}
global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0);
global half * d = (global half *) src_d + get_global_id(0);

b->d = *d;

for (int i = 0; i < QK4_NL/2; ++i) {
b->qs[i] = q[i];
}
}

kernel void kernel_convert_block_iq4_nl_noshuffle(
global struct block_iq4_nl * src0,
global uchar * dst_q,
global half * dst_d,
uchar mask_0F,
uchar mask_F0,
ulong n_blk
) {
if (get_global_id(0) >= n_blk) {
return;
}
global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);

*d = b->d;
for (int i = 0; i < QK4_NL/4; ++i) {
uchar x0 = b->qs[2*i + 0];
uchar x1 = b->qs[2*i + 1];

q[i + 0 ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4);
q[i + QK4_NL/4] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0);
}
}

kernel void kernel_restore_block_iq4_nl_noshuffle(
global uchar * src_q,
global half * src_d,
global struct block_iq4_nl * dst,
uchar mask_0F,
uchar mask_F0,
ulong n_blk
) {
if (get_global_id(0) >= n_blk) {
return;
}
global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0);
global half * d = (global half *) src_d + get_global_id(0);

b->d = *d;
for (int i = 0; i < QK4_NL/4; ++i) {
uchar x0 = q[i + 0 ];
uchar x1 = q[i + QK4_NL/4];

b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4));
b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0));
}
}
150 changes: 150 additions & 0 deletions ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable

#ifdef cl_qcom_reqd_sub_group_size
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif

constant half kvalues_iq4nl[16] = {
(half)-127.f, (half)-104.f, (half)-83.f, (half)-65.f,
(half) -49.f, (half) -35.f, (half)-22.f, (half)-10.f,
(half) 1.f, (half) 13.f, (half) 25.f, (half) 38.f,
(half) 53.f, (half) 69.f, (half) 89.f, (half)113.f
};

// Packed LUT: 2 FP16 values per uint, 8 unique constant loads instead of 16
constant uint iq4nl_packed[8] = {
0xD680D7F0u, // idx 0,1: -127, -104
0xD410D530u, // idx 2,3: -83, -65
0xD060D220u, // idx 4,5: -49, -35
0xC900CD80u, // idx 6,7: -22, -10
0x4A803C00u, // idx 8,9: 1, 13
0x50C04E40u, // idx 10,11: 25, 38
0x545052A0u, // idx 12,13: 53, 69
0x57105590u // idx 14,15: 89, 113
};

// Packed dequant: 1 uint constant load (8-way divergence) + shift + as_half
#define IQ4_NL_DEQUANT(nibble) as_half((ushort)(iq4nl_packed[(nibble) >> 1] >> (((nibble) & 1u) << 4)))

#ifdef ADRENO_GPU
REQD_SUBGROUP_SIZE_128
#endif

kernel void kernel_gemm_noshuffle_iq4_nl_f32(
global const ushort * src0_q,
global const half * src0_d,
read_only image1d_buffer_t src1,
global float * dst,
ulong offsetd,
int m,
int n,
int k,
int n_no_padding
) {
dst = (global float *)((global char *)dst + offsetd);

int m_4 = m >> 2;
int n_4 = n >> 2;

int gy = get_global_id(0);
int gx = get_global_id(1);
int gx_2 = gx << 2;

half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
half8 B;
half4 dequantized_weights;

global const ushort * weight_ptr = src0_q + gx_2;
global const half * scale_ptr = src0_d + gx_2;

for (int i = 0; i < k; i += 4) {
B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4));
B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1);

ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m));

half4 scale = vload4(0, scale_ptr + (i/32)*(m));

// j=0
dequantized_weights.s0 = IQ4_NL_DEQUANT(bits4.s0 & 0x000Fu) * scale.s0;
dequantized_weights.s1 = IQ4_NL_DEQUANT(bits4.s1 & 0x000Fu) * scale.s1;
dequantized_weights.s2 = IQ4_NL_DEQUANT(bits4.s2 & 0x000Fu) * scale.s2;
dequantized_weights.s3 = IQ4_NL_DEQUANT(bits4.s3 & 0x000Fu) * scale.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;

// j=1
B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4));
B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1);
dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 4) & 0x000Fu) * scale.s0;
dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 4) & 0x000Fu) * scale.s1;
dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 4) & 0x000Fu) * scale.s2;
dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 4) & 0x000Fu) * scale.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;

// j=2
B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4));
B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1);
dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 8) & 0x000Fu) * scale.s0;
dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 8) & 0x000Fu) * scale.s1;
dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 8) & 0x000Fu) * scale.s2;
dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 8) & 0x000Fu) * scale.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;

// j=3
B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4));
B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1);
dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 12) & 0x000Fu) * scale.s0;
dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 12) & 0x000Fu) * scale.s1;
dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 12) & 0x000Fu) * scale.s2;
dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 12) & 0x000Fu) * scale.s3;
c0 += B * dequantized_weights.s0;
c1 += B * dequantized_weights.s1;
c2 += B * dequantized_weights.s2;
c3 += B * dequantized_weights.s3;
}

int idx = (gy<<3)*m + (gx<<2);

if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
idx += m;
}
if(idx+3 < m*n_no_padding){
vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
}
}
Loading
Loading