diff --git a/applications/dual_gemm/collective/xe_dual_gemm_mma.hpp b/applications/dual_gemm/collective/xe_dual_gemm_mma.hpp index 9e3ae57eaa..5f98666621 100644 --- a/applications/dual_gemm/collective/xe_dual_gemm_mma.hpp +++ b/applications/dual_gemm/collective/xe_dual_gemm_mma.hpp @@ -98,7 +98,7 @@ struct DualGemmMma, TileShape_, ElementA_ using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), StrideA{})); //(m, k) using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), StrideB{})); //(n, k) - + using MainloopTensors = cute::tuple; using CopyThreadShape = Shape<_1, Int>; // Host side kernel arguments diff --git a/applications/dual_gemm/kernel/xe_dual_gemm.hpp b/applications/dual_gemm/kernel/xe_dual_gemm.hpp index 6a40509412..8d87a65cf9 100644 --- a/applications/dual_gemm/kernel/xe_dual_gemm.hpp +++ b/applications/dual_gemm/kernel/xe_dual_gemm.hpp @@ -125,7 +125,7 @@ class DualGemm using TensorMKL = typename DualGemmMainloop::TensorMKL; using TensorNKL = typename DualGemmMainloop::TensorNKL; - + using MainloopTensors = cute::tuple; using TensorMK = decltype(TensorMKL{}(_, _, 0)); using TensorNK = decltype(TensorNKL{}(_, _, 0)); diff --git a/examples/sycl/10_bmg_grouped_gemm_mixed_dtype/10_bmg_grouped_gemm_bf16_f16_s8.cpp b/examples/sycl/10_bmg_grouped_gemm_mixed_dtype/10_bmg_grouped_gemm_bf16_f16_s8.cpp new file mode 100644 index 0000000000..248d61ac43 --- /dev/null +++ b/examples/sycl/10_bmg_grouped_gemm_mixed_dtype/10_bmg_grouped_gemm_bf16_f16_s8.cpp @@ -0,0 +1,237 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS Intel BMG Grouped Gemm with mixed input types + + This example demonstrates how to dispatch a mixed precision GEMM (int8 and bfloat16 | half_t) on BMG, with + optional dequantization. The GemmMode enum describes the 3 modes of operation: + + - ConvertOnly: Narrower type is simply converted to the wider type before MMA + - ConvertAndScale: Narrower type is converted to wider type, then scaled + - ConvertAndScaleWithZeroPoint: Narrower type is converted to wider type, scaled and offset + + - Requirements: + - dequantization group size (options.g) must be multiple of k-block size + - scales & zeros must be MN-major + + The MMA operation itself takes bfloat16 input for both A and B, and so the narrower type is first + upcasted (inside the mainloop) prior to being passed into the MMA atom. + + Verification for this example is performed against a standard reference GEMM in the wider type. + The narrow-type input data are upcasted (or dequantized) externally before executing the + reference GEMM. + + Note: due to a bug in the IGC compiler, it's currently necessary to build this example with the + following environment variable set (CMake handles this for AOT compilation; for JIT, please set + this in your environment): + + export IGC_allowDecompose2DBlockFuncs=0 + + To build & run this example (from your build dir): + + $ ninja 10_bmg_grouped_gemm_bf16_s8 + $ ./examples/sycl/10_bmg_grouped_gemm_mixed_dtype/10_bmg_grouped_gemm_bf16_s8 + $ ninja 10_bmg_grouped_gemm_f16_s8_tensorwise + $ ./examples/sycl/10_bmg_grouped_gemm_mixed_dtype/10_bmg_grouped_gemm_f16_s8_tensorwise + + Call with `--help` for information about available options +*/ + +#include "bmg_grouped_gemm_mixed_dtype_runner.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = cutlass::QUANT_TYPE; // <- data type of elements in input matrix A + using ElementInputB = cutlass::MMA_TYPE; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementZero = cutlass::MMA_TYPE; + using ElementScale = cutlass::MMA_TYPE; + using StrideScale = cute::Stride<_1, int64_t, int64_t>; + using StrideZero = StrideScale; + + using GmemTiledCopyA = XE_2D_U8x32x32_LD_N; // U8 (1-byte) block copy for A (narrower type) + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; // U16 (2-byte) block copy for B (wider type) + static_assert(sizeof(ElementInputA) == 1, "ElementA width must match GmemTiledCopyA U8"); + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + // Although this is a mixed type example, the actual MMA accepts bf16 input for both A and B: + using TiledMma = // M=8,N=16,K=16, D=f32,A=bf16,B=bf16,C=f32 + typename TiledMMAHelper::type>, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 3; // prefetch 3 iters of data for A and B + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16GroupMixedPrecision; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; + + // Default (Linear Combination) epilogue + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U32x8x16_ST_N, + void, void>; + + // Use the helpers to avoid template arg repetition + using GemmAdapterBuilder = helpers::MixedGemmUniversalAdapterBuilder; + + using MixedBuilderQuantA = + helpers::MixedCollectiveMmaBuilder, + cutlass::gemm::TagToStrideB_t, + TiledMma, GmemTiledCopyA, GmemTiledCopyB>; + + using MixedBuilderQuantB = + helpers::MixedCollectiveMmaBuilder, + cutlass::gemm::TagToStrideB_t, + TiledMma, GmemTiledCopyB, GmemTiledCopyA>; + + // A-narrow Mainloop & GemmUniversalAdapter + using MainloopAConvertOnly = + MixedBuilderQuantA::CollectiveMma, + ElementInputB>; + using GemmAConvertOnly = + GemmAdapterBuilder::GemmUniversalAdapter; + + using MainloopAConvertAndScale = MixedBuilderQuantA::CollectiveMma< + cute::tuple, ElementInputB>; + using GemmAConvertAndScale = + GemmAdapterBuilder::GemmUniversalAdapter; + + using MainloopAConvertAndScaleWithZeroPoint = + MixedBuilderQuantA::CollectiveMma< + cute::tuple, ElementInputB>; + using GemmAConvertAndScaleWithZeroPoint = + GemmAdapterBuilder::GemmUniversalAdapter< + MainloopAConvertAndScaleWithZeroPoint>; + + // B-narrow Mainloop & GemmUniversalAdapter + using MainloopBConvertOnly = + MixedBuilderQuantB::CollectiveMma>; + using GemmBConvertOnly = + GemmAdapterBuilder::GemmUniversalAdapter; + + using MainloopBConvertAndScale = MixedBuilderQuantB::CollectiveMma< + ElementInputB, cute::tuple>; + using GemmBConvertAndScale = + GemmAdapterBuilder::GemmUniversalAdapter; + + using MainloopBConvertAndScaleWithZeroPoint = + MixedBuilderQuantB::CollectiveMma< + ElementInputB, cute::tuple>; + using GemmBConvertAndScaleWithZeroPoint = + GemmAdapterBuilder::GemmUniversalAdapter< + MainloopBConvertAndScaleWithZeroPoint>; + + if(options.a_narrower){ + std::cout << "Setting A as narrower type" << std::endl; + if(options.mode == GemmMode::ConvertOnly) { + std::cout << "Running in ConvertOnly mode." << std::endl; + CUTLASS_CHECK(ExampleRunner{}.run(options, hw_info)); + } else if(options.mode == GemmMode::ConvertAndScale){ + std::cout << "Running in ConvertAndScale mode." << std::endl; + CUTLASS_CHECK(ExampleRunner{}.run(options, hw_info)); + } else { + std::cout << "Running in ConvertAndScaleWithZeroPoint mode." << std::endl; + CUTLASS_CHECK(ExampleRunner{}.run(options, hw_info)); + } + } else { + std::cout << "Setting B as narrower type" << std::endl; + if(options.mode == GemmMode::ConvertOnly) { + std::cout << "Running in ConvertOnly mode." << std::endl; + CUTLASS_CHECK(ExampleRunner{}.run(options, hw_info)); + } else if(options.mode == GemmMode::ConvertAndScale){ + std::cout << "Running in ConvertAndScale mode." << std::endl; + CUTLASS_CHECK(ExampleRunner{}.run(options, hw_info)); + } else { + std::cout << "Running in ConvertAndScaleWithZeroPoint mode." << std::endl; + CUTLASS_CHECK(ExampleRunner{}.run(options, hw_info)); + } + } + + return 0; +} diff --git a/examples/sycl/10_bmg_grouped_gemm_mixed_dtype/10_bmg_grouped_gemm_f16_u4.cpp b/examples/sycl/10_bmg_grouped_gemm_mixed_dtype/10_bmg_grouped_gemm_f16_u4.cpp new file mode 100755 index 0000000000..c0e893fcdc --- /dev/null +++ b/examples/sycl/10_bmg_grouped_gemm_mixed_dtype/10_bmg_grouped_gemm_f16_u4.cpp @@ -0,0 +1,167 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Mixed Precision BMG Grouped Gemm Example + + This example demonstrates how to dispatch a mixed precision Grouped GEMM on BMG, with optional dequantization. + The GemmMode enum describes the 3 modes of operation: + + - ConvertOnly: Narrower type is simply converted to the wider type before MMA + - ConvertAndScale: Narrower type is converted to wider type, then scaled + - ConvertAndScaleWithZeroPoint: Narrower type is converted to wider type, then scaled and shifted by zero point + - Limitations: + - group must be multiple of k-block size + - scales & zeros must be MN-major + + Note: due to a bug in the IGC compiler, it's currently necessary to build this example with the following + environment variable set: + export IGC_allowDecompose2DBlockFuncs=0 + To build & run this example (from your build dir): + + $ ninja 10_bmg_grouped_gemm_f16_u4 + $ ./examples/sycl/10_bmg_grouped_gemm_mixed_dtype/10_bmg_grouped_gemm_f16_u4 + + Call with `--help` for information about available options +*/ + +#include "bmg_grouped_gemm_mixed_dtype_runner.hpp" + +int main(int argc, const char** argv) { + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementInputA = uint4_t; // <- data type of elements in input matrix A + using ElementInputB = half_t; // <- data type of elements in input matrix B + using ElementOutput = half_t; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using ElementZero = int4_t; + using ElementScale = half_t; + + using StrideScale = cute::Stride<_1, int64_t, int64_t>; + using StrideZero = cute::Stride<_8, cute::Stride<_1, int64_t>, int64_t>; // int4_t zero point packed 8 elements along K dimension and then along N dimension + + using GmemTiledCopyA = XE_2D_U4x32x16_LD_T; + using GmemTiledCopyB = XE_2D_U16x16x32_LD_N; + + // Workgroup-level tile + using TileShape = Shape<_16, _64, _64>; + + using TiledMma = + typename TiledMMAHelper::type>, Layout, + Layout, Stride<_2, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 3; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16GroupMixedPrecision; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16Group; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + XE_2D_U16x8x16_ST_N, + void, void>; + + // Use the helpers to avoid template arg repetition + using GemmAdapterBuilder = typename helpers::MixedGemmUniversalAdapterBuilder; + + if(options.a_narrower){ + // TODO: this feature not support now + std::cout << "Not support setting A as narrower type for int4 now." << std::endl; + } else { + std::cout << "Setting B as narrower type" << std::endl; + using MixedBuilderQuant = helpers::MixedCollectiveMmaBuilder, + cutlass::gemm::TagToStrideB_t, + TiledMma, GmemTiledCopyB, GmemTiledCopyA>; + if(options.mode == GemmMode::ConvertOnly) { + std::cout << "Running in ConvertOnly mode." << std::endl; + using MainloopConvertOnly = MixedBuilderQuant::template CollectiveMma>; + using GemmConvertOnly = GemmAdapterBuilder::template GemmUniversalAdapter; + CUTLASS_CHECK(ExampleRunner{}.run(options, hw_info)); + }else if(options.mode == GemmMode::ConvertAndScale){ + std::cout << "Running in ConvertAndScale mode." << std::endl; + using MainloopConvertAndScale = MixedBuilderQuant::template CollectiveMma< + ElementInputB, cute::tuple>; + using GemmConvertAndScale = GemmAdapterBuilder::template GemmUniversalAdapter; + CUTLASS_CHECK(ExampleRunner{}.run(options, hw_info)); + }else{ + std::cout << "Running in ConvertAndScaleWithZeroPoint mode." << std::endl; + using MainloopConvertAndScaleWithZeroPoint = MixedBuilderQuant::template CollectiveMma< + ElementInputB, cute::tuple>; + using GemmConvertAndScaleWithZeroPoint = GemmAdapterBuilder::template GemmUniversalAdapter; + CUTLASS_CHECK(ExampleRunner{}.run(options, hw_info)); + } + } +} diff --git a/examples/sycl/10_bmg_grouped_gemm_mixed_dtype/CMakeLists.txt b/examples/sycl/10_bmg_grouped_gemm_mixed_dtype/CMakeLists.txt new file mode 100644 index 0000000000..6ecf262dd4 --- /dev/null +++ b/examples/sycl/10_bmg_grouped_gemm_mixed_dtype/CMakeLists.txt @@ -0,0 +1,84 @@ +# Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set(TEST_GROUPS --groups=2) +set(TEST_MODE_1 --groups=2 --mode=1) +set(TEST_MODE_0 --groups=2 --mode=0) +set(TEST_A_NARROW --groups=2 --a_narrower) +set(TEST_GROUP_DEQUANT --g=128) + +set(MMA_T bfloat16_t half_t) +set(QUANT_T int8_t) +set(EXE_LIST "") + +foreach(MMA_TYPE IN LISTS MMA_T) + set(mma_name "10_bmg_grouped_gemm_bf16") + if(${MMA_TYPE} STREQUAL "half_t") + set(mma_name "10_bmg_grouped_gemm_f16") + endif() + foreach(QUANT_TYPE IN LISTS QUANT_T) + set(exe_name "${mma_name}_s8") + if(${MMA_TYPE} STREQUAL "half_t" AND ${QUANT_TYPE} STREQUAL "int8_t") + set(exe_name "${mma_name}_s8_tensorwise") + set(TEST_GROUP_DEQUANT --g=0) + set(TEST_A_NARROW --groups=2) + endif() + + cutlass_example_add_executable( + ${exe_name} + 10_bmg_grouped_gemm_bf16_f16_s8.cpp + TEST_COMMAND_OPTIONS + TEST_GROUPS + TEST_MODE_1 + TEST_MODE_0 + TEST_A_NARROW + TEST_GROUP_DEQUANT + ) + list(APPEND EXE_LIST ${exe_name}) + target_compile_definitions(${exe_name} PRIVATE MMA_TYPE=${MMA_TYPE} QUANT_TYPE=${QUANT_TYPE}) + endforeach() +endforeach() + +cutlass_example_add_executable( + 10_bmg_grouped_gemm_f16_u4 + 10_bmg_grouped_gemm_f16_u4.cpp + TEST_COMMAND_OPTIONS + TEST_GROUPS + TEST_MODE_1 + TEST_MODE_0 + TEST_A_NARROW + TEST_GROUP_DEQUANT +) + +if(NOT DPCPP_SYCL_TARGET STREQUAL "spir64") + # TODO(codeplay): Remove these once IGC block load loop hoisting bug is fixed + foreach(target_exe IN LISTS EXE_LIST) + target_link_options(${target_exe} PRIVATE -Xs "-options \"-igc_opts 'allowDecompose2DBlockFuncs=0'\"" ) + endforeach() + target_link_options(10_bmg_grouped_gemm_f16_u4 PRIVATE -Xs "-options \"-igc_opts 'allowDecompose2DBlockFuncs=0'\"" ) +endif() diff --git a/examples/sycl/10_bmg_grouped_gemm_mixed_dtype/bmg_grouped_gemm_mixed_dtype_runner.hpp b/examples/sycl/10_bmg_grouped_gemm_mixed_dtype/bmg_grouped_gemm_mixed_dtype_runner.hpp new file mode 100644 index 0000000000..525fa8ae6d --- /dev/null +++ b/examples/sycl/10_bmg_grouped_gemm_mixed_dtype/bmg_grouped_gemm_mixed_dtype_runner.hpp @@ -0,0 +1,907 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Mixed Precision BMG Grouped Gemm Example Runner +*/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" +#include "helper.h" +#include "cutlass/util/mixed_dtype_utils.hpp" + +#include +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +enum GemmMode { + ConvertOnly, + ConvertAndScale, + ConvertAndScaleWithZeroPoint +}; + +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +// Command line options parsing +struct Options { + + bool help; + bool error; + + bool a_narrower; + int mode; + int m, n, k, l, iterations, groups; + int g; + float alpha, beta; + std::vector problem_sizes_host; + + Options(): help(false), error(false), m(5120), n(4096), k(4096), l(1), iterations(20), + g(128), groups(2), mode(2), a_narrower(false), alpha(FLT_MAX), beta(FLT_MAX) { + + problem_sizes_host.reserve(groups); + for(int i = 0; i < groups; i++) { + problem_sizes_host.push_back({m, n, k}); + } + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("g", g, 128); + cmd.get_cmd_line_argument("groups", groups, 2); + cmd.get_cmd_line_argument("mode", mode, 2); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + if (cmd.check_cmd_line_flag("a_narrower")) { + a_narrower = true; + } + assert(groups > 0); + problem_sizes_host.clear(); + problem_sizes_host.reserve(groups); + for(int i = 0; i < groups; i++) { + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "BMG Grouped GEMM Mixed Type Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --g= The size of each group for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --mode= The mode to run the gemm. 0 is Convert Only, 1 is Convert and Scale, 2 is Convert and Scale with Zero Point\n" + << " --a_narrower If specified, make A the narrower type (B is narrower by default).\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = static_cast(2) * static_cast(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +// Factory structs to factor out boilerplate code +namespace helpers{ + using namespace cutlass::gemm; + + template + struct MixedCollectiveMmaBuilder { + + template + using CollectiveMma = collective::CollectiveMma< + DispatchPolicy, TileShape, ElementA, LayoutA, ElementB, LayoutB, TiledMMA, + GmemTiledCopyA, void, void, cute::identity, GmemTiledCopyB, void, void, + cute::identity>; + }; + + template + struct MixedGemmUniversalAdapterBuilder { + template + using GemmUniversalAdapter = + device::GemmUniversalAdapter>; + }; + + template + struct MMAOp; + + template <> + struct MMAOp { + using type = XE_8x16x16_F32BF16BF16F32_TT; + }; + + template <> + struct MMAOp { + using type = XE_8x16x16_F32F16F16F32_TT; + }; + + template + struct RefTiledCopyB; + + template <> + struct RefTiledCopyB { + using type = XE_2D_U16x32x32_LD_V; + }; + + template <> + struct RefTiledCopyB { + using type = XE_2D_U16x16x16_LD_T; + }; +} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Gemm +> +struct ExampleRunner { + + using CollectiveMainloop = typename Gemm::CollectiveMainloop; + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + + static constexpr bool AIsNarrower = CollectiveMainloop::IsATransformed; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + using ElementMMA = std::conditional_t; + using ElementQuant = std::conditional_t; + + using ElementScale = typename CollectiveMainloop::NonVoidElementScale; + using ElementZero = typename CollectiveMainloop::NonVoidElementZero; + // Scale and Zero share a stride since the layout and shapes must be the same. + using StrideScale = typename CollectiveMainloop::InternalNonVoidStrideScale; + using StrideZero = typename CollectiveMainloop::InternalNonVoidStrideZero; + + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + // + // Data members + // + + // Host-side allocations + std::vector offset_A; + std::vector offset_B; + std::vector offset_S; + std::vector offset_Z; + std::vector offset_C; + std::vector offset_D; + + std::vector stride_A_host; + std::vector stride_B_host; + std::vector stride_S_host; + std::vector stride_Z_host; + std::vector stride_C_host; + std::vector stride_D_host; + + std::vector alpha_host; + std::vector beta_host; + + // Device-side allocations + cutlass::DeviceAllocation problem_sizes; + + /// Initialization + cutlass::DeviceAllocation stride_A; + cutlass::DeviceAllocation stride_B; + cutlass::DeviceAllocation stride_S; + cutlass::DeviceAllocation stride_Z; + cutlass::DeviceAllocation stride_C; + cutlass::DeviceAllocation stride_D; + + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_A_dq; // Dequantized copy of A for validation + cutlass::DeviceAllocation block_B_dq; // Dequantized copy of B for validation + cutlass::DeviceAllocation block_S; + cutlass::DeviceAllocation block_Z; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_A_dq; + cutlass::DeviceAllocation ptr_B_dq; + cutlass::DeviceAllocation ptr_S; + cutlass::DeviceAllocation ptr_Z; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + cutlass::DeviceAllocation ptr_ref_D; + + // Note, this is an array of pointers to alpha and beta scaling values per group + cutlass::DeviceAllocation alpha_device; + cutlass::DeviceAllocation beta_device; + cutlass::DeviceAllocation block_alpha; + cutlass::DeviceAllocation block_beta; + + // + // Methods + // + + template + void quantize_tensorwise(const SrcT* d_src, DstT* d_dst, const ElementScale* scale, const ElementZero* zero, size_t size, size_t L) { + SrcT* h_src = new SrcT[size * L]; + ElementScale* scale_h = new ElementScale[L]; + ElementZero* zero_h = new ElementZero[L]; + syclcompat::memcpy(h_src, d_src, size * L * sizeof(SrcT)); + syclcompat::memcpy(scale_h, scale, L * sizeof(ElementScale)); + syclcompat::memcpy(zero_h, zero, L * sizeof(ElementZero)); + + DstT* h_dst = new DstT[size * L]; + for(size_t j = 0; j < L; ++j) { + for (size_t i = 0; i < size; ++i) { + h_dst[i + j * size] = (static_cast(h_src[i + j * size]) - zero_h[j]) * scale_h[j]; + } + } + + syclcompat::memcpy(d_dst, h_dst, size * sizeof(DstT)); + } + + /// Populates a Gemm::Arguments structure from the given commandline options + auto args_from_options(const Options &options, const cutlass::KernelHardwareInfo& hw_info) + { + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerXeGroup::RasterOrderOptions; + + // Per-GEMM problem shape info may only exist on the device. + return cute::make_tuple(cutlass::gemm::GemmUniversalMode::kGrouped, + typename Gemm::GemmKernel::ProblemShape{options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + fusion_args, hw_info, + typename Gemm::GemmKernel::TileSchedulerArguments{1, RasterOrderOptions::AlongN}); + + } + + bool verify(const Options &options) { + + // + // Compute reference output (default gemm kernel w/ ElementA == ElementB) + // + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = typename helpers::RefTiledCopyB::type; + + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = + typename TiledMMAHelper::type>, Layout, + Layout, Stride<_4, _1, _0>>>::TiledMMA; + + constexpr int PipelineStages = 3; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; + + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + + using CollectiveEpilogueRef = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + FusionCallBacks, + XE_2D_U32x8x16_LD_N, + void, void, + typename CollectiveEpilogue::GmemTiledCopyD, + void, void>; + + // Mainloop + using CollectiveMainloopRef = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementMMA, + cutlass::gemm::TagToStrideA_t, + ElementMMA, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + cutlass::KernelHardwareInfo hw_info; + + ElementOutput const epsilon(1e-2f); + ElementOutput const non_zero_floor(1e-4f); + bool passed = false; + + for(int i = 0; i < options.groups; i++) { + Shape problem_size = append<4>(options.problem_sizes_host[i], 1); + auto M = get<0>(problem_size); + auto N = get<1>(problem_size); + auto K = get<2>(problem_size); + using RefStrideA = cutlass::gemm::TagToStrideA_t; + using RefStrideB = cutlass::gemm::TagToStrideB_t; + using RefStrideC = cutlass::gemm::TagToStrideC_t; + RefStrideA stride_a = cutlass::make_cute_packed_stride(RefStrideA{}, {M, K, 1}); + RefStrideB stride_b = cutlass::make_cute_packed_stride(RefStrideB{}, {N, K, 1}); + RefStrideC stride_c = cutlass::make_cute_packed_stride(RefStrideC{}, {M, N, 1}); + + // allocate the reference memory + cutlass::DeviceAllocation block_ref_D; + block_ref_D.reset(i == options.groups - 1 ? block_D.size() - offset_D[i] : offset_D[i + 1] - offset_D[i]); + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A_dq.get() + offset_A[i], stride_a, block_B_dq.get() + offset_B[i], stride_b}, + {{alpha_host[i], beta_host[i]}, block_C.get() + offset_C[i], stride_c, block_ref_D.get(), stride_c}, + hw_info + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + syclcompat::wait(); + // compare_reference + passed |= cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get() + offset_D[i], block_ref_D.size(), epsilon, non_zero_floor); + syclcompat::wait(); + } + + return passed; + } + + /// Allocates device-side data + void allocate(const Options &options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + int64_t total_elements_S = 0; + int64_t total_elements_Z = 0; + + // Compute total allocation sizes across group + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + // Offset into block allocation of each matrix base pointer + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_S.push_back(total_elements_S); + offset_Z.push_back(total_elements_Z); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + const int scale_k = options.g == 0 ? 1 : cute::ceil_div(K, options.g); + const int dq_mn_size = options.g == 0 ? 1 : AIsNarrower ? M : N; + total_elements_S += (dq_mn_size * scale_k); + total_elements_Z += (dq_mn_size * scale_k); + auto zero_elements_packed_along_k = get<0>(StrideZero{}); + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + auto stride_scale = cutlass::make_cute_packed_stride(StrideScale{}, {dq_mn_size, scale_k, 1}); + stride_S_host.push_back(stride_scale); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + + auto stride_zero = [&]() { + if constexpr (is_tuple_v(StrideZero{}))>>) { + return make_stride(Int{}, + make_stride(_1{}, static_cast(zero_elements_packed_along_k * dq_mn_size)), + static_cast(dq_mn_size * scale_k)); + } else { + return stride_scale; + } + }(); + stride_Z_host.push_back(stride_zero); + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_A_dq.reset(total_elements_A); + block_B_dq.reset(total_elements_B); + block_S.reset(total_elements_S); + block_Z.reset(total_elements_Z); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); + } + + template + bool initialize_scale( + cutlass::DeviceAllocation& block, + Options const& options) { + + if (options.mode == GemmMode::ConvertOnly) { + // No scales, so just initialize with 1 so we can use the same kernel to dequantize the data. + std::vector stage(block.size(), Element(1.0f)); + block.copy_from_host(stage.data()); + } + else { + float elt_max_f = float(cutlass::platform::numeric_limits::max()); + const float max_dequant_val = 4.f; + const float min_dequant_val = 0.5f; + + float scope_max(max_dequant_val / elt_max_f); + float scope_min(min_dequant_val / elt_max_f); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + } + return true; + } + + template + bool initialize_zero( + cutlass::DeviceAllocation& block, + Options const& options) { + + if (options.mode == GemmMode::ConvertAndScaleWithZeroPoint) { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(2.0f), Element(-2.0f)); + } else { + // No bias, so just initialize with 0 so we can use the same kernel to dequantize the data. + std::vector stage(block.size(), Element(0.0f)); + block.copy_from_host(stage.data()); + } + return true; + } + + template < + class QuantizedElement, + class DequantizedElement, + class OperandLayout, + class ElementScale, + class ElementZero, + class ScaleLayout, + class ZeroLayout> + static void dequantize_B_int4(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleLayout const scale_layout, + ZeroLayout const zero_layout, + int const group_size) { + std::vector dst(size(operand_layout) * sizeof_bits_v / 8, 0); + cutlass::device_memory::copy_to_host(dst.data(), (uint8_t*)dq_buffer, dst.size()); + + std::vector src(size(operand_layout) * sizeof_bits_v / 8, 0); + cutlass::device_memory::copy_to_host(src.data(), (uint8_t*)q_buffer, src.size()); + + std::vector scale(size(scale_layout) * sizeof_bits_v / 8, 0); + cutlass::device_memory::copy_to_host(scale.data(), (uint8_t*)scale_buffer, scale.size()); + + std::vector zero(size(zero_layout) * sizeof_bits_v / 8, 0); + cutlass::device_memory::copy_to_host(zero.data(), (uint8_t*)zero_buffer, zero.size()); + + syclcompat::wait(); + + auto dst_tensor = make_tensor(make_gmem_ptr(reinterpret_cast(dst.data())), operand_layout); + + auto src_tensor = [&]() { + if constexpr (sizeof_bits_v < 8) { + return make_tensor(cute::subbyte_iterator(src.data()), operand_layout); + } else { + return make_tensor(make_gmem_ptr(reinterpret_cast(src.data())), operand_layout); + } + }(); + + auto scale_tensor = make_tensor(make_gmem_ptr(reinterpret_cast(scale.data())), scale_layout); + + auto zero_tensor = [&]() { + if constexpr (sizeof_bits_v < 8) { + auto flatten_tensor = flatten(make_tensor(cute::subbyte_iterator(zero.data()), zero_layout)); + static_assert(rank(flatten_tensor.layout()) == 4); + return make_tensor(flatten_tensor.data(), select<1, 0, 2, 3>(flatten_tensor.layout())); + } else { + return make_tensor(make_gmem_ptr(reinterpret_cast(zero.data())), zero_layout); + } + }(); + + auto N = size<0>(src_tensor); + auto K = size<1>(src_tensor); + auto L = size<2>(src_tensor); + + for (int l = 0; l < L; l++) { + for (int k= 0; k < K; k++) { + for (int n = 0; n < N; n++) { + using ret_type = cute::conditional_t >= 8, ElementZero, int8_t>; + ret_type a = [&]() { + if constexpr (sizeof_bits_v >= 8) { + return static_cast(src_tensor(n, k, l)); + } else { + return static_cast(src_tensor(n, k, l).get()); + }}(); + + ret_type b = [&]() { + if constexpr (sizeof_bits_v >= 8) { + return static_cast(zero_tensor(n, k / group_size, l)); + } else { + auto zero_elements_packed_along_k = get<0>(zero_tensor.shape()); + return static_cast(zero_tensor((k / group_size) % zero_elements_packed_along_k, n, k / group_size / zero_elements_packed_along_k, l).get()); + } + }(); + + dst_tensor(n, k, l) = ((ElementScale)(a - b)) * scale_tensor(n, k / group_size, l); + } + } + } + + cutlass::device_memory::copy_to_device(dq_buffer, (DequantizedElement*)(raw_pointer_cast(dst_tensor.data())), dst_tensor.size()); + syclcompat::wait(); + } + + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(Options const& options) { + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_A_dq_host(options.groups); + std::vector ptr_B_dq_host(options.groups); + std::vector ptr_S_host(options.groups); + std::vector ptr_Z_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + // Compute offsets, alpha & beta over group on host + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_A_dq_host.at(i) = block_A_dq.get() + offset_A.at(i); + ptr_B_dq_host.at(i) = block_B_dq.get() + offset_B.at(i); + ptr_S_host.at(i) = block_S.get() + offset_S.at(i); + ptr_Z_host.at(i) = block_Z.get() + offset_Z.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + // Fill host vector of alpha & beta with random values if using per-group values + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + // Fill host ptr vectors with offset addresses into device alpha/beta blocks + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + // Allocate device memory & copy from host + ptr_A.reset(options.groups); + // Per-group alpha and beta + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_A_dq.reset(options.groups); + ptr_A_dq.copy_from_host(ptr_A_dq_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_B_dq.reset(options.groups); + ptr_B_dq.copy_from_host(ptr_B_dq_host.data()); + + ptr_S.reset(options.groups); + ptr_S.copy_from_host(ptr_S_host.data()); + + ptr_Z.reset(options.groups); + ptr_Z.copy_from_host(ptr_Z_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_S.reset(options.groups); + stride_S.copy_from_host(stride_S_host.data()); + + stride_Z.reset(options.groups); + stride_Z.copy_from_host(stride_Z_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + // Per-group alpha and beta ptrs + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_mixed_dtype_block(block_A, block_A_dq, seed + 2022); + initialize_mixed_dtype_block(block_B, block_B_dq, seed + 2023); + + initialize_block(block_C, seed + 2024); + + initialize_scale(block_S, options); + initialize_zero(block_Z, options); + + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + const int scale_k = options.g == 0 ? 1 : cute::ceil_div(K, options.g); + const int dq_mn_size = options.g == 0 ? 1 : AIsNarrower ? M : N; + + auto layout_A = make_layout(make_shape(M, K, 1), stride_A_host.at(i)); + auto layout_B = make_layout(make_shape(N, K, 1), stride_B_host.at(i)); + auto zero_elements_packed_along_k = get<0>(StrideZero{}); + auto shape_scale = cute::make_shape(dq_mn_size, scale_k, 1); + auto stride_scale = cutlass::make_cute_packed_stride(StrideScale{}, shape_scale); + auto shape_zero = [&]() { + if constexpr (is_tuple_v(StrideZero{}))>>) { + return cute::make_shape(dq_mn_size, + cute::make_shape(zero_elements_packed_along_k, cute::max(1, scale_k / zero_elements_packed_along_k)), + 1); + } else { + return shape_scale; + } + }(); + auto stride_zero = [&]() { + if constexpr (is_tuple_v(StrideZero{}))>>) { + return make_stride(Int{}, + make_stride(_1{}, static_cast(zero_elements_packed_along_k * dq_mn_size)), + static_cast(dq_mn_size * scale_k)); + } else { + return stride_scale; + } + }(); + auto layout_scale = make_layout(shape_scale, stride_scale); + auto layout_zero = make_layout(shape_zero, stride_zero); + + // Note that we are overwriting the relevant `block_X_dq` here, both were + // filled by initialize_mixed_dtype_block above + if (options.g != 0) { + if constexpr (AIsNarrower) { + dequantize(block_A_dq.get() + offset_A.at(i), block_A.get() + offset_A.at(i), layout_A, + block_S.get() + offset_S.at(i), block_Z.get() + offset_Z.at(i), layout_scale, layout_zero, + options.g); + } else { + if constexpr (cute::sizeof_bits_v < 8) { + dequantize_B_int4(block_B_dq.get() + offset_B.at(i), block_B.get() + offset_B.at(i), layout_B, + block_S.get() + offset_S.at(i), block_Z.get() + offset_Z.at(i), layout_scale, layout_zero, + options.g); + } else { + dequantize(block_B_dq.get() + offset_B.at(i), block_B.get() + offset_B.at(i), layout_B, + block_S.get() + offset_S.at(i), block_Z.get() + offset_Z.at(i), layout_scale, layout_zero, + options.g); + } + } + } else { + if constexpr (AIsNarrower) { + const size_t size_a = i == options.groups - 1 ? block_A.size() - offset_A[i] : offset_A[i + 1] - offset_A[i]; + quantize_tensorwise( + block_A.get() + offset_A.at(i), + block_A_dq.get() + offset_A.at(i), + block_S.get() + offset_S.at(i), + block_Z.get() + offset_Z.at(i), + size_a, 1 + ); + } else { + const size_t size_b = i == options.groups - 1 ? block_B.size() - offset_B[i] : offset_B[i + 1] - offset_B[i]; + quantize_tensorwise( + block_B.get() + offset_B.at(i), + block_B_dq.get() + offset_B.at(i), + block_S.get() + offset_S.at(i), + block_Z.get() + offset_Z.at(i), + size_b, 1 + ); + } + } + } + } + + cutlass::Status run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + allocate(options); + initialize(options); + + auto args_tuple = args_from_options(options, hw_info); + typename Gemm::GemmKernel::Arguments arguments { + get<0>(args_tuple), get<1>(args_tuple), + typename Gemm::GemmKernel::MainloopArguments{ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get(), ptr_S.get(), + stride_S.get(), ptr_Z.get(), stride_Z.get(), options.g}, + typename Gemm::GemmKernel::EpilogueArguments{get<2>(args_tuple), ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + get<3>(args_tuple), get<4>(args_tuple) + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (gemm_op.can_implement(arguments) != cutlass::Status::kSuccess){ + std::cout << "Invalid Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::exit(1); + } + + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace.get())); + + // Run the GEMM + CUTLASS_CHECK(gemm_op.run()); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if(!passed) return cutlass::Status::kErrorInternal; + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double cute_average_time = double(cute_time) / double(options.iterations); + double gflops = options.gflops(cute_average_time / 1000.0, options.problem_sizes_host); + + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + std::cout << " Avg runtime : " << cute_average_time << " ms" << std::endl; + std::cout << " GFLOPS : " << gflops << std::endl; + } + + return cutlass::Status::kSuccess; + } +}; diff --git a/examples/sycl/CMakeLists.txt b/examples/sycl/CMakeLists.txt index 2ab920ccaf..ae4c080102 100644 --- a/examples/sycl/CMakeLists.txt +++ b/examples/sycl/CMakeLists.txt @@ -38,6 +38,7 @@ if(SYCL_INTEL_TARGET) add_subdirectory(07_bmg_dual_gemm) add_subdirectory(08_bmg_gemm_f8) add_subdirectory(09_bmg_grouped_gemm_f8) + add_subdirectory(10_bmg_grouped_gemm_mixed_dtype) endif() if (CUTLASS_ENABLE_SYCL) diff --git a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp index a5b2a3f249..e8b1709aad 100644 --- a/include/cutlass/epilogue/collective/xe_array_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_array_epilogue.hpp @@ -162,6 +162,7 @@ class CollectiveEpilogue< using TensorC = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideC{})); //(m, n) using TensorD = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideD{})); //(m, n) + using EpilogueTensors = cute::tuple; // Host side epilogue arguments struct Arguments { diff --git a/include/cutlass/gemm/collective/builders/xe_mma_builder.inl b/include/cutlass/gemm/collective/builders/xe_mma_builder.inl index 7c6364e028..727623dc5d 100644 --- a/include/cutlass/gemm/collective/builders/xe_mma_builder.inl +++ b/include/cutlass/gemm/collective/builders/xe_mma_builder.inl @@ -161,9 +161,8 @@ struct CollectiveBuilder< KernelScheduleType, cute::enable_if_t< cute::is_any_of_v && - cute::is_same_v && - cute::is_any_of_v && - cute::is_any_of_v + cute::is_any_of_v && + cute::is_any_of_v > >{ @@ -175,7 +174,9 @@ struct CollectiveBuilder< static_assert(cute::is_any_of_v, "Intel multi-stage pipeline requires ElementC to be of type float, bfloat or half"); - using MMAAtom = typename pick_mma_atom::atom; + static constexpr bool isAtypeBig = cute::sizeof_bits_v > cute::sizeof_bits_v; + using MMAType = std::conditional_t; + using MMAAtom = typename pick_mma_atom::atom; static constexpr auto tile_M = get<0>(TileShape_MNK{}); static constexpr auto tile_N = get<1>(TileShape_MNK{}); @@ -193,12 +194,17 @@ struct CollectiveBuilder< using KernelSchedule = std::conditional_t, KernelXe, KernelScheduleType>; static constexpr int PipelineStages = IsGroup ? 2 : 3; - using DispatchPolicy = std::conditional_t, - cutlass::gemm::MainloopIntelXeXMX16>; - - using GmemTiledCopyA = decltype(select_copy_atom_16b, false>(tile_M/atoms_M{}, tile_K)); - using GmemTiledCopyB = decltype(select_copy_atom_16b, true>(tile_K, tile_N/atoms_N{})); + using DispatchPolicy = std::conditional_t, + std::conditional_t, cutlass::gemm::MainloopIntelXeXMX16Group, + cutlass::gemm::MainloopIntelXeXMX16GroupMixedPrecision>>; + + static constexpr bool isAtransposed = cute::is_same_v; + static constexpr bool isBtransposed = cute::is_same_v; + using GmemTiledCopyA = std::conditional_t == 8, std::conditional_t, + decltype(select_copy_atom_16b(tile_M/atoms_M{}, tile_K))>; + using GmemTiledCopyB = std::conditional_t == 4, std::conditional_t, + std::conditional_t == 8, std::conditional_t, + decltype(select_copy_atom_16b(tile_K, tile_N/atoms_N{}))>>; // Xe pipeline does not use shared memory using SmemLayoutAtomA = void; @@ -209,12 +215,15 @@ struct CollectiveBuilder< using TransformA = cute::identity; using TransformB = cute::identity; + using ElementA_ = std::conditional_t <= 8, cute::tuple, ElementA>; + using ElementB_ = std::conditional_t <= 8, cute::tuple, ElementB>; + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< DispatchPolicy, TileShape_MNK, - ElementA, + ElementA_, cutlass::gemm::TagToStrideA_t>, - ElementB, + ElementB_, cutlass::gemm::TagToStrideB_t>, TiledMma, GmemTiledCopyA, diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index df97578080..3249cc8be2 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -74,6 +74,7 @@ #include "cutlass/gemm/collective/xe_array_mma.hpp" #include "cutlass/gemm/collective/xe_array_mma_fp8.hpp" #include "cutlass/gemm/collective/xe_mma_mixed_input.hpp" +#include "cutlass/gemm/collective/xe_array_mma_mixed_input.hpp" #include "cutlass/gemm/collective/xe_mma_w8a8.hpp" #include "cutlass/gemm/collective/xe_mma_fp8_scaling.hpp" #endif diff --git a/include/cutlass/gemm/collective/xe_array_mma.hpp b/include/cutlass/gemm/collective/xe_array_mma.hpp index 8b049b0c56..42c2f8fdbd 100644 --- a/include/cutlass/gemm/collective/xe_array_mma.hpp +++ b/include/cutlass/gemm/collective/xe_array_mma.hpp @@ -103,7 +103,7 @@ struct CollectiveMma, TileShape_, El using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideA{})); //(m, k) using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideB{})); //(n, k) - + using MainloopTensors = cute::tuple; // Host side kernel arguments struct Arguments { ElementA const** ptr_A; diff --git a/include/cutlass/gemm/collective/xe_array_mma_fp8.hpp b/include/cutlass/gemm/collective/xe_array_mma_fp8.hpp index e9e6b47752..99c4d15699 100644 --- a/include/cutlass/gemm/collective/xe_array_mma_fp8.hpp +++ b/include/cutlass/gemm/collective/xe_array_mma_fp8.hpp @@ -104,7 +104,7 @@ struct CollectiveMma, TileShape_, using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideA{})); //(m, k) using TensorNKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideB{})); //(n, k) - + using MainloopTensors = cute::tuple; // Host side kernel arguments struct Arguments { ElementA const** ptr_A; diff --git a/include/cutlass/gemm/collective/xe_array_mma_mixed_input.hpp b/include/cutlass/gemm/collective/xe_array_mma_mixed_input.hpp new file mode 100644 index 0000000000..add024e9b8 --- /dev/null +++ b/include/cutlass/gemm/collective/xe_array_mma_mixed_input.hpp @@ -0,0 +1,797 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/fp8_to_fp16.h" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class Schedule, + class TileShape_, + class ElementAOptionalTuple, + class StrideA_, + class ElementBOptionalTuple, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopIntelXeXMX16GroupMixedPrecision, + TileShape_, + ElementAOptionalTuple, + StrideA_, + ElementBOptionalTuple, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ +private: + enum class ConversionMode { + DirectConvert, + ConvertAndScale, + ConvertAndScaleWithZero + }; + +public: + // + // Type Aliases + // + using DispatchPolicy = MainloopIntelXeXMX16GroupMixedPrecision; + using WorkgroupTileShape = TileShape_; + + + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale]," + "[ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + + static constexpr bool IsATransformed = cute::is_tuple::value; + + using ElementMMA = cute::conditional_t; + using ElementQuant = cute::conditional_t; + + using ElementScale = cute::conditional_t, detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>>; + using StrideScale = cute::conditional_t, detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>>; + + using ElementZero = cute::conditional_t, detail::deduce_mixed_width_dtype_t<3, ElementBOptionalTuple>>; + using StrideZero = cute::conditional_t, detail::deduce_mixed_width_dtype_t<4, ElementBOptionalTuple>>; + + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is void. + using NonVoidElementScale = cute::conditional_t, ElementMMA, ElementScale>; + using NonVoidElementZero = cute::conditional_t, ElementMMA, ElementZero>; + + using NonVoidStrideScale = cute::conditional_t, cute::Stride<_1, int64_t, int64_t> *, StrideScale>; + using NonVoidStrideZero = cute::conditional_t, cute::Stride<_1, int64_t, int64_t> *, StrideZero>; + using InternalNonVoidStrideScale = cute::remove_pointer_t; + using InternalNonVoidStrideZero = cute::remove_pointer_t; + static constexpr auto zero_elements_packed_along_k = get<0>(InternalNonVoidStrideZero{}); + + using StrideA = StrideA_; + using InternalStrideA = cute::remove_pointer_t; + using StrideB = StrideB_; + using InternalStrideB = cute::remove_pointer_t; + + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using MmaType = typename TiledMma::ValTypeA; // ValTypeA and ValTypeB are always same and reflects MMA type on intel Xe + using LargerElementType = std::conditional_t<(cute::sizeof_bits_v > cute::sizeof_bits_v), + ElementA, + ElementB>; + + static_assert(!cute::is_same_v, "Mixed precision GEMM requires different types for A and B!"); + static_assert(std::is_same_v, "Transformation for A is not currently supported on Intel PVC"); + static_assert(std::is_same_v, "Transformation for B is not currently supported on Intel PVC"); + +private: + + static constexpr ConversionMode + get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } + else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } + else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || + KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + static constexpr bool ModeHasScalesZero = KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; +public: + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + + static constexpr auto BLK_M = get<0>(WorkgroupTileShape{}); + static constexpr auto BLK_N = get<1>(WorkgroupTileShape{}); + static constexpr auto BLK_K = get<2>(WorkgroupTileShape{}); + + static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape()); + static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape()); + + static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M); + static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N); + static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K); + using SubgroupTileShape = Shape; + + using GmemTiledCopyScale = typename scale_zero_copy_traits::type; + using GmemTiledCopyZero = typename scale_zero_copy_traits::type; + + static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K; + static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); + + using CopyThreadShape = Shape<_1, Int>; + using CopyThreadShapeRev = decltype(cute::reverse(CopyThreadShape{})); + + using traits_load_A = Copy_Traits; + using atom_load_A = Copy_Atom; + using val_layout_load_A = decltype(make_layout(shape_div(typename traits_load_A::BlockShape{}, CopyThreadShape{}))); + using Copy_A = decltype(make_tiled_copy(atom_load_A{}, Layout{}, val_layout_load_A{})); + + using traits_load_B = Copy_Traits; + using atom_load_B = Copy_Atom; + using val_layout_load_B = decltype(make_layout(shape_div(typename traits_load_B::BlockShape{}, CopyThreadShape{}))); + using Copy_B = decltype(make_tiled_copy(atom_load_B{}, Layout{}, val_layout_load_B{})); + + using traits_load_scale = Copy_Traits; + using atom_load_scale = Copy_Atom; + using val_layout_load_scale = decltype(make_layout(shape_div(typename traits_load_scale::BlockShape{}, CopyThreadShapeRev{}))); + using Copy_Scale = decltype(make_tiled_copy(atom_load_scale{}, Layout{}, val_layout_load_scale{})); + + using traits_load_zero = Copy_Traits; + using atom_load_zero = Copy_Atom; + using val_layout_load_zero = decltype(make_layout(shape_div(typename traits_load_zero::BlockShape{}, CopyThreadShapeRev{}))); + using Copy_Zero = decltype(make_tiled_copy(atom_load_zero{}, Layout{}, val_layout_load_zero{})); + + using TensorMKL = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalStrideA{})); //(m, k) + using TensorS = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), InternalNonVoidStrideScale{})); + + // Purpose of this struct is to create a pointer of required type + // for creating the TensorNKL type + template + struct GenPtrType; + + template + struct GenPtrType < 8>> { + static constexpr auto get_pointer() { + // For int4_t type, subbyte_iterator does not accept nullptr, + // so need to create a pointer of type int4_t to get this code + // working. + T* ptr; + return cute::subbyte_iterator(ptr); + } + }; + + template + struct GenPtrType >= 8>> { + static constexpr auto get_pointer() { + return make_gmem_ptr(static_cast(nullptr)); + } + }; + + using TensorNKL = decltype(make_tensor(GenPtrType::get_pointer(), make_shape(0,0,0), InternalStrideB{})); //(n, k) + using TensorZ = decltype(make_tensor(GenPtrType::get_pointer(), make_shape(0,0,0), InternalNonVoidStrideScale{})); + using MainloopTensors = cute::tuple; + + // Host side kernel arguments + struct Arguments { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + NonVoidElementScale const** ptr_S = nullptr; + NonVoidStrideScale dS{}; + NonVoidElementZero const** ptr_Z = nullptr; + NonVoidStrideZero dZ{}; + int group_size = 1; + }; + + struct Params { + ElementA const** ptr_A; + StrideA dA; + ElementB const** ptr_B; + StrideB dB; + NonVoidElementScale const** ptr_S = nullptr; + NonVoidStrideScale dS{}; + NonVoidElementZero const** ptr_Z = nullptr; + NonVoidStrideZero dZ{}; + int group_size; + }; + + // + // Methods + // + + CollectiveMma() = default; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const &problem_shape, + Arguments const &args, void *workspace) { + (void)workspace; + return Params{args.ptr_A, args.dA, args.ptr_B, args.dB, args.ptr_S, args.dS, args.ptr_Z, args.dZ, args.group_size}; + } + + template + static bool + can_implement( + ProblemShape problem_shapes, + Arguments const& args) { + constexpr int copy_alignment_bits = 128; + constexpr int batch_alignment_bits = 512; + auto problem_shape_MNKL = append<4>(problem_shapes, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + + constexpr int min_aligned_elements_A = copy_alignment_bits / sizeof_bits::value; + constexpr int min_aligned_elements_B = copy_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_A = batch_alignment_bits / sizeof_bits::value; + constexpr int min_batch_aligned_elements_B = batch_alignment_bits / sizeof_bits::value; + for (int i = 0; i < problem_shapes.groups(); i++) { + auto problem_shape_MNKL = append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + implementable &= cutlass::detail::check_alignment(cute::make_shape(M,K,L), InternalStrideA{}); + implementable &= cutlass::detail::check_alignment(cute::make_shape(N,K,L), InternalStrideB{}); + + if (L > 1) { + implementable &= get<2>(InternalStrideA{}) % min_batch_aligned_elements_A == 0; + implementable &= get<2>(InternalStrideB{}) % min_batch_aligned_elements_B == 0; + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for XE 2D copy.\n"); + } + + return implementable; + } + + // Helper functions to select packing for conversion + template + struct select_packing { // Naive packing policy + static constexpr auto value() { + return Int, sizeof_bits_v))>{}; + } + }; + + template + CUTLASS_DEVICE typename std::enable_if_t == 4> + transform_quant( + Tensor const& in, + Tensor& out, + Tensor& tCrS_input, + Tensor tCrZ_input + ) { + // TODO (Codeplay): add assert here because int4 is not currently supported + static_assert(!IsATransformed); + + static_assert(is_rmem::value, "Input tensor for conversion must come from registers"); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(std::is_same_v); + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + using ZeroType = typename EngineZeros::value_type; + using ScaleType = typename EngineScales::value_type; + + static constexpr auto DPAS = decltype(size<0>(in))::value; + static constexpr auto N = decltype(size<1>(in))::value; + static constexpr auto K = decltype(size<2>(in))::value; + + using format_type = ushort; + static constexpr auto src_bits = sizeof_bits_v; + static constexpr auto scalar = sizeof_bits_v / src_bits; + static constexpr auto loop_cnt = decltype(size(out))::value / N; + static_assert((scalar % N) == 0); + + // for tuning performance + static constexpr auto vec_size = scalar; + static constexpr auto splits = loop_cnt / vec_size; + static_assert(vec_size <= scalar); + + // reshape tensors for easy access + auto s_tensor = make_tensor((format_type*)(raw_pointer_cast(in.data())), Shape, Int>{}); + auto d_tensor = make_tensor(out.data(), Shape, Int, Int>{}); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < N; n++) { + const auto ts = tCrS_input(n); + const auto tz = [&](){ + if constexpr (sizeof_bits_v >= 8) { + return tCrZ_input(n); + } else { + return tCrZ_input(n).get(); + } + }(); + + auto& src = *(cute::array*)(s_tensor(_, n).data()); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < splits; s++) { + auto idx = vec_size * s / scalar; + auto format_data = src[idx]; + + auto& dst = *(cute::array*)(d_tensor(_, s, n).data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < vec_size; i++) { + auto data = [&](){ + if constexpr (cutlass::platform::numeric_limits::is_signed) { + return static_cast((format_data >> (src_bits * i)) & 0xf); + } else { + return (format_data >> (src_bits * i)) & 0xf; + } + }(); + + if constexpr (ModeHasScales) { + if constexpr (IsATransformed) { + static_assert(dependent_false && "ATransform not support now"); + } else { + using ret_type = cute::conditional_t >= 8, ZeroType, int8_t>; + ret_type minus(data); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + minus = static_cast(data) - static_cast(tz); + } + dst[i] = (static_cast(minus)) * ts; + } + } else { + dst[i] = static_cast(data); + } + } + } + } + } + + /// Utilities to transform A. + template + CUTLASS_DEVICE typename std::enable_if_t >= 8> + transform_quant( + Tensor const& tCrA_load, + Tensor& tCrA_mma, + Tensor& tCrS_input, + Tensor& tCrZ_input + ) { + + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + if constexpr(cute::is_any_of_v + && cute::is_any_of_v) { + convert_FP8_to_FP16(make_tensor(reinterpret_cast(tCrA_load.data()), tCrA_load.layout()), tCrA_mma); + } else { + auto const& src = tCrA_load(_, _, _); + auto const& dst = tCrA_mma(_, _, _); + auto pSrc = const_cast(raw_pointer_cast(src.data())); + auto pDst = const_cast(raw_pointer_cast(dst.data())); + constexpr int num_elements = decltype(size(src))::value; + + // TODO(Codeplay): (perf) consider replacing `pack` with `num_elements` here - See xe_flash_attn_mma.hpp + constexpr int pack = decltype(select_packing::value())::value; + using Converter = cutlass::NumericArrayConverter; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + constexpr int iters = num_elements / pack; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < iters; ++i) { + SrcArray const* pSrcArr = reinterpret_cast(pSrc) + i; + DstArray* pDstArr = reinterpret_cast(pDst) + i; + *pDstArr = Converter::convert(*pSrcArr); + } + } + + if constexpr (ModeHasScales) { + if constexpr(IsATransformed){ + // The current scale load atom (1x32) gives 2 scale values to + // each thread. All threads need access to all other threads + // scale values, and each scale value is reused twice (unrolled) + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 16; ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < 2; ++j) { + auto scale = shfl_sync(0xFFFFFFFF, tCrS_input(j), i); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero){ + auto zero = shfl_sync(0xFFFFFFFF, tCrZ_input(j), i); + tCrA_mma(_, _, 0)[j * 16 + i] -= zero; + tCrA_mma(_, _, 1)[j * 16 + i] -= zero; + } + tCrA_mma(_, _, 0)[j * 16 + i] *= scale; + tCrA_mma(_, _, 1)[j * 16 + i] *= scale; + } + } + } else { + static constexpr auto N = decltype(size<1>(tCrA_load))::value; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < N; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < decltype(size(tCrA_load))::value / N; ++i) { + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero){ + tCrA_mma(_, n, _)[i] -= tCrZ_input(n); + } + tCrA_mma(_, n, _)[i] *= tCrS_input(n); + } + } + } + } + } + +template +CUTLASS_DEVICE auto create_copies(LoadTensors const& load_tensors) { + Copy_A tiled_copy_a{Copy_A{}.with(get<0>(load_tensors))}; + Copy_B tiled_copy_b{Copy_B{}.with(get<1>(load_tensors))}; + + if constexpr(KernelConversionMode == ConversionMode::DirectConvert){ + return cute::make_tuple(tiled_copy_a, tiled_copy_b, Copy_Scale{}, Copy_Zero{}); + } + + Copy_Scale tiled_copy_scale{Copy_Scale{}.with(get<2>(load_tensors))}; + + if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale){ + return cute::make_tuple(tiled_copy_a, tiled_copy_b, tiled_copy_scale, Copy_Zero{}); + } + + Copy_Zero tiled_copy_zero{Copy_Zero{}.with(get<3>(load_tensors))}; + return cute::make_tuple(tiled_copy_a, tiled_copy_b, tiled_copy_scale, tiled_copy_zero); +} + + /// Perform a subgroup-scoped matrix multiply-accumulate + template + CUTLASS_DEVICE void + operator() ( + FrgTensorD &accum, + TensorA gA, + TensorB gB, + FrgTensorC const &src_accum, + KTileIterator k_tile_iter, int k_tile_count, + BlkCoord const &blk_coord, + int const &K_start, + int thread_idx, + Params const& mainloop, + LoadTensors const& load_tensors) + { + static_assert(is_rmem::value, "D tensor must be rmem resident."); + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + auto [tiled_copy_a, tiled_copy_b, tiled_copy_scale, tiled_copy_zero] = create_copies(load_tensors); + + // Partition the copying of A and B tiles across the threads + auto thr_copy_A = tiled_copy_a.get_slice(thread_idx); + auto thr_copy_B = tiled_copy_b.get_slice(thread_idx); + auto thr_copy_scale = tiled_copy_scale.get_slice(thread_idx); + auto thr_copy_zero = tiled_copy_zero.get_slice(thread_idx); + + // Instantiate the MMA object and get thread slice + TiledMma tiled_mma; + auto sg = syclcompat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize; + auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + + // Partition + Tensor tCgA = thr_mma.partition_A(gA); + Tensor tCgB = thr_mma.partition_B(gB); + + // Create fragments + Tensor mma_A = make_tensor(make_fragment_layout(tiled_copy_a, tCgA(_,_,_,0).shape())); + Tensor mma_B = make_tensor(make_fragment_layout(tiled_copy_b, tCgB(_,_,_,0).shape())); + + // If IsATransformed, we need modes M_atom, and M_iter from fragment_A + // layout else we need mode N_iter from fragment_B layout. + static constexpr auto scale_traits_size = decltype(size(typename GmemTiledCopyScale::BlockShape{}))::value / SubgroupSize; + static constexpr auto scale_traits_num = SG_N / size<1>(typename GmemTiledCopyScale::BlockShape{}); + using FragScaleLayout = std::conditional_t>, + Layout, Int, _1>>>; + Tensor fragment_scale_input = make_tensor(FragScaleLayout{}); + + static constexpr auto zero_traits_size = decltype(size(typename GmemTiledCopyZero::BlockShape{}))::value / SubgroupSize; + static constexpr auto zero_traits_num = SG_N * zero_elements_packed_along_k / size<1>(typename GmemTiledCopyZero::BlockShape{}); + using FragZeroLayout = std::conditional_t>, + Layout, Int, _1>>>; + Tensor fragment_zero_input = make_tensor (FragZeroLayout{}); + + // narrow input fragment + Tensor quant_frag = make_tensor( + std::conditional_t{}); + + static_assert(std::is_same_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + // Retile for copy + auto [frag_copy_A, frag_copy_B] = [&](){ + if constexpr (IsATransformed) { + return std::make_pair(thr_copy_A.retile_D(quant_frag), thr_copy_B.retile_D(mma_B)); + } else { + return std::make_pair(thr_copy_A.retile_D(mma_A), thr_copy_B.retile_D(quant_frag)); + } + }(); + + Tensor copy_tCrS = thr_copy_scale.retile_D(fragment_scale_input); + Tensor copy_tCrZ = thr_copy_zero.retile_D(fragment_zero_input); + + // Retile global tile for copies + Tensor tAgA = thr_copy_A.retile_S(tCgA); + Tensor tBgB = thr_copy_B.retile_S(tCgB); + + auto tiled_prefetch_a = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_a);; + auto tiled_prefetch_b = cute::prefetch_selector,Int>, Num_SGs>(tiled_copy_b);; + auto thr_prefetch_A = tiled_prefetch_a.get_slice(thread_idx); + auto thr_prefetch_B = tiled_prefetch_b.get_slice(thread_idx); + + // Partition global tile for prefetch + auto pAgA = thr_prefetch_A.partition_S(gA); + auto pBgB = thr_prefetch_B.partition_S(gB); + + // + // Mainloop + // + // TODO(Codeplay): Define these coord tensors using proper cute logic + auto [m_idx, n_idx, k_idx, l_idx] = blk_coord; + const int m_coord = m_idx * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M; + const int n_coord = n_idx * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N; + const int l_coord = 0; + + Tensor copy_iter_s = [&](){ + if constexpr(IsATransformed){ + return make_tensor(make_inttuple_iter(make_coord(m_coord, 0, l_coord)), + make_layout(make_shape(_2{}, _1{}, _1{}, k_tile_count), + make_stride(E<0>{} * _16{}, E<0>{} * _32{}, _0{}, E<1>{} * _1{}))); + }else{ + return make_tensor(make_inttuple_iter(make_coord(n_coord, 0, l_coord)), + make_layout(make_shape(Int{}, Int{}, _1{}, k_tile_count), + make_stride(E<0>{} * _16{}, E<0>{} * size<1>(typename GmemTiledCopyScale::BlockShape{}), _0{}, E<1>{} * _1{}))); + } + }(); + + Tensor copy_iter_z = [&](){ + if constexpr(IsATransformed){ + return make_tensor(make_inttuple_iter(make_coord(m_coord, 0, l_coord)), + make_layout(make_shape(_2{}, _1{}, _1{}, k_tile_count), + make_stride(E<0>{} * _16{}, E<0>{} * _32{}, _0{}, E<1>{} * _1{}))); + }else{ + return make_tensor(make_inttuple_iter(make_coord(n_coord * zero_elements_packed_along_k, 0, l_coord)), + make_layout(make_shape(Int{}, Int{}, _1{}, k_tile_count), + make_stride(E<0>{} * _16{}, E<0>{} * size<1>(typename GmemTiledCopyZero::BlockShape{}), _0{}, E<1>{} * _1{}))); + } + }(); + + #if CUTLASS_ENABLE_DEBUG_PRINTS + #define PRINT(x) print(#x ": "); print(x); print("\n"); + if (cutlass::thread(LOG_THREAD, LOG_GROUP)) { + print("======================= A: \n"); + PRINT(gA); + PRINT(tCgA); + PRINT(tAgA); + PRINT(mma_A); + PRINT(frag_copy_A); + + print("===================== B :\n"); + PRINT(gB); + PRINT(tCgB); + PRINT(tBgB); + PRINT(mma_B); + PRINT(frag_copy_B); + + print("===================== Config: \n"); + PRINT(MaxThreadsPerBlock); + PRINT(SubgroupTileShape{}); + + PRINT(tiled_prefetch_a); + PRINT(tiled_prefetch_b); + PRINT(pAgA); + PRINT(pBgB); + } + #undef PRINT + #endif + + const int k_start_idx = crd2idx((*k_tile_iter), make_shape(K_start)); + constexpr int barrier_scope = 2; + int prefetch_k = k_start_idx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < DispatchPolicy::Stages; i++, prefetch_k++) { + prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k)); + } + + const int k_reload_factor = mainloop.group_size / BLK_K; + + for (int k_tile = k_start_idx; k_tile < k_tile_count + k_start_idx; k_tile++, prefetch_k++) { + barrier_arrive(barrier_scope); + + // Copy gmem to rmem for the first k_tile + copy(tiled_copy_a, tAgA(_,_,_,k_tile), frag_copy_A); + copy(tiled_copy_b, tBgB(_,_,_,k_tile), frag_copy_B); + + if constexpr(ModeHasScales) { + copy(tiled_copy_scale, copy_iter_s(_, _, _, k_tile / k_reload_factor), copy_tCrS); + } + if constexpr(KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + copy(tiled_copy_zero, copy_iter_z(_, _, _, k_tile / k_reload_factor / zero_elements_packed_along_k), copy_tCrZ); + } + + if(prefetch_k < k_tile_count) { + prefetch(tiled_prefetch_a, pAgA(_,_,_,prefetch_k)); + prefetch(tiled_prefetch_b, pBgB(_,_,_,prefetch_k)); + } + + if constexpr (IsATransformed) { + transform_quant(quant_frag, mma_A, fragment_scale_input, + fragment_zero_input); + } else { + if constexpr (ModeHasScalesZero && sizeof_bits_v < 8) { + transform_quant(quant_frag, mma_B, fragment_scale_input, fragment_zero_input((k_tile / k_reload_factor) % zero_traits_size, _, 0)); + } else { + transform_quant(quant_frag, mma_B, fragment_scale_input, fragment_zero_input); + } + } + + cute::gemm(tiled_mma, mma_A, mma_B, accum); + barrier_wait(barrier_scope); + } + } + + template + CUTLASS_DEVICE MainloopTensors update_tensor_shape_stride( + Params const& mainloop_params, + int32_t const& next_group, + ProblemShape_MNKL const& problem_shape_mnkl) { + const int32_t M = get<0>(problem_shape_mnkl); + const int32_t N = get<1>(problem_shape_mnkl); + const int32_t K = get<2>(problem_shape_mnkl); + + ElementA const* ptr_A_curr_batch = reinterpret_cast(mainloop_params.ptr_A[next_group]); + auto ptr_B_curr_batch = [&]() { + if constexpr (sizeof_bits_v < 8) { + return cute::subbyte_iterator(mainloop_params.ptr_B[next_group]); + } else { + return make_gmem_ptr(static_cast(mainloop_params.ptr_B[next_group])); + } + }(); + + TensorMKL mA_mkl = make_tensor(make_gmem_ptr(ptr_A_curr_batch), make_shape(M, K, static_cast(1)), mainloop_params.dA[next_group]); + TensorNKL mB_nkl = make_tensor(ptr_B_curr_batch, make_shape(N, K,static_cast(1)), mainloop_params.dB[next_group]); + + if constexpr(KernelConversionMode == ConversionMode::DirectConvert){ + return cute::make_tuple(mA_mkl, mB_nkl, TensorS{}, TensorZ{}); + } + + auto scale_k = cute::ceil_div(K, mainloop_params.group_size); + TensorS mScale = make_tensor( + make_gmem_ptr(static_cast(mainloop_params.ptr_S[next_group])), + make_layout(make_shape(IsATransformed ? M : N, scale_k, static_cast(1)), mainloop_params.dS[next_group])); + + if constexpr(KernelConversionMode == ConversionMode::ConvertAndScale){ + return cute::make_tuple(mA_mkl, mB_nkl, mScale, TensorZ{}); + } + + auto ptr_Z = [&]() { + if constexpr (sizeof_bits_v < 8) { + return cute::subbyte_iterator(mainloop_params.ptr_Z[next_group]); + } else { + return make_gmem_ptr(static_cast(mainloop_params.ptr_Z[next_group])); + } + }(); + + TensorZ mZero = make_tensor(ptr_Z, + make_layout(make_shape(zero_elements_packed_along_k * (IsATransformed ? M : N), scale_k / zero_elements_packed_along_k, static_cast(1)), + make_stride(_1{}, static_cast(zero_elements_packed_along_k) * (IsATransformed ? M : N), static_cast(IsATransformed ? M : N) * scale_k))); + + return cute::make_tuple(mA_mkl, mB_nkl, mScale, mZero); + } +}; + + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index 42fc5e8d6d..4996cf8749 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -1036,8 +1036,12 @@ template struct MainloopIntelXeXMX16Group : MainloopIntelXeXMX16 { }; -template -struct MainloopIntelXeXMX16MixedPrecision : MainloopIntelXeXMX16 { +template +struct MainloopIntelXeXMX16GroupMixedPrecision : MainloopIntelXeXMX16 { +}; + +template +struct MainloopIntelXeXMX16MixedPrecision : MainloopIntelXeXMX16 { }; template diff --git a/include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp b/include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp index 25d7dc4591..61bae38de0 100644 --- a/include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp +++ b/include/cutlass/gemm/kernel/xe_gemm_array_cooperative.hpp @@ -107,6 +107,9 @@ class GemmUniversal< using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; using SubgroupTileShape = typename CollectiveMainloop::SubgroupTileShape; + using MainloopTensors = typename CollectiveMainloop::MainloopTensors; + using EpilogueTensors = typename CollectiveEpilogue::EpilogueTensors; + // Kernel level shared memory storage struct SharedStorage { using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; @@ -252,8 +255,8 @@ class GemmUniversal< int32_t curr_group = -1; using ProblemShapeMNKL = Shape; ProblemShapeMNKL problem_shape_MNKL; - cute::tuple AB_tensors; - cute::tuple CD_tensors; + MainloopTensors AB_tensors; + EpilogueTensors CD_tensors; if (work_tile_info.is_valid()) { curr_group = work_tile_info.L_idx; diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 2ca7397532..35ed379fa9 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -73,6 +73,14 @@ if(CUTLASS_ENABLE_SYCL) xe_gemm_fp16_fp16_fp16_tensor_op_fp32_group_gemm.cpp ) + cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_mixed_dtype_tensorop_xe_group_gemm + xe_gemm_bf16_s8_fp32_tensor_op_fp32_group_gemm.cpp + xe_gemm_fp16_s8_fp32_tensor_op_fp32_group_gemm.cpp + xe_gemm_bf16_u4_fp32_tensor_op_fp32_group_gemm.cpp + xe_gemm_fp16_u4_fp32_tensor_op_fp32_group_gemm.cpp + ) + add_custom_target( cutlass_test_unit_gemm_device DEPENDS @@ -81,6 +89,7 @@ if(CUTLASS_ENABLE_SYCL) cutlass_test_unit_gemm_device_tensorop_epilogue_fusion_xe cutlass_test_unit_gemm_device_mixed_input_tensorop_xe cutlass_test_unit_gemm_device_tensorop_xe_group_gemm + cutlass_test_unit_gemm_device_mixed_dtype_tensorop_xe_group_gemm cutlass_test_unit_gemm_device_tensorop_xe ) @@ -92,6 +101,7 @@ if(CUTLASS_ENABLE_SYCL) test_unit_gemm_device_tensorop_epilogue_fusion_xe test_unit_gemm_device_mixed_input_tensorop_xe test_unit_gemm_device_tensorop_xe_group_gemm + test_unit_gemm_device_mixed_dtype_tensorop_xe_group_gemm ) else() # Dummy targets if not building for Intel diff --git a/test/unit/gemm/device/default_gemm_group_configuration.hpp b/test/unit/gemm/device/default_gemm_group_configuration.hpp index f3a4d54feb..33003e2863 100644 --- a/test/unit/gemm/device/default_gemm_group_configuration.hpp +++ b/test/unit/gemm/device/default_gemm_group_configuration.hpp @@ -61,120 +61,51 @@ struct DefaultGemmGroupConfiguration { }; -// Intel XE MMA F32BF16 -template +// Intel XE MMA f16s8f32 +template struct DefaultGemmGroupConfiguration< arch::OpClassTensorOp, arch::IntelXe, - bfloat16_t, LayoutA, - bfloat16_t, LayoutB, + ElementA, LayoutA, + ElementB, LayoutB, float, LayoutC, ElementOutput> { - using TileShape = Shape<_256, _256, _32>; - - using TiledMma = typename TiledMMAHelper, - Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; - - // A - static constexpr int kAlignmentA = 32; - using DefaultOperandA = cutlass::gemm::device::detail::DefaultGemm_TensorOpXe_OperandA< - bfloat16_t, LayoutA, kAlignmentA, 32>; - using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; - // B - static constexpr int kAlignmentB = 32; - using DefaultOperandB = cutlass::gemm::device::detail::DefaultGemm_TensorOpXe_OperandB< - bfloat16_t, LayoutB, kAlignmentB, 32>; - using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; - - using EpilogueOp = epilogue::fusion::LinearCombination; - - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< - epilogue::IntelXeXMX16, - EpilogueOp, - TileShape, - decltype(tile_shape(TiledMma())) - >; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, - TileShape, Shape<_1, _1, _1>, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - float, LayoutC, 1, - ElementOutput, LayoutC, 1, - epilogue::IntelXeXMX16Group, - EpilogueOp - >::CollectiveOp; + static_assert(cute::is_any_of_v, "ElementA needs to be of 16 or 8 bit type"); + static_assert(cute::is_any_of_v, "ElementB needs to be of 16, 8 or 4 bit type"); + using TileShape = Shape<_256, _256, _32>; - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, - cute::bfloat16_t, LayoutA, 1, - cute::bfloat16_t, LayoutB, 1, + using CollectiveMainloop = typename gemm::collective::CollectiveBuilder< + arch::IntelXe, arch::OpClassTensorOp, + ElementA, LayoutA, 1, + ElementB, LayoutB, 1, float, TileShape, Shape<_1, _1, _1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelXePtrArrayCooperative + gemm::collective::StageCountAuto, + gemm::KernelXePtrArrayCooperative >::CollectiveOp; -}; - -// Intel XE MMA F32F16 -template -struct DefaultGemmGroupConfiguration< - arch::OpClassTensorOp, arch::IntelXe, - half_t, LayoutA, - half_t, LayoutB, - float, LayoutC, - ElementOutput> -{ - using TileShape = Shape<_256, _256, _32>; - - using TiledMma = typename TiledMMAHelper, - Layout, - Layout, Stride<_4, _1, _0>>>::TiledMMA; - - // A - static constexpr int kAlignmentA = 32; - using DefaultOperandA = cutlass::gemm::device::detail::DefaultGemm_TensorOpXe_OperandA< - half_t, LayoutA, kAlignmentA, 32>; - using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; - // B - static constexpr int kAlignmentB = 32; - using DefaultOperandB = cutlass::gemm::device::detail::DefaultGemm_TensorOpXe_OperandB< - half_t, LayoutB, kAlignmentB, 32>; - using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + using TiledMma = typename CollectiveMainloop::TiledMma; using EpilogueOp = epilogue::fusion::LinearCombination; - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< - epilogue::IntelXeXMX16, + using FusionCallBacks = epilogue::fusion::FusionCallbacks< + epilogue::IntelXeXMX16Group, EpilogueOp, TileShape, decltype(tile_shape(TiledMma())) >; - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, + using CollectiveEpilogue = typename epilogue::collective::CollectiveBuilder< + arch::IntelXe, arch::OpClassTensorOp, TileShape, Shape<_1, _1, _1>, - cutlass::epilogue::collective::EpilogueTileAuto, + epilogue::collective::EpilogueTileAuto, float, float, float, LayoutC, 1, ElementOutput, LayoutC, 1, epilogue::IntelXeXMX16Group, EpilogueOp >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::IntelXe, cutlass::arch::OpClassTensorOp, - cute::bfloat16_t, LayoutA, 1, - cute::bfloat16_t, LayoutB, 1, - float, - TileShape, Shape<_1, _1, _1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelXePtrArrayCooperative - >::CollectiveOp; }; diff --git a/test/unit/gemm/device/xe_gemm_bf16_s8_fp32_tensor_op_fp32_group_gemm.cpp b/test/unit/gemm/device/xe_gemm_bf16_s8_fp32_tensor_op_fp32_group_gemm.cpp new file mode 100644 index 0000000000..81560d044b --- /dev/null +++ b/test/unit/gemm/device/xe_gemm_bf16_s8_fp32_tensor_op_fp32_group_gemm.cpp @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for Xe Group bf16_s8_fp32 +*/ + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" + +#include "default_gemm_configuration.hpp" +#include "default_gemm_group_configuration.hpp" +#include "gemm_testbed_3x_ptr_array.hpp" + +namespace cutlass { +namespace { +template +struct XE_Device_Gemm_bf16_s8_f32_tensor_op_f32_group_gemm { + using ProblemShape = gemm::GroupProblemShape>; // per group + using ElementA = cute::bfloat16_t; + using ElementB = cute::int8_t; + using ElementC = float; + using ElementAccumulator = float; + using LayoutC = layout::RowMajor; + + using Config = gemm::device::DefaultGemmGroupConfiguration< + arch::OpClassTensorOp, arch::IntelXe, + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementAccumulator>; + + using GemmKernel = gemm::kernel::GemmUniversal< + ProblemShape, + typename Config::CollectiveMainloop, + typename Config::CollectiveEpilogue, + gemm::GroupScheduler + >; + + using Gemm = gemm::device::GemmUniversalAdapter; +}; + +TEST(XE_Device_Gemm_bf16t_s8t_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::RowMajor; + using LayoutB = layout::RowMajor; + using Gemm = XE_Device_Gemm_bf16_s8_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} + +TEST(XE_Device_Gemm_bf16n_s8t_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::ColumnMajor; + using LayoutB = layout::RowMajor; + using Gemm = XE_Device_Gemm_bf16_s8_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} + +TEST(XE_Device_Gemm_bf16t_s8n_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::RowMajor; + using LayoutB = layout::ColumnMajor; + using Gemm = XE_Device_Gemm_bf16_s8_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} + +TEST(XE_Device_Gemm_bf16n_s8n_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::ColumnMajor; + using LayoutB = layout::ColumnMajor; + using Gemm = XE_Device_Gemm_bf16_s8_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} +} +} // namespace cutlass diff --git a/test/unit/gemm/device/xe_gemm_bf16_u4_fp32_tensor_op_fp32_group_gemm.cpp b/test/unit/gemm/device/xe_gemm_bf16_u4_fp32_tensor_op_fp32_group_gemm.cpp new file mode 100644 index 0000000000..9f1a040765 --- /dev/null +++ b/test/unit/gemm/device/xe_gemm_bf16_u4_fp32_tensor_op_fp32_group_gemm.cpp @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for Xe Group bf16_u4_fp32 +*/ + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" + +#include "default_gemm_configuration.hpp" +#include "default_gemm_group_configuration.hpp" +#include "gemm_testbed_3x_ptr_array.hpp" + +namespace cutlass { +namespace { +template +struct XE_Device_Gemm_bf16_u4_f32_tensor_op_f32_group_gemm { + using ProblemShape = gemm::GroupProblemShape>; // per group + using ElementA = cute::bfloat16_t; + using ElementB = cute::uint4_t; + using ElementC = float; + using ElementAccumulator = float; + using LayoutC = layout::RowMajor; + + using Config = gemm::device::DefaultGemmGroupConfiguration< + arch::OpClassTensorOp, arch::IntelXe, + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementAccumulator>; + + using GemmKernel = gemm::kernel::GemmUniversal< + ProblemShape, + typename Config::CollectiveMainloop, + typename Config::CollectiveEpilogue, + gemm::GroupScheduler + >; + + using Gemm = gemm::device::GemmUniversalAdapter; +}; + +TEST(XE_Device_Gemm_bf16t_u4t_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::RowMajor; + using LayoutB = layout::RowMajor; + using Gemm = XE_Device_Gemm_bf16_u4_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} + +TEST(XE_Device_Gemm_bf16n_u4t_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::ColumnMajor; + using LayoutB = layout::RowMajor; + using Gemm = XE_Device_Gemm_bf16_u4_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} + +TEST(XE_Device_Gemm_bf16t_u4n_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::RowMajor; + using LayoutB = layout::ColumnMajor; + using Gemm = XE_Device_Gemm_bf16_u4_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} + +TEST(XE_Device_Gemm_bf16n_u4n_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::ColumnMajor; + using LayoutB = layout::ColumnMajor; + using Gemm = XE_Device_Gemm_bf16_u4_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} +} +} // namespace cutlass diff --git a/test/unit/gemm/device/xe_gemm_fp16_s8_fp32_tensor_op_fp32_group_gemm.cpp b/test/unit/gemm/device/xe_gemm_fp16_s8_fp32_tensor_op_fp32_group_gemm.cpp new file mode 100644 index 0000000000..1a298ea289 --- /dev/null +++ b/test/unit/gemm/device/xe_gemm_fp16_s8_fp32_tensor_op_fp32_group_gemm.cpp @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for Xe Group fp16_s8_fp32 +*/ + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" + +#include "default_gemm_configuration.hpp" +#include "default_gemm_group_configuration.hpp" +#include "gemm_testbed_3x_ptr_array.hpp" + +namespace cutlass { +namespace { +template +struct XE_Device_Gemm_fp16_s8_f32_tensor_op_f32_group_gemm { + using ProblemShape = gemm::GroupProblemShape>; // per group + using ElementA = cute::half_t; + using ElementB = cute::int8_t; + using ElementC = float; + using ElementAccumulator = float; + using LayoutC = layout::RowMajor; + + using Config = gemm::device::DefaultGemmGroupConfiguration< + arch::OpClassTensorOp, arch::IntelXe, + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementAccumulator>; + + using GemmKernel = gemm::kernel::GemmUniversal< + ProblemShape, + typename Config::CollectiveMainloop, + typename Config::CollectiveEpilogue, + gemm::GroupScheduler + >; + + using Gemm = gemm::device::GemmUniversalAdapter; +}; + +TEST(XE_Device_Gemm_fp16t_s8t_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::RowMajor; + using LayoutB = layout::RowMajor; + using Gemm = XE_Device_Gemm_fp16_s8_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} + +TEST(XE_Device_Gemm_fp16n_s8t_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::ColumnMajor; + using LayoutB = layout::RowMajor; + using Gemm = XE_Device_Gemm_fp16_s8_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} + +TEST(XE_Device_Gemm_fp16t_s8n_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::RowMajor; + using LayoutB = layout::ColumnMajor; + using Gemm = XE_Device_Gemm_fp16_s8_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} + +TEST(XE_Device_Gemm_fp16n_s8n_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::ColumnMajor; + using LayoutB = layout::ColumnMajor; + using Gemm = XE_Device_Gemm_fp16_s8_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} +} +} // namespace cutlass diff --git a/test/unit/gemm/device/xe_gemm_fp16_u4_fp32_tensor_op_fp32_group_gemm.cpp b/test/unit/gemm/device/xe_gemm_fp16_u4_fp32_tensor_op_fp32_group_gemm.cpp new file mode 100644 index 0000000000..07d7966902 --- /dev/null +++ b/test/unit/gemm/device/xe_gemm_fp16_u4_fp32_tensor_op_fp32_group_gemm.cpp @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for Xe Group fp16_u4_fp32 +*/ + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" + +#include "default_gemm_configuration.hpp" +#include "default_gemm_group_configuration.hpp" +#include "gemm_testbed_3x_ptr_array.hpp" + +namespace cutlass { +namespace { +template +struct XE_Device_Gemm_fp16_u4_f32_tensor_op_f32_group_gemm { + using ProblemShape = gemm::GroupProblemShape>; // per group + using ElementA = cute::half_t; + using ElementB = cute::uint4_t; + using ElementC = float; + using ElementAccumulator = float; + using LayoutC = layout::RowMajor; + + using Config = gemm::device::DefaultGemmGroupConfiguration< + arch::OpClassTensorOp, arch::IntelXe, + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementAccumulator>; + + using GemmKernel = gemm::kernel::GemmUniversal< + ProblemShape, + typename Config::CollectiveMainloop, + typename Config::CollectiveEpilogue, + gemm::GroupScheduler + >; + + using Gemm = gemm::device::GemmUniversalAdapter; +}; + +TEST(XE_Device_Gemm_fp16t_u4t_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::RowMajor; + using LayoutB = layout::RowMajor; + using Gemm = XE_Device_Gemm_fp16_u4_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} + +TEST(XE_Device_Gemm_fp16n_u4t_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::ColumnMajor; + using LayoutB = layout::RowMajor; + using Gemm = XE_Device_Gemm_fp16_u4_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} + +TEST(XE_Device_Gemm_fp16t_u4n_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::RowMajor; + using LayoutB = layout::ColumnMajor; + using Gemm = XE_Device_Gemm_fp16_u4_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} + +TEST(XE_Device_Gemm_fp16n_u4n_f32t_tensor_op_f32_group_gemm, 256x256x32) { + using LayoutA = layout::ColumnMajor; + using LayoutB = layout::ColumnMajor; + using Gemm = XE_Device_Gemm_fp16_u4_f32_tensor_op_f32_group_gemm::Gemm; + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 1.0)); + EXPECT_TRUE(test::gemm::device::TestAll(1.0, 0.0)); +} +} +} // namespace cutlass