Skip to content

Commit e976127

Browse files
authored
Merge pull request #33 from zoq/vulkan_tq2_0_type
Integrate TQ2_0 into Vulkan
2 parents 9d22adc + 6d0777e commit e976127

File tree

13 files changed

+391
-13
lines changed

13 files changed

+391
-13
lines changed

convert_hf_to_gguf.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2641,18 +2641,47 @@ def prepare_tensors(self):
26412641
super().prepare_tensors()
26422642

26432643

2644-
@ModelBase.register("BitnetForCausalLM")
2644+
@ModelBase.register("BitnetForCausalLM", "BitNetForCausalLM")
26452645
class BitnetModel(TextModel):
26462646
model_arch = gguf.MODEL_ARCH.BITNET
26472647

2648+
def __init__(self, *args, **kwargs):
2649+
super().__init__(*args, **kwargs)
2650+
self._bitnet_weight_scales: dict[str, torch.Tensor] = {}
2651+
26482652
def set_vocab(self):
2649-
self._set_vocab_sentencepiece()
2653+
if (self.dir_model / "tokenizer.model").is_file():
2654+
self._set_vocab_sentencepiece()
2655+
else:
2656+
self._set_vocab_gpt2()
26502657

26512658
def set_gguf_parameters(self):
26522659
super().set_gguf_parameters()
26532660
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
26542661
self.gguf_writer.add_rope_scaling_factor(1.0)
26552662

2663+
@staticmethod
2664+
def _unpack_bitnet_weights(packed: torch.Tensor) -> torch.Tensor:
2665+
if packed.dtype != torch.uint8:
2666+
raise ValueError(f"Expected packed BitNet weights to be torch.uint8, got {packed.dtype}")
2667+
2668+
values_per_item = 4
2669+
rows = packed.shape[0]
2670+
rest = packed.shape[1:]
2671+
2672+
unpacked_chunks: list[torch.Tensor] = []
2673+
mapping = torch.tensor([-1.0, 0.0, 1.0, 0.0], dtype=torch.float32, device=packed.device)
2674+
2675+
for i in range(values_per_item):
2676+
chunk = (packed >> (2 * i)) & 0x03
2677+
chunk = mapping[chunk.long()].reshape((rows, *rest))
2678+
unpacked_chunks.append(chunk)
2679+
2680+
if not unpacked_chunks:
2681+
raise ValueError("Failed to unpack BitNet weights: no chunks produced")
2682+
2683+
return torch.cat(unpacked_chunks, dim=0)
2684+
26562685
def weight_quant(self, weight: Tensor) -> Tensor:
26572686
dtype = weight.dtype
26582687
weight = weight.float()
@@ -2665,8 +2694,36 @@ def weight_quant(self, weight: Tensor) -> Tensor:
26652694
return result.type(dtype)
26662695

26672696
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2697+
if name.endswith(".weight_scale"):
2698+
weight_name = name[:-13] + ".weight"
2699+
mapped_weight_name = self.map_tensor_name(weight_name)
2700+
if isinstance(data_torch, LazyTorchTensor):
2701+
data_torch = LazyTorchTensor.to_eager(data_torch)
2702+
2703+
scale_tensor = data_torch.to(torch.float32)
2704+
self._bitnet_weight_scales[mapped_weight_name] = scale_tensor
2705+
return []
2706+
26682707
new_name = self.map_tensor_name(name)
26692708

