Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions csrc/xpu/onednn/fp4_gemm_w4a4.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#pragma once

#include <c10/xpu/XPUStream.h>
#include <dnnl.hpp>
#include <torch/torch.h>
Comment thread
zufangzhu marked this conversation as resolved.

#include "onednn_ext.h"
#include "onednn_runtime.h"

namespace oneDNN {

static inline void dnnl_matmul_w4a4_fp4(
torch::Tensor& result, // dst, [b, m, n]
const torch::Tensor& mat1, // src, [b, m, k]
const torch::Tensor& mat2, // quantized weight, [k, n] transpose
bool is_nt,
const std::optional<torch::Tensor>& bias,
const torch::Tensor& m1_sc,
const torch::Tensor& m2_sc) {
auto src_sz = mat1.sizes();
auto o_sz = result.sizes();

const int m = std::reduce(
src_sz.begin(), src_sz.end() - 1, 1, std::multiplies<int64_t>());
const int n = o_sz.back(); // presume channel last format
const int k = (*(src_sz.end() - 1)) * 2;
Comment thread
zufangzhu marked this conversation as resolved.

// get joint dtypes
joint_dtypes_t jd;
auto out_dtype = result.scalar_type();
auto m1_sc_dtype = m1_sc.scalar_type();
auto m2_sc_dtype = m2_sc.scalar_type();
if (m1_sc_dtype == at::ScalarType::Float8_e8m0fnu) {
TORCH_CHECK(
m2_sc_dtype == at::ScalarType::Float8_e8m0fnu,
"Mismatched scale data types in mxfp4 matmul: ",
m1_sc_dtype,
" vs ",
m2_sc_dtype);
jd = out_dtype == at::ScalarType::BFloat16 ? joint_dtypes_t::mxfp4_bf16
: joint_dtypes_t::mxfp4_f16;
} else {
TORCH_INTERNAL_ASSERT(
false, "Unsupported scale type for fp4 matmul: ", m1_sc_dtype);
}

// get bias type
bias_type_t b_type = get_bias_type(bias, m, n);

trans_type_t tt = trans_type_t::nn;
if (is_nt) {
// transpose mat2
tt = trans_type_t::nt;
}

// get lda ldb and ldc
auto mat1_strides = mat1.strides();
int64_t leading_dim = -1;
if (mat1.dim() == 2) {
leading_dim = 0;
} else if (mat1.dim() == 3) {
leading_dim = mat1_strides[0] < mat1_strides[1] ? 0 : 1;
} else {
TORCH_CHECK(
false, "Unsupported input dimension for fp4 matmul: ", mat1.dim());
}
int64_t lda = 2 * mat1_strides[leading_dim];
int64_t ldb = mat2.strides()[mat2.dim() - 1] == 1
? mat2.strides()[mat2.dim() - 2]
: 2 * (mat2.strides()[mat2.dim() - 1]);
int64_t ldc = result.strides()[leading_dim];

auto f_attr = [&](dnnl::primitive_attr& pattr) {
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);

if (m1_sc_dtype == at::ScalarType::Float8_e8m0fnu) {
pattr.set_scales(
DNNL_ARG_SRC,
/* mask */ (1 << 0) + (1 << 1),
{1, 32},
get_onednn_dtype(m1_sc));
pattr.set_scales(
DNNL_ARG_WEIGHTS,
/* mask */ (1 << 0) + (1 << 1),
{32, 1},
get_onednn_dtype(m2_sc));
} else {
if (m1_sc.numel() == 1) {
pattr.set_scales(
DNNL_ARG_SRC,
/* mask */ 0,
{},
get_onednn_dtype(m1_sc));
/* per tensor quant */
} else {
pattr.set_scales(
DNNL_ARG_SRC,
/* mask */ (1 << 0) + (1 << 1),
{1, k},
get_onednn_dtype(m1_sc));
/* per token quant */
}

if (m2_sc.numel() == 1) {
pattr.set_scales(
DNNL_ARG_WEIGHTS,
/* mask */ 0,
{},
get_onednn_dtype(m2_sc));
/* per tensor quant */
} else {
pattr.set_scales(
DNNL_ARG_WEIGHTS,
/* mask */ (1 << 1),
{},
get_onednn_dtype(m2_sc));
/* per channel quant */
}
}
};
Comment thread
zufangzhu marked this conversation as resolved.

int arg_off = 0;

// ************************************************************
// get device, engine, stream
const int dev_id = c10::xpu::getCurrentXPUStream().device_index();
at::Device curDevice = at::Device(at::kXPU, dev_id);
auto engine = GpuEngineManager::Instance().get_engine(curDevice);

int m1_sc_group_size = m1_sc.numel();
int m2_sc_group_size = m2_sc.numel();
int sc_group_size = (m1_sc_group_size << 8) | m2_sc_group_size;
auto& matmul_ext = matmul_primitive_create_and_cache(
Comment thread
zufangzhu marked this conversation as resolved.
jd, tt, b_type, m, n, k, lda, ldb, ldc, dev_id, f_attr, sc_group_size);

matmul_ext.set_attribute(
arg_off++,
DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS,
m2_sc.data_ptr(),
[&]() {
return make_onednn_memory(
get_onednn_md(m2_sc), engine, m2_sc.data_ptr());
});
matmul_ext.set_attribute(
arg_off++, DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, m1_sc.data_ptr(), [&]() {
return make_onednn_memory(
get_onednn_md(m1_sc), engine, m1_sc.data_ptr());
});

std::vector<std::pair<int, void*>> arg_handles;
arg_handles.reserve(8);

arg_handles.emplace_back(DNNL_ARG_SRC, mat1.data_ptr());
arg_handles.emplace_back(DNNL_ARG_WEIGHTS, mat2.data_ptr());
arg_handles.emplace_back(DNNL_ARG_DST, result.data_ptr());
if (get_shape(b_type) != bias_shape_t::none) {
arg_handles.emplace_back(DNNL_ARG_BIAS, bias.value().data_ptr());
}
int scratchpad_size = matmul_ext.get_scratchpad_size();
torch::Tensor scratchpad_tensor = at::empty(
{scratchpad_size}, mat1.options().dtype(at::kByte), c10::nullopt);
arg_handles.emplace_back(DNNL_ARG_SCRATCHPAD, scratchpad_tensor.data_ptr());

auto& strm = GpuStreamManager::Instance().get_stream();
auto qfp4_matmul_event =
matmul_ext.execute(strm, engine, std::move(arg_handles), arg_off);
Comment thread
zufangzhu marked this conversation as resolved.
}
} // namespace oneDNN
68 changes: 45 additions & 23 deletions csrc/xpu/onednn/fp8_gemm_w8a8.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ static inline void dnnl_matmul_w8a8_fp8(
: joint_dtypes_t::f8_e4m3_f16;
} else {
TORCH_INTERNAL_ASSERT(
false, "Unsupported data type for fp8 matmul: ", mat1.scalar_type());
false, "Unsupported data type for fp8 matmul: ", in_dtype);
}

