Skip to content
10 changes: 10 additions & 0 deletions include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ extern "C" {
GGML_OP_OPT_STEP_SGD,

GGML_OP_GLU,
GGML_OP_ROPE_FLUX,

GGML_OP_COUNT,
};
Expand Down Expand Up @@ -1860,6 +1861,15 @@ extern "C" {
float beta_slow),
"use ggml_rope_ext_inplace instead");

// Fused Flux-style RoPE: applies rotation using precomputed PE matrix and permutes output layout.
// a: [d_head, n_head, L, N] (Q or K tensor, may be non-contiguous)
// b: [2, 2, d_head/2, L] (precomputed rotation matrix [[cos,-sin],[sin,cos]]), or NULL for permute-only
// result: [d_head, L, N*n_head] (contiguous, layout for flash attention)
GGML_API struct ggml_tensor * ggml_rope_flux(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);

// compute correction dims for YaRN RoPE scaling
GGML_API void ggml_rope_yarn_corr_dims(
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
Expand Down
5 changes: 5 additions & 0 deletions src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1868,6 +1868,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rope_back(params, tensor);
} break;
case GGML_OP_ROPE_FLUX:
{
ggml_compute_forward_rope_flux(params, tensor);
} break;
case GGML_OP_CLAMP:
{
ggml_compute_forward_clamp(params, tensor);
Expand Down Expand Up @@ -2296,6 +2300,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ROPE:
case GGML_OP_ROPE_BACK:
case GGML_OP_ROPE_FLUX:
case GGML_OP_ADD_REL_POS:
{
n_tasks = n_threads;
Expand Down
5 changes: 5 additions & 0 deletions src/ggml-cpu/ggml-cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,11 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
}
case GGML_OP_IM2COL_BACK:
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
case GGML_OP_IM2COL_3D:
return src1->type == GGML_TYPE_F32 &&
((op->type == GGML_TYPE_F32 &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
(op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F16));
case GGML_OP_GET_ROWS_BACK:
return src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16;
case GGML_OP_OUT_PROD:
Expand Down
81 changes: 81 additions & 0 deletions src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5840,6 +5840,87 @@ void ggml_compute_forward_rope_back(
}
}

// ggml_compute_forward_rope_flux

void ggml_compute_forward_rope_flux(
Comment thread
jpgaribotti marked this conversation as resolved.
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];

GGML_ASSERT(src0 != NULL);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(dst));

const int64_t d_head = src0->ne[0];
const int64_t n_head = src0->ne[1];
const int64_t L = src0->ne[2];
const int64_t N = src0->ne[3];

GGML_ASSERT(d_head > 0 && n_head > 0 && L > 0 && N > 0);
GGML_ASSERT(d_head % 2 == 0);

if (src1 != NULL) {
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src1->ne[0] == 2);
GGML_ASSERT(src1->ne[1] == 2);
GGML_ASSERT(src1->ne[2] == d_head / 2);
GGML_ASSERT(src1->ne[3] == L);
}

const int ith = params->ith;
const int nth = params->nth;

const int64_t n_rows = L * n_head * N;
const int64_t dr = (n_rows + nth - 1) / nth;
const int64_t r0 = dr * ith;
const int64_t r1 = MIN(r0 + dr, n_rows);

float * dst_data = (float *) dst->data;
const char * src0_data = (const char *) src0->data;
const char * src1_data = src1 ? (const char *) src1->data : NULL;