2709+
ternary_weight = False
2710+
2711+
if name.endswith(".weight"):
2712+
scale_tensor = self._bitnet_weight_scales.pop(new_name, None)
2713+
if scale_tensor is not None:
2714+
scale_tensor = scale_tensor.to(torch.float32)
2715+
if scale_tensor.numel() != 1:
2716+
raise ValueError(f"Expected scalar weight_scale for '{name}', got shape {tuple(scale_tensor.shape)}")
2717+
2718+
if isinstance(data_torch, LazyTorchTensor):
2719+
data_torch = LazyTorchTensor.to_eager(data_torch)
2720+
2721+
packed = data_torch.to(torch.uint8)
2722+
unpacked = self._unpack_bitnet_weights(packed)
2723+
scale_value = scale_tensor.reshape(-1)[0].item()
2724+
data_torch = unpacked * scale_value
2725+
ternary_weight = True
2726+
26702727
if any(self.match_model_tensor_name(new_name, key, bid) for key in [
26712728
gguf.MODEL_TENSOR.ATTN_Q,
26722729
gguf.MODEL_TENSOR.ATTN_K,
@@ -2675,7 +2732,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
26752732
gguf.MODEL_TENSOR.FFN_UP,
26762733
gguf.MODEL_TENSOR.FFN_DOWN,
26772734
gguf.MODEL_TENSOR.FFN_GATE,
2678-
]):
2735+
]) and not ternary_weight:
26792736
# transform weight into 1/0/-1 (in fp32)
26802737
data_torch = self.weight_quant(data_torch)
26812738

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 71 additions & 2 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,30 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
434434
}
435435
#endif
436436

437+
#if defined(DATA_A_TQ2_0)
438+
// TQ2_0 ternary dequantization: {0,1,2} -> {-1,0,+1} via (q-1) mapping
439+
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
440+
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
441+
const uint c0 = (vui >> 0) & 3;
442+
const uint c1 = (vui >> 2) & 3;
443+
const float q0 = float(c0) - 1.0f;
444+
const float q1 = float(c1) - 1.0f;
445+
return vec2(q0, q1);
446+
}
447+
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
448+
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
449+
const uint c0 = (vui >> 0) & 3;
450+
const uint c1 = (vui >> 2) & 3;
451+
const uint c2 = (vui >> 4) & 3;
452+
const uint c3 = (vui >> 6) & 3;
453+
const float q0 = float(c0) - 1.0f;
454+
const float q1 = float(c1) - 1.0f;
455+
const float q2 = float(c2) - 1.0f;
456+
const float q3 = float(c3) - 1.0f;
457+
return vec4(q0, q1, q2, q3);
458+
}
459+
#endif
460+
437461
#if defined(DATA_A_MXFP4)
438462
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
439463
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@@ -461,7 +485,7 @@ vec2 get_dm(uint ib, uint a_offset) {
461485
}
462486
#endif
463487

464-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
488+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_TQ2_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
465489
vec2 get_dm(uint ib, uint a_offset) {
466490
return vec2(float(data_a[a_offset + ib].d), 0);
467491
}

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,24 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
654654
}
655655
#endif
656656

