Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,26 @@ static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor
return ggml_mem_ranges_add(ctx->mem_ranges, node);
}

static bool ggml_metal_op_mutates_tq_src1(const ggml_tensor * node) {
if (node == nullptr || node->src[0] == nullptr || node->src[1] == nullptr) {
return false;
}

const bool is_tq_weight = node->src[0]->type == GGML_TYPE_TQ3_1S ||
node->src[0]->type == GGML_TYPE_TQ4_1S;
if (!is_tq_weight) {
return false;
}

switch (node->op) {
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return true;
default:
return false;
}
}

static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
struct ggml_tensor * node = ctx->node(idx);

Expand Down Expand Up @@ -209,6 +229,15 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {

int n_fuse = 1;

// Rotated TQ weight kernels temporarily rotate src1 in-place before the
// matmul and restore it afterwards. The generic range tracker only sees a
// read dependency on src1, so sibling projections can be scheduled as
// concurrent even though they race on the shared activation buffer.
// Gemma4 GEGLU / MoE fan-out is especially sensitive to this hazard.
if (ggml_metal_op_mutates_tq_src1(node)) {
ggml_metal_op_concurrency_reset(ctx);
}

// check if the current node can run concurrently with other nodes before it
// the condition is that:
// - the current node cannot write to any previous src or dst ranges
Expand Down Expand Up @@ -2091,8 +2120,6 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
op->src[0]->type == GGML_TYPE_Q8_0 ||
op->src[0]->type == GGML_TYPE_MXFP4 ||
op->src[0]->type == GGML_TYPE_IQ4_NL ||
op->src[0]->type == GGML_TYPE_TQ3_1S ||
op->src[0]->type == GGML_TYPE_TQ4_1S ||
false) && (ne11 >= 2 && ne11 <= 8)
) ||
(
Expand Down Expand Up @@ -2183,8 +2210,10 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
props_dev->has_simdgroup_mm && ne00 >= 64 &&
(ne11 > ne11_mm_min || ((op->src[0]->type == GGML_TYPE_TQ3_1S || op->src[0]->type == GGML_TYPE_TQ4_1S) && ne11 > 1))) {
// TQ3_1S/TQ4_1S with ne11=1 uses specialized V2.1 fused mul_mv kernel
(ne11 > ne11_mm_min || op->src[0]->type == GGML_TYPE_TQ3_1S || op->src[0]->type == GGML_TYPE_TQ4_1S)) {
// Route all TQ weights through the rotated mul_mm path.
// Gemma4 decode still degrades on the fused mul_mv kernel even after the broader
// TQ backend fixes, while the rotated mul_mm path matches CPU behavior.

const bool is_tq_weight = (op->src[0]->type == GGML_TYPE_TQ3_1S || op->src[0]->type == GGML_TYPE_TQ4_1S);

Expand Down Expand Up @@ -2382,12 +2411,13 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {

const uint32_t r2 = 1;
const uint32_t r3 = 1;
const bool is_tq_weight = (op->src[0]->type == GGML_TYPE_TQ3_1S || op->src[0]->type == GGML_TYPE_TQ4_1S);

// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
// ne20 = n_used_experts
// ne21 = n_rows (batch size)
const int ne21_mm_id_min = 32;
const int ne21_mm_id_min = is_tq_weight ? 1 : 32;

if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
// some Metal matrix data types require aligned pointers
Expand Down Expand Up @@ -2441,8 +2471,6 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_metal_op_concurrency_reset(ctx);

{
const bool is_tq_weight = (op->src[0]->type == GGML_TYPE_TQ3_1S || op->src[0]->type == GGML_TYPE_TQ4_1S);

// TQ weight MoE: pre-rotate activations for rotated dispatch
if (is_tq_weight && ne00 % 32 == 0) {
const int64_t n_act = (int64_t)ne10 * ne11 * ne12 * ne13;
Expand Down
17 changes: 9 additions & 8 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -4823,15 +4823,15 @@ template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;

template [[host_name("kernel_mul_mv_ext_tq3_1s_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_tq3_1s, 2, dequantize_tq3_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq3_1s_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_tq3_1s, 2, dequantize_tq3_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq3_1s_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_tq3_1s, 2, dequantize_tq3_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq3_1s_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_tq3_1s, 2, dequantize_tq3_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq3_1s_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_tq3_1s, 32, dequantize_tq3_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq3_1s_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_tq3_1s, 32, dequantize_tq3_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq3_1s_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_tq3_1s, 32, dequantize_tq3_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq3_1s_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_tq3_1s, 32, dequantize_tq3_1s_t4>;

template [[host_name("kernel_mul_mv_ext_tq4_1s_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_tq4_1s, 2, dequantize_tq4_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq4_1s_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_tq4_1s, 2, dequantize_tq4_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq4_1s_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_tq4_1s, 2, dequantize_tq4_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq4_1s_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_tq4_1s, 2, dequantize_tq4_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq4_1s_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_tq4_1s, 32, dequantize_tq4_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq4_1s_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_tq4_1s, 32, dequantize_tq4_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq4_1s_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_tq4_1s, 32, dequantize_tq4_1s_t4>;
template [[host_name("kernel_mul_mv_ext_tq4_1s_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_tq4_1s, 32, dequantize_tq4_1s_t4>;

template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;
template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;
Expand Down Expand Up @@ -11209,6 +11209,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_
template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>;
template [[host_name("kernel_mul_mm_id_map0_ne20_96")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<96>;

template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
kernel void kernel_mul_mm_id(
Expand Down
6 changes: 6 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3987,6 +3987,8 @@ class GGMLQuantizationType(IntEnum):
TQ2_0 = 35
MXFP4 = 39
NVFP4 = 40
TQ3_1S = 44
TQ4_1S = 45


class ExpertGatingFuncType(IntEnum):
Expand Down Expand Up @@ -4040,6 +4042,8 @@ class LlamaFileType(IntEnum):
MOSTLY_TQ2_0 = 37 # except 1d tensors
MOSTLY_MXFP4_MOE = 38 # except 1d tensors
MOSTLY_NVFP4 = 39 # except 1d tensors
MOSTLY_TQ3_1S = 43 # except 1d tensors
MOSTLY_TQ4_1S = 44 # except 1d tensors

GUESSED = 1024 # not specified in the model file

Expand Down Expand Up @@ -4151,6 +4155,8 @@ class VisionProjectorType:
GGMLQuantizationType.TQ2_0: (256, 2 + 64),
GGMLQuantizationType.MXFP4: (32, 1 + 16),
GGMLQuantizationType.NVFP4: (64, 4 + 32),
GGMLQuantizationType.TQ3_1S: (32, 2 + 2 + 12),
GGMLQuantizationType.TQ4_1S: (32, 2 + 2 + 16),
}


Expand Down