diff --git a/candle-metal-kernels/src/metal_src/affine.metal b/candle-metal-kernels/src/metal_src/affine.metal index e5229f55ee..16a896aaaf 100644 --- a/candle-metal-kernels/src/metal_src/affine.metal +++ b/candle-metal-kernels/src/metal_src/affine.metal @@ -1,18 +1,105 @@ #include +template +METAL_FUNC uint get_strided_idx( + uint idx, + constant const size_t *shape, + constant const size_t *strides +); + +template<> +METAL_FUNC uint get_strided_idx<1>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<2>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[1]) * strides[1] + + ((idx / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<3>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[2]) * strides[2] + + ((idx / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<4>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[3]) * strides[3] + + ((idx / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<5>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[4]) * strides[4] + + ((idx / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<6>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[5]) * strides[5] + + ((idx / shape[5]) % shape[4]) * strides[4] + + ((idx / shape[5] / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[5] / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + METAL_FUNC uint get_strided_index( uint idx, - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides + constant const size_t &num_dims, + constant const size_t *dims, + constant const size_t *strides ) { - uint strided_i = 0; - for (uint d = 0; d < num_dims; d++) { - uint dim_idx = num_dims - 1 - d; - strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; - idx /= dims[dim_idx]; + switch (num_dims) { + case 1: return get_strided_idx<1>(idx, dims, strides); + case 2: return get_strided_idx<2>(idx, dims, strides); + case 3: return get_strided_idx<3>(idx, dims, strides); + case 4: return get_strided_idx<4>(idx, dims, strides); + case 5: return get_strided_idx<5>(idx, dims, strides); + case 6: return get_strided_idx<6>(idx, dims, strides); + default: { + uint strided_i = 0; + #pragma clang loop unroll(full) + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; + } } - return strided_i; } using namespace metal; diff --git a/candle-metal-kernels/src/metal_src/binary.metal b/candle-metal-kernels/src/metal_src/binary.metal index e83498e40d..73e2360d75 100644 --- a/candle-metal-kernels/src/metal_src/binary.metal +++ b/candle-metal-kernels/src/metal_src/binary.metal @@ -3,19 +3,106 @@ #define MAX(x, y) ((x) > (y) ? (x) : (y)) #define MIN(x, y) ((x) < (y) ? (x) : (y)) +template +METAL_FUNC uint get_strided_idx( + uint idx, + constant const size_t *shape, + constant const size_t *strides +); + +template<> +METAL_FUNC uint get_strided_idx<1>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<2>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[1]) * strides[1] + + ((idx / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<3>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[2]) * strides[2] + + ((idx / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<4>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[3]) * strides[3] + + ((idx / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<5>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[4]) * strides[4] + + ((idx / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<6>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[5]) * strides[5] + + ((idx / shape[5]) % shape[4]) * strides[4] + + ((idx / shape[5] / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[5] / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + METAL_FUNC uint get_strided_index( uint idx, - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides + constant const size_t &num_dims, + constant const size_t *dims, + constant const size_t *strides ) { - uint strided_i = 0; - for (uint d = 0; d < num_dims; d++) { - uint dim_idx = num_dims - 1 - d; - strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; - idx /= dims[dim_idx]; + switch (num_dims) { + case 1: return get_strided_idx<1>(idx, dims, strides); + case 2: return get_strided_idx<2>(idx, dims, strides); + case 3: return get_strided_idx<3>(idx, dims, strides); + case 4: return get_strided_idx<4>(idx, dims, strides); + case 5: return get_strided_idx<5>(idx, dims, strides); + case 6: return get_strided_idx<6>(idx, dims, strides); + default: { + uint strided_i = 0; + #pragma clang loop unroll(full) + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; + } } - return strided_i; } using namespace metal; diff --git a/candle-metal-kernels/src/metal_src/cast.metal b/candle-metal-kernels/src/metal_src/cast.metal index 2af3fdceb0..367657c0ff 100644 --- a/candle-metal-kernels/src/metal_src/cast.metal +++ b/candle-metal-kernels/src/metal_src/cast.metal @@ -1,18 +1,105 @@ #include +template +METAL_FUNC uint get_strided_idx( + uint idx, + constant const size_t *shape, + constant const size_t *strides +); + +template<> +METAL_FUNC uint get_strided_idx<1>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<2>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[1]) * strides[1] + + ((idx / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<3>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[2]) * strides[2] + + ((idx / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<4>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[3]) * strides[3] + + ((idx / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<5>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[4]) * strides[4] + + ((idx / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<6>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[5]) * strides[5] + + ((idx / shape[5]) % shape[4]) * strides[4] + + ((idx / shape[5] / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[5] / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + METAL_FUNC uint get_strided_index( uint idx, - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides + constant const size_t &num_dims, + constant const size_t *dims, + constant const size_t *strides ) { - uint strided_i = 0; - for (uint d = 0; d < num_dims; d++) { - uint dim_idx = num_dims - 1 - d; - strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; - idx /= dims[dim_idx]; + switch (num_dims) { + case 1: return get_strided_idx<1>(idx, dims, strides); + case 2: return get_strided_idx<2>(idx, dims, strides); + case 3: return get_strided_idx<3>(idx, dims, strides); + case 4: return get_strided_idx<4>(idx, dims, strides); + case 5: return get_strided_idx<5>(idx, dims, strides); + case 6: return get_strided_idx<6>(idx, dims, strides); + default: { + uint strided_i = 0; + #pragma clang loop unroll(full) + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; + } } - return strided_i; } @@ -128,4 +215,4 @@ CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t) CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float) CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float) CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float) -#endif \ No newline at end of file +#endif diff --git a/candle-metal-kernels/src/metal_src/indexing.metal b/candle-metal-kernels/src/metal_src/indexing.metal index 0da416cfc6..46986f213d 100644 --- a/candle-metal-kernels/src/metal_src/indexing.metal +++ b/candle-metal-kernels/src/metal_src/indexing.metal @@ -19,19 +19,106 @@ inline uint8_t max_value() { return 0xFF; } +template +METAL_FUNC uint get_strided_idx( + uint idx, + constant const size_t *shape, + constant const size_t *strides +); + +template<> +METAL_FUNC uint get_strided_idx<1>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<2>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[1]) * strides[1] + + ((idx / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<3>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[2]) * strides[2] + + ((idx / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<4>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[3]) * strides[3] + + ((idx / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<5>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[4]) * strides[4] + + ((idx / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<6>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[5]) * strides[5] + + ((idx / shape[5]) % shape[4]) * strides[4] + + ((idx / shape[5] / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[5] / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + METAL_FUNC uint get_strided_index( uint idx, - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides + constant const size_t &num_dims, + constant const size_t *dims, + constant const size_t *strides ) { - uint strided_i = 0; - for (uint d = 0; d < num_dims; d++) { - uint dim_idx = num_dims - 1 - d; - strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; - idx /= dims[dim_idx]; + switch (num_dims) { + case 1: return get_strided_idx<1>(idx, dims, strides); + case 2: return get_strided_idx<2>(idx, dims, strides); + case 3: return get_strided_idx<3>(idx, dims, strides); + case 4: return get_strided_idx<4>(idx, dims, strides); + case 5: return get_strided_idx<5>(idx, dims, strides); + case 6: return get_strided_idx<6>(idx, dims, strides); + default: { + uint strided_i = 0; + #pragma clang loop unroll(full) + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; + } } - return strided_i; } template diff --git a/candle-metal-kernels/src/metal_src/reduce.metal b/candle-metal-kernels/src/metal_src/reduce.metal index 618f679892..e718aa3693 100644 --- a/candle-metal-kernels/src/metal_src/reduce.metal +++ b/candle-metal-kernels/src/metal_src/reduce.metal @@ -31,19 +31,106 @@ METAL_FUNC uint max_shared_mem(uint n) { return min(n, prev_p2(MAX_SHARED_MEM / sizeof(T))); } +template +METAL_FUNC uint get_strided_idx( + uint idx, + constant const size_t *shape, + constant const size_t *strides +); + +template<> +METAL_FUNC uint get_strided_idx<1>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<2>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[1]) * strides[1] + + ((idx / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<3>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[2]) * strides[2] + + ((idx / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<4>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[3]) * strides[3] + + ((idx / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<5>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[4]) * strides[4] + + ((idx / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<6>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[5]) * strides[5] + + ((idx / shape[5]) % shape[4]) * strides[4] + + ((idx / shape[5] / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[5] / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + METAL_FUNC uint get_strided_index( uint idx, constant const uint &num_dims, constant const size_t *dims, constant const size_t *strides ) { - uint strided_i = 0; - for (uint d = 0; d < num_dims; d++) { - uint dim_idx = num_dims - 1 - d; - strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; - idx /= dims[dim_idx]; + switch (num_dims) { + case 1: return get_strided_idx<1>(idx, dims, strides); + case 2: return get_strided_idx<2>(idx, dims, strides); + case 3: return get_strided_idx<3>(idx, dims, strides); + case 4: return get_strided_idx<4>(idx, dims, strides); + case 5: return get_strided_idx<5>(idx, dims, strides); + case 6: return get_strided_idx<6>(idx, dims, strides); + default: { + uint strided_i = 0; + #pragma clang loop unroll(full) + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; + } } - return strided_i; } struct Divide { diff --git a/candle-metal-kernels/src/metal_src/ternary.metal b/candle-metal-kernels/src/metal_src/ternary.metal index fe04f2378f..861a040aa0 100644 --- a/candle-metal-kernels/src/metal_src/ternary.metal +++ b/candle-metal-kernels/src/metal_src/ternary.metal @@ -1,19 +1,106 @@ #include using namespace metal; +template +METAL_FUNC uint get_strided_idx( + uint idx, + constant const size_t *shape, + constant const size_t *strides +); + +template<> +METAL_FUNC uint get_strided_idx<1>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<2>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[1]) * strides[1] + + ((idx / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<3>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[2]) * strides[2] + + ((idx / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<4>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[3]) * strides[3] + + ((idx / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<5>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[4]) * strides[4] + + ((idx / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<6>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[5]) * strides[5] + + ((idx / shape[5]) % shape[4]) * strides[4] + + ((idx / shape[5] / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[5] / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + METAL_FUNC uint get_strided_index( uint idx, - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides + constant const size_t &num_dims, + constant const size_t *dims, + constant const size_t *strides ) { - uint strided_i = 0; - for (uint d = 0; d < num_dims; d++) { - uint dim_idx = num_dims - 1 - d; - strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; - idx /= dims[dim_idx]; + switch (num_dims) { + case 1: return get_strided_idx<1>(idx, dims, strides); + case 2: return get_strided_idx<2>(idx, dims, strides); + case 3: return get_strided_idx<3>(idx, dims, strides); + case 4: return get_strided_idx<4>(idx, dims, strides); + case 5: return get_strided_idx<5>(idx, dims, strides); + case 6: return get_strided_idx<6>(idx, dims, strides); + default: { + uint strided_i = 0; + #pragma clang loop unroll(full) + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; + } } - return strided_i; } template diff --git a/candle-metal-kernels/src/metal_src/unary.metal b/candle-metal-kernels/src/metal_src/unary.metal index 368b9f2077..d97b34fb64 100644 --- a/candle-metal-kernels/src/metal_src/unary.metal +++ b/candle-metal-kernels/src/metal_src/unary.metal @@ -3,19 +3,106 @@ # using namespace metal; +template +METAL_FUNC uint get_strided_idx( + uint idx, + constant const size_t *shape, + constant const size_t *strides +); + +template<> +METAL_FUNC uint get_strided_idx<1>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<2>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[1]) * strides[1] + + ((idx / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<3>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[2]) * strides[2] + + ((idx / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<4>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[3]) * strides[3] + + ((idx / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<5>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[4]) * strides[4] + + ((idx / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + +template<> +METAL_FUNC uint get_strided_idx<6>( + uint idx, + constant const size_t *shape, + constant const size_t *strides +) { + return (idx % shape[5]) * strides[5] + + ((idx / shape[5]) % shape[4]) * strides[4] + + ((idx / shape[5] / shape[4]) % shape[3]) * strides[3] + + ((idx / shape[5] / shape[4] / shape[3]) % shape[2]) * strides[2] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2]) % shape[1]) * strides[1] + + ((idx / shape[5] / shape[4] / shape[3] / shape[2] / shape[1]) % shape[0]) * strides[0]; +} + METAL_FUNC uint get_strided_index( uint idx, - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides + constant const size_t &num_dims, + constant const size_t *dims, + constant const size_t *strides ) { - uint strided_i = 0; - for (uint d = 0; d < num_dims; d++) { - uint dim_idx = num_dims - 1 - d; - strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; - idx /= dims[dim_idx]; + switch (num_dims) { + case 1: return get_strided_idx<1>(idx, dims, strides); + case 2: return get_strided_idx<2>(idx, dims, strides); + case 3: return get_strided_idx<3>(idx, dims, strides); + case 4: return get_strided_idx<4>(idx, dims, strides); + case 5: return get_strided_idx<5>(idx, dims, strides); + case 6: return get_strided_idx<6>(idx, dims, strides); + default: { + uint strided_i = 0; + #pragma clang loop unroll(full) + for (uint d = 0; d < num_dims; d++) { + uint dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; + } } - return strided_i; } template METAL_FUNC T sqr(T in){ return in * in; }