Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions android/src/main/rnllama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,12 @@ function(build_rnllama_library target_name arch cpu_flags)
gemv_moe_q5_0_f32_ns
gemm_moe_q5_1_f32_ns
gemv_moe_q5_1_f32_ns
gemm_moe_q4_k_f32_ns
gemv_moe_q4_k_f32_ns
gemm_moe_q5_k_f32_ns
gemv_moe_q5_k_f32_ns
gemm_moe_q6_k_f32_ns
gemv_moe_q6_k_f32_ns
gemm_moe_mxfp4_f32
gemv_moe_mxfp4_f32
gemm_moe_mxfp4_f32_ns
Expand Down
4 changes: 2 additions & 2 deletions cpp/common/build-info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#include <cstdio>
#include <string>

int LLAMA_BUILD_NUMBER = 9243;
char const * LLAMA_COMMIT = "17d22a3";
int LLAMA_BUILD_NUMBER = 9254;
char const * LLAMA_COMMIT = "e947228";
char const * LLAMA_COMPILER = "unknown";
char const * LLAMA_BUILD_TARGET = "unknown";

Expand Down
8 changes: 7 additions & 1 deletion cpp/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1897,7 +1897,11 @@ lm_ggml_metal_pipeline_with_params lm_ggml_metal_library_get_pipeline_pad(lm_ggm
char base[256];
char name[256];

snprintf(base, 256, "kernel_pad_%s", lm_ggml_type_name(op->src[0]->type));
// note: this is slower
//const bool is_c4 = op->src[0]->ne[0] % 4 == 0 && op->ne[0] % 4 == 0;
const bool is_c4 = false;

snprintf(base, 256, "kernel_pad_%s%s", lm_ggml_type_name(op->src[0]->type), is_c4 ? "_4" : "");
snprintf(name, 256, "%s", base);

lm_ggml_metal_pipeline_with_params res = lm_ggml_metal_library_get_pipeline(lib, name);
Expand All @@ -1907,6 +1911,8 @@ lm_ggml_metal_pipeline_with_params lm_ggml_metal_library_get_pipeline_pad(lm_ggm

res = lm_ggml_metal_library_compile_pipeline(lib, base, name, nullptr);

res.c4 = is_c4;

return res;
}

Expand Down
17 changes: 11 additions & 6 deletions cpp/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -816,9 +816,7 @@ int lm_ggml_metal_op_unary(lm_ggml_metal_op_t ctx, int idx) {
lm_ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
} else {
const int nth_max = MIN(256, lm_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));

const int nth = MIN(args.ne00, nth_max);

const int nk0 = (args.ne00 + nth - 1)/nth;

lm_ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
Expand Down Expand Up @@ -1863,7 +1861,7 @@ int lm_ggml_metal_op_cpy(lm_ggml_metal_op_t ctx, int idx) {
nk0 = ne00/lm_ggml_blck_size(op->type);
}

int nth = std::min<int>(nk0, lm_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
int nth = std::min<int>(nk0*ne01, 256);

// when rows are small, we can batch them together in a single threadgroup
int nrptg = 1;
Expand All @@ -1874,7 +1872,7 @@ int lm_ggml_metal_op_cpy(lm_ggml_metal_op_t ctx, int idx) {
nrptg = (nth + nk0 - 1)/nk0;
nth = nk0;

if (nrptg*nth > lm_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
if (nrptg*nth > 256) {
nrptg--;
}
}
Expand Down Expand Up @@ -4039,14 +4037,21 @@ int lm_ggml_metal_op_pad(lm_ggml_metal_op_t ctx, int idx) {

auto pipeline = lm_ggml_metal_library_get_pipeline_pad(lib, op);

const int nth = std::min(1024, ne0);
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}

const int nth_max = MIN(64, lm_ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const int nth = MIN(args.ne0, nth_max);
const int nk0 = (args.ne0 + 1024 - 1)/1024; // note: 1024 is hardcoded in the kernel!

lm_ggml_metal_encoder_set_pipeline(enc, pipeline);
lm_ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
lm_ggml_metal_encoder_set_buffer (enc, lm_ggml_metal_get_buffer_id(op->src[0]), 1);
lm_ggml_metal_encoder_set_buffer (enc, lm_ggml_metal_get_buffer_id(op), 2);

lm_ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
lm_ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1);

return 1;
}
Expand Down
128 changes: 72 additions & 56 deletions cpp/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -5712,7 +5712,7 @@ kernel void kernel_gated_delta_net_impl(
b_ptr += args.ne21;
g_ptr += args.ne21*G;

if (K > 1u) {
if (K > 1) {
const int target_slot = (int)t - shift;
if (target_slot >= 0 && target_slot < (int)K) {
device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base;
Expand All @@ -5724,7 +5724,7 @@ kernel void kernel_gated_delta_net_impl(
}
}

if (K == 1u) {
if (K == 1) {
device float * dst_state = (device float *) (dst) + attn_size + state_out_base;
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
Expand Down Expand Up @@ -8173,7 +8173,7 @@ kernel void kernel_upscale_bilinear_f32(
for (int64_t sx = x_min; sx < x_max; ++sx) {
const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0);
const float w = wx * wy;
const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00);
sum += (*src_ptr) * w;
wsum += w;
}
Expand Down Expand Up @@ -8355,7 +8355,7 @@ kernel void kernel_upscale_bicubic_f32(
const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx));
const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3;

const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00);
sum += (*src_ptr) * wx * wy;
}
}
Expand Down Expand Up @@ -8398,42 +8398,46 @@ kernel void kernel_roll_f32(
}
}

kernel void kernel_pad_f32(
template <typename T>
kernel void kernel_pad_impl(
constant lm_ggml_metal_kargs_pad & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int32_t i3 = tgpig.z;
const int32_t i2 = tgpig.y;
const int32_t k0 = tgpig.x/args.ne1;
const int32_t i1 = tgpig.x - k0*args.ne1;

const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int32_t i03 = i3;
const int32_t i02 = i2;
const int32_t i01 = i1;

const int64_t i03 = i3;
const int64_t i02 = i2;
const int64_t i01 = i1;

device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);

if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
if (i0 < args.ne00) {
dst_ptr[i0] = src0_ptr[i0];
} else {
dst_ptr[i0] = 0.0f;
}
for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) {
const int32_t i0 = k0*1024 + tpitg.x + l0;
if (i0 >= args.ne0) {
break;
}

return;
}

for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
dst_ptr[i0] = 0.0f;
if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
dst_ptr[i0] = src0_ptr[i0];
} else {
dst_ptr[i0] = 0.0f;
}
}
}

