Skip to content

Commit e14d2aa

Browse files
committed
Revert "Merge pull request #33 from zoq/vulkan_tq2_0_type"
This reverts commit e976127, reversing changes made to 9d22adc.
1 parent 3b10875 commit e14d2aa

File tree

13 files changed

+13
-391
lines changed

13 files changed

+13
-391
lines changed

convert_hf_to_gguf.py

Lines changed: 3 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2641,47 +2641,18 @@ def prepare_tensors(self):
26412641
super().prepare_tensors()
26422642

26432643

2644-
@ModelBase.register("BitnetForCausalLM", "BitNetForCausalLM")
2644+
@ModelBase.register("BitnetForCausalLM")
26452645
class BitnetModel(TextModel):
26462646
model_arch = gguf.MODEL_ARCH.BITNET
26472647

2648-
def __init__(self, *args, **kwargs):
2649-
super().__init__(*args, **kwargs)
2650-
self._bitnet_weight_scales: dict[str, torch.Tensor] = {}
2651-
26522648
def set_vocab(self):
2653-
if (self.dir_model / "tokenizer.model").is_file():
2654-
self._set_vocab_sentencepiece()
2655-
else:
2656-
self._set_vocab_gpt2()
2649+
self._set_vocab_sentencepiece()
26572650

26582651
def set_gguf_parameters(self):
26592652
super().set_gguf_parameters()
26602653
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
26612654
self.gguf_writer.add_rope_scaling_factor(1.0)
26622655

2663-
@staticmethod
2664-
def _unpack_bitnet_weights(packed: torch.Tensor) -> torch.Tensor:
2665-
if packed.dtype != torch.uint8:
2666-
raise ValueError(f"Expected packed BitNet weights to be torch.uint8, got {packed.dtype}")
2667-
2668-
values_per_item = 4
2669-
rows = packed.shape[0]
2670-
rest = packed.shape[1:]
2671-
2672-
unpacked_chunks: list[torch.Tensor] = []
2673-
mapping = torch.tensor([-1.0, 0.0, 1.0, 0.0], dtype=torch.float32, device=packed.device)
2674-
2675-
for i in range(values_per_item):
2676-
chunk = (packed >> (2 * i)) & 0x03
2677-
chunk = mapping[chunk.long()].reshape((rows, *rest))
2678-
unpacked_chunks.append(chunk)
2679-
2680-
if not unpacked_chunks:
2681-
raise ValueError("Failed to unpack BitNet weights: no chunks produced")
2682-
2683-
return torch.cat(unpacked_chunks, dim=0)
2684-
26852656
def weight_quant(self, weight: Tensor) -> Tensor:
26862657
dtype = weight.dtype
26872658
weight = weight.float()
@@ -2694,36 +2665,8 @@ def weight_quant(self, weight: Tensor) -> Tensor:
26942665
return result.type(dtype)
26952666

26962667
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2697-
if name.endswith(".weight_scale"):
2698-
weight_name = name[:-13] + ".weight"
2699-
mapped_weight_name = self.map_tensor_name(weight_name)
2700-
if isinstance(data_torch, LazyTorchTensor):
2701-
data_torch = LazyTorchTensor.to_eager(data_torch)
2702-
2703-
scale_tensor = data_torch.to(torch.float32)
2704-
self._bitnet_weight_scales[mapped_weight_name] = scale_tensor
2705-
return []
2706-
27072668
new_name = self.map_tensor_name(name)
27082669

2709-
ternary_weight = False
2710-
2711-
if name.endswith(".weight"):
2712-
scale_tensor = self._bitnet_weight_scales.pop(new_name, None)
2713-
if scale_tensor is not None:
2714-
scale_tensor = scale_tensor.to(torch.float32)
2715-
if scale_tensor.numel() != 1:
2716-
raise ValueError(f"Expected scalar weight_scale for '{name}', got shape {tuple(scale_tensor.shape)}")
2717-
2718-
if isinstance(data_torch, LazyTorchTensor):
2719-
data_torch = LazyTorchTensor.to_eager(data_torch)
2720-
2721-
packed = data_torch.to(torch.uint8)
2722-
unpacked = self._unpack_bitnet_weights(packed)
2723-
scale_value = scale_tensor.reshape(-1)[0].item()
2724-
data_torch = unpacked * scale_value
2725-
ternary_weight = True
2726-
27272670
if any(self.match_model_tensor_name(new_name, key, bid) for key in [
27282671
gguf.MODEL_TENSOR.ATTN_Q,
27292672
gguf.MODEL_TENSOR.ATTN_K,
@@ -2732,7 +2675,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27322675
gguf.MODEL_TENSOR.FFN_UP,
27332676
gguf.MODEL_TENSOR.FFN_DOWN,
27342677
gguf.MODEL_TENSOR.FFN_GATE,
2735-
]) and not ternary_weight:
2678+
]):
27362679
# transform weight into 1/0/-1 (in fp32)
27372680
data_torch = self.weight_quant(data_torch)
27382681

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 71 deletions
Large diffs are not rendered by default.

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -434,30 +434,6 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
434434
}
435435
#endif
436436