for (int64_t row = r0; row < r1; ++row) {
const int64_t l = row % L;
const int64_t bh = row / L;
const int64_t h = bh % n_head;
const int64_t n = bh / n_head;

float * dst_row = dst_data + row * d_head;
const char * src0_row = src0_data +
n * src0->nb[3] +
l * src0->nb[2] +
h * src0->nb[1];

if (src1_data == NULL) {
for (int64_t d = 0; d < d_head; ++d) {
dst_row[d] = *(const float *) (src0_row + d * src0->nb[0]);
}
continue;
}

for (int64_t pair = 0; pair < d_head / 2; ++pair) {
const char * src0_pair = src0_row + (2 * pair) * src0->nb[0];
const char * src1_pair = src1_data +
l * src1->nb[3] +
pair * src1->nb[2];

const float x_even = *(const float *) src0_pair;
const float x_odd = *(const float *) (src0_pair + src0->nb[0]);

const float pe_00 = *(const float *) src1_pair;
const float pe_10 = *(const float *) (src1_pair + src1->nb[0]);
const float pe_01 = *(const float *) (src1_pair + src1->nb[1]);
const float pe_11 = *(const float *) (src1_pair + src1->nb[1] + src1->nb[0]);

dst_row[2 * pair] = x_even * pe_00 + x_odd * pe_10;
dst_row[2 * pair + 1] = x_even * pe_01 + x_odd * pe_11;
}
}
}

// ggml_compute_forward_conv_transpose_1d

static void ggml_compute_forward_conv_transpose_1d_f16_f32(
Expand Down
1 change: 1 addition & 0 deletions src/ggml-cpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void ggml_compute_forward_soft_max(const struct ggml_compute_params * params, st
void ggml_compute_forward_soft_max_ext_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rope(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rope_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_rope_flux(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
Expand Down
19 changes: 18 additions & 1 deletion src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <Metal/Metal.h>

#include <stdatomic.h>
#include <stdint.h>

#ifndef TARGET_OS_VISION
#define TARGET_OS_VISION 0
Expand Down Expand Up @@ -1041,10 +1042,26 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
case GGML_OP_ROPE:
return true;
case GGML_OP_ROPE_FLUX:
if (op->src[0] == nil || op->src[0]->type != GGML_TYPE_F32 ||
op->src[0]->ne[0] <= 0 || op->src[0]->ne[1] <= 0 || op->src[0]->ne[2] <= 0 || op->src[0]->ne[3] <= 0 ||
op->src[0]->ne[0] % 2 != 0 ||
ggml_nelements(op) > INT32_MAX) {
return false;
}
if (op->src[1] == nil) {
return true;
}
return op->src[1]->type == GGML_TYPE_F32 &&
op->src[1]->ne[0] == 2 &&
op->src[1]->ne[1] == 2 &&
op->src[0]->ne[0] == 2 * op->src[1]->ne[2] &&
op->src[0]->ne[2] == op->src[1]->ne[3];
case GGML_OP_IM2COL:
return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
case GGML_OP_CONV_2D:
return ggml_is_contiguous(op->src[0]) &&
return has_simdgroup_mm &&
ggml_is_contiguous(op->src[0]) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32 &&
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
Expand Down
15 changes: 15 additions & 0 deletions src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,21 @@ typedef struct {
bool src2;
} ggml_metal_kargs_rope;

typedef struct {
int32_t d_head;
int32_t n_head;
int32_t L;
int32_t N;
uint64_t nb00; // x strides (may be non-contiguous)
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t pe_nb0; // pe strides
uint64_t pe_nb1;
uint64_t pe_nb2;
uint64_t pe_nb3;
} ggml_metal_kargs_rope_flux;

typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
Expand Down
146 changes: 111 additions & 35 deletions src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_rope(ctx, idx);
} break;
case GGML_OP_ROPE_FLUX:
{
n_fuse = ggml_metal_op_rope_flux(ctx, idx);
} break;
case GGML_OP_IM2COL:
{
n_fuse = ggml_metal_op_im2col(ctx, idx);
Expand Down Expand Up @@ -3284,6 +3288,74 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
return 1;
}

int ggml_metal_op_rope_flux(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;

GGML_ASSERT(op->src[0] != nullptr);
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);

const bool has_pe = op->src[1] != nullptr;
if (has_pe) {
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->ne[0] == 2);
GGML_ASSERT(op->src[1]->ne[1] == 2);
GGML_ASSERT(op->src[0]->ne[0] == 2 * op->src[1]->ne[2]);
GGML_ASSERT(op->src[0]->ne[2] == op->src[1]->ne[3]);
}

auto to_i32_dim = [](int64_t dim) {
GGML_ASSERT(dim > 0);
GGML_ASSERT(dim <= std::numeric_limits<int32_t>::max());
return (int32_t) dim;
};

const int32_t d_head = to_i32_dim(op->src[0]->ne[0]);
const int32_t n_head = to_i32_dim(op->src[0]->ne[1]);
const int32_t L = to_i32_dim(op->src[0]->ne[2]);
const int32_t N = to_i32_dim(op->src[0]->ne[3]);
const int64_t total = ggml_nelements(op);
GGML_ASSERT(total > 0);
GGML_ASSERT(total <= std::numeric_limits<int32_t>::max());

ggml_metal_kargs_rope_flux args = {
/*.d_head =*/ d_head,
/*.n_head =*/ n_head,
/*.L =*/ L,
/*.N =*/ N,
/*.nb00 =*/ op->src[0]->nb[0],
/*.nb01 =*/ op->src[0]->nb[1],
/*.nb02 =*/ op->src[0]->nb[2],
/*.nb03 =*/ op->src[0]->nb[3],
/*.pe_nb0 =*/ has_pe ? op->src[1]->nb[0] : 0,
/*.pe_nb1 =*/ has_pe ? op->src[1]->nb[1] : 0,
/*.pe_nb2 =*/ has_pe ? op->src[1]->nb[2] : 0,
/*.pe_nb3 =*/ has_pe ? op->src[1]->nb[3] : 0,
};

if (has_pe) {
auto pipeline = ggml_metal_library_compile_pipeline(lib, "kernel_rope_flux", "kernel_rope_flux", NULL);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
} else {
auto pipeline = ggml_metal_library_compile_pipeline(lib, "kernel_permute_cont_021", "kernel_permute_cont_021", NULL);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);
}

const int nth = 256;
ggml_metal_encoder_dispatch_threadgroups(enc, ((int) total + nth - 1) / nth, 1, 1, nth, 1, 1);

return 1;
}

int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

Expand Down Expand Up @@ -3380,54 +3452,58 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
const int32_t d1 = ((const int32_t *) op->op_params)[5];

ggml_metal_kargs_conv_2d args = {
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.IW =*/ ne10,
/*.IH =*/ ne11,
/*.KW =*/ ne00,
/*.KH =*/ ne01,
/*.IC =*/ ne02,
/*.OC =*/ ne03,
/*.OW =*/ ne0,
/*.OH =*/ ne1,
/*.N =*/ ne3,
/*.s0 =*/ s0,
/*.s1 =*/ s1,
/*.p0 =*/ p0,
/*.p1 =*/ p1,
/*.d0 =*/ d0,
/*.d1 =*/ d1,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.IW =*/ ne10,
/*.IH =*/ ne11,
/*.KW =*/ ne00,
/*.KH =*/ ne01,
/*.IC =*/ ne02,
/*.OC =*/ ne03,
/*.OW =*/ ne0,
/*.OH =*/ ne1,
/*.N =*/ ne3,
/*.s0 =*/ s0,
/*.s1 =*/ s1,
/*.p0 =*/ p0,
/*.p1 =*/ p1,
/*.d0 =*/ d0,
/*.d1 =*/ d1,
};

auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);

int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
nth = std::min(nth, 256);
nth = std::max(nth, 1);
const int M_TILE = 64;
const int N_TILE = 64;
const int K_TILE = 32;

const uint64_t n_out = ggml_nelements(op);
const int M = ne0 * ne1;
const int tg_x = ((int) ne03 + N_TILE - 1) / N_TILE;
const int tg_y = (M + M_TILE - 1) / M_TILE;
const int tg_z = ne3;

uint64_t tg = (n_out + nth - 1)/nth;
tg = std::max<uint64_t>(tg, 1);
tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
const size_t smem = GGML_PAD(std::max(
(size_t)(M_TILE * K_TILE + K_TILE * N_TILE) * sizeof(uint16_t),
(size_t)(M_TILE * N_TILE) * sizeof(float)), 16);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, tg_x, tg_y, tg_z, 256, 1, 1);

return 1;
}
Expand Down
1 change: 1 addition & 0 deletions src/ggml-metal/ggml-metal-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope_flux (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
Expand Down
Loading