Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1480,20 +1480,23 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_met
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_L2_NORM);

GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));

char base[256];
char name[256];

snprintf(base, 256, "kernel_l2_norm_f32");
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;

const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);

snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}

res.c4 = is_c4;
res.smem = 32*sizeof(float);

return res;
Expand Down
3 changes: 1 addition & 2 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -1086,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&
Expand Down
15 changes: 14 additions & 1 deletion ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,21 @@ typedef struct {

typedef struct {
int32_t ne00;
int32_t ne00_4;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float eps;
} ggml_metal_kargs_l2_norm;

Expand Down
46 changes: 33 additions & 13 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2979,39 +2979,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);

GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));

ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);

float eps;
memcpy(&eps, op->op_params, sizeof(float));

int nth = 32; // SIMD width

ggml_metal_kargs_l2_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.eps =*/ eps,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.eps =*/ eps,
};

auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);

while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}

int nth = 32; // SIMD width

while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}

nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00/4);

const size_t smem = pipeline.smem;

const int64_t nrows = ggml_nrows(op->src[0]);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);

ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);

ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);

return 1;
}
Expand Down
30 changes: 20 additions & 10 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2706,26 +2706,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f
template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;

kernel void kernel_l2_norm_f32(
template <typename T0, typename T>
kernel void kernel_l2_norm_impl(
constant ggml_metal_kargs_l2_norm & args,
device const char * src0,
device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
ushort tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort ntg[[threads_per_threadgroup]]) {
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig.z;
const int i02 = tgpig.y;
const int i01 = tgpig.x;

if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f;
}

device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);

float sumf = 0.0f;

// parallel sum
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
sumf += dot(x[i00], x[i00]);
}
sumf = simd_sum(sumf);
Expand All @@ -2743,12 +2749,16 @@ kernel void kernel_l2_norm_f32(

const float scale = 1.0f/sqrt(max(sumf, args.eps));

device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
y[i00] = x[i00] * scale;
}
}

typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;

template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;

kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,
Expand Down
Loading