diff --git a/sgl-kernel/csrc/cpu/common.h b/sgl-kernel/csrc/cpu/common.h index 31c6295a011a..48f25e21c286 100644 --- a/sgl-kernel/csrc/cpu/common.h +++ b/sgl-kernel/csrc/cpu/common.h @@ -45,7 +45,7 @@ namespace { } \ }() -// dispatch: bfloat16, float16, int8_t, fp8_e4m3 +// dispatch: bfloat16, float16, int8_t, fp8_e4m3, uint8_t(mxfp4/int4) #define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \ [&] { \ switch (TYPE) { \ @@ -65,6 +65,10 @@ namespace { using packed_t = at::Float8_e4m3fn; \ return __VA_ARGS__(); \ } \ + case at::ScalarType::Byte: { \ + using packed_t = uint8_t; \ + return __VA_ARGS__(); \ + } \ default: \ TORCH_CHECK(false, "Unsupported floating data type.\n"); \ } \ diff --git a/sgl-kernel/csrc/cpu/gemm.cpp b/sgl-kernel/csrc/cpu/gemm.cpp index e2fdc8951f23..13d329be84f9 100644 --- a/sgl-kernel/csrc/cpu/gemm.cpp +++ b/sgl-kernel/csrc/cpu/gemm.cpp @@ -65,6 +65,43 @@ inline void pack_vnni(int8_t* __restrict__ packed, const int8_t* __restr s8s8_compensation(packed, K); } +// uint8_t: mxfp4 or int4 +// pack to vnni2 format as they are computed with bfloat16 +// +// from [N, K'/2, 2] to [K'/2, N, 2], view 2x int4 as unit8: +// from [N, K ] to [K, N ] where K = K'/2 +// +template <> +inline void pack_vnni(uint8_t* __restrict__ packed, const uint8_t* __restrict__ weight, int N, int K) { + constexpr int BLOCK_N = block_size_n(); + + uint8_t unpacked[2 * BLOCK_N]; + + // 32-way pack (align with BLOCK_N), faster for avx512 unpacking + // + // for a range of (64): + // {0, 1, 2, ..., 63} + // + // original format: + // { 1|0, 3|2, ..., 63|62} + // + // packed format: + // {32|0, 31|1, ..., 63|31} + // + for (int k = 0; k < K; ++k) { + // unpack first + for (int n = 0; n < N; ++n) { + uint8_t value = weight[n * K + k]; + unpacked[n * 2 + 0] = value & 0xF; // lower 4 bits + unpacked[n * 2 + 1] = value >> 4; // higher 4 bits + } + // re-pack to 32-way + for (int n = 0; n < N; ++n) { + packed[k * N + n] = (unpacked[n + BLOCK_N] << 4) | unpacked[n]; + } + } +} + template inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { using bVec = at::vec::Vectorized; @@ -600,9 +637,12 @@ at::Tensor convert_weight_packed(at::Tensor& weight) { const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0); const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1); + // mxfp4 or int4 are packed with uint8 + const int64_t actual_IC = st == at::kByte ? IC * 2 : IC; + // we handle 2 TILE_N at a time. TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC); - TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC); + TORCH_CHECK(actual_IC % TILE_K == 0, "invalid weight input features ", actual_IC); constexpr int64_t BLOCK_N = block_size_n(); const int64_t NB = div_up(OC, BLOCK_N); @@ -611,13 +651,14 @@ at::Tensor convert_weight_packed(at::Tensor& weight) { auto packed_weight = at::empty({}, weight.options()); const int64_t stride = OC * IC; + // Note: for `kByte` (uint8), it represents either `mxfp4` or `int4`. TORCH_CHECK( - st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn, - "expect weight to be bfloat16, float16, int8 or fp8_e4m3."); + st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn || st == at::kByte, + "expect weight to be bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4)."); CPU_DISPATCH_PACKED_TYPES(st, [&] { // adjust most inner dimension size - const int packed_row_size = get_row_size(IC); + const int packed_row_size = get_row_size(actual_IC); auto sizes = weight.sizes().vec(); sizes[ndim - 1] = packed_row_size; packed_weight.resize_(sizes); @@ -646,6 +687,41 @@ at::Tensor convert_weight_packed(at::Tensor& weight) { return packed_weight; } +at::Tensor convert_scale_packed(at::Tensor& scale) { + CHECK_INPUT(scale); + + const int64_t ndim = scale.ndimension(); + TORCH_CHECK(ndim == 2 || ndim == 3, "expect scale to be 2d or 3d, got ", ndim, "d tensor."); + const auto st = scale.scalar_type(); + const int64_t E = ndim == 3 ? scale.size(0) : 1; + const int64_t N = ndim == 3 ? scale.size(1) : scale.size(0); + // number of groups, e.g. K/32 + const int64_t G = ndim == 3 ? scale.size(2) : scale.size(1); + + constexpr int64_t BLOCK_N = block_size_n(); + TORCH_CHECK(N % BLOCK_N == 0, "invalid weight out features ", N); + const int64_t NB = N / BLOCK_N; + + auto packed_scale = at::empty_like(scale); + TORCH_CHECK(st == at::kByte, "expect scale to be uint8."); + + const uint8_t* s_data = scale.data_ptr(); + uint8_t* packed_data = packed_scale.data_ptr(); + + // parallel on src {E, NB, BLOCK_N, G}, dst {E, NB, G, BLOCK_N} + at::parallel_for(0, E * NB * BLOCK_N * G, 0, [&](int64_t begin, int64_t end) { + int64_t e{0}, nb{0}, n{0}, g{0}; + data_index_init(begin, e, E, nb, NB, n, BLOCK_N, g, G); + + for (int64_t i = begin; i < end; ++i) { + packed_data[e * N * G + nb * G * BLOCK_N + g * BLOCK_N + n] = s_data[i]; + // move to the next index + data_index_step(e, E, nb, NB, n, BLOCK_N, g, G); + } + }); + return packed_scale; +} + // mat1 : [M, K] // mat2 : [N, K] ([K, N] if use_fma_gemm) // bias : [N] diff --git a/sgl-kernel/csrc/cpu/gemm.h b/sgl-kernel/csrc/cpu/gemm.h index e11b224fe193..fc2b199bfea6 100644 --- a/sgl-kernel/csrc/cpu/gemm.h +++ b/sgl-kernel/csrc/cpu/gemm.h @@ -33,6 +33,11 @@ inline bool can_use_brgemm(int M) { return M > 4; } +template <> +inline bool can_use_brgemm(int M) { + return M > 4; +} + template <> inline bool can_use_brgemm(int M) { return M > 4; @@ -52,6 +57,12 @@ inline int64_t get_row_size(int64_t K) { return K + sizeof(int32_t); } +// uint8: mxfp4 or int4 +template <> +inline int64_t get_row_size(int64_t K) { + return K >> 1; +} + inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) { return use_int8_w8a8 ? K + sizeof(int32_t) : K; } @@ -287,3 +298,22 @@ void tinygemm_kernel( int64_t ldc_s, bool store_out, bool use_brgemm); + +// mxfp4 +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const uint8_t* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const uint8_t* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K, + bool do_unpack = true); diff --git a/sgl-kernel/csrc/cpu/gemm_fp8.cpp b/sgl-kernel/csrc/cpu/gemm_fp8.cpp index 15bd44434e0e..245fd3a075d7 100644 --- a/sgl-kernel/csrc/cpu/gemm_fp8.cpp +++ b/sgl-kernel/csrc/cpu/gemm_fp8.cpp @@ -65,15 +65,15 @@ inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ inline void unpack_B( at::BFloat16* __restrict__ Btmp, const at::Float8_e4m3fn* __restrict__ packed_B, - int N, - int K, - int ldb, - int ldb_tmp, + int64_t N, + int64_t K, + int64_t ldb, + int64_t ldb_tmp, float scale) { #if defined(CPU_CAPABILITY_AVX512) // [K/2, N, 2] - const int K2 = K >> 1; - const int ldb2 = ldb; // ldb * 2 >> 1; + const int64_t K2 = K >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; const uint16_t* b_ptr = reinterpret_cast(packed_B); const __m512 vexp = _mm512_castsi512_ps(_mm512_set1_epi32(kFP8_BIAS)); const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(scale), vexp); @@ -85,7 +85,7 @@ inline void unpack_B( constexpr int PREFETCH_SIZE_K = 64; #pragma GCC unroll 4 - for (int k = 0; k < K2; ++k) { + for (int64_t k = 0; k < K2; ++k) { __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2); if constexpr (PREFETCH_SIZE_K > 0) { _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0); @@ -154,18 +154,67 @@ inline void unpack_B( #endif } -template +// mxfp4 +inline void unpack_B( + at::BFloat16* __restrict__ Btmp, + const uint8_t* __restrict__ packed_B, + int64_t N, + int64_t K, + int64_t ldb, + int64_t ldb_tmp, + const uint8_t* __restrict__ scale) { +#if defined(CPU_CAPABILITY_AVX512) + // [K/2, N, 2] + const int64_t K2 = K >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const uint8_t* b_ptr = reinterpret_cast(packed_B); // 2 * 4 bit = 8 bit + + constexpr int BLOCK_N = block_size_n(); + static_assert(BLOCK_N == 32); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 64; + + // exponent bias 127 + const __m512i off = _mm512_set1_epi16(0x7F); + + // load 32 bytes only once for each block + __m256i s8 = _mm256_loadu_si256(reinterpret_cast(scale)); + __m512i s16 = _mm512_slli_epi16(_mm512_sub_epi16(_mm512_cvtepu8_epi16(s8), off), 0x7); + + // holds Nx2(64) scales, interleaved as 2 belongs to K dimension + // e.g. vs0: { s0, s0, s1, s1, ..., s15, s15} + // vs1: {s16, s16, s17, s17, ..., s31, s31} + auto [vscale0, vscale1] = transpose_2x32_16bit(s16, s16); + +#pragma GCC unroll 4 + for (int64_t k = 0; k < K2; ++k) { + __m256i b4 = _mm256_loadu_si256(reinterpret_cast(b_ptr + k * ldb2)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0); + } + auto [vb0, vb1] = CVT_MXFP4_TO_BF16(b4, vscale0, vscale1); + + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)vb0); + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)vb1); + } +#else + TORCH_CHECK(false, "unpack_B: scalar path not implemented!"); +#endif +} + +template struct tinygemm_kernel_nn { static inline void apply( const scalar_t* __restrict__ A, const packed_t* __restrict__ B, scalar_t* __restrict__ C, const float* __restrict__ bias, - const float* __restrict__ scale, - int K, - int lda, - int ldb, - int ldc, + const param_t* __restrict__ scale, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, int64_t block_size_K) { TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); } @@ -187,22 +236,22 @@ struct tinygemm_kernel_nn2 { }; #if defined(CPU_CAPABILITY_AVX512) template -struct tinygemm_kernel_nn { +struct tinygemm_kernel_nn { static inline void apply( const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, at::BFloat16* __restrict__ C, const float* __restrict__ bias, const float* __restrict__ scale, - int K, - int lda, - int ldb, - int ldc, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, int64_t block_size_K) { constexpr int ROWS = BLOCK_M; constexpr int COLS = BLOCK_N / 16; - const int KB = div_up(K, BLOCK_K); + const int64_t KB = div_up(K, (int64_t)BLOCK_K); // prefetch distance constexpr int PREFETCH_SIZE_K = 64; @@ -228,8 +277,8 @@ struct tinygemm_kernel_nn{}(loadc); - const int lda2 = lda >> 1; - const int ldb2 = ldb; // ldb * 2 >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; const float* a_ptr = reinterpret_cast(A); const uint16_t* b_ptr = reinterpret_cast(B); @@ -256,10 +305,10 @@ struct tinygemm_kernel_nn> 1; - for (int kb = 0; kb < KB; ++kb) { - int kb_start = kb * BLOCK_K2; - int kb_end = std::min(K >> 1, kb_start + BLOCK_K2); + constexpr int64_t BLOCK_K2 = BLOCK_K >> 1; + for (int64_t kb = 0; kb < KB; ++kb) { + int64_t kb_start = kb * BLOCK_K2; + int64_t kb_end = std::min(K >> 1, kb_start + BLOCK_K2); // 1. load scale vector vscale = _mm512_set1_ps(scale[kb]); vscale = _mm512_mul_ps(vscale, vexp); @@ -359,10 +408,110 @@ struct tinygemm_kernel_nn2 { Unroll{}(storec); } }; + +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, + const uint8_t* __restrict__ B, + at::BFloat16* __restrict__ C, + const float* __restrict__ bias, + const uint8_t* __restrict__ scale, + int K, + int lda, + int ldb, + int ldc, + int64_t block_size_K) { + // mxfp4 supports only group size of 32 + // expect weight packed in 32-way, vnni2 format Nx2(64) + assert(block_size_K == 32); + assert(BLOCK_N == 32); + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 64; + constexpr int PREFETCH_SIZE_KB = 1; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + + // holds Nx2(64) scales, interleaved as 2 belongs to K dimension + // e.g. vs0: { s0, s0, s1, s1, ..., s15, s15} + // vs1: {s16, s16, s17, s17, ..., s31, s31} + __m512i vscale[COLS]; + + // exponent bias 127 + const __m512i off = _mm512_set1_epi16(0x7F); + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + if constexpr (has_bias) { + vc[i] = _mm512_loadu_ps(bias + col * 16); + } else { + vc[i] = _mm512_setzero_ps(); + } + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const uint8_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0); + } + } + if constexpr (row == 0) { + // load 32 * 2 (64) int4 at a time + if constexpr (col % 2 == 0) { + __m256i b4 = _mm256_loadu_si256(reinterpret_cast(b_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + std::tie(vb[col + 0], vb[col + 1]) = CVT_MXFP4_TO_BF16(b4, vscale[col + 0], vscale[col + 1]); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + + for (int64_t k = 0; k < K2; ++k) { + // update scales every 16x2 K + if ((k & 15) == 0) { + __m256i s8 = _mm256_loadu_si256(reinterpret_cast(scale + (k >> 4) * 32)); + __m512i s16 = _mm512_slli_epi16(_mm512_sub_epi16(_mm512_cvtepu8_epi16(s8), off), 0x7); + std::tie(vscale[0], vscale[1]) = transpose_2x32_16bit(s16, s16); + } + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2,4 use 512bit store + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); + } + }; + Unroll{}(storec); + } +}; #endif #define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ - tinygemm_kernel_nn::apply( \ + tinygemm_kernel_nn::apply( \ A + mb_start * lda, \ B + nb_start * 2, \ C + mb_start * ldc + nb_start, \ @@ -378,7 +527,7 @@ struct tinygemm_kernel_nn2 { tinygemm_kernel_nn2::apply( \ A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, scale, K, lda, ldb, ldc); -template +template struct brgemm { static inline void apply( const scalar_t* __restrict__ A, @@ -387,7 +536,7 @@ struct brgemm { scalar_t* __restrict__ Btmp, float* __restrict__ Ctmp, const float* __restrict__ bias, - const float* __restrict__ scale, + const param_t* __restrict__ scale, int M, int N, int K, @@ -402,7 +551,7 @@ template struct brgemm2 {}; template -struct brgemm { +struct brgemm { static inline void apply( const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, @@ -481,14 +630,56 @@ struct brgemm2 { } }; -template +template +struct brgemm { + static inline void apply( + const at::BFloat16* __restrict__ A, + const uint8_t* __restrict__ B, + at::BFloat16* __restrict__ C, + at::BFloat16* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + const uint8_t* __restrict__ scale, + int M, + int N, + int K, + int lda, + int ldb, + int ldc, + bool do_unpack = true) { + constexpr int BLOCK_N = block_size_n(); + + // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2] + const int ldb_tmp = BLOCK_N; + + if (do_unpack) { + // group size 32 for mxfp4 + for (int k = 0; k < K; k += 32) { + unpack_B(Btmp + k * ldb_tmp, B + k * (ldb >> 1), N, 32, ldb, ldb_tmp, scale + (k >> 5) * BLOCK_N); + } + } + + at::native::cpublas::brgemm(M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp); + + // copy from Ctmp to C + for (int m = 0; m < M; ++m) { + if constexpr (has_bias) { + copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); + } else { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + } +}; + +template void tinygemm_kernel( const scalar_t* __restrict__ A, - const at::Float8_e4m3fn* __restrict__ B, + const packed_t* __restrict__ B, scalar_t* __restrict__ C, scalar_t* __restrict__ Btmp, float* __restrict__ Ctmp, - const float* __restrict__ scale, + const param_t* __restrict__ scale, const float* __restrict__ bias, int64_t M, int64_t N, @@ -500,7 +691,7 @@ void tinygemm_kernel( int64_t block_size_K, bool do_unpack = true) { if (brg) { - brgemm::apply( + brgemm::apply( A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc, do_unpack); return; } @@ -536,6 +727,7 @@ void tinygemm_kernel( } } } + template void tinygemm_kernel2( const scalar_t* __restrict__ A, @@ -633,12 +825,19 @@ void tinygemm_kernel2( } } } -template -void fp8_scaled_mm_kernel_impl( + +// NB: fp8/fp4 scaled mm kernel implementation +// +// scalar_t packed_t param_t +// FP8 BF16 FP8 FP32 +// MXFP4 BF16 U8 U8 +// +template +void fp_scaled_mm_kernel_impl( scalar_t* __restrict__ out, const scalar_t* __restrict__ mat1, - const at::Float8_e4m3fn* __restrict__ mat2, - const float* __restrict__ scales2, + const packed_t* __restrict__ mat2, + const param_t* __restrict__ scales2, const float* __restrict__ bias, scalar_t* __restrict__ buffer, int64_t M, @@ -648,16 +847,17 @@ void fp8_scaled_mm_kernel_impl( int64_t out_strideM, int64_t block_size_N, int64_t block_size_K, - int64_t buffer_size_per_thread) { + int64_t buffer_size_per_thread, + const func_t& scale_offset_per_block) { constexpr int64_t BLOCK_M = block_size_m(); constexpr int64_t BLOCK_N = block_size_n(); const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); - const int64_t scale_size_K = div_up(K, block_size_K); - const int64_t blocks_n_per_group = block_size_N / BLOCK_N; + const bool use_brgemm = can_use_brgemm(M); - const bool use_brgemm = can_use_brgemm(M); + // use K/2 for mxfp4 and K for fp8 + const int64_t packed_K = get_row_size(K); // parallel on [MB, NB] AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { @@ -666,8 +866,8 @@ void fp8_scaled_mm_kernel_impl( scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread; float* __restrict__ Ctmp = (float*)((void*)(Btmp + MAX_CACHE_BLOCK_SIZE * BLOCK_N * K)); - loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) { - const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K; + loop_2d(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) { + const param_t* scale_ptr = scales2 + scale_offset_per_block(nb); int64_t mb_start = mb * BLOCK_M; int64_t mb_size = std::min(M - mb_start, BLOCK_M); @@ -677,9 +877,9 @@ void fp8_scaled_mm_kernel_impl( // only do unpacking for the first row bool do_unpack = (mb == mb0); - tinygemm_kernel( + tinygemm_kernel( /* A */ mat1 + mb_start * mat1_strideM, - /* B */ mat2 + nb_start * K, // nb * BLOCK_N * K + /* B */ mat2 + nb_start * packed_K, // nb * BLOCK_N * K /* C */ out + mb_start * out_strideM + nb_start, /* Btmp */ Btmp + nb_offset * BLOCK_N * K, /* Ctmp */ Ctmp, @@ -723,9 +923,10 @@ void tinygemm_kernel( bool brg, int64_t block_size_K, bool do_unpack) { - tinygemm_kernel( + tinygemm_kernel( A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K, do_unpack); } + template void tinygemm_kernel( const scalar_t* __restrict__ A, @@ -743,24 +944,51 @@ void tinygemm_kernel( bool brg) { tinygemm_kernel2(A, B, C, Btmp, Ctmp, scale, M, N, K, lda, ldb, ldc, brg); } -#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ - template void tinygemm_kernel( \ - const TYPE* __restrict__ A, \ - const at::Float8_e4m3fn* __restrict__ B, \ - TYPE* __restrict__ C, \ - TYPE* __restrict__ Btmp, \ - float* __restrict__ Ctmp, \ - const float* __restrict__ scale, \ - int64_t M, \ - int64_t N, \ - int64_t K, \ - int64_t lda, \ - int64_t ldb, \ - int64_t ldc, \ - bool brg, \ - int64_t block_size_K, \ + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const uint8_t* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const uint8_t* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K, + bool do_unpack) { + tinygemm_kernel( + A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K, do_unpack); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE_A, TYPE_B, TYPE_S) \ + template void tinygemm_kernel( \ + const TYPE_A* __restrict__ A, \ + const TYPE_B* __restrict__ B, \ + TYPE_A* __restrict__ C, \ + TYPE_A* __restrict__ Btmp, \ + float* __restrict__ Ctmp, \ + const TYPE_S* __restrict__ scale, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t lda, \ + int64_t ldb, \ + int64_t ldc, \ + bool brg, \ + int64_t block_size_K, \ bool do_unpack) +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16, at::Float8_e4m3fn, float); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half, at::Float8_e4m3fn, float); +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16, uint8_t, uint8_t); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half, uint8_t, uint8_t); + #define INSTANTIATE_TINYGEMM_TEMPLATE2(TYPE) \ template void tinygemm_kernel( \ const TYPE* __restrict__ A, \ @@ -777,10 +1005,28 @@ void tinygemm_kernel( int64_t ldc, \ bool brg) -INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); -INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); INSTANTIATE_TINYGEMM_TEMPLATE2(at::BFloat16); +inline const float* get_bias_data(const std::optional& bias, int64_t N) { + if (bias.has_value()) { + const auto& bias_ref = bias.value(); + CHECK_EQ(bias_ref.size(0), N); + return bias_ref.data_ptr(); + } + return nullptr; +} + +// FP8 and MXFP4 WoQ uses the same pattern: +// Btmp : [T, BLOCK_N * K] +// Ctmp : [T, BLOCK_M * BLOCK_N] +inline at::Tensor alloc_thread_buffer(const at::TensorOptions& options, int64_t K) { + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + int num_threads = at::get_num_threads(); + int64_t size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * K + BLOCK_M * BLOCK_N * 2; + return at::empty({num_threads, size_per_thread}, options); +} + at::Tensor fp8_scaled_mm_cpu( at::Tensor& mat1, at::Tensor& mat2, @@ -807,11 +1053,9 @@ at::Tensor fp8_scaled_mm_cpu( CHECK_DIM(2, mat2); TORCH_CHECK(block_size.size() == 2, "fp8_scaled_mm_cpu: expect block_size.size() to be 2."); - int64_t block_size_N = block_size[0]; int64_t block_size_K = block_size[1]; - constexpr int64_t BLOCK_M = block_size_m(); constexpr int64_t BLOCK_N = block_size_n(); TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); @@ -825,39 +1069,90 @@ at::Tensor fp8_scaled_mm_cpu( TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales to be float32."); auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); - // strides - int64_t mat1_strideM = mat1.stride(0); - int64_t out_strideM = out.stride(0); - - const bool has_bias = bias.has_value(); - const float* bias_data = nullptr; - if (has_bias) { - CHECK_EQ(bias.value().size(0), N); - bias_data = bias.value().data_ptr(); - } - - // Btmp : [T, BLOCK_N * K] - // Ctmp : [T, BLOCK_M * BLOCK_N] - int num_threads = at::get_num_threads(); - int64_t size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * K + BLOCK_M * BLOCK_N * 2; - auto buffer = at::empty({num_threads, size_per_thread}, mat1.options()); + auto buffer = alloc_thread_buffer(mat1.options(), K); AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { - fp8_scaled_mm_kernel_impl( + // used for lambda computing scale offset for each block + // fp8 block gemm sale shape: [N/128, K/128] + // for each block: [1, K/128] + const int64_t scale_size_K = div_up(K, block_size_K); + const int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + fp_scaled_mm_kernel_impl( out.data_ptr(), mat1.data_ptr(), packed_w.data_ptr(), scales2.data_ptr(), - bias_data, + get_bias_data(bias, N), buffer.data_ptr(), M, N, K, - mat1_strideM, - out_strideM, + mat1.stride(0), + out.stride(0), block_size_N, block_size_K, - size_per_thread); + buffer.size(-1), + [&](int64_t nb) { return (nb / blocks_n_per_group) * scale_size_K; }); + }); + + return out; +} + +// mat1 : [M, K] bfloat16 +// mat2 : [N, K / 2] uint8, actual layout: [N / BLOCK_N, K / 2, BLOCK_N, 2] +// scales2: [N, K / G], actual layout: [N / BLOCK_N, K / G, BLOCK_N] +at::Tensor mxfp4_scaled_mm_cpu( + at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, const std::optional& bias, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::mxfp4_scaled_mm_cpu", std::vector({mat1, mat2, scales2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat2.size(1) * 2; + + // mxfp4 supports only group size of 32 (2^5) + constexpr int64_t group_size = 32; + constexpr int64_t BLOCK_N = block_size_n(); + + CHECK_EQ(mat1.size(1), K); + CHECK_EQ(scales2.numel(), N * K >> 5); + + const auto st = mat1.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "mxfp4_scaled_mm_cpu: expect A to be bfloat16 or half."); + TORCH_CHECK(mat2.scalar_type() == at::kByte, "mxfp4_scaled_mm_cpu: expect mat2 to be uint8."); + TORCH_CHECK(scales2.scalar_type() == at::kByte, "mxfp4_scaled_mm_cpu: expect scales to be uint8."); + auto out = at::empty({M, N}, mat1.options()); + + auto buffer = alloc_thread_buffer(mat1.options(), K); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "mxfp4_scaled_mm_kernel_impl", [&] { + // used for lambda computing scale offset for each block + // mxfp4 block gemm sale shape: [N/BLOCK_N, K/32, BLOCK_N] + // for each block: [K/32, BLOCK_N] + const int64_t s_strideN = (K >> 5) * BLOCK_N; + + fp_scaled_mm_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + scales2.data_ptr(), + get_bias_data(bias, N), + buffer.data_ptr(), + M, + N, + K, + mat1.stride(0), + out.stride(0), + /* block_size_N */ 1, + /* block_size_K */ group_size, + buffer.size(-1), + [&](int64_t nb) { return nb * s_strideN; }); }); return out; diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 88f4228a50fc..abd6abde7717 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -143,6 +143,9 @@ std::tuple chunk_gated_delta_rule_cpu( // weight prepack at::Tensor convert_weight_packed(at::Tensor& weight); +// scale prepack for mxfp4 +at::Tensor convert_scale_packed(at::Tensor& scale); + // quant std::tuple per_token_quant_int8_cpu(at::Tensor& A); @@ -178,6 +181,10 @@ at::Tensor fp8_scaled_mm_cpu( at::ScalarType out_dtype, bool is_vnni); +// mxfp4 gemm +at::Tensor mxfp4_scaled_mm_cpu( + at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, const std::optional& bias, bool is_vnni); + // quant + igemm at::Tensor int8_scaled_mm_with_quant( at::Tensor& mat1, @@ -463,6 +470,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("convert_weight_packed(Tensor weight) -> Tensor"); m.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed); + // scale prepack for mxfp4 + m.def("convert_scale_packed(Tensor scale) -> Tensor"); + m.impl("convert_scale_packed", torch::kCPU, &convert_scale_packed); + // quant m.def("per_token_quant_int8_cpu(Tensor A) -> (Tensor, Tensor)"); m.impl("per_token_quant_int8_cpu", torch::kCPU, &per_token_quant_int8_cpu); @@ -488,6 +499,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "out_dtype, bool is_vnni) -> Tensor"); m.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu); + // mxfp4 gemm + m.def("mxfp4_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales2, Tensor? bias, bool is_vnni) -> Tensor"); + m.impl("mxfp4_scaled_mm_cpu", torch::kCPU, &mxfp4_scaled_mm_cpu); + // quant + igemm m.def( "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, Tensor? bias, ScalarType out_dtype, bool " diff --git a/sgl-kernel/csrc/cpu/vec.h b/sgl-kernel/csrc/cpu/vec.h index 107022ffd237..a37bc6ba2467 100644 --- a/sgl-kernel/csrc/cpu/vec.h +++ b/sgl-kernel/csrc/cpu/vec.h @@ -145,6 +145,49 @@ inline __attribute__((always_inline)) __m512bh CVT_FP8_TO_BF16_EXT(__m256i a) { // bias for conversion of fp8 to bf16 1/256 in float32 #define kFP8_BIAS 0x3b800000 +// remove warning: ignoring attributes on template argument ‘__m512bh’ [-Wignored-attributes] +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-attributes" + +#define MXFP4_VALUES \ + -6.0f, -4.0f, -3.0f, -2.0f, -1.5f, -1.0f, -0.5f, -0.0f, 6.0f, 4.0f, 3.0f, 2.0f, 1.5f, 1.0f, 0.5f, 0.0f + +// convert 64 mxfp4 to 2x bf16 vectors, expect input 32-way packing +inline std::tuple<__m512bh, __m512bh> cvt_mxfp4_e2m1_bf16_intrinsic_lut(__m256i a, __m512i s0, __m512i s1) { + // LUT + const __m512 values = _mm512_set_ps(MXFP4_VALUES); + const __m512i lut = (__m512i)(_mm512_cvtne2ps_pbh(values, values)); + + const __m512i abs_mask = _mm512_set1_epi16(0x7FFF); + const __m512i zero = _mm512_setzero_si512(); + + // expand values to 16-bit integers + __m512i x0 = _mm512_cvtepu8_epi16(a); + __m512i x1 = _mm512_srli_epi32(x0, 4); + + // LUT to convert mxfp4 values to bf16 + x0 = _mm512_permutexvar_epi16(x0, lut); + x1 = _mm512_permutexvar_epi16(x1, lut); + + // check for zeros + __mmask32 mask0 = _mm512_cmp_epi16_mask(_mm512_and_si512(x0, abs_mask), zero, _MM_CMPINT_EQ); + __mmask32 mask1 = _mm512_cmp_epi16_mask(_mm512_and_si512(x1, abs_mask), zero, _MM_CMPINT_EQ); + + // emulate bf16 mul with scale factor + x0 = _mm512_add_epi16(x0, s0); + x1 = _mm512_add_epi16(x1, s1); + + // blend with zero + x0 = _mm512_mask_blend_epi16(mask0, x0, zero); + x1 = _mm512_mask_blend_epi16(mask1, x1, zero); + + return std::make_tuple(__m512bh(x0), __m512bh(x1)); +} + +#define CVT_MXFP4_TO_BF16(a, s0, s1) cvt_mxfp4_e2m1_bf16_intrinsic_lut(a, s0, s1) + +#pragma GCC diagnostic pop + #endif // vector to scalar reduction