437-
#if defined(DATA_A_TQ2_0)
438-
// TQ2_0 ternary dequantization: {0,1,2} -> {-1,0,+1} via (q-1) mapping
439-
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
440-
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
441-
const uint c0 = (vui >> 0) & 3;
442-
const uint c1 = (vui >> 2) & 3;
443-
const float q0 = float(c0) - 1.0f;
444-
const float q1 = float(c1) - 1.0f;
445-
return vec2(q0, q1);
446-
}
447-
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
448-
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
449-
const uint c0 = (vui >> 0) & 3;
450-
const uint c1 = (vui >> 2) & 3;
451-
const uint c2 = (vui >> 4) & 3;
452-
const uint c3 = (vui >> 6) & 3;
453-
const float q0 = float(c0) - 1.0f;
454-
const float q1 = float(c1) - 1.0f;
455-
const float q2 = float(c2) - 1.0f;
456-
const float q3 = float(c3) - 1.0f;
457-
return vec4(q0, q1, q2, q3);
458-
}
459-
#endif
460-
461437
#if defined(DATA_A_MXFP4)
462438
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
463439
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@@ -485,7 +461,7 @@ vec2 get_dm(uint ib, uint a_offset) {
485461
}
486462
#endif
487463

488-
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_TQ2_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
464+
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
489465
vec2 get_dm(uint ib, uint a_offset) {
490466
return vec2(float(data_a[a_offset + ib].d), 0);
491467
}

ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -654,24 +654,6 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
654654
}
655655
#endif
656656

657-
#if defined(DATA_A_TQ2_0)
658-
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufTQ2_0 {
659-
block_tq2_0 block;
660-
};
661-
662-
float16_t dequantFuncTQ2_0(const in decodeBufTQ2_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
663-
{
664-
const float16_t d = bl.block.d;
665-
const uint idx = coordInBlock[1];
666-
667-
const uint byte_idx = ((idx >> 7) << 5) + (idx & 31u);
668-
const uint qsshift = (((idx & 127u) >> 5) << 1);
669-
670-
const uint c = (uint(bl.block.qs[byte_idx]) >> qsshift) & 3u;
671-
return d * float16_t(float(c) - 1.0f);
672-
}
673-
#endif
674-
675657
#if defined(DATA_A_MXFP4)
676658
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
677659
block_mxfp4 block;
@@ -733,8 +715,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
733715
#define dequantFuncA dequantFuncIQ4_XS
734716
#elif defined(DATA_A_IQ4_NL)
735717
#define dequantFuncA dequantFuncIQ4_NL
736-
#elif defined(DATA_A_TQ2_0)
737-
#define dequantFuncA dequantFuncTQ2_0
738718
#elif defined(DATA_A_MXFP4)
739719
#define dequantFuncA dequantFuncMXFP4
740720
#endif

ggml/src/ggml-vulkan/vulkan-shaders/dequant_tq2_0.comp

Lines changed: 0 additions & 36 deletions
This file was deleted.

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_tq2_0.comp

Lines changed: 0 additions & 66 deletions
This file was deleted.

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -450,22 +450,6 @@ void main() {
450450
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
451451
buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
452452
buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
453-
#elif defined(DATA_A_TQ2_0)
454-
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
455-
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
456-
457-
const uint ib = idx / 128; // 2 values per idx (like Q2_K)
458-
const uint iqs = idx % 128; // 0..127
459-
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // Q2_K indexing pattern
460-
const uint qsshift = ((iqs % 64) / 16) * 2; // Q2_K shift: 0,2,4,6
461-
462-
const float d = float(data_a[ib].d);
463-
464-
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
465-
const vec2 v = d * (vec2((qs >> qsshift) & 3) - 1.0f); // (q-1)*d
466-
467-
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
468-
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
469453
#elif defined(DATA_A_Q2_K)
470454
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
471455
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;

ggml/src/ggml-vulkan/vulkan-shaders/out_prod_tq2_0.comp

Lines changed: 0 additions & 58 deletions
This file was deleted.

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,22 +1355,6 @@ struct block_iq4_nl_packed16
13551355
#define A_TYPE_PACKED16 block_iq4_nl_packed16
13561356
#endif
13571357

1358-
// TQ2_0
1359-
#define QUANT_K_TQ2_0 256
1360-
#define QUANT_R_TQ2_0 4
1361-
1362-
struct block_tq2_0
1363-
{
1364-
uint8_t qs[QUANT_K_TQ2_0/QUANT_R_TQ2_0]; // 256/4 = 64 bytes
1365-
float16_t d;
1366-
};
1367-
1368-
#if defined(DATA_A_TQ2_0)
1369-
#define QUANT_K QUANT_K_TQ2_0
1370-
#define QUANT_R QUANT_R_TQ2_0
1371-
#define A_TYPE block_tq2_0
1372-
#endif
1373-
13741358
#define QUANT_K_MXFP4 32
13751359
#define QUANT_R_MXFP4 2
13761360

0 commit comments

Comments
 (0)