diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 2baa6c05bf..7a0a167d7c 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -120,6 +120,14 @@ if(NOT MLX_METAL_PATH) set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/) endif() +if((MLX_METAL_VERSION GREATER_EQUAL 400) AND (MACOS_SDK_VERSION GREATER_EQUAL + 26.2)) + set(MLX_ENABLE_NAX TRUE) + target_compile_definitions(mlx PRIVATE MLX_ENABLE_NAX) +else() + set(MLX_ENABLE_NAX FALSE) +endif() + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels) target_compile_definitions(mlx diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index fefb7cdc0c..746fbc0888 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -265,4 +265,14 @@ Device& device(mlx::core::Device); std::unique_ptr> new_scoped_memory_pool(); +#ifdef MLX_ENABLE_NAX + +inline bool is_nax_available() { + static bool is_nax_available_ = + metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17; + return is_nax_available_; +} + +#endif // MLX_ENABLE_NAX + } // namespace mlx::core::metal diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index c2842d5343..5215fb3460 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -9,10 +9,13 @@ set(BASE_HEADERS utils.h) function(build_kernel_base TARGET SRCFILE DEPS) - set(METAL_FLAGS -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) + set(METAL_FLAGS -x metal -Wall -Wextra -fno-fast-math -Wno-c++17-extensions) if(MLX_METAL_DEBUG) set(METAL_FLAGS ${METAL_FLAGS} -gline-tables-only -frecord-sources) endif() + if(MLX_ENABLE_NAX) + set(METAL_FLAGS ${METAL_FLAGS} -Wno-c++20-extensions -std=metal4.0) + endif() if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "") set(METAL_FLAGS ${METAL_FLAGS} "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") @@ -120,6 +123,30 @@ if(NOT MLX_METAL_JIT) build_kernel(gemv_masked steel/utils.h) endif() +if(MLX_ENABLE_NAX) + + set(STEEL_NAX_HEADERS + steel/defines.h + steel/utils.h + steel/gemm/transforms.h + steel/gemm/nax.h + steel/gemm/gemm_nax.h + steel/utils/type_traits.h + steel/utils/integral_constant.h) + + build_kernel(steel/gemm/kernels/steel_gemm_fused_nax ${STEEL_NAX_HEADERS}) + build_kernel(steel/gemm/kernels/steel_gemm_gather_nax ${STEEL_NAX_HEADERS}) + + build_kernel(quantized_nax quantized_nax.h ${STEEL_NAX_HEADERS}) + build_kernel(fp_quantized_nax fp_quantized_nax.h ${STEEL_NAX_HEADERS}) + + set(STEEL_NAX_ATTN_HEADERS + steel/defines.h steel/utils.h steel/attn/nax.h steel/utils/type_traits.h + steel/utils/integral_constant.h) + + build_kernel(steel/attn/kernels/steel_attention_nax ${STEEL_NAX_ATTN_HEADERS}) +endif() + add_custom_command( OUTPUT ${MLX_METAL_PATH}/mlx.metallib COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o diff --git a/mlx/backend/metal/kernels/fp_quantized_nax.h b/mlx/backend/metal/kernels/fp_quantized_nax.h new file mode 100644 index 0000000000..abd90834ba --- /dev/null +++ b/mlx/backend/metal/kernels/fp_quantized_nax.h @@ -0,0 +1,1066 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/kernels/fp4.h" +#include "mlx/backend/metal/kernels/fp8.h" + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return wsize / 4; +} + +template +inline constexpr short get_bytes_per_pack() { + return wsize / 8; +} + +template +static inline T dequantize_scale(uint8_t s) { + return T(*(thread fp8_e8m0*)(&s)); +} + +template +struct Quantize { + uint8_t operator()(float x) { + if constexpr (bits == 8) { + return fp8_e4m3(x).bits; + } else { + return fp4_e2m1(x).bits; + } + } +}; + +template +struct Dequantize { + float operator()(uint8_t x) { + if constexpr (bits == 8) { + return float(*(thread fp8_e4m3*)(&x)); + } else { + return float(*(thread fp4_e2m1*)(&x)); + } + } +}; + +template +inline void dequantize( + const device uint8_t* w, + U scale, + threadgroup U* w_local, + const threadgroup U* lut) { + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = scale * lut[w[i] & 0xf]; + w_local[2 * i + 1] = scale * lut[(w[i] >> 4) & 0xf]; + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size> +struct QuantizedBlockLoader { + static_assert( + BCOLS % group_size == 0, + "The group size should be divisible by the columns"); + + MLX_MTL_CONST short pack_factor = get_pack_factor<8>(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short n_groups = BCOLS / group_size; + + static_assert( + (BCOLS_PACKED / n_reads) == n_groups, + "Other configurations are not yet supported"); + + const int src_ld; + const int tile_stride; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + const short group_id; + + threadgroup T* dst; + const device uint8_t* src; + const device uint8_t* scales; + threadgroup T* lut; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device uint8_t* scales_, + const int src_ld_, + threadgroup T* dst_, + threadgroup T* lut_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + group_id((bj * pack_factor) / group_size), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size + group_id), + lut(lut_) { + if (simd_group_id == 0 && simd_lane_id < 16) { + lut[simd_lane_id] = static_cast(FP4_LUT[simd_lane_id]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, dst + i * pack_factor, lut); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + dst + i * pack_factor, + lut); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + // if (group_steps > 1) { + // group_step_cnt++; + // if (group_step_cnt == group_steps) { + // group_step_cnt = 0; + // scales++; + // } + // } else { + scales += n_groups; + // } + } else { + scales += n_groups * group_stride; + } + } +}; + +using namespace mlx::steel; + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +METAL_FUNC void fp_qmm_t_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + threadgroup Wtype* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup Wtype* lut) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); + + // Instantiate Loader + using loader_w_t = QuantizedBlockLoader< + Wtype, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the weight loader + loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid); + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + + const short sgp_sm = min(SM, short(M - (y_row + tm))); + const bool is_unaligned_sm = (sgp_sm != SM); + + const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn))); + + const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col))); + const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN); + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + NAXTile Dtile; + + Dtile.clear(); + + x += tm * K; + + dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe(short2(BK, tgp_bn)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(x + kk1, K); + } else { + Atile.load_safe(x + kk1, K, short2(SK, sgp_sm)); + } + + Btile.template load(Ws + tn * BK_padded + kk1); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + + if constexpr (kAlignedM.value && kAlignedN.value) { + Dtile.store(y + tm * N + tn, N); + } else if (kAlignedM.value && sgp_sn == SN) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm)); + } + }); + }); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +METAL_FUNC void fp_qmm_n_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup T* lut) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = mlx::steel:: + BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size>; + + auto wl = (const device uint8_t*)w; + + // Set the block + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * static_cast(K); + wl += y_col * bytes_per_pack / pack_factor; + scales += y_col / group_size; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, N, Ws, lut, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, num_els)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, BM)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const bool batched, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +[[kernel]] void fp_qmm_t_nax( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); + + threadgroup Wtype Ws[BN * BK_padded]; + threadgroup Wtype lut[16]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + fp_qmm_t_impl( + w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + const int bits, + const bool batched, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +[[kernel]] void fp_qmm_n_nax( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + + fp_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +[[kernel]] void fp_gather_qmm_t_nax( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); + + threadgroup Wtype Ws[BN * BK_padded]; + threadgroup Wtype lut[16]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmm_t_impl( + w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +[[kernel]] void fp_gather_qmm_n_nax( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + int group_size, + const int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose, + typename Wtype = bfloat> +[[kernel]] void fp_gather_qmm_rhs_nax( + const device T* x, + const device uint32_t* w, + const device uint8_t* scales, + const device uint32_t* indices, + device T* y, + const constant int& M, + const constant int& N, + const constant int& K, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); + constexpr int BN_padded = (BN + 16 / sizeof(Wtype)); + + threadgroup Wtype lut[16]; + + using loader_w_t = QuantizedBlockLoader< + Wtype, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size>; + + threadgroup Wtype Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = + align_M ? SM : min(SM, short(max(0, (M - (y_row + tm))))); + const short sgp_sn = + align_N ? SN : min(SN, short(max(0, (N - (y_col + tn))))); + + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN); + + constexpr short BR = transpose ? TN : TK; + constexpr short BC = transpose ? TK : TN; + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + NAXTile Dtile; + + Dtile.clear(); + + const device T* xn = x + tm * K; + + // Prepare threadgroup loading operations + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + transpose ? K : N, + Ws, + lut, + simd_group_id, + simd_lane_id); + + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K_it; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe( + transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(xn + kk1, K); + } else { + Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm)); + } + + if constexpr (transpose) { + Btile.template load( + Ws + tn * BK_padded + kk1); + } else { + Btile.template load( + Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + xn += BK; + loader_w.next(); + } + + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_safe(tile_w); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + const short psk = min(int(SK), max(0, (BK - kk1))); + Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm)); + + if constexpr (transpose) { + Btile.template load( + Ws + tn * BK_padded + kk1); + } else { + Btile.template load( + Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm)); + const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm)); + + // Store results to device memory + if constexpr (kAlignedN.value) { + if (m_lo_lim == 0 && m_hi_lim == SM) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_slice( + y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim)); + } + } else { + Dtile.store_slice( + y + tm * N + tn, + N, + short2(0, m_lo_lim), + short2(sgp_sn, m_hi_lim)); + } + }); + }); + } +} diff --git a/mlx/backend/metal/kernels/fp_quantized_nax.metal b/mlx/backend/metal/kernels/fp_quantized_nax.metal new file mode 100644 index 0000000000..bd2df2b71e --- /dev/null +++ b/mlx/backend/metal/kernels/fp_quantized_nax.metal @@ -0,0 +1,74 @@ +// Copyright © 2025 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/quantized_utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/nax.h" +#include "mlx/backend/metal/kernels/fp_quantized_nax.h" + + +#define instantiate_quantized_batched(mode, name, type, bm, bn, bk, wm, wn, batched) \ + instantiate_kernel( \ + #mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \ + fp_ ## name, \ + type, \ + 32, \ + 4, \ + batched) + +#define instantiate_quantized_aligned(mode, name, type, bm, bn, bk, wm, wn, aligned) \ + instantiate_kernel( \ + #mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \ + fp_ ## name, \ + type, \ + 32, \ + 4, \ + aligned) + +#define instantiate_quantized_aligned_batched(mode, name, type, bm, bn, bk, wm, wn, aligned, batched) \ + instantiate_kernel( \ + #mode "_" #name "_" #type "_gs_32_b_4_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \ + fp_ ## name, \ + type, \ + 32, \ + 4, \ + aligned, \ + batched) + +#define instantiate_gather_qmm_rhs(func, name, type, bm, bn, bk, wm, wn, transpose) \ + instantiate_kernel( \ + #name "_" #type "_gs_32_b_4_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ + func, \ + type, \ + 32, \ + 4, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + transpose) + + +#define instantiate_quantized_all_aligned(type) \ + instantiate_quantized_aligned(mxfp4, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, true) \ + instantiate_quantized_aligned(mxfp4, gather_qmm_t_nax, type, 64, 64, 64, 2, 2, false) \ + instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 1) \ + instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, true, 0) \ + instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 1) \ + instantiate_quantized_aligned_batched(mxfp4, qmm_t_nax, type, 64, 64, 64, 2, 2, false, 0) + + +#define instantiate_quantized_all_rhs(type) \ + instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, mxfp4_gather_qmm_rhs_nax_nt, type, 64, 64, 64, 2, 2, true) \ + instantiate_gather_qmm_rhs(fp_gather_qmm_rhs_nax, mxfp4_gather_qmm_rhs_nax_nn, type, 64, 64, 64, 2, 2, false) + +#define instantiate_quantized_types(type) \ + instantiate_quantized_all_aligned(type) \ + instantiate_quantized_all_rhs(type) + +instantiate_quantized_types(float) +instantiate_quantized_types(bfloat16_t) +instantiate_quantized_types(float16_t) + // clang-format on diff --git a/mlx/backend/metal/kernels/quantized_nax.h b/mlx/backend/metal/kernels/quantized_nax.h new file mode 100644 index 0000000000..c26ff646bb --- /dev/null +++ b/mlx/backend/metal/kernels/quantized_nax.h @@ -0,0 +1,1705 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +using namespace metal; +using namespace mlx::steel; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + +template +inline U load_vector(const device T* x, thread U* x_thread) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + return sum; +} + +template +inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + + return sum; +} + +template +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline U qdot_safe( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum, + int N) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline void +qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + if (bits == 2) { + U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; + for (int i = 0; i < (values_per_thread / 4); i++) { + result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); + result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); + result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); + result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[8 * i] += x * ((w0 & 0x7) * scale + bias); + result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); + result[8 * i + 2] += + x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); + result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); + result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); + result[8 * i + 5] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); + result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); + result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / 16.0f}; + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); + result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[5 * i]; + uint8_t w1 = w[5 * i + 1]; + uint8_t w2 = w[5 * i + 2]; + uint8_t w3 = w[5 * i + 3]; + uint8_t w4 = w[5 * i + 4]; + result[8 * i] += x * ((w0 & 0x1f) * scale + bias); + result[8 * i + 1] += + x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); + result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); + result[8 * i + 3] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); + result[8 * i + 4] += + x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); + result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); + result[8 * i + 6] += + x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); + result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[4 * i] += x * ((w0 & 0x3f) * scale + bias); + result[4 * i + 1] += + x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); + result[4 * i + 2] += + x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); + result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + result[i] += x * (scale * w[i] + bias); + } + } +} + +template +inline void +dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + if (bits == 2) { + U s[4] = { + scale, + scale / static_cast(4.0f), + scale / static_cast(16.0f), + scale / static_cast(64.0f)}; + for (int i = 0; i < (N / 4); i++) { + w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; + w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; + w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; + w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 3 * i; + + w_local[0] = (w[0] & 0x7) * scale + bias; + w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; + w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; + w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; + w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / static_cast(16.0f)}; + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; + w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 5 * i; + + w_local[0] = (w[0] & 0x1f) * scale + bias; + w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + w_local += 4 * i; + w += 3 * i; + w_local[0] = (w[0] & 0x3f) * scale + bias; + w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + w_local[i] = scale * w[i] + bias; + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + short bits> +struct QuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + const device T* scales; + const device T* biases; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device T* scales_, + const device T* biases_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size), + biases(biases_ + bi * src_ld / group_size) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + bias, + dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + biases++; + } + } else { + scales++; + biases++; + } + } else { + scales += group_stride; + biases += group_stride; + } + } +}; + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short bits> +struct QuantizedBlockLoader< + T, + BROWS, + BCOLS, + dst_ld, + reduction_dim, + tgp_size, + 32, + bits> { + MLX_MTL_CONST short group_size = 32; + + static_assert( + BCOLS % group_size == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short n_groups = BCOLS / group_size; + + static_assert( + (BCOLS_PACKED / n_reads) == n_groups, + "Other configurations are not yet supported"); + + const int src_ld; + const int tile_stride; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + const short group_id; + + threadgroup T* dst; + const device uint8_t* src; + const device T* scales; + const device T* biases; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device T* scales_, + const device T* biases_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + group_id((bj * pack_factor) / group_size), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size + group_id), + biases(biases_ + bi * src_ld / group_size + group_id) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + bias, + dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + // if (group_steps > 1) { + // group_step_cnt++; + // if (group_step_cnt == group_steps) { + // group_step_cnt = 0; + // scales++; + // biases++; + // } + // } else { + scales += n_groups; + biases += n_groups; + // } + } else { + scales += n_groups * group_stride; + biases += n_groups * group_stride; + } + } +}; + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +METAL_FUNC void qmm_t_nax_tgp_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + biases += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the weight loader + loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + + const short sgp_sm = min(SM, short(M - (y_row + tm))); + const bool is_unaligned_sm = (sgp_sm != SM); + + const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn))); + + const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col))); + const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN); + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + NAXTile Dtile; + + Dtile.clear(); + + x += tm * K; + + dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe(short2(BK, tgp_bn)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(x + kk1, K); + } else { + Atile.load_safe(x + kk1, K, short2(SK, sgp_sm)); + } + + Btile.template load(Ws + tn * BK_padded + kk1); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + + if constexpr (kAlignedM.value && kAlignedN.value) { + Dtile.store(y + tm * N + tn, N); + } else if (kAlignedM.value && sgp_sn == SN) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm)); + } + }); + }); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +METAL_FUNC void qmm_n_nax_tgp_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + (void)M; + + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + biases += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + // const short num_els = min(BM, M - y_row); + // const short num_outs = min(BN, N - y_col); + loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + const short ldb_tgp = BN_padded; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = false; + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + NAXTile Dtile; + + Dtile.clear(); + + x += tm * K; + + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + Atile.load(x + kk1, K); + Btile.template load(Ws + tn + kk1 * ldb_tgp); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + + Dtile.store(y + tm * N + tn, N); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const bool batched, + const int BM = 64, + const int BK = 32, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_qmm_t_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant int64_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Ws[BN * BK_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmm_t_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool batched, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_qmm_n_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant int64_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Ws[BK * BN_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + + qmm_n_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_gather_qmm_t_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Ws[BN * BK_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_t_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_gather_qmm_n_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Ws[BK * BN_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_n_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + int group_size, + int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void affine_gather_qmm_rhs_nax( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* indices [[buffer(4)]], + device T* y [[buffer(5)]], + const constant int& M [[buffer(6)]], + const constant int& N [[buffer(7)]], + const constant int& K [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + biases += transpose ? y_col_long * K_g : y_col / group_size; + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = + align_M ? SM : min(SM, short(max(0, (M - (y_row + tm))))); + const short sgp_sn = + align_N ? SN : min(SN, short(max(0, (N - (y_col + tn))))); + + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN); + + constexpr short BR = transpose ? TN : TK; + constexpr short BC = transpose ? TK : TN; + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + NAXTile Dtile; + + Dtile.clear(); + + const device T* xn = x + tm * K; + + // Prepare threadgroup loading operations + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + biases + index * stride_s, + transpose ? K : N, + Ws, + simd_group_id, + simd_lane_id); + + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K_it; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe( + transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(xn + kk1, K); + } else { + Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm)); + } + + if constexpr (transpose) { + Btile.template load(Ws + tn * BK_padded + kk1); + } else { + Btile.template load(Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + xn += BK; + loader_w.next(); + } + + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_safe(tile_w); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + const short psk = min(int(SK), max(0, (BK - kk1))); + Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm)); + + if constexpr (transpose) { + Btile.template load(Ws + tn * BK_padded + kk1); + } else { + Btile.template load(Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm)); + const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm)); + + // Store results to device memory + if constexpr (kAlignedN.value) { + if (m_lo_lim == 0 && m_hi_lim == SM) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_slice( + y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim)); + } + } else { + Dtile.store_slice( + y + tm * N + tn, + N, + short2(0, m_lo_lim), + short2(sgp_sn, m_hi_lim)); + } + }); + }); + } +} \ No newline at end of file diff --git a/mlx/backend/metal/kernels/quantized_nax.metal b/mlx/backend/metal/kernels/quantized_nax.metal new file mode 100644 index 0000000000..5a9c9fb874 --- /dev/null +++ b/mlx/backend/metal/kernels/quantized_nax.metal @@ -0,0 +1,106 @@ +// Copyright © 2023-2024 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/steel/gemm/gemm.h" +#include "mlx/backend/metal/kernels/steel/gemm/nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/loader.h" +#include "mlx/backend/metal/kernels/quantized_nax.h" + +#define instantiate_quantized(name, type, group_size, bits, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits, \ + name, \ + type, \ + group_size, \ + bits, bm, bk, bn, wm, wn) + +#define instantiate_quantized_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, batched) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_batch_" #batched, \ + name, \ + type, \ + group_size, \ + bits, \ + batched, bm, bk, bn, wm, wn) + +#define instantiate_quantized_aligned(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned, \ + name, \ + type, \ + group_size, \ + bits, \ + aligned, bm, bk, bn, wm, wn) + +#define instantiate_quantized_aligned_batched(name, type, group_size, bits, bm, bn, bk, wm, wn, aligned, batched) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_alN_" #aligned "_batch_" #batched, \ + name, \ + type, \ + group_size, \ + bits, \ + aligned, \ + batched, bm, bk, bn, wm, wn) + +#define instantiate_gather_qmm_rhs(func, name, type, group_size, bits, bm, bn, bk, wm, wn, transpose) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits "_bm_" #bm "_bn_" #bn "_bk_" #bk "_wm_" #wm "_wn_" #wn, \ + func, \ + type, \ + group_size, \ + bits, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + transpose) + +#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ + instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 1) \ + instantiate_quantized_batched(name, type, group_size, bits, 64, 64, 64, 2, 2, 0) + +#define instantiate_quantized_all_batched(type, group_size, bits) \ + instantiate_quantized_batched_wrap(affine_qmm_n_nax, type, group_size, bits) + + +#define instantiate_quantized_all_single(type, group_size, bits) \ + instantiate_quantized(affine_gather_qmm_n_nax, type, group_size, bits, 64, 64, 64, 2, 2) + +#define instantiate_quantized_all_aligned(type, group_size, bits) \ + instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true) \ + instantiate_quantized_aligned(affine_gather_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false) \ + instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 1) \ + instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, true, 0) \ + instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 1) \ + instantiate_quantized_aligned_batched(affine_qmm_t_nax, type, group_size, bits, 64, 64, 64, 2, 2, false, 0) + +#define instantiate_quantized_all_rhs(type, group_size, bits) \ + instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nt, type, group_size, bits, 64, 64, 64, 2, 2, true) \ + instantiate_gather_qmm_rhs(affine_gather_qmm_rhs_nax, affine_gather_qmm_rhs_nax_nn, type, group_size, bits, 64, 64, 64, 2, 2, false) + +#define instantiate_quantized_funcs(type, group_size, bits) \ + instantiate_quantized_all_batched(type, group_size, bits) \ + instantiate_quantized_all_aligned(type, group_size, bits) \ + instantiate_quantized_all_rhs(type, group_size, bits) + +#define instantiate_quantized_types(group_size, bits) \ + instantiate_quantized_funcs(float, group_size, bits) \ + instantiate_quantized_funcs(float16_t, group_size, bits) \ + instantiate_quantized_funcs(bfloat16_t, group_size, bits) + +#define instantiate_quantized_groups(bits) \ + instantiate_quantized_types(128, bits) \ + instantiate_quantized_types(64, bits) \ + instantiate_quantized_types(32, bits) + +#define instantiate_quantized_all() \ + instantiate_quantized_groups(2) \ + instantiate_quantized_groups(3) \ + instantiate_quantized_groups(4) \ + instantiate_quantized_groups(5) \ + instantiate_quantized_groups(6) \ + instantiate_quantized_groups(8) + +instantiate_quantized_all() // clang-format on diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h new file mode 100644 index 0000000000..3a2136f4b7 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h @@ -0,0 +1,476 @@ +// Copyright © 2024-25 Apple Inc. + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool align_Q [[function_constant(200)]]; +constant bool align_K [[function_constant(201)]]; + +constant bool has_mask [[function_constant(300)]]; +constant bool do_causal [[function_constant(301)]]; +constant bool has_sinks [[function_constant(302)]]; + +template +struct TransformScale { + T scale; + METAL_FUNC TransformScale(T scale_) : scale(scale_) {} + + METAL_FUNC T apply(T x) const { + return scale * x; + } +}; + +struct MaxOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return metal::max(x, y); + } +}; + +struct SumOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x + y; + } +}; + +struct MulOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x * y; + } +}; + +struct SubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x - y; + } +}; + +struct ExpSubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return fast::exp2(x - y); + } +}; + +struct DivOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x / y; + } +}; + +// clang-format off +template < + typename T, + int BQ, + int BK, + int BD, + int WM, + int WN, + typename MaskType = float, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_nax( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant AttnParams* params [[buffer(4)]], + const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], + const device MaskType* mask [[buffer(6), function_constant(has_mask)]], + const device T* sinks [[buffer(7), function_constant(has_sinks)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + + // Pacifying compiler + (void)lid; + (void)simd_lane_id; + + // Move to correct block + ulong3 tidl{tid.x, tid.y, tid.z}; + + Q += tidl.z * params->Q_strides[0] + // Batch + tidl.y * params->Q_strides[1] + // Head + tidl.x * BQ * params->Q_strides[2]; // Sequence + + ulong kv_head_idx = int(tid.y) / params->gqa_factor; + K += tidl.z * params->K_strides[0] + // Batch + kv_head_idx * params->K_strides[1]; // Head + + V += tidl.z * params->V_strides[0] + // Batch + kv_head_idx * params->V_strides[1]; // Head + + O += tidl.z * params->O_strides[0] + // Batch + tidl.y * params->O_strides[1] + // Head + tidl.x * BQ * params->O_strides[2]; // Sequence + + if (has_mask) { + mask += tidl.z * mask_params->M_strides[0] + // Batch + tidl.y * mask_params->M_strides[1]; // Head + } + + const metal::uniform scale2 = + make_uniform(params->scale) * make_uniform(1.44269504089f); + + // Prepare MMA tiles + constexpr short UQ = 16; + constexpr short UD = 32; + + constexpr int kNWarps = WM * WN; + static_assert( + BQ >= (kNWarps * UQ) && BQ % (kNWarps * UQ) == 0, + "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); + + // Q seq frags per warp + constexpr int TQ = BQ / (kNWarps * UQ); + // HeadDim frags (all warps load the same frags) + constexpr int TD = BD / UD; + + static_assert(TQ == 1, "Check TQ"); + + using OSubTile = NAXSubTile; + NAXTile Otile; + + Otile.clear(); + + // Prepare mma tile offsets + const short2 simd_coord = OSubTile::NAXFrag_t::get_coord(); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = UQ * TQ * simd_group_id; + + Q += (tm + sm) * int(params->Q_strides[2]) + sn; + K += sm * int(params->K_strides[2]) + sn; + V += sm * int(params->V_strides[2]) + sn; + + // Init row reduction variables + constexpr short kRowsPT = decltype(Otile)::kRowsPerThread; + + metal::vec max_score; + metal::vec sum_score{0}; + + // Init to -Inf + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = Limits::finite_min; + } + + if (has_sinks) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = M_LOG2E_F * static_cast(sinks[tidl.y]); + sum_score[i] = 1; + } + } + + int kb_lim = params->NK; + + if (do_causal) { + int q_max = (tid.x + 1) * BQ + params->qL_off; + kb_lim = (q_max + BK - 1) / BK; + kb_lim = min(params->NK, kb_lim); + } + + const bool is_last_bq = int(tid.x) == (params->NQ_aligned); + // const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ); + const bool is_last_q = is_last_bq; + + const short lim_rows_q = params->qL_rem - (tm + sm); + const short lim_rows_k = params->kL_rem - sm; + + // Loop over KV seq length + for (int kb = 0; kb < kb_lim; kb++) { + const int is_last_k = (kb == (params->NK_aligned)); + + // Do S = Q @ K.T + constexpr short UDs = 16; + constexpr short UKs = 32; + + constexpr short TDs = BD / UDs; + constexpr short TKs = BK / UKs; + + using SSubTile = NAXSubTile; + using QSubTile = NAXSubTile; + using KSubTile = NAXSubTile; + + NAXTile Stile; + + Stile.clear(); + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TKs; ik++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TDs; id++) { + NAXTile Qtile; + NAXTile Ktile; + + const int Q_load_off = iq * UQ * int(params->Q_strides[2]) + id * UDs; + const int K_load_off = + ik * UKs * int(params->K_strides[2]) + id * UDs; + + if (!align_Q && is_last_q) { + // Qtile.load_rows( + // Q + Q_load_off, + // int(params->Q_strides[2]), + // lim_rows_q - iq * UQ); + Qtile.load_safe( + Q + Q_load_off, + int(params->Q_strides[2]), + short2(BD, lim_rows_q - iq * UQ)); + } else { + Qtile.load(Q + Q_load_off, int(params->Q_strides[2])); + } + + if (!align_K && is_last_k) { + // Ktile.load_rows( + // K + K_load_off, + // int(params->K_strides[2]), + // lim_rows_k - ik * UKs); + Ktile.load_safe( + K + K_load_off, + int(params->K_strides[2]), + short2(BD, lim_rows_k - ik * UKs)); + } else { + Ktile.load(K + K_load_off, int(params->K_strides[2])); + } + + subtile_matmad_nax( + Stile.subtile_at(iq, ik), + Qtile.subtile_at(0, 0), + metal::false_type{}, + Ktile.subtile_at(0, 0), + metal::true_type{}); + } + } + } + + // Scale S + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { + Stile.elems()[ii] *= float(scale2); + } + + // Scale and Retile S + constexpr short UK = 16; + constexpr short TK = BK / UK; + using PSubTile = NAXSubTile; + + NAXTile Ptile; + + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { + Ptile.elems()[ii] = Stile.elems()[ii]; + } + + // Mask out length sequence + if (!align_K && is_last_k) { + constexpr auto neg_inf = Limits::finite_min; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + const short col_pos = sn + ik * UK; + + thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) { + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) { + const auto loc = ii * PSubTile::kFragThrCols + jj; + fg[loc] = ((col_pos + jj) >= params->kL_rem) ? neg_inf : fg[loc]; + } + } + } + } + } + + // Mask out if causal + if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { + constexpr auto neg_inf = Limits::finite_min; + + const int base_row = tid.x * BQ + params->qL_off + tm; + const int base_col = kb * BK; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + const short row_pos = base_row + iq * UQ; + const short col_pos = base_col + ik * UK; + + thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) { + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) { + const auto r = row_pos + ii * PSubTile::kFragRowsJump + sm; + const auto c = col_pos + jj + sn; + const auto loc = ii * PSubTile::kFragThrCols + jj; + fg[loc] = (r < c) ? neg_inf : fg[loc]; + } + } + } + } + } + + // Other masking as needed + if (has_mask) { + constexpr auto neg_inf = Limits::finite_min; + + const int base_row = tid.x * BQ + tm; + const int base_col = kb * BK; + + constexpr bool is_bool = is_same_v; + using melem_t = typename metal::conditional_t; + using MSubTile = NAXSubTile; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + const short row_pos = base_row + iq * UQ + sm; + const short col_pos = base_col + ik * UK + sn; + + MSubTile mfrag; + mfrag.load_safe( + mask, + int(mask_params->M_strides[2]), + Int<1>{}, + params->qL, + params->kL, + row_pos, + col_pos); + + thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < MSubTile::kElemsPerFrag; jj++) { + if constexpr (is_bool) { + fg[jj] = mfrag.elems()[jj] ? fg[jj] : neg_inf; + } else { + fg[jj] += M_LOG2E_F * AccumType(mfrag.elems()[jj]); + } + } + } + } + } + + // Do softmax + + // Temp variables + metal::vec new_max; + metal::vec factor; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + new_max[i] = max_score[i]; + } + + // Row max + Ptile.template row_reduce(new_max); + + // exp(Si - rowmax(Si)) + Ptile.template row_bin_op(new_max); + + // Factor exp(rowmax(Si) - rowmax(Si-1)) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + factor[i] = fast::exp2(max_score[i] - new_max[i]); + max_score[i] = new_max[i]; + } + + // Row Sum + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + sum_score[i] = sum_score[i] * factor[i]; + } + + Ptile.template row_reduce(sum_score); + + // Update O + Otile.template row_bin_op(factor); + + simdgroup_barrier(mem_flags::mem_none); + + // Do O = P @ V + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + if constexpr (BD == 128) { + if (id == 2) { + threadgroup_barrier(mem_flags::mem_none); + } + } + + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + using VSubTile = NAXSubTile; + NAXTile Vtile; + + const int V_load_off = ik * UK * int(params->V_strides[2]) + id * UD; + + if (!align_K && is_last_k) { + // Vtile.load_rows( + // V + V_load_off, + // int(params->V_strides[2]), + // lim_rows_k - ik * UK); + Vtile.load_safe( + V + V_load_off, + int(params->V_strides[2]), + short2(BD, lim_rows_k - ik * UK)); + } else { + Vtile.load(V + V_load_off, int(params->V_strides[2])); + } + + subtile_matmad_nax( + Otile.subtile_at(iq, id), + Ptile.subtile_at(iq, ik), + metal::bool_constant{}, + Vtile.subtile_at(0, 0), + metal::bool_constant{}); + } + } + } + + // Prepare for next iteration + K += BK * int(params->K_strides[2]); + V += BK * int(params->V_strides[2]); + } + + // Normalize output + + threadgroup_barrier(mem_flags::mem_none); + + metal::vec rcp; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + rcp[i] = (1.f / sum_score[i]); + } + + Otile.template row_bin_op(rcp); + + // Store results + O += (tm + sm) * int(params->O_strides[2]) + sn; + + if (!align_Q && is_last_q) { + if (lim_rows_q <= 0) + return; + + // Otile.store_rows(O, params->O_strides[2], lim_rows_q); + Otile.store_safe(O, params->O_strides[2], short2(BD, lim_rows_q)); + } else { + Otile.store(O, int(params->O_strides[2])); + } +} diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal new file mode 100644 index 0000000000..1fba9af617 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal @@ -0,0 +1,33 @@ +// Copyright © 2024-25 Apple Inc. + +// clang-format off +#include "mlx/backend/metal/kernels/utils.h" + +#include "mlx/backend/metal/kernels/steel/attn/nax.h" +#include "mlx/backend/metal/kernels/steel/attn/params.h" +#include "mlx/backend/metal/kernels/steel/attn/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h" + +#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \ + instantiate_kernel( \ + "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \ + "_wm" #wm "_wn" #wn "_mask" #mname, \ + attention_nax, dtype, bq, bk, bd, wm, wn, mtype, float) + +#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ + instantiate_attn(iname, itype, 64, 32, 128, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 64, 32, 64, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 64, 64, 128, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 64, 64, 64, 4, 1, mname, mtype) + +#define instantiate_attn_mask_helper(iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, iname, itype) \ + instantiate_attn_shapes_helper(iname, itype, bool_, bool) + +instantiate_attn_mask_helper(float16, half); +instantiate_attn_mask_helper(bfloat16, bfloat); + +instantiate_attn_mask_helper(float32, float); +// clang-format on diff --git a/mlx/backend/metal/kernels/steel/attn/nax.h b/mlx/backend/metal/kernels/steel/attn/nax.h new file mode 100644 index 0000000000..c8f3ea5ef1 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/attn/nax.h @@ -0,0 +1,1076 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +#include + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +/////////////////////////////////////////////////////////////////////////////// +// NAX Steel with new tiles +/////////////////////////////////////////////////////////////////////////////// + +struct BaseNAXFrag { + STEEL_CONST short kFragRows = 16; + STEEL_CONST short kFragCols = 16; + + STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST short kElemRows = 2; + STEEL_CONST short kElemCols = 4; + + STEEL_CONST short kElemRowsJump = 8; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + template + using dtype_frag_t = typename metal::vec; + + METAL_FUNC static short2 get_coord() { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)); + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4; + return short2{fn, fm}; + } + + METAL_FUNC static short2 get_coord(short idx) { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8; + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4; + return short2{fn, fm}; + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + c + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_rows( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j)]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + + } else { + dst = dtype_frag_t(0); + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_safe( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_rows( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_safe( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_slice( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + + const_for_loop<0, kElemRows, 1>([&](auto idx_row) { + const auto r = off_x + idx_row * Int{}; + if (r >= stop_x - sc.y || r < start_x - sc.y) { + return; + } + + const_for_loop<0, kElemCols, 1>([&](auto idx_col) { + const auto c = off_y + idx_col; + if (c >= stop_y - sc.x || c < start_y - sc.x) { + return; + } + + const auto src_idx = idx_row * Int{} + idx_col; + dst[(r + sc.y) * str_x + (c + sc.x) * str_y] = + static_cast(src[src_idx]); + }); + }); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const dtype_frag_t& inp_vals, + thread T* reduced_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + T thr_reduce = Op::apply( + Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]), + Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3])); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce); + } + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread dtype_frag_t& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + short kRows_, + short kCols_, + typename NAXFrag_ = BaseNAXFrag> +struct NAXSubTile { + using NAXFrag_t = NAXFrag_; + STEEL_CONST short kRows = kRows_; + STEEL_CONST short kCols = kCols_; + + STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; + STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; + STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; + + STEEL_CONST short kSubTileRows = kRows / kFragRows; + STEEL_CONST short kSubTileCols = kCols / kFragCols; + + STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols; + STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols; + + STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; + STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; + STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; + + using frag_type = typename NAXFrag_t::template dtype_frag_t; + + frag_type val_frags[kNumFrags]; + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr thread frag_type& frag_at() { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr const thread frag_type& frag_at() const { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC thread T* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread T* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + thread T* vptr = (thread T*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_reduce( + frag_at(i, j), &vptr[i * kFragThrRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + thread T* vptr = (thread T*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_bin_op( + frag_at(i, j), &vptr[i * kFragThrRows]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load( + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load( + frag_at(i, j), + src, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store( + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store( + frag_at(i, j), + dst, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_rows( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_rows( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_safe( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_safe( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_rows( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_safe( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_slice( + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) const { + const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store_slice( + frag_at(), + dst, + str_x, + str_y, + start_x, + stop_x, + start_y, + stop_y, + off_x + idx_row * Int{}, + off_y + idx_col * Int{}); + }); + }); + } +}; + +template < + short RC, + short CC, + short RA, + short CA, + short RB, + short CB, + typename CType, + typename AType, + typename BType, + bool transpose_a, + bool transpose_b, + typename NAXFrag_t = BaseNAXFrag> +METAL_FUNC void subtile_matmad_nax( + thread NAXSubTile& C, + thread NAXSubTile& A, + metal::bool_constant, + thread NAXSubTile& B, + metal::bool_constant) { + // Static checks + constexpr short FMa = transpose_a ? CA : RA; + constexpr short FMc = RC; + static_assert(FMa == FMc, "NAX matmul: M dimensions do not match"); + + constexpr short FNb = transpose_b ? RB : CB; + constexpr short FNc = CC; + static_assert(FNb == FNc, "NAX matmul: N dimensions do not match"); + + constexpr short FKa = transpose_a ? RA : CA; + constexpr short FKb = transpose_b ? CB : RB; + static_assert(FKa == FKb, "NAX matmul: N dimensions do not match"); + + constexpr short FM = FMc; + constexpr short FN = FNc; + constexpr short FK = FKa; + + constexpr int TM = FM / 16; + constexpr int TN = FN / 16; + constexpr int TK = FK / 16; + + constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( + FM, + FN, + FK, + transpose_a, + transpose_b, + true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); + + mpp::tensor_ops::matmul2d gemm_op; + + auto ct_a = + gemm_op.template get_left_input_cooperative_tensor(); + auto ct_b = + gemm_op + .template get_right_input_cooperative_tensor(); + auto ct_c = gemm_op.template get_destination_cooperative_tensor< + decltype(ct_a), + decltype(ct_b), + CType>(); + + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_a ? kk : mm; + const short fj = transpose_a ? mm : kk; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i]; + } + } + } + + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_b ? nn : kk; + const short fj = transpose_b ? kk : nn; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i]; + } + } + } + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + ct_c[i] = C.elems()[i]; + } + + gemm_op.run(ct_a, ct_b, ct_c); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + C.elems()[i] = ct_c[i]; + } +} + +template +struct NAXTile { + using NAXSubTile_t = NAXSubTile_; + using elem_type = T; + STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows; + STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols; + STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile; + + STEEL_CONST short kTileRows = kTileRows_; + STEEL_CONST short kTileCols = kTileCols_; + + STEEL_CONST short kRows = kTileRows * kSubTileRows; + STEEL_CONST short kCols = kTileCols * kSubTileCols; + + STEEL_CONST short kSubTiles = kTileRows * kTileCols; + STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile; + + STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread; + + STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread; + + NAXSubTile_t val_subtiles[kSubTiles]; + + METAL_FUNC NAXTile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTiles; ++i) { + val_subtiles[i].clear(); + } + } + + METAL_FUNC constexpr thread NAXSubTile_t& subtile_at( + const short i, + const short j) { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at( + const short i, + const short j) const { + return val_subtiles[i * kTileCols + j]; + } + + template + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_subtiles[0].elems()); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_subtiles[0].elems()); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_reduce(sub_rows[i]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_bin_op(sub_rows[i]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + src, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + dst, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_safe( + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void + load_rows(const device U* src, const int ld, const short n_rows) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_rows( + &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_safe( + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) + const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_rows( + &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void store_slice( + device U* dst, + const int ld, + const short2 start, + const short2 stop) const { + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + subtile_at().store_slice( + dst, + ld, + Int<1>{}, + start.y, + stop.y, + start.x, + stop.x, + idx_row * Int{}, + idx_col * Int{}); + }); + }); + } +}; + +template < + class CTile, + class ATile, + class BTile, + bool transpose_a, + bool transpose_b> +METAL_FUNC void tile_matmad_nax( + thread CTile& C, + thread ATile& A, + metal::bool_constant, + thread BTile& B, + metal::bool_constant) { + // Static checks + constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; + constexpr short TMc = CTile::kTileRows; + static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match"); + + constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows; + constexpr short FMc = CTile::kSubTileRows; + static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match"); + + constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; + constexpr short TNc = CTile::kTileCols; + static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match"); + + constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols; + constexpr short FNc = CTile::kSubTileCols; + static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match"); + + constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; + constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows; + static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match"); + + constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols; + constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows; + static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match"); + + constexpr short TM = TMc; + constexpr short TN = TNc; + constexpr short TK = TKa; + + // Do matmul here + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; ++j) { + STEEL_PRAGMA_UNROLL + for (short k = 0; k < TK; ++k) { + const short ra = transpose_a ? k : i; + const short ca = transpose_a ? i : k; + const short rb = transpose_b ? j : k; + const short cb = transpose_b ? k : j; + + subtile_matmad_nax( + C.subtile_at(i, j), + A.subtile_at(ra, ca), + metal::bool_constant{}, + B.subtile_at(rb, cb), + metal::bool_constant{}); + } + } + } +} + +} // namespace steel +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/defines.h b/mlx/backend/metal/kernels/steel/defines.h index 6c3bfcf4ed..f5657ee363 100644 --- a/mlx/backend/metal/kernels/steel/defines.h +++ b/mlx/backend/metal/kernels/steel/defines.h @@ -1,4 +1,7 @@ // Copyright © 2024 Apple Inc. +#pragma once + #define STEEL_CONST static constant constexpr const #define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") +#define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)") diff --git a/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h b/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h new file mode 100644 index 0000000000..e9b69a2006 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h @@ -0,0 +1,154 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/gemm/nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" + +using namespace metal; + +namespace mlx::steel { + +template < + typename T, + short SM, + short SN, + short SK, + short BK, + bool transpose_a, + bool transpose_b, + bool kAlignedM, + bool kAlignedN, + bool kAlignedK, + short UM, + short UN, + short UK, + typename AccumType = float> +auto gemm_loop( + const device T* A, + const device T* B, + const constant GEMMParams* params [[buffer(4)]], + const short sgp_sm, + const short sgp_sn) { + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + constexpr int RA = transpose_a ? TK : TM; + constexpr int CA = transpose_a ? TM : TK; + + constexpr int RB = transpose_b ? TN : TK; + constexpr int CB = transpose_b ? TK : TN; + + using DSubTile = NAXSubTile; + using ASubTile = + NAXSubTile; + using BSubTile = + NAXSubTile; + + NAXTile Dtile; + Dtile.clear(); + + int gemm_k_iterations_ = params->gemm_k_iterations_aligned; + + STEEL_PRAGMA_NO_UNROLL + for (int kk0 = 0; kk0 < gemm_k_iterations_; kk0++) { + threadgroup_barrier(mem_flags::mem_none); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + const int k = kk1; + + volatile int compiler_barrier; + + const int A_offset = transpose_a ? k * params->lda : k; + const int B_offset = transpose_b ? k : k * params->ldb; + + if constexpr (kAlignedM) { + Atile.load(A + A_offset, params->lda); + } else { + const short rmax = transpose_a ? SK : sgp_sm; + const short cmax = transpose_a ? sgp_sm : SK; + Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax)); + } + + if constexpr (kAlignedN) { + Btile.load(B + B_offset, params->ldb); + } else { + const short rmax = transpose_b ? sgp_sn : SK; + const short cmax = transpose_b ? SK : sgp_sn; + Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax)); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + A += transpose_a ? (BK * params->lda) : BK; + B += transpose_b ? BK : (BK * params->ldb); + } + + if constexpr (!kAlignedK) { + simdgroup_barrier(mem_flags::mem_none); + + const short rem_bk = params->K - gemm_k_iterations_ * BK; + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + STEEL_PRAGMA_UNROLL + for (int mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (int nn = 0; nn < TN; nn++) { + STEEL_PRAGMA_UNROLL + for (int kk = 0; kk < TK; kk++) { + const int m = mm * UM; + const int n = nn * UN; + const int k = kk1 + kk * UK; + const short psk = max(0, rem_bk - k); + + const int A_offset = + transpose_a ? (m + k * params->lda) : (m * params->lda + k); + const int B_offset = + transpose_b ? (k + n * params->ldb) : (k * params->ldb + n); + + { + const short psm = kAlignedM ? SM : max(0, sgp_sm - m); + const short rmax = transpose_a ? psk : psm; + const short cmax = transpose_a ? psm : psk; + Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax)); + } + + { + const short psn = kAlignedN ? SN : max(0, sgp_sn - n); + const short rmax = transpose_b ? psn : psk; + const short cmax = transpose_b ? psk : psn; + Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax)); + } + + subtile_matmad_nax( + Dtile.subtile_at(mm, nn), + Atile.subtile_at(0, 0), + metal::bool_constant{}, + Btile.subtile_at(0, 0), + metal::bool_constant{}); + } + } + } + } + } + + return Dtile; +} + +} // namespace mlx::steel diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h new file mode 100644 index 0000000000..44328ed0b9 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h @@ -0,0 +1,207 @@ +// Copyright © 2025 Apple Inc. + +using namespace mlx::steel; + +constant bool has_batch [[function_constant(10)]]; + +constant bool use_out_source [[function_constant(100)]]; +constant bool do_axpby [[function_constant(110)]]; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +// clang-format off +template < + bool kAlignedM, + bool kAlignedN, + typename NAXTile_t, + typename T> +void gemm_epilogue( + thread NAXTile_t& Dtile, + const device T* C, + const constant GEMMParams* params, + const constant GEMMAddMMParams* addmm_params, + const short sgp_sm, + const short sgp_sn) { // clang-format on + + (void)params; + + constexpr short UM = NAXTile_t::kSubTileRows; + constexpr short UN = NAXTile_t::kSubTileCols; + using CSubTile = NAXSubTile; + + using V = typename NAXTile_t::elem_type; + + constexpr short TM = NAXTile_t::kTileRows; + constexpr short TN = NAXTile_t::kTileCols; + constexpr short kElemsPerSubTile = NAXTile_t::kElemsPerSubTile; + + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn++) { + const short m = mm * UM; + const short n = nn * UN; + + CSubTile CTile; + + if constexpr (kAlignedM && kAlignedN) { + CTile.load(C, addmm_params->ldc, addmm_params->fdc, m, n); + } else { + CTile.load_safe( + C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n); + } + + auto delems = Dtile.subtile_at(mm, nn).elems(); + auto celems = CTile.elems(); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemsPerSubTile; i++) { + if (do_axpby) { + delems[i] = addmm_params->alpha * delems[i] + + addmm_params->beta * static_cast(celems[i]); + } else { + delems[i] += static_cast(celems[i]); + } + } + } + } +} + +// clang-format off +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device T* C [[buffer(2), function_constant(use_out_source)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on + // Find block + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + // Exit early if out of bounds + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Adjust for batch + if (has_batch) { + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + if (use_out_source) { + const constant auto* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); + } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; + } + } + + D += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + if (use_out_source) { + C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; + } + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm))); + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + + const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn))); + const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); + + A += transpose_a ? tm : (tm * params->lda); + B += transpose_b ? (tn * params->ldb) : tn; + D += tm * params->ldd + tn; + + if (use_out_source) { + C += tm * addmm_params->ldc + tn * addmm_params->fdc; + } + + using DSubTile = NAXSubTile; + NAXTile Dtile; + + dispatch_bool(align_K, [&](auto kAlignedK) { + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { + Dtile = gemm_loop< + T, + SM, + SN, + SK, + BK, + transpose_a, + transpose_b, + kAlignedM.value, + kAlignedN.value, + kAlignedK.value, + UM, + UN, + UK, + AccumType>(A, B, params, sgp_sm, sgp_sn); + if (use_out_source) { + gemm_epilogue( + Dtile, C, params, addmm_params, sgp_sm, sgp_sn); + } + if constexpr (kAlignedM && kAlignedN) { + Dtile.store(D, int(params->ldd)); + } else { + Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm)); + } + }); + }); + }); +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal new file mode 100644 index 0000000000..e6cb0b64c6 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal @@ -0,0 +1,35 @@ +// Copyright © 2025 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/utils.h" + +#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h" + +// clang-format off +#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_gemm_fused_nax_" #tname "_" #iname "_" #oname \ + "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn, \ + gemm, itype, bm, bn, bk, wm, wn, trans_a, trans_b, float) + +#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \ + instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 512, 4, 4) + +instantiate_gemm_shapes_helper(float16, half, float16, half); +instantiate_gemm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat); +instantiate_gemm_shapes_helper(float32, float, float32, float); +// clang-format on diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h new file mode 100644 index 0000000000..29285833af --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h @@ -0,0 +1,132 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +gather_mm_rhs_nax( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* rhs_indices [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + rhs_indices += c_row; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm))); + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + + const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn))); + const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); + + A += transpose_a ? tm : (tm * params->lda); + B += transpose_b ? (tn * params->ldb) : tn; + C += tm * params->ldd + tn; + rhs_indices += tm; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = rhs_indices[0]; + short offset_next = 0; + int n = 0; + while (n < sgp_sm) { + n++; + offset = offset_next; + index = index_next; + offset_next = sgp_sm; + for (; n < sgp_sm; n++) { + if (rhs_indices[n] != index) { + offset_next = n; + index_next = rhs_indices[n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + using DSubTile = NAXSubTile; + NAXTile Ctile; + + dispatch_bool(align_K, [&](auto kAlignedK) { + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { + auto do_gemm = gemm_loop< + T, + SM, + SN, + SK, + BK, + transpose_a, + transpose_b, + kAlignedM.value, + kAlignedN.value, + kAlignedK.value, + UM, + UN, + UK, + AccumType>; + Ctile = do_gemm( + A, B + index * params->batch_stride_b, params, sgp_sm, sgp_sn); + + if constexpr (kAlignedN.value) { + if (offset_next - offset == SM) { + Ctile.store(C, int(params->ldd)); + } else { + Ctile.store_slice( + C, + int(params->ldd), + short2(0, offset), + short2(SN, offset_next)); + } + } else { + Ctile.store_slice( + C, + int(params->ldd), + short2(0, offset), + short2(sgp_sn, offset_next)); + } + }); + }); + }); + } +} diff --git a/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal new file mode 100644 index 0000000000..5b8589f547 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal @@ -0,0 +1,39 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/steel/gemm/gemm_nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" +#include "mlx/backend/metal/kernels/steel/utils.h" +#include "mlx/backend/metal/kernels/utils.h" + +// clang-format off +#define instantiate_gather_mm_rhs(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_kernel( \ + "steel_gather_mm_rhs_nax_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn \ + "_bk" #bk "_wm" #wm "_wn" #wn, \ + gather_mm_rhs_nax, \ + itype, \ + bm, \ + bn, \ + bk, \ + wm, \ + wn, \ + trans_a, \ + trans_b, \ + float) + +#define instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm_rhs(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \ + instantiate_gather_mm_rhs(nt, false, true, iname, itype, oname, otype, bm, bn, bk, wm, wn) + +#define instantiate_gather_mm_shapes_helper(iname, itype, oname, otype) \ + instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 16, 128, 128, 1, 4) \ + instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 32, 128, 128, 1, 4) \ + instantiate_gather_mm_rhs_transpose_helper(iname, itype, oname, otype, 64, 128, 128, 2, 4) +// clang-format on + +instantiate_gather_mm_shapes_helper(float16, half, float16, half); +instantiate_gather_mm_shapes_helper(bfloat16, bfloat, bfloat16, bfloat); diff --git a/mlx/backend/metal/kernels/steel/gemm/nax.h b/mlx/backend/metal/kernels/steel/gemm/nax.h new file mode 100644 index 0000000000..5839176c28 --- /dev/null +++ b/mlx/backend/metal/kernels/steel/gemm/nax.h @@ -0,0 +1,1084 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +#include + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +/////////////////////////////////////////////////////////////////////////////// +// NAX Steel with new tiles +/////////////////////////////////////////////////////////////////////////////// + +struct BaseNAXFrag { + STEEL_CONST short kFragRows = 16; + STEEL_CONST short kFragCols = 16; + + STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST short kElemRows = 2; + STEEL_CONST short kElemCols = 4; + + STEEL_CONST short kElemRowsJump = 8; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + template + using dtype_frag_t = typename metal::vec; + + METAL_FUNC static short2 get_coord() { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)); + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4; + return short2{fn, fm}; + } + + METAL_FUNC static short2 get_coord(short idx) { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8; + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4; + return short2{fn, fm}; + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + c + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_rows( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j)]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + + } else { + dst = dtype_frag_t(0); + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_safe( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_rows( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_safe( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_slice( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + + const_for_loop<0, kElemRows, 1>([&](auto idx_row) { + const auto r = off_x + idx_row * Int{}; + if (r >= stop_x - sc.y || r < start_x - sc.y) { + return; + } + + const_for_loop<0, kElemCols, 1>([&](auto idx_col) { + const auto c = off_y + idx_col; + if (c >= stop_y - sc.x || c < start_y - sc.x) { + return; + } + + const auto src_idx = idx_row * Int{} + idx_col; + dst[(r + sc.y) * str_x + (c + sc.x) * str_y] = + static_cast(src[src_idx]); + }); + }); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const dtype_frag_t& inp_vals, + thread T* reduced_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + T thr_reduce = Op::apply( + Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]), + Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3])); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce); + } + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread dtype_frag_t& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + short kRows_, + short kCols_, + typename NAXFrag_t = BaseNAXFrag> +struct NAXSubTile { + STEEL_CONST short kRows = kRows_; + STEEL_CONST short kCols = kCols_; + + STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; + STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; + STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; + + STEEL_CONST short kSubTileRows = kRows / kFragRows; + STEEL_CONST short kSubTileCols = kCols / kFragCols; + + STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols; + STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols; + + STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; + STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; + STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; + + using frag_type = typename NAXFrag_t::template dtype_frag_t; + + frag_type val_frags[kNumFrags]; + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr thread frag_type& frag_at() { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr const thread frag_type& frag_at() const { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC thread T* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread T* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_reduce( + frag_at(i, j), &vals[i * kFragThrRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_bin_op( + frag_at(i, j), &vals[i * kFragThrRows]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load( + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load( + frag_at(i, j), + src, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store( + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store( + frag_at(i, j), + dst, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_rows( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_rows( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_safe( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_safe( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_safe( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_rows( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_slice( + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) const { + const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store_slice( + frag_at(), + dst, + str_x, + str_y, + start_x, + stop_x, + start_y, + stop_y, + off_x + idx_row * Int{}, + off_y + idx_col * Int{}); + }); + }); + } +}; + +template < + short RC, + short CC, + short RA, + short CA, + short RB, + short CB, + typename CType, + typename AType, + typename BType, + bool transpose_a, + bool transpose_b, + typename NAXFrag_t = BaseNAXFrag> +METAL_FUNC void subtile_matmad_nax( + thread NAXSubTile& C, + thread NAXSubTile& A, + metal::bool_constant, + thread NAXSubTile& B, + metal::bool_constant) { + // Static checks + constexpr short FMa = transpose_a ? CA : RA; + constexpr short FMc = RC; + static_assert(FMa == FMc, "NAX matmul: M dimensions do not match"); + + constexpr short FNb = transpose_b ? RB : CB; + constexpr short FNc = CC; + static_assert(FNb == FNc, "NAX matmul: N dimensions do not match"); + + constexpr short FKa = transpose_a ? RA : CA; + constexpr short FKb = transpose_b ? CB : RB; + static_assert(FKa == FKb, "NAX matmul: N dimensions do not match"); + + constexpr short FM = FMc; + constexpr short FN = FNc; + constexpr short FK = FKa; + + constexpr int TM = FM / 16; + constexpr int TN = FN / 16; + constexpr int TK = FK / 16; + + // Create Matmul descriptor + constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( + FM, + FN, + FK, + transpose_a, + transpose_b, + true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); + + // Create matmul op + mpp::tensor_ops::matmul2d gemm_op; + + // Create matmul operands in registers + auto ct_a = + gemm_op.template get_left_input_cooperative_tensor(); + auto ct_b = + gemm_op + .template get_right_input_cooperative_tensor(); + + // Create matmul output in register + auto ct_c = gemm_op.template get_destination_cooperative_tensor< + decltype(ct_a), + decltype(ct_b), + CType>(); + + // Load A in to left operand registers + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_a ? kk : mm; + const short fj = transpose_a ? mm : kk; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i]; + } + } + } + + // Load B into right operand registers + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_b ? nn : kk; + const short fj = transpose_b ? kk : nn; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i]; + } + } + } + + // Load C into output registers (op handles accumulation) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + ct_c[i] = C.elems()[i]; + } + + // Do matmul + gemm_op.run(ct_a, ct_b, ct_c); + + // Copy out results + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + C.elems()[i] = ct_c[i]; + } +} + +template +struct NAXTile { + using NAXSubTile_t = NAXSubTile_; + using elem_type = T; + STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows; + STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols; + STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile; + + STEEL_CONST short kTileRows = kTileRows_; + STEEL_CONST short kTileCols = kTileCols_; + + STEEL_CONST short kRows = kTileRows * kSubTileRows; + STEEL_CONST short kCols = kTileCols * kSubTileCols; + + STEEL_CONST short kSubTiles = kTileRows * kTileCols; + STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile; + + STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread; + + STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread; + + NAXSubTile_t val_subtiles[kSubTiles]; + + METAL_FUNC NAXTile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTiles; ++i) { + val_subtiles[i].clear(); + } + } + + METAL_FUNC constexpr thread NAXSubTile_t& subtile_at( + const short i, + const short j) { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at( + const short i, + const short j) const { + return val_subtiles[i * kTileCols + j]; + } + + template + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_subtiles[0].elems()); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_subtiles[0].elems()); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_reduce(sub_rows[i]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_bin_op(sub_rows[i]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + src, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + dst, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + &src[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + &dst[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_rows(const device U* src, const int ld, const short n_rows) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_rows( + &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_safe( + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) + const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_rows( + &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_safe( + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store_slice( + device U* dst, + const int ld, + const short2 start, + const short2 stop) const { + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + subtile_at().store_slice( + dst, + ld, + Int<1>{}, + start.y, + stop.y, + start.x, + stop.x, + idx_row * Int{}, + idx_col * Int{}); + }); + }); + } +}; + +template < + class CTile, + class ATile, + class BTile, + bool transpose_a, + bool transpose_b> +METAL_FUNC void tile_matmad_nax( + thread CTile& C, + thread ATile& A, + metal::bool_constant, + thread BTile& B, + metal::bool_constant) { + // Static checks + constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; + constexpr short TMc = CTile::kTileRows; + static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match"); + + constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows; + constexpr short FMc = CTile::kSubTileRows; + static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match"); + + constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; + constexpr short TNc = CTile::kTileCols; + static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match"); + + constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols; + constexpr short FNc = CTile::kSubTileCols; + static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match"); + + constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; + constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows; + static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match"); + + constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols; + constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows; + static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match"); + + constexpr short TM = TMc; + constexpr short TN = TNc; + constexpr short TK = TKa; + + // Do matmul here + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; ++j) { + STEEL_PRAGMA_UNROLL + for (short k = 0; k < TK; ++k) { + const short ra = transpose_a ? k : i; + const short ca = transpose_a ? i : k; + const short rb = transpose_b ? j : k; + const short cb = transpose_b ? k : j; + + subtile_matmad_nax( + C.subtile_at(i, j), + A.subtile_at(ra, ca), + metal::bool_constant{}, + B.subtile_at(rb, cb), + metal::bool_constant{}); + } + } + } +} + +} // namespace steel +} // namespace mlx diff --git a/mlx/backend/metal/kernels/steel/utils/integral_constant.h b/mlx/backend/metal/kernels/steel/utils/integral_constant.h index b616acc676..526f561ee5 100644 --- a/mlx/backend/metal/kernels/steel/utils/integral_constant.h +++ b/mlx/backend/metal/kernels/steel/utils/integral_constant.h @@ -74,6 +74,44 @@ integral_const_binop(>=, operator>=); integral_const_binop(&&, operator&&); integral_const_binop(||, operator||); +template >> +METAL_FUNC constexpr auto operator||(true_type, T) { + return true_type{}; +} +template >> +METAL_FUNC constexpr auto operator||(T, true_type) { + return true_type{}; +} + +template >> +METAL_FUNC constexpr auto operator&&(false_type, T) { + return false_type{}; +} + +template >> +METAL_FUNC constexpr auto operator&&(T, false_type) { + return false_type{}; +} + +// Dispatch utilities +template +void dispatch_bool(bool v, F f) { + if (v) { + f(true_type{}); + } else { + f(false_type{}); + } +} + +template +constexpr void const_for_loop(F f) { + if constexpr (start < stop) { + constexpr auto idx = Int{}; + f(idx); + const_for_loop(f); + } +} + #undef integral_const_binop /////////////////////////////////////////////////////////////////////////////// diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp index abc45575a4..e4f6253837 100644 --- a/mlx/backend/metal/matmul.cpp +++ b/mlx/backend/metal/matmul.cpp @@ -172,6 +172,165 @@ ensure_batch_contiguous(const array& x, metal::Device& d, const Stream& s) { // Regular steel matmul dispatch /////////////////////////////////////////////////////////////////////////////// +#ifdef MLX_ENABLE_NAX + +template +void steel_matmul_regular_axpby_nax( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + int ldd, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape, + Strides batch_strides, + int64_t A_batch_stride, + int64_t B_batch_stride, + int64_t matrix_stride_out, + int64_t C_batch_stride /* = 0*/, + float alpha /* = 1.0f */, + float beta /* = 0.0f */) { + using namespace mlx::steel; + + // Determine dispatch kernel + int bm = 128, bn = 128, bk = 512; + int wm = 4, wn = 4; + + // Prepare kernel name + std::ostringstream kname; + + // clang-format off + kname << "steel_gemm_fused_nax_" + << (transpose_a ? 't' : 'n') + << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) + << "_" << type_to_name(out) + << "_bm" << bm << "_bn" << bn << "_bk" << bk + << "_wm" << wm << "_wn" << wn; // clang-format on + + std::string base_name = kname.str(); + + const bool has_batch = (batch_shape.size() > 1); + const bool use_out_source = CHECK_AB && (alpha != 0.0f || beta != 1.0f); + const bool do_axpby = use_out_source && (alpha != 1.0f || beta != 1.0f); + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + + metal::MTLFCList func_consts = { + {&has_batch, MTL::DataType::DataTypeBool, 10}, + {&use_out_source, MTL::DataType::DataTypeBool, 100}, + {&do_axpby, MTL::DataType::DataTypeBool, 110}, + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + }; + + // clang-format off + kname << "_has_batch_" << (has_batch ? 't' : 'n') + << "_use_out_source_" << (use_out_source ? 't' : 'n') + << "_do_axpby_" << (do_axpby ? 't' : 'n') + << "_align_M_" << (align_M ? 't' : 'n') + << "_align_N_" << (align_N ? 't' : 'n') + << "_align_K_" << (align_K ? 't' : 'n'); // clang-format on + + std::string hash_name = kname.str(); + + // Encode and dispatch kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_fused_kernel( + /* metal::Device& d = */ d, + /* const std::string& kernel_name = */ base_name, + /* const std::string& hash_name = */ hash_name, + /* const metal::MTLFCList& func_consts = */ func_consts, + /* const array& out = */ out, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* int bm = */ bm, + /* int bn = */ bn, + /* int bk = */ bk, + /* int wm = */ wm, + /* int wn = */ wn); + + compute_encoder.set_compute_pipeline_state(kernel); + + // Use problem size to determine threadblock swizzle + int tn = (N + bn - 1) / bn; + int tm = (M + bm - 1) / bm; + + // TODO: Explore device-based tuning for swizzle + int swizzle_log = tm <= 3 ? 0 : 1; + + // Prepare steel matmul params + GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ ldb, + /* const int ldd = */ ldd, + /* const int tiles_n = */ tn, + /* const int tiles_m = */ tm, + /* const int64_t batch_stride_a = */ A_batch_stride, + /* const int64_t batch_stride_b = */ B_batch_stride, + /* const int64_t batch_stride_d = */ matrix_stride_out, + /* const int swizzle_log = */ swizzle_log, + /* const int gemm_k_iterations_aligned = */ (K / bk), + /* const int batch_ndim = */ int(batch_shape.size())}; + + // Prepare launch grid params + int tile = 1 << swizzle_log; + tm = (tm + tile - 1) / tile; + tn = tn * tile; + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(tn, tm, batch_size_out); + + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_output_array(out, 3); + + compute_encoder.set_bytes(params, 4); + + if (has_batch) { + compute_encoder.set_vector_bytes(batch_shape, 6); + compute_encoder.set_vector_bytes(batch_strides, 7); + } + + if (use_out_source) { + int ldc = c.strides()[c.ndim() - 2]; + int fdc = c.strides()[c.ndim() - 1]; + + GEMMAddMMParams params{ + /* const int ldc = */ ldc, + /* const int fdc = */ fdc, + /* const int64_t batch_stride_c = */ C_batch_stride, + /* const float alpha = */ alpha, + /* const float beta = */ beta}; + + compute_encoder.set_input_array(c, 2); + compute_encoder.set_bytes(params, 5); + } + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); + + // Record copies + d.add_temporaries(std::move(copies), s.index); +} + +#endif // MLX_ENABLE_NAX + template void steel_matmul_regular_axpby( const Stream& s, @@ -198,6 +357,41 @@ void steel_matmul_regular_axpby( int64_t C_batch_stride /* = 0*/, float alpha /* = 1.0f */, float beta /* = 0.0f */) { +#ifdef MLX_ENABLE_NAX + + if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + if (metal::is_nax_available() && !issubdtype(a.dtype(), complexfloating) && + (env::enable_tf32() || a.dtype() != float32)) { + return steel_matmul_regular_axpby_nax( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ c, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ ldd, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides batch_strides = */ batch_strides, + /* int64_t A_batch_stride = */ A_batch_stride, + /* int64_t B_batch_stride = */ B_batch_stride, + /* int64_t matrix_stride_out = */ matrix_stride_out, + /* int64_t C_batch_stride = */ C_batch_stride, + /* float alpha = */ alpha, + /* float beta = */ beta); + } + } + +#endif // MLX_ENABLE_NAX + using namespace mlx::steel; // Determine dispatch kernel @@ -1572,6 +1766,153 @@ void gather_mm_rhs( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +#ifdef MLX_ENABLE_NAX + +void gather_mm_rhs_nax( + const array& a_, + const array& b_, + const array& indices_, + array& out, + metal::Device& d, + const Stream& s) { + array indices = ensure_row_contiguous(indices_, d, s); + auto [transpose_b, ldb, b] = ensure_batch_contiguous(b_, d, s); + + // Broadcast a with indices. If we are here that means lhs_indices were not + // provided so the lhs_indices are implied to be the shape of a broadcasted + // with rhs_indices. We need only broadcast a and copy it as if applying the + // lhs_indices. + auto broadcast_with_indices = [&d, &s, &indices](const array& x) { + if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { + return ensure_row_contiguous(x, d, s); + } + + auto x_shape = indices.shape(); + x_shape.push_back(x.shape(-2)); + x_shape.push_back(x.shape(-1)); + array new_x(std::move(x_shape), x.dtype(), nullptr, {}); + broadcast(x, new_x); + return ensure_row_contiguous(new_x, d, s); + }; + array a = broadcast_with_indices(a_); + + // Extract the matmul shapes + int K = a.shape(-1); + int M = a.size() / K; + int N = b.shape(-1); + int lda = a.strides()[a.ndim() - 2]; // should be K + int E = b.shape(0); + + // Define the dispatch blocks + int bm, bn = 128, bk = 128, wm, wn = 4; + if (M / E > 48) { + bm = 64; + wm = 2; + } else if (M / E > 24) { + bm = 32l; + wm = 1; + } else { + bm = 16; + wm = 1; + } + + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + + // Define the kernel name + std::string base_name; + base_name.reserve(64); + concatenate( + base_name, + "steel_gather_mm_rhs_nax_n", + transpose_b ? 't' : 'n', + '_', + type_to_name(a), + '_', + type_to_name(out), + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); + + metal::MTLFCList func_consts = { + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + }; + + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + base_name, + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n', + "_align_K_", + align_K ? 't' : 'n'); + + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_steel_gemm_gather_kernel( + d, + base_name, + hash_name, + func_consts, + out, + false, + transpose_b, + bm, + bn, + bk, + wm, + wn, + true); + compute_encoder.set_compute_pipeline_state(kernel); + + // Prepare the matmul params + auto batch_stride_b = b.ndim() > 2 ? b.strides()[b.ndim() - 3] : b.size(); + steel::GEMMParams params{ + /* const int M = */ M, + /* const int N = */ N, + /* const int K = */ K, + /* const int lda = */ lda, + /* const int ldb = */ static_cast(ldb), + /* const int ldd = */ N, + /* const int tiles_n = */ (N + bn - 1) / bn, + /* const int tiles_m = */ (M + bm - 1) / bm, + /* const int64_t batch_stride_a = */ 0, + /* const int64_t batch_stride_b = */ static_cast(batch_stride_b), + /* const int64_t batch_stride_d = */ 0, + /* const int swizzle_log = */ 0, + /* const int gemm_k_iterations_aligned = */ (K / bk), + /* const int batch_ndim = */ 0}; + + // Prepare the grid + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size(params.tiles_n, params.tiles_m, 1); + + // Launch kernel + compute_encoder.set_input_array(a, 0); + compute_encoder.set_input_array(b, 1); + compute_encoder.set_input_array(indices, 2); + compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(params, 4); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +#endif // MLX_ENABLE_NAX + void gather_mv( const array& mat_, const array& vec_, @@ -1855,6 +2196,19 @@ void GatherMM::eval_gpu(const std::vector& inputs, array& out) { // We are walking a in order and b is also in order so we can batch up the // matmuls and reuse reading a and b. if (M == 1 && right_sorted_ == true) { +#ifdef MLX_ENABLE_NAX + + if (__builtin_available( + macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + if (metal::is_nax_available() && + !issubdtype(a.dtype(), complexfloating) && + (env::enable_tf32() || a.dtype() != float32)) { + return gather_mm_rhs_nax(a, b, rhs_indices, out, d, s); + } + } + +#endif // MLX_ENABLE_NAX + gather_mm_rhs(a, b, rhs_indices, out, d, s); return; } diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index e03e5dca20..55b69b9cac 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -451,6 +451,210 @@ void qvm( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +#ifdef MLX_ENABLE_NAX + +void qmm_nax( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + array& out, + bool transpose, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s, + const std::string& mode) { + int B = out.size() / M / N; + + int wm = 2; + int wn = 2; + int bm = 64; + int bn = 64; + int bk = 64; + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); + + std::string kname; + kname.reserve(64); + bool aligned = N % 64 == 0; + bool batched = B > 1; + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + mode + (transpose ? "_qmm_t_nax_" : "_qmm_n_nax_"), + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn, + transpose ? (aligned ? "_alN_true" : "_alN_false") : "", + batched ? "_batch_1" : "_batch_0"); + std::string template_def; + MTL::ComputePipelineState* kernel; + if (transpose) { + kernel = get_quantized_kernel_wrapped( + d, + kname, + "qmm_t_nax", + mode, + type_string, + group_size, + bits, + aligned, + batched); + } else { + kernel = get_quantized_kernel_wrapped( + d, kname, "qmm_n_nax", mode, type_string, group_size, bits, batched); + } + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + compute_encoder.set_bytes(M, c++); + add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, c); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void gather_qmm_nax( + const array& x, + const array& w, + const array& scales, + const std::optional& biases, + const array& lhs_indices, + const array& rhs_indices, + array& out, + bool transpose, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s, + const std::string& mode) { + int B = out.size() / M / N; + + int wm = 2; + int wn = 2; + int bm = 64; + int bn = 64; + int bk = 32; + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, B); + + std::string kname; + kname.reserve(64); + bool aligned = N % 64 == 0; + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + mode + (transpose ? "_gather_qmm_t_nax_" : "_gather_qmm_n_nax_"), + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn, + transpose ? (aligned ? "_alN_true" : "_alN_false") : ""); + MTL::ComputePipelineState* kernel; + if (transpose) { + kernel = get_quantized_kernel_wrapped( + d, + kname, + "gather_qmm_t_nax_", + mode, + type_string, + group_size, + bits, + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn, + aligned); + } else { + kernel = get_quantized_kernel_wrapped( + d, + kname, + "gather_qmm_n_nax_", + mode, + type_string, + group_size, + bits, + "_bm", + bm, + "_bn", + bn, + "_bk", + bk, + "_wm", + wm, + "_wn", + wn); + } + + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + int c = 0; + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases) { + compute_encoder.set_input_array(*biases, c++); + } + compute_encoder.set_input_array(x, c++); + compute_encoder.set_input_array(lhs_indices, c++); + compute_encoder.set_input_array(rhs_indices, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(K, c++); + compute_encoder.set_bytes(N, c++); + compute_encoder.set_bytes(M, c++); + c = add_strides_and_shapes(compute_encoder, false, x, w, scales, biases, c); + add_gather_strides_and_shapes(compute_encoder, lhs_indices, rhs_indices, c); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +#endif // MLX_ENABLE_NAX + void qmm( const array& x, const array& w, @@ -466,6 +670,31 @@ void qmm( metal::Device& d, const Stream& s, const std::string& mode) { +#ifdef MLX_ENABLE_NAX + + if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + if (metal::is_nax_available() && transpose && (K % 64 == 0) && + (env::enable_tf32() || x.dtype() != float32)) { + return qmm_nax( + /* const array& x = */ x, + /* const array& w = */ w, + /* const array& scales = */ scales, + /* const std::optional& biases = */ biases, + /* array& out = */ out, + /* bool transpose = */ transpose, + /* int group_size = */ group_size, + /* int bits = */ bits, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* metal::Device& d = */ d, + /* const Stream& s = */ s, + /* const std::string& mode = */ mode); + } + } + +#endif // MLX_ENABLE_NAX + int B = out.size() / M / N; int wm = 2; @@ -543,6 +772,33 @@ void gather_qmm( metal::Device& d, const Stream& s, const std::string& mode) { +#ifdef MLX_ENABLE_NAX + + if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + if (metal::is_nax_available() && transpose && (K % 64 == 0) && + (env::enable_tf32() || x.dtype() != float32)) { + return gather_qmm_nax( + /* const array& x = */ x, + /* const array& w = */ w, + /* const array& scales = */ scales, + /* const std::optional& biases = */ biases, + /* const array& lhs_indices = */ lhs_indices, + /* const array& rhs_indices = */ rhs_indices, + /* array& out = */ out, + /* bool transpose = */ transpose, + /* int group_size = */ group_size, + /* int bits = */ bits, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* metal::Device& d = */ d, + /* const Stream& s = */ s, + /* const std::string& mode = */ mode); + } + } + +#endif // MLX_ENABLE_NAX + int B = out.size() / M / N; int wm = 2; @@ -719,6 +975,141 @@ void gather_qvm( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +#ifdef MLX_ENABLE_NAX + +void gather_qmm_rhs_nax( + const array& x_, + const array& w_, + const array& scales_, + const std::optional& biases_, + const array& indices_, + array& out, + bool transpose, + int group_size, + int bits, + int M, + int N, + int K, + metal::Device& d, + const Stream& s, + const std::string mode) { + // Start by normalizing the indices + array indices = ensure_row_contiguous(indices_, d, s); + + // Broadcast x with indices. If we are here that means lhs_indices were not + // provided so the lhs_indices are implied to be the shape of x broadcasted + // with rhs_indices. We need only broadcast x and copy it as if applying the + // lhs_indices. + auto broadcast_with_indices = [&d, &s, &indices](const array& x) { + if (x.size() / x.shape(-2) / x.shape(-1) == indices.size()) { + return ensure_row_contiguous(x, d, s); + } + + auto x_shape = indices.shape(); + x_shape.push_back(x.shape(-2)); + x_shape.push_back(x.shape(-1)); + array new_x(std::move(x_shape), x.dtype(), nullptr, {}); + broadcast(x, new_x); + return ensure_row_contiguous(new_x, d, s); + }; + + // Normalize the input arrays + array x = broadcast_with_indices(x_); + array w = ensure_row_contiguous(w_, d, s); + array scales = ensure_row_contiguous(scales_, d, s); + + // TODO: Tune the block sizes + int bm = 64, bn = 64, bk = 64; + int wm = 2, wn = 2; + + const bool align_M = (M % bm) == 0; + const bool align_N = (N % bn) == 0; + const bool align_K = (K % bk) == 0; + + // Make the kernel name + std::string kname; + kname.reserve(64); + std::string type_string = get_type_string(x.dtype()); + concatenate( + kname, + mode + + (transpose ? "_gather_qmm_rhs_nax_nt_" : "_gather_qmm_rhs_nax_nn_"), + type_string, + "_gs_", + group_size, + "_b_", + bits, + "_bm_", + bm, + "_bn_", + bn, + "_bk_", + bk, + "_wm_", + wm, + "_wn_", + wn); + + metal::MTLFCList func_consts = { + {&align_M, MTL::DataType::DataTypeBool, 200}, + {&align_N, MTL::DataType::DataTypeBool, 201}, + {&align_K, MTL::DataType::DataTypeBool, 202}, + }; + + // And the kernel hash that includes the function constants + std::string hash_name; + hash_name.reserve(128); + concatenate( + hash_name, + kname, + "_align_M_", + align_M ? 't' : 'n', + "_align_N_", + align_N ? 't' : 'n', + "_align_K_", + align_K ? 't' : 'n'); + + // Get and set the kernel + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = get_gather_qmm_kernel( + d, + kname, + hash_name, + func_consts, + x, + group_size, + bits, + mode, + bm, + bn, + bk, + wm, + wn, + transpose); + compute_encoder.set_compute_pipeline_state(kernel); + + MTL::Size group_dims(32, wn, wm); + MTL::Size grid_dims((N + bn - 1) / bn, (M + bm - 1) / bm, 1); + + int c = 0; + compute_encoder.set_input_array(x, c++); + compute_encoder.set_input_array(w, c++); + compute_encoder.set_input_array(scales, c++); + if (biases_) { + array biases = ensure_row_contiguous(*biases_, d, s); + compute_encoder.set_input_array(biases, c++); + } + compute_encoder.set_input_array(indices, c++); + compute_encoder.set_output_array(out, c++); + compute_encoder.set_bytes(M, c++); + compute_encoder.set_bytes(N, c++); + compute_encoder.set_bytes(K, c++); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +#endif // MLX_ENABLE_NAX + void gather_qmm_rhs( const array& x_, const array& w_, @@ -735,6 +1126,32 @@ void gather_qmm_rhs( metal::Device& d, const Stream& s, const std::string mode) { +#ifdef MLX_ENABLE_NAX + + if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + if (metal::is_nax_available() && transpose && + (env::enable_tf32() || x_.dtype() != float32)) { + return gather_qmm_rhs_nax( + /* const array& x_ = */ x_, + /* const array& w_ = */ w_, + /* const array& scales_ = */ scales_, + /* const std::optional& biases_ = */ biases_, + /* const array& indices_ = */ indices_, + /* array& out = */ out, + /* bool transpose = */ transpose, + /* int group_size = */ group_size, + /* int bits = */ bits, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* metal::Device& d = */ d, + /* const Stream& s = */ s, + /* const std::string mode = */ mode); + } + } + +#endif // MLX_ENABLE_NAX + // Start by normalizing the indices array indices = ensure_row_contiguous(indices_, d, s); diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index d8adf81996..d3920b55d9 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -12,6 +12,146 @@ namespace mlx::core::fast { namespace { + +#ifdef MLX_ENABLE_NAX + +void sdpa_full_self_attention_nax( + const Stream& s, + metal::Device& d, + const array& q, + const array& k, + const array& v, + const float scale, + array& o, + bool do_causal_, + const std::optional& mask, + const std::optional& sinks) { + using namespace mlx::steel; + + int wm = 4; + int wn = 1; + + int bd = q.shape(-1); + int bq = 64; + int bk = 32; + + int B = q.shape(0); + int H = q.shape(1); + int D = q.shape(3); + int gqa_factor = q.shape(1) / k.shape(1); + + int qL = q.shape(2); + int kL = k.shape(2); + + const bool align_Q = (qL % bq) == 0; + const bool align_K = (kL % bk) == 0; + const bool has_mask = mask.has_value(); + const bool do_causal = do_causal_; + const bool has_sinks = sinks.has_value(); + + metal::MTLFCList func_consts = { + {&align_Q, MTL::DataType::DataTypeBool, 200}, + {&align_K, MTL::DataType::DataTypeBool, 201}, + {&has_mask, MTL::DataType::DataTypeBool, 300}, + {&do_causal, MTL::DataType::DataTypeBool, 301}, + {&has_sinks, MTL::DataType::DataTypeBool, 302}}; + + std::string base_name; + concatenate( + base_name, + "steel_attention_", + type_to_name(q), + "_bq", + bq, + "_bk", + bk, + "_bd", + bd, + "_wm", + wm, + "_wn", + wn, + "_mask", + type_to_name(has_mask ? *mask : q)); + + std::string hash_name; + concatenate( + hash_name, + base_name, + "_align_Q_", + (align_Q ? 't' : 'n'), + "_align_K_", + (align_K ? 't' : 'n'), + "_has_mask_", + (has_mask ? 't' : 'n'), + "_do_causal_", + (do_causal ? 't' : 'n'), + "_has_sinks_", + (has_sinks ? 't' : 'n')); + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(base_name, hash_name, func_consts); + compute_encoder.set_compute_pipeline_state(kernel); + + const int NQ = (qL + bq - 1) / bq; + const int NK = (kL + bk - 1) / bk; + + const int NQ_aligned = qL / bq; + const int NK_aligned = kL / bk; + + AttnParams params{ + /* int B = */ B, + /* int H = */ H, + /* int D = */ D, + + /* int qL = */ qL, + /* int kL = */ kL, + + /* int gqa_factor = */ gqa_factor, + /* float scale = */ scale, + + /* int NQ = */ NQ, + /* int NK = */ NK, + + /* int NQ_aligned = */ NQ_aligned, + /* int NK_aligned = */ NK_aligned, + + /* int qL_rem = */ (qL - NQ_aligned * bq), + /* int kL_rem = */ (kL - NK_aligned * bk), + /* int qL_off = */ (kL - qL), + + /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)}, + /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)}, + /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)}, + /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}}; + + compute_encoder.set_input_array(q, 0); + compute_encoder.set_input_array(k, 1); + compute_encoder.set_input_array(v, 2); + compute_encoder.set_output_array(o, 3); + compute_encoder.set_bytes(params, 4); + + if (has_mask) { + auto& m = *mask; + + AttnMaskParams mask_params{/* int64_t M_strides[3] = */ { + m.strides(0), m.strides(1), m.strides(2)}}; + + compute_encoder.set_bytes(mask_params, 5); + compute_encoder.set_input_array(m, 6); + } + if (has_sinks) { + compute_encoder.set_input_array(*sinks, 7); + } + + MTL::Size grid_dims = MTL::Size(NQ, H, B); + MTL::Size group_dims = MTL::Size(32, wm, wn); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +#endif // MLX_ENABLE_NAX + void sdpa_full_self_attention_metal( const Stream& s, metal::Device& d, @@ -23,6 +163,25 @@ void sdpa_full_self_attention_metal( bool do_causal_, const std::optional& mask, const std::optional& sinks) { +#ifdef MLX_ENABLE_NAX + if (__builtin_available(macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + if (metal::is_nax_available() && q.shape(3) != 80 && + (env::enable_tf32() || q.dtype() != float32)) { + return sdpa_full_self_attention_nax( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& q = */ q, + /* const array& k = */ k, + /* const array& v = */ v, + /* const float scale = */ scale, + /* array& o = */ o, + /* bool do_causal_ = */ do_causal_, + /* const std::optional& mask = */ mask, + /* const std::optional& sinks = */ sinks); + } + } +#endif // MLX_ENABLE_NAX + using namespace mlx::steel; int wm = 4; diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index e751063036..ce63544b47 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -163,6 +163,7 @@ def test_nvfp4_quantize_dequantize(self): def test_qmm(self): key = mx.random.key(0) k1, k2 = mx.random.split(key) + dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32 tests = product( [128, 64, 32], # group_size [2, 4, 8], # bits @@ -178,8 +179,13 @@ def test_qmm(self): bits=bits, transposed=transposed, ): - x = mx.random.normal(shape=(M, K), key=k1) - w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2) + x = mx.random.normal(shape=(M, K), key=k1) / K**0.5 + w = ( + mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2) + / K**0.5 + ) + x = x.astype(dtype) + w = w.astype(dtype) w_q, scales, biases = mx.quantize(w, group_size, bits) w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) y_q = mx.quantized_matmul( @@ -187,7 +193,9 @@ def test_qmm(self): ) y_hat = (x @ w_hat.T) if transposed else (x @ w_hat) self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + + tol = 1e-3 if dtype == mx.float32 else 1.5e-3 + self.assertLess((y_q - y_hat).abs().max(), tol) def test_qmm_vjp(self): key = mx.random.key(0) @@ -833,48 +841,75 @@ def scatter_unsort(x, inv_order, shape=None): (133, 512, 555, 4, 2, False, "affine"), (64, 512, 512, 4, 2, False, "affine"), ] + + key = mx.random.key(0) + k1, k2, k3 = mx.random.split(key, 3) + dtype = mx.float16 if (mx.default_device() == mx.gpu) else mx.float32 + for L, K, D, E, I, transpose, mode in parameters: - if mode == "mxfp4": - group_size = 32 - else: - group_size = 64 - K, D = (K, D) if transpose else (D, K) - ishape = (L, I) - xshape = (L, 1, 1, K) - wshape = (E, D, K) if transpose else (E, K, D) - - indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32) - x = mx.random.normal(xshape) / K**0.5 - w = mx.random.normal(wshape) / K**0.5 - w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose) - - y1 = mx.gather_mm(x, w, rhs_indices=indices) - y2 = mx.gather_qmm( - x, - *wq, - group_size=group_size, - mode=mode, - transpose=transpose, - rhs_indices=indices - ) - xs, idx, inv_order = gather_sort(x, indices) - y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True) + with self.subTest(L=L, K=K, D=D, E=E, I=I, transpose=transpose, mode=mode): + if mode == "mxfp4": + group_size = 32 + dtype = ( + mx.bfloat16 if (mx.default_device() == mx.gpu) else mx.float32 + ) + else: + group_size = 64 + dtype = ( + mx.float16 if (mx.default_device() == mx.gpu) else mx.float32 + ) + + K, D = (K, D) if transpose else (D, K) + ishape = (L, I) + xshape = (L, 1, 1, K) + wshape = (E, D, K) if transpose else (E, K, D) + + indices = (mx.random.uniform(shape=ishape, key=k1) * E).astype( + mx.uint32 + ) + x = mx.random.normal(xshape, key=k2) / K**0.5 + w = mx.random.normal(wshape, key=k3) / K**0.5 - y4 = mx.gather_qmm( - xs, - *wq, - group_size=group_size, - mode=mode, - rhs_indices=idx, - transpose=transpose, - sorted_indices=True - ) - y3 = scatter_unsort(y3, inv_order, indices.shape) - y4 = scatter_unsort(y4, inv_order, indices.shape) + x = x.astype(dtype) + w = w.astype(dtype) - self.assertTrue(mx.allclose(y1, y2, atol=1e-5)) - self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) - self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) + w, *wq = quantize( + w, group_size=group_size, mode=mode, transpose=transpose + ) + + y1 = mx.gather_mm(x, w, rhs_indices=indices) + y2 = mx.gather_qmm( + x, + *wq, + group_size=group_size, + mode=mode, + transpose=transpose, + rhs_indices=indices + ) + xs, idx, inv_order = gather_sort(x, indices) + y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True) + + y4 = mx.gather_qmm( + xs, + *wq, + group_size=group_size, + mode=mode, + rhs_indices=idx, + transpose=transpose, + sorted_indices=True + ) + y3 = scatter_unsort(y3, inv_order, indices.shape) + y4 = scatter_unsort(y4, inv_order, indices.shape) + + tol = 1.5e-5 if (dtype == mx.float32) else 2.5e-4 + + self.assertLess((y1 - y2).abs().max(), tol) + self.assertLess((y1 - y3).abs().max(), tol) + self.assertLess((y1 - y4).abs().max(), tol) + + self.assertTrue(mx.allclose(y1, y2, atol=tol)) + self.assertTrue(mx.allclose(y1, y3, atol=tol)) + self.assertTrue(mx.allclose(y1, y4, atol=tol)) def test_gather_qmm_grad(self): def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort): @@ -898,10 +933,14 @@ def gather_qmm(x, w, s, b, lhs, rhs, trans, sort): sorted_indices=sort, ) - x = mx.random.normal((16, 1, 256)) - w, s, b = mx.quantize(mx.random.normal((4, 256, 256))) - indices = mx.sort(mx.random.randint(0, 4, shape=(16,))) - cotan = mx.random.normal((16, 1, 256)) + key = mx.random.key(0) + k1, k2, k3, k4 = mx.random.split(key, 4) + dtype = mx.float32 + + x = mx.random.normal((16, 1, 256), key=k1).astype(dtype) + w, s, b = mx.quantize(mx.random.normal((4, 256, 256), key=k2).astype(dtype)) + indices = mx.sort(mx.random.randint(0, 4, shape=(16,), key=k3)) + cotan = mx.random.normal((16, 1, 256), key=k4).astype(dtype) (o1,), (dx1, ds1, db1) = mx.vjp( lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True), @@ -914,6 +953,7 @@ def gather_qmm(x, w, s, b, lhs, rhs, trans, sort): [cotan], ) + self.assertLess((o1 - o2).abs().max(), 1e-4) self.assertTrue(mx.allclose(o1, o2, atol=1e-4)) self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4)) self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3))