Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sgl-kernel/csrc/cpu/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ namespace {
} \
}()

// dispatch: bfloat16, float16, int8_t, fp8_e4m3
// dispatch: bfloat16, float16, int8_t, fp8_e4m3, uint8_t(mxfp4/int4)
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
Expand All @@ -65,6 +65,10 @@ namespace {
using packed_t = at::Float8_e4m3fn; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Byte: { \
using packed_t = uint8_t; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
Expand Down
84 changes: 80 additions & 4 deletions sgl-kernel/csrc/cpu/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,43 @@ inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restr
s8s8_compensation<BLOCK_N>(packed, K);
}

// uint8_t: mxfp4 or int4
// pack to vnni2 format as they are computed with bfloat16
//
// from [N, K'/2, 2] to [K'/2, N, 2], view 2x int4 as unit8:
// from [N, K ] to [K, N ] where K = K'/2
//
template <>
inline void pack_vnni<uint8_t>(uint8_t* __restrict__ packed, const uint8_t* __restrict__ weight, int N, int K) {
constexpr int BLOCK_N = block_size_n();

uint8_t unpacked[2 * BLOCK_N];

// 32-way pack (align with BLOCK_N), faster for avx512 unpacking
//
// for a range of (64):
// {0, 1, 2, ..., 63}
//
// original format:
// { 1|0, 3|2, ..., 63|62}
//
// packed format:
// {32|0, 31|1, ..., 63|31}
//
for (int k = 0; k < K; ++k) {
// unpack first
for (int n = 0; n < N; ++n) {
uint8_t value = weight[n * K + k];
unpacked[n * 2 + 0] = value & 0xF; // lower 4 bits
unpacked[n * 2 + 1] = value >> 4; // higher 4 bits
}
// re-pack to 32-way
for (int n = 0; n < N; ++n) {
packed[k * N + n] = (unpacked[n + BLOCK_N] << 4) | unpacked[n];
}
}
}

template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
Expand Down Expand Up @@ -600,9 +637,12 @@ at::Tensor convert_weight_packed(at::Tensor& weight) {
const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0);
const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1);

// mxfp4 or int4 are packed with uint8
const int64_t actual_IC = st == at::kByte ? IC * 2 : IC;

// we handle 2 TILE_N at a time.
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
TORCH_CHECK(actual_IC % TILE_K == 0, "invalid weight input features ", actual_IC);

constexpr int64_t BLOCK_N = block_size_n();
const int64_t NB = div_up(OC, BLOCK_N);
Expand All @@ -611,13 +651,14 @@ at::Tensor convert_weight_packed(at::Tensor& weight) {
auto packed_weight = at::empty({}, weight.options());
const int64_t stride = OC * IC;

// Note: for `kByte` (uint8), it represents either `mxfp4` or `int4`.
TORCH_CHECK(
st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn,
"expect weight to be bfloat16, float16, int8 or fp8_e4m3.");
st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn || st == at::kByte,
"expect weight to be bfloat16, float16, int8, fp8_e4m3 or uint8(mxfp4 or int4).");

CPU_DISPATCH_PACKED_TYPES(st, [&] {
// adjust most inner dimension size
const int packed_row_size = get_row_size<packed_t>(IC);
const int packed_row_size = get_row_size<packed_t>(actual_IC);
auto sizes = weight.sizes().vec();
sizes[ndim - 1] = packed_row_size;
packed_weight.resize_(sizes);
Expand Down Expand Up @@ -646,6 +687,41 @@ at::Tensor convert_weight_packed(at::Tensor& weight) {
return packed_weight;
}

at::Tensor convert_scale_packed(at::Tensor& scale) {
CHECK_INPUT(scale);

const int64_t ndim = scale.ndimension();
TORCH_CHECK(ndim == 2 || ndim == 3, "expect scale to be 2d or 3d, got ", ndim, "d tensor.");
const auto st = scale.scalar_type();
const int64_t E = ndim == 3 ? scale.size(0) : 1;
const int64_t N = ndim == 3 ? scale.size(1) : scale.size(0);
// number of groups, e.g. K/32
const int64_t G = ndim == 3 ? scale.size(2) : scale.size(1);

constexpr int64_t BLOCK_N = block_size_n();
TORCH_CHECK(N % BLOCK_N == 0, "invalid weight out features ", N);
const int64_t NB = N / BLOCK_N;

auto packed_scale = at::empty_like(scale);
TORCH_CHECK(st == at::kByte, "expect scale to be uint8.");

const uint8_t* s_data = scale.data_ptr<uint8_t>();
uint8_t* packed_data = packed_scale.data_ptr<uint8_t>();

// parallel on src {E, NB, BLOCK_N, G}, dst {E, NB, G, BLOCK_N}
at::parallel_for(0, E * NB * BLOCK_N * G, 0, [&](int64_t begin, int64_t end) {
int64_t e{0}, nb{0}, n{0}, g{0};
data_index_init(begin, e, E, nb, NB, n, BLOCK_N, g, G);

for (int64_t i = begin; i < end; ++i) {
packed_data[e * N * G + nb * G * BLOCK_N + g * BLOCK_N + n] = s_data[i];
// move to the next index
data_index_step(e, E, nb, NB, n, BLOCK_N, g, G);
}
});
return packed_scale;
}

// mat1 : [M, K]
// mat2 : [N, K] ([K, N] if use_fma_gemm)
// bias : [N]
Expand Down
30 changes: 30 additions & 0 deletions sgl-kernel/csrc/cpu/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ inline bool can_use_brgemm<int8_t>(int M) {
return M > 4;
}

template <>
inline bool can_use_brgemm<uint8_t>(int M) {
return M > 4;
}

template <>
inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) {
return M > 4;
Expand All @@ -52,6 +57,12 @@ inline int64_t get_row_size<int8_t>(int64_t K) {
return K + sizeof(int32_t);
}

// uint8: mxfp4 or int4
template <>
inline int64_t get_row_size<uint8_t>(int64_t K) {
return K >> 1;
}

inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
}
Expand Down Expand Up @@ -287,3 +298,22 @@ void tinygemm_kernel(
int64_t ldc_s,
bool store_out,
bool use_brgemm);

// mxfp4
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const uint8_t* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const uint8_t* __restrict__ scale,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K,
bool do_unpack = true);
Loading
Loading