657+
#if defined(DATA_A_TQ2_0)
658+
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTQ2_0 {
659+
block_tq2_0 block;
660+
};
661+
662+
float16_t dequantFuncTQ2_0(const in decodeBufTQ2_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
663+
{
664+
const float16_t d = bl.block.d;
665+
const uint idx = coordInBlock[1];
666+
667+
const uint byte_idx = ((idx >> 7) << 5) + (idx & 31u);
668+
const uint qsshift = (((idx & 127u) >> 5) << 1);
669+
670+
const uint c = (uint(bl.block.qs[byte_idx]) >> qsshift) & 3u;
671+
return d * float16_t(float(c) - 1.0f);
672+
}
673+
#endif
674+
657675
#if defined(DATA_A_MXFP4)
658676
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
659677
block_mxfp4 block;
@@ -715,6 +733,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
715733
#define dequantFuncA dequantFuncIQ4_XS
716734
#elif defined(DATA_A_IQ4_NL)
717735
#define dequantFuncA dequantFuncIQ4_NL
736+
#elif defined(DATA_A_TQ2_0)
737+
#define dequantFuncA dequantFuncTQ2_0
718738
#elif defined(DATA_A_MXFP4)
719739
#define dequantFuncA dequantFuncMXFP4
720740
#endif
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#version 450
2+
3+
#extension GL_EXT_shader_16bit_storage : require
4+
5+
#include "types.comp"
6+
7+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
8+
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
9+
10+
layout (push_constant) uniform parameter {
11+
uint ne;
12+
} p;
13+
14+
layout (local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
15+
16+
void main() {
17+
const uint i = gl_GlobalInvocationID.x * 4;
18+
19+
if (i >= p.ne) {
20+
return;
21+
}
22+
23+
const uint ib = i / QUANT_K; // block index
24+
const uint iqs = (i % QUANT_K) / 4; // quant index within block (byte index)
25+
const uint bit_pos_base = (i % 4) * 2; // bit position within byte
26+
27+
const float d = float(data_a[ib].d);
28+
29+
for (uint j = 0; j < 4 && (i + j) < p.ne; ++j) {
30+
const uint local_iqs = ((i + j) % QUANT_K) / 4; // byte index for this element
31+
const uint bit_pos = ((i + j) % 4) * 2; // bit position for this element
32+
const uint vui = uint(data_a[ib].qs[local_iqs]);
33+
const uint q = (vui >> bit_pos) & 3;
34+
data_b[i + j] = D_TYPE(d * (float(q) - 1.0f));
35+
}
36+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#version 450
2+
#extension GL_EXT_shader_explicit_arithmetic_types : require
3+
4+
#include "mul_mat_vec_base.comp"
5+
6+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
7+
8+
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
9+
10+
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
11+
uint a_offset, b_offset, d_offset;
12+
get_offsets(a_offset, b_offset, d_offset);
13+
14+
const uint num_blocks_per_row = p.ncols / QUANT_K;
15+
16+
const uint tid = gl_LocalInvocationID.x;
17+
18+
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
19+
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
20+
temp[j][i] = FLOAT_TYPE(0);
21+
}
22+
}
23+
24+
[[unroll]] for (uint i = tid; i < num_blocks_per_row; i += gl_WorkGroupSize.x) {
25+
26+
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
27+
const uint ib0 = a_offset / QUANT_K + (first_row + n) * num_blocks_per_row;
28+
const float d = float(data_a[ib0 + i].d);
29+
30+
[[unroll]] for (uint j = 0; j < 64; j += 32) {
31+
[[unroll]] for (uint l = 0; l < 4; ++l) {
32+
[[unroll]] for (uint k = 0; k < 32; ++k) {
33+
// Extract quantized value: ((x[i].qs[j + k] >> (l*2)) & 3) - 1
34+
const uint q_byte = uint(data_a[ib0 + i].qs[j + k]);
35+
const uint shift = l * 2;
36+
const uint q = (q_byte >> shift) & 3;
37+
const FLOAT_TYPE dequant_val = FLOAT_TYPE(d * (float(q) - 1.0f)); // CPU kernel: (q-1)*d
38+
39+
// y-data access pattern: y[i].qs[j*4 + l*32 + k]
40+
const uint b_idx = i * QUANT_K + j * 4 + l * 32 + k;
41+
if (b_idx < p.ncols) {
42+
[[unroll]] for (uint jcol = 0; jcol < NUM_COLS; ++jcol) {
43+
temp[jcol][n] += dequant_val * FLOAT_TYPE(data_b[jcol * p.batch_stride_b + b_offset + b_idx]);
44+
}
45+
}
46+
}
47+
}
48+
}
49+
}
50+
}
51+
52+
reduce_result(temp, d_offset, first_row, num_rows, tid);
53+
}
54+
55+
void main() {
56+
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
57+
58+
if (first_row + NUM_ROWS <= p.stride_d) {
59+
compute_outputs(first_row, NUM_ROWS);
60+
} else {
61+
if (first_row >= p.stride_d) {
62+
return;
63+
}
64+
compute_outputs(first_row, p.stride_d - first_row);
65+
}
66+
}

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,22 @@ void main() {
450450
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
451451
buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
452452
buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
453+
#elif defined(DATA_A_TQ2_0)
454+
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
455+
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
456+
457+
const uint ib = idx / 128; // 2 values per idx (like Q2_K)
458+
const uint iqs = idx % 128; // 0..127
459+
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // Q2_K indexing pattern
460+
const uint qsshift = ((iqs % 64) / 16) * 2; // Q2_K shift: 0,2,4,6
461+
462+
const float d = float(data_a[ib].d);
463+
464+
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
465+
const vec2 v = d * (vec2((qs >> qsshift) & 3) - 1.0f); // (q-1)*d
466+
467+
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
468+
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
453469
#elif defined(DATA_A_Q2_K)
454470
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
455471
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#version 450
2+
3+
#include "types.comp"
4+
#include "generic_binary_head.comp"
5+
#include "dequant_funcs.comp"
6+
7+
const uint num_threads = 256;
8+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
9+
10+
void get_dst_indices(uint idx, out uint i20, out uint i21, out uint i22, out uint i23) {
11+
i23 = fastdiv(idx, (p.ne22*p.ne21*p.ne20));
12+
const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
13+
i22 = fastdiv((idx - i23_offset), (p.ne21*p.ne20));
14+
const uint i22_offset = i22*p.ne21*p.ne20;
15+
i21 = (idx - i23_offset - i22_offset) / p.ne20;
16+
i20 = idx - i23_offset - i22_offset - i21*p.ne20;
17+
}
18+
19+
void main() {
20+
// num_threads * num_iter must equal 512 to match the wg_denoms and get_idx
21+
const uint num_iter = 2;
22+
23+
const uint broadcast2 = uint(p.param2);
24+
const uint broadcast3 = p.param3;
25+
26+
uint idx = get_idx();
27+
28+
[[unroll]] for (uint it = 0; it < num_iter; ++it) {
29+
if (idx < p.ne) {
30+
uint i0, i1, i2, i3;
31+
get_dst_indices(idx, i0, i1, i2, i3);
32+
33+
float acc = 0.0f;
34+
35+
for (uint k = 0; k < p.ne01; k += 1) {
36+
const uint a_block_base = get_aoffset() + (i3 / broadcast3) * p.nb03 + (i2 / broadcast2) * p.nb02 + k * p.nb01;
37+
const uint ib = a_block_base + (i0 / QUANT_K);
38+
const uint r = (i0 % QUANT_K);
39+
const uint iqs = (r % 32u) + 32u * (r / 128u);
40+
const uint sub = (r % 128u) / 32u;
41+
42+
const vec4 v = dequantize4(ib, iqs, 0);
43+
const vec2 dm = get_dm(ib, 0);
44+
45+
float qv = (sub == 0u) ? v.x : (sub == 1u) ? v.y : (sub == 2u) ? v.z : v.w;
46+
const float a_val = qv * dm.x + dm.y;
47+
48+
const uint b_idx = src1_idx(i1, k, i2, i3);
49+
const float b = data_b[get_boffset() + b_idx];
50+
acc += a_val * b;
51+
}
52+
53+
uint d_idx = dst_idx(i0, i1, i2, i3);
54+
data_d[get_doffset() + d_idx] = acc;
55+
}
56+
idx += num_threads;
57+
}
58+
}

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,22 @@ struct block_iq4_nl_packed16
13551355
#define A_TYPE_PACKED16 block_iq4_nl_packed16
13561356
#endif
13571357

1358+
// TQ2_0
1359+
#define QUANT_K_TQ2_0 256
1360+
#define QUANT_R_TQ2_0 4
1361+
1362+
struct block_tq2_0
1363+
{
1364+
uint8_t qs[QUANT_K_TQ2_0/QUANT_R_TQ2_0]; // 256/4 = 64 bytes
1365+
float16_t d;
1366+
};
1367+
1368+
#if defined(DATA_A_TQ2_0)
1369+
#define QUANT_K QUANT_K_TQ2_0
1370+
#define QUANT_R QUANT_R_TQ2_0
1371+
#define A_TYPE block_tq2_0
1372+
#endif
1373+
13581374
#define QUANT_K_MXFP4 32
13591375
#define QUANT_R_MXFP4 2
13601376

0 commit comments

Comments
 (0)