diff --git a/include/ggml.h b/include/ggml.h index 1988d16dc4..da6370800b 100644 --- a/include/ggml.h +++ b/include/ggml.h @@ -571,6 +571,7 @@ extern "C" { GGML_OP_OPT_STEP_SGD, GGML_OP_GLU, + GGML_OP_ROPE_FLUX, GGML_OP_COUNT, }; @@ -1860,6 +1861,15 @@ extern "C" { float beta_slow), "use ggml_rope_ext_inplace instead"); + // Fused Flux-style RoPE: applies rotation using precomputed PE matrix and permutes output layout. + // a: [d_head, n_head, L, N] (Q or K tensor, may be non-contiguous) + // b: [2, 2, d_head/2, L] (precomputed rotation matrix [[cos,-sin],[sin,cos]]), or NULL for permute-only + // result: [d_head, L, N*n_head] (contiguous, layout for flash attention) + GGML_API struct ggml_tensor * ggml_rope_flux( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + // compute correction dims for YaRN RoPE scaling GGML_API void ggml_rope_yarn_corr_dims( int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]); diff --git a/src/ggml-cpu/ggml-cpu.c b/src/ggml-cpu/ggml-cpu.c index b1de2ae871..a7f7ebefeb 100644 --- a/src/ggml-cpu/ggml-cpu.c +++ b/src/ggml-cpu/ggml-cpu.c @@ -1868,6 +1868,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rope_back(params, tensor); } break; + case GGML_OP_ROPE_FLUX: + { + ggml_compute_forward_rope_flux(params, tensor); + } break; case GGML_OP_CLAMP: { ggml_compute_forward_clamp(params, tensor); @@ -2296,6 +2300,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: + case GGML_OP_ROPE_FLUX: case GGML_OP_ADD_REL_POS: { n_tasks = n_threads; diff --git a/src/ggml-cpu/ggml-cpu.cpp b/src/ggml-cpu/ggml-cpu.cpp index f4713a4218..115cb77638 100644 --- a/src/ggml-cpu/ggml-cpu.cpp +++ b/src/ggml-cpu/ggml-cpu.cpp @@ -449,6 +449,11 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st } case GGML_OP_IM2COL_BACK: return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; + case GGML_OP_IM2COL_3D: + return src1->type == GGML_TYPE_F32 && + ((op->type == GGML_TYPE_F32 && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) || + (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F16)); case GGML_OP_GET_ROWS_BACK: return src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16; case GGML_OP_OUT_PROD: diff --git a/src/ggml-cpu/ops.cpp b/src/ggml-cpu/ops.cpp index 48c8964361..443a72bda1 100644 --- a/src/ggml-cpu/ops.cpp +++ b/src/ggml-cpu/ops.cpp @@ -5840,6 +5840,87 @@ void ggml_compute_forward_rope_back( } } +// ggml_compute_forward_rope_flux + +void ggml_compute_forward_rope_flux( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0 != NULL); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(dst)); + + const int64_t d_head = src0->ne[0]; + const int64_t n_head = src0->ne[1]; + const int64_t L = src0->ne[2]; + const int64_t N = src0->ne[3]; + + GGML_ASSERT(d_head > 0 && n_head > 0 && L > 0 && N > 0); + GGML_ASSERT(d_head % 2 == 0); + + if (src1 != NULL) { + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src1->ne[0] == 2); + GGML_ASSERT(src1->ne[1] == 2); + GGML_ASSERT(src1->ne[2] == d_head / 2); + GGML_ASSERT(src1->ne[3] == L); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t n_rows = L * n_head * N; + const int64_t dr = (n_rows + nth - 1) / nth; + const int64_t r0 = dr * ith; + const int64_t r1 = MIN(r0 + dr, n_rows); + + float * dst_data = (float *) dst->data; + const char * src0_data = (const char *) src0->data; + const char * src1_data = src1 ? (const char *) src1->data : NULL; + + for (int64_t row = r0; row < r1; ++row) { + const int64_t l = row % L; + const int64_t bh = row / L; + const int64_t h = bh % n_head; + const int64_t n = bh / n_head; + + float * dst_row = dst_data + row * d_head; + const char * src0_row = src0_data + + n * src0->nb[3] + + l * src0->nb[2] + + h * src0->nb[1]; + + if (src1_data == NULL) { + for (int64_t d = 0; d < d_head; ++d) { + dst_row[d] = *(const float *) (src0_row + d * src0->nb[0]); + } + continue; + } + + for (int64_t pair = 0; pair < d_head / 2; ++pair) { + const char * src0_pair = src0_row + (2 * pair) * src0->nb[0]; + const char * src1_pair = src1_data + + l * src1->nb[3] + + pair * src1->nb[2]; + + const float x_even = *(const float *) src0_pair; + const float x_odd = *(const float *) (src0_pair + src0->nb[0]); + + const float pe_00 = *(const float *) src1_pair; + const float pe_10 = *(const float *) (src1_pair + src1->nb[0]); + const float pe_01 = *(const float *) (src1_pair + src1->nb[1]); + const float pe_11 = *(const float *) (src1_pair + src1->nb[1] + src1->nb[0]); + + dst_row[2 * pair] = x_even * pe_00 + x_odd * pe_10; + dst_row[2 * pair + 1] = x_even * pe_01 + x_odd * pe_11; + } + } +} + // ggml_compute_forward_conv_transpose_1d static void ggml_compute_forward_conv_transpose_1d_f16_f32( diff --git a/src/ggml-cpu/ops.h b/src/ggml-cpu/ops.h index 0fdfee7976..d08d1294dc 100644 --- a/src/ggml-cpu/ops.h +++ b/src/ggml-cpu/ops.h @@ -62,6 +62,7 @@ void ggml_compute_forward_soft_max(const struct ggml_compute_params * params, st void ggml_compute_forward_soft_max_ext_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rope(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rope_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_rope_flux(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/src/ggml-metal/ggml-metal-device.m b/src/ggml-metal/ggml-metal-device.m index f4dd568d94..784051052e 100644 --- a/src/ggml-metal/ggml-metal-device.m +++ b/src/ggml-metal/ggml-metal-device.m @@ -7,6 +7,7 @@ #include #include +#include #ifndef TARGET_OS_VISION #define TARGET_OS_VISION 0 @@ -1041,10 +1042,26 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0])); case GGML_OP_ROPE: return true; + case GGML_OP_ROPE_FLUX: + if (op->src[0] == nil || op->src[0]->type != GGML_TYPE_F32 || + op->src[0]->ne[0] <= 0 || op->src[0]->ne[1] <= 0 || op->src[0]->ne[2] <= 0 || op->src[0]->ne[3] <= 0 || + op->src[0]->ne[0] % 2 != 0 || + ggml_nelements(op) > INT32_MAX) { + return false; + } + if (op->src[1] == nil) { + return true; + } + return op->src[1]->type == GGML_TYPE_F32 && + op->src[1]->ne[0] == 2 && + op->src[1]->ne[1] == 2 && + op->src[0]->ne[0] == 2 * op->src[1]->ne[2] && + op->src[0]->ne[2] == op->src[1]->ne[3]; case GGML_OP_IM2COL: return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); case GGML_OP_CONV_2D: - return ggml_is_contiguous(op->src[0]) && + return has_simdgroup_mm && + ggml_is_contiguous(op->src[0]) && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); diff --git a/src/ggml-metal/ggml-metal-impl.h b/src/ggml-metal/ggml-metal-impl.h index 9de3cc758a..d222f45f98 100644 --- a/src/ggml-metal/ggml-metal-impl.h +++ b/src/ggml-metal/ggml-metal-impl.h @@ -261,6 +261,21 @@ typedef struct { bool src2; } ggml_metal_kargs_rope; +typedef struct { + int32_t d_head; + int32_t n_head; + int32_t L; + int32_t N; + uint64_t nb00; // x strides (may be non-contiguous) + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + uint64_t pe_nb0; // pe strides + uint64_t pe_nb1; + uint64_t pe_nb2; + uint64_t pe_nb3; +} ggml_metal_kargs_rope_flux; + typedef struct { int32_t ne11; int32_t ne_12_2; // assume K and V are same shape diff --git a/src/ggml-metal/ggml-metal-ops.cpp b/src/ggml-metal/ggml-metal-ops.cpp index 11fab9c19a..5e008090a8 100644 --- a/src/ggml-metal/ggml-metal-ops.cpp +++ b/src/ggml-metal/ggml-metal-ops.cpp @@ -374,6 +374,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_rope(ctx, idx); } break; + case GGML_OP_ROPE_FLUX: + { + n_fuse = ggml_metal_op_rope_flux(ctx, idx); + } break; case GGML_OP_IM2COL: { n_fuse = ggml_metal_op_im2col(ctx, idx); @@ -3284,6 +3288,74 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_rope_flux(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(op->src[0] != nullptr); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); + + const bool has_pe = op->src[1] != nullptr; + if (has_pe) { + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->ne[0] == 2); + GGML_ASSERT(op->src[1]->ne[1] == 2); + GGML_ASSERT(op->src[0]->ne[0] == 2 * op->src[1]->ne[2]); + GGML_ASSERT(op->src[0]->ne[2] == op->src[1]->ne[3]); + } + + auto to_i32_dim = [](int64_t dim) { + GGML_ASSERT(dim > 0); + GGML_ASSERT(dim <= std::numeric_limits::max()); + return (int32_t) dim; + }; + + const int32_t d_head = to_i32_dim(op->src[0]->ne[0]); + const int32_t n_head = to_i32_dim(op->src[0]->ne[1]); + const int32_t L = to_i32_dim(op->src[0]->ne[2]); + const int32_t N = to_i32_dim(op->src[0]->ne[3]); + const int64_t total = ggml_nelements(op); + GGML_ASSERT(total > 0); + GGML_ASSERT(total <= std::numeric_limits::max()); + + ggml_metal_kargs_rope_flux args = { + /*.d_head =*/ d_head, + /*.n_head =*/ n_head, + /*.L =*/ L, + /*.N =*/ N, + /*.nb00 =*/ op->src[0]->nb[0], + /*.nb01 =*/ op->src[0]->nb[1], + /*.nb02 =*/ op->src[0]->nb[2], + /*.nb03 =*/ op->src[0]->nb[3], + /*.pe_nb0 =*/ has_pe ? op->src[1]->nb[0] : 0, + /*.pe_nb1 =*/ has_pe ? op->src[1]->nb[1] : 0, + /*.pe_nb2 =*/ has_pe ? op->src[1]->nb[2] : 0, + /*.pe_nb3 =*/ has_pe ? op->src[1]->nb[3] : 0, + }; + + if (has_pe) { + auto pipeline = ggml_metal_library_compile_pipeline(lib, "kernel_rope_flux", "kernel_rope_flux", NULL); + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + } else { + auto pipeline = ggml_metal_library_compile_pipeline(lib, "kernel_permute_cont_021", "kernel_permute_cont_021", NULL); + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2); + } + + const int nth = 256; + ggml_metal_encoder_dispatch_threadgroups(enc, ((int) total + nth - 1) / nth, 1, 1, nth, 1, 1); + + return 1; +} + int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -3380,46 +3452,49 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) { const int32_t d1 = ((const int32_t *) op->op_params)[5]; ggml_metal_kargs_conv_2d args = { - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.IW =*/ ne10, - /*.IH =*/ ne11, - /*.KW =*/ ne00, - /*.KH =*/ ne01, - /*.IC =*/ ne02, - /*.OC =*/ ne03, - /*.OW =*/ ne0, - /*.OH =*/ ne1, - /*.N =*/ ne3, - /*.s0 =*/ s0, - /*.s1 =*/ s1, - /*.p0 =*/ p0, - /*.p1 =*/ p1, - /*.d0 =*/ d0, - /*.d1 =*/ d1, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.IW =*/ ne10, + /*.IH =*/ ne11, + /*.KW =*/ ne00, + /*.KH =*/ ne01, + /*.IC =*/ ne02, + /*.OC =*/ ne03, + /*.OW =*/ ne0, + /*.OH =*/ ne1, + /*.N =*/ ne3, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.p0 =*/ p0, + /*.p1 =*/ p1, + /*.d0 =*/ d0, + /*.d1 =*/ d1, }; auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op); - int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline); - nth = std::min(nth, 256); - nth = std::max(nth, 1); + const int M_TILE = 64; + const int N_TILE = 64; + const int K_TILE = 32; - const uint64_t n_out = ggml_nelements(op); + const int M = ne0 * ne1; + const int tg_x = ((int) ne03 + N_TILE - 1) / N_TILE; + const int tg_y = (M + M_TILE - 1) / M_TILE; + const int tg_z = ne3; - uint64_t tg = (n_out + nth - 1)/nth; - tg = std::max(tg, 1); - tg = std::min(tg, (uint64_t) std::numeric_limits::max()); + const size_t smem = GGML_PAD(std::max( + (size_t)(M_TILE * K_TILE + K_TILE * N_TILE) * sizeof(uint16_t), + (size_t)(M_TILE * N_TILE) * sizeof(float)), 16); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3427,7 +3502,8 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1); + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + ggml_metal_encoder_dispatch_threadgroups(enc, tg_x, tg_y, tg_z, 256, 1, 1); return 1; } diff --git a/src/ggml-metal/ggml-metal-ops.h b/src/ggml-metal/ggml-metal-ops.h index a261a520c1..267fcd6e64 100644 --- a/src/ggml-metal/ggml-metal-ops.h +++ b/src/ggml-metal/ggml-metal-ops.h @@ -72,6 +72,7 @@ int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_rope_flux (ggml_metal_op_t ctx, int idx); int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); diff --git a/src/ggml-metal/ggml-metal.metal b/src/ggml-metal/ggml-metal.metal index e74dde5c5b..0ddaed7ca6 100644 --- a/src/ggml-metal/ggml-metal.metal +++ b/src/ggml-metal/ggml-metal.metal @@ -4491,95 +4491,250 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; //template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext; //template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext; +// Fused Flux RoPE: applies interleaved rotary embedding and permutes output layout in one pass. +// Input x: [d_head, n_head, L, N] (may be non-contiguous) +// Input pe: [2, 2, d_head/2, L] (precomputed [[cos,-sin],[sin,cos]]) +// Output: [d_head, L, N*n_head] (contiguous, ready for flash attention) +kernel void kernel_rope_flux( + constant ggml_metal_kargs_rope_flux & args, + device const char * src0, + device const char * src1, + device float * dst, + uint tid [[thread_position_in_grid]]) { + + const int d_head = args.d_head; + const int n_head = args.n_head; + const int L = args.L; + const int N = args.N; + + if (d_head <= 0 || n_head <= 0 || L <= 0 || N <= 0) return; + + const uint total = (uint)d_head * (uint)L * (uint)N * (uint)n_head; + if (tid >= total) return; + + const int d = tid % d_head; + const int l = (tid / d_head) % L; + const int bh = tid / (d_head * L); + const int h = bh % n_head; + const int n = bh / n_head; + + const int pair = d / 2; + + const float x_even = *(device const float *)(src0 + n*args.nb03 + l*args.nb02 + h*args.nb01 + (2*pair) *args.nb00); + const float x_odd = *(device const float *)(src0 + n*args.nb03 + l*args.nb02 + h*args.nb01 + (2*pair+1)*args.nb00); + + const int comp = d % 2; + const float pe_col0 = *(device const float *)(src1 + l*args.pe_nb3 + pair*args.pe_nb2 + comp*args.pe_nb1 + 0*args.pe_nb0); + const float pe_col1 = *(device const float *)(src1 + l*args.pe_nb3 + pair*args.pe_nb2 + comp*args.pe_nb1 + 1*args.pe_nb0); + + dst[bh * L * d_head + l * d_head + d] = x_even * pe_col0 + x_odd * pe_col1; +} + +// Fused permute(0,2,1,3)+cont: transposes dims 1 and 2 and produces contiguous output. +// Input: [ne0, ne1, ne2, ne3] (may be non-contiguous) +// Output: [ne0, ne2, ne1*ne3] (contiguous) +kernel void kernel_permute_cont_021( + constant ggml_metal_kargs_rope_flux & args, + device const char * src0, + device float * dst, + uint tid [[thread_position_in_grid]]) { + + const int ne0 = args.d_head; + const int ne1 = args.n_head; + const int ne2 = args.L; + const int ne3 = args.N; + + if (ne0 <= 0 || ne1 <= 0 || ne2 <= 0 || ne3 <= 0) return; + + const uint total = (uint)ne0 * (uint)ne1 * (uint)ne2 * (uint)ne3; + if (tid >= total) return; + + const int i0 = tid % ne0; + const int i2 = (tid / ne0) % ne2; + const int bh = tid / (ne0 * ne2); + const int i1 = bh % ne1; + const int i3 = bh / ne1; + + const float val = *(device const float *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + + dst[bh * ne2 * ne0 + i2 * ne0 + i0] = val; +} + +// Implicit GEMM conv2d using simdgroup matrix operations. +// +// C[M,N] = A[M,K] * B[K,N] where M=OH*OW, N=OC, K=IC*KH*KW. +// A is implicit im2col (indices computed on the fly), B is weights. +// +// 64×32 output tile, 8 simdgroups each owning 4 accumulators (8×32 strip). +// Half-precision loads, float accumulators. +// Weight loading exploits contiguity (single offset, no index decomposition). +// A-tile loading precomputes (oh,ow) per row, uses incremental k decomposition. + +#define CONV2D_GEMM_M 64 +#define CONV2D_GEMM_N 64 +#define CONV2D_GEMM_K 32 + template kernel void kernel_conv_2d( constant ggml_metal_kargs_conv_2d & args, device const char * weights, device const char * src, device char * dst, + threadgroup char * shared_mem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { - const uint threads_per_tg = ntg.x * ntg.y * ntg.z; - const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x; - const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x; - const uint thread_index = tg_index * threads_per_tg + local_thread; - const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z; - const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW; + const int M = args.OH * args.OW; + const int N = args.OC; + const int KHW = args.KH * args.KW; + const int K = args.IC * KHW; - for (uint64_t index = thread_index; index < total_outputs; index += total_threads) { - uint64_t tmp = index; + const int n_start = (int)tgpig.x * CONV2D_GEMM_N; + const int m_start = (int)tgpig.y * CONV2D_GEMM_M; + const int batch = (int)tgpig.z; - const int32_t ow = tmp % args.OW; tmp /= args.OW; - const int32_t oh = tmp % args.OH; tmp /= args.OH; - const int32_t oc = tmp % args.OC; tmp /= args.OC; - const int32_t n = tmp; + if (m_start >= M || n_start >= N) return; - float acc = 0.0f; + const uint64_t src_base = (uint64_t)batch * args.nb13; + const int lid = sgitg * 32 + tiisg; - const int32_t base_x = ow*args.s0 - args.p0; - const int32_t base_y = oh*args.s1 - args.p1; + // Each of 8 simdgroups owns one 8-row strip across all 64 N columns (8 sub-tiles). + const int sg_m = sgitg * 8; - int32_t ky_start = 0; - if (base_y < 0) { - ky_start = (-base_y + args.d1 - 1)/args.d1; - } - int32_t ky_end = args.KH; - const int32_t y_max = args.IH - 1 - base_y; - if (y_max < 0) { - ky_end = ky_start; - } else if (base_y + (args.KH - 1)*args.d1 >= args.IH) { - ky_end = min(ky_end, y_max/args.d1 + 1); - } + simdgroup_float8x8 C[8]; + for (int i = 0; i < 8; i++) { + C[i] = make_filled_simdgroup_matrix(0.0f); + } + + threadgroup half * sa = (threadgroup half *)shared_mem; + threadgroup half * sb = sa + CONV2D_GEMM_M * CONV2D_GEMM_K; + + // A-tile loading: 256 threads cover 64 rows × 32 cols = 2048 elements (8 per thread). + // Assign threads to rows so (oh,ow) is computed once and reused across k elements. + // 256 threads / 64 rows = 4 threads per row, each handles 8 consecutive k elements. + const int a_row = lid / 4; // 0..63 + const int a_k_base = (lid % 4) * 8; // 0, 8, 16, 24 + const int a_m = m_start + a_row; + const int a_oh = a_m < M ? (a_m / args.OW) : 0; + const int a_ow = a_m < M ? (a_m - a_oh * args.OW) : 0; + const int a_by = a_oh * args.s1 - args.p1; + const int a_bx = a_ow * args.s0 - args.p0; + + const bool fast_1x1 = KHW == 1 && args.s0 == 1 && args.s1 == 1 && args.p0 == 0 && args.p1 == 0; + + // Precompute src offset for the strict 1x1/no-padding/no-stride fast path. + const uint64_t a_src_spatial = src_base + + (uint64_t)a_oh * args.nb11 + + (uint64_t)a_ow * args.nb10; + + for (int k_start = 0; k_start < K; k_start += CONV2D_GEMM_K) { + // --- Load A tile: implicit im2col --- + if (fast_1x1) { + // Fast path for 1x1 convolutions where output coordinates map directly to input coordinates. + for (int dk = 0; dk < 8; ++dk) { + const int ic = k_start + a_k_base + dk; + half val = 0; + if (a_m < M && ic < args.IC) { + val = (half)(*(device const float *)(src + a_src_spatial + + (uint64_t)ic * args.nb12)); + } + sa[a_row * CONV2D_GEMM_K + a_k_base + dk] = val; + } + } else { + // General path with incremental (ic, ky, kx) decomposition + const int k0 = k_start + a_k_base; + int ic = k0 / KHW; + int rem = k0 - ic * KHW; + int ky = rem / args.KW; + int kx = rem - ky * args.KW; + + for (int dk = 0; dk < 8; ++dk) { + const int k = k0 + dk; + half val = 0; + + if (a_m < M && k < K) { + const int iy = a_by + ky * args.d1; + const int ix = a_bx + kx * args.d0; + + if (iy >= 0 && iy < args.IH && ix >= 0 && ix < args.IW) { + val = (half)(*(device const float *)(src + src_base + + (uint64_t)ic * args.nb12 + + (uint64_t)iy * args.nb11 + + (uint64_t)ix * args.nb10)); + } + } + sa[a_row * CONV2D_GEMM_K + a_k_base + dk] = val; - int32_t kx_start = 0; - if (base_x < 0) { - kx_start = (-base_x + args.d0 - 1)/args.d0; - } - int32_t kx_end = args.KW; - const int32_t x_max = args.IW - 1 - base_x; - if (x_max < 0) { - kx_end = kx_start; - } else if (base_x + (args.KW - 1)*args.d0 >= args.IW) { - kx_end = min(kx_end, x_max/args.d0 + 1); + if (++kx >= args.KW) { + kx = 0; + if (++ky >= args.KH) { + ky = 0; + ++ic; + } + } + } } - if (ky_start < ky_end && kx_start < kx_end) { - const uint64_t src_base_n = (uint64_t) n * args.nb13; - const uint64_t w_base_oc = (uint64_t) oc * args.nb03; + // --- Load B tile: weights are contiguous, no index decomposition needed --- + for (int i = lid; i < CONV2D_GEMM_K * CONV2D_GEMM_N; i += 256) { + const int kl = i / CONV2D_GEMM_N; + const int nl = i % CONV2D_GEMM_N; + const int k = k_start + kl; + const int oc = n_start + nl; - for (int32_t ic = 0; ic < args.IC; ++ic) { - const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12; - const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02; - - for (int32_t ky = ky_start; ky < ky_end; ++ky) { - const int32_t iy = base_y + ky*args.d1; - const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11; - const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01; + half val = 0; + if (k < K && oc < N) { + val = (half)(*(device const TK *)(weights + + (uint64_t)oc * args.nb03 + + (uint64_t)k * args.nb00)); + } + sb[i] = val; + } - for (int32_t kx = kx_start; kx < kx_end; ++kx) { - const int32_t ix = base_x + kx*args.d0; - const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10; - const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00; + threadgroup_barrier(mem_flags::mem_threadgroup); - const float x = *(device const float *)(src + src_offs); - const float w = (float) (*(device const TK *)(weights + w_offs)); + // --- GEMM: each simdgroup loads A once, multiplies with 8 B sub-tiles --- + for (int kk = 0; kk < CONV2D_GEMM_K; kk += 8) { + simdgroup_half8x8 A; + simdgroup_load(A, sa + sg_m * CONV2D_GEMM_K + kk, CONV2D_GEMM_K); - acc += x * w; - } - } + for (int ni = 0; ni < 8; ni++) { + simdgroup_half8x8 B; + simdgroup_load(B, sb + kk * CONV2D_GEMM_N + ni * 8, CONV2D_GEMM_N); + simdgroup_multiply_accumulate(C[ni], A, B, C[ni]); } } - const uint64_t dst_offs = - (uint64_t) n * args.nb3 + - (uint64_t) oc * args.nb2 + - (uint64_t) oh * args.nb1 + - (uint64_t) ow * args.nb0; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // --- Store accumulators and write to global --- + // Each simdgroup writes to non-overlapping rows [sg_m..sg_m+7]. + threadgroup float * so = (threadgroup float *)shared_mem; - *(device float *)(dst + dst_offs) = acc; + for (int ni = 0; ni < 8; ni++) { + simdgroup_store(C[ni], so + sg_m * CONV2D_GEMM_N + ni * 8, CONV2D_GEMM_N); + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Each simdgroup's 32 threads write their own 8×64 block (16 elements per thread) + for (int i = tiisg; i < 8 * CONV2D_GEMM_N; i += 32) { + const int ml = sg_m + i / CONV2D_GEMM_N; + const int nl = i % CONV2D_GEMM_N; + const int m = m_start + ml; + const int oc = n_start + nl; + + if (m < M && oc < N) { + const int oh = m / args.OW; + const int ow = m - oh * args.OW; + + *(device float *)(dst + + (uint64_t)batch * args.nb3 + + (uint64_t)oc * args.nb2 + + (uint64_t)oh * args.nb1 + + (uint64_t)ow * args.nb0) = so[ml * CONV2D_GEMM_N + nl]; + } } } @@ -4589,10 +4744,10 @@ kernel void kernel_conv_2d( device const char * weights, device const char * src, device char * dst, + threadgroup char * shared_mem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]); + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]); template [[host_name("kernel_conv_2d_f16_f32")]] kernel void kernel_conv_2d( @@ -4600,10 +4755,10 @@ kernel void kernel_conv_2d( device const char * weights, device const char * src, device char * dst, + threadgroup char * shared_mem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tgpg[[threadgroups_per_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]); + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]); typedef void (conv_transpose_1d_t)( constant ggml_metal_kargs_conv_transpose_1d & args, @@ -5542,60 +5697,60 @@ void kernel_flash_attn_ext_impl( // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) (k + ic*args.nb11); - threadgroup const q_t * pq = sq; - threadgroup s_t * ps = ss; - - pk += sgitg*(8*NS10); - ps += sgitg*(8*1); - static_assert((C/8) % NSG == 0, ""); - constexpr short NC = (C/8)/NSG; - FOR_UNROLL (short cc = 0; cc < NC; ++cc) { - qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); + for (short qb = 0; qb < Q; qb += 8) { + device const k_t * pk = (device const k_t *) (k + ic*args.nb11); + threadgroup const q_t * pq = sq + qb*DK; + threadgroup s_t * ps = ss + qb*SH; - if (DK % 16 != 0) { - k8x8_t mk; - q8x8_t mq; + pk += sgitg*(8*NS10); + ps += sgitg*(8*1); - FOR_UNROLL (short i = 0; i < DK8; ++i) { - simdgroup_barrier(mem_flags::mem_none); + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { + qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - simdgroup_load(mk, pk + 8*i, NS10, 0, true); - simdgroup_load(mq, pq + 8*i, DK); + if (DK % 16 != 0) { + k8x8_t mk; + q8x8_t mq; - simdgroup_barrier(mem_flags::mem_none); + FOR_UNROLL (short i = 0; i < DK8; ++i) { + simdgroup_barrier(mem_flags::mem_none); - simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - } - } else { - k8x8_t mk[2]; - q8x8_t mq[2]; + simdgroup_load(mk, pk + 8*i, NS10, 0, true); + simdgroup_load(mq, pq + 8*i, DK); + + simdgroup_barrier(mem_flags::mem_none); + + simdgroup_multiply_accumulate(mqk, mq, mk, mqk); + } + } else { + k8x8_t mk[2]; + q8x8_t mq[2]; - // note: too much unroll can tank the performance for large heads - #pragma unroll (MIN(DK8/2, 4*NSG)) - for (short i = 0; i < DK8/2; ++i) { - simdgroup_barrier(mem_flags::mem_none); + #pragma unroll (MIN(DK8/2, 4*NSG)) + for (short i = 0; i < DK8/2; ++i) { + simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); - simdgroup_load(mq[1], pq + 1*8 + 16*i, DK); + simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); + simdgroup_load(mq[1], pq + 1*8 + 16*i, DK); - simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true); - simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true); + simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true); + simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true); - simdgroup_barrier(mem_flags::mem_none); + simdgroup_barrier(mem_flags::mem_none); - simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); - simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); + simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); + simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); + } } - } - simdgroup_store(mqk, ps, SH, 0, false); + simdgroup_store(mqk, ps, SH, 0, false); - pk += 8*(NSG*NS10); - ps += 8*(NSG); + pk += 8*(NSG*NS10); + ps += 8*(NSG); + } } } else { // TODO: this is the quantized K cache branch - not optimized yet @@ -5715,77 +5870,79 @@ void kernel_flash_attn_ext_impl( constexpr short NO = PV8/NSG; - o8x8_t lo[NO]; + for (short qb = 0; qb < Q; qb += 8) { + o8x8_t lo[NO]; - { - auto sot = so + 8*sgitg; + { + auto sot = so + qb*PV + 8*sgitg; - FOR_UNROLL (short ii = 0; ii < NO; ++ii) { - simdgroup_load(lo[ii], sot, PV, 0, false); + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_load(lo[ii], sot, PV, 0, false); - sot += 8*NSG; + sot += 8*NSG; + } } - } - { - device const v_t * pv = (device const v_t *) (v + ic*args.nb21); + { + device const v_t * pv = (device const v_t *) (v + ic*args.nb21); - pv += 8*sgitg; + pv += 8*sgitg; - if (DV <= 64) { - FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { - s8x8_t vs; - simdgroup_load(vs, ss + 8*cc, SH, 0, false); + if (DV <= 64) { + FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + qb*SH + 8*cc, SH, 0, false); - FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { - v8x8_t mv[2]; + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[2]; + + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false); - simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false); - simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false); + simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]); + } - simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]); - simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]); + pv += 8*NS20; } + } else { + constexpr short NC = (C/8)/2; - pv += 8*NS20; - } - } else { - constexpr short NC = (C/8)/2; + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { + s8x8_t vs[2]; - FOR_UNROLL (short cc = 0; cc < NC; ++cc) { - s8x8_t vs[2]; + simdgroup_load(vs[0], ss + qb*SH + 16*cc + 0, SH, 0, false); + simdgroup_load(vs[1], ss + qb*SH + 16*cc + 8, SH, 0, false); - simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); - simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false); + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[4]; - FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { - v8x8_t mv[4]; + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); - simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); - simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); - simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); - simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]); + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]); + } - simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]); - simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]); - simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]); - simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]); + pv += 2*8*NS20; } - - pv += 2*8*NS20; } } - } - { - auto sot = so + 8*sgitg; + { + auto sot = so + qb*PV + 8*sgitg; - FOR_UNROLL (short ii = 0; ii < NO; ++ii) { - simdgroup_store(lo[ii], sot, PV, 0, false); + FOR_UNROLL (short ii = 0; ii < NO; ++ii) { + simdgroup_store(lo[ii], sot, PV, 0, false); - sot += 8*NSG; - } + sot += 8*NSG; + } } + } // qb loop } else { // TODO: this is the quantized V cache branch - not optimized yet @@ -8864,17 +9021,14 @@ kernel void kernel_mul_mm( for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { #ifndef GGML_METAL_HAS_TENSOR - // load data and store to threadgroup memory + // load A-tile to threadgroup memory (same as before, 256 threads cover 64 rows with redundancy) if (is_same::value && FC_mul_mm_bc_inp) { threadgroup_barrier(mem_flags::mem_threadgroup); - // no need for dequantization for (short i = 0; i < 16; i++) { const short sx = 2*il0 + i/8; const short sy = (tiitg/NL0)/8; - //const short lx = i%8; - //const short ly = (tiitg/NL0)%8; const short lx = (tiitg/NL0)%8; const short ly = i%8; @@ -8892,16 +9046,11 @@ kernel void kernel_mul_mm( const short sx = 2*il0 + i/8; const short sy = (tiitg/NL0)/8; - //const short lx = i%8; - //const short ly = (tiitg/NL0)%8; const short lx = (tiitg/NL0)%8; const short ly = i%8; const short ib = 8*sx + sy; - // NOTE: this is massively slower.. WTF? - //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4]; - *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4]; } } @@ -8934,19 +9083,16 @@ kernel void kernel_mul_mm( *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y)); } #else - // load data and store to threadgroup memory + // load A-tile if (is_same::value && FC_mul_mm_bc_inp) { threadgroup_barrier(mem_flags::mem_threadgroup); - // no need for dequantization for (short i = 0; i < 16; i++) { const short sx = 2*il0 + i/8; const short sy = (tiitg/NL0)/8; const short lx = i%8; const short ly = (tiitg/NL0)%8; - //const short lx = (tiitg/NL0)%8; - //const short ly = i%8; *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; } @@ -8962,13 +9108,12 @@ kernel void kernel_mul_mm( const short lx = i%8; const short ly = (tiitg/NL0)%8; - //const short lx = (tiitg/NL0)%8; - //const short ly = i%8; *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; } } + // load B-tile if (FC_mul_mm_bc_inp) { for (short i = 0; i < 8; ++i) { const short sx = (tiitg%NL1); @@ -8976,8 +9121,6 @@ kernel void kernel_mul_mm( const short lx = i; const short ly = (tiitg/NL1)%8; - //const short lx = (tiitg/NL1)%8; - //const short ly = i; *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; } @@ -8985,10 +9128,7 @@ kernel void kernel_mul_mm( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - //const short lx = i; const short ly = (tiitg/NL1)%8; - //const short lx = (tiitg/NL1)%8; - //const short ly = i; *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); } @@ -9002,7 +9142,7 @@ kernel void kernel_mul_mm( threadgroup_barrier(mem_flags::mem_threadgroup); #ifndef GGML_METAL_HAS_TENSOR - // load matrices from threadgroup memory and conduct outer products + // 8 simdgroups in 2×4 layout: sgitg%2 → M (32 rows), sgitg/2 → N (16 cols) threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2)); threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2)); @@ -9037,7 +9177,6 @@ kernel void kernel_mul_mm( } if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) { - // if no bounds checks on the output are needed, we can directly write to device memory #ifdef GGML_METAL_HAS_TENSOR device float * C = (device float *) dst + r0 + \ @@ -9055,7 +9194,6 @@ kernel void kernel_mul_mm( } #endif } else { - // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; @@ -9076,7 +9214,7 @@ kernel void kernel_mul_mm( device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0; device float4 * D4 = (device float4 *) D; - threadgroup float * C = temp_str + (j*NR0); + threadgroup float * C = sc + (j*NR0); threadgroup float4 * C4 = (threadgroup float4 *) C; int i = 0; diff --git a/src/ggml-vulkan/ggml-vulkan.cpp b/src/ggml-vulkan/ggml-vulkan.cpp index 3925e5e6fc..44d019adb5 100644 --- a/src/ggml-vulkan/ggml-vulkan.cpp +++ b/src/ggml-vulkan/ggml-vulkan.cpp @@ -14912,7 +14912,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_IM2COL_3D: return op->src[1]->type == GGML_TYPE_F32 - && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); + && ((op->type == GGML_TYPE_F32 && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16)) || + (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F16)); case GGML_OP_TIMESTEP_EMBEDDING: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONV_2D_DW: diff --git a/src/ggml.c b/src/ggml.c index 1725ad1654..4b3c4801d1 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -1045,9 +1045,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "OPT_STEP_SGD", "GLU", + "ROPE_FLUX", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1154,9 +1155,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "sgd(x)", "glu(x)", + "rope_flux(x)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4258,6 +4260,42 @@ static float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, fl return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); } +struct ggml_tensor * ggml_rope_flux( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b) { + // a: [d_head, n_head, L, N] + // b: [2, 2, d_head/2, L] (precomputed PE), or NULL for permute-only (no rotation) + // result: [d_head, L, N*n_head] + GGML_ASSERT(a != NULL); + GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(a->ne[0] > 0 && a->ne[1] > 0 && a->ne[2] > 0 && a->ne[3] > 0); + GGML_ASSERT(a->ne[0] % 2 == 0); + if (b != NULL) { + GGML_ASSERT(b->type == GGML_TYPE_F32); + GGML_ASSERT(b->ne[0] == 2); + GGML_ASSERT(b->ne[1] == 2); + GGML_ASSERT(a->ne[0] == 2 * b->ne[2]); // d_head == 2 * (d_head/2) + GGML_ASSERT(a->ne[2] == b->ne[3]); // L matches + } + + const int64_t d_head = a->ne[0]; + const int64_t n_head = a->ne[1]; + const int64_t L = a->ne[2]; + const int64_t N = a->ne[3]; + + GGML_ASSERT(n_head <= INT64_MAX / N); + + const int64_t ne[4] = { d_head, L, N * n_head, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne); + + result->op = GGML_OP_ROPE_FLUX; + result->src[0] = a; + result->src[1] = b; + + return result; +} + void ggml_rope_yarn_corr_dims( int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] ) { @@ -6612,6 +6650,10 @@ static void ggml_compute_backward( } GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented"); } break; + case GGML_OP_ROPE_FLUX: { + GGML_ASSERT(!src0_needs_grads && "backward pass for rope_flux not implemented"); + GGML_ASSERT((!src1 || !src1_needs_grads) && "gradients for rope_flux positional encoding not implemented"); + } break; case GGML_OP_IM2COL: { if (src1_needs_grads) { const int32_t s0 = ggml_get_op_params_i32(tensor, 0); @@ -6885,6 +6927,7 @@ void ggml_build_backward_expand( case GGML_OP_GET_ROWS: // row indices not differentiable case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS case GGML_OP_ROPE: // positions not differentiable + case GGML_OP_ROPE_FLUX: // positional encoding not differentiable ignore_src[1] = true; break; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7c28e344c5..4b497ebc4b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -336,6 +336,24 @@ if (NOT GGML_BACKEND_DL) add_test(NAME ${TEST_TARGET} COMMAND $) set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") + # + # test-conv2d-direct + + set(TEST_TARGET test-conv2d-direct) + add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp) + target_link_libraries(${TEST_TARGET} PRIVATE ggml) + add_test(NAME ${TEST_TARGET} COMMAND $) + set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") + + # + # test-rope-flux + + set(TEST_TARGET test-rope-flux) + add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp) + target_link_libraries(${TEST_TARGET} PRIVATE ggml) + add_test(NAME ${TEST_TARGET} COMMAND $) + set_property(TEST ${TEST_TARGET} PROPERTY ENVIRONMENT "LLVM_PROFILE_FILE=${TEST_TARGET}.profraw") + # # test-cont diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3b31f1fb69..eb7b2ba166 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4640,6 +4640,44 @@ struct test_rope : public test_case { } }; +// GGML_OP_ROPE_FLUX +struct test_rope_flux : public test_case { + const std::array ne_a; + const bool use_pe; + + std::string vars() override { + return VARS_TO_STR2(ne_a, use_pe); + } + + test_rope_flux(std::array ne_a = {16, 4, 32, 1}, bool use_pe = true) + : ne_a(ne_a), use_pe(use_pe) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_a.data()); + if (mode != MODE_GRAD) { + // rope_flux is inference-only today; keep support/test coverage without + // requesting an unsupported backward pass in grad mode. + ggml_set_param(a); + } + ggml_set_name(a, "a"); + + ggml_tensor * pe = nullptr; + if (use_pe) { + pe = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 2, 2, ne_a[0] / 2, ne_a[2]); + ggml_set_name(pe, "pe"); + } + + ggml_tensor * out = ggml_rope_flux(ctx, a, pe); + ggml_set_name(out, "out"); + + return out; + } + + double max_nmse_err() override { + return 1e-6; + } +}; + // GGML_OP_POOL2D struct test_pool2d : public test_case { enum ggml_op_pool pool_type; @@ -7083,6 +7121,12 @@ static std::vector> make_test_cases_eval() { } } + for (bool use_pe : {true, false}) { + test_cases.emplace_back(new test_rope_flux({16, 4, 32, 1}, use_pe)); + test_cases.emplace_back(new test_rope_flux({64, 8, 128, 1}, use_pe)); + test_cases.emplace_back(new test_rope_flux({128, 24, 256, 2}, use_pe)); + } + for (ggml_type type_input : {GGML_TYPE_F32}) { for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) { for (int k0 : {1, 3}) { @@ -7172,10 +7216,6 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32)); test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32)); test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16)); - // negative case: F16 input must be rejected by backends that require F32 input (e.g. Metal) - test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F16, GGML_TYPE_F32, GGML_TYPE_F32)); - // negative case: F16 dst with F32 kernel must be rejected; CPU F16 path asserts src0->type==F16 - test_cases.emplace_back(new test_im2col_3d(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F16)); for (int s0 : {1, 3}) { for (int s1 : {1, 3}) { for (int s2 : {1, 3}) { diff --git a/tests/test-conv2d-direct.cpp b/tests/test-conv2d-direct.cpp new file mode 100644 index 0000000000..d68d23ce3e --- /dev/null +++ b/tests/test-conv2d-direct.cpp @@ -0,0 +1,265 @@ +// Test that ggml_conv_2d_direct (GGML_OP_CONV_2D, implicit GEMM kernel) +// produces the same results as ggml_conv_2d (im2col + matmul) for +// configurations representative of Stable Diffusion U-Net and VAE layers. +// +// The implicit GEMM kernel uses half-precision intermediates with float +// accumulators, so we allow a small per-element tolerance. + +#include "ggml.h" +#include "ggml-cpu.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include +#include +#include +#include +#include + +static void ggml_log_callback_default(ggml_log_level level, const char * text, void *) { + (void) level; + fputs(text, stderr); + fflush(stderr); +} + +struct conv2d_test_case { + const char * name; + int KW, KH, IC, OC; + int IW, IH, N; + int s0, s1, p0, p1, d0, d1; + ggml_type weight_type = GGML_TYPE_F16; +}; + +static bool run_test(const conv2d_test_case & tc) { + printf(" %-40s ", tc.name); + fflush(stdout); + + const int OW = (tc.IW + 2*tc.p0 - tc.d0*(tc.KW - 1) - 1) / tc.s0 + 1; + const int OH = (tc.IH + 2*tc.p1 - tc.d1*(tc.KH - 1) - 1) / tc.s1 + 1; + const int n_out = OW * OH * tc.OC * tc.N; + + const int n_weights = tc.KW * tc.KH * tc.IC * tc.OC; + const int n_input = tc.IW * tc.IH * tc.IC * tc.N; + + std::vector weight_f32(n_weights); + std::vector weight_f16(n_weights); + std::vector input_f32(n_input); + + srand(42); + for (int i = 0; i < n_weights; i++) { + weight_f32[i] = ((float)rand() / RAND_MAX) * 2.0f - 1.0f; + } + ggml_fp32_to_fp16_row(weight_f32.data(), weight_f16.data(), n_weights); + + for (int i = 0; i < n_input; i++) { + input_f32[i] = ((float)rand() / RAND_MAX) * 2.0f - 1.0f; + } + + ggml_log_set(ggml_log_callback_default, nullptr); + + ggml_backend_t backend = nullptr; +#ifdef GGML_USE_METAL + backend = ggml_backend_metal_init(); +#endif + if (!backend) { + backend = ggml_backend_cpu_init(); + } + + size_t buf_size = n_weights * ggml_type_size(tc.weight_type) + n_input * sizeof(float) + 4096; + ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead() * 4, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx_data = ggml_init(params); + + struct ggml_tensor * w = ggml_new_tensor_4d(ctx_data, tc.weight_type, tc.KW, tc.KH, tc.IC, tc.OC); + struct ggml_tensor * x = ggml_new_tensor_4d(ctx_data, GGML_TYPE_F32, tc.IW, tc.IH, tc.IC, tc.N); + + struct ggml_tallocr alloc = ggml_tallocr_new(buffer); + ggml_tallocr_alloc(&alloc, w); + ggml_tallocr_alloc(&alloc, x); + + if (tc.weight_type == GGML_TYPE_F16) { + ggml_backend_tensor_set(w, weight_f16.data(), 0, ggml_nbytes(w)); + } else { + assert(tc.weight_type == GGML_TYPE_F32); + ggml_backend_tensor_set(w, weight_f32.data(), 0, ggml_nbytes(w)); + } + ggml_backend_tensor_set(x, input_f32.data(), 0, ggml_nbytes(x)); + + // --- Build graph with BOTH paths --- + size_t graph_buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + std::vector graph_buf(graph_buf_size); + + struct graph_build { + ggml_context * ctx; + ggml_cgraph * gf; + }; + + auto build_graph = [&](bool direct) -> graph_build { + struct ggml_init_params gp = { + /*.mem_size =*/ graph_buf_size, + /*.mem_buffer =*/ graph_buf.data(), + /*.no_alloc =*/ true, + }; + struct ggml_context * ctx0 = ggml_init(gp); + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * result; + if (direct) { + result = ggml_conv_2d_direct(ctx0, w, x, tc.s0, tc.s1, tc.p0, tc.p1, tc.d0, tc.d1); + ggml_set_name(result, "direct"); + } else { + result = ggml_conv_2d(ctx0, w, x, tc.s0, tc.s1, tc.p0, tc.p1, tc.d0, tc.d1); + ggml_set_name(result, "im2col"); + } + ggml_build_forward_expand(gf, result); + return { ctx0, gf }; + }; + + auto run_graph = [&](bool direct) -> std::vector { + graph_build reserved = build_graph(direct); + ggml_gallocr_t gallocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_reserve(gallocr, reserved.gf); + ggml_free(reserved.ctx); + + graph_build compute = build_graph(direct); + ggml_gallocr_alloc_graph(gallocr, compute.gf); + + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, 1); + } + ggml_backend_graph_compute(backend, compute.gf); + + struct ggml_tensor * result = nullptr; + for (int i = 0; i < ggml_graph_n_nodes(compute.gf); i++) { + struct ggml_tensor * node = ggml_graph_node(compute.gf, i); + if (strcmp(ggml_get_name(node), direct ? "direct" : "im2col") == 0) { + result = node; + break; + } + } + assert(result); + + std::vector data(ggml_nelements(result)); + ggml_backend_tensor_get(result, data.data(), 0, ggml_nbytes(result)); + ggml_gallocr_free(gallocr); + ggml_free(compute.ctx); + return data; + }; + + std::vector ref_data = run_graph(false); + std::vector direct_data = run_graph(true); + + ggml_free(ctx_data); + ggml_backend_buffer_free(buffer); + ggml_backend_free(backend); + + if ((int)ref_data.size() != n_out || (int)direct_data.size() != n_out) { + printf("\033[31mFAIL\033[0m (size mismatch: ref=%d direct=%d expected=%d)\n", + (int)ref_data.size(), (int)direct_data.size(), n_out); + return false; + } + + float max_abs_err = 0.0f; + float max_rel_err = 0.0f; + int worst_idx = 0; + int n_bad = 0; + + // Half intermediates are expected to differ from the im2col reference. + // Flag only elements that fail both an absolute and relative tolerance. + const float rel_tol = 0.005f; + const float abs_tol = 0.05f; + + for (int i = 0; i < n_out; i++) { + float abs_err = fabsf(ref_data[i] - direct_data[i]); + float denom = fmaxf(fabsf(ref_data[i]), 1e-6f); + float rel_err = abs_err / denom; + + if (abs_err > max_abs_err) { + max_abs_err = abs_err; + worst_idx = i; + } + if (rel_err > max_rel_err) { + max_rel_err = rel_err; + } + if (abs_err > abs_tol && rel_err > rel_tol) { + n_bad++; + } + } + + bool pass = n_bad == 0; + + if (pass) { + printf("\033[32mPASS\033[0m (max_abs=%.4f max_rel=%.4f%% bad=%d/%d)\n", + max_abs_err, max_rel_err * 100.0f, n_bad, n_out); + } else { + printf("\033[31mFAIL\033[0m (max_abs=%.4f max_rel=%.4f%% bad=%d/%d at [%d] ref=%.4f got=%.4f)\n", + max_abs_err, max_rel_err * 100.0f, n_bad, n_out, worst_idx, + ref_data[worst_idx], direct_data[worst_idx]); + } + + return pass; +} + +int main(void) { + ggml_time_init(); + + conv2d_test_case tests[] = { + // SD U-Net typical layers (3×3, stride 1, pad 1) + { "3x3 s1p1 IC=10 OC=10 8x6", 3,3, 10, 10, 8, 6, 1, 1,1, 1,1, 1,1 }, + { "3x3 s1p1 IC=32 OC=32 16x16", 3,3, 32, 32, 16, 16, 1, 1,1, 1,1, 1,1 }, + { "3x3 s1p1 IC=64 OC=64 32x32", 3,3, 64, 64, 32, 32, 1, 1,1, 1,1, 1,1 }, + { "3x3 s1p1 IC=128 OC=128 32x32", 3,3,128,128, 32, 32, 1, 1,1, 1,1, 1,1 }, + { "3x3 s1p1 IC=320 OC=320 64x64", 3,3,320,320, 64, 64, 1, 1,1, 1,1, 1,1 }, + { "3x3 s1p1 IC=640 OC=640 32x32", 3,3,640,640, 32, 32, 1, 1,1, 1,1, 1,1 }, + + // 1×1 convolution (channel projection in attention blocks) + { "1x1 s1p0 IC=320 OC=320 64x64", 1,1,320,320, 64, 64, 1, 1,1, 0,0, 1,1 }, + { "1x1 s1p0 IC=640 OC=640 32x32", 1,1,640,640, 32, 32, 1, 1,1, 0,0, 1,1 }, + { "1x1 s1p1 IC=32 OC=32 8x8", 1,1, 32, 32, 8, 8, 1, 1,1, 1,1, 1,1 }, + { "1x1 s2p1 IC=32 OC=32 8x8", 1,1, 32, 32, 8, 8, 1, 2,2, 1,1, 1,1 }, + + // Stride-2 downsampling + { "3x3 s2p1 IC=128 OC=256 64x64", 3,3,128,256, 64, 64, 1, 2,2, 1,1, 1,1 }, + + // No padding (edge case) + { "3x3 s1p0 IC=32 OC=32 16x16", 3,3, 32, 32, 16, 16, 1, 1,1, 0,0, 1,1 }, + + // Non-square spatial + { "3x3 s1p1 IC=64 OC=64 48x32", 3,3, 64, 64, 48, 32, 1, 1,1, 1,1, 1,1 }, + + // Small kernel edge: OC not multiple of tile (32) + { "3x3 s1p1 IC=64 OC=48 32x32", 3,3, 64, 48, 32, 32, 1, 1,1, 1,1, 1,1 }, + + // Minimal size + { "3x3 s1p1 IC=3 OC=16 8x8", 3,3, 3, 16, 8, 8, 1, 1,1, 1,1, 1,1 }, + { "1x1 s1p0 IC=16 OC=3 8x8", 1,1, 16, 3, 8, 8, 1, 1,1, 0,0, 1,1 }, + + // supports_op also advertises F32 weights with F32 input. + { "3x3 s1p1 F32W IC=16 OC=16 8x8", 3,3, 16, 16, 8, 8, 1, 1,1, 1,1, 1,1, GGML_TYPE_F32 }, + }; + + int n_tests = sizeof(tests) / sizeof(tests[0]); + int n_pass = 0; + + printf("conv2d_direct vs conv2d (im2col+matmul) comparison test\n"); + printf("========================================================\n"); + + for (int i = 0; i < n_tests; i++) { + if (run_test(tests[i])) { + n_pass++; + } + } + + printf("\n%d/%d tests passed\n", n_pass, n_tests); + return n_pass == n_tests ? 0 : 1; +} diff --git a/tests/test-rope-flux.cpp b/tests/test-rope-flux.cpp new file mode 100644 index 0000000000..e37b6a3c56 --- /dev/null +++ b/tests/test-rope-flux.cpp @@ -0,0 +1,449 @@ +#include "ggml.h" +#include "ggml-cpu.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include +#include +#include +#include +#include + +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +struct test_config { + const char * name; + int d_head; + int n_head; + int L; + int N; +}; + +// Build the OLD multi-op apply_rope graph (interleaved) +static ggml_tensor * build_rope_old(ggml_context * ctx, + ggml_tensor * x, // [d_head, n_head, L, N] + ggml_tensor * pe) { // [2, 2, d_head/2, L] + int64_t d_head = x->ne[0]; + int64_t n_head = x->ne[1]; + int64_t L = x->ne[2]; + int64_t N = x->ne[3]; + + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); + x = ggml_reshape_4d(ctx, x, 2, d_head / 2, L, n_head * N); + x = ggml_cont(ctx, ggml_permute(ctx, x, 3, 0, 1, 2)); + + int64_t offset = x->nb[2] * x->ne[2]; + auto x_0 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 0); + auto x_1 = ggml_view_3d(ctx, x, x->ne[0], x->ne[1], x->ne[2], x->nb[1], x->nb[2], offset * 1); + x_0 = ggml_reshape_4d(ctx, x_0, 1, x_0->ne[0], x_0->ne[1], x_0->ne[2]); + x_1 = ggml_reshape_4d(ctx, x_1, 1, x_1->ne[0], x_1->ne[1], x_1->ne[2]); + auto temp_x = ggml_new_tensor_4d(ctx, x_0->type, 2, x_0->ne[1], x_0->ne[2], x_0->ne[3]); + x_0 = ggml_repeat(ctx, x_0, temp_x); + x_1 = ggml_repeat(ctx, x_1, temp_x); + + pe = ggml_cont(ctx, ggml_permute(ctx, pe, 3, 0, 1, 2)); + offset = pe->nb[2] * pe->ne[2]; + auto pe_0 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 0); + auto pe_1 = ggml_view_3d(ctx, pe, pe->ne[0], pe->ne[1], pe->ne[2], pe->nb[1], pe->nb[2], offset * 1); + + auto x_out = ggml_add(ctx, ggml_mul(ctx, x_0, pe_0), ggml_mul(ctx, x_1, pe_1)); + x_out = ggml_reshape_3d(ctx, x_out, d_head, L, n_head * N); + return x_out; +} + +static ggml_tensor * build_permute_old(ggml_context * ctx, + ggml_tensor * x, + ggml_tensor *) { + int64_t d_head = x->ne[0]; + int64_t n_head = x->ne[1]; + int64_t L = x->ne[2]; + int64_t N = x->ne[3]; + + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); + return ggml_reshape_3d(ctx, x, d_head, L, n_head * N); +} + +// Build the NEW fused graph +static ggml_tensor * build_rope_fused(ggml_context * ctx, + ggml_tensor * x, + ggml_tensor * pe) { + return ggml_rope_flux(ctx, x, pe); +} + +static ggml_tensor * build_permute_fused(ggml_context * ctx, + ggml_tensor * x, + ggml_tensor *) { + return ggml_rope_flux(ctx, x, nullptr); +} + +static ggml_tensor * view_packed_q(ggml_context * ctx, + ggml_tensor * packed_qkv, + ggml_tensor * pe) { + GGML_ASSERT(pe != nullptr); + const int64_t d_head = pe->ne[2] * 2; + GGML_ASSERT(d_head > 0); + GGML_ASSERT(packed_qkv->ne[0] % (3 * d_head) == 0); + const int64_t n_head = packed_qkv->ne[0] / (3 * d_head); + + return ggml_view_4d(ctx, packed_qkv, d_head, n_head, packed_qkv->ne[1], packed_qkv->ne[2], + packed_qkv->nb[0] * d_head, packed_qkv->nb[1], packed_qkv->nb[2], 0); +} + +static ggml_tensor * build_rope_old_packed_q(ggml_context * ctx, + ggml_tensor * packed_qkv, + ggml_tensor * pe) { + return build_rope_old(ctx, view_packed_q(ctx, packed_qkv, pe), pe); +} + +static ggml_tensor * build_rope_fused_packed_q(ggml_context * ctx, + ggml_tensor * packed_qkv, + ggml_tensor * pe) { + return build_rope_fused(ctx, view_packed_q(ctx, packed_qkv, pe), pe); +} + +static float * run_graph(ggml_backend_t backend, + ggml_tensor * x_param, + ggml_tensor * pe_param, + const float * x_data, + const float * pe_data, + ggml_tensor * (*builder)(ggml_context *, ggml_tensor *, ggml_tensor *), + int * out_nodes) { + size_t buf_size = ggml_tensor_overhead() * 256 + ggml_graph_overhead_custom(256, false); + std::vector buf(buf_size); + + ggml_init_params p = { buf_size, buf.data(), true }; + ggml_context * ctx = ggml_init(p); + ggml_cgraph * gf = ggml_new_graph_custom(ctx, 256, false); + + ggml_tensor * result = builder(ctx, x_param, pe_param); + ggml_set_name(result, "result"); + ggml_build_forward_expand(gf, result); + + ggml_gallocr_t allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + ggml_gallocr_reserve(allocr, gf); + + // re-build for actual compute + ggml_free(ctx); + ctx = ggml_init(p); + gf = ggml_new_graph_custom(ctx, 256, false); + result = builder(ctx, x_param, pe_param); + ggml_set_name(result, "result"); + ggml_build_forward_expand(gf, result); + ggml_gallocr_alloc_graph(allocr, gf); + + // set input data + ggml_backend_tensor_set(x_param, x_data, 0, ggml_nbytes(x_param)); + if (pe_param != nullptr) { + ggml_backend_tensor_set(pe_param, pe_data, 0, ggml_nbytes(pe_param)); + } + + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, 1); + } + + ggml_backend_synchronize(backend); + int64_t t0 = ggml_time_us(); + ggml_backend_graph_compute(backend, gf); + ggml_backend_synchronize(backend); + int64_t t1 = ggml_time_us(); + + *out_nodes = ggml_graph_n_nodes(gf); + + result = ggml_graph_get_tensor(gf, "result"); + size_t nbytes = ggml_nbytes(result); + float * out = (float *)malloc(nbytes); + assert(out != nullptr); + ggml_backend_tensor_get(result, out, 0, nbytes); + + printf(" time: %.2f ms, nodes: %d, output shape: [%lld, %lld, %lld]\n", + (t1 - t0) / 1000.0, *out_nodes, + (long long)result->ne[0], (long long)result->ne[1], (long long)result->ne[2]); + + ggml_gallocr_free(allocr); + ggml_free(ctx); + return out; +} + +static bool run_packed_q_view_case(ggml_backend_t backend) { + test_config cfg = { "packed_q_view", 64, 8, 128, 1 }; + printf("=== %s: d_head=%d, n_head=%d, L=%d, N=%d ===\n", + cfg.name, cfg.d_head, cfg.n_head, cfg.L, cfg.N); + + const int q_elems = cfg.d_head * cfg.n_head * cfg.L * cfg.N; + const int packed_elems = 3 * q_elems; + const int pe_elems = 2 * 2 * (cfg.d_head / 2) * cfg.L; + + std::vector q_data(q_elems); + std::vector packed_data(packed_elems); + std::vector pe_data(pe_elems); + + srand(1729); + for (int i = 0; i < packed_elems; i++) { + packed_data[i] = -100.0f + 0.001f * i; + } + for (int n = 0; n < cfg.N; n++) { + for (int l = 0; l < cfg.L; l++) { + for (int h = 0; h < cfg.n_head; h++) { + for (int d = 0; d < cfg.d_head; d++) { + const int q_idx = d + cfg.d_head * (h + cfg.n_head * (l + cfg.L * n)); + const int packed_idx = d + cfg.d_head * h + 3 * cfg.d_head * cfg.n_head * (l + cfg.L * n); + q_data[q_idx] = ((float)(rand() % 2000) - 1000.0f) / 1000.0f; + packed_data[packed_idx] = q_data[q_idx]; + } + } + } + } + for (int i = 0; i < cfg.L; i++) { + for (int p = 0; p < cfg.d_head / 2; p++) { + float theta = (float)i / powf(10000.0f, 2.0f * p / cfg.d_head); + float cos_v = cosf(theta); + float sin_v = sinf(theta); + int base = i * (cfg.d_head / 2) * 4 + p * 4; + pe_data[base + 0] = cos_v; + pe_data[base + 1] = -sin_v; + pe_data[base + 2] = sin_v; + pe_data[base + 3] = cos_v; + } + } + + size_t buf_size = packed_elems * sizeof(float) + pe_elems * sizeof(float) + 4096; + ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, buf_size); + + ggml_init_params mp = { ggml_tensor_overhead() * 4, NULL, true }; + ggml_context * mctx = ggml_init(mp); + + ggml_tensor * packed_param = ggml_new_tensor_3d(mctx, GGML_TYPE_F32, 3 * cfg.d_head * cfg.n_head, cfg.L, cfg.N); + ggml_tensor * pe_param = ggml_new_tensor_4d(mctx, GGML_TYPE_F32, 2, 2, cfg.d_head / 2, cfg.L); + + ggml_tallocr alloc = ggml_tallocr_new(buffer); + ggml_tallocr_alloc(&alloc, packed_param); + ggml_tallocr_alloc(&alloc, pe_param); + + printf(" OLD non-contiguous Q view:\n"); + int old_nodes = 0; + float * old_out = run_graph(backend, packed_param, pe_param, packed_data.data(), pe_data.data(), + build_rope_old_packed_q, &old_nodes); + + printf(" NEW non-contiguous Q view (fused):\n"); + int new_nodes = 0; + float * new_out = run_graph(backend, packed_param, pe_param, packed_data.data(), pe_data.data(), + build_rope_fused_packed_q, &new_nodes); + + float max_abs = 0.0f; + int max_abs_idx = 0; + for (int i = 0; i < q_elems; i++) { + float diff = fabsf(old_out[i] - new_out[i]); + if (diff > max_abs) { + max_abs = diff; + max_abs_idx = i; + } + } + + bool ok = max_abs < 1e-4f; + printf(" NON-CONTIGUOUS COMPARE: max_abs=%.6f (at idx %d: old=%.6f new=%.6f) => %s\n", + max_abs, max_abs_idx, old_out[max_abs_idx], new_out[max_abs_idx], ok ? "PASS" : "FAIL"); + printf(" nodes: old=%d, new=%d (-%d)\n\n", old_nodes, new_nodes, old_nodes - new_nodes); + + free(old_out); + free(new_out); + ggml_free(mctx); + ggml_backend_buffer_free(buffer); + return ok; +} + +int main(void) { + ggml_time_init(); + ggml_log_set(ggml_log_callback_default, nullptr); + + test_config configs[] = { + { "small", 16, 4, 32, 1 }, + { "medium", 64, 8, 128, 1 }, + { "flux_klein", 128, 24, 4352, 1 }, + { "flux_batch", 128, 24, 256, 2 }, + }; + int n_configs = sizeof(configs) / sizeof(configs[0]); + + ggml_backend_t backend = nullptr; +#ifdef GGML_USE_METAL + backend = ggml_backend_metal_init(); + if (backend) { + printf("Using Metal backend\n\n"); + } +#endif + if (!backend) { + backend = ggml_backend_cpu_init(); + if (!backend) { + fprintf(stderr, "Backend init failed\n"); + return 1; + } + printf("Using CPU backend\n\n"); + } + + int pass = 0, fail = 0; + + for (int c = 0; c < n_configs; c++) { + auto & cfg = configs[c]; + printf("=== %s: d_head=%d, n_head=%d, L=%d, N=%d ===\n", + cfg.name, cfg.d_head, cfg.n_head, cfg.L, cfg.N); + + int x_elems = cfg.d_head * cfg.n_head * cfg.L * cfg.N; + int pe_elems = 2 * 2 * (cfg.d_head / 2) * cfg.L; + + // generate deterministic test data + std::vector x_data(x_elems); + std::vector pe_data(pe_elems); + srand(42 + c); + for (int i = 0; i < x_elems; i++) { + x_data[i] = ((float)(rand() % 2000) - 1000.0f) / 1000.0f; + } + for (int i = 0; i < cfg.L; i++) { + for (int p = 0; p < cfg.d_head / 2; p++) { + float theta = (float)i / powf(10000.0f, 2.0f * p / cfg.d_head); + float cos_v = cosf(theta); + float sin_v = sinf(theta); + int base = i * (cfg.d_head / 2) * 4 + p * 4; + pe_data[base + 0] = cos_v; // [0,0] = cos + pe_data[base + 1] = -sin_v; // [1,0] = -sin + pe_data[base + 2] = sin_v; // [0,1] = sin + pe_data[base + 3] = cos_v; // [1,1] = cos + } + } + + // allocate parameter tensors on backend + size_t buf_size = x_elems * sizeof(float) + pe_elems * sizeof(float) + 4096; + ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, buf_size); + + ggml_init_params mp = { ggml_tensor_overhead() * 4, NULL, true }; + ggml_context * mctx = ggml_init(mp); + + ggml_tensor * x_param = ggml_new_tensor_4d(mctx, GGML_TYPE_F32, cfg.d_head, cfg.n_head, cfg.L, cfg.N); + ggml_tensor * pe_param = ggml_new_tensor_4d(mctx, GGML_TYPE_F32, 2, 2, cfg.d_head / 2, cfg.L); + + ggml_tallocr alloc = ggml_tallocr_new(buffer); + ggml_tallocr_alloc(&alloc, x_param); + ggml_tallocr_alloc(&alloc, pe_param); + + // run old pipeline + printf(" OLD (multi-op):\n"); + int old_nodes = 0; + float * old_out = run_graph(backend, x_param, pe_param, x_data.data(), pe_data.data(), + build_rope_old, &old_nodes); + + // run new fused kernel + printf(" NEW (fused):\n"); + int new_nodes = 0; + float * new_out = run_graph(backend, x_param, pe_param, x_data.data(), pe_data.data(), + build_rope_fused, &new_nodes); + + // compare outputs + int out_elems = cfg.d_head * cfg.L * cfg.N * cfg.n_head; + float max_abs = 0.0f; + float max_rel = 0.0f; + int max_abs_idx = 0; + double old_sum = 0, new_sum = 0, old_sqsum = 0, new_sqsum = 0; + int n_nonzero_old = 0, n_nonzero_new = 0; + int n_exact = 0; + + for (int i = 0; i < out_elems; i++) { + float diff = fabsf(old_out[i] - new_out[i]); + float rel = (fabsf(old_out[i]) > 1e-6f) ? diff / fabsf(old_out[i]) : 0.0f; + if (diff > max_abs) { + max_abs = diff; + max_abs_idx = i; + } + if (rel > max_rel) max_rel = rel; + if (old_out[i] == new_out[i]) n_exact++; + if (fabsf(old_out[i]) > 1e-8f) n_nonzero_old++; + if (fabsf(new_out[i]) > 1e-8f) n_nonzero_new++; + old_sum += old_out[i]; + new_sum += new_out[i]; + old_sqsum += (double)old_out[i] * old_out[i]; + new_sqsum += (double)new_out[i] * new_out[i]; + } + + double old_mean = old_sum / out_elems; + double new_mean = new_sum / out_elems; + double old_std = sqrt(old_sqsum / out_elems - old_mean * old_mean); + double new_std = sqrt(new_sqsum / out_elems - new_mean * new_mean); + + printf(" OLD stats: mean=%.6f, std=%.6f, nonzero=%d/%d\n", + old_mean, old_std, n_nonzero_old, out_elems); + printf(" NEW stats: mean=%.6f, std=%.6f, nonzero=%d/%d\n", + new_mean, new_std, n_nonzero_new, out_elems); + printf(" Sample values [0..4]: old=[%.4f, %.4f, %.4f, %.4f, %.4f]\n", + old_out[0], old_out[1], old_out[2], old_out[3], old_out[4]); + printf(" new=[%.4f, %.4f, %.4f, %.4f, %.4f]\n", + new_out[0], new_out[1], new_out[2], new_out[3], new_out[4]); + printf(" Exact matches: %d/%d (%.1f%%)\n", n_exact, out_elems, 100.0f * n_exact / out_elems); + + bool ok = max_abs < 1e-4f && n_nonzero_old > out_elems / 2 && n_nonzero_new > out_elems / 2; + printf(" COMPARE: max_abs=%.6f (at idx %d: old=%.6f new=%.6f), max_rel=%.4f%% => %s\n", + max_abs, max_abs_idx, + old_out[max_abs_idx], new_out[max_abs_idx], + max_rel * 100.0f, + ok ? "PASS" : "FAIL"); + printf(" nodes: old=%d, new=%d (-%d)\n\n", + old_nodes, new_nodes, old_nodes - new_nodes); + + if (ok) pass++; else fail++; + + free(old_out); + free(new_out); + + // compare permute-only path (b == NULL) + printf(" OLD permute-only:\n"); + int old_perm_nodes = 0; + float * old_perm_out = run_graph(backend, x_param, nullptr, x_data.data(), nullptr, + build_permute_old, &old_perm_nodes); + + printf(" NEW permute-only (fused):\n"); + int new_perm_nodes = 0; + float * new_perm_out = run_graph(backend, x_param, nullptr, x_data.data(), nullptr, + build_permute_fused, &new_perm_nodes); + + float perm_max_abs = 0.0f; + int perm_max_abs_idx = 0; + for (int i = 0; i < out_elems; i++) { + float diff = fabsf(old_perm_out[i] - new_perm_out[i]); + if (diff > perm_max_abs) { + perm_max_abs = diff; + perm_max_abs_idx = i; + } + } + + bool perm_ok = perm_max_abs < 1e-6f; + printf(" PERMUTE COMPARE: max_abs=%.6f (at idx %d: old=%.6f new=%.6f) => %s\n", + perm_max_abs, perm_max_abs_idx, + old_perm_out[perm_max_abs_idx], new_perm_out[perm_max_abs_idx], + perm_ok ? "PASS" : "FAIL"); + printf(" nodes: old=%d, new=%d (-%d)\n\n", + old_perm_nodes, new_perm_nodes, old_perm_nodes - new_perm_nodes); + + if (perm_ok) pass++; else fail++; + + free(old_perm_out); + free(new_perm_out); + ggml_free(mctx); + ggml_backend_buffer_free(buffer); + } + + if (run_packed_q_view_case(backend)) pass++; else fail++; + + ggml_backend_free(backend); + + printf("=== Results: %d/%d passed", pass, pass + fail); + if (fail > 0) printf(", %d FAILED", fail); + printf(" ===\n"); + + return fail > 0 ? 1 : 0; +}