Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 96 additions & 9 deletions candle-metal-kernels/src/metal_src/affine.metal
Original file line number Diff line number Diff line change
@@ -1,18 +1,105 @@
#include <metal_stdlib>

template<uint D>
METAL_FUNC uint get_strided_idx(
uint idx,
constant const size_t *shape,
constant const size_t *strides
);

template<>
METAL_FUNC uint get_strided_idx<1>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<2>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[1]) * strides[1]
+ ((idx / shape[1]) % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<3>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[2]) * strides[2]
+ ((idx / shape[2]) % shape[1]) * strides[1]
+ ((idx / shape[2] / shape[1]) % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<4>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[3]) * strides[3]
+ ((idx / shape[3]) % shape[2]) * strides[2]
+ ((idx / shape[3] / shape[2]) % shape[1]) * strides[1]
+ ((idx / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<5>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[4]) * strides[4]
+ ((idx / shape[4]) % shape[3]) * strides[3]
+ ((idx / shape[4] / shape[3]) % shape[2]) * strides[2]
+ ((idx / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1]
+ ((idx / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<6>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[5]) * strides[5]
+ ((idx / shape[5]) % shape[4]) * strides[4]
+ ((idx / shape[5] / shape[4]) % shape[3]) * strides[3]
+ ((idx / shape[5] / shape[4] / shape[3]) % shape[2]) * strides[2]
+ ((idx / shape[5] / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1]
+ ((idx / shape[5] / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0];
}

METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
constant const size_t &num_dims,
constant const size_t *dims,
constant const size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
switch (num_dims) {
case 1: return get_strided_idx<1>(idx, dims, strides);
case 2: return get_strided_idx<2>(idx, dims, strides);
case 3: return get_strided_idx<3>(idx, dims, strides);
case 4: return get_strided_idx<4>(idx, dims, strides);
case 5: return get_strided_idx<5>(idx, dims, strides);
case 6: return get_strided_idx<6>(idx, dims, strides);
default: {
uint strided_i = 0;
#pragma clang loop unroll(full)
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
}
return strided_i;
}

using namespace metal;
Expand Down
105 changes: 96 additions & 9 deletions candle-metal-kernels/src/metal_src/binary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,106 @@
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))

template<uint D>
METAL_FUNC uint get_strided_idx(
uint idx,
constant const size_t *shape,
constant const size_t *strides
);

template<>
METAL_FUNC uint get_strided_idx<1>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<2>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[1]) * strides[1]
+ ((idx / shape[1]) % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<3>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[2]) * strides[2]
+ ((idx / shape[2]) % shape[1]) * strides[1]
+ ((idx / shape[2] / shape[1]) % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<4>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[3]) * strides[3]
+ ((idx / shape[3]) % shape[2]) * strides[2]
+ ((idx / shape[3] / shape[2]) % shape[1]) * strides[1]
+ ((idx / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<5>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[4]) * strides[4]
+ ((idx / shape[4]) % shape[3]) * strides[3]
+ ((idx / shape[4] / shape[3]) % shape[2]) * strides[2]
+ ((idx / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1]
+ ((idx / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<6>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[5]) * strides[5]
+ ((idx / shape[5]) % shape[4]) * strides[4]
+ ((idx / shape[5] / shape[4]) % shape[3]) * strides[3]
+ ((idx / shape[5] / shape[4] / shape[3]) % shape[2]) * strides[2]
+ ((idx / shape[5] / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1]
+ ((idx / shape[5] / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0];
}

METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
constant const size_t &num_dims,
constant const size_t *dims,
constant const size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
switch (num_dims) {
case 1: return get_strided_idx<1>(idx, dims, strides);
case 2: return get_strided_idx<2>(idx, dims, strides);
case 3: return get_strided_idx<3>(idx, dims, strides);
case 4: return get_strided_idx<4>(idx, dims, strides);
case 5: return get_strided_idx<5>(idx, dims, strides);
case 6: return get_strided_idx<6>(idx, dims, strides);
default: {
uint strided_i = 0;
#pragma clang loop unroll(full)
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
}
return strided_i;
}

using namespace metal;
Expand Down
107 changes: 97 additions & 10 deletions candle-metal-kernels/src/metal_src/cast.metal
Original file line number Diff line number Diff line change
@@ -1,18 +1,105 @@
#include <metal_stdlib>

template<uint D>
METAL_FUNC uint get_strided_idx(
uint idx,
constant const size_t *shape,
constant const size_t *strides
);

template<>
METAL_FUNC uint get_strided_idx<1>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<2>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[1]) * strides[1]
+ ((idx / shape[1]) % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<3>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[2]) * strides[2]
+ ((idx / shape[2]) % shape[1]) * strides[1]
+ ((idx / shape[2] / shape[1]) % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<4>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[3]) * strides[3]
+ ((idx / shape[3]) % shape[2]) * strides[2]
+ ((idx / shape[3] / shape[2]) % shape[1]) * strides[1]
+ ((idx / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<5>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[4]) * strides[4]
+ ((idx / shape[4]) % shape[3]) * strides[3]
+ ((idx / shape[4] / shape[3]) % shape[2]) * strides[2]
+ ((idx / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1]
+ ((idx / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0];
}

template<>
METAL_FUNC uint get_strided_idx<6>(
uint idx,
constant const size_t *shape,
constant const size_t *strides
) {
return (idx % shape[5]) * strides[5]
+ ((idx / shape[5]) % shape[4]) * strides[4]
+ ((idx / shape[5] / shape[4]) % shape[3]) * strides[3]
+ ((idx / shape[5] / shape[4] / shape[3]) % shape[2]) * strides[2]
+ ((idx / shape[5] / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1]
+ ((idx / shape[5] / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0];
}

METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
constant const size_t &num_dims,
constant const size_t *dims,
constant const size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
switch (num_dims) {
case 1: return get_strided_idx<1>(idx, dims, strides);
case 2: return get_strided_idx<2>(idx, dims, strides);
case 3: return get_strided_idx<3>(idx, dims, strides);
case 4: return get_strided_idx<4>(idx, dims, strides);
case 5: return get_strided_idx<5>(idx, dims, strides);
case 6: return get_strided_idx<6>(idx, dims, strides);
default: {
uint strided_i = 0;
#pragma clang loop unroll(full)
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
}
return strided_i;
}


Expand Down Expand Up @@ -128,4 +215,4 @@ CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t)
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
#endif
#endif
Loading
Loading