diff --git a/csrc/xpu/onednn/fp4_gemm_w4a4.h b/csrc/xpu/onednn/fp4_gemm_w4a4.h new file mode 100644 index 000000000..167597e5e --- /dev/null +++ b/csrc/xpu/onednn/fp4_gemm_w4a4.h @@ -0,0 +1,168 @@ +#pragma once + +#include +#include +#include + +#include "onednn_ext.h" +#include "onednn_runtime.h" + +namespace oneDNN { + +static inline void dnnl_matmul_w4a4_fp4( + torch::Tensor& result, // dst, [b, m, n] + const torch::Tensor& mat1, // src, [b, m, k] + const torch::Tensor& mat2, // quantized weight, [k, n] transpose + bool is_nt, + const std::optional& bias, + const torch::Tensor& m1_sc, + const torch::Tensor& m2_sc) { + auto src_sz = mat1.sizes(); + auto o_sz = result.sizes(); + + const int m = std::reduce( + src_sz.begin(), src_sz.end() - 1, 1, std::multiplies()); + const int n = o_sz.back(); // presume channel last format + const int k = (*(src_sz.end() - 1)) * 2; + + // get joint dtypes + joint_dtypes_t jd; + auto out_dtype = result.scalar_type(); + auto m1_sc_dtype = m1_sc.scalar_type(); + auto m2_sc_dtype = m2_sc.scalar_type(); + if (m1_sc_dtype == at::ScalarType::Float8_e8m0fnu) { + TORCH_CHECK( + m2_sc_dtype == at::ScalarType::Float8_e8m0fnu, + "Mismatched scale data types in mxfp4 matmul: ", + m1_sc_dtype, + " vs ", + m2_sc_dtype); + jd = out_dtype == at::ScalarType::BFloat16 ? joint_dtypes_t::mxfp4_bf16 + : joint_dtypes_t::mxfp4_f16; + } else { + TORCH_INTERNAL_ASSERT( + false, "Unsupported scale type for fp4 matmul: ", m1_sc_dtype); + } + + // get bias type + bias_type_t b_type = get_bias_type(bias, m, n); + + trans_type_t tt = trans_type_t::nn; + if (is_nt) { + // transpose mat2 + tt = trans_type_t::nt; + } + + // get lda ldb and ldc + auto mat1_strides = mat1.strides(); + int64_t leading_dim = -1; + if (mat1.dim() == 2) { + leading_dim = 0; + } else if (mat1.dim() == 3) { + leading_dim = mat1_strides[0] < mat1_strides[1] ? 0 : 1; + } else { + TORCH_CHECK( + false, "Unsupported input dimension for fp4 matmul: ", mat1.dim()); + } + int64_t lda = 2 * mat1_strides[leading_dim]; + int64_t ldb = mat2.strides()[mat2.dim() - 1] == 1 + ? mat2.strides()[mat2.dim() - 2] + : 2 * (mat2.strides()[mat2.dim() - 1]); + int64_t ldc = result.strides()[leading_dim]; + + auto f_attr = [&](dnnl::primitive_attr& pattr) { + pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + + if (m1_sc_dtype == at::ScalarType::Float8_e8m0fnu) { + pattr.set_scales( + DNNL_ARG_SRC, + /* mask */ (1 << 0) + (1 << 1), + {1, 32}, + get_onednn_dtype(m1_sc)); + pattr.set_scales( + DNNL_ARG_WEIGHTS, + /* mask */ (1 << 0) + (1 << 1), + {32, 1}, + get_onednn_dtype(m2_sc)); + } else { + if (m1_sc.numel() == 1) { + pattr.set_scales( + DNNL_ARG_SRC, + /* mask */ 0, + {}, + get_onednn_dtype(m1_sc)); + /* per tensor quant */ + } else { + pattr.set_scales( + DNNL_ARG_SRC, + /* mask */ (1 << 0) + (1 << 1), + {1, k}, + get_onednn_dtype(m1_sc)); + /* per token quant */ + } + + if (m2_sc.numel() == 1) { + pattr.set_scales( + DNNL_ARG_WEIGHTS, + /* mask */ 0, + {}, + get_onednn_dtype(m2_sc)); + /* per tensor quant */ + } else { + pattr.set_scales( + DNNL_ARG_WEIGHTS, + /* mask */ (1 << 1), + {}, + get_onednn_dtype(m2_sc)); + /* per channel quant */ + } + } + }; + + int arg_off = 0; + + // ************************************************************ + // get device, engine, stream + const int dev_id = c10::xpu::getCurrentXPUStream().device_index(); + at::Device curDevice = at::Device(at::kXPU, dev_id); + auto engine = GpuEngineManager::Instance().get_engine(curDevice); + + int m1_sc_group_size = m1_sc.numel(); + int m2_sc_group_size = m2_sc.numel(); + int sc_group_size = (m1_sc_group_size << 8) | m2_sc_group_size; + auto& matmul_ext = matmul_primitive_create_and_cache( + jd, tt, b_type, m, n, k, lda, ldb, ldc, dev_id, f_attr, sc_group_size); + + matmul_ext.set_attribute( + arg_off++, + DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, + m2_sc.data_ptr(), + [&]() { + return make_onednn_memory( + get_onednn_md(m2_sc), engine, m2_sc.data_ptr()); + }); + matmul_ext.set_attribute( + arg_off++, DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, m1_sc.data_ptr(), [&]() { + return make_onednn_memory( + get_onednn_md(m1_sc), engine, m1_sc.data_ptr()); + }); + + std::vector> arg_handles; + arg_handles.reserve(8); + + arg_handles.emplace_back(DNNL_ARG_SRC, mat1.data_ptr()); + arg_handles.emplace_back(DNNL_ARG_WEIGHTS, mat2.data_ptr()); + arg_handles.emplace_back(DNNL_ARG_DST, result.data_ptr()); + if (get_shape(b_type) != bias_shape_t::none) { + arg_handles.emplace_back(DNNL_ARG_BIAS, bias.value().data_ptr()); + } + int scratchpad_size = matmul_ext.get_scratchpad_size(); + torch::Tensor scratchpad_tensor = at::empty( + {scratchpad_size}, mat1.options().dtype(at::kByte), c10::nullopt); + arg_handles.emplace_back(DNNL_ARG_SCRATCHPAD, scratchpad_tensor.data_ptr()); + + auto& strm = GpuStreamManager::Instance().get_stream(); + auto qfp4_matmul_event = + matmul_ext.execute(strm, engine, std::move(arg_handles), arg_off); +} +} // namespace oneDNN diff --git a/csrc/xpu/onednn/fp8_gemm_w8a8.h b/csrc/xpu/onednn/fp8_gemm_w8a8.h index 176fa18b6..8c2368639 100644 --- a/csrc/xpu/onednn/fp8_gemm_w8a8.h +++ b/csrc/xpu/onednn/fp8_gemm_w8a8.h @@ -39,7 +39,7 @@ static inline void dnnl_matmul_w8a8_fp8( : joint_dtypes_t::f8_e4m3_f16; } else { TORCH_INTERNAL_ASSERT( - false, "Unsupported data type for fp8 matmul: ", mat1.scalar_type()); + false, "Unsupported data type for fp8 matmul: ", in_dtype); } // get bias type @@ -68,38 +68,60 @@ static inline void dnnl_matmul_w8a8_fp8( : mat2.strides()[mat2.dim() - 1]; int64_t ldc = result.strides()[leading_dim]; + auto m1_sc_dtype = m1_sc.scalar_type(); + auto m2_sc_dtype = m2_sc.scalar_type(); auto f_attr = [&](dnnl::primitive_attr& pattr) { pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user); - if (m1_sc.numel() == 1) { - pattr.set_scales( - DNNL_ARG_SRC, - /* mask */ 0, - {}, - get_onednn_dtype(m1_sc)); - /* per tensor quant */ - } else { + + if (m1_sc_dtype == at::ScalarType::Float8_e8m0fnu) { + TORCH_CHECK( + m2_sc_dtype == at::ScalarType::Float8_e8m0fnu, + "Mismatched scale data types in mxfp8 matmul: ", + m1_sc_dtype, + " vs ", + m2_sc_dtype); pattr.set_scales( DNNL_ARG_SRC, /* mask */ (1 << 0) + (1 << 1), - {1, k}, + {1, 32}, get_onednn_dtype(m1_sc)); - /* per token quant */ - } - - if (m2_sc.numel() == 1) { pattr.set_scales( DNNL_ARG_WEIGHTS, - /* mask */ 0, - {}, + /* mask */ (1 << 0) + (1 << 1), + {32, 1}, get_onednn_dtype(m2_sc)); - /* per tensor quant */ } else { - pattr.set_scales( - DNNL_ARG_WEIGHTS, - /* mask */ (1 << 1), - {}, - get_onednn_dtype(m2_sc)); - /* per channel quant */ + if (m1_sc.numel() == 1) { + pattr.set_scales( + DNNL_ARG_SRC, + /* mask */ 0, + {}, + get_onednn_dtype(m1_sc)); + /* per tensor quant */ + } else { + pattr.set_scales( + DNNL_ARG_SRC, + /* mask */ (1 << 0) + (1 << 1), + {1, k}, + get_onednn_dtype(m1_sc)); + /* per token quant */ + } + + if (m2_sc.numel() == 1) { + pattr.set_scales( + DNNL_ARG_WEIGHTS, + /* mask */ 0, + {}, + get_onednn_dtype(m2_sc)); + /* per tensor quant */ + } else { + pattr.set_scales( + DNNL_ARG_WEIGHTS, + /* mask */ (1 << 1), + {}, + get_onednn_dtype(m2_sc)); + /* per channel quant */ + } } }; diff --git a/csrc/xpu/onednn/onednn_ext.h b/csrc/xpu/onednn/onednn_ext.h index a87f288c9..fe024069c 100644 --- a/csrc/xpu/onednn/onednn_ext.h +++ b/csrc/xpu/onednn/onednn_ext.h @@ -52,6 +52,8 @@ enum class joint_dtypes_t { f8_e5m2_bf16, f8_e4m3_f16, f8_e4m3_bf16, + mxfp4_bf16, + mxfp4_f16, }; template @@ -195,6 +197,30 @@ struct onednn_types_mapper { } }; +template <> +struct onednn_types_mapper { + static inline std:: + tuple + get() { + return std::make_tuple( + memory::data_type::f4_e2m1, + memory::data_type::f4_e2m1, + memory::data_type::bf16); + } +}; + +template <> +struct onednn_types_mapper { + static inline std:: + tuple + get() { + return std::make_tuple( + memory::data_type::f4_e2m1, + memory::data_type::f4_e2m1, + memory::data_type::f16); + } +}; + static inline memory::data_type get_onednn_dtype(const at::Tensor& tensor, bool allow_undef = false) { switch (tensor.scalar_type()) { @@ -218,6 +244,10 @@ get_onednn_dtype(const at::Tensor& tensor, bool allow_undef = false) { return memory::data_type::f8_e4m3; case at::ScalarType::Float8_e5m2: return memory::data_type::f8_e5m2; + case at::ScalarType::Float8_e8m0fnu: + return memory::data_type::e8m0; + case at::ScalarType::Float4_e2m1fn_x2: + return memory::data_type::f4_e2m1; default: if (!allow_undef) { TORCH_CHECK( @@ -1016,6 +1046,34 @@ static inline primitive_ext& matmul_primitive_create_and_cache( attr, scale_group_size, zp_group_size); + case joint_dtypes_t::mxfp4_bf16: + return matmul_primitive_create_and_cache( + Tt, + b_type, + m, + n, + k, + lda, + ldb, + ldc, + device_id, + attr, + scale_group_size, + zp_group_size); + case joint_dtypes_t::mxfp4_f16: + return matmul_primitive_create_and_cache( + Tt, + b_type, + m, + n, + k, + lda, + ldb, + ldc, + device_id, + attr, + scale_group_size, + zp_group_size); default: throw std::runtime_error("Only support int4 and fp8 gemm ..."); } diff --git a/csrc/xpu/onednn/onednn_matmul.cpp b/csrc/xpu/onednn/onednn_matmul.cpp index 061e6c5f5..2a3dace10 100644 --- a/csrc/xpu/onednn/onednn_matmul.cpp +++ b/csrc/xpu/onednn/onednn_matmul.cpp @@ -1,4 +1,5 @@ #include +#include "fp4_gemm_w4a4.h" #include "fp8_gemm_w8a8.h" #include "fp8_gemm_w8a16.h" #include "int4_gemm_w4a16.h" @@ -9,6 +10,10 @@ inline bool is_supported_fp8(at::ScalarType t) { (t == at::ScalarType::Float8_e4m3fn); } +inline bool is_supported_fp4(at::ScalarType t) { + return t == at::ScalarType::Float4_e2m1fn_x2; +} + torch::Tensor check_and_create_output_tensor( const torch::Tensor& A, const torch::Tensor& B, @@ -93,6 +98,30 @@ torch::Tensor fp8_gemm_w8a16( return result; } +torch::Tensor fp4_gemm( + const torch::Tensor& A, + const torch::Tensor& B, + const torch::Tensor& A_scale, + const torch::Tensor& B_scale, + std::optional out_dtype, + const std::optional& bias) { + const at::DeviceGuard device_guard(A.device()); + torch::Tensor result = check_and_create_output_tensor(A, B, out_dtype); + auto a_st = A.scalar_type(); + auto b_st = B.scalar_type(); + TORCH_CHECK( + is_supported_fp4(a_st) && is_supported_fp4(b_st) && a_st == b_st, + "input and weight must be f4_e2m1x2 or f4_e2m1x2 for fp4 matmul"); + TORCH_CHECK( + result.scalar_type() == torch::kFloat16 || + result.scalar_type() == torch::kBFloat16, + "output must be float16 or bfloat16 for fp4 matmul"); + // check if nt format + bool is_nt = B.strides()[B.dim() - 2] == 1; + oneDNN::dnnl_matmul_w4a4_fp4(result, A, B, is_nt, bias, A_scale, B_scale); + return result; +} + torch::Tensor int4_gemm_w4a16( const torch::Tensor& A_, // src, [b, m, k] const torch::Tensor& B, // quantized weight, [k, n] diff --git a/csrc/xpu/ops.h b/csrc/xpu/ops.h index d7eaad965..1f26b7554 100644 --- a/csrc/xpu/ops.h +++ b/csrc/xpu/ops.h @@ -5,7 +5,7 @@ /** * Make sure the shape of A and B is correctly setting before calling below gemm * method implemented with OneDNN. - * A should be one of [b, m, k] and [m, k] + * A should be one of [b, m, k] and [m, k] or [m, k//2] in fp4 precision * B should be [k, n] or [k//8, n] in int4 precision, where [k//8, n] indicates * a packed representation with 8 int4 values packed into one byte along the k * dimension. @@ -24,6 +24,14 @@ torch::Tensor fp8_gemm_w8a16( const std::optional& B_scale_, const std::optional& bias_); +torch::Tensor fp4_gemm( + const torch::Tensor& A, + const torch::Tensor& B, + const torch::Tensor& A_scale, + const torch::Tensor& B_scale, + std::optional out_dtype, + const std::optional& bias); + torch::Tensor int4_gemm_w4a16( const torch::Tensor& A_, const torch::Tensor& B, diff --git a/csrc/xpu/torch_bindings.cpp b/csrc/xpu/torch_bindings.cpp index cdfa12844..05075919b 100644 --- a/csrc/xpu/torch_bindings.cpp +++ b/csrc/xpu/torch_bindings.cpp @@ -19,6 +19,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, xpu_ops) { "Tensor? bias_) -> Tensor"); xpu_ops.impl("fp8_gemm_w8a16", torch::kXPU, &fp8_gemm_w8a16); + xpu_ops.def( + "fp4_gemm(Tensor A, Tensor B, Tensor A_scale, Tensor B_scale, " + "ScalarType? out_dtype, Tensor? bias_) -> Tensor"); + xpu_ops.impl("fp4_gemm", torch::kXPU, &fp4_gemm); + xpu_ops.def( "int4_gemm_w4a16(Tensor A, Tensor B, Tensor? bias, Tensor B_scale, " "Tensor B_zp, int group_size, Tensor? g_idx) -> Tensor"); diff --git a/tests/register_ops.py b/tests/register_ops.py index 0c965d441..c5ee99d21 100644 --- a/tests/register_ops.py +++ b/tests/register_ops.py @@ -329,6 +329,18 @@ def fp8_gemm(input: torch.Tensor, weight: torch.Tensor, scale_wei, bias) +def fp4_gemm( + input: torch.Tensor, + weight: torch.Tensor, + scale_act: torch.Tensor, + scale_wei: torch.Tensor, + out_dtype: Optional[torch.dtype], + bias: Optional[torch.Tensor], +): + return torch.ops._xpu_C.fp4_gemm(input, weight, scale_act, scale_wei, + out_dtype, bias) + + def fp8_gemm_w8a16(input: torch.Tensor, weight: torch.Tensor, scale_wei: Optional[torch.Tensor], scale_act: Optional[torch.Tensor]): diff --git a/tests/test_fp4_gemm_onednn.py b/tests/test_fp4_gemm_onednn.py new file mode 100644 index 000000000..95c7184ea --- /dev/null +++ b/tests/test_fp4_gemm_onednn.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from tests.ops.mx_utils import (FP4_EBITS, FP4_MBITS, _floatx_unpacked_to_f32, + from_blocked_format, to_mxfp, unpack_uint4) +from tests.register_ops import fp4_gemm + +OUT_DTYPES = [torch.float16, torch.bfloat16] +MNK_FACTORS = [ + (1, 32, 1024), + (1, 32, 2048), + (1, 32, 5120), + (8, 512, 2048), + (512, 1024, 2048), + (1024, 2048, 2048), +] + +MINI_MX_MNK_FACTORS = [ + (1, 32, 32), + (1, 64, 32), + (32, 32, 32), + (32, 64, 32), +] + +# override pytest parameters when enable mini pytest +MINI_PYTEST_PARAMS = { + "default": { + "mnk_factors": MINI_MX_MNK_FACTORS, + } +} + + +def _convert_to_mxfp4_with_hp_ref(t): + # Convert a tensor to mxfp4, returning: + # t_hp : reconstructed bf16 version of t_lp + # t_lp : fp4_e2m1x2 tensor + # t_scale: fp8_e8m0 block-wise scaling factors (non-swizzled) + t_scale, t_lp = to_mxfp(t, format="mxfp4") + t_hp = from_blocked_format( + _floatx_unpacked_to_f32(unpack_uint4(t_lp), FP4_EBITS, FP4_MBITS), + t_scale, + blocksize=32, + ) + + return t_hp, t_lp, t_scale + + +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("out_dtype", OUT_DTYPES) +def test_mxfp4_gemm(mnk_factors, out_dtype): + m, n, k = mnk_factors + + inputs = torch.randn((m, k), dtype=out_dtype).xpu() * 0.01 + weights = torch.randn((n, k), dtype=out_dtype).xpu() * 0.01 + + # Reference: to_mxfp operates on float32 or bfloat16. + if out_dtype is torch.half: + inputs = inputs.to(torch.float) + weights = weights.to(torch.float) + inputs_hp, inputs_lp, inputs_scale = _convert_to_mxfp4_with_hp_ref(inputs) + weights_hp, weights_lp, weights_scale = _convert_to_mxfp4_with_hp_ref( + weights) + + output = fp4_gemm( + inputs_lp, + weights_lp.transpose(0, 1), + inputs_scale, + weights_scale, + out_dtype, + torch.Tensor(), + ) + + output_ref = torch.matmul(inputs_hp.to(out_dtype), + weights_hp.to(out_dtype).t()) + torch.testing.assert_close(output.to(torch.float), + output_ref.to(torch.float), + atol=5e-2, + rtol=5e-2) diff --git a/tests/test_fp8_gemm_onednn.py b/tests/test_fp8_gemm_onednn.py index fe5c540f0..eaef35520 100644 --- a/tests/test_fp8_gemm_onednn.py +++ b/tests/test_fp8_gemm_onednn.py @@ -3,15 +3,18 @@ import torch from tests.ops.fp8_quant_op import scaled_fp8_quant +from tests.ops.mx_utils import from_blocked_format, to_mxfp from tests.register_ops import fp8_gemm, fp8_gemm_w8a16 -BATCHES = [1] +BATCHES = [1, 2, 8] +OUT_DTYPES = [torch.float16, torch.bfloat16] MNK_FACTORS = [ - (1, 4096, 1), (1, 32, 1024), - (4, 16, 1024), - (8, 32, 1024), - (8, 512, 1024), + (1, 32, 2048), + (1, 32, 5120), + (8, 512, 2048), + (512, 1024, 2048), + (1024, 2048, 2048), ] MINI_MNK_FACTORS = [ @@ -20,9 +23,17 @@ (4, 32, 16), ] -#override pytest parameters when enable mini pytest +MINI_MX_MNK_FACTORS = [ + (1, 32, 32), + (1, 64, 32), + (32, 32, 32), + (32, 64, 32), +] + +# override pytest parameters when enable mini pytest MINI_PYTEST_PARAMS = { "test_fp8_gemm_w8a16": { + "batch": 1, "mnk_factors": MINI_MNK_FACTORS[:1], }, "test_fp8_gemm_per_tensor": { @@ -34,39 +45,44 @@ "test_fp8_gemm_w8a16_per_channel": { "mnk_factors": MINI_MNK_FACTORS[:1], }, + "test_mxfp8_gemm": { + "mnk_factors": MINI_MX_MNK_FACTORS, + }, } @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("out_dtype", OUT_DTYPES) @pytest.mark.parametrize("trans_wei", [True, False]) @pytest.mark.parametrize("is_mbk", [True, False]) @pytest.mark.parametrize("batch", BATCHES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_fp8_gemm_w8a16(fp8_dtype, dtype, trans_wei, is_mbk, batch, +def test_fp8_gemm_w8a16(fp8_dtype, out_dtype, trans_wei, is_mbk, batch, mnk_factors): seed = 1234 torch.manual_seed(seed) m, n, k = mnk_factors - input = torch.randn([batch, m, k], dtype=dtype, - device=torch.device("xpu")) / 10.0 + input = torch.randn( + [batch, m, k], dtype=out_dtype, device=torch.device("xpu")) / 10.0 if trans_wei: - weight = torch.ones([n, k], dtype=dtype).xpu() + weight = torch.ones([n, k], dtype=out_dtype).xpu() else: - weight = torch.ones([k, n], dtype=dtype).xpu() - scale_wei = (torch.ones(batch) * 4).xpu() - scale_shape = None + weight = torch.ones([k, n], dtype=out_dtype).xpu() + scale_wei = torch.tensor(4.0).xpu() - weight_fp8, _ = scaled_fp8_quant(weight, scale_wei, False, False, - fp8_dtype, scale_shape) + weight_fp8, _ = scaled_fp8_quant(weight, + scale_wei, + fp8_dtype=fp8_dtype, + group_shape=(-1, 1)) + weight_fp8_hp = weight_fp8.to(out_dtype) * scale_wei.to(out_dtype) # reference fp16 gemm if trans_wei: - output_ref = torch.matmul(input, weight.t()) + output_ref = torch.matmul(input, weight_fp8_hp.t()) else: - output_ref = torch.matmul(input, weight) + output_ref = torch.matmul(input, weight_fp8_hp) # onednn fp8 gemm if is_mbk: @@ -83,37 +99,33 @@ def test_fp8_gemm_w8a16(fp8_dtype, dtype, trans_wei, is_mbk, batch, @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("out_dtype", OUT_DTYPES) @pytest.mark.parametrize("is_nt", [True, False]) @pytest.mark.parametrize("batch", BATCHES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_fp8_gemm_per_tensor(fp8_dtype, dtype, is_nt, batch, mnk_factors): +def test_fp8_gemm_per_tensor(fp8_dtype, out_dtype, is_nt, batch, mnk_factors): seed = 1234 torch.manual_seed(seed) m, n, k = mnk_factors input = torch.randn( - [batch * m, k], dtype=dtype, device=torch.device("xpu")) / 10.0 - weight = torch.randn([n, k], dtype=dtype).xpu() / 10.0 + [batch, m, k], dtype=out_dtype, device=torch.device("xpu")) / 10.0 + weight = torch.randn([n, k], dtype=out_dtype).xpu() / 10.0 - scale_src = (torch.ones(batch) * 4).xpu() - scale_wei = (torch.ones(batch) * 4).xpu() + scale_src = torch.tensor(4.0).xpu() + scale_wei = torch.tensor(4.0).xpu() input_fp8, _ = scaled_fp8_quant(input.reshape(-1, k), scale_src, - False, - False, fp8_dtype=fp8_dtype) + weight_fp8, _ = scaled_fp8_quant(weight, scale_wei, fp8_dtype=fp8_dtype) - weight_fp8, _ = scaled_fp8_quant(weight, - scale_wei, - False, - False, - fp8_dtype=fp8_dtype) + input_fp8_hp = input_fp8.to(out_dtype) * scale_src.to(out_dtype) + weight_fp8_hp = weight_fp8.to(out_dtype) * scale_wei.to(out_dtype) # reference fp16 gemm - output_ref = torch.matmul(input, weight.t()) + output_ref = torch.matmul(input_fp8_hp, weight_fp8_hp.t()) weight_fp8 = weight_fp8.transpose(0, 1) if is_nt: @@ -122,7 +134,7 @@ def test_fp8_gemm_per_tensor(fp8_dtype, dtype, is_nt, batch, mnk_factors): output_fp8 = fp8_gemm( input_fp8, weight_fp8, - dtype, + out_dtype, scale_src, scale_wei, torch.Tensor(), @@ -132,30 +144,31 @@ def test_fp8_gemm_per_tensor(fp8_dtype, dtype, is_nt, batch, mnk_factors): @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("out_dtype", OUT_DTYPES) @pytest.mark.parametrize("is_nt", [True, False]) @pytest.mark.parametrize("batch", BATCHES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_fp8_gemm_per_channel(fp8_dtype, dtype, is_nt, batch, mnk_factors): +def test_fp8_gemm_per_channel(fp8_dtype, out_dtype, is_nt, batch, mnk_factors): seed = 1234 torch.manual_seed(seed) m, n, k = mnk_factors input = torch.randn( - [batch * m, k], dtype=dtype, device=torch.device("xpu")) / 10.0 - weight = torch.randn([n, k], dtype=dtype).xpu() / 10.0 + [batch, m, k], dtype=out_dtype, device=torch.device("xpu")) / 10.0 + weight = torch.randn([n, k], dtype=out_dtype).xpu() / 10.0 input_fp8, scale_src_fp8 = scaled_fp8_quant(input.reshape(-1, k), use_per_token_if_dynamic=True, fp8_dtype=fp8_dtype) - weight_fp8, scale_wei_fp8 = scaled_fp8_quant(weight, use_per_token_if_dynamic=True, fp8_dtype=fp8_dtype) # reference fp16 gemm - output_ref = torch.matmul(input, weight.t()) + input_fp8_hp = input_fp8.to(out_dtype) * scale_src_fp8.to(out_dtype) + weight_fp8_hp = weight_fp8.to(out_dtype) * scale_wei_fp8.to(out_dtype) + output_ref = torch.matmul(input_fp8_hp, weight_fp8_hp.t()) weight_fp8 = weight_fp8.transpose(0, 1) if is_nt: @@ -164,7 +177,7 @@ def test_fp8_gemm_per_channel(fp8_dtype, dtype, is_nt, batch, mnk_factors): output_fp8 = fp8_gemm( input_fp8, weight_fp8, - dtype, + out_dtype, scale_src_fp8, scale_wei_fp8, torch.Tensor(), @@ -174,21 +187,21 @@ def test_fp8_gemm_per_channel(fp8_dtype, dtype, is_nt, batch, mnk_factors): @pytest.mark.parametrize("fp8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("out_dtype", OUT_DTYPES) @pytest.mark.parametrize("is_nt", [True, False]) @pytest.mark.parametrize("is_mbk", [True, False]) @pytest.mark.parametrize("batch", BATCHES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -def test_fp8_gemm_w8a16_per_channel(fp8_dtype, dtype, is_nt, is_mbk, batch, +def test_fp8_gemm_w8a16_per_channel(fp8_dtype, out_dtype, is_nt, is_mbk, batch, mnk_factors): seed = 1234 torch.manual_seed(seed) m, n, k = mnk_factors - input = torch.randn([batch, m, k], dtype=dtype, - device=torch.device("xpu")) / 10.0 - weight = torch.randn([n, k], dtype=dtype).xpu() / 10.0 + input = torch.randn( + [batch, m, k], dtype=out_dtype, device=torch.device("xpu")) / 10.0 + weight = torch.randn([n, k], dtype=out_dtype).xpu() / 10.0 weight_fp8, scale_wei_fp8 = scaled_fp8_quant(weight, use_per_token_if_dynamic=True, @@ -197,7 +210,7 @@ def test_fp8_gemm_w8a16_per_channel(fp8_dtype, dtype, is_nt, is_mbk, batch, scale_wei_flat = scale_wei_fp8.flatten() # reference: dequantize weight then fp16/bf16 matmul - weight_dequant = weight_fp8.to(dtype) * scale_wei_fp8.to(dtype) + weight_dequant = weight_fp8.to(out_dtype) * scale_wei_fp8.to(out_dtype) output_ref = torch.matmul(input, weight_dequant.t()) weight_fp8_t = weight_fp8.transpose(0, 1) @@ -215,3 +228,44 @@ def test_fp8_gemm_w8a16_per_channel(fp8_dtype, dtype, is_nt, is_mbk, batch, output_fp8 = output_fp8.transpose(0, 1) if is_mbk else output_fp8 torch.testing.assert_close(output_fp8, output_ref, atol=5e-2, rtol=5e-2) + + +def _convert_to_mxfp8_with_hp_ref(t): + # Convert a tensor to mxfp8, returning: + # t_hp : reconstructed bf16 version of t_lp + # t_lp : fp8_e4m3 tensor + # t_scale: fp8_e8m0 block-wise scaling factors (non-swizzled) + t_scale, t_lp = to_mxfp(t, format="mxfp8") + t_hp = from_blocked_format(t_lp, t_scale, blocksize=32) + + return t_hp, t_lp, t_scale + + +@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) +@pytest.mark.parametrize("out_dtype", OUT_DTYPES) +def test_mxfp8_gemm(mnk_factors, out_dtype): + m, n, k = mnk_factors + inputs = torch.randn((m, k), dtype=out_dtype).xpu() * 0.01 + weights = torch.randn((n, k), dtype=out_dtype).xpu() * 0.01 + + # Reference: to_mxfp operates on float32 or bfloat16. + if out_dtype == torch.half: + inputs = inputs.to(torch.float32) + weights = weights.to(torch.float32) + + inputs_hp, inputs_lp, inputs_scale = _convert_to_mxfp8_with_hp_ref(inputs) + weights_hp, weights_lp, weights_scale = _convert_to_mxfp8_with_hp_ref( + weights) + + output = fp8_gemm( + inputs_lp, + weights_lp.transpose(0, 1), + out_dtype, + inputs_scale, + weights_scale, + torch.Tensor(), + ) + + output_ref = torch.matmul(inputs_hp.to(out_dtype), + weights_hp.to(out_dtype).t()) + torch.testing.assert_close(output, output_ref, atol=5e-2, rtol=5e-2)