// get bias type
Expand Down Expand Up @@ -68,38 +68,60 @@ static inline void dnnl_matmul_w8a8_fp8(
: mat2.strides()[mat2.dim() - 1];
int64_t ldc = result.strides()[leading_dim];

auto m1_sc_dtype = m1_sc.scalar_type();
auto m2_sc_dtype = m2_sc.scalar_type();
auto f_attr = [&](dnnl::primitive_attr& pattr) {
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
if (m1_sc.numel() == 1) {
pattr.set_scales(
DNNL_ARG_SRC,
/* mask */ 0,
{},
get_onednn_dtype(m1_sc));
/* per tensor quant */
} else {

if (m1_sc_dtype == at::ScalarType::Float8_e8m0fnu) {
TORCH_CHECK(
m2_sc_dtype == at::ScalarType::Float8_e8m0fnu,
"Mismatched scale data types in mxfp8 matmul: ",
m1_sc_dtype,
" vs ",
m2_sc_dtype);
pattr.set_scales(
DNNL_ARG_SRC,
/* mask */ (1 << 0) + (1 << 1),
{1, k},
{1, 32},
get_onednn_dtype(m1_sc));
/* per token quant */
}

if (m2_sc.numel() == 1) {
pattr.set_scales(
DNNL_ARG_WEIGHTS,
/* mask */ 0,
{},
/* mask */ (1 << 0) + (1 << 1),
{32, 1},
get_onednn_dtype(m2_sc));
/* per tensor quant */
} else {
pattr.set_scales(
DNNL_ARG_WEIGHTS,
/* mask */ (1 << 1),
{},
get_onednn_dtype(m2_sc));
/* per channel quant */
if (m1_sc.numel() == 1) {
pattr.set_scales(
DNNL_ARG_SRC,
/* mask */ 0,
{},
get_onednn_dtype(m1_sc));
/* per tensor quant */
} else {
pattr.set_scales(
DNNL_ARG_SRC,
/* mask */ (1 << 0) + (1 << 1),
{1, k},
get_onednn_dtype(m1_sc));
/* per token quant */
}

if (m2_sc.numel() == 1) {
pattr.set_scales(
DNNL_ARG_WEIGHTS,
/* mask */ 0,
{},
get_onednn_dtype(m2_sc));
/* per tensor quant */
} else {
pattr.set_scales(
DNNL_ARG_WEIGHTS,
/* mask */ (1 << 1),
{},
get_onednn_dtype(m2_sc));
/* per channel quant */
}
}
};

Expand Down
58 changes: 58 additions & 0 deletions csrc/xpu/onednn/onednn_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ enum class joint_dtypes_t {
f8_e5m2_bf16,
f8_e4m3_f16,
f8_e4m3_bf16,
mxfp4_bf16,
mxfp4_f16,
};

template <joint_dtypes_t Ts>
Expand Down Expand Up @@ -195,6 +197,30 @@ struct onednn_types_mapper<joint_dtypes_t::f8_e4m3_bf16> {
}
};

template <>
struct onednn_types_mapper<joint_dtypes_t::mxfp4_bf16> {
static inline std::
tuple<memory::data_type, memory::data_type, memory::data_type>
get() {
return std::make_tuple(
memory::data_type::f4_e2m1,
memory::data_type::f4_e2m1,
memory::data_type::bf16);
}
};

template <>
struct onednn_types_mapper<joint_dtypes_t::mxfp4_f16> {
static inline std::
tuple<memory::data_type, memory::data_type, memory::data_type>
get() {
return std::make_tuple(
memory::data_type::f4_e2m1,
memory::data_type::f4_e2m1,
memory::data_type::f16);
}
};

static inline memory::data_type
get_onednn_dtype(const at::Tensor& tensor, bool allow_undef = false) {
switch (tensor.scalar_type()) {
Expand All @@ -218,6 +244,10 @@ get_onednn_dtype(const at::Tensor& tensor, bool allow_undef = false) {
return memory::data_type::f8_e4m3;
case at::ScalarType::Float8_e5m2:
return memory::data_type::f8_e5m2;
case at::ScalarType::Float8_e8m0fnu:
return memory::data_type::e8m0;
case at::ScalarType::Float4_e2m1fn_x2:
return memory::data_type::f4_e2m1;
default:
if (!allow_undef) {
TORCH_CHECK(
Expand Down Expand Up @@ -1016,6 +1046,34 @@ static inline primitive_ext& matmul_primitive_create_and_cache(
attr,
scale_group_size,
zp_group_size);
case joint_dtypes_t::mxfp4_bf16:
return matmul_primitive_create_and_cache<joint_dtypes_t::mxfp4_bf16, F>(
Tt,
b_type,
m,
n,
k,
lda,
ldb,
ldc,
device_id,
attr,
scale_group_size,
zp_group_size);
case joint_dtypes_t::mxfp4_f16:
return matmul_primitive_create_and_cache<joint_dtypes_t::mxfp4_f16, F>(
Tt,
b_type,
m,
n,
k,
lda,
ldb,
ldc,
device_id,
attr,
scale_group_size,
zp_group_size);
default:
throw std::runtime_error("Only support int4 and fp8 gemm ...");
}
Expand Down
29 changes: 29 additions & 0 deletions csrc/xpu/onednn/onednn_matmul.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <vector>
#include "fp4_gemm_w4a4.h"
#include "fp8_gemm_w8a8.h"
#include "fp8_gemm_w8a16.h"
#include "int4_gemm_w4a16.h"
Expand All @@ -9,6 +10,10 @@ inline bool is_supported_fp8(at::ScalarType t) {
(t == at::ScalarType::Float8_e4m3fn);
}

inline bool is_supported_fp4(at::ScalarType t) {
return t == at::ScalarType::Float4_e2m1fn_x2;
}

torch::Tensor check_and_create_output_tensor(
const torch::Tensor& A,
const torch::Tensor& B,
Expand Down Expand Up @@ -93,6 +98,30 @@ torch::Tensor fp8_gemm_w8a16(
return result;
}

torch::Tensor fp4_gemm(
const torch::Tensor& A,
const torch::Tensor& B,
const torch::Tensor& A_scale,
const torch::Tensor& B_scale,
std::optional<c10::ScalarType> out_dtype,
const std::optional<torch::Tensor>& bias) {
const at::DeviceGuard device_guard(A.device());
torch::Tensor result = check_and_create_output_tensor(A, B, out_dtype);
auto a_st = A.scalar_type();
auto b_st = B.scalar_type();
TORCH_CHECK(
is_supported_fp4(a_st) && is_supported_fp4(b_st) && a_st == b_st,
"input and weight must be f4_e2m1x2 or f4_e2m1x2 for fp4 matmul");
TORCH_CHECK(
result.scalar_type() == torch::kFloat16 ||
result.scalar_type() == torch::kBFloat16,
"output must be float16 or bfloat16 for fp4 matmul");
// check if nt format
bool is_nt = B.strides()[B.dim() - 2] == 1;
oneDNN::dnnl_matmul_w4a4_fp4(result, A, B, is_nt, bias, A_scale, B_scale);
return result;
}

torch::Tensor int4_gemm_w4a16(
const torch::Tensor& A_, // src, [b, m, k]
const torch::Tensor& B, // quantized weight, [k, n]
Expand Down
Loading
Loading