typedef decltype(kernel_pad_impl<float>) kernel_pad_t;

template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>;
template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>;

// TODO: this is slow - optimize
kernel void kernel_pad_reflect_1d_f32(
constant lm_ggml_metal_kargs_pad_reflect_1d & args,
device const char * src0,
Expand Down Expand Up @@ -10397,23 +10401,27 @@ kernel void kernel_cpy_t_t(
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;

if (i01 >= args.ne01) {
return;
}

const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;

const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);

device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);

for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) {
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
dst_data[i00] = (T1) src[0];
break;
Expand Down Expand Up @@ -10445,23 +10453,27 @@ kernel void kernel_cpy_f32_q(
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;

if (i01 >= args.ne01) {
return;
}

const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;

const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
const int32_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;

device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);

for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);

quantize_func(src, dst_data[i00]);
Expand All @@ -10486,24 +10498,28 @@ kernel void kernel_cpy_q_f32(
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int32_t i03 = tgpig[2];
const int32_t i02 = tgpig[1];
const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y;
const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;

if (i01 >= args.ne01) {
return;
}

const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;

const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
const int32_t i3 = n/(args.ne2*args.ne1*args.ne0);
const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);

device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);

for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) {
T4x4 temp;
dequantize_func(src_data + i00/nl, i00%nl, temp);
dst_data[i00] = temp;
Expand Down
Loading
Loading