diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 08aed0cb296a2..301fb0fbe82b0 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -859,6 +859,10 @@ set(ONNXRUNTIME_PROVIDER_NAMES cpu) set(ORT_PROVIDER_FLAGS) if (onnxruntime_USE_CUDA) + include(cuda_configuration) + setup_cuda_compiler() + setup_cuda_architectures() + enable_language(CUDA) message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") @@ -878,9 +882,6 @@ if (onnxruntime_USE_CUDA) set(onnxruntime_USE_FLASH_ATTENTION OFF) endif() - if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) - message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4") - endif() if (WIN32) message( STATUS "Lean Attention unsupported in Windows") set(onnxruntime_USE_LEAN_ATTENTION OFF) @@ -1590,25 +1591,17 @@ if (onnxruntime_USE_CUDA) file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME}) endif() find_package(CUDAToolkit REQUIRED) - if (NOT CMAKE_CUDA_ARCHITECTURES) - # Note that we generate SASS+PTX code for specified cuda architectures by assigning "xy" - # To add SASS only, assign "xy-real" - # To add PTX only, assign "xy-virtual" - if (CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu") - # Support for Jetson/Tegra ARM devices - set(CMAKE_CUDA_ARCHITECTURES "53-real;62-real;72-real;87") # TX1/Nano, TX2, Xavier, Orin - else() - if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12) - # 37, 50 still work in CUDA 11 but are marked deprecated and will be removed in future CUDA version. - set(CMAKE_CUDA_ARCHITECTURES "37-real;50-real;52-real;60-real;70-real;75-real;80-real;86-real;89") - elseif (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8) - set(CMAKE_CUDA_ARCHITECTURES "52-real;60-real;70-real;75-real;80-real;86-real;89-real;90") - else() - # https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html - set(CMAKE_CUDA_ARCHITECTURES "all") # Supporting all, including latest Blackwell B series & RTX 50 series - endif() - endif() + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.8) + add_definitions("-DENABLE_FP8") + message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_FP8 flag") endif() + + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) + add_definitions("-DENABLE_FP4") + message(STATUS "CUDA Toolkit version is greater or equal than 12.8, enable -DENABLE_FP4 flag") + endif() + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xfatbin=-compress-all") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror default-stream-launch") diff --git a/cmake/external/cuda_configuration.cmake b/cmake/external/cuda_configuration.cmake new file mode 100644 index 0000000000000..ef94ec25132e3 --- /dev/null +++ b/cmake/external/cuda_configuration.cmake @@ -0,0 +1,172 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. +# + +macro(setup_cuda_compiler) + # Determine CUDA version before enabling the language extension check_language(CUDA) clears CMAKE_CUDA_HOST_COMPILER + # if CMAKE_CUDA_COMPILER is not set + include(CheckLanguage) + if(NOT CMAKE_CUDA_COMPILER AND CMAKE_CUDA_HOST_COMPILER) + set(CMAKE_CUDA_HOST_COMPILER_BACKUP ${CMAKE_CUDA_HOST_COMPILER}) + endif() + check_language(CUDA) + if(CMAKE_CUDA_HOST_COMPILER_BACKUP) + set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CUDA_HOST_COMPILER_BACKUP}) + check_language(CUDA) + endif() + if(CMAKE_CUDA_COMPILER) + message(STATUS "CUDA compiler: ${CMAKE_CUDA_COMPILER}") + if(NOT WIN32) # Linux + execute_process( + COMMAND "bash" "-c" "${CMAKE_CUDA_COMPILER} --version | grep -E -o 'V[0-9]+.[0-9]+.[0-9]+' | cut -c2-" + RESULT_VARIABLE _BASH_SUCCESS + OUTPUT_VARIABLE CMAKE_CUDA_COMPILER_VERSION + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(NOT _BASH_SUCCESS EQUAL 0) + message(FATAL_ERROR "Failed to determine CUDA version") + endif() + + else() # Windows + execute_process( + COMMAND ${CMAKE_CUDA_COMPILER} --version + OUTPUT_VARIABLE versionString + RESULT_VARIABLE versionResult) + + if(versionResult EQUAL 0 AND versionString MATCHES "V[0-9]+\\.[0-9]+\\.[0-9]+") + string(REGEX REPLACE "V" "" version ${CMAKE_MATCH_0}) + set(CMAKE_CUDA_COMPILER_VERSION "${version}") + else() + message(FATAL_ERROR "Failed to determine CUDA version") + endif() + endif() + else() + message(FATAL_ERROR "No CUDA compiler found") + endif() + + set(CUDA_REQUIRED_VERSION "11.4") + if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS CUDA_REQUIRED_VERSION) + message(FATAL_ERROR "CUDA version ${CMAKE_CUDA_COMPILER_VERSION} must be at least ${CUDA_REQUIRED_VERSION}") + endif() +endmacro() + +macro(setup_cuda_architectures) + # cmake-format: off + # Initialize and normalize CMAKE_CUDA_ARCHITECTURES before enabling CUDA. + # Special values: + # (1) `native` is resolved to HIGHEST available architecture. Fallback to `all` if detection failed. + # (2) `all` / `all-major` / unset is resolved to a default set of architectures we optimized and compiler supports. + # Numerical architectures: + # * For `-virtual` architectures, the last one is kept as it is, and the others are ignored. + # * `-real` suffix is automatically added for other cases. + # * Always use accelerated (`-a` suffix) target for supported real architectures. + # cmake-format: on + + if(CMAKE_CUDA_ARCHITECTURES STREQUAL "native") + # Detect highest available compute capability + set(OUTPUTFILE ${PROJECT_BINARY_DIR}/detect_cuda_arch) + set(CUDAFILE ${CMAKE_SOURCE_DIR}/utils/detect_cuda_arch.cu) + execute_process(COMMAND ${CMAKE_CUDA_COMPILER} -lcuda ${CUDAFILE} -o ${OUTPUTFILE}) + message(VERBOSE "Detecting native CUDA compute capability") + execute_process( + COMMAND ${OUTPUTFILE} + RESULT_VARIABLE CUDA_RETURN_CODE + OUTPUT_VARIABLE CUDA_ARCH_OUTPUT) + if(NOT ${CUDA_RETURN_CODE} EQUAL 0) + message(WARNING "Detecting native CUDA compute capability - fail") + message(WARNING "CUDA compute capability detection failed, compiling for all optimized architectures") + unset(CMAKE_CUDA_ARCHITECTURES) + else() + message(STATUS "Detecting native CUDA compute capability - done") + set(CMAKE_CUDA_ARCHITECTURES "${CUDA_ARCH_OUTPUT}") + endif() + elseif(CMAKE_CUDA_ARCHITECTURES STREQUAL "all") + unset(CMAKE_CUDA_ARCHITECTURES) + message(STATUS "Setting CMAKE_CUDA_ARCHITECTURES to all enables a list of architectures OnnxRuntime optimized for, " + "not all architectures CUDA compiler supports.") + elseif(CMAKE_CUDA_ARCHITECTURES STREQUAL "all-major") + unset(CMAKE_CUDA_ARCHITECTURES) + message( + STATUS "Setting CMAKE_CUDA_ARCHITECTURES to all-major enables a list of architectures OnnxRuntime optimized for, " + "not all major architectures CUDA compiler supports.") + else() + message(STATUS "Original CMAKE_CUDA_ARCHITECTURES : ${CMAKE_CUDA_ARCHITECTURES}") + endif() + + if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if(CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu") + # Support for Jetson/Tegra ARM devices + set(CMAKE_CUDA_ARCHITECTURES "53;62;72;87") # TX1/Nano, TX2, Xavier, Orin + else() + if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12) + # 37, 50 still work in CUDA 11 but are marked deprecated and will be removed in future CUDA version. + set(CMAKE_CUDA_ARCHITECTURES "37;50;52;60;70;75;80;86;89") + elseif(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8) + set(CMAKE_CUDA_ARCHITECTURES "52;60;70;75;80;86;89;90") + else() + set(CMAKE_CUDA_ARCHITECTURES "60;70;75;80;86;89;90;100;120") + endif() + endif() + endif() + + unset(CMAKE_CUDA_ARCHITECTURES_CLEAN) + unset(CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL) + foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES) + if(CUDA_ARCH STREQUAL "") + continue() + endif() + + if(CUDA_ARCH MATCHES "^([1-9])([0-9])+a?-virtual$") + set(CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL ${CUDA_ARCH}) + elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)a?-real$") + list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1}) + elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)a?$") + list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1}) + else() + message(FATAL_ERROR "Unrecognized CUDA architecture: ${CUDA_ARCH}") + endif() + endforeach() + list(REMOVE_DUPLICATES CMAKE_CUDA_ARCHITECTURES_CLEAN) + set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_CLEAN}) + + # CMAKE_CUDA_ARCHITECTURES_ORIG contains all architectures enabled, without automatically added -real or -a suffix. + set(CMAKE_CUDA_ARCHITECTURES_ORIG "${CMAKE_CUDA_ARCHITECTURES}") + message(STATUS "GPU architectures: ${CMAKE_CUDA_ARCHITECTURES_ORIG}") + + set(ARCHITECTURES_WITH_KERNELS "80" "86" "89" "90" "100" "120") + foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS) + if(NOT "${CUDA_ARCH}" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) + add_definitions("-DEXCLUDE_SM_${CUDA_ARCH}") + message(STATUS "Excluding SM ${CUDA_ARCH}") + endif() + endforeach() + + # Enable accelerated features (like WGMMA, TMA and setmaxnreg) for SM >= 90. + set(ARCHITECTURES_WITH_ACCEL "90" "100" "101" "120") + unset(CMAKE_CUDA_ARCHITECTURES_NORMALIZED) + foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES) + if("${CUDA_ARCH}" IN_LIST ARCHITECTURES_WITH_ACCEL) + list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}a-real") + else() + list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}-real") + endif() + endforeach() + + if(DEFINED CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL) + list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL}") + endif() + + set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_NORMALIZED}) + + message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}") +endmacro() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 6a7510a5d83bc..da46f29dacf5f 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -179,7 +179,7 @@ set(onnxruntime_NVCC_THREADS "1" CACHE STRING "Number of threads that NVCC can use for compilation.") target_compile_options(${target} PRIVATE "$<$:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">") endif() - + # Since CUDA 12.8, compiling diagnostics become stricter if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) target_compile_options(${target} PRIVATE "$<$:--relocatable-device-code=true>") @@ -261,6 +261,11 @@ set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA) set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime") + if("90" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) + target_compile_options(${target} PRIVATE $<$:-Xptxas=-w>) + target_compile_definitions(${target} PRIVATE COMPILE_HOPPER_TMA_GEMMS) + endif() + if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling target_link_libraries(${target} PRIVATE CUDA::cupti) endif() diff --git a/cmake/utils/detect_cuda_arch.cu b/cmake/utils/detect_cuda_arch.cu new file mode 100644 index 0000000000000..83fbc13dbff7f --- /dev/null +++ b/cmake/utils/detect_cuda_arch.cu @@ -0,0 +1,39 @@ +#include +#include +#include +#include +#include + +int main(int argc, char* argv[]) +{ + int n_devices = 0; + int rc = cudaGetDeviceCount(&n_devices); + if (rc != cudaSuccess) + { + cudaError_t error = cudaGetLastError(); + std::cout << "CUDA error: " << cudaGetErrorString(error) << std::endl; + return rc; + } + + std::vector> arch(n_devices); + for (int cd = 0; cd < n_devices; ++cd) + { + cudaDeviceProp dev; + int rc = cudaGetDeviceProperties(&dev, cd); + if (rc != cudaSuccess) + { + cudaError_t error = cudaGetLastError(); + std::cout << "CUDA error: " << cudaGetErrorString(error) << std::endl; + return rc; + } + else + { + arch[cd] = {dev.major, dev.minor}; + } + } + + std::pair best_cc = *std::max_element(begin(arch), end(arch)); + std::cout << best_cc.first << best_cc.second; + + return 0; +} diff --git a/onnxruntime/contrib_ops/cuda/llm/common/cuda_runtime_utils.h b/onnxruntime/contrib_ops/cuda/llm/common/cuda_runtime_utils.h new file mode 100644 index 0000000000000..06442c6e02ae0 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/common/cuda_runtime_utils.h @@ -0,0 +1,46 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "core/providers/cuda/shared_inc/cuda_call.h" + +namespace onnxruntime::llm::common { +inline int getDevice() { + int deviceID{0}; + CUDA_CALL_THROW(cudaGetDevice(&deviceID)); + return deviceID; +} + +inline int getSMVersion() { + int device{-1}; + CUDA_CALL_THROW(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + CUDA_CALL_THROW(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); + CUDA_CALL_THROW(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +inline int getMultiProcessorCount() { + int nSM{0}; + int deviceID{0}; + CUDA_CALL_THROW(cudaGetDevice(&deviceID)); + CUDA_CALL_THROW(cudaDeviceGetAttribute(&nSM, cudaDevAttrMultiProcessorCount, deviceID)); + return nSM; +} +} // namespace onnxruntime::llm::common diff --git a/onnxruntime/contrib_ops/cuda/llm/common/logger.h b/onnxruntime/contrib_ops/cuda/llm/common/logger.h new file mode 100644 index 0000000000000..a3992e751926d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/common/logger.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/shared_library/provider_api.h" + +#ifndef NDEBUG +#define ORT_LLM_LOG_TRACE(msg) LOGS_DEFAULT(VERBOSE) << msg +#define ORT_LLM_LOG_DEBUG(msg) LOGS_DEFAULT(VERBOSE) << msg +#else +#define ORT_LLM_LOG_TRACE(msg) +#define ORT_LLM_LOG_DEBUG(msg) +#endif + +#define ORT_LLM_LOG_INFO(msg) LOGS_DEFAULT(INFO) << msg +#define ORT_LLM_LOG_WARNING(msg) LOGS_DEFAULT(WARNING) << msg +#define ORT_LLM_LOG_ERROR(msg) LOGS_DEFAULT(ERROR) << msg diff --git a/onnxruntime/contrib_ops/cuda/llm/common/workspace.h b/onnxruntime/contrib_ops/cuda/llm/common/workspace.h new file mode 100644 index 0000000000000..126884a941336 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/common/workspace.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace onnxruntime::llm::common { + +std::uintptr_t constexpr kCudaMemAlign = 128; + +inline int8_t* alignPtr(int8_t* ptr, uintptr_t to) { + uintptr_t addr = (uintptr_t)ptr; + if (addr % to) { + addr += to - addr % to; + } + return reinterpret_cast(addr); +} + +constexpr size_t alignSize(size_t size, size_t to) { + if ((size % to) != 0U) { + size += to - size % to; + } + return size; +} + +inline int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment) { + uintptr_t addr = (uintptr_t)ptr; + addr += previousWorkspaceSize; + return alignPtr(reinterpret_cast(addr), alignment); +} + +inline int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize) { + return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign); +} + +inline int8_t* nextWorkspacePtr( + int8_t* const base, uintptr_t& offset, uintptr_t const size, uintptr_t const alignment = kCudaMemAlign) { + uintptr_t curr_offset = offset; + uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment; + int8_t* newptr = size == 0 ? nullptr : base + curr_offset; + offset = next_offset; + return newptr; +} + +inline int8_t* nextWorkspacePtrWithAlignment( + int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment = kCudaMemAlign) { + return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment); +} + +inline size_t calculateTotalWorkspaceSize( + size_t const* workspaces, int count, uintptr_t const alignment = kCudaMemAlign) { + size_t total = 0; + for (int i = 0; i < count; i++) { + total += workspaces[i]; + if (workspaces[i] % alignment) { + total += alignment - (workspaces[i] % alignment); + } + } + return total; +} + +}; // namespace onnxruntime::llm::common diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h new file mode 100644 index 0000000000000..6de056b44339d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Templates exposing architecture support for multiply-add operations +*/ + +#pragma once +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +// Tag which triggers MMA which will trigger +struct OpMultiplyAddDequantizeInterleavedBToA; + +/* + Below we have extra tags to signal what kind of dequantization we want to do + (per col, scale only fine grained, finegrained with zero). This still lets us + the existing template infrastructure (incl. that in CUTLASS). However, we + split out the template below into OpMultiplyAddDequantizeInterleavedBToA along + with the quantization op before instantiating the GEMM pieces. + + Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of + code we need to duplicate. + */ +struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; + +// The default just forwards the original operator +template +struct TagOperator { + using TaggedOperator = MmaOp; +}; + +// Specializations below attach more information to the operator +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +}; + +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +}; + +template <> +struct TagOperator { + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; +}; + +// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original +// operator + the extra information. If no extra info was tagged, the dequant op per column scaling +// as a default. +template +struct DetagOperator { + using Operator = TaggedMmaOp; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +}; + +template <> +struct DetagOperator { + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +}; + +} // namespace arch +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h new file mode 100644 index 0000000000000..63dca2f458e1a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "cutlass/device_kernel.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime::llm { +namespace cutlass_extensions { + +template +inline int compute_occupancy_for_kernel() { + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size > (48 << 10)) { + cudaFuncAttributes attr; + int device = 0; + int max_smem_per_block = 0; + CUDA_CALL_THROW(cudaGetDevice(&device)); + CUDA_CALL_THROW( + cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + if constexpr (enable_cutlass_3x) { + CUDA_CALL_THROW(cudaFuncGetAttributes(&attr, cutlass::device_kernel)); + } else { + CUDA_CALL_THROW(cudaFuncGetAttributes(&attr, cutlass::Kernel)); + } + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) + // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this + // configuration. + return 0; + } + + if constexpr (enable_cutlass_3x) { + CUDA_CALL_THROW(cudaFuncSetAttribute( + cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } else { + CUDA_CALL_THROW(cudaFuncSetAttribute( + cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + } + + int max_active_blocks = -1; + if constexpr (enable_cutlass_3x) { + CUDA_CALL_THROW( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel, + 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size)); + } else { + CUDA_CALL_THROW(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); + } + + return max_active_blocks; +} + +} // namespace cutlass_extensions +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/thread/fused_activations.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/thread/fused_activations.h new file mode 100644 index 0000000000000..e0911460ef8a3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/thread/fused_activations.h @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Functor performing linear combination with a maximum operation used by epilogues. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__forceinline__ __device__ float tanh_opt(float x) { +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + float const exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h new file mode 100644 index 0000000000000..1d7ff42d591e2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h @@ -0,0 +1,122 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * @file epilogue_helpers.h + * + * This file includes types for the epilogues. The empty structs exist so we can signal to template + * code the type of epilogue we want to run, and let the underlying code specify the details such as + * element types, accumulator type and elements per vector access. + * + */ + +#pragma once + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/epilogue/thread/fused_activations.h" +#include + +namespace onnxruntime::llm { +namespace cutlass_extensions { + +struct EpilogueOpBiasSilu { +}; + +struct EpilogueOpBiasReLU { +}; + +struct EpilogueOpBiasFtGelu { +}; + +struct EpilogueOpBias { +}; + +struct EpilogueOpDefaultSilu { +}; + +struct EpilogueOpDefaultReLU { +}; + +struct EpilogueOpDefaultFtGelu { +}; + +struct EpilogueOpDefault { +}; + +template +struct Epilogue { + static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); +}; + +constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +} // namespace cutlass_extensions +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_interleaved.inl b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_interleaved.inl new file mode 100644 index 0000000000000..a7146d99224eb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_interleaved.inl @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_RS Mixed Scaled GEMM +template +struct CollectiveBuilderInterleaved + || cute::is_same_v + || cute::is_same_v)>> +{ + +private: + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementPairA_>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementPairB_>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementPairA_>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementPairB_>; + static constexpr bool NeitherIsTuple + = !cute::is_tuple::value && !cute::is_tuple::value; + +public: + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementPairA_>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementPairB_>; + static_assert(cute::is_tuple::value ^ cute::is_tuple::value + || (NeitherIsTuple && (sizeof_bits::value != sizeof_bits::value)), + "Either A OR B must be a tuple or the widths of A and B must be different."); + + static constexpr bool IsANarrow = sizeof_bits::value < sizeof_bits::value; + + using GmemLayoutATag = GmemLayoutATag_; + using GmemLayoutBTag = GmemLayoutBTag_; + + using ElementPairA = cute::conditional_t, ElementPairA_>; + using ElementPairB = cute::conditional_t, ElementPairB_>; + + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B(); + static constexpr bool IsWarpSpecializedTransposeB = detail::is_warpspecialized_transpose_B(); + static_assert(!IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B."); + + // If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to RF and we must swap the operands. + static constexpr bool SwapAB = !IsATransformed; + + // When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly. + static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB; + + using ElementMma = cute::conditional_t; + using AtomLayoutMNK = cute::conditional_t, + Layout>, Layout>>; + + using TiledMma + = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA + = decltype(detail::rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + using SmemLayoutAtomB + = decltype(detail::rs_smem_selector(TileShape_MNK{})), + decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + + using RealElementA = cute::conditional_t; + using RealElementB = cute::conditional_t; + static constexpr int PipelineStages + = detail::compute_stage_count_or_override_single_affine_transformed_input(StageCountType{}); + + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; + + using DispatchPolicy + = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; + + // We pack the scale data with the operand that will be optionally scaled and converted before MMA. + using StrideA = TagToStrideA_t; + using StrideB = TagToStrideB_t; + + using CollectiveOp = CollectiveMmaInterleaved; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_builder_interleaved.hpp b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_builder_interleaved.hpp new file mode 100644 index 0000000000000..97feaa2498bba --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_builder_interleaved.hpp @@ -0,0 +1,55 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_mma_interleaved.hpp" + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveBuilderInterleaved { + static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_interleaved.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_mma_interleaved.hpp b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_mma_interleaved.hpp new file mode 100644 index 0000000000000..ce56a9d717ceb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_mma_interleaved.hpp @@ -0,0 +1,55 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct CollectiveMmaInterleaved { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp new file mode 100644 index 0000000000000..499504439aa46 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/sm90_mma_interleaved_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -0,0 +1,1372 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop that source A operand from registers +template +struct CollectiveMmaInterleaved, + TileShape_, ElementAOptionalTuple, StrideA_, ElementBOptionalTuple, StrideB_, TiledMma_, GmemTiledCopyA_, + SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_> { + private: + template + static constexpr auto get_logical_ptr(PointerType const* ptr) { + if constexpr (cute::sizeof_bits_v < 8) { + return subbyte_iterator(ptr); + } else { + return ptr; + } + } + + template + static constexpr auto get_smem_interleave_layout() { + if constexpr (cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 8) { + return Layout(TileShape{})), Shape<_4, _4, _2, _4>>, + Stride<_128, Stride<_1, _8, _4, _32>>>{}; + } else if constexpr (cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 16) { + return Layout(TileShape{})), Shape<_2, _4, _4, _2>>, + Stride<_64, Stride<_1, _8, _2, _32>>>{}; + } else if constexpr (cute::sizeof_bits_v == 8 && cute::sizeof_bits_v == 16) { + return Layout(TileShape{})), Shape<_2, _4, _2, _4>>, + Stride<_64, Stride<_1, _4, _2, _16>>>{}; + } else { + static_assert(dependent_false, + "unsupported weight and activation, must be one of w4a8,w4a16,w8a16"); + } + } + + enum class ConversionMode { + DirectConvert, + ConvertAndScale, + ConvertAndScaleWithZero + }; + + using ScaleA = detail::deduce_mixed_width_dtype_t<1, ElementAOptionalTuple>; + using ScaleB = detail::deduce_mixed_width_dtype_t<1, ElementBOptionalTuple>; + using ZeroA = detail::deduce_mixed_width_dtype_t<2, ElementAOptionalTuple>; + using ZeroB = detail::deduce_mixed_width_dtype_t<2, ElementBOptionalTuple>; + + public: + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; + using TileShape = TileShape_; + + static_assert(cute::is_tuple::value ^ cute::is_tuple::value, + "Either A OR B must be a tuple. It must take the from {ElementOperand, [ElementScale]," + "[ElementZero]}. Inputs in [] are optional."); + + using ElementA = detail::deduce_mixed_width_dtype_t<0, ElementAOptionalTuple>; + using ElementB = detail::deduce_mixed_width_dtype_t<0, ElementBOptionalTuple>; + static constexpr bool IsATransformed = cute::is_tuple::value; + using ElementScale = cute::conditional_t; + using ElementZero = cute::conditional_t; + // For cases where we can't have a void type, we can use this to allow the code to compile when the scale / zero is + // void. + using NonVoidElementScale = cute::conditional_t, float, ElementScale>; + using NonVoidElementZero = cute::conditional_t, float, ElementZero>; + + using StrideA = StrideA_; + using StrideB = StrideB_; + // These are always MN major + using StrideScale = cute::Stride, int64_t, int64_t>; + // For cases where we can't have a void scale, we can use this to allow the code to compile when the scale is void. + using NonVoidStrideScale = cute::conditional_t, cute::Stride<_1, int64_t, int64_t>, StrideScale>; + + static_assert((IsATransformed && cutlass::gemm::detail::is_k_major()) || (!IsATransformed && cutlass::gemm::detail::is_k_major()), + "The transformed type must be K-major."); + + static_assert((IsATransformed && (sizeof(ElementB) == 2)) || (!IsATransformed && (sizeof(ElementA) == 2)) || (cutlass::gemm::detail::is_k_major() && cutlass::gemm::detail::is_k_major()), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + + static_assert(cutlass::gemm::detail::is_mn_major(), + "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); + + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using GmemTiledCopyScale = cute::SM90_TMA_LOAD; + + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + // Scale layout atom set after swapping. + + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using SmemCopyAtomScale = Copy_Atom; + + // We must ensure the type to be scaled goes to RF + static constexpr bool SwapAB = !IsATransformed; + using InternalSmemLayoutAtomA = cute::conditional_t; + using InternalSmemLayoutAtomB = cute::conditional_t; + using InternalSmemCopyAtomA = cute::conditional_t; + using InternalSmemCopyAtomB = cute::conditional_t; + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using RealInternalElementA = cute::conditional_t; + using RealInternalElementB = cute::conditional_t; + using InternalElementA = cute::conditional_t; + using InternalElementB = cute::conditional_t; + using InternalStrideA = cute::conditional_t; + using InternalStrideB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using InternalTransformA = cute::conditional_t; + using InternalTransformB = cute::conditional_t; + + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + // One threads per CTA are producers (1 for operand tile) + static constexpr int NumProducerThreadEvents = 1; + + using SmemLayoutAtomScale = Layout(InternalSmemLayoutAtomA{})), cute::Int<1>>>; + using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}), shape<1>(SmemLayoutAtomScale{}))); + static constexpr int type_factor = sizeof_bits::value / sizeof_bits::value; + + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, + "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomScale{}) == 2, "SmemLayoutAtomScale must be rank 2"); + static_assert( + (size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must equal the tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, + "SmemLayoutAtomScale must evenly divide tile k shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape(InternalSmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, InternalStrideA>(), Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); + + using Layout_Interleave = decltype(cute::composition(SmemLayoutA{}.layout_a(), SmemLayoutA{}.offset(), + get_smem_interleave_layout())); + using SmemLayoutA_mma_interleave = decltype(tile_to_shape(Layout_Interleave{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, InternalStrideA>(), Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); + using SmemLayoutA_mma = decltype(cute::composition(SmemLayoutA{}.layout_a(), SmemLayoutA{}.offset(), + make_layout(make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + make_stride(get<2>(TileShape{}), _1{}, get<0>(TileShape{}) * get<2>(TileShape{}))))); + // cute::conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideA>(), + // Stride<_1, cute::Int(TileShape{})>, cute::Int(TileShape{}) * + // get<2>(TileShape{})>>, Stride(TileShape{})>, _1, + // cute::Int(TileShape{}) * get<2>(TileShape{})>>>{}))); + + using SmemLayoutB = decltype(tile_to_shape(InternalSmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, InternalStrideB>(), Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); + + // It is assumed that the scales and zero-points share the same smem layout + using SmemLayoutScale = decltype(tile_to_shape(SmemLayoutAtomScale{}, + make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}), Int{}), + cute::conditional_t<::cutlass::gemm::detail::is_major<0, NonVoidStrideScale>(), Step<_2, _1, _3>, + Step<_1, _2, _3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(!cute::is_base_of::value && cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // To relax them, we need to handle loading more than 1 row of scales for every main loop iteration. + // We must also handle updating the pipeline transaction bytes on the fly. + // NOTE: Deleting this assertion without required changes will cause the code to hang. + static_assert(size<1>(SmemLayoutAtomScale{}) == 1, "size<1>(SmemLayoutAtomScale) must be 1."); + + private: + static constexpr ConversionMode get_conversion_mode() { + if constexpr (cute::is_void_v) { + return ConversionMode::DirectConvert; + } else if constexpr (cute::is_void_v) { + return ConversionMode::ConvertAndScale; + } else { + return ConversionMode::ConvertAndScaleWithZero; + } + } + + static constexpr ConversionMode KernelConversionMode = get_conversion_mode(); + static constexpr bool ModeHasScales = KernelConversionMode == ConversionMode::ConvertAndScale || KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; + + static constexpr auto elements_per_smem_scale() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } else if constexpr (ModeHasScales) { + return cute::cosize_v; + } else { + static_assert( + cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + static constexpr auto elements_per_smem_zero() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || KernelConversionMode == ConversionMode::ConvertAndScale) { + return 0; + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + return cute::cosize_v; + } else { + static_assert( + cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + // These methods use some the public members of the class. For that reason, we define them after the public section. + static constexpr uint32_t compute_tma_transaction_bytes_mk() { + constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return baseline_bytes; + } else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return baseline_bytes + scale_tx_bytes; + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return baseline_bytes + scale_tx_bytes + zero_tx_bytes; + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in tma transaction bytes computation."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Type not handled in tma transaction bytes computation."); + } + } + + static constexpr uint32_t compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + + public: + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + // Just pick the max alignment of A and B since it is required to be at least 128B + static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); + + static_assert(SmemAlignmentA >= 128 && SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage { + static constexpr int scale_elements = elements_per_smem_scale(); + static constexpr int zero_elements = elements_per_smem_zero(); + + struct TensorStorage : cute::aligned_struct { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + cute::ArrayEngine smem_scale; + cute::ArrayEngine smem_zero; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + ElementScale const* ptr_S = nullptr; + NonVoidStrideScale dS{}; + int group_size = 0; + ElementZero const* ptr_Z = nullptr; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + private: + using Outer = CollectiveMmaInterleaved; + + public: + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), + repeat_like(InternalStrideA{}, static_cast(0)), InternalStrideA{}), + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + + using TMA_Scale = decltype(make_tma_copy(GmemTiledCopyScale{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, static_cast(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_, _, cute::Int<0>{}), ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + using TMA_Zero = decltype(make_tma_copy(GmemTiledCopyScale{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), + repeat_like(NonVoidStrideScale{}, static_cast(0)), NonVoidStrideScale{}), + SmemLayoutScale{}(_, _, cute::Int<0>{}), ScaleTileShape{}, + _1{})); // mcast along N mode for this M load, if any. Scale is ALWAYS loaded with A for RF kernel + + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), + repeat_like(InternalStrideB{}, static_cast(0)), InternalStrideB{}), + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Scale tma_load_scale; + TMA_Zero tma_load_zero; + int64_t scale_k; + int group_size; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void)workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + if constexpr (SwapAB) { + M = get<1>(problem_shape_MNKL); + N = get<0>(problem_shape_MNKL); + } + + InternalElementA const* ptr_A; + InternalStrideA dA; + InternalElementB const* ptr_B; + InternalStrideB dB; + + if constexpr (not SwapAB) { + ptr_A = reinterpret_cast(args.ptr_A); + ptr_B = reinterpret_cast(args.ptr_B); + dA = args.dA; + dB = args.dB; + } else { + ptr_A = reinterpret_cast(args.ptr_B); + ptr_B = reinterpret_cast(args.ptr_A); + dA = args.dB; + dB = args.dA; + } + + Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), make_layout(make_shape(M, K, L), dA)); + Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), make_layout(make_shape(N, K, L), dB)); + typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + + typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + typename Params::TMA_Scale tma_load_scale; + typename Params::TMA_Zero tma_load_zero; + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, 0, 0, TmaTransactionBytes, + TmaTransactionBytesMK, TmaTransactionBytesNK}; + } else if constexpr (ModeHasScales) { + auto scale_k = (K + args.group_size - 1) / args.group_size; + ElementScale const* ptr_S = args.ptr_S; + StrideScale dS = args.dS; + Tensor tensor_scale = make_tensor(get_logical_ptr(ptr_S), make_layout(make_shape(M, scale_k, L), dS)); + tma_load_scale = make_tma_copy(GmemTiledCopyScale{}, tensor_scale, SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, _1{}); // mcast along N mode for this M load, if any + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, + TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK}; + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tensor_zero = make_tensor(get_logical_ptr(args.ptr_Z), make_layout(make_shape(M, scale_k, L), dS)); + tma_load_zero = make_tma_copy(GmemTiledCopyScale{}, tensor_zero, SmemLayoutScale{}(_, _, cute::Int<0>{}), + ScaleTileShape{}, _1{}); // mcast along N mode for this M load, if any + return {tma_load_a, tma_load_b, tma_load_scale, tma_load_zero, scale_k, args.group_size, + TmaTransactionBytes, TmaTransactionBytesMK, TmaTransactionBytesNK}; + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in to_underlying_arguments."); + } + } + + template + static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + implementable = implementable && (args.ptr_S == nullptr); + implementable = implementable && (args.ptr_Z == nullptr); + } else if constexpr (ModeHasScales) { + int const scale_mn = SwapAB ? N : M; + int const scale_k = (K + args.group_size - 1) / args.group_size; + constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment( + cute::make_shape(scale_mn, scale_k, L), StrideScale{}); + implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0)); + implementable = implementable && args.group_size != 0; + implementable = implementable && (args.ptr_S != nullptr); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + implementable = implementable && (args.ptr_Z == nullptr); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment( + cute::make_shape(scale_mn, scale_k, L), StrideScale{}); + implementable = implementable && (args.ptr_Z != nullptr); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in can_implement."); + } + + if (!implementable) { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk(); + static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk(); + static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor()); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in TMA prefetch."); + } + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(gA_mkl, gB_nkl); + } else if constexpr (ModeHasScales) { + auto scale_k = mainloop_params.scale_k; + Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l) + Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l) + Tensor gZ_mkl = local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gS_mkl, gZ_mkl); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in load_init."); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + /// This overload gets triggered when we have scales. + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, uint32_t block_rank_in_cluster, TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + static_assert(sizeof...(Ts) == 2, "Direct convert needs two inputs"); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + static_assert(sizeof...(Ts) == 3, "Scaled convert needs three inputs"); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(sizeof...(Ts) == 4, "Scaled and zero convert needs four inputs"); + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in TMA load."); + } + + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A, B and Scales + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_s = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); + } + } + + auto extra_input_partitions = partition_extra_tma_inputs( + mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // Nothing extra to do. + } else if constexpr (ModeHasScales) { + auto tSgS = get<0>(extra_input_partitions); + auto tSsS = get<1>(extra_input_partitions); + + // Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify + // tma transaction bytes on the fly. We must do a ceiling divide here to correctly handle with + // group_size == K. In that case, we don't require that K is a multiple of the threadblock tile K + int const ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{}); + int const scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K. + copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_, _, _, scale_load_k), + tSsS(_, _, _, write_stage)); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tZgZ = get<2>(extra_input_partitions); + auto tZsZ = get<3>(extra_input_partitions); + copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), + tZgZ(_, _, _, scale_load_k), tZsZ(_, _, _, write_stage)); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for TMA copy op."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for TMA copy op."); + } + + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + template + constexpr auto interleave_for_mixed_input() { + if constexpr (cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 8) { + return Layout, _1, Shape<_2, _2>>, + Stride, _0, Stride<_16, _32>>>{}; + } else if constexpr (cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 16) { + return Layout, _1, Shape<_2>>, + Stride, _0, Stride<_16>>>{}; + } else if constexpr (cute::sizeof_bits_v == 8 && cute::sizeof_bits_v == 16) { + return Layout, _1, Shape<_2, _2>>, + Stride, _0, Stride<_8, _16>>>{}; + } else { + static_assert(dependent_false, + "unsupported weight and activation, must be one of w4a8,w4a16,w8a16"); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum, + int k_tile_count, int thread_idx, TensorStorage& shared_tensors, Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); + static_assert(cute::rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for RF sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + Tensor sA_ = make_tensor( + make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA_mma_interleave{}); // (BLK_M,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + // Layout of warp group to thread mapping + + static_assert(stride<0>(typename TiledMma::BLayout{}) == 0 and size<0>(typename TiledMma::BLayout{}) == NumThreadsPerWarpGroup, + "Stride of the first mode must be 0 and the size of the mode must be NumThreadsPerWarpGroup"); + + constexpr int MmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; + Layout warp_group_thread_layout = make_layout(Int{}, Int{}); + + int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0); + + TiledMma tiled_mma; + auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsA = mma_thread_slice.partition_A(sA); + auto mma_warpgroup_slice = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)); + + auto interleave_layout = interleave_for_mixed_input(); + + auto interleave_remapping = cute::flat_product(interleave_layout, Layout>>{}); + + Tensor tCsA_remapped = tCsA.compose(interleave_remapping); + + auto interleave_remapping_thread = right_inverse(interleave_layout); + + // Allocate fragments and descriptors + Tensor tCrA_mma = mma_thread_slice.partition_fragment_A(sA(_, _, Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrA_load = make_fragment_like(tCrA_mma); + + Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // Copy Atom A retiling + // + auto smem_tiled_copy_A = make_tiled_copy_A(InternalSmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(warp_group_thread_idx); + + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) + + // Compute the max vector length that can be used to copy A. This will match the vector width of the + // conversions used. It helps by allowing the compiler to convert using the same register that was used + // to load the data from smem. This significantly reduces the need to move data among registers. + // Note that this is correct even if copy fails to vectorize, since the granularity at which we perform + // the conversion does not impact correctness. + using A_CPY_VEC = decltype(max_common_vector(tCsA, tCrA_copy_view)); + using A_CPY_VEC_remapped = decltype(max_common_vector(tCsA_remapped, tCrA_copy_view)); + static_assert(A_CPY_VEC_remapped{} == 32 / cutlass::sizeof_bits::value, + "max_common_vector(tCsA_remapped, tCrA_copy_view) is 32 / cutlass::sizeof_bits::value"); + auto tCrA_mma_tmp = tCrA_mma.compose(interleave_remapping_thread); + auto tCrA_mma_inverse_mapping = tCrA_mma_tmp.compose(tCrA_mma.layout()); + + auto tCrA_load_tmp = tCrA_load.compose(interleave_remapping_thread); + auto tCrA_load_inverse_mapping = tCrA_load_tmp.compose(tCrA_load.layout()); + + // Partition of thread -> shared and thread -> RF + auto partitioned_extra_info = partition_extra_mma_info(mma_thread_slice, shared_tensors); + auto copy_partitions_extra_info = retile_extra_mma_info(tiled_mma, partitioned_extra_info, warp_group_thread_idx); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + + constexpr int K_BLOCK_MAX = size<2>(tCrA_load); + + constexpr int kNumKIterationsPerWarpBLoad = type_factor / 2; + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 0, read_stage, kNumKIterationsPerWarpBLoad); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 1, read_stage, kNumKIterationsPerWarpBLoad); + } + + transform_A_kblock( + tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, 0, kNumKIterationsPerWarpBLoad); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma_inverse_mapping(_, _, k_block), tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + if (k_block < K_BLOCK_MAX - 2) // prefetch next block + { + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, k_block + 2, read_stage, kNumKIterationsPerWarpBLoad); + } + if (k_block < K_BLOCK_MAX - 1) { + transform_A_kblock(tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, k_block + 1, + kNumKIterationsPerWarpBLoad); + } + } + + --k_tile_count; + if (k_tile_count > 0) { + // Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to overwrite the A registers for + // the first mma. + pipeline.consumer_wait(smem_pipe_read, barrier_token); + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 0, smem_pipe_read.index(), kNumKIterationsPerWarpBLoad); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 1, smem_pipe_read.index(), kNumKIterationsPerWarpBLoad); + } + warpgroup_wait(); + transform_A_kblock( + tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, 0, kNumKIterationsPerWarpBLoad); + } + } + + if (k_tile_count == 0) { + return; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 1; --k_tile_count) { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma_inverse_mapping(_, _, k_block), tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + + warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this + // stage, so we can release prior barrier + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + + if (k_block == K_BLOCK_MAX - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 0, smem_pipe_read.index(), kNumKIterationsPerWarpBLoad); + if (K_BLOCK_MAX > 1) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, 1, smem_pipe_read.index(), kNumKIterationsPerWarpBLoad); + } + transform_A_kblock(tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, 0, + kNumKIterationsPerWarpBLoad); + } else { + if (k_block < K_BLOCK_MAX - 2) { // prefetch next block + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, k_block + 2, read_stage, kNumKIterationsPerWarpBLoad); + } + transform_A_kblock(tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, k_block + 1, + kNumKIterationsPerWarpBLoad); + } + } + warpgroup_fence_operand(accum); + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma_inverse_mapping(_, _, k_block), tCrB(_, _, k_block, read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait(); + if (k_block == K_BLOCK_MAX - 1) // release prior barrier + { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + + if (k_block < K_BLOCK_MAX - 2) // prefetch next block + { + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, k_block + 2, read_stage, kNumKIterationsPerWarpBLoad); + } + + if (k_block < K_BLOCK_MAX - 1) { + copy_A_and_extra_info(smem_tiled_copy_A, tCsA_remapped, tCrA_copy_view, partitioned_extra_info, + copy_partitions_extra_info, k_block + 1, read_stage, kNumKIterationsPerWarpBLoad); + transform_A_kblock(tCrA_load, A_CPY_VEC_remapped{}, tCrA_mma, partitioned_extra_info, k_block + 1, + kNumKIterationsPerWarpBLoad); + } + } + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + private: + /// Utilities for any additional inputs inside of the TMA load + template + CUTLASS_DEVICE auto partition_extra_tma_inputs(Params const& mainloop_params, cute::tuple const& load_inputs, + TensorStorage& shared_tensors, uint2 const& cluster_local_block_id, int const m_coord, int const l_coord) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(); + } else if constexpr (ModeHasScales) { + Tensor sS = make_tensor( + make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gS_mkl = get<2>(load_inputs); + auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tSgS, tSsS); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor( + make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gZ_mkl = get<3>(load_inputs); + auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for input partitioning."); + } + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled for input partitioning."); + } + } + + template + constexpr auto scale_remapping() { + if constexpr (cute::sizeof_bits_v == 8) { + return Layout, Stride<_1, _8, _4>>{}; + } else if constexpr (cute::sizeof_bits_v == 16) { + return Layout, Stride<_1, _4, _2>>{}; + } else { + static_assert(dependent_false, "cute::sizeof_bits_v must be 8 or 16"); + } + } + + /// Utilities for partitioning extra inputs for loading from smem in the mainloop. + template + CUTLASS_DEVICE auto partition_extra_mma_info(ThreadMma const& mma_thread_slice, TensorStorage& shared_tensors) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } else if constexpr (ModeHasScales) { + Tensor sS = make_tensor( + make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + auto remappingScale = scale_remapping(); + Tensor tCsS_remapped = tCsS.compose(remappingScale, _, _, _); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_, _, Int<0>{})).shape()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS_remapped, tCrS); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor( + make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = mma_thread_slice.partition_A(sZ); + Tensor tCsZ_remapped = tCsZ.compose(remappingScale, _, _, _); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_, _, Int<0>{})).shape()); + return cute::make_tuple(tCsS_remapped, tCrS, tCsZ_remapped, tCrZ); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + /// Returns the tiled copy and copy views for the extra inputs. + template + CUTLASS_DEVICE auto retile_extra_mma_info( + TiledMma const& tiled_mma, cute::tuple& partitioned_extra_info, int const warp_group_thread_idx) { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + /// Utilities to copy A and extra inputs from smem to RF + template + CUTLASS_DEVICE void copy_A_and_extra_info(SmemTiledCopyA const& smem_tiled_copy_A, TensorASmemView const& tCsA, + TensorACopyView& tCrA_copy_view, cute::tuple const& partitioned_mma_extra_info, + cute::tuple const& tiled_copy_and_views, int k_block, int read_stage, int kNumKIterationsPerWarpBLoad) { + if (kNumKIterationsPerWarpBLoad == 1) { + copy(smem_tiled_copy_A, tCsA(_, _, k_block, read_stage), tCrA_copy_view(_, _, k_block)); + } else { + using reshape_layout = Layout, Int<1>, Int<2>>>; + auto tCrA_copy_view_reshaped = tCrA_copy_view.compose(reshape_layout{}); + if (k_block % kNumKIterationsPerWarpBLoad == 0) + copy(smem_tiled_copy_A, tCsA(_, _, k_block / kNumKIterationsPerWarpBLoad, read_stage), + tCrA_copy_view_reshaped(_, _, k_block / kNumKIterationsPerWarpBLoad)); + } + if (k_block == 0) { + // We are starting a new k-tile so copy the scale + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); + auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); + auto tCsS = cute::get<0>(partitioned_mma_extra_info); + copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage), tCrS_copy_view(_, _, k_block)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCsZ = cute::get<2>(partitioned_mma_extra_info); + auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); + copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage), tCrZ_copy_view(_, _, k_block)); + } else { + static_assert(cutlass::detail::dependent_false, + "Conversion mode not handled in A -> RF path."); + } + } else { + static_assert( + cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + } + + /// Utilities to transform A. + template + CUTLASS_DEVICE void transform_A_kblock(TCrA_load const& tCrA_load, cute::Int vec_A, + TCrA_mma& tCrA_mma, cute::tuple const& partitioned_extra_info, int const k_block, + int kNumKIterationsPerWarpBLoad) { + if (kNumKIterationsPerWarpBLoad != 1) { + if (k_block % kNumKIterationsPerWarpBLoad == 0) { + int k_block_load = k_block / kNumKIterationsPerWarpBLoad; + using reshape_layout = Layout, _1, _2>>; + auto tCrA_load_reshaped = tCrA_load.compose(reshape_layout{}); + auto tCra_mma_reshaped = tCrA_mma.compose(reshape_layout{}); + + using scale_reshape = Layout, _1, _1>, Stride, _0, _0>>; + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + transform_internal_A( + tCrA_load_reshaped(_, _, k_block_load), vec_A, tCra_mma_reshaped(_, _, k_block_load)); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto tCrS_reshaped = tCrS.compose(scale_reshape{}); + transform_internal_A(tCrA_load_reshaped(_, _, k_block_load), vec_A, + make_fragment_like(tCra_mma_reshaped)(_, _, k_block_load), tCrS_reshaped(_, _, 0), + tCra_mma_reshaped(_, _, k_block_load)); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto tCrS_reshaped = tCrS.compose(scale_reshape{}); + auto tCrZ = cute::get<3>(partitioned_extra_info); + auto tCrZ_reshaped = tCrZ.compose(scale_reshape{}); + transform_internal_A(tCrA_load_reshaped(_, _, k_block_load), vec_A, + make_fragment_like(tCra_mma_reshaped)(_, _, k_block_load), tCrS_reshaped(_, _, 0), + tCrZ_reshaped(_, _, 0), tCra_mma_reshaped(_, _, k_block_load)); + } else { + static_assert(cutlass::detail::dependent_false, "No A data is loaded."); + } + } + } else { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + transform_internal_A(tCrA_load(_, _, k_block), vec_A, tCrA_mma(_, _, k_block)); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + auto tCrS = cute::get<1>(partitioned_extra_info); + transform_internal_A(tCrA_load(_, _, k_block), vec_A, + make_fragment_like(tCrA_mma)(_, _, k_block), tCrS(_, _, 0), tCrA_mma(_, _, k_block)); + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCrS = cute::get<1>(partitioned_extra_info); + auto tCrZ = cute::get<3>(partitioned_extra_info); + transform_internal_A(tCrA_load(_, _, k_block), vec_A, + make_fragment_like(tCrA_mma)(_, _, k_block), tCrS(_, _, 0), tCrZ(_, _, 0), + tCrA_mma(_, _, k_block)); + } else { + static_assert(cutlass::detail::dependent_false, "No A data is loaded."); + } + } + } + + /// Utilities for transforming the A operand prior to issuing tensorcore math. + template > + CUTLASS_DEVICE void convert_tensor(Tensor const& in, Tensor& out, + cute::Int width = {}) { + /// This is an element-wise conversion where we expect both tensors to have the same layout. + /// As a result, we can cast as a cutlass array to use the fast numeric converters without + /// worrying about indexing into the layout. + constexpr int N = cosize_v; + + /// The inputs must be backed by registers & be statically sized. + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(is_static_v, "Tensor layout for the conversion must be static"); + static_assert(cosize_v == size(TensorLayout{}), "Cosize and size of the layout must be equal."); + static_assert( + N % ConversionVectorWidth == 0, "Conversion vector width must divide cosize of the tensor layout."); + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + + using Converter = std::conditional_t < cutlass::sizeof_bits_v, + cutlass::FastInterleavedAndBiasedNumericArrayConverter, + cutlass::NumericArrayConverter>; + + constexpr int NumIterations = N / ConversionVectorWidth; + + for (int ii = 0; ii < NumIterations; ++ii) { + SrcArray const* src_array_ptr = reinterpret_cast(raw_pointer_cast(in.data())) + ii; + DstArray* dst_array_ptr = reinterpret_cast(raw_pointer_cast(out.data())) + ii; + *dst_array_ptr = Converter::convert(*src_array_ptr); + } + } + + template + CUTLASS_DEVICE void transform_internal_A(Tensor&& in, + cute::Int a_vec_width, Tensor&& out) { + convert_tensor(in, out, a_vec_width); + } + + template + CUTLASS_DEVICE void transform_internal_A(Tensor&& in, + cute::Int a_vec_width, Tensor&& converted_inputs, + Tensor&& scales, Tensor&& out) { + static_assert(cute::is_same_v, + "Type of the engine input buffer must equal the scale buffer"); + + // First, we upcast the inputs to the scale type + convert_tensor(in, converted_inputs, a_vec_width); + + // Apply scales and broadcast across inputs, store in converted_inputs + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(converted_inputs); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(converted_inputs); ++j) { + if constexpr (cute::is_same_v) { + converted_inputs(j, i) = bfloat16_t(__hmul(reinterpret_cast<__nv_bfloat16 const&>(converted_inputs(j, i)), + reinterpret_cast<__nv_bfloat16 const&>(scales(j, i)))); + } else { + converted_inputs(j, i) *= scales(j, i); + } + } + } + + // Finally, we convert the scaled inputs to the mma type. + convert_tensor(converted_inputs, out); + } + + template + CUTLASS_DEVICE void transform_internal_A(Tensor&& in, + cute::Int a_vec_width, Tensor&& converted_inputs, + Tensor&& scales, Tensor&& zeros, + Tensor&& out) { + static_assert(cute::is_same_v, + "Type of the engine input buffer must equal the scale buffer"); + + static_assert(cute::is_same_v, + "Type of the engine zero buffer must equal the scale buffer"); + + // First, we upcast the inputs to the scale type + convert_tensor(in, converted_inputs, a_vec_width); + + // Apply scales and broadcast across inputs, store in converted_inputs + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(converted_inputs); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(converted_inputs); ++j) { + if constexpr (cute::is_same_v) { + converted_inputs(j, i) = bfloat16_t(__hfma(reinterpret_cast<__nv_bfloat16 const&>(converted_inputs(j, i)), + reinterpret_cast<__nv_bfloat16 const&>(scales(j, i)), + reinterpret_cast<__nv_bfloat16 const&>(zeros(j, i)))); + } else { + converted_inputs(j, i) = converted_inputs(j, i) * scales(j, i) + zeros(j, i); + } + } + } + + // Finally, we convert the scaled inputs to the mma type. + convert_tensor(converted_inputs, out); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h new file mode 100644 index 0000000000000..c7f2a682323a0 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h @@ -0,0 +1,370 @@ +/* + * Copyright (c) 2017-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// #include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/* + This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) + It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs + and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. + + Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support + that feature at the moment. + */ + +template +class GemmUniversalBaseCompat { + public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + + protected: + /// Kernel parameters object + typename GemmKernel::Params params_; + + protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + + public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) { + return Status::kErrorInvalidProblem; + } + + return GemmKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); + + size_t workspace_bytes = 0; + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); + } else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) { + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); + + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); + + return result; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); + + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + if (smem_size <= (48 << 10)) { + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); + + if (result == cudaSuccess) { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } else { + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + + return -1; + } + + if (smem_capacity < 0) { + int device_idx = 0; + result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + return -1; + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + return -1; + } + + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + + return occupancy; + } + + CUTLASS_TRACE_HOST(" returning internal error"); + + return -1; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + size_t workspace_bytes = get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) { + if (!workspace) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + + return Status::kErrorInternal; + } + } + } + + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + // Initialize the Params structure + params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + + // + // Configure grid and block dimensions + // + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); + + // + // Launch kernel + // + + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h new file mode 100644 index 0000000000000..83ebe2191717b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -0,0 +1,149 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/half.h" +#include "cutlass/layout/matrix.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct MixedGemmArchTraits { + static_assert(dependent_false, "Unrecognized parameterization"); +}; + +template +struct MixedGemmArchTraits { + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::ColumnMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// ======================= Turing Traits ============================== +// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. +template +struct MixedGemmArchTraits::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ampere Traits ============================== +template +struct MixedGemmArchTraits::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ada Traits ============================== +template +struct MixedGemmArchTraits::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; +}; + +// FP8 A/B = fp8, C/D = fp32 +template +struct MixedGemmArchTraits::value || cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + // be careful, TypeC should align with TmaWarpSpecializedGroupedGemmInput::OutputTypeAdaptor_t + using TypeC = __nv_bfloat16; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_int8_traits.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_int8_traits.h new file mode 100644 index 0000000000000..fe4bc0940d9e8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_int8_traits.h @@ -0,0 +1,51 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassSimt; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; +}; + +// ======================= Turing Traits ============================== +template <> +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; +}; + +// ======================= Ampere Traits ============================== +template <> +struct Int8GemmArchTraits { + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h new file mode 100644 index 0000000000000..a888ea3e71487 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -0,0 +1,461 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +template +inline constexpr bool dependent_false_v = false; +} + +template +struct GemmFpAIntB { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Element; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + int group_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + + // Control serial split-k + int batch_count; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // Included so we can use Gemm Universal + int batch_stride_D = 0; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Arguments() {} + + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, + typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) + : problem_size(problem_size), group_size(group_size), ref_A(ref_A), ref_B(ref_B), ref_scale(ref_scale), ref_zero(ref_zero), ref_C(ref_C), ref_D(ref_D), batch_count(serial_split_k_factor), output_op(output_op), gather_A_indices(gather_A_indices), gather_B_indices(gather_B_indices), scatter_D_indices(scatter_D_indices) { + } + }; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + int group_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::Params params_scale; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + int* semaphore; + int gemm_k_size; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, + void* workspace = nullptr) + : problem_size(args.problem_size), group_size(args.group_size), grid_tiled_shape(grid_tiled_shape), swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), params_A(args.ref_A.layout()), ref_A(args.ref_A), params_B(args.ref_B.layout()), ref_B(args.ref_B), params_scale(args.ref_scale.layout()), ref_scale(args.ref_scale), ref_zero(args.ref_zero), params_C(args.ref_C.layout()), ref_C(args.ref_C), params_D(args.ref_D.layout()), ref_D(args.ref_D), output_op(args.output_op), semaphore(static_cast(workspace)), gemm_k_size(gemm_k_size), gather_A_indices(args.gather_A_indices), gather_B_indices(args.gather_B_indices), scatter_D_indices(args.scatter_D_indices) { + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(Arguments const& args) { + static int const alignmentA = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const alignmentB = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + + static int const alignmentScale = Mma::IteratorScale::AccessType::kElements; + + static int const alignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(args.ref_A, alignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_B, alignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_scale, alignmentScale)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_zero, alignmentScale)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_C, alignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_D, alignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!args.ref_scale.good()) { + return Status::kErrorNotSupported; + } + + if constexpr (hasZero(Mma::QuantOp)) { + if (!args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } else { + if (args.ref_zero.good()) { + return Status::kErrorNotSupported; + } + } + + if constexpr (isFinegrained(Mma::QuantOp)) { + if (args.group_size != 64 && args.group_size != 128) { + return Status::kErrorNotSupported; + } + } + + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& /*args*/, cutlass::gemm::GemmCoord const& /*grid_tiled_shape*/) { + return 0; + } + + // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator + // has a different constructor signature than a regular cutlass iterator + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) { + return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); + } + + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) { + return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert(platform::is_same::value && kInterleave == 1 || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; + cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices); + + typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B, + params.gather_B_indices); + + typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = initialize_scale( + params.params_scale, params.ref_scale.data(), params.ref_zero.data(), + {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ == 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 1000) + // Use SM80 implementation for GB10x, GB20x. + run_kernel(params, shared_storage); +#else + CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h new file mode 100644 index 0000000000000..163a43238a425 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h @@ -0,0 +1,451 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized softmax partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. + + original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/trace.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h" + +namespace tk = onnxruntime::llm::common; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithEpilogueVisitor { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementCompute = typename EpilogueVisitor::ElementCompute; + using LayoutAlphaCol = cutlass::layout::RowMajor; + using LayoutAlphaRow = cutlass::layout::ColumnMajor; + using TensorRefAlphaCol = TensorRef; + using TensorRefAlphaRow = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + using EpilogueOutputOp = + typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + TensorRefA ref_A; + TensorRefB ref_B; + tk::QuantMode quant_option; + TensorRefAlphaCol ref_alpha_col; + TensorRefAlphaRow ref_alpha_row; + TensorRefC ref_C; + TensorRefC ref_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_D; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments() + : mode(GemmUniversalMode::kGemm), batch_count(1) { + } + + /// constructs an arguments structure + Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, + TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_, + TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_, + int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_) + : mode(mode_), problem_size(problem_size_), batch_count(batch_count_), ref_A(ref_A_), ref_B(ref_B_), quant_option(quant_option_), ref_alpha_col(ref_alpha_col_), ref_alpha_row(ref_alpha_row_), ref_C(ref_C_), ref_D(ref_D_), batch_stride_A(batch_stride_A_), batch_stride_B(batch_stride_B_), batch_stride_D(0), epilogue_visitor(epilogue_visitor_) { + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void* ptr_A; + void* ptr_B; + tk::QuantMode quant_option; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0), params_A(0), params_B(0), params_alpha_col(0), params_C(0), params_D(0), batch_count(0), gemm_k_size(0), mode(cutlass::gemm::GemmUniversalMode::kGemm), ptr_A(nullptr), ptr_B(nullptr), ptr_alpha_col(nullptr), ptr_alpha_row(nullptr), ptr_C(nullptr), ptr_D(nullptr), batch_stride_A(0), batch_stride_B(0) { + } + + Params( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) + : problem_size(args.problem_size), swizzle_log_tile(0), params_A(args.ref_A.layout()), params_B(args.ref_B.layout()), params_alpha_col(args.ref_alpha_col.layout()), params_alpha_row(args.ref_alpha_col.layout()), params_C(args.ref_C.layout()), params_D(args.ref_D.layout()), mode(args.mode), batch_count(args.batch_count), gemm_k_size(args.problem_size.k()), ptr_A(args.ref_A.data()), ptr_B(args.ref_B.data()), quant_option(args.quant_option), ptr_alpha_col(args.ref_alpha_col.data()), ptr_alpha_row(args.ref_alpha_row.data()), ptr_C(args.ref_C.data()), ptr_D(args.ref_D.data()), batch_stride_A(args.batch_stride_A), batch_stride_B(args.batch_stride_B), epilogue_visitor(args.epilogue_visitor) { + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + + struct + { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + +#define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + +#if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } +#endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, + params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, + params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C, + params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m()); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + // TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. + run_kernel(params, shared_storage); +#else + static_assert( + false, "Invalid architecture being compiled. Only Ampere+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h new file mode 100644 index 0000000000000..c0656ac784830 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -0,0 +1,112 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is + quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices + to be consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct LayoutDetailsB { +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. +// TODO - Switch this to column major for weights since gemms should be more performant. +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB { + static constexpr int ThreadblockK = 64; + + private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; + // for fast accumulation + // using Operator = cutlass::arch::OpMultiplyAddFastAccum; +}; + +// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, +// which signals that we want to dequantize after loading from smem. +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + + private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template +struct LayoutDetailsB= 75>::type> { + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + + private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h new file mode 100644 index 0000000000000..ef28dcc46cd21 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h @@ -0,0 +1,117 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { +//////////////////////////////////////////////////////////////////////////////// + +// We need to distinguish here, since we want volta support. It is too much effort +// to write shared memory iterators that are probably needed for volta to function +// properly. As a result, we allow converters both after the LDG (for volta) and after +// the LDS for Turing+. +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Warp level Mma + typename MmaOperator, + /// Math operation perform by warp level operator + typename MathOperator> +struct SetConverters { +}; + +// Dequantize after LDG, so set transforms accordingly +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = FastInterleavedAndBiasedNumericArrayConverter; + + using TransformAfterLDS = NumericArrayConverter; +}; + +// Dequantize after LDS, so set transforms accordingly + +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = NumericArrayConverter; + + using TransformAfterLDS = FastInterleavedAndBiasedNumericArrayConverter; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale_, + /// Layout for the scale operand + typename LayoutScale_, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// + typename Enable = void> +struct DqMma; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h new file mode 100644 index 0000000000000..8d73329ed7713 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -0,0 +1,289 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIteratorsMultistage; + +// Fine grained iterators +template +struct DefaultScaleIteratorsMultistage> { + using IteratorScale = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +// Per column iterators +template +struct DefaultScaleIteratorsMultistage> { + // ThreadMap for scale iterator + static_assert((MmaShape::kN % Alignment) == 0, ""); + + private: + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + + public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, + Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && !layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value || platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, + AccessTypeB>; + + using ScaleIterators = DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && layout::IsColumnMajorTileInterleave::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value || platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator; + + using ScaleIterators = DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h new file mode 100644 index 0000000000000..ae0cee20d3575 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -0,0 +1,270 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIteratorsPipelined; + +// Fine grained iterators +template +struct DefaultScaleIteratorsPipelined> { + private: + using SmemScaleType = half_t; + + public: + using IteratorScale = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, + SmemScaleType, Layout, 0, Alignment>; +}; + +// Per column iterators +template +struct DefaultScaleIteratorsPipelined> { + static_assert((MmaShape::kN % Alignment) == 0, ""); + + private: + // ThreadMap for scale iterator + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + using SmemScaleType = half_t; + + public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, + Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, SmemScaleType, + Layout, 0, IteratorScaleThreadMap, Alignment>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + static constexpr bool DqAfterLDG = platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> { + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + + static constexpr bool DqAfterLDG = platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma.h new file mode 100644 index 0000000000000..dfe99c271f547 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma.h @@ -0,0 +1,336 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma_bf16.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +#ifdef ENABLE_FP8 +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +#endif + +// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma_bf16.h new file mode 100644 index 0000000000000..cb5ce0f72b362 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -0,0 +1,336 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + private: + using MmaElementA = bfloat16_t; + using MmaElementB = bfloat16_t; + + public: + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined; +}; + +// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, + AccessTypeA, GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, + AccessTypeB, GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h new file mode 100644 index 0000000000000..cad280febbe76 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// The type of the scales + typename ElementScale_, + /// Number of stages, + int Stages, + /// The dequantizing op to be performed. + WeightOnlyQuantOp DequantOp, + /// Used for partial specialization, + typename Enable = bool> +class DqMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Type of the scale to be loaded + using ElementScale = ElementScale_; + + static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); + + // Finegrained scales get streamed in via cp.async + static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; + // We always have scales. + static constexpr int ScaleElementsPerStage = Shape::kN; + // We sometimes have a bias + static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + static constexpr int kNumKIterationsPerWarpBLoad = Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; + + /// Shape of the shared memory buffer for the scales for the B matrix. + using ShapeScale = MatrixShape; + /// Shape of the shared memory buffer for the biases of the B matrix. + using ShapeZero = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_scale; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_zero; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h new file mode 100644 index 0000000000000..78b6abb50513f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = void> +class DqMmaMultistage; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h" diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h new file mode 100644 index 0000000000000..5db74039469c4 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -0,0 +1,612 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applied immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + /// The group size for quantization + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1) { + static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + + typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale(); + typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero(); + + typename IteratorScale::AccessType* smem_scale_ptr = reinterpret_cast(this->smem_iterator_scale_.get_scale()); + typename IteratorScale::AccessType* smem_zero_ptr = reinterpret_cast(this->smem_iterator_scale_.get_zero()); + + int const kSrcBytes = sizeof_bits::value * IteratorScale::kAlignment / 8; + + cutlass::arch::cp_async(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) { + cutlass::arch::cp_async(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); + } + + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if constexpr (Shape::kK == 128) { + iterator_scale.add_tile_offset({1, 0}); + } else if constexpr (Shape::kK == 64) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); + } + } else { + static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + typename Dequantizer::FragmentZero warp_frag_zeros; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); + + using FragmentOperandB = cutlass::Array; + constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + + using Converter = cutlass::NumericArrayConverter; + + FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); + warp_mma( + accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // This is the first group of a given stage, so we issue the loads for the B scales immediately. + if (group_start_iteration_B == 0) { + copy_scales_and_advance(iterator_scale); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - + // #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + } + } + + // Load the scale needed for the next tile iteration. + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + // Update internal pointer to set of scales in shared memory. + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h new file mode 100644 index 0000000000000..e992915cafeea --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Used for partial specialization + typename Enable = void> +class DqMmaPipelined; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h" diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h new file mode 100644 index 0000000000000..b362195834c87 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h @@ -0,0 +1,431 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_> +class DqMmaPipelined> + : public DqMmaBase { + public: + ///< Base class + using Base = DqMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + + static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + using WarpFragmentZero = typename Dequantizer::FragmentZero; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int const group_size, ///< The group size for quantization + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx), smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale) { + using TransformScale = NumericArrayConverter; + + FragmentScale tb_frag_scales; + FragmentScale tb_frag_zeros; + tb_frag_scales.clear(); + tb_frag_zeros.clear(); + + TransformScale transformScale; + + using FragmentElement = typename FragmentScale::Element; + + auto gmem_scale_ptr = iterator_scale.get_scale(); + auto gmem_zero_ptr = iterator_scale.get_zero(); + + arch::global_load(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) { + arch::global_load( + tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid()); + } + + typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales); + typename TransformScale::result_type tb_frag_zeros_fp16; + if (gmem_zero_ptr != nullptr) + tb_frag_zeros_fp16 = transformScale(tb_frag_zeros); + + auto frag_scale_ptr_fp16 = reinterpret_cast(&tb_frag_scales_fp16); + auto frag_zero_ptr_fp16 = reinterpret_cast(&tb_frag_zeros_fp16); + auto smem_scale_ptr = this->smem_iterator_scale_.get_scale(); + auto smem_zero_ptr = this->smem_iterator_scale_.get_zero(); + + if (iterator_scale.valid()) { + auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr); + arch::shared_store(smem_offset, frag_scale_ptr_fp16); + + if (gmem_zero_ptr != nullptr) { + smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr); + arch::shared_store(smem_offset, frag_zero_ptr_fp16); + } + } + + if (iterator_scale.group_size_ == 64) { + iterator_scale.add_tile_offset({1, 0}); + } else if (iterator_scale.group_size_ == 128) { + if constexpr (Shape::kK == 128) { + iterator_scale.add_tile_offset({1, 0}); + } else if constexpr (Shape::kK == 64) { + if (iterator_scale.row_groupsize64_ & 0x1) { + iterator_scale.add_tile_offset({1, 0}); + } + } else { + static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) { ///< source accumulator tile + + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA = NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want + // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. + TransformA transformA; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + copy_scales_and_advance(iterator_scale); + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + WarpFragmentScale warp_frag_scales; + WarpFragmentZero warp_frag_zero; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + iterator_scale.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + copy_scales_and_advance(iterator_scale); + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + iterator_scale.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero); + warp_mma(accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + } + + // Load the scales needed for the next tile iteration + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + // Update internal pointer to the set of scales in shared memory + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h new file mode 100644 index 0000000000000..e680493cf060a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/arch/mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements, + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp { + private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; + + // Chosen so we get K=16 for int8 and K=32 for int4. + static constexpr int LoadInstructionK = 128 / sizeof_bits::value; + + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = GemmShape; + + public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1>>; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h new file mode 100644 index 0000000000000..21c787e91be50 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -0,0 +1,263 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations targeting + Tensor Cores. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" +#include "cutlass/arch/mma_sm89.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Instruction shape to override shared memory iterators with + typename SharedMemoryInstructionShape_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool> +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert((platform::is_same::value && platform::is_same::value) || (platform::is_same::value && platform::is_same::value && ArchTag::kMinComputeCapability >= 80) || (platform::is_same::value && platform::is_same::value && ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); + + static_assert(platform::is_same::value || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80) || (platform::is_same::value && ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert( + SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); + static_assert( + SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, + LayoutB, MatrixShape, Policy::OpDelta::kRow, + kThreadCount, kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, + int const warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " + "B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h new file mode 100644 index 0000000000000..47f1bb240e8b3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -0,0 +1,393 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" + +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + WeightOnlyQuantOp QuantOp_, + /// + typename Enable = void> +class MmaTensorOpDequantizer; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 80 && platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + __nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag); + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Turing & Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 75 && platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) { + if constexpr (hasZero(QuantOp)) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + + if constexpr (hasZero(QuantOp)) { + plus plus_op; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]); + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + + private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h new file mode 100644 index 0000000000000..e48ef3f154883 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h @@ -0,0 +1,405 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif + +#include "cute/tensor.hpp" + +namespace onnxruntime::llm { +namespace cutlass_extensions { + +// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=16 + CtaShape16x128x64_WarpShape16x32x64, + + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x64x128_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x64x64, + CtaShape128x128x64_WarpShape128x32x64, + CtaShape128x256x64_WarpShape64x64x64, + + // Warp configs for M=256 + CtaShape256x128x64_WarpShape64x64x64, + + // TensorCore config CTA_N = 64, CTA_K = 128 + CtaShape128x64x128_WarpShape64x32x128, + + // TensorCore config CTA_N = 256, CTA_K = 64 + CtaShape16x256x64_WarpShape16x64x64, + + // TensorCore config CTA_N = 256, CTA_K = 128 + CtaShape16x256x128_WarpShape16x64x128 + +}; + +enum class SplitKStyle { + NO_SPLIT_K, + SPLIT_K_SERIAL, + STREAM_K, // Sm80+ + // SPLIT_K_PARALLEL // Not supported yet +}; + +enum class CutlassTileConfigSM90 { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // CTA configs for M=64 + CtaShape64x16x128B, + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + + // CTA configs for M=128 + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, + + // CTA configs for M=128 + CtaShape256x128x128B, +}; + +enum class CutlassTileConfigSM100 { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + /* + * Grouped GEMM + */ + // M=64 + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + + // M=128 + CtaShape128x8x256B, + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, + CtaShape128x128x256B, + CtaShape128x256x256B, + + // M=256 + CtaShape256x64x128B, + CtaShape256x128x128B, + CtaShape256x256x128B, +}; + +enum class MainloopScheduleType { + AUTO, // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this + // defaults to the "legacy" main loop schedule. + PINGPONG, + COOPERATIVE, + WARPSPECIALIZED +}; + +#if 0 +static auto get_mainloop_schedule_name(MainloopScheduleType schedule) { + if (schedule == MainloopScheduleType::AUTO) { + return "auto"; + } else if (schedule == MainloopScheduleType::PINGPONG) { + return "pingpong"; + } else if (schedule == MainloopScheduleType::COOPERATIVE) { + return "cooperative"; + } else if (schedule == MainloopScheduleType::WARPSPECIALIZED) { + return "warpspecialized"; + } + return "unknown schedule"; +} +#endif + +enum class EpilogueScheduleType { + AUTO, // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For + // architectures older than hopper, the epilogue is always performed by the same thread block as the main + // loop. +}; + +enum class TileShape { + TileShape_64x16x128, + TileShape_64x32x128, + TileShape_64x64x128, + TileShape_64x128x128, + TileShape_64x256x128, + TileShape_64x512x128, + TileShape_128x16x128, + TileShape_128x32x128, + TileShape_128x64x128, + TileShape_128x128x128, + TileShape_128x256x128 +}; + +template +constexpr auto get_tile_shape() { + using namespace cute; + if constexpr (Shape_MNK == TileShape::TileShape_64x16x128) { + return cute::Shape<_64, _16, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x32x128) { + return cute::Shape<_64, _32, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x64x128) { + return cute::Shape<_64, _64, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x128x128) { + return cute::Shape<_64, _128, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x256x128) { + return cute::Shape<_64, _256, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_64x512x128) { + return cute::Shape<_64, _512, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x16x128) { + return cute::Shape<_128, _16, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x32x128) { + return cute::Shape<_128, _32, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x64x128) { + return cute::Shape<_128, _64, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x128x128) { + return cute::Shape<_128, _128, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x256x128) { + return cute::Shape<_128, _256, _128>{}; + } +} + +#if 0 +static auto get_tile_shape_name(TileShape Shape_MNK) { + if (Shape_MNK == TileShape::TileShape_64x16x128) { + return "64x16x128"; + } else if (Shape_MNK == TileShape::TileShape_64x32x128) { + return "64x32x128"; + } else if (Shape_MNK == TileShape::TileShape_64x64x128) { + return "64x64x128"; + } else if (Shape_MNK == TileShape::TileShape_64x128x128) { + return "64x128x128"; + } else if (Shape_MNK == TileShape::TileShape_64x256x128) { + return "64x256x128"; + } else if (Shape_MNK == TileShape::TileShape_64x512x128) { + return "64x512x128"; + } else if (Shape_MNK == TileShape::TileShape_128x16x128) { + return "128x16x128"; + } else if (Shape_MNK == TileShape::TileShape_128x32x128) { + return "128x32x128"; + } else if (Shape_MNK == TileShape::TileShape_128x64x128) { + return "128x64x128"; + } else if (Shape_MNK == TileShape::TileShape_128x128x128) { + return "128x128x128"; + } else if (Shape_MNK == TileShape::TileShape_128x256x128) { + return "128x256x128"; + } + return "Unknown shape"; +} +#endif + +enum class ClusterShape { + ClusterShape_1x1x1, + ClusterShape_2x1x1, + ClusterShape_1x2x1, + ClusterShape_2x2x1, + ClusterShape_1x4x1, + ClusterShape_4x2x1, + ClusterShape_2x4x1, + ClusterShape_4x4x1, + ClusterShape_1x8x1, + ClusterShape_8x1x1 +}; + +#if 0 +static auto get_cluster_shape_name(ClusterShape Shape_MNK) { + if (Shape_MNK == ClusterShape::ClusterShape_1x1x1) { + return "1x1x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_2x1x1) { + return "2x1x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_1x2x1) { + return "1x2x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_2x2x1) { + return "2x2x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_1x8x1) { + return "1x8x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { + return "8x1x1"; + } + return "Unknown shape"; +} + +template +constexpr auto get_cluster_shape() { + using namespace cute; + if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x1x1) { + return cute::Shape<_1, _1, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x1x1) { + return cute::Shape<_2, _1, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x2x1) { + return cute::Shape<_1, _2, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_2x2x1) { + return cute::Shape<_2, _2, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_1x8x1) { + return cute::Shape<_1, _8, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { + return cute::Shape<_8, _1, _1>{}; + } +} +#endif + +struct CutlassGemmConfig { + enum CandidateConfigTypeParam : int { + NONE = 0, + WEIGHT_ONLY = 1u << 0, + SIMT_ONLY = 1u << 1, + INT8_ONLY = 1u << 2, + HOPPER = 1u << 3, + BLACKWELL = 1u << 4, + GROUPED_GEMM = 1u << 5, + FP8_ONLY = 1u << 6, + FP4_ONLY = 1u << 7 + }; + + CutlassTileConfig tile_config_sm80 = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; + + // config options for sm90 + CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; + CutlassTileConfigSM100 tile_config_sm100 = CutlassTileConfigSM100::ChooseWithHeuristic; + MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; + EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; + ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; + bool enableCudaKernel = false; + int sm_version = 80; // Use 80 as a catch all for <90 + bool is_tma_warp_specialized = false; + + CutlassGemmConfig() = default; + + CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) + : tile_config_sm80(tile_config), split_k_style(split_k_style), split_k_factor(split_k_factor), stages(stages), sm_version(80) { + } + + CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, + EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) + : tile_config_sm90(tile_config_sm90), mainloop_schedule(mainloop_schedule), epilogue_schedule(epilogue_schedule), cluster_shape(cluster_shape), sm_version(90), is_tma_warp_specialized(true) { + } + + CutlassGemmConfig(CutlassTileConfigSM100 tile_config_sm100, MainloopScheduleType mainloop_schedule, + EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) + : tile_config_sm100(tile_config_sm100), mainloop_schedule(mainloop_schedule), epilogue_schedule(epilogue_schedule), cluster_shape(cluster_shape), sm_version(100), is_tma_warp_specialized(true) { + } + + int getTileConfigAsInt() const { + if (sm_version == 120) + return (int)tile_config_sm80; + if (sm_version >= 100) + return (int)tile_config_sm100; + if (sm_version == 90) + return (int)tile_config_sm90; + if (sm_version < 90) + return (int)tile_config_sm80; + assert(false && "Invalid SM version"); + return -1; + } + + std::string toString() const { + std::stringstream tactic; + tactic << "Cutlass GEMM Tactic"; + if (is_tma_warp_specialized && getTileConfigAsInt() != (int)CutlassTileConfigSM90::ChooseWithHeuristic) { + assert(sm_version >= 90 && "Invalid cutlass GEMM config"); + tactic << "\n\tstyle=TMA Warp Specialized" + << "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsInt() + << "\n\tcluster shape ID: " << (int)cluster_shape + << "\n\tmainloop sched: " << (int)mainloop_schedule << "\n\tepi sched: " << (int)epilogue_schedule + << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); + } else if (tile_config_sm80 != onnxruntime::llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) { + assert(sm_version < 90 && "Invalid cutlass GEMM config"); + tactic << "\n\tstyle=compatible" + << "\n\ttile shape ID: " << (int)tile_config_sm80 << "\n\tstages: " << (int)stages + << "\n\tsplit k: " << (int)split_k_factor + << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); + } else if (enableCudaKernel) { + tactic << "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false"); + } else { + tactic << "\n\tundefined"; + } + tactic << "\n"; + return tactic.str(); + } +}; + +inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) { + // clang-format off + if (config.is_tma_warp_specialized) + { + out << "tile_config_sm90_enum: " << config.getTileConfigAsInt() + << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) + << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) + << ", cluster_shape_enum: " << int(config.cluster_shape) + << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); + } + else + { + out << "tile_config_enum: " << config.getTileConfigAsInt() + << ", split_k_style_enum: " << int(config.split_k_style) + << ", split_k_factor: " << config.split_k_factor + << ", stages: " << config.stages + << ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false"); + } + // clang-format on + return out; +} + +} // namespace cutlass_extensions +} // namespace onnxruntime::llm + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h new file mode 100644 index 0000000000000..86c45a865954e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h @@ -0,0 +1,399 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + \file + \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/numeric_types.h" + +namespace cutlass { + +// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low +// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally +// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. +// This converter will uninterleave the data and subtract the bias while converting to the result type. +template +struct FastInterleavedAndBiasedNumericArrayConverter { +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; + } + + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h new file mode 100644 index 0000000000000..30df05f24257e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/tile_interleaved_layout.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Defines new layouts needed for MoE +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/pitch_linear_coord.h" + +namespace cutlass { +namespace layout { + +template +struct ColumnMajorTileInterleave { + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; +}; + +template +struct IsColumnMajorTileInterleave { + static constexpr bool value = false; +}; + +template +struct IsColumnMajorTileInterleave> { + static constexpr bool value = true; +}; + +} // namespace layout +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h new file mode 100644 index 0000000000000..cf5ebdaeec261 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2017-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM + quantization. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +class FineGrainedScaleZeroIterator; + +template +class FineGrainedScaleZeroIterator { + public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; + + static int const kAccessesPerVector = 1; + + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using AccessType = AlignedArray; + + using Fragment = cutlass::Array; + + // For compatibility with existing iterator interface + struct Params { + LongIndex stride_ = 0; + + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_(layout.stride(0)) { + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; + } + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + BytePointer pointer_zero_; + + bool is_valid_ = false; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params), pointer_scale_(reinterpret_cast(const_cast(pointer_scale))), pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset = threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + if (pointer_zero_ != nullptr) { + pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); + } + + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; + + int const thread_row = thread_id / THREADS_PER_ROW; + int const thread_col = thread_id % THREADS_PER_ROW; + + const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + if (pointer_zero_ != nullptr) { + pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); + } + + // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on + // a given iteration. The same threads will be responsible for issues reads since the number of scales + // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ + // outside of the constructor. + int const global_row = threadblock_offset.row() + thread_row; + int const global_col = threadblock_offset.column() + thread_col * kAlignment; + + bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; + bool const col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator( + params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size) { + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + if (pointer_zero_ != nullptr) { + pointer_zero_ += row_byte_offset + col_byte_offset; + } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) { + is_valid_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + return is_valid_; + } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const { + return reinterpret_cast(pointer_scale_); + } + + /// Returns a zero pointer + CUTLASS_HOST_DEVICE + AccessType* get_zero() const { + return reinterpret_cast(pointer_zero_); + } +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h new file mode 100644 index 0000000000000..cc54764c2be50 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2017-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +namespace cutlass { + +enum class WeightOnlyQuantOp { + UNDEFINED, + PER_COLUMN_SCALE_ONLY, + FINEGRAINED_SCALE_ONLY, + FINEGRAINED_SCALE_AND_ZEROS +}; + +constexpr bool isFinegrained(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +} + +constexpr bool hasZero(WeightOnlyQuantOp op) { + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +} + +} // namespace cutlass diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.cc b/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.cc new file mode 100644 index 0000000000000..d53fb558ba1a1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.cc @@ -0,0 +1,479 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#pragma GCC diagnostic ignored "-Wsign-compare" +#endif // __GNUC__ + +#include "contrib_ops/cuda/llm/cutlass_heuristic.h" + +#include + +#include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_types.h" +#include "core/common/common.h" + +#include +#include +#include +#include + +using namespace onnxruntime::llm::cutlass_extensions; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +struct TileShape { + int m; + int n; +}; + +TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { + switch (tile_config) { + case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + return TileShape{16, 128}; + case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + return TileShape{16, 256}; + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: + return TileShape{64, 64}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: + return TileShape{128, 64}; + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + return TileShape{128, 128}; + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + return TileShape{128, 256}; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + return TileShape{256, 128}; + case CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: + return TileShape{16, 256}; + default: + ORT_THROW("[get_grid_shape_for_config] Invalid config"); + } +} + +bool is_valid_split_k_factor(int64_t const m, int64_t const n, int64_t const k, TileShape const tile_shape, + int const split_k_factor, size_t const workspace_bytes, bool const is_weight_only) { + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 64; + + // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k + if (is_weight_only) { + if ((k % k_tile) != 0) { + return false; + } + + if ((k % split_k_factor) != 0) { + return false; + } + + int const k_elements_per_split = k / split_k_factor; + if ((k_elements_per_split % k_tile) != 0) { + return false; + } + } + + // Check that the workspace has sufficient space for this split-k factor + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + + if (required_ws_bytes > workspace_bytes) { + return false; + } + + return true; +} + +std::vector get_candidate_tiles( + int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { + enum class CutlassGemmType : char { + Default, + WeightOnly, + Simt, + Int8, + Fp8 + }; + + CutlassGemmType gemm_type = CutlassGemmType::Default; + if (config_type_param & CutlassGemmConfig::SIMT_ONLY) { + gemm_type = CutlassGemmType::Simt; + } else if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) { + gemm_type = CutlassGemmType::WeightOnly; + } else if (config_type_param & CutlassGemmConfig::INT8_ONLY) { + gemm_type = CutlassGemmType::Int8; + } else if (config_type_param & CutlassGemmConfig::FP8_ONLY) { + gemm_type = CutlassGemmType::Fp8; + } + + std::vector base_configs{ + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; + if (sm >= 75) { + base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64); + } + + switch (gemm_type) { + case CutlassGemmType::Simt: + return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + case CutlassGemmType::WeightOnly: + if (sm >= 75) { + return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, + CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; + } else { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}; + } + case CutlassGemmType::Int8: + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + case CutlassGemmType::Fp8: + if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) { + if (sm == 89) { + return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + } else { + // no valid ampere style fp8 configs for sm90 + return {}; + } + } else { + if (sm == 89) { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x64x128_WarpShape64x32x128, + CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128}; + } else { + return {}; + } + } + default: + return base_configs; + } +} + +std::vector get_candidate_tiles_sm90(CutlassGemmConfig::CandidateConfigTypeParam const config) { +#ifdef FAST_BUILD + // Fast build disables all configs except this one for SM90 + return {CutlassTileConfigSM90::CtaShape128x128x128B}; +#else + if (config & CutlassGemmConfig::GROUPED_GEMM) { + return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; + } else { + return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B, + CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B}; + } +#endif +} + +// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve +// compilation speed. +bool sm90_supports_mcast_along_m(CutlassTileConfigSM90 const tile) { +#ifdef FAST_BUILD + return false; +#else + std::set valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B, + CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; +#endif +} + +// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve +// compilation speed. +bool sm90_supports_mcast_along_n(CutlassTileConfigSM90 const tile) { +#ifdef FAST_BUILD + return false; +#else + std::set valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; +#endif +} + +std::vector get_candidate_configs_sm90(CutlassGemmConfig::CandidateConfigTypeParam const config) { + auto tiles = get_candidate_tiles_sm90(config); + std::vector candidate_configs; + for (auto const& tile_config : tiles) { + CutlassGemmConfig config( + tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); + candidate_configs.push_back(config); + + bool const has_m_mcast = sm90_supports_mcast_along_m(tile_config); + bool const has_n_mcast = sm90_supports_mcast_along_n(tile_config); + if (has_m_mcast) { + CutlassGemmConfig config( + tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1); + candidate_configs.push_back(config); + } + + if (has_n_mcast) { + CutlassGemmConfig config( + tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1); + candidate_configs.push_back(config); + } + + if (has_m_mcast && has_n_mcast) { + CutlassGemmConfig config( + tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x2x1); + candidate_configs.push_back(config); + } + } + // add cuda kernel profiler to tactics for weight-only plugins + if (config & CutlassGemmConfig::WEIGHT_ONLY) { + if (tiles.size() > 0) { + CutlassGemmConfig CudaKernelConfig( + tiles[0], MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); + CudaKernelConfig.enableCudaKernel = true; + candidate_configs.push_back(CudaKernelConfig); + } + } + return candidate_configs; +} + +std::vector get_candidate_configs_sm100(CutlassGemmConfig::CandidateConfigTypeParam const config) { +#ifdef FAST_BUILD + // Fast build disables all configs except this one for SM100 + return {CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, MainloopScheduleType::AUTO, + EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}}; +#else + if (config & CutlassGemmConfig::GROUPED_GEMM) { + std::vector candidate_configs; + if ((config & CutlassGemmConfig::FP4_ONLY) != 0) { + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x128x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + // candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, + // MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x64x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1}); + candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B, + MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1}); + return candidate_configs; + } + + for (int cluster_m = 1; cluster_m <= 2; cluster_m++) { + bool Is2SM = cluster_m == 2; + for (int cluster_n = 1; cluster_n <= 2; cluster_n++) { + std::vector base = {// M=128 + CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B}; + + if (Is2SM) { + if (cluster_n == 1) { + base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B); + base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B); + } + + std::vector twosm = {// M=256 + CutlassTileConfigSM100::CtaShape256x128x128B, CutlassTileConfigSM100::CtaShape256x256x128B}; + std::copy(twosm.begin(), twosm.end(), std::back_inserter(base)); + } else { + if (cluster_n == 1) { + base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B); + if ((config & CutlassGemmConfig::FP8_ONLY) != 0) { + base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B); + } + } + + if (cluster_n == 1 && cluster_m == 1 && ((config & CutlassGemmConfig::FP8_ONLY) != 0)) { + base.push_back(CutlassTileConfigSM100::CtaShape128x8x256B); + } + + std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B, + CutlassTileConfigSM100::CtaShape64x128x128B, CutlassTileConfigSM100::CtaShape64x256x128B, + CutlassTileConfigSM100::CtaShape128x64x128B}; + std::copy(onesm.begin(), onesm.end(), std::back_inserter(base)); + } + + constexpr std::array, 2> cluster_shapes = + {{std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1}, + std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}}}; + + auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1]; + for (auto tile : base) { + CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster}; + candidate_configs.push_back(config); + } + } + } + return candidate_configs; + } else { + ORT_THROW("Not Implemented: SM100 GEMM candidates have not been defined."); + } +#endif + +} // namespace kernels + +std::vector get_candidate_configs( + int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) { + if ((config_type_param & CutlassGemmConfig::FP4_ONLY) && !(config_type_param & CutlassGemmConfig::BLACKWELL)) { + // FP4 is only supported on blackwell + return {}; + } + + if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER)) { + return get_candidate_configs_sm90(config_type_param); + } + if (sm >= 100 && sm != 120 && (config_type_param & CutlassGemmConfig::BLACKWELL)) { + return get_candidate_configs_sm100(config_type_param); + } + + std::vector tiles = get_candidate_tiles(sm, config_type_param); + + std::vector candidate_configs; + bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY; + int const min_stages = int8_configs_only ? 3 : 2; + int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); + for (auto const& tile_config : tiles) { + for (int stages = min_stages; stages <= max_stages; ++stages) { + CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages); + candidate_configs.push_back(config); + if (sm >= 75) { + for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor) { + auto config = CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages}; + candidate_configs.push_back(config); + } + } + } + } + // add cuda kernel profiler to tactics for weight-only plugins + if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) { + if (tiles.size() > 0) { + CutlassGemmConfig CudaKernelConfig(tiles[0], SplitKStyle::NO_SPLIT_K, 1, min_stages); + CudaKernelConfig.enableCudaKernel = true; + candidate_configs.push_back(CudaKernelConfig); + } + } + return candidate_configs; +} + +CutlassGemmConfig estimate_best_config_from_occupancies( + std::vector const& candidate_configs, + std::vector const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const /*num_experts*/, + int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only) { + if (occupancies.size() != candidate_configs.size()) { + ORT_THROW( + "[estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } + + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (int ii = 0; ii < candidate_configs.size(); ++ii) { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config_sm80); + int occupancy = occupancies[ii]; + + if (occupancy == 0) { + continue; + } + + // Keep small tile sizes when possible. + if (best_config.tile_config_sm80 != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile && current_m_tile < tile_shape.m) { + continue; + } + + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + + for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) { + if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) { + int const ctas_per_wave = occupancy * multi_processor_count; + int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + float const num_waves_fractional = ctas_for_problem / float(ctas_per_wave); + float const current_score = float(num_waves_total) - num_waves_fractional; + + float const score_slack = 0.1f; + if (current_score < config_score || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig( + candidate_config.tile_config_sm80, split_style, split_k_factor, candidate_config.stages); + current_m_tile = tile_shape.m; + } else if (current_score == config_score && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor || current_m_tile < tile_shape.m)) { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig( + candidate_config.tile_config_sm80, split_style, split_k_factor, candidate_config.stages); + current_m_tile = tile_shape.m; + config_waves = num_waves_total; + } + } + } + } + + if (best_config.tile_config_sm80 == CutlassTileConfig::ChooseWithHeuristic) { + ORT_THROW("Heuristic failed to find a valid config."); + } + + return best_config; +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.h new file mode 100644 index 0000000000000..b9b0301d78fc7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cute/tensor.hpp" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +template +struct should_filter_tma_warp_specialized_gemm_problem_shape { +#ifdef FAST_BUILD + using SupportedCtaShape = cute::Shape(TileShape{}))>; + using SupportedCgaShape = cute::Shape; + + constexpr static bool value = !cute::is_same_v || !cute::is_same_v; +#else + constexpr static bool value = false; +#endif +}; +template +constexpr static bool should_filter_tma_warp_specialized_gemm_problem_shape_v = should_filter_tma_warp_specialized_gemm_problem_shape::value; + +std::vector get_candidate_configs( + int sm, int const max_split_k, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam const); + +onnxruntime::llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies( + std::vector const& candidate_configs, + std::vector const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const /*num_experts*/, + int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.cc b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.cc new file mode 100644 index 0000000000000..50ee944161538 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.cc @@ -0,0 +1,687 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/cutlass_preprocessors.h" + +#include + +#include "core/common/common.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "contrib_ops/cuda/llm/common/logger.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +using namespace onnxruntime::llm::common; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +struct LayoutDetails { + enum class Layout { + UNKNOWN, + ROW_MAJOR, + COLUMN_MAJOR + }; + + Layout layoutB = Layout::UNKNOWN; + int rows_per_column_tile = 1; + int columns_interleaved = 1; + + bool uses_imma_ldsm = false; +}; + +template +struct getLayoutDetails { +}; + +template <> +struct getLayoutDetails { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; + return layout_details; + } +}; + +template <> +struct getLayoutDetails { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + return layout_details; + } +}; + +template +struct getLayoutDetails> { + LayoutDetails operator()() { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + layout_details.rows_per_column_tile = RowsPerTile; + layout_details.columns_interleaved = ColumnsInterleaved; + return layout_details; + } +}; + +template +LayoutDetails getLayoutDetailsForArchAndQuantType() { + using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB; + using LayoutB = typename CompileTraits::Layout; + using MmaOperator = typename CompileTraits::Operator; + LayoutDetails details = getLayoutDetails()(); + details.uses_imma_ldsm = std::is_same::value; + return details; +} + +template +LayoutDetails getLayoutDetailsForArch(QuantType quant_type) { + LayoutDetails details; + switch (quant_type) { + case QuantType::W8_A16: + details = getLayoutDetailsForArchAndQuantType(); + break; + case QuantType::W4_A16: + details = getLayoutDetailsForArchAndQuantType(); + break; + case QuantType::W4_AFP8: + details = getLayoutDetailsForArchAndQuantType(); + break; + default: + ORT_THROW("Unsupported quantization type"); + } + return details; +} + +LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) { + if (arch >= 75 && arch < 80) { + return getLayoutDetailsForArch(quant_type); + } else if (arch >= 80 && arch < 90) { + return getLayoutDetailsForArch(quant_type); + } else if (arch >= 90 && arch < 100) { + return getLayoutDetailsForArch(quant_type); + } else if (arch >= 100) { + return getLayoutDetailsForArch(quant_type); + } else { + ORT_THROW("Unsupported Arch"); + return LayoutDetails(); + } +} + +// Permutes the rows of B in a way that is compatible with Turing+ architectures. +// +// Throws an error for other architectures. +// The data is permuted such that: +// For W8_A16, each group of 16 rows is permuted using the map below: +// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 +// For W4_A16, each group of 32 rows is permuted using the map below: +// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 +// For W4_A8, see the map in the code. The idea is similar to above. +// The goal of this permutation is to ensure data ends up in the correct threads after +// we execute LDSM. It counteracts the effect of the data being of different widths. +// For more information about the expected layouts, see the MMA section in the PTX docs. +std::vector get_permutation_map(QuantType quant_type) { + if (quant_type == QuantType::W8_A16) { + return {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + } else if (quant_type == QuantType::W4_A16) { + return {0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, 4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15, + 22, 23, 30, 31}; + } else if (quant_type == QuantType::W4_AFP8) { + return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, + 28, 29, 30, 31}; + } else { + ORT_THROW("Invalid quantization type for LDSM permutation"); + } +} + +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + // We only want to run this step for weight only quant. + std::vector row_permutation = get_permutation_map(quant_type); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const BITS_PER_ELT = get_weight_quant_bits(quant_type); + int const K = 16 / BITS_PER_ELT; + + uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); + + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; + int const elts_in_int32 = 32 / BITS_PER_ELT; + + int const num_vec_cols = num_cols / elts_in_int32; + + ORT_ENFORCE(num_rows % B_ROWS_PER_MMA == 0, + "Invalid shape for quantized tensor. Number of rows of quantized matrix must be a multiple of ", + B_ROWS_PER_MMA); + ORT_ENFORCE(num_cols % MMA_SHAPE_N == 0, + "Invalid shape for quantized tensor. On turing/Ampere, the number of cols must be a multiple of ", + MMA_SHAPE_N); + + ORT_ENFORCE(size_t(B_ROWS_PER_MMA) == row_permutation.size(), "Unexpected number of LDSM rows permuted."); + + for (int expert = 0; expert < static_cast(num_experts); ++expert) { + const int64_t matrix_offset = expert * int64_t(num_rows) * int64_t(num_vec_cols); + for (int base_row = 0; base_row < static_cast(num_rows); base_row += B_ROWS_PER_MMA) { + for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { + for (int write_col = 0; write_col < num_vec_cols; ++write_col) { + int const write_row = base_row + tile_row; + int const tile_read_row = row_permutation[tile_row]; + int const read_row = base_row + tile_read_row; + int const read_col = write_col; + + const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col; + const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col; + + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } + } +} + +// We need to use this transpose to correctly handle packed int4 and int8 data +// The reason this code is relatively complex is that the "trivial" loops took a substantial +// amount of time to transpose leading to long preprocessing times. This seemed to be a big +// issue for relatively large models. +template +void subbyte_transpose_impl( + int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, std::vector const& shape) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + constexpr int bits_per_elt = get_weight_quant_bits(quant_type); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const size_t col_bytes = num_cols * bits_per_elt / 8; + const size_t col_bytes_trans = num_rows * bits_per_elt / 8; + + uint8_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint8_t* output_byte_ptr = reinterpret_cast(transposed_quantized_tensor); + + static constexpr int ELTS_PER_BYTE = 8 / bits_per_elt; + + static constexpr int M_TILE_L1 = 64; + static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; + uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; + + static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); + + // We assume the dims are a multiple of vector width. Our kernels only handle dims which are multiples + // of 64 for weight-only quantization. As a result, this seemed like a reasonable tradeoff because it + // allows GCC to emit vector instructions. + ORT_ENFORCE(!(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), + "Number of bytes for rows and cols must be a multiple of ", VECTOR_WIDTH, ". However, num_rows_bytes = ", + col_bytes_trans, " and num_col_bytes = ", col_bytes); + + for (size_t expert = 0; expert < num_experts; ++expert) { + const size_t matrix_offset = expert * num_rows * col_bytes; + for (size_t row_tile_start = 0; row_tile_start < num_rows; row_tile_start += M_TILE_L1) { + for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) { + int const row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); + int const col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + int const row = row_tile_start + ii; + + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + int const col = col_tile_start_byte + jj; + + const size_t logical_src_offset = matrix_offset + row * col_bytes + col; + + if (row < row_limit && col < col_limit) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; + } + } + } + } + + if constexpr (bits_per_elt == 8) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + for (int jj = ii + 1; jj < N_TILE_L1; ++jj) { + std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); + } + } + } else if constexpr (bits_per_elt == 4) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + // Using M_TILE_L1 here is deliberate since we assume that the cache tile + // is square in the number of elements (not necessarily the number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { + int const ii_byte = ii / ELTS_PER_BYTE; + int const ii_bit_offset = ii % ELTS_PER_BYTE; + + int const jj_byte = jj / ELTS_PER_BYTE; + int const jj_bit_offset = jj % ELTS_PER_BYTE; + + uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); + uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); + } + } + } else { + ORT_THROW("Unsupported quantization type."); + } + + const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; + const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; + + int const row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols); + int const col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + int const row = row_tile_start_trans + ii; + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + int const col = col_tile_start_byte_trans + jj; + + const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col; + + if (row < row_limit_trans && col < col_limit_trans) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; + } + } + } + } + } + } + } +} + +void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + + if (quant_type == QuantType::W8_A16) { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } else if (quant_type == QuantType::W4_A16) { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } else if (quant_type == QuantType::W4_AFP8) { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } else { + ORT_THROW("Invalid quant_type"); + } +} + +void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, const size_t num_elts) { + for (size_t ii = 0; ii < num_elts; ++ii) { + int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to match the int4 layout. This has no + // performance benefit and is purely so that int4 and int8 have the same layout. + // Pictorially, this does the following: + // bit 32 0 + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + + ORT_ENFORCE(num_elts % 4 == 0, "Dimensions of int8 tensor must be a multiple of 4 for register relayout"); + for (size_t base = 0; base < num_elts; base += 4) { + std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); + } +} + +void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts) { + size_t const num_bytes = num_elts / 2; + + // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little + // instructions as possible in the CUDA code. + for (size_t ii = 0; ii < num_bytes; ++ii) { + int8_t transformed_packed_int4s = 0; + int8_t transformed_first_elt = (int8_t(packed_int4_tensor[ii] << 4) >> 4) + 8; // The double shift here is to ensure sign extension + int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; + + ORT_ENFORCE( + transformed_first_elt >= 0 && transformed_first_elt <= 15, "Illegal result for int4 transform (first elt)"); + ORT_ENFORCE(transformed_second_elt >= 0 && transformed_second_elt <= 15, + "Illegal result for int4 transform (second elt)"); + + // We don't need to mask in these ops since everything should be in the range 0-15 + transformed_packed_int4s |= transformed_first_elt; + transformed_packed_int4s |= (transformed_second_elt << 4); + packed_int4_tensor[ii] = transformed_packed_int4s; + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical + // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the + // following: Take as input a 32 bit register with layout: bit 32 0 + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) + + ORT_ENFORCE(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a multiple of 8 for register relayout"); + const size_t num_registers = num_bytes / 4; + + uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); + for (size_t ii = 0; ii < num_registers; ++ii) { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { + int const src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; + int const src_shift = 4 * src_idx; + int const dest_shift = 4 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0xF; + transformed_register |= (src_bits << dest_shift); + } + register_ptr[ii] = transformed_register; + } +} + +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + if (quant_type == QuantType::W8_A16) { + add_bias_and_interleave_int8s_inplace(tensor, num_elts); + } else if (quant_type == QuantType::W4_A16 || quant_type == QuantType::W4_AFP8) { + // W4_AFP8 uses the same preprocessor as W4_A16 because the FP8 data must + // be converted to FP16 before the scales can be applied using CUDA cores. + // As a result, we still want permute the data so that it is well aligned + // for conversion to FP16. + add_bias_and_interleave_int4s_inplace(tensor, num_elts); + } else { + ORT_THROW("Invalid quantization type for interleaving."); + } +} + +void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type, LayoutDetails details) { + ORT_LLM_LOG_TRACE(__PRETTY_FUNCTION__); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const BITS_PER_ELT = get_weight_quant_bits(quant_type); + int const elts_in_int32 = 32 / BITS_PER_ELT; + + int const rows_per_tile = details.rows_per_column_tile; + + ORT_ENFORCE(!(num_rows % elts_in_int32), + "The number of rows must be a multiple of ", elts_in_int32, " but the number of rows is ", num_rows); + + uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor); + + ORT_ENFORCE(!(num_rows % rows_per_tile), + "The number of rows must be a multiple of ", rows_per_tile, " but the number of rows is ", num_rows); + + int const num_vec_rows = num_rows / elts_in_int32; + int const vec_rows_per_tile = rows_per_tile / elts_in_int32; + int const interleave = details.columns_interleaved; + + for (int expert = 0; expert < static_cast(num_experts); ++expert) { + const int64_t matrix_offset = expert * int64_t(num_vec_rows) * int64_t(num_cols); + for (int64_t read_col = 0; read_col < static_cast(num_cols); ++read_col) { + const int64_t write_col = read_col / interleave; + for (int base_vec_row = 0; base_vec_row < num_vec_rows; base_vec_row += vec_rows_per_tile) { + for (int vec_read_row = base_vec_row; + vec_read_row < std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); ++vec_read_row) { + const int64_t vec_write_row = interleave * base_vec_row + vec_rows_per_tile * (read_col % interleave) + vec_read_row % vec_rows_per_tile; + + const int64_t read_offset = matrix_offset + read_col * num_vec_rows + vec_read_row; + const int64_t write_offset = matrix_offset + int64_t(write_col) * num_vec_rows * interleave + vec_write_row; + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } + } +} + +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, + std::vector const& shape, QuantType quant_type, bool force_interleave) { + int arch = getSMVersion(); + if (force_interleave && arch >= 90) { + // Workaround for MOE which doesn't have specialized Hopper/Blackwell kernels yet + arch = 80; + } + // Force use sm80 kernel for GB20x. + if (arch >= 100) { + arch = 80; + } + LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + + size_t num_elts = 1; + for (auto const& dim : shape) { + num_elts *= dim; + } + + const size_t num_bytes = num_elts * get_weight_quant_bits(quant_type) / 8; + + std::vector src_buf(num_bytes); + std::vector dst_buf(num_bytes); + std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin()); + + // Works on row major data, so issue this permutation first. + if (details.uses_imma_ldsm) { + permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type); + src_buf.swap(dst_buf); + } + + if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) { + subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); + src_buf.swap(dst_buf); + } + + if (details.columns_interleaved > 1 && arch != 90) { + interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details); + src_buf.swap(dst_buf); + } + + add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); + std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); +} + +/* + Arguments: + input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16. + + quant_type - the type of the output quantization weight. + + This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the + zero-point is zero and will automatically construct the scales. + + It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is + viewed as a stack of matrices and a scale is produced for each column of every matrix. + +Outputs + processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM + unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking. + scale_ptr - scales for the quantized weight. + + Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data + layout may not make sense if printed. + + Shapes: + quant_type == int8: + If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n] + If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n] + quant_type == int4: + If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n] + If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape + [b,n] + + The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the + reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind + of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors + must have a dimension of 1, which breaks the semantics we need for batched weights. + */ + +template +void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, + bool force_interleave) { + ORT_ENFORCE(processed_quantized_weight, "Processed quantized tensor is NULL"); + ORT_ENFORCE(scale_ptr, "Scale output pointer is NULL"); + ORT_ENFORCE(input_weight_ptr, "Input weight pointer is NULL"); + + ORT_ENFORCE(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const bits_in_type = get_weight_quant_bits(quant_type); + int const bytes_per_out_col = num_cols * bits_in_type / 8; + + int const bits_per_weigtht_element = get_weight_quant_bits(quant_type); + + std::vector weight_buf; + if (unprocessed_quantized_weight == nullptr) { + weight_buf.resize(num_experts * num_rows * num_cols); + unprocessed_quantized_weight = weight_buf.data(); + } + + int const input_mat_size = num_rows * num_cols; + int const quantized_mat_size = num_rows * bytes_per_out_col; + float const quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); + + std::vector per_col_max(num_cols); + + for (int expert = 0; expert < static_cast(num_experts); ++expert) { + WeightType const* current_weight = input_weight_ptr + expert * input_mat_size; + int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; + + // First we find the per column max for this expert weight. + for (size_t jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] = 0.f; + } + + for (size_t ii = 0; ii < num_rows; ++ii) { + WeightType const* current_weight_row = current_weight + ii * num_cols; + for (size_t jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); + } + } + + // Then, we construct the scales + ComputeType* current_scales = scale_ptr + expert * num_cols; + for (size_t jj = 0; jj < num_cols; ++jj) { + per_col_max[jj] *= quant_range_scale; + current_scales[jj] = ComputeType(per_col_max[jj]); + } + + // Finally, construct the weights. + for (size_t ii = 0; ii < num_rows; ++ii) { + int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; + WeightType const* current_weight_row = current_weight + ii * num_cols; + for (int jj = 0; jj < bytes_per_out_col; ++jj) { + if (bits_per_weigtht_element == 8) { + float const col_scale = per_col_max[jj]; + float const weight_elt = float(current_weight_row[jj]); + float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; + const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); + current_quantized_weight_row[jj] = clipped_weight; + } else if (bits_per_weigtht_element == 4) { + // We will pack two int4 elements per iteration of the inner loop. + int8_t packed_int4s = 0; + for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { + int const input_idx = 2 * jj + packed_idx; + if (input_idx < static_cast(num_cols)) { + float const col_scale = per_col_max[input_idx]; + float const weight_elt = float(current_weight_row[input_idx]); + float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; + int int_weight = int(scaled_weight); + const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); + + // Kill the sign extension bits (hence 0x0F mask) then shift to upper bits + // if packing the second int4 and or the bits into the final result. + packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); + } + } + current_quantized_weight_row[jj] = packed_int4s; + } else { + ORT_THROW("Unsupported quantization type"); + } + } + } + } + + preprocess_weights_for_mixed_gemm( + processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, force_interleave); +} + +template void symmetric_quantize( + int8_t*, int8_t*, half*, float const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize( + int8_t*, int8_t*, half*, half const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( + int8_t*, int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, float>( + int8_t*, int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); + +template +void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, + std::vector const& shape, QuantType quant_type, bool force_interleave) { + symmetric_quantize( + processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave); +} + +template void symmetric_quantize( + int8_t*, float*, float const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize( + int8_t*, half*, float const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize(int8_t*, half*, half const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( + int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, half>( + int8_t*, __nv_bfloat16*, half const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize( + int8_t*, half*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, float>( + int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.h new file mode 100644 index 0000000000000..3e83852228e24 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_preprocessors.h @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +enum class QuantType { + W8_A16, + W4_A16, + W4_AFP8 +}; + +constexpr int get_weight_quant_bits(QuantType quant_type) { + switch (quant_type) { + case QuantType::W8_A16: + return 8; + case QuantType::W4_A16: + return 4; + case QuantType::W4_AFP8: + return 4; + default: + ORT_THROW("Invalid quant_type"); + return -1; + } +} + +// Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols] +// 3-D shapes are [num_experts, num_rows, num_cols] +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type); + +void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type); + +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type); + +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, + std::vector const& shape, QuantType quant_type, bool force_interleave = false); + +template +void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, + std::vector const& shape, QuantType quant_type, bool force_interleave); + +// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight +// to implement a simple reference implementation. +template +void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, + bool force_interleave); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h new file mode 100644 index 0000000000000..1fe8035cbcdae --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "contrib_ops/cuda/llm/nv_infer_datatype.h" + +#include "cutlass/half.h" +#include + +#include "cutlass/bfloat16.h" +#include + +#include "cutlass/float8.h" +#include + +#if defined(ENABLE_FP4) +#include "cutlass/float_subbyte.h" +#include +#endif + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +/////////////////////////////////////////////////////////////////////////////////////////////////// +// nvinfer::DataType to Cutlass +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct CutlassType { + using type = void; +}; + +template <> +struct CutlassType { + using type = cutlass::half_t; +}; + +template <> +struct CutlassType { + using type = cutlass::bfloat16_t; +}; + +template <> +struct CutlassType { + using type = cutlass::float_e4m3_t; +}; + +#if defined(ENABLE_FP4) +template <> +struct CutlassType { + using type = cutlass::float_e2m1_t; +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// CUDA to Cutlass + +template +struct CudaToCutlassTypeAdapter { + using type = T; +}; + +template <> +struct CudaToCutlassTypeAdapter { + using type = cutlass::half_t; +}; + +template <> +struct CudaToCutlassTypeAdapter<__nv_bfloat16> { + using type = cutlass::bfloat16_t; +}; + +#if defined(ENABLE_FP8) +template <> +struct CudaToCutlassTypeAdapter<__nv_fp8_e4m3> { + using type = cutlass::float_e4m3_t; +}; + +template <> +struct CudaToCutlassTypeAdapter<__nv_fp8_e5m2> { + using type = cutlass::float_e5m2_t; +}; +#endif + +#if defined(ENABLE_FP4) +template <> +struct CudaToCutlassTypeAdapter<__nv_fp4_e2m1> { + using type = cutlass::float_e2m1_t; +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Cutlass to CUDA + +template +struct CudaToCudaTypeAdapter { + using type = T; +}; + +template <> +struct CudaToCudaTypeAdapter { + using type = half; +}; + +template <> +struct CudaToCudaTypeAdapter { + using type = __nv_bfloat16; +}; + +#if defined(ENABLE_FP8) +template <> +struct CudaToCudaTypeAdapter { + using type = __nv_fp8_e4m3; +}; + +template <> +struct CudaToCudaTypeAdapter { + using type = __nv_fp8_e5m2; +}; +#endif + +#if defined(ENABLE_FP4) +template <> +struct CudaToCudaTypeAdapter { + using type = __nv_fp4_e2m1; +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu new file mode 100644 index 0000000000000..47e662b9a88ba --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int4_gemm_scale_zeros.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, cutlass::uint4b_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu new file mode 100644 index 0000000000000..9452aa0e1fbe6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/bf16_int8_gemm_scale_zeros.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +template class CutlassFpAIntBGemmRunner<__nv_bfloat16, uint8_t, + cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu new file mode 100644 index 0000000000000..4a22e0f1b2aac --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int4_gemm_scale_zeros.cu @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scale_zeros.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scale_zeros.cu new file mode 100644 index 0000000000000..9f4091be4cd07 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fp16_int8_gemm_scale_zeros.cu @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { +template class CutlassFpAIntBGemmRunner; +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h new file mode 100644 index 0000000000000..0141c76bbc031 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" +#include +#include + +namespace tkc = onnxruntime::llm::cutlass_extensions; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +// TRT Activation Type does not have Gelu or Silu +enum class ActivationType { + Gelu, + Relu, + Silu, + Identity, + InvalidType +}; + +/* + This runner only supports: + T in {half, __nv_bfloat} WeightType in {int8_t, cutlass::uint4b_t} + + Activations, biases, scales and outputs are all assumed to be row-major. + + However, it is assumed that B is in a special format governed by cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. + In this case, B must be preprocessed using the cutlass weight only quant preprocessors. The weight preprocessor + will instantiate the layout and preprocess based on the instantiation, so layout changes should only require + modifications to mix_gemm_B_layout.h. +*/ + +class CutlassFpAIntBGemmRunnerInterface { + public: + CutlassFpAIntBGemmRunnerInterface() {} + + virtual ~CutlassFpAIntBGemmRunnerInterface() {} + + virtual void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) = 0; + + virtual void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, + int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) = 0; + + virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) = 0; + + virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) = 0; + + // Returns desired workspace size in bytes. + virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0; + + virtual std::vector getConfigs() const = 0; + + protected: + static constexpr int SPLIT_K_LIMIT = 7; + static constexpr int MIN_M_TILE = 16; + static constexpr int MIN_N_TILE = 64; + + static constexpr int MAX_M_TILE_SM90 = 128; + static constexpr int MAX_N_TILE_SM90 = 256; +}; + +template +class CutlassFpAIntBGemmRunner : public virtual CutlassFpAIntBGemmRunnerInterface { + public: + CutlassFpAIntBGemmRunner(); + ~CutlassFpAIntBGemmRunner(); + + void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) override; + + void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) override; + + void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override; + + void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, + void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, + cudaStream_t stream) override; + + // Disabled since the fused GEMM, activation kernels will not be used in v1. + + // void gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, int m, int n, + // int k, ActivationType activation_type, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t + // stream); + + // Returns desired workspace size in bytes. + size_t getWorkspaceSize(int const m, int const n, int const k) override; + + std::vector getConfigs() const override; + + private: + template + void dispatch_to_arch(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr, + const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr); + + private: + int sm_; + int multi_processor_count_; +}; + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h new file mode 100644 index 0000000000000..715397270331b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -0,0 +1,489 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ + +#include "cutlass/gemm/kernel/default_gemm.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/device/gemm_universal_base_compat.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/threadblock/default_mma.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include "core/common/common.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "contrib_ops/cuda/llm/common/logger.h" +#include "contrib_ops/cuda/llm/cutlass_heuristic.h" +#include "contrib_ops/cuda/llm/cutlass_type_conversion.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" + +namespace tk = onnxruntime::llm::common; +namespace tkc = onnxruntime::llm::cutlass_extensions; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +template +void generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + static_assert( +#ifdef ENABLE_FP8 + cutlass::platform::is_same::value || +#endif + cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for bfloat16, half, float"); + + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + ""); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using CutlassActivationType = typename CudaToCutlassTypeAdapter::type; + using CutlassWeightType = typename CudaToCutlassTypeAdapter::type; + using CutlassScaleZeroType = typename CudaToCutlassTypeAdapter::type; + using CutlassBiasType = typename CudaToCutlassTypeAdapter::type; + using CutlassOutputType = typename CudaToCutlassTypeAdapter::type; + + // We need separate config for each architecture since we will target different tensorcore instructions. For float, + // we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using EpilogueOp = + typename tkc::Epilogue::Op; + + using Operator = typename MixedGemmArchTraits::Operator; + using TaggedOperator = typename cutlass::arch::TagOperator::TaggedOperator; + + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm, Stages, true, + TaggedOperator>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB; + + if (occupancy != nullptr) { + *occupancy = onnxruntime::llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + + using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; + + int const ldb = cutlass::platform::is_same::value + ? n + : k * GemmKernel::kInterleave; + + if (weight_scales == nullptr) { + ORT_THROW("Weight scales must always be set to a non-null value."); + } + + if constexpr (cutlass::isFinegrained(QuantOp)) { + if constexpr (cutlass::platform::is_same::value) { + if (group_size != 128) { + ORT_THROW("Only group size 128 supported for fine grained W4A(fp)8 kernels."); + } + } + if (group_size != 64 && group_size != 128) { + ORT_THROW("Only group size 64 and 128 supported for fine grained kernels."); + } + + if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { + if (weight_zero_points != nullptr) { + ORT_THROW("Weight zero pointer must be a nullptr for scale only fine grained"); + } + } else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) { + if (weight_zero_points == nullptr) { + ORT_THROW("Weight zero pointer must be valid for scale and bias fine grained"); + } + } + } else { + if (group_size != k) { + ORT_THROW("Invalid group size for per column scaling kernels."); + } + + if (weight_zero_points != nullptr) { + ORT_THROW("Weight zero-points must be null when running per column scaling"); + } + } + + int const ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0; + ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f); + typename Gemm::Arguments args({m, n, k}, group_size, + {reinterpret_cast(const_cast(A)), k}, + {reinterpret_cast(const_cast(B)), ldb}, + {reinterpret_cast(const_cast(weight_scales)), ld_scale_zero}, + {reinterpret_cast(const_cast(weight_zero_points)), ld_scale_zero}, + {reinterpret_cast(const_cast(biases)), 0}, + {reinterpret_cast(C), n}, gemm_config.split_k_factor, + {ElementAccumulator(alpha), output_op_beta}); + + // This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of + // threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the + // interleaved matrix. The way masking in handled in these do not map to the interleaved layout. We need to write + // our own predicated iterator in order to relax this limitation. + if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || ((k / gemm_config.split_k_factor) % MixedGemmArchTraits::ThreadblockK))) { + ORT_THROW("Temp assertion: k must be multiple of threadblockK"); + } + + Gemm gemm; + if (gemm.get_workspace_size(args) > workspace_bytes) { + ORT_LLM_LOG_WARNING( + "Requested split-k but workspace size insufficient. Falling back to non-split-k implementation."); + // If requested split-k factor will require more workspace bytes, revert to standard gemm. + args.batch_count = 1; + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } + + auto init_status = gemm.initialize(args, workspace, stream); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } +} + +// This filters out invalid template combinations that we DON'T want instantiated in CUTLASS. For example, +// instantiating SM=75, Stages=3 is invalid so we would need to filter that out. Fine grained +// quanitzation is only supported on Ampere+ GPUs. FP8 GEMM is only supported on Ada+ GPUs. +template +void filter_and_run_mixed_gemm(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + if constexpr (Stages > 2 && arch::kMinComputeCapability < 80) { + // Multistage only supported on Ampere + std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } else if constexpr (Stages == 2 && arch::kMinComputeCapability >= 89) { + // Multistage only supported on Ampere + std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } else if constexpr (cutlass::platform::is_same::value && arch::kMinComputeCapability < 89) { + // FP8 activation type only supported on Ada+ GPUs + std::string err_msg = "Cutlass fpA_intB gemm not supported for arch " + std::to_string(arch::kMinComputeCapability) + " with activation type set to FP8"; + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + } else { + generic_mixed_gemm_kernelLauncher(A, B, weight_scales, weight_zero_points, biases, + alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + } +} + +template +void dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemm_config.stages) { + case 2: + filter_and_run_mixed_gemm(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, + n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case 3: + filter_and_run_mixed_gemm(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, + n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case 4: + filter_and_run_mixed_gemm(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, + n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + default: + std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages); + ORT_THROW("[fpA_intB_gemm] Error:", err_msg); + break; + } +} + +template +constexpr bool is_fp8() { + return std::is_same_v || std::is_same_v; +} + +template +void dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + // Don't instantiate configs that are not supported pre-hopper. Produce a sensible error instead. + constexpr bool any_is_fp8 = is_fp8() || is_fp8() || is_fp8() || is_fp8() || is_fp8(); + + constexpr bool all_types_are_the_same = std::is_same_v && std::is_same_v && std::is_same_v; + + constexpr bool is_valid_pre_hopper = (all_types_are_the_same && !any_is_fp8) || (arch::kMinComputeCapability == 89); + + if constexpr (is_valid_pre_hopper) { + // Note that SIMT configs are omitted here since they are not supported for fpA_intB. + // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the + // best for mixed type gemms. + constexpr int tile_shape_k = 128 * 8 / cutlass::sizeof_bits::value; + switch (gemm_config.tile_config_sm80) { + case tkc::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<16, 64, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, tile_shape_k>>(A, B, weight_scales, weight_zero_points, biases, alpha, + C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfig::Undefined: + ORT_THROW("[fpA_intB_gemm] Error:[dispatch_gemm_to_cutlass] gemm config undefined."); + break; + case tkc::CutlassTileConfig::ChooseWithHeuristic: + ORT_THROW( + "[fpA_intB_gemm] Error:[dispatch_gemm_to_cutlass] gemm config should have already been set by " + "heuristic."); + break; + default: + ORT_THROW( + "[fpA_intB_gemm] Error:[dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM."); + break; + } + } else { + // This is not a limitation in CUTLASS. We just do not need to support this case. + std::string err_msg = "The activation type must equal the scale, bias and output types on Ampere and earlier."; + ORT_THROW("[fpA_intB_gemm] Error: [dispatch_gemm_to_cutlass] ", err_msg); + } +} + +template +CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + sm_ = ::onnxruntime::llm::common::getSMVersion(); + multi_processor_count_ = ::onnxruntime::llm::common::getMultiProcessorCount(); +} + +template +CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); +} + +template +template +void CutlassFpAIntBGemmRunner::dispatch_to_arch(ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + // std::string config_str = gemm_config.toString(); + // printf("######## sm=%d, alpha: %f m:%d n:%d, k:%d, group_size:%d, workspace_bytes:%zu config:%s\n", sm_, alpha, m, n, k, group_size, workspace_bytes, config_str.c_str()); + + if (sm_ >= 75 && sm_ < 80) { + dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); + } else if ((sm_ >= 80 && sm_ < 89) || sm_ >= 100) { + dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); + } else if (sm_ == 89) { +#if ENABLE_FP8 && ((__CUDACC_VER_MAJOR__ < 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) + if constexpr (cutlass::platform::is_same::value) { + ORT_THROW( + "[fpA_intB_gemm] Error: INT4xFP8 GEMM for Ada needs CUDA>=12.4"); + } +#endif + dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + workspace_ptr, workspace_bytes, gemm_config, stream, occupancy); + } else if (sm_ == 90) { + static_assert(!cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "ScaleZeroType must be half for activation=fp8"); + sm90_dispatch_gemm_to_cutlass(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr, + workspace_bytes, gemm_config, stream, occupancy); + } else { + ORT_THROW("[fpA_intB_gemm] Error:Arch unsupported for CUTLASS mixed type GEMM"); + } +} + +template +void CutlassFpAIntBGemmRunner::gemm( + void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases, + float const alpha, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, + char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + if constexpr ((QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) || (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)) { + dispatch_to_arch((ActivationType const*)A, (WeightType const*)B, + (ScaleZeroType const*)weight_scales, (ScaleZeroType const*)weight_zero_points, (BiasType const*)biases, + alpha, (OutputType*)C, m, n, k, group_size, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr); + } else { + ORT_THROW("Overload with scale, zero and group size only supported for fine grained bias template."); + } +} + +template +void CutlassFpAIntBGemmRunner::gemm( + void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases, + void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, + const size_t workspace_bytes, cudaStream_t stream) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + gemm(A, B, weight_scales, weight_zero_points, biases, 1.f, C, m, n, k, group_size, gemmConfig, workspace_ptr, + workspace_bytes, stream); +} + +template +void CutlassFpAIntBGemmRunner::gemm( + void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY) { + dispatch_to_arch((ActivationType const*)A, (WeightType const*)B, + (ScaleZeroType const*)weight_scales, nullptr, nullptr, alpha, (OutputType*)C, m, n, k, k, gemmConfig, + workspace_ptr, workspace_bytes, stream, nullptr); + } else { + ORT_THROW("Overload with scale only (and no group size) only supported for per column scaling."); + } +} + +template +void CutlassFpAIntBGemmRunner::gemm( + void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k, + tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + gemm(A, B, weight_scales, 1.f, C, m, n, k, gemmConfig, workspace_ptr, workspace_bytes, stream); +} + +template +std::vector +CutlassFpAIntBGemmRunner::getConfigs() const { + static constexpr bool is_weight_only = !std::is_same::value; + tkc::CutlassGemmConfig::CandidateConfigTypeParam config_type_param = tkc::CutlassGemmConfig::CandidateConfigTypeParam::HOPPER; + if (is_weight_only) { + config_type_param = static_cast( + config_type_param | tkc::CutlassGemmConfig::CandidateConfigTypeParam::WEIGHT_ONLY); + } + std::vector candidateConfigs = get_candidate_configs(sm_, SPLIT_K_LIMIT, config_type_param); + return candidateConfigs; +} + +template +size_t +CutlassFpAIntBGemmRunner::getWorkspaceSize( + int const m, int const n, int const /*k*/) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + // For Hopper, we have to allocate large memory size in case for stream-K + if (sm_ == 90) { + // https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L878-L892 + // The above lines says sk_tiles = output_tiles - (static_cast(output_tiles / ctas_per_wave) - 1) * + // ctas_per_wave This means sk_tiles is at most 2 * ctas_per_wave, which is 2 * multi_processor_count_ + int const max_sk_tiles = 2 * multi_processor_count_; + + // https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L939 + // The above line says uint64_t sk_units = platform::min(ctas_per_sk_wave, min_sized_sk_units); + // That means sk_units is at most ctas_per_sk_wave, which is multi_processor_count_ + int const max_sk_units = multi_processor_count_; + + // https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L505 + // The above lines scales sk_tiles by the factor of static_cast(sk_units / sk_tiles + 2) + // That means the final sk_tiles is at most 2 * max_sk_tiles + max_sk_units; + int const max_sk_tiles_with_separate_reduction = 2 * max_sk_tiles + max_sk_units; + + return static_cast( + max_sk_tiles_with_separate_reduction * MAX_M_TILE_SM90 * MAX_N_TILE_SM90 * sizeof(float)); + } + // These are the min tile sizes for each config, which would launch the maximum number of blocks + int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE); + int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE); + // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. + return static_cast(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4); +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h new file mode 100644 index 0000000000000..432adb20079b6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cute/numeric/integral_constant.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "core/common/common.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "contrib_ops/cuda/llm/common/logger.h" +#include "contrib_ops/cuda/llm/cutlass_heuristic.h" + +#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h" + +namespace tkc = onnxruntime::llm::cutlass_extensions; + +using namespace cute; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +// This filters out invalid template combinations that we DON'T want instantiated in CUTLASS. For example, +// instantiating SM=75, Stages=3 is invalid so we would need to filter that out. Fine grained +// quanitzation is only supported on Ampere+ GPUs. +template +void sm90_dispatch_epilogue_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemm_config.epilogue_schedule) { + case tkc::EpilogueScheduleType::AUTO: + using EpilogueScheduleType = cute::conditional_t(CTAShape{}) == Int<64>{}, + cutlass::epilogue::TmaWarpSpecialized, cutlass::epilogue::TmaWarpSpecializedCooperative>; + sm90_generic_mixed_gemm_kernelLauncher(A, B, weight_scales, + weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, + occupancy); + break; + default: + ORT_THROW( + "[fpA_intB_gemm][sm90_dispatch_epilogue_schedules] epilogue schedule config is invalid for " + "mixed type GEMM."); + break; + } +} + +/* + 1x1x1 cluster shape is are supported for any tile shape. + + 2x1x1 cluster shape is only supported for when the M tile is at least 128. + + 1x2x1 cluster shape is only supported when the N tile is at least 128. + + 2x2x1 cluster shape is only supported when both the M and N tiles are at least 128. + + We make the above restrictions to improve compilation speed in TRT-LLM, by pruning kernels + that may not be very useful in practice. + */ +template +constexpr bool are_tile_shapes_supported() { + [[maybe_unused]] constexpr int cta_m = get<0>(CTAShape{}); + [[maybe_unused]] constexpr int cta_n = get<1>(CTAShape{}); + constexpr int cga_m = get<0>(ClusterShape{}); + constexpr int cga_n = get<1>(ClusterShape{}); + + if constexpr (cga_m == _1{} && cga_n == _1{}) { + return true; + } else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{}) { + return true; + } else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{}) { + return true; + } else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{}) { + return true; + } else { + return false; + } +} + +template +void sm90_dispatch_mainloop_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + constexpr bool tile_shapes_supported = are_tile_shapes_supported(); + + if constexpr (tile_shapes_supported) { + switch (gemm_config.mainloop_schedule) { + case tkc::MainloopScheduleType::AUTO: + using KernelScheduleType = cute::conditional_t(CTAShape{}) == Int<64>{}, + cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::gemm::KernelTmaWarpSpecializedCooperative>; + sm90_dispatch_epilogue_schedules(A, B, weight_scales, weight_zero_points, + biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + default: + ORT_THROW( + "[fpA_intB_gemm][sm90_dispatch_mainloop_schedules] mainloop schedule config is invalid " + "for " + "mixed type GEMM."); + break; + } + } else { + ORT_THROW( + "[fpA_intB_gemm][sm90_dispatch_mainloop_schedules] Unsupported CTA and Cluster shapes for " + "mixed type GEMM."); + } +} + +template +void sm90_dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + switch (gemm_config.cluster_shape) { + case tkc::ClusterShape::ClusterShape_1x1x1: + sm90_dispatch_mainloop_schedules>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, + k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_2x1x1: + sm90_dispatch_mainloop_schedules>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, + k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_1x2x1: + sm90_dispatch_mainloop_schedules>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, + k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::ClusterShape::ClusterShape_2x2x1: + sm90_dispatch_mainloop_schedules>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, + k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + default: + ORT_THROW("[fpA_intB_gemm][dispatch_CGA_config] Config is invalid for mixed type GEMM."); + break; + } +} + +template +void sm90_dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales, + ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n, + int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config, + cudaStream_t stream, int* occupancy = nullptr) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + // Note that SIMT configs are omitted here since they are not supported for fpA_intB. + // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best + // for mixed type gemms. + + constexpr int Ktile = 128 / sizeof(ActivationType); + using _Ktile = Int; + switch (gemm_config.tile_config_sm90) { + case tkc::CutlassTileConfigSM90::CtaShape64x16x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x32x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x64x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x128x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape64x256x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x16x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x32x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x64x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x128x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::CtaShape128x256x128B: + sm90_dispatch_gemm_config>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, + gemm_config, workspace, workspace_bytes, stream, occupancy); + break; + case tkc::CutlassTileConfigSM90::Undefined: + ORT_THROW("[fpA_intB_gemm][sm90_dispatch_gemm_to_cutlass] gemm config undefined."); + break; + case tkc::CutlassTileConfigSM90::ChooseWithHeuristic: + ORT_THROW( + "[fpA_intB_gemm][sm90_dispatch_gemm_to_cutlass] gemm config should have already been set by " + "heuristic."); + break; + default: + ORT_THROW("[fpA_intB_gemm][sm90_dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM."); + break; + } +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu new file mode 100644 index 0000000000000..468d53f336e55 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_1.generated.cu @@ -0,0 +1,264 @@ +#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl" +namespace onnxruntime::llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedPingpong, cutlass::epilogue::TmaWarpSpecialized> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu new file mode 100644 index 0000000000000..0156c83840b09 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_gemm_launcher_2.generated.cu @@ -0,0 +1,516 @@ +#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl" +namespace onnxruntime::llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const cutlass::uint4b_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const half*, const uint8_t*, const half*, const half*, const half*, const float, +half*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const cutlass::uint4b_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<16>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<32>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<64>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<128>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<1>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +template void sm90_generic_mixed_gemm_kernelLauncher<__nv_bfloat16, uint8_t, __nv_bfloat16, __nv_bfloat16, __nv_bfloat16, +cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, onnxruntime::llm::cutlass_extensions::EpilogueOpBias, +cute::Shape, cute::Int<256>, cute::Int<64>>, cute::Shape, cute::Int<2>, cute::Int<1>>, +cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative> ( +const __nv_bfloat16*, const uint8_t*, const __nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, const float, +__nv_bfloat16*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); + + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h new file mode 100644 index 0000000000000..594ae1079c06e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/weight_only_quant_op.h" +#include + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +template +void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, + onnxruntime::llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, + cudaStream_t stream, int* occupancy = nullptr); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl new file mode 100644 index 0000000000000..779ff88455703 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl @@ -0,0 +1,282 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/packed_stride.hpp" + +#include "contrib_ops/cuda/llm/cutlass_extensions/compute_occupancy.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/epilogue_helpers.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.h" + +#include "contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/collective_builder_interleaved.hpp" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include "core/common/common.h" +#include "contrib_ops/cuda/llm/common/cuda_runtime_utils.h" +#include "contrib_ops/cuda/llm/common/logger.h" +#include "contrib_ops/cuda/llm/cutlass_heuristic.h" +#include "contrib_ops/cuda/llm/cutlass_type_conversion.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h" + +namespace tk = onnxruntime::llm::common; +namespace tkc = onnxruntime::llm::cutlass_extensions; + +using namespace cute; + +namespace onnxruntime::llm { +namespace kernels { +namespace cutlass_kernels { + +template +#ifdef COMPILE_HOPPER_TMA_GEMMS +void sm90_generic_mixed_gemm_kernelLauncher( + ActivationType const* A, WeightType const* B, + ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases, + float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig /*gemm_config*/, + char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + using CutlassActivationType = typename CudaToCutlassTypeAdapter::type; + + if constexpr (!should_filter_tma_warp_specialized_gemm_problem_shape_v) { + using CutlassWeightType = typename CudaToCutlassTypeAdapter::type; + + using CutlassScaleZeroType = typename CudaToCutlassTypeAdapter::type; + using CutlassBiasType = typename CudaToCutlassTypeAdapter::type; + using CutlassOutputType = typename CudaToCutlassTypeAdapter::type; + + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v, + "Activation type must be bfloat16, half, FP8"); + + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v, + "Weight type must be fp8, uint8_t or uint4_t"); + + static_assert(!std::is_same_v || + std::is_same_v, + "Scale/Zero type must be half for fp8 activation"); + + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + // This example manually swaps and transposes, so keep transpose of input layouts + using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; + using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + + using ElementZero = CutlassScaleZeroType; + using ElementScale = CutlassScaleZeroType; + + // C/D matrix configuration. We reuse the C operand for the bias and set the stride for broadcast. + using LayoutBias = cutlass::layout::RowMajor; + constexpr int AlignmentBias = 128 / cutlass::sizeof_bits::value; + + // D matrix configuration + using LayoutOutput = cutlass::layout::RowMajor; + constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ElementCompute = float; // Element type for epilogue computation + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = CTAShape; // Threadblock-level tile size + using KernelSchedule = MainloopScheduleType; + using EpilogueSchedule = EpilogueScheduleType; + + // Shrink the N dimension to match CTA_N if needed + constexpr int epi_tile_M = cute::min(shape<0>(TileShape{}), 128); // 64 or 128 + constexpr int epi_tile_N = cute::min(shape<1>(TileShape{}), 32); // Allow this to be 16 for some small N tiles. + using EpilogueTileType = cute::Shape, cute::Int>; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + static_assert(std::is_same_v, ""); + using EVT_bias_addition = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute, // alpha * acc + bias + cutlass::epilogue::fusion::Sm90ScalarBroadcast, // alpha + cutlass::epilogue::fusion::Sm90AccFetch, // acc + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, CutlassBiasType, CutlassBiasType, + Stride<_1, _0, _0>, + AlignmentBias> // bias + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementAccumulator, + // Transpose layout of D here since we use the explicit swap + transpose trick + // Void C since we don't use it. Prevents smem allocation. + void, typename cutlass::layout::LayoutTranspose::type, AlignmentBias, CutlassOutputType, + typename cutlass::layout::LayoutTranspose::type, AlignmentOutput, EpilogueSchedule, + EVT_bias_addition>::CollectiveOp; + + using PackedScaleZero = cute::tuple; + using PackedScale = cute::tuple; + using ElementBCollectiveInfo = std::conditional_t; + + // We swap A and B operands to the builder here + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilderInterleaved< + ArchTag, + OperatorClass, ElementBCollectiveInfo, LayoutB_Transpose, AlignmentB, CutlassActivationType, + LayoutA_Transpose, AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using TileScheduler = cute::conditional_t(CTAShape{}) == Int<64>{}, cutlass::gemm::PersistentScheduler, + cutlass::gemm::StreamKScheduler>; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + + if (occupancy != nullptr) { + *occupancy = onnxruntime::llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename GemmKernel::StrideA; + using StrideB = typename GemmKernel::StrideB; + using StrideC = typename GemmKernel::StrideC; + using StrideD = typename GemmKernel::StrideD; + using StrideS = typename CollectiveMainloop::StrideScale; + + if (weight_scales == nullptr) { + ORT_THROW("Weight scales must always be set to a non-null value."); + } + + if constexpr (cutlass::isFinegrained(QuantOp)) { + int cta_shape_k = cute::size<2>(TileShape{}); + if (group_size % cta_shape_k != 0) { + std::string err_msg = "The group size must a multiple of " + std::to_string(cta_shape_k); + ORT_THROW("[fpA_intB_gemm] ", err_msg); + } + + if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY) { + if (weight_zero_points != nullptr) { + ORT_THROW("Weight zero pointer must be a nullptr for scale only fine grained"); + } + } else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) { + if (weight_zero_points == nullptr) { + ORT_THROW("Weight zero pointer must be valid for scale and bias fine grained"); + } + } + } else { + if (group_size != k) { + ORT_THROW("Invalid group size for per column scaling kernels."); + } + + if (weight_zero_points != nullptr) { + ORT_THROW("Weight zero-points must be null when running per column scaling"); + } + } + + auto cutlass_scale_k = (k + group_size - 1) / group_size; + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1)); + StrideS stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(n, cutlass_scale_k, 1)); + + // Use the output as the bias to avoid making a tma descriptor with a nullptr. + auto output_as_bias_type = reinterpret_cast(C); + + typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, + {n, m, k, 1}, + {reinterpret_cast(B), stride_B, + reinterpret_cast(A), stride_A, + reinterpret_cast(weight_scales), stride_S, + group_size, reinterpret_cast(weight_zero_points)}, + {{}, output_as_bias_type, stride_D, reinterpret_cast(C), stride_D}}; + + args.epilogue.thread = { + {alpha}, // alpha args + {}, // accumulator + {reinterpret_cast(biases), CutlassBiasType(0.f)}, // bias args + {} // end multiply_add + }; + + Gemm gemm; + if (gemm.get_workspace_size(args) > workspace_bytes) { + ORT_LLM_LOG_ERROR("[fpA_intB_gemm] given workspace size insufficient."); + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement)); + ORT_THROW("[fpA_intB_gemm] ", err_msg); + } + + auto init_status = gemm.initialize(args, workspace, stream); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(init_status)); + ORT_THROW("[fpA_intB_gemm] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status)); + ORT_THROW("[fpA_intB_gemm] " + err_msg); + } + } else { + std::stringstream ss; + ss << "[fpA_intB_gemm] Config (" << (int64_t)cute::size<0>(CTAShape{}) << "," + << (int64_t)cute::size<1>(CTAShape{}) << "," << (int64_t)cute::size<2>(CTAShape{}) << ") (" + << (int64_t)cute::size<0>(ClusterShape{}) << "," << (int64_t)cute::size<1>(ClusterShape{}) << "," + << (int64_t)cute::size<2>(ClusterShape{}) << ") not compiled with FAST_BUILD."; + + ORT_THROW(ss.str()); + } +} +#else // COMPILE_HOPPER_TMA_GEMMS +void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const*, WeightType const*, + ScaleZeroType const*, ScaleZeroType const*, BiasType const*, + float const, OutputType*, int, int, int, int const, tkc::CutlassGemmConfig, + char*, size_t, cudaStream_t, int*) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + ORT_THROW("[fpA_intB_gemm] Please recompile with support for hopper by passing 90a-real as an arch."); +} +#endif // COMPILE_HOPPER_TMA_GEMMS + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu new file mode 100644 index 0000000000000..55beb8b9ca029 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.cu @@ -0,0 +1,260 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h" +#include +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +template +__global__ void transposeScaleKernel( + const T* scale, + T* transposed_scale, + int n, int k_blocks) { + // Calculate the output matrix coordinates [row, col] for this thread + // The output matrix has dimensions [k_blocks, n] + int out_row = blockIdx.y * blockDim.y + threadIdx.y; + int out_col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check bounds to ensure we are within the output matrix dimensions [k_blocks, n] + if (out_row < k_blocks && out_col < n) { + int in_row = out_col; + int in_col = out_row; + int64_t input_offset = static_cast(in_row) * k_blocks + in_col; + int64_t output_offset = static_cast(out_row) * n + out_col; + T scale_val = scale[input_offset]; + transposed_scale[output_offset] = scale_val; + } +} + +template +void launch_transpose_scale_kernel( + cudaStream_t stream, + const T* scale, + T* transposed_scale, + int n, int k_blocks) { + constexpr int BLOCK_SIZE = 16; + dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE); + dim3 gridDim( + (n + blockDim.x - 1) / blockDim.x, // Grid size in x covers output columns (n) + (k_blocks + blockDim.y - 1) / blockDim.y // Grid size in y covers output rows (k_blocks) + ); + + transposeScaleKernel<<>>( + scale, + transposed_scale, + n, + k_blocks); +} + +// CUDA kernel to compute -scale * zero_point and transpose +// Each thread computes one element of the OUTPUT matrix (shape [k_blocks, n]) +template +__global__ void computeScaledZeroPointAndTransposeKernel( + const Z* zero_point, // Input zero_point matrix [n, k_blocks] or [n, (k_blocks + 1) / 2] if packed int4 + const T* transposed_scale, // transposed scale [k_blocks, n] + T* scaled_zero_point, // Output matrix [k_blocks, n] + int n, // Rows of input matrices + int k_blocks, // Columns of input matrices + float default_zero_point) { + // Calculate the output matrix coordinates [row, col] for this thread + // The output matrix has dimensions [k_blocks, n] + int out_row = blockIdx.y * blockDim.y + threadIdx.y; + int out_col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check bounds to ensure we are within the output matrix dimensions [k_blocks, n] + if (out_row < k_blocks && out_col < n) { + int in_row = out_col; + int in_col = out_row; + int64_t output_offset = static_cast(out_row) * n + out_col; + + // Perform the computation: scaled_zero_point[out_row, out_col] = -scale[in_row, in_col] * zero_point[in_row, in_col] + T scale_val = transposed_scale[output_offset]; + float zero_point_val; + if (zero_point != nullptr) { + if constexpr (is_zero_point_int4_packed) { // zero point is 4 bit, and two elements are packed into one byte. + int64_t packed_row_size = (k_blocks + 1) / 2; + int64_t packed_zp_offset = static_cast(in_row) * packed_row_size + in_col / 2; + uint8_t packed_zp = zero_point[packed_zp_offset]; + zero_point_val = static_cast((in_col & 0x01) ? (packed_zp >> 4) : (packed_zp & 0x0f)); + } else { + int64_t input_offset = static_cast(in_row) * k_blocks + in_col; + zero_point_val = static_cast(zero_point[input_offset]); + } + } else { + zero_point_val = default_zero_point; + } + + float result = static_cast(scale_val) * (-zero_point_val + default_zero_point); + scaled_zero_point[output_offset] = static_cast(result); + } +} + +template +void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const Z* zero_point, + const T* transposed_scale, + T* scaled_zero_point, + int n, int k_blocks, float default_zero_point) { + assert(zero_point != nullptr); + constexpr int BLOCK_SIZE = 16; + dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE); + dim3 gridDim( + (n + blockDim.x - 1) / blockDim.x, // Grid size in x covers output columns (n) + (k_blocks + blockDim.y - 1) / blockDim.y // Grid size in y covers output rows (k_blocks) + ); + + computeScaledZeroPointAndTransposeKernel<<>>( + zero_point, + transposed_scale, + scaled_zero_point, + n, + k_blocks, + default_zero_point); +} + +// Explicit instantiations: +template void launch_transpose_scale_kernel( + cudaStream_t stream, + const half* scale, + half* transposed_scale, + int n, int k_blocks); + +template void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const half* zero_point, + const half* transposed_scale, + half* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +template void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const uint8_t* zero_point, + const half* transposed_scale, + half* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +// zero point is 4 bits packed. +template void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const uint8_t* zero_point, + const half* transposed_scale, + half* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +// CUDA kernel to unpack uint4, transpose, and pack into int8 directly +__global__ void unpack_transpose_pack_uint4_to_int8_kernel_v2( + const unsigned char* __restrict__ packed_weight, + signed char* __restrict__ packed_transposed_weight, + int n, // original matrix rows + int k) // original matrix columns +{ + // The output 'packed_transposed_weight' has dimensions k x (n/2) bytes. + // Each thread processes one byte in the output. + int out_flat_idx = blockIdx.x * blockDim.x + threadIdx.x; + + // Total number of bytes in the output packed_transposed_weight matrix + int total_output_bytes = k * (n / 2); + + if (out_flat_idx < total_output_bytes) { + constexpr signed char default_zero_point = 8; + + // Calculate row and column in the output packed_transposed_weight matrix (k x n/2) + // out_row_packed: row in the k dimension of the output (0 to k-1) + // out_col_packed: column in the n/2 dimension of the output (0 to n/2 - 1) + const int out_row_packed = out_flat_idx / (n / 2); + const int out_col_packed = out_flat_idx % (n / 2); + + // These two int8 values will form the current output packed byte: + // val_0: corresponds to original_unpacked[2 * out_col_packed][out_row_packed] + // val_1: corresponds to original_unpacked[2 * out_col_packed + 1][out_row_packed] + + // --- Retrieve val_0 --- + // Its original (unpacked) row index was '2 * out_col_packed' + const int r_orig_0 = 2 * out_col_packed; + // Its original (unpacked) column index was 'out_row_packed' + const int c_orig_0 = out_row_packed; + + // Determine the flat index in the input 'packed_weight' (n x k/2) where val_0 resides + const int packed_weight_idx_0 = r_orig_0 * (k / 2) + c_orig_0 / 2; + + unsigned char packed_data_0 = packed_weight[packed_weight_idx_0]; + signed char val_0; + if ((c_orig_0 % 2) == 0) { // If original column is even, it's the lower 4 bits + val_0 = (signed char)(packed_data_0 & 0x0f) - default_zero_point; + } else { // If original column is odd, it's the upper 4 bits + val_0 = (signed char)(packed_data_0 >> 4) - default_zero_point; + } + + // --- Retrieve val_1 --- + // Its original (unpacked) row index was '2 * out_col_packed + 1' + const int r_orig_1 = 2 * out_col_packed + 1; + // Its original (unpacked) column index was 'out_row_packed' + const int c_orig_1 = out_row_packed; + + // Determine the flat index in the input 'packed_weight' (n x k/2) where val_1 resides + const int packed_weight_idx_1 = r_orig_1 * (k / 2) + c_orig_1 / 2; + + unsigned char packed_data_1 = packed_weight[packed_weight_idx_1]; + signed char val_1; + if ((c_orig_1 % 2) == 0) { // If original column is even, it's the lower 4 bits + val_1 = (signed char)(packed_data_1 & 0x0f) - default_zero_point; + } else { // If original column is odd, it's the upper 4 bits + val_1 = (signed char)(packed_data_1 >> 4) - default_zero_point; + } + + // Pack the two signed char values (now 8-bit, but we only care about their 4 LSBs) + // back into a single byte for the output. + packed_transposed_weight[out_flat_idx] = (unsigned char)((val_0 & 0x0f) | ((val_1 & 0x0f) << 4)); + } +} + +void unpack_uint4_transposed_to_int8_direct_cuda( + cudaStream_t stream, void* packed_transposed_weight, const void* packed_weight, int n, int k) { + int total_output_bytes = k * (n / 2); + int threads_per_block = 256; + int num_blocks = (total_output_bytes + threads_per_block - 1) / threads_per_block; + + unpack_transpose_pack_uint4_to_int8_kernel_v2<<>>( + (const unsigned char*)packed_weight, + (signed char*)packed_transposed_weight, + n, + k); +} + +__global__ void transpose_uint8_matrix_and_convert_to_int8_kernel( + const uint8_t* __restrict__ input, // shape: (n, k) + int8_t* __restrict__ output, // shape: (k, n) + int n, int k) { + + int row = blockIdx.y * blockDim.y + threadIdx.y; // index in n + int col = blockIdx.x * blockDim.x + threadIdx.x; // index in k + + if (row < n && col < k) { + int input_idx = row * k + col; + int output_idx = col * n + row; + output[output_idx] = static_cast(static_cast(input[input_idx]) - 128); + } +} + +void transpose_uint8_matrix_and_convert_to_int8( + cudaStream_t stream, + int8_t* output, // shape: (k, n) + const uint8_t* input, // shape: (n, k) + int n, int k) { + + dim3 blockDim(16, 16); + dim3 gridDim((k + blockDim.x - 1) / blockDim.x, + (n + blockDim.y - 1) / blockDim.y); + + transpose_uint8_matrix_and_convert_to_int8_kernel<<>>(input, output, n, k); +} + + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h new file mode 100644 index 0000000000000..61023b62d8a49 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +// Convert scale and zero_point from MatMulNBits to the format required by fpA_intB_gemm or fpA_intB_gemv kernels. +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +template +void launch_scaled_zero_point_kernel( + cudaStream_t stream, + const Z* zero_point, + const T* transposed_scale, + T* scaled_zero_point, + int n, int k_blocks, float default_zero_point); + +template +void launch_transpose_scale_kernel( + cudaStream_t stream, + const T* scale, + T* transposed_scale, + int n, int k_blocks); + +// Transpose uint4 weight matrix and add default zero points then pack as int8. +void unpack_uint4_transposed_to_int8_direct_cuda(cudaStream_t stream, + void* packed_transposed_weight, + const void* packed_weight, + int n, + int k); + +// Transpose uint8 weight matrix and add default zero points as int8. +void transpose_uint8_matrix_and_convert_to_int8(cudaStream_t stream, + int8_t* output, // shape: (k, n) + const uint8_t* input, // shape: (n, k) + int n, int k); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc new file mode 100644 index 0000000000000..8112562623791 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.cc @@ -0,0 +1,100 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h" +#include "contrib_ops/cuda/llm/common/workspace.h" + +using namespace onnxruntime::llm::common; +using namespace onnxruntime::llm::kernels::cutlass_kernels; + +namespace onnxruntime::llm::kernels::weight_only { + +void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic( + int m, int n, int k, + WeightOnlyGroupwiseQuantGemmPluginProfiler::Config const& tactic, char* workspace, cudaStream_t const& stream) { + int const originalN = mQuantBits == 8 ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO; + half* actPtr = reinterpret_cast(workspace); + void* weightPtr = nextWorkspacePtr(reinterpret_cast(actPtr), m * k * sizeof(half)); + half* inputScalesPtr = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(weightPtr), n * k * sizeof(float))); + half* zerosPtr = reinterpret_cast( + nextWorkspacePtr(reinterpret_cast(inputScalesPtr), k * originalN * sizeof(half) / mGroupSize)); + half* biasesPtr = reinterpret_cast( + nextWorkspacePtr(reinterpret_cast(zerosPtr), k * originalN * sizeof(half) / mGroupSize)); + half* outputPtr = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(biasesPtr), n * sizeof(half))); + char* workspacePtr = reinterpret_cast(nextWorkspacePtr(reinterpret_cast(outputPtr), m * originalN * sizeof(half))); + + if (!mHasZeros) { + zerosPtr = nullptr; + } + + if (!mHasBiases) { + biasesPtr = nullptr; + } + + if (tactic.enableCudaKernel) { + // run CUDA kernel + void const* pre_quant_scale_ptr = nullptr; + bool apply_alpha_in_advance = false; + float alpha = 1.0f; + onnxruntime::llm::kernels::fpA_intB_gemv::Params params( + actPtr, pre_quant_scale_ptr, weightPtr, + inputScalesPtr, zerosPtr, + biasesPtr, outputPtr, + alpha, m, originalN, k, mGroupSize, mCudaKernelType, apply_alpha_in_advance); + onnxruntime::llm::kernels::fpA_intB_gemv::kernel_launcher(mArch, params, stream); + } else { + // run CUTLASS kernel + int const wsSize = mRunner->getWorkspaceSize(m, originalN, k); + if (mQuantBits == 8) { + mRunner->gemm(actPtr, reinterpret_cast(weightPtr), inputScalesPtr, zerosPtr, biasesPtr, outputPtr, + m, originalN, k, mGroupSize, tactic, workspacePtr, wsSize, stream); + } else { + mRunner->gemm(actPtr, reinterpret_cast(weightPtr), inputScalesPtr, zerosPtr, biasesPtr, + outputPtr, m, originalN, k, mGroupSize, tactic, workspacePtr, wsSize, stream); + } + } +} + +void WeightOnlyGroupwiseQuantGemmPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k) { + // Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16) + int const originalN = mQuantBits == 8 ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO; + std::vector workspaces = { + maxM * k * sizeof(half), // A + k * n * sizeof(float), // B + k * originalN * sizeof(half) / mGroupSize, // scales + k * originalN * sizeof(half) / mGroupSize, // zeros + originalN * sizeof(half), // biases + maxM * originalN * sizeof(half), // C + mRunner->getWorkspaceSize(maxM, originalN, k) // workspace + }; + size_t bytes = calculateTotalWorkspaceSize(workspaces.data(), workspaces.size()); + setTmpWorkspaceSizeInBytes(bytes); +} + +std::vector WeightOnlyGroupwiseQuantGemmPluginProfiler::getTactics( + int /*m*/, int /*n*/, int /*k*/) const { + return mRunner->getConfigs(); +} + +bool WeightOnlyGroupwiseQuantGemmPluginProfiler::checkTactic(int m, int /*n*/, int /*k*/, Config const& tactic) const { + // stop to profile Cuda kernel for m >= 16 + if (tactic.enableCudaKernel) { + return m < 16; + } + return true; +} + +} // namespace onnxruntime::llm::kernels::weight_only diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h new file mode 100644 index 0000000000000..7be77fa43d85d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h @@ -0,0 +1,86 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "contrib_ops/cuda/llm/gemm_profiler.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h" + +#include +#include +#include +#include +#include +#include + +using WeightOnlyGemmRunner = onnxruntime::llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunnerInterface; +using WeightOnlyGemmRunnerPtr = std::shared_ptr; +using KernelType = onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; + +namespace onnxruntime::llm::kernels::weight_only { +enum class WeightTypeId { + INT8 = 1, + INT4 = 2, +}; + +constexpr int32_t FP16_BITS = 16; +constexpr int32_t INT8_BITS = 8; +constexpr int32_t INT4_BITS = 4; +constexpr int32_t FP16_INT4_RATIO = FP16_BITS / INT4_BITS; +constexpr int32_t FP16_INT8_RATIO = FP16_BITS / INT8_BITS; + +class WeightOnlyGroupwiseQuantGemmPluginProfiler + : public GemmPluginProfiler { + public: + using Config = onnxruntime::llm::cutlass_extensions::CutlassGemmConfig; + + void setQuant(int bits, bool has_bias, bool has_zeros) { + mQuantBits = bits; + mHasBiases = has_bias; + mHasZeros = has_zeros; + } + + void setGroupSize(int groupSize) { + mGroupSize = groupSize; + } + + void setCudaKernelType(KernelType cudaKernelType, int arch) { + mCudaKernelType = cudaKernelType; + mArch = arch; + } + + protected: + void runTactic(int m, int n, int k, Config const& tactic, + char* workspace, cudaStream_t const& stream) override; + + void computeTmpSize(size_t maxM, size_t n, size_t k) override; + + std::vector getTactics(int m, int n, int k) const override; + + bool checkTactic(int m, int n, int k, Config const& tactic) const override; + + private: + bool mHasBiases; + bool mHasZeros; + int mQuantBits; + int mGroupSize; + KernelType mCudaKernelType; + int mArch; +}; + +} // namespace onnxruntime::llm::kernels::weight_only diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/details.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/details.h new file mode 100644 index 0000000000000..4fa64ef329c57 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/details.h @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/llm/cutlass_extensions/interleaved_numeric_conversion.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +template +struct kernel_type_traits; +#define KERNEL_TYPE_TRAITS_REGISTRY(KT, _isGroupwise, _isInt4) \ + template <> \ + struct kernel_type_traits { \ + static constexpr bool isGroupwise = _isGroupwise; \ + static constexpr bool isInt4 = _isInt4; \ + }; + +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int8Groupwise, true, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int4Groupwise, true, true); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int8PerChannel, false, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int4PerChannel, false, true); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int8Groupwise, true, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int4Groupwise, true, true); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int8PerChannel, false, false); +KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int4PerChannel, false, true); +#undef KERNEL_TYPE_TRAITS_REGISTRY + +// A generic memory iterator used for coalesced global memory access with optional enablement. +// Template parameters: +// Enable: If false, disables loading/storing. +// TVec: Vectorized type (e.g., float4, half2). +// Strided: Number of rows in a tile. +// Continuous: Number of contiguous vector elements to load/store at once. +// Scalar type (e.g., half). +template +class GMemIterator { + public: + __device__ __forceinline__ GMemIterator(T* addr, int offset, int step, int stride) + : addr_(Enable ? (addr + offset) : nullptr), step_(step), stride_(stride) { + } + + __device__ __forceinline__ void load(void* dst, int iter, int ii = 0) { + if constexpr (Enable) { +#pragma unroll + for (int jj = 0; jj < Continuous; ++jj) { + reinterpret_cast(dst)[jj] = reinterpret_cast(addr_ + iter * step_ + ii * stride_)[jj]; + } + } + } + + private: + T* addr_; + int step_; + int stride_; +}; + +struct FP16DetailsA { + using Type = half; + using Type2 = half2; + static constexpr int kElemBits = 16; +}; + +struct BF16DetailsA { + using Type = __nv_bfloat16; + using Type2 = __nv_bfloat162; + static constexpr int kElemBits = 16; +}; + +struct Int8DetailsW { + static constexpr int kElemBits = 8; +}; + +struct Int4DetailsW { + static constexpr int kElemBits = 4; +}; + +template +struct ColumnMajor { + using DetailsA = TypeDetailsA; + using DetailsW = TypeDetailsW; + using AccessTypeA = float4; + using AccessTypeW = int; + static constexpr int kAccessSize = 128; + static constexpr int kStepK = kAccessSize / TypeDetailsA::kElemBits; + static constexpr int kTileSize = TileSizeK; + static constexpr int kInterleave = 1; + + struct Mapper { + __device__ __forceinline__ int operator()(int i) { + return i; + } + }; +}; + +template +struct ColumnMajorInterleavedForHopper { + using DetailsA = TypeDetailsA; + using DetailsW = TypeDetailsW; + using AccessTypeA = float4; + using AccessTypeW = int4; + static constexpr int kAccessSize = 128; + static constexpr int kStepK = kAccessSize / TypeDetailsW::kElemBits; + static constexpr int kTileSize = TileSizeK; + static constexpr int kInterleave = 1; + + static constexpr int kTypeFactor = 128 * 8 / (TileSizeK * TypeDetailsW::kElemBits); + + // constants for mapper + static constexpr int kElementGroupSizeA = TileSizeK / 32; + static constexpr int kElementGroupSizeW = kTypeFactor * kElementGroupSizeA; + static constexpr int kGroupOffsetA = 4 * kElementGroupSizeA; + + struct Mapper { + __device__ __forceinline__ int operator()(int i) { + return i % kElementGroupSizeA + (i % kGroupOffsetA) / kElementGroupSizeA * kElementGroupSizeW + i / kGroupOffsetA * kElementGroupSizeA; + } + }; +}; + +template +struct ColumnMajorInterleaved { + using DetailsA = TypeDetailsA; + using DetailsW = TypeDetailsW; + using AccessTypeA = float4; + using AccessTypeW = int4; + static constexpr int kAccessSize = 128; + static constexpr int kStepK = kAccessSize / TypeDetailsW::kElemBits; + static constexpr int kTileSize = TileSizeK; + static constexpr int kInterleave = 128 * 8 / (TileSizeK * TypeDetailsW::kElemBits); + + // constants for mapper + static constexpr int kElementGroupSizeA = TileSizeK / 32; + static constexpr int kElementGroupSizeW = kInterleave * kElementGroupSizeA; + static constexpr int kGroupOffsetA = 4 * kElementGroupSizeA; + + struct Mapper { + __device__ __forceinline__ int operator()(int i) { + return i % kElementGroupSizeA + (i % kGroupOffsetA) / kElementGroupSizeA * kElementGroupSizeW + i / kGroupOffsetA * kElementGroupSizeA; + } + }; +}; + +template class LayoutDetails_, + bool UseInterleavedConverter, int TileSizeK> +struct KernelDetails { + using TypeDetailsA = TypeDetailsA_; + using TypeDetailsW = TypeDetailsW_; + using LayoutDetails = LayoutDetails_; + using AccessTypeA = typename LayoutDetails::AccessTypeA; + using AccessTypeW = typename LayoutDetails::AccessTypeW; + static constexpr int kWarpSize = 32; + static constexpr int kStepK = LayoutDetails::kStepK; + static constexpr int kAccessNumA = kStepK * TypeDetailsA::kElemBits / (sizeof(AccessTypeA) * 8); + static constexpr int kAccessNumW = kStepK * TypeDetailsW::kElemBits / (sizeof(AccessTypeW) * 8); + static constexpr int kInterleave = LayoutDetails::kInterleave; + static constexpr int kThreadsPerInterleavedTile = LayoutDetails::kTileSize / kStepK; + static constexpr int kElemsPerByteW = 8 / TypeDetailsW::kElemBits; + static constexpr bool kUseInterleavedConverter = UseInterleavedConverter; +}; + +template +struct I2FConverter; + +template +struct I2FConverter { + static_assert(std::is_same_v || std::is_same_v); + static_assert(WElemBits == 4 || WElemBits == 8); + using CutlassAType = std::conditional_t, cutlass::half_t, cutlass::bfloat16_t>; + using CutlassWType = std::conditional_t; + static constexpr int kConvertCount = 32 / WElemBits; + using Converter = cutlass::FastInterleavedAndBiasedNumericArrayConverter; + using CvtSrcType = typename Converter::source_type; + using CvtResType = typename Converter::result_type; + + template + __device__ __forceinline__ static void convert(void* src, void* dst) { + static_assert(N % kConvertCount == 0); +#pragma unroll + for (int ii = 0; ii < N / kConvertCount; ++ii) { + reinterpret_cast(dst)[ii] = Converter::convert(reinterpret_cast(src)[ii]); + } + } +}; + +template +struct I2FConverter { + static_assert(std::is_same_v || std::is_same_v); + static_assert(WElemBits == 4 || WElemBits == 8); + using CutlassAType = std::conditional_t, cutlass::half_t, cutlass::bfloat16_t>; + using CutlassWType = std::conditional_t; + static constexpr int kConvertCount = 32 / WElemBits; + using Converter = cutlass::NumericArrayConverter; + using CvtSrcType = typename Converter::source_type; + using CvtResType = typename Converter::result_type; + + template + __device__ __forceinline__ static void convert(void* src, void* dst) { + static_assert(N % kConvertCount == 0); +#pragma unroll + for (int ii = 0; ii < N / kConvertCount; ++ii) { + reinterpret_cast(dst)[ii] = Converter::convert(reinterpret_cast(src)[ii]); + } + } +}; + +template +struct ConverterWrapper { + using TypeDetailsA = typename Details::TypeDetailsA; + using TypeDetailsW = typename Details::TypeDetailsW; + static constexpr bool kUseInterleavedConverter = Details::kUseInterleavedConverter; + using Converter = I2FConverter; +}; + +template +void select_gs(Params& params, cudaStream_t s); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h new file mode 100644 index 0000000000000..ff1a28661184f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h @@ -0,0 +1,423 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/details.h" +#include "core/common/common.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +template +struct MathWrapper { +}; + +template <> +struct MathWrapper { + using Type = typename FP16DetailsA::Type; + using Type2 = typename FP16DetailsA::Type2; + + __device__ __forceinline__ static Type2 to_vec2(Type const& v) { + return __half2half2(v); + } + + __device__ __forceinline__ static Type2 fma2(Type2 const& a, Type2 const& b, Type2 const& c) { + return __hfma2(a, b, c); + } + + __device__ __forceinline__ static Type2 mul2(Type2 const& a, Type2 const& b) { + return __hmul2(a, b); + } + + // __device__ __forceinline__ static Type2 deq2(Type2 const& weight, Type2 const& scale, Type2 const& zero_point) { + // return __hmul2(__hsub2(weight, zero_point), scale); + // } +}; + +template <> +struct MathWrapper { + using Type = typename BF16DetailsA::Type; + using Type2 = typename BF16DetailsA::Type2; + + __device__ __forceinline__ static Type2 to_vec2(Type const& v) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __bfloat162bfloat162(v); +#else + uint32_t val = 0; + Type2 ret = reinterpret_cast(val); + return ret; +#endif + } + + __device__ __forceinline__ static Type2 fma2(Type2 const& a, Type2 const& b, Type2 const& c) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __hfma2(a, b, c); +#else + return to_vec2(static_cast(0.f)); +#endif + } + + __device__ __forceinline__ static Type2 mul2(Type2 const& a, Type2 const& b) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __hmul2(a, b); +#else + return to_vec2(static_cast(0.f)); +#endif + } +}; + +template +__device__ __forceinline__ void apply_scale(void* act, void* act_scale) { + using Type2 = typename MathWrapper::Type2; + static_assert(K % 2 == 0); + [[maybe_unused]] static constexpr int VecK = K / 2; + if constexpr (Enable) { + Type2* pa = reinterpret_cast(act); + Type2* pb = reinterpret_cast(act_scale); +#pragma unroll + for (int m = 0; m < M; ++m) { +#pragma unroll + for (int k = 0; k < VecK; ++k) { + pa[m * VecK + k] = MathWrapper::mul2(pa[m * VecK + k], pb[k]); + } + } + } +} + +template +__device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros, float alpha) { + using Type = typename MathWrapper::Type; + using Type2 = typename MathWrapper::Type2; + using Converter = typename ConverterWrapper
::Converter; + static_assert(K % 2 == 0); + static constexpr int VecK = K / 2; +#pragma unroll + for (int n = 0; n < N; ++n) { + Converter::convert(reinterpret_cast(quantized_w) + n * K / Details::kElemsPerByteW, + reinterpret_cast(w) + n * K); + Type2 vec_scale, vec_zero; + if constexpr (ApplyAlphaInAdvance) { + // For W4A8, we assume scales/zero is always half data type, no matter activation dtype is bf16 or fp16 + Type scales_ = static_cast(reinterpret_cast(scales)[n]) * alpha; + vec_scale = MathWrapper::to_vec2(scales_); + vec_zero = MathWrapper::to_vec2(static_cast(0.f)); + if constexpr (EnableZero) { + vec_zero = MathWrapper::to_vec2( + static_cast(reinterpret_cast(zeros)[n]) * alpha); + } + } else { + vec_scale = MathWrapper::to_vec2(reinterpret_cast(scales)[n]); + vec_zero = MathWrapper::to_vec2(static_cast(0.f)); + if constexpr (EnableZero) { + vec_zero = MathWrapper::to_vec2(reinterpret_cast(zeros)[n]); + } + } +#pragma unroll + for (int k = 0; k < VecK; ++k) { + reinterpret_cast(w)[n * VecK + k] = MathWrapper::fma2( + reinterpret_cast(w)[n * VecK + k], vec_scale, vec_zero); + } + } +} + +template +__device__ __forceinline__ void pack_to_vec2(void* dst, void* src, int n) { + using Type = typename MathWrapper::Type; + typename Details::LayoutDetails::Mapper mapper; + int n0 = n & ~0x1, n1 = n & 0x1; + for (int k = 0; k < K; ++k) { + int physical_idx = mapper(k); + reinterpret_cast(dst)[n0 * K + k * 2 + n1] = reinterpret_cast(src)[physical_idx]; + } +} + +template +__device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act) { + using Type = typename MathWrapper::Type; + using Type2 = typename MathWrapper::Type2; + static_assert(N % 2 == 0); + static constexpr int VecN = N / 2; +#pragma unroll + for (int m = 0; m < M; ++m) { +#pragma unroll + for (int n = 0; n < VecN; ++n) { +#pragma unroll + for (int k = 0; k < K; ++k) { + reinterpret_cast(acc)[m * VecN + n] = MathWrapper::fma2( + reinterpret_cast(w_pack2)[n * K + k], + MathWrapper::to_vec2(reinterpret_cast(act)[m * K + k]), + reinterpret_cast(acc)[m * VecN + n]); + } + } + } +} + +template +__device__ __forceinline__ T warp_reduce_sum(T& val) { + val += __shfl_xor_sync(~0, val, 16); + val += __shfl_xor_sync(~0, val, 8); + if (Interleave != 2 && Interleave != 4) + val += __shfl_xor_sync(~0, val, 4); + if (Interleave != 4) + val += __shfl_xor_sync(~0, val, 2); + val += __shfl_xor_sync(~0, val, 1); + return val; +} + +template +__device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias, float alpha) { + using Type = typename MathWrapper::Type; + static constexpr int Interleave = Details::kInterleave; + static constexpr int ThreadsPerInterleavedTile = Details::kThreadsPerInterleavedTile; + static constexpr int WarpSize = Details::kWarpSize; + static constexpr int WarpNum = Threads / WarpSize; + static_assert(Threads % WarpSize == 0); + __shared__ float shmem[CtaM * CtaN * Interleave * WarpNum]; + int tid = threadIdx.x; + int warp_id = tid / WarpSize, lane_id = tid % WarpSize; +#pragma unroll + for (int m = 0; m < CtaM; ++m) { +#pragma unroll + for (int n = 0; n < CtaN; ++n) { + float v = static_cast(reinterpret_cast(tile_acc)[m * CtaN + n]); + v = warp_reduce_sum(v); + if (lane_id < Interleave * ThreadsPerInterleavedTile && lane_id % ThreadsPerInterleavedTile == 0) { + shmem[warp_id * CtaM * CtaN * Interleave + m * CtaN * Interleave + n * Interleave + lane_id / ThreadsPerInterleavedTile] = v; + } + } + } + __syncthreads(); +#pragma unroll + for (int ii = tid; ii < CtaM * CtaN * Interleave; ii += Threads) { + int m = ii / (CtaN * Interleave), n = ii % (CtaN * Interleave); + float val = 0.f, v_bias = 0.f; + if constexpr (EnableBias) { + v_bias = static_cast(reinterpret_cast(bias)[n]); + } +#pragma unroll + for (int jj = 0; jj < WarpNum; ++jj) { + val += shmem[jj * CtaM * CtaN * Interleave + ii]; + } + if constexpr (ApplyAlphaInAdvance) { + reinterpret_cast(out)[m * stride + n] = static_cast(val + v_bias); + } else { + reinterpret_cast(out)[m * stride + n] = static_cast(alpha * val + v_bias); + } + } +} + +template +__device__ __forceinline__ void fill(void* tile, T v) { +#pragma unroll + for (int ii = 0; ii < N; ++ii) { + reinterpret_cast(tile)[ii] = v; + } +} + +template +__global__ void kernel(TypeA* act, TypeA* act_scale, uint8_t* weight, TypeA* scales, TypeA* zeros, TypeA* bias, + TypeA* out, float alpha, int m, int n, int k) { + // ArgType ArgName DataType Shape Layout + // input act fp16/bf16 [m, k] RowMajor + // input act_scale fp16/bf16 [1, k] RowMajor + // input weight int4b/int8b [k, n] ColumnMajor or ColumnMajorInterleaved + // input scales fp16/bf16 [k / GroupSize, n] RowMajor + // input zeros fp16/bf16 [k / GroupSize, n] RowMajor + // input bias fp16/bf16 [1, n] RowMajor + // output out fp16/bf16 [m, n] RowMajor + + using AccessTypeA = typename Details::AccessTypeA; + using AccessTypeW = typename Details::AccessTypeW; + + static constexpr bool Mandatory = true; + static constexpr int StepK = Details::kStepK; + static constexpr int CtaK = StepK * Threads; + static_assert(CtaN % 2 == 0); + if constexpr (GroupSize != 0) { + static_assert((CtaK / Details::kInterleave) % GroupSize == 0); + } + + int const origin_k = k, interleaved_k = k * Details::kInterleave; + + int const tile_id_m = blockIdx.x, tile_id_n = blockIdx.y, tid = threadIdx.x; + int const offset_m = tile_id_m * CtaM, interleaved_offset_n = tile_id_n * CtaN; + int const real_offset_n = interleaved_offset_n * Details::kInterleave + ((tid * StepK / Details::LayoutDetails::kTileSize) % Details::kInterleave); + int const real_offset_k = (tid * StepK / (Details::kInterleave * Details::LayoutDetails::kTileSize)) * Details::LayoutDetails::kTileSize + ((tid * StepK) % Details::LayoutDetails::kTileSize); + + GMemIterator act_iterator( + act, offset_m * origin_k + real_offset_k, CtaK / Details::kInterleave, origin_k); + GMemIterator act_scale_iterator( + act_scale, real_offset_k, CtaK / Details::kInterleave, 0); + GMemIterator weight_iterator( + weight, + (interleaved_offset_n * interleaved_k + tid * StepK) / Details::kElemsPerByteW, CtaK / Details::kElemsPerByteW, + interleaved_k / Details::kElemsPerByteW); + + GMemIterator scales_iterator( + scales, + (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); + + GMemIterator zeros_iterator( + zeros, + (GroupSize != 0 ? real_offset_k / GroupSize * n : 0) + real_offset_n, + (GroupSize != 0 ? CtaK / Details::kInterleave / GroupSize * n : 0), Details::kInterleave); + + out += offset_m * n + tile_id_n * CtaN * Details::kInterleave; + if constexpr (EnableBias) { + bias += tile_id_n * CtaN * Details::kInterleave; + } + + TypeA tile_acc[CtaM * CtaN]; + fill(tile_acc, static_cast(0.f)); + + for (int idx_k = tid * StepK, iter = 0; idx_k < interleaved_k; idx_k += CtaK, ++iter) { + TypeA vec_act_scale[StepK]; + TypeA vec_scale[CtaN], vec_zero[CtaN]; + TypeA tile_a[StepK], tile_w[StepK], tile_w_pack2[CtaN * StepK]; + uint8_t tile_w_quantized[StepK / Details::kElemsPerByteW]; +#pragma unroll + for (int i = 0; i < CtaN; ++i) { + scales_iterator.load(vec_scale + i, iter, i); + zeros_iterator.load(vec_zero + i, iter, i); + } + act_scale_iterator.load(vec_act_scale, iter); +#pragma unroll + for (int i = 0; i < CtaN; ++i) { + weight_iterator.load(tile_w_quantized, iter, i); + dequantize( + tile_w, tile_w_quantized, vec_scale + i, vec_zero + i, alpha); + pack_to_vec2(tile_w_pack2, tile_w, i); + } +#pragma unroll + for (int i = 0; i < CtaM; ++i) { + act_iterator.load(tile_a, iter, i); + apply_scale(tile_a, vec_act_scale); + mma(tile_acc + i * CtaN, tile_w_pack2, tile_a); + } + } + epilogue(out, n, tile_acc, bias, alpha); +} + +template +void exec_kernel(Params& params, cudaStream_t s) { + using T = typename Details::TypeDetailsA::Type; + if (params.m % CtaM || params.n % (CtaN * Details::kInterleave)) { + throw std::runtime_error("launch failed"); + } + dim3 grid(params.m / CtaM, params.n / (CtaN * Details::kInterleave)); + dim3 block(Threads); + kernel<<>>( + reinterpret_cast(params.act), + reinterpret_cast(params.act_scale), + reinterpret_cast(params.weight), + reinterpret_cast(params.scales), + reinterpret_cast(params.zeros), + reinterpret_cast(params.bias), + reinterpret_cast(params.out), + params.alpha, + params.m, params.n, params.k); +} + +template +void dispatcher(Params& params, cudaStream_t s) { +#define DISPATCHER_FOR_M(target_m, CtaM, CtaN, Threads) \ + do { \ + if (params.m == target_m) { \ + exec_kernel(params, s); \ + return; \ + } \ + } while (0); + + if constexpr (EnableZero) { + DISPATCHER_FOR_M(1, 1, 4, 128); + DISPATCHER_FOR_M(2, 2, 4, 128); + DISPATCHER_FOR_M(3, 3, 4, 128); + DISPATCHER_FOR_M(4, 4, 4, 128); + DISPATCHER_FOR_M(5, 5, 4, 128); + DISPATCHER_FOR_M(6, 6, 4, 128); + DISPATCHER_FOR_M(7, 7, 4, 128); + DISPATCHER_FOR_M(8, 8, 4, 128); + DISPATCHER_FOR_M(9, 9, 4, 128); + DISPATCHER_FOR_M(10, 10, 4, 128); + DISPATCHER_FOR_M(11, 11, 4, 128); + DISPATCHER_FOR_M(12, 12, 4, 128); + DISPATCHER_FOR_M(13, 13, 4, 128); + DISPATCHER_FOR_M(14, 14, 4, 128); + DISPATCHER_FOR_M(15, 15, 4, 128); + } else { + DISPATCHER_FOR_M(1, 1, 8, 128); + DISPATCHER_FOR_M(2, 2, 8, 128); + DISPATCHER_FOR_M(3, 3, 8, 128); + DISPATCHER_FOR_M(4, 4, 8, 128); + DISPATCHER_FOR_M(5, 5, 8, 128); + DISPATCHER_FOR_M(6, 6, 8, 128); + DISPATCHER_FOR_M(7, 7, 8, 128); + DISPATCHER_FOR_M(8, 8, 8, 128); + DISPATCHER_FOR_M(9, 9, 8, 128); + DISPATCHER_FOR_M(10, 10, 8, 128); + DISPATCHER_FOR_M(11, 11, 8, 128); + DISPATCHER_FOR_M(12, 12, 8, 128); + DISPATCHER_FOR_M(13, 13, 8, 128); + DISPATCHER_FOR_M(14, 14, 8, 128); + DISPATCHER_FOR_M(15, 15, 8, 128); + } + throw std::runtime_error("unsupported m"); +#undef DISPATCHER_FOR_M +} + +template +void check_pointer(Params& params, cudaStream_t s) { + assert(!params.act_scale); // act_scale is not supported for now. + assert(!params.apply_alpha_in_advance); // apply_alpha_in_advance is not supported for now. + + if (params.zeros && params.bias) { + dispatcher(params, s); + } else if (!params.zeros && params.bias) { + dispatcher(params, s); + } else if (params.zeros && !params.bias) { + dispatcher(params, s); + } else { + dispatcher(params, s); + } +} + +template +void select_gs(Params& params, cudaStream_t s) { + if constexpr (isGroupwise) { + if (params.groupsize == 64) { + check_pointer(params, s); + return; + } else if (params.groupsize == 128) { + check_pointer(params, s); + return; + } + } + + ORT_THROW("unsupported block_size: ", params.groupsize); +} + +#define INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(KType, A, B, Layout, ConverterInterleave, KTile) \ + template void select_gs::isGroupwise, \ + KernelDetails>(Params & params, cudaStream_t s); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu new file mode 100644 index 0000000000000..e2c008884c998 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true, 64); + +// KTile=128 for Ada w4a8 +// INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( +// KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true, 128); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu new file mode 100644 index 0000000000000..8cd96c44421e5 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int4_hopper.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true, 64); + +// INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( +// KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true, 128); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu new file mode 100644 index 0000000000000..1eb5f51bdffdc --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true, 64); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu new file mode 100644 index 0000000000000..f5872841e1acb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_bf16_int8_hopper.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true, 64); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu new file mode 100644 index 0000000000000..f6b76e67b20ba --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true, 64); + +// KTile=128 for Ada w4a8 +// INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( +// KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true, 128); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu new file mode 100644 index 0000000000000..2ca88285d4cfe --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int4_hopper.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true, 64); + +// INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( +// KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true, 128); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu new file mode 100644 index 0000000000000..7a00e1ba35f80 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true, 64); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu new file mode 100644 index 0000000000000..4a8506ca6bbde --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher_fp16_int8_hopper.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/fpA_intB_gemv/dispatcher.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS( + KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true, 64); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu new file mode 100644 index 0000000000000..32cd607d36480 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.cu @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemv/details.h" + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +void kernel_launcher(int arch, Params& params, cudaStream_t s) { +#define EXEC(KType, A, B, Layout, ConverterInterleave) \ + if (params.type == KType) { \ + select_gs::isGroupwise, KernelDetails>( \ + params, s); \ + return; \ + } + +// This is not used since there is no alpha for MatMulNBits currently. +#define EXEC_W4A8(KType, A, B, Layout, ConverterInterleave) \ + if (params.type == KType && params.apply_alpha_in_advance) { \ + select_gs::isGroupwise, KernelDetails>( \ + params, s); \ + return; \ + } + + if (arch >= 75 && arch < 80) { + EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + } else if (arch >= 80 && arch < 90 || arch >= 100) { + // if (arch == 89 || arch >= 120) + // { + // EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + // EXEC_W4A8(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + // } + EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + + EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true); + EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true); + } else if (arch >= 90) { + // Dispatchers for W4A8 groupwise + // EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); + // EXEC_W4A8(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); + + EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true); + EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); + + EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleavedForHopper, true); + EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleavedForHopper, true); + } +#undef EXEC_W4A8 +#undef EXEC +} + +bool is_supported(int arch, KernelType kernel_type) { +#define SUPPORT(Type) \ + if (kernel_type == Type) \ + return true; + + if (arch >= 75 && arch < 80) { + SUPPORT(KernelType::FP16Int8Groupwise); + SUPPORT(KernelType::FP16Int4Groupwise); + } else if (arch >= 80) { + SUPPORT(KernelType::FP16Int8Groupwise); + SUPPORT(KernelType::FP16Int4Groupwise); + + SUPPORT(KernelType::BF16Int8Groupwise); + SUPPORT(KernelType::BF16Int4Groupwise); + } + return false; +#undef SUPPORT +} + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h new file mode 100644 index 0000000000000..db2860c6b265c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/fpA_intB_gemv/fpA_intB_gemv.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace onnxruntime::llm { +namespace kernels { +namespace fpA_intB_gemv { + +enum class KernelType { + FP16Int8Groupwise, + FP16Int4Groupwise, + FP16Int8PerChannel, + FP16Int4PerChannel, + BF16Int8Groupwise, + BF16Int4Groupwise, + BF16Int8PerChannel, + BF16Int4PerChannel +}; + +struct Params { + using Pointer = void*; + using ConstPointer = void const*; + Pointer act; + Pointer act_scale; + Pointer weight; + Pointer scales; + Pointer zeros; + Pointer bias; + Pointer out; + float alpha; + int m; + int n; + int k; + int groupsize; + KernelType type; + bool apply_alpha_in_advance; + + Params(ConstPointer _act, ConstPointer _act_scale, ConstPointer _weight, ConstPointer _scales, ConstPointer _zeros, + ConstPointer _bias, Pointer _out, float _alpha, int _m, int _n, int _k, int _groupsize, KernelType _type, + bool _apply_alpha_in_advance = false) + : act(const_cast(_act)), + act_scale(const_cast(_act_scale)), + weight(const_cast(_weight)), + scales(const_cast(_scales)), + zeros(const_cast(_zeros)), + bias(const_cast(_bias)), + out(_out), + alpha(_alpha), + m(_m), + n(_n), + k(_k), + groupsize(_groupsize), + type(_type), + apply_alpha_in_advance(_apply_alpha_in_advance) { + } +}; + +void kernel_launcher(int arch, Params& params, cudaStream_t s); + +bool is_supported(int arch, KernelType kernel_type); + +} // namespace fpA_intB_gemv +} // namespace kernels +} // namespace onnxruntime::llm diff --git a/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc new file mode 100644 index 0000000000000..893ff27c068f8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc @@ -0,0 +1,311 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "contrib_ops/cuda/llm/gemm_profiler.h" +#include "contrib_ops/cuda/llm/common/logger.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" + +#include + +namespace onnxruntime::llm::kernels::weight_only { + +template +GemmPluginProfiler::GemmPluginProfiler() { + mMNKProfileMap = std::make_shared(); + + // set SKIP_GEMM_PLUGIN_PROFILINGS=1 to avoid tactics profilings + auto const skipEnv = std::getenv("SKIP_GEMM_PLUGIN_PROFILINGS"); + mSkip = (skipEnv != NULL && std::stoi(skipEnv)); + if (mSkip) { + ORT_LLM_LOG_DEBUG( + "SKIP_GEMM_PLUGIN_PROFILINGS is set. Skipping GEMM plugin profilings. It could result in runtime error " + "if default tactic is not defined."); + } +} + +// template +// void GemmPluginProfiler::serialize( +// char*& buffer, GemmIdType const& gemmId) const +// { +// auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId); + +// // Save number of profiles for given GEMM ID +// write(buffer, static_cast(mProfileMap->size())); +// for (auto const& pair : *mProfileMap) +// { +// // Save pair of M to the best GEMM config +// write(buffer, pair); +// } +// } + +// template +// void GemmPluginProfiler::deserialize( +// char const*& data, GemmDims& dims, GemmIdType const& gemmId) +// { +// // NOTE: this mutex is not needed since each thread owns its private map, but will put here for +// // consistency +// writer_lock lock(mMNKProfileMap->mutex); + +// mDims = dims; + +// // GemmId gemmId(dims.n, dims.k); +// if (!mMNKProfileMap->existsMProfileMap(gemmId)) +// { +// // Create GEMM with GEMM ID if it does not exist +// mMNKProfileMap->createMProfileMap(gemmId); +// } +// // Populate map with profiles of GEMM ID +// auto profileMap = mMNKProfileMap->getMProfileMap(gemmId); +// int selectedMapSize; +// read(data, selectedMapSize); +// for (int ii = 0; ii < selectedMapSize; ++ii) +// { +// std::pair> config; +// read(data, config); +// profileMap->insert(config); +// } +// } + +// template +// size_t GemmPluginProfiler::getSerializationSize( +// GemmIdType const& gemmId) const +// { +// reader_lock lock(mMNKProfileMap->mutex); +// return sizeof(int) + // size of the tactics map +// mMNKProfileMap->getMProfileMap(gemmId)->size() +// * sizeof(std::pair>); // size of the tactics map +// } + +template +int GemmPluginProfiler::getMaxProfileM() const { + return 8192; +} + +template +void GemmPluginProfiler::initTmpData( + int /*m*/, int /*n*/, int /*k*/, char* /*workspace*/, size_t /*size*/, cudaStream_t /*stream*/) { + /* Do nothing */ +} + +template +void GemmPluginProfiler::profileTactics( + RunnerPtr const& runner, nvinfer::DataType const& type, GemmDims const& dims, GemmIdType const& gemmId, + bool hasWeightOnlyCudaKernel) { + writer_lock lock(mMNKProfileMap->mutex); + + if (!dims.isInitialized()) { + return; + } + + mRunner = runner; + mType = type; + + int const maxM = std::min(nextPowerOfTwo(dims.maxM), getMaxProfileM()); + computeTmpSize(maxM, dims.n, dims.k); + + if (!mMNKProfileMap->existsMProfileMap(gemmId)) { + // Create map for GEMM ID + mMNKProfileMap->createMProfileMap(gemmId); + } + + if (mSkip) { + return; + } + + auto mProfileMap = mMNKProfileMap->getMProfileMap(gemmId); + bool isAllocated{false}; + + auto profileTactics = [&mProfileMap, &isAllocated, this](int m, int n, int k) { + if (mProfileMap->count(m) == 0) { + if (!isAllocated) { + // Allocate tmp data to run GEMMs + allocateTmpData(); + isAllocated = true; + } + initTmpData(m, n, k, mWorkspaceTmp, mTmpWorkspaceSizeInBytes, mStream); + auto tactics = this->getTactics(m, n, k); + + // Profile different tactics for particular m and insert best config to the map + mProfileMap->insert({m, this->profileTacticsForProblem(m, n, k, tactics)}); + } + }; + + CUDA_CALL_THROW(cudaStreamCreate(&mStream)); + + int const startMinMRounded = nextPowerOfTwo(dims.minM); + + if (hasWeightOnlyCudaKernel) { + // Profile tactics for finer granularity of M, + // if CUDA kernel is enabled for weight-only plugins + int minM = dims.minM; + for (int m = std::max(1, minM); m < std::min(16, maxM); m += 1) { + profileTactics(m, dims.n, dims.k); + } + + for (int m = 16; m < maxM; m *= 2) { + profileTactics(m, dims.n, dims.k); + } + } else { + // Profile tactics for CUTLASS kernel only + for (int m = std::max(1, startMinMRounded); m < maxM; m *= 2) { + profileTactics(m, dims.n, dims.k); + } + } + + profileTactics(maxM, dims.n, dims.k); + + if (isAllocated) { + // Free tmp data + freeTmpData(); + } + CUDA_CALL_THROW(cudaStreamDestroy(mStream)); +} + +template +std::optional GemmPluginProfiler::getBestConfig( + int m, GemmIdType const& gemmId) const { + reader_lock lock(mMNKProfileMap->mutex); + + if (mSkip) { + ORT_LLM_LOG_TRACE("Skip is set, no best config is set for this instance"); + return std::nullopt; + } + + int const mRounded = std::min(std::max(1, nextPowerOfTwo(m)), getMaxProfileM()); + fflush(stdout); + + if (mMNKProfileMap->getMProfileMap(gemmId)->count(m) > 0) { + return mMNKProfileMap->getMProfileMap(gemmId)->at(m); + } else if (mMNKProfileMap->getMProfileMap(gemmId)->count(mRounded) > 0) { + return mMNKProfileMap->getMProfileMap(gemmId)->at(mRounded); + } else { + std::ostringstream msg; + msg << "Cannot find best tactic for m=" << m << " and GEMM ID " << gemmId; + ORT_LLM_LOG_WARNING(msg.str()); + return std::nullopt; + } +} + +template +void GemmPluginProfiler::allocateTmpData() { + ORT_ENFORCE(mTmpWorkspaceSizeInBytes > 0, "tmpWorkspaceSizeInBytes must be larger than 0"); + auto const status = cudaMalloc(&mWorkspaceTmp, mTmpWorkspaceSizeInBytes); + ORT_ENFORCE(status == cudaSuccess, "Can't allocate tmp workspace for GEMM tactics profiling."); +} + +template +void GemmPluginProfiler::freeTmpData() { + auto const status = cudaFree(mWorkspaceTmp); + ORT_ENFORCE(status == cudaSuccess, "Can't free tmp workspace for GEMM tactics profiling."); +} + +template +std::optional GemmPluginProfiler::profileTacticsForProblem( + int m, int n, int k, std::vector const& tactics) { + ORT_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + + float bestTime = std::numeric_limits::max(); + Config bestConfig; + bool foundOne = false; + + // Iterate over all tactics for given M, N and K + for (size_t ii = 0; ii < tactics.size(); ++ii) { + Config const& candidateConfig = tactics[ii]; + float time = std::numeric_limits::max(); + try { + if (!checkTactic(m, n, k, candidateConfig)) { + continue; + } + // Profile particular tactic for given M, N and K + time = profileTacticForProblem(m, n, k, candidateConfig); + foundOne = true; + } catch (std::exception const& e) { + std::ostringstream msg; + msg << "Cannot profile configuration " << ii; + if constexpr (std::is_same_v) { + msg << ": " << candidateConfig.toString(); + } + msg << "\n (for" + << " m=" << m << ", n=" << n << ", k=" << k << ")" + << ", reason: \"" << e.what() << "\". Skipped"; + ORT_LLM_LOG_TRACE(msg.str()); + cudaGetLastError(); // Reset the last cudaError to cudaSuccess. + continue; + } + + // Choose the fastest tactic + if (time < bestTime) { + bestConfig = candidateConfig; + bestTime = time; + } + } + + if (!foundOne) { + std::ostringstream msg; + msg << "Have not found any valid GEMM config for shape (" + << "m=" << m << ", n=" << n << ", k=" << k << "). Will try to use default or fail at runtime"; + ORT_LLM_LOG_WARNING(msg.str()); + return std::nullopt; + } + + return {bestConfig}; +} + +template +float GemmPluginProfiler::profileTacticForProblem( + int m, int n, int k, Config const& tactic) { + constexpr int warmup = 5; + constexpr int runs = 10; + + cudaStream_t stream = mStream; + + // Warmup the execution + for (int i = 0; i < warmup; ++i) { + runTactic(m, n, k, tactic, mWorkspaceTmp, stream); + } + + cudaEvent_t start; + cudaEvent_t stop; + CUDA_CALL_THROW(cudaEventCreate(&start)); + CUDA_CALL_THROW(cudaEventCreate(&stop)); + CUDA_CALL_THROW(cudaStreamSynchronize(stream)); + CUDA_CALL_THROW(cudaEventRecord(start, stream)); + + // Profile GEMM + for (int i = 0; i < runs; ++i) { + runTactic(m, n, k, tactic, mWorkspaceTmp, stream); + } + + CUDA_CALL_THROW(cudaEventRecord(stop, stream)); + + CUDA_CALL_THROW(cudaEventSynchronize(stop)); + + float elapsed; + CUDA_CALL_THROW(cudaEventElapsedTime(&elapsed, start, stop)); + + CUDA_CALL_THROW(cudaEventDestroy(start)); + CUDA_CALL_THROW(cudaEventDestroy(stop)); + + return elapsed / runs; +} + +template class GemmPluginProfiler, GemmIdCore, + GemmIdCoreHash>; + +} // namespace onnxruntime::llm::kernels::weight_only diff --git a/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h new file mode 100644 index 0000000000000..0ab9b91e7f43c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/gemm_profiler.h @@ -0,0 +1,283 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "contrib_ops/cuda/llm/nv_infer_datatype.h" +#include "core/common/common.h" + +namespace onnxruntime::llm::kernels::weight_only { + +struct GemmDims { + int64_t minM; + int64_t maxM; + int64_t n; + int64_t k; + + GemmDims() + : minM(-1), maxM(-1), n(-1), k(-1) { + } + + GemmDims(int64_t minM_, int64_t maxM_, int64_t n_, int64_t k_) + : minM(minM_), maxM(maxM_), n(n_), k(k_) { + } + + [[nodiscard]] bool isInitialized() const { + return minM >= 0 && maxM >= 0 && n >= 0 && k >= 0; + } +}; + +// Unique ID of GEMM +// In our case GEMM is uniqly identified by N and K +class GemmIdCore { + public: + int n; + int k; + nvinfer::DataType dtype; + + GemmIdCore(int n_, int k_, nvinfer::DataType const& dtype_) + : n(n_), k(k_), dtype(dtype_) { + } + + GemmIdCore() + : n(-1), k(-1), dtype(nvinfer::DataType::kFLOAT) // dtype does not matter here + { + } + + bool operator==(GemmIdCore const& id) const { + return isEqual(id); + } + + friend std::ostream& operator<<(std::ostream& out, GemmIdCore const& id) { + out << "(N;K)=(" << id.n << ";" << id.k << "),"; + out << " type=" << static_cast(id.dtype); + return out; + } + + protected: + bool isEqual(GemmIdCore const& id) const { + return n == id.n && k == id.k && dtype == id.dtype; + } +}; + +// Hash of GemmId +struct GemmIdCoreHash { + std::size_t operator()(GemmIdCore const& id) const { + auto h1 = std::hash{}(id.n); + auto h2 = std::hash{}(id.k); + auto h3 = std::hash{}(static_cast(id.dtype)); + return h1 ^ h2 ^ h3; + } +}; + +// class GemmIdCublas : public GemmIdCore { +// public: +// bool transA{}; +// bool transB{}; +// nvinfer::DataType outputDtype; + +// GemmIdCublas(int n_, int k_, nvinfer::DataType const& dtype_, bool transA_, bool transB_, +// nvinfer::DataType const& output_dtype_) +// : GemmIdCore(n_, k_, dtype_), transA(transA_), transB(transB_), outputDtype(output_dtype_) { +// } + +// GemmIdCublas() {} + +// bool operator==(GemmIdCublas const& id) const { +// return isEqual(id) && transA == id.transA && transB == id.transB && outputDtype == id.outputDtype; +// } + +// friend std::ostream& operator<<(std::ostream& out, GemmIdCublas const& id) { +// out << "(N;K)=(" << id.n << ";" << id.k << "),"; +// out << " type=" << static_cast(id.dtype); +// out << " transA=" << id.transA; +// out << " transB=" << id.transB; +// out << " outputDtype=" << static_cast(id.outputDtype); +// return out; +// } +// }; + +// // Hash of GemmIdCublas +// struct GemmIdCublasHash { +// std::size_t operator()(GemmIdCublas const& id) const { +// auto h1 = std::hash{}(id.n); +// auto h2 = std::hash{}(id.k); +// auto h3 = std::hash{}(static_cast(id.dtype)); +// auto h4 = std::hash{}(id.transA); +// auto h5 = std::hash{}(id.transB); +// auto h6 = std::hash{}(static_cast(id.outputDtype)); +// return h1 ^ h2 ^ h3 ^ h4 ^ h5 ^ h6; +// } +// }; + +template +class GemmPluginProfiler { + public: + // Map for single GEMM for different Ms (GEMM dimension) to the best config for particular M + using MProfileMap = std::unordered_map>; + using MProfileMapPtr = std::shared_ptr; + + // requires exclusive ownership to write to *this + using reader_lock = std::unique_lock; + // requires shared ownership to read from other + using writer_lock = std::shared_lock; + + // Struct of continuing map if GEMMs to the best profiles for different Ms + struct MNKProfileMap { + // Mutex guarding map + std::shared_timed_mutex mutex; + // Map from GEMM Id to profile for particular GEMM + std::unordered_map profileMap; + + bool existsMProfileMap(GemmIdType const& id) { + auto const iter = profileMap.find(id); + return iter != profileMap.end(); + } + + void createMProfileMap(GemmIdType const& id) { + profileMap[id] = std::make_shared(); + } + + MProfileMapPtr getMProfileMap(GemmIdType const& id) { + auto const iter = profileMap.find(id); + if (iter == profileMap.end()) { + ORT_THROW("Cannot find ID (", id, ") in the profile map. Abort."); + } + return iter->second; + } + }; + + using MNKProfileMapPtr = std::shared_ptr; + + GemmPluginProfiler(); + + virtual ~GemmPluginProfiler() = default; + + // void serialize(char*& buffer, GemmIdType const& gemmId) const; + + // void deserialize(char const*& data, GemmDims& dims, GemmIdType const& gemmId); + // size_t getSerializationSize(GemmIdType const& gemmId) const; + + void profileTactics(RunnerPtr const& runner, nvinfer::DataType const& type, GemmDims const& dims, + GemmIdType const& gemmId, bool hasWeightOnlyCudaKernel = false); + + void setSelectionTactics(MNKProfileMapPtr const& map) { + mMNKProfileMap = map; + } + + void setTmpWorkspaceSizeInBytes(size_t bytes) { + mTmpWorkspaceSizeInBytes = bytes; + } + + void setSkip(bool skip) { + mSkip = mSkip || skip; + } + + std::optional getBestConfig(int m, GemmIdType const& gemmId) const; + + virtual int getMaxProfileM() const; + + protected: + virtual void runTactic(int m, int n, int k, Config const& tactic, char* workspace, cudaStream_t const& stream) = 0; + + virtual void computeTmpSize(size_t maxM, size_t n, size_t k) = 0; + + virtual bool checkTactic(int /*m*/, int /*n*/, int /*k*/, Config const& /*tactic*/) const { + return true; + } + + virtual std::vector getTactics(int m, int n, int k) const = 0; + + virtual void initTmpData(int m, int n, int k, char* workspace, size_t size, cudaStream_t stream); + + private: + void allocateTmpData(); + + void freeTmpData(); + + std::optional profileTacticsForProblem(int m, int n, int k, std::vector const& tactics); + + float profileTacticForProblem(int m, int n, int k, Config const& tactic); + + int nextPowerOfTwo(int v) const { + --v; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + return ++v; + } + + protected: + RunnerPtr mRunner{nullptr}; + + nvinfer::DataType mType{}; + + private: + MNKProfileMapPtr mMNKProfileMap{}; + + size_t mTmpWorkspaceSizeInBytes{0}; + + char* mWorkspaceTmp{nullptr}; + + cudaStream_t mStream; + + GemmDims mDims{}; + + bool mSkip{false}; +}; + +template +class GemmPluginProfilerManager { + public: + using MNKProfileMap = typename GemmPluginProfilerType::MNKProfileMap; + using MNKProfileMapPtr = typename GemmPluginProfilerType::MNKProfileMapPtr; + using GemmPluginProfilerPtr = std::shared_ptr; + + GemmPluginProfilerManager() { + mMNKProfileMap = std::make_shared(); + } + + GemmPluginProfilerPtr createGemmPluginProfiler(bool inference, bool skip = false) { + auto profiler = std::make_shared(); + profiler->setSkip(skip); + // If the profiler is created during the engine build, + // mMNKProfileMap is shared between different profilers to minimize the time spent on the profiling + // and do not repeat profiling for the GEMMs of the same shape. + if (!inference) { + profiler->setSelectionTactics(mMNKProfileMap); + } + return profiler; + } + + private: + MNKProfileMapPtr mMNKProfileMap{}; +}; + +} // namespace onnxruntime::llm::kernels::weight_only diff --git a/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py b/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py new file mode 100644 index 0000000000000..678102c809b63 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/generate_kernels.py @@ -0,0 +1,397 @@ +# Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generate fpA intB GEMM kernels: +# pip install nvidia-cutlass +# python generate_kernels.py -a "90" -o ./fpA_intB_gemm/launchers + +import argparse +import enum +import os +from itertools import product + +from cutlass_library import ( + DataType, + DataTypeNames, + DataTypeSize, + DataTypeTag, + EpilogueScheduleSuffixes, + EpilogueScheduleTag, + EpilogueScheduleType, + GemmKind, + GemmKindNames, + KernelScheduleSuffixes, + KernelScheduleTag, + KernelScheduleType, +) + + +################################################################################ +# Epilogue Tag enum and string utils +class LlmEpilogueTag(enum.Enum): + epilogue_op_default = enum.auto() + epilogue_op_bias = enum.auto() + epilogue_op_silu = enum.auto() + epilogue_op_gelu = enum.auto() + + +class LlmEpilogueFusion(enum.Enum): + epilogue_fusion_none = enum.auto() + epilogue_fusion_finalize = enum.auto() + + +EpiTagNames = { + LlmEpilogueTag.epilogue_op_default: "lc", # linear combination + LlmEpilogueTag.epilogue_op_bias: "lc_bias", # linear combination with bias addition + LlmEpilogueTag.epilogue_op_silu: "silu", # silu or swiglu + LlmEpilogueTag.epilogue_op_gelu: "gelu", # gelu or geglu +} + +EpiTag = { + LlmEpilogueTag.epilogue_op_default: "onnxruntime::llm::cutlass_extensions::EpilogueOpDefault", + LlmEpilogueTag.epilogue_op_bias: "onnxruntime::llm::cutlass_extensions::EpilogueOpBias", + LlmEpilogueTag.epilogue_op_silu: "onnxruntime::llm::cutlass_extensions::EpilogueOpDefaultSilu", + LlmEpilogueTag.epilogue_op_gelu: "onnxruntime::llm::cutlass_extensions::EpilogueOpDefaultFtGelu", +} + +EpiFusion = { + LlmEpilogueFusion.epilogue_fusion_none: "onnxruntime::llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE", + LlmEpilogueFusion.epilogue_fusion_finalize: "onnxruntime::llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE", +} + +EpiFusionSuffixes = { + None: "", + LlmEpilogueFusion.epilogue_fusion_none: "EpilogueFusion_NONE", + LlmEpilogueFusion.epilogue_fusion_finalize: "EpilogueFusion_FINALIZE", +} + + +################################################################################ +# Quantization Operation and string utils +class LlmQuantOp(enum.Enum): + per_column_scale_only = enum.auto() + finegrained_scale_only = enum.auto() + finegrained_scale_and_zeros = enum.auto() + none = enum.auto() + + +QuantOpNames = { + LlmQuantOp.per_column_scale_only: "cs", + LlmQuantOp.finegrained_scale_only: "fgs", + LlmQuantOp.finegrained_scale_and_zeros: "fgsz", + LlmQuantOp.none: "noquant", +} + +QuantOpTag = { + LlmQuantOp.per_column_scale_only: "cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY", + LlmQuantOp.finegrained_scale_only: "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY", + LlmQuantOp.finegrained_scale_and_zeros: "cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS", + LlmQuantOp.none: "void", +} + +################################################################################ +# The activations, biases, scales and zeros are instantiated using CUDA types, +# not CUTLASS types. This map materializes the name of the CUDA type. + + +def get_data_type_bits(type): + return DataTypeSize[type] + + +def get_data_type_names(type): + return DataTypeNames[type] + + +CudaTypeName = { + DataType.e4m3: "__nv_fp8_e4m3", + DataType.bf16: "__nv_bfloat16", + DataType.f16: "half", + DataType.f32: "float", +} + + +################################################################################ +# A data structure holding all info to instantiate gemm launchers in TRT LLM. +class LlmGemmLauncher: + def __init__( + self, + gemm_kind, + arch, + act_type, + weight_type, + scalezero_type, + bias_type, + output_type, + quant_op, + epi_tag, + cta_shape, + warp_shape, + stages, + cga_shape, + mainloop_schedule, + epi_schedule, + epi_fusion=None, + ): + self.gemm_kind = gemm_kind + self.arch = arch + self.act_type = act_type + self.weight_type = weight_type + self.scalezero_type = scalezero_type + self.bias_type = bias_type + self.output_type = output_type + self.quant_op = quant_op + self.epi_tag = epi_tag + self.cta_shape = cta_shape + self.warp_shape = warp_shape + self.stages = stages + self.cga_shape = cga_shape + self.mainloop_schedule = mainloop_schedule + self.epi_schedule = epi_schedule + self.epi_fusion = epi_fusion + + def __repr__(self): + kernel_prefix = f"{GemmKindNames[self.gemm_kind]}_sm{self.arch}_{get_data_type_names(self.act_type)}_{get_data_type_names(self.weight_type)}_{get_data_type_names(self.scalezero_type)}_{get_data_type_names(self.bias_type)}_{get_data_type_names(self.output_type)}_{QuantOpNames[self.quant_op]}_{EpiTagNames[self.epi_tag]}_{self.cta_shape[0]}x{self.cta_shape[1]}x{self.cta_shape[2]}_{self.warp_shape[0]}x{self.warp_shape[1]}x{self.warp_shape[2]}_{self.stages}" + + hopper_suffix = f"_{self.cga_shape[0]}x{self.cga_shape[1]}x{self.cga_shape[2]}{KernelScheduleSuffixes[self.mainloop_schedule]}{EpilogueScheduleSuffixes[self.epi_schedule]}{EpiFusionSuffixes[self.epi_fusion]}" + + if self.arch >= 90: + return kernel_prefix + hopper_suffix + elif self.arch > 100: + raise ValueError(f"SM{self.arch} not supported yet.") + return kernel_prefix + + +################################################################################ +def tuple_to_cute_shape(shape): + return f"cute::Shape, cute::Int<{shape[1]}>, cute::Int<{shape[2]}>>" + + +def instantiate_operation_tma_warp_specialized(operation): + act_tag = CudaTypeName[operation.act_type] + scale_zero_tag = CudaTypeName[operation.scalezero_type] + bias_tag = CudaTypeName[operation.bias_type] + out_tag = CudaTypeName[operation.output_type] + + quant_op = QuantOpTag[operation.quant_op] + epi_tag = EpiTag[operation.epi_tag] + + cute_cta_shape = tuple_to_cute_shape(operation.cta_shape) + cute_cga_shape = tuple_to_cute_shape(operation.cga_shape) + + kernel_sched = KernelScheduleTag[operation.mainloop_schedule] + epi_sched = EpilogueScheduleTag[operation.epi_schedule] + + assert operation.gemm_kind == GemmKind.Gemm + weight_tag = DataTypeTag[operation.weight_type] + + return f""" +template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag}, +{quant_op}, {epi_tag}, +{cute_cta_shape}, {cute_cga_shape}, +{kernel_sched}, {epi_sched}> ( +const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float, +{out_tag}*, int, int, int, const int, onnxruntime::llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int* +); +""" + + +def instantiate_operation(insts_list, operation): + if operation.arch >= 90: + insts_list.append(instantiate_operation_tma_warp_specialized(operation)) + + +def get_file_content(launcher_inl_files, operations): + assert operations + include_list = list() + for file in launcher_inl_files: + include_list.append(f'#include "{file}"') + includes = "\n".join(include_list) + + insts_list = list() + for op in operations: + instantiate_operation(insts_list, op) + instantiations = "\n".join(insts_list) + + file_content = f"""{includes} +namespace onnxruntime::llm +{{ +namespace kernels +{{ +namespace cutlass_kernels +{{ + +{instantiations} + +}} // namespace cutlass_kernels +}} // namespace kernels +}} // namespace onnxruntime::llm +""" + return file_content + + +def write_file(launcher_inl_files, operations, output_file): + os.makedirs(os.path.dirname(output_file), exist_ok=True) + # Avoid changing modified time if file content is up to date + content = get_file_content(launcher_inl_files, operations) + if os.path.exists(output_file): + with open(output_file) as f: + if f.read() == content: + return + with open(output_file, mode="w") as f: + f.write(content) + + +def elementwise(x, y, f): + return tuple(f(a, b) for (a, b) in zip(x, y, strict=False)) + + +def is_gemm_op_valid(op): + tile_m, tile_n, _ = op.cta_shape + cga_m, cga_n, _ = op.cga_shape + + if cga_m == 1 and cga_n == 1: + return True + + if cga_m == 2 and cga_n == 1 and tile_m >= 128: + return True + + if cga_m == 1 and cga_n == 2 and tile_n >= 128: + return True + + if cga_m == 2 and cga_n == 2 and tile_m >= 128 and tile_n >= 128: + return True + + return False + + +################################################################################ +def generate_sm90_mixed_gemm_operations(enable_fp8=False, enable_scale_only=False): + arch = 90 + + # For legacy reasons, we use unsigned types for the weights. The instanitated template + # will remap those back to the signed type. + # Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type) + supported_dtypes = [ + (DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16), + (DataType.f16, DataType.u8, DataType.f16, DataType.f16, DataType.f16), + (DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16, DataType.bf16), + (DataType.bf16, DataType.u8, DataType.bf16, DataType.bf16, DataType.bf16), + ] + + if enable_fp8: + supported_dtypes = [ + *supported_dtypes, + (DataType.e4m3, DataType.u4, DataType.f16, DataType.f16, DataType.f16), + (DataType.e4m3, DataType.u4, DataType.f16, DataType.bf16, DataType.bf16), + ] + + quant_ops = [LlmQuantOp.finegrained_scale_and_zeros] + + if enable_scale_only: + quant_ops = [ + *quant_ops, + LlmQuantOp.finegrained_scale_only, + ] + + epi_tags = [LlmEpilogueTag.epilogue_op_bias] + + m_tiles = [64, 128] + n_tiles = [16, 32, 64, 128, 256] + cta_shapes_mn = product(m_tiles, n_tiles) + + warp_shape = [4, 1, 1] + stages = 0 # auto + + cga_shapes = product([1, 2], [1, 2], [1]) + + partial_args = product(supported_dtypes, quant_ops, epi_tags, cta_shapes_mn, cga_shapes) + + operations = list() + for dtype_combo, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args: + max_k_bits = 128 * 8 + cta_shape_k = max_k_bits // get_data_type_bits(dtype_combo[0]) + cta_shape_mnk = (*cta_shape_mn, cta_shape_k) + + use_coop = cta_shape_mn[0] == 128 + mainloop_schedule = ( + KernelScheduleType.TmaWarpSpecializedCooperative + if use_coop + else KernelScheduleType.TmaWarpSpecializedPingpong + ) + epi_schedule = ( + EpilogueScheduleType.TmaWarpSpecializedCooperative if use_coop else EpilogueScheduleType.TmaWarpSpecialized + ) + + mixed_gemm_operation = LlmGemmLauncher( + GemmKind.Gemm, + arch, + *dtype_combo, + quant_op, + epi_tag, + cta_shape_mnk, + warp_shape, + stages, + cga_shape, + mainloop_schedule, + epi_schedule, + ) + + if is_gemm_op_valid(mixed_gemm_operation): + operations.append(mixed_gemm_operation) + + return operations + + +def generate_sm90_operations(is_arch_enabled): + operations = generate_sm90_mixed_gemm_operations() + return operations + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Print the output directory") + + parser.add_argument("-o", "--output_dir", type=str, required=True, help="Path to the output directory") + parser.add_argument("-a", "--architectures", type=str, required=True, help="Architectures to generate kernels for") + + args = parser.parse_args() + + arches = args.architectures.split(";") + + output_dir = os.path.abspath(args.output_dir) + + include_map = { + (GemmKind.Gemm, 90): ["contrib_ops/cuda/llm/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"], + } + + def has_arch(sm): + return f"{sm}" in arches or f"{sm}-real" in arches + + # The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads. + # Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve. + operations = [] + operations += generate_sm90_operations(has_arch(90)) + + op_groups = dict() + for op in operations: + dict_key = (op.gemm_kind, op.arch, op.cta_shape[0]) + op_group = op_groups.get(dict_key, list()) + op_group.append(op) + op_groups[dict_key] = op_group + + file_counter = 1 + for key, value in op_groups.items(): + gemm_kind, _, _ = key + out_file = os.path.join(output_dir, f"fpA_intB_gemm_launcher_{file_counter}.generated.cu") + write_file(include_map[key[:2]], value, out_file) + file_counter += 1 diff --git a/onnxruntime/contrib_ops/cuda/llm/nv_infer_datatype.h b/onnxruntime/contrib_ops/cuda/llm/nv_infer_datatype.h new file mode 100644 index 0000000000000..52e8eb225c79c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/llm/nv_infer_datatype.h @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +// This is corresponding to nvinfer1 namespace used by TensorRT. Add it to avoid dependency on TensorRT. +namespace onnxruntime::llm::nvinfer { + +enum class DataType : int32_t { + //! 32-bit floating point format. + kFLOAT = 0, + + //! IEEE 16-bit floating-point format -- has a 5 bit exponent and 11 bit significand. + kHALF = 1, + + //! Signed 8-bit integer representing a quantized floating-point value. + kINT8 = 2, + + //! Signed 32-bit integer format. + kINT32 = 3, + + //! 8-bit boolean. 0 = false, 1 = true, other values undefined. + kBOOL = 4, + + //! Unsigned 8-bit integer format. + //! Cannot be used to represent quantized floating-point values. + kUINT8 = 5, + + //! Signed 8-bit floating point with + //! 1 sign bit, 4 exponent bits, 3 mantissa bits, and exponent-bias 7. + kFP8 = 6, + + //! Brain float -- has an 8 bit exponent and 8 bit significand. + kBF16 = 7, + + //! Signed 64-bit integer type. + kINT64 = 8, + + //! Signed 4-bit integer type. + kINT4 = 9, + + kFP4 = 10, +}; +} // namespace onnxruntime::llm::nvinfer diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index ed6021530018f..3f485f0abdcb1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -9,21 +9,231 @@ #include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" +#include "contrib_ops/cpu/utils/dump_tensor.h" +#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" +#include "contrib_ops/cuda/quantization/dequantize_blockwise.cuh" +#include "contrib_ops/cuda/llm/fpA_intB_gemm/fpA_intB_gemm.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h" +#include "contrib_ops/cuda/llm/cutlass_preprocessors.h" #include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" -#include "matmul_nbits.cuh" -#include "dequantize_blockwise.cuh" + +constexpr int MatMulNBits_Input_B = 1; +constexpr int MatMulNBits_Input_Scale = 2; +constexpr int MatMulNBits_Input_ZeroPoint = 3; namespace onnxruntime { namespace contrib { namespace cuda { using namespace onnxruntime::cuda; +using onnxruntime::llm::kernels::weight_only::GemmPluginProfilerManager; +using onnxruntime::llm::kernels::weight_only::WeightOnlyGroupwiseQuantGemmPluginProfiler; +using onnxruntime::llm::kernels::weight_only::WeightTypeId; +static GemmPluginProfilerManager s_profilerManager; + +template +void MatMulNBits::InitGemmProfiler(int sm) { + gemmProfiler_ = s_profilerManager.createGemmPluginProfiler(/*inference*/ false); + + if constexpr (std::is_same_v) { + if (nbits_ == 8) { + weightOnlyGemmRunner_ = std::make_shared>(); + } else if (nbits_ == 4) { + weightOnlyGemmRunner_ = std::make_shared>(); + } + } else if constexpr (std::is_same_v) { + if (nbits_ == 8) { + weightOnlyGemmRunner_ = std::make_shared>(); + } else if (nbits_ == 4) { + weightOnlyGemmRunner_ = std::make_shared>(); + } + } + + using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; + KernelType cuda_kernel_type = nbits_ == 8 ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; + gemmProfiler_->setCudaKernelType(cuda_kernel_type, sm); + gemmProfiler_->setQuant(nbits_, has_bias_, has_zero_points_); + gemmProfiler_->setGroupSize(block_size_); +} + +template +void MatMulNBits::RunGemmProfile(bool hasWeightOnlyCudaKernel, int min_m, int max_m) { + // Number of 16-bit elements after casting int8/int4 to fp16. + int n_16b = N_ / (nbits_ == 8 ? 2 : 4); + + gemmId_ = GemmIdCore(n_16b, K_, onnxruntime::llm::nvinfer::DataType::kHALF); + + GemmDims dims = {min_m, max_m, n_16b, K_}; + gemmProfiler_->profileTactics(weightOnlyGemmRunner_, gemmId_.dtype, dims, gemmId_, hasWeightOnlyCudaKernel); +} + +template +Status MatMulNBits::PrePack(const Tensor& /* tensor */, int /* input_idx */, AllocatorPtr /*alloc*/, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + return Status::OK(); +} + +template <> +Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, + PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + if (has_fpA_intB_gemm_) { + cudaStream_t stream = cudaStreamLegacy; // Use default stream for prepacking. + if (input_idx == MatMulNBits_Input_B) { + ORT_RETURN_IF_ERROR(PrePack_B(tensor, alloc, stream)); + is_packed = true; + } else if (input_idx == MatMulNBits_Input_Scale) { + ORT_RETURN_IF_ERROR(PrePack_Scale(tensor, alloc, stream)); + is_packed = true; + } else if (input_idx == MatMulNBits_Input_ZeroPoint) { + if (has_zero_points_) { + ORT_RETURN_IF_ERROR(PrePack_ZeroPoint(tensor, alloc, stream)); + is_packed = true; + } + } + } + + return Status::OK(); +} + +template +Status MatMulNBits::PrePack_B([[maybe_unused]] const Tensor& tensor, + [[maybe_unused]] AllocatorPtr alloc, + [[maybe_unused]] cudaStream_t stream) { + if constexpr (std::is_same_v) { + size_t n = static_cast(N_); + size_t k = static_cast(K_); + + size_t packed_weight_bytes = n * k / (8 / nbits_); + + // uint8 does not need to be packed so we do not need to allocate extra space. + IAllocatorUniquePtr packed_transposed_weight_space = this->GetTransientScratchBuffer(packed_weight_bytes); + int8_t* packed_transposed_weight = reinterpret_cast(packed_transposed_weight_space.get()); + + fpA_intB_weight_buffer_ = IAllocator::MakeUniquePtr(alloc, packed_weight_bytes, true); // Transient buffer. + + int8_t* preprocessed_weight = reinterpret_cast(fpA_intB_weight_buffer_.get()); + + const uint8_t* blob_data = tensor.Data(); + if (nbits_ == 4) { + // Transpose the weight and add default zero point. + onnxruntime::llm::kernels::fpA_intB_gemv::unpack_uint4_transposed_to_int8_direct_cuda( + stream, packed_transposed_weight, blob_data, n, k); + } else { + onnxruntime::llm::kernels::fpA_intB_gemv::transpose_uint8_matrix_and_convert_to_int8( + stream, packed_transposed_weight, blob_data, n, k); + } + + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + auto tranpose_weight_buffer = this->AllocateBufferOnCPUPinned(packed_weight_bytes); + CUDA_RETURN_IF_ERROR(cudaMemcpy(tranpose_weight_buffer.get(), packed_transposed_weight, packed_weight_bytes, cudaMemcpyDeviceToHost)); + + auto processed_weight_buffer = this->AllocateBufferOnCPUPinned(n * k / (8 / nbits_)); + bool force_interleave = false; + + using onnxruntime::llm::kernels::cutlass_kernels::QuantType; + QuantType quant_type = nbits_ == 4 ? QuantType::W4_A16 : QuantType::W8_A16; + + // TODO: Add a cuda kernle for preprocessing so that we can avoid copying the data back to CPU. + onnxruntime::llm::kernels::cutlass_kernels::preprocess_weights_for_mixed_gemm( + reinterpret_cast(processed_weight_buffer.get()), + reinterpret_cast(tranpose_weight_buffer.get()), + {static_cast(k), static_cast(n)}, + quant_type, + force_interleave); + + CUDA_RETURN_IF_ERROR(cudaMemcpy(preprocessed_weight, processed_weight_buffer.get(), n * k / (8 / nbits_), cudaMemcpyHostToDevice)); + CUDA_RETURN_IF_ERROR(cudaDeviceSynchronize()); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("packed transposed_weight in GPU", packed_transposed_weight, k, n * nbits_ / 8); + DUMP_TENSOR_D("preprocessed_weight", reinterpret_cast(preprocessed_weight), k, n * nbits_ / 8); + } + + return Status::OK(); +} + +template +Status MatMulNBits::PrePack_Scale([[maybe_unused]] const Tensor& tensor, + [[maybe_unused]] AllocatorPtr alloc, + [[maybe_unused]] cudaStream_t stream) { + if constexpr (std::is_same_v) { + size_t n = static_cast(N_); + size_t k = static_cast(K_); + + size_t k_blocks = (k + block_size_ - 1) / block_size_; + size_t scale_bytes = n * k_blocks * sizeof(T); + + fpA_intB_scale_buffer_ = IAllocator::MakeUniquePtr(alloc, scale_bytes, true); // Transient buffer. + + typedef typename ToCudaType::MappedType CudaT; + CudaT* transposed_scales = reinterpret_cast(fpA_intB_scale_buffer_.get()); + + onnxruntime::llm::kernels::fpA_intB_gemv::launch_transpose_scale_kernel(stream, reinterpret_cast(tensor.Data()), transposed_scales, n, k_blocks); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("transposed_scales", transposed_scales, k_blocks, n); + } + return Status::OK(); +} + +template +Status MatMulNBits::PrePack_ZeroPoint([[maybe_unused]] const Tensor& tensor, + [[maybe_unused]] AllocatorPtr alloc, + [[maybe_unused]] cudaStream_t stream) { + if constexpr (std::is_same_v) { + size_t n = static_cast(N_); + size_t k = static_cast(K_); + + size_t k_blocks = (k + block_size_ - 1) / block_size_; + size_t scale_bytes = n * k_blocks * sizeof(T); + + typedef typename ToCudaType::MappedType CudaT; + const CudaT* transposed_scales = reinterpret_cast(fpA_intB_scale_buffer_.get()); + + fpA_intB_zero_buffer_ = IAllocator::MakeUniquePtr(alloc, scale_bytes, true); // Transient buffer. + CudaT* scaled_zero_points = reinterpret_cast(fpA_intB_zero_buffer_.get()); + + constexpr float kDefaultZeroPoint4Bit = 8.0f; + constexpr float kDefaultZeroPoint8Bit = 128.0f; + const float default_zero_point = nbits_ == 4 ? kDefaultZeroPoint4Bit : kDefaultZeroPoint8Bit; + const auto* zero_points_data = tensor.DataRaw(); + + // The scaled zero point will be zero for the default zero point, so there is no need to scale when it is nullptr. + if (!tensor.IsDataType()) { // zero point is uint8_t type + if (nbits_ == 4) { + onnxruntime::llm::kernels::fpA_intB_gemv::launch_scaled_zero_point_kernel( + stream, reinterpret_cast(zero_points_data), + transposed_scales, scaled_zero_points, n, k_blocks, default_zero_point); + } else { + onnxruntime::llm::kernels::fpA_intB_gemv::launch_scaled_zero_point_kernel( + stream, reinterpret_cast(zero_points_data), + transposed_scales, scaled_zero_points, n, k_blocks, default_zero_point); + } + } else { // zero point is not uint8_t type + onnxruntime::llm::kernels::fpA_intB_gemv::launch_scaled_zero_point_kernel( + stream, reinterpret_cast(zero_points_data), + transposed_scales, scaled_zero_points, n, k_blocks, default_zero_point); + } + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("scaled_zero_points", scaled_zero_points, k_blocks, n); + } + return Status::OK(); +} template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { + const bool is_prepacked = has_fpA_intB_gemm_; const Tensor* a = ctx->Input(0); - const Tensor* b = ctx->Input(1); - const Tensor* scales = ctx->Input(2); - const Tensor* zero_points = ctx->Input(3); + const Tensor* b = is_prepacked ? nullptr : ctx->Input(1); + const Tensor* scales = is_prepacked ? nullptr : ctx->Input(2); + const Tensor* zero_points = is_prepacked ? nullptr : ctx->Input(3); const Tensor* reorder_idx = ctx->Input(4); const Tensor* bias = ctx->Input(5); @@ -35,19 +245,17 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { a, b, scales, zero_points, reorder_idx, bias, N_, K_, block_size_, nbits_)); const auto* a_data = a->Data(); - const uint8_t* blob_data = b->Data(); - const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); + const uint8_t* blob_data = is_prepacked ? nullptr : b->Data(); + const auto* scales_data = is_prepacked ? nullptr : scales->Data(); + const auto* zero_points_data = (is_prepacked || zero_points == nullptr) ? nullptr : zero_points->DataRaw(); const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); - - typedef typename ToCudaType::MappedType CudaT; + const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); constexpr bool transa = false; constexpr bool transb = true; MatMulComputeHelper helper; TensorShape b_shape({N_, K_}); - ORT_RETURN_IF_ERROR( - helper.Compute(a->Shape(), b_shape, transa, transb)); + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb)); Tensor* Y = ctx->Output(0, helper.OutputShape()); @@ -55,6 +263,61 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { if (Y->Shape().Size() == 0) return Status::OK(); + cudaStream_t stream = static_cast(ctx->GetComputeStream()->GetHandle()); + + typedef typename ToCudaType::MappedType CudaT; + CudaT* out_data = reinterpret_cast(Y->MutableData()); + + int m = SafeInt(helper.M()); + int n = SafeInt(helper.N()); + int k = SafeInt(helper.K()); + + DUMP_TENSOR_INIT(); + + if constexpr (std::is_same::value) { + if (has_fpA_intB_gemm_) { + auto const& bestTactic = gemmProfiler_->getBestConfig(m, gemmId_); + + DUMP_STRING("Best tactic: m=", m, " n=", n, " k=", k, " group_size=", block_size_, bestTactic->toString()); + + if (bestTactic->enableCudaKernel) { + using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; + KernelType cuda_kernel_type = (nbits_ == 8) ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; + + void const* pre_quant_scale_ptr = nullptr; + bool apply_alpha_in_advance = false; + float alpha = 1.0f; + onnxruntime::llm::kernels::fpA_intB_gemv::Params params( + a_data, pre_quant_scale_ptr, fpA_intB_weight_buffer_.get(), + fpA_intB_scale_buffer_.get(), has_zero_points_ ? fpA_intB_zero_buffer_.get() : nullptr, + bias_data, out_data, + alpha, m, n, k, block_size_, cuda_kernel_type, apply_alpha_in_advance); + + onnxruntime::llm::kernels::fpA_intB_gemv::kernel_launcher(sm_, params, stream); + } else { + const size_t workspace_size = weightOnlyGemmRunner_->getWorkspaceSize(m, n, k); + auto workspace_buffer = GetScratchBuffer(workspace_size, ctx->GetComputeStream()); + + weightOnlyGemmRunner_->gemm( + a_data, + fpA_intB_weight_buffer_.get(), + fpA_intB_scale_buffer_.get(), + has_zero_points_ ? fpA_intB_zero_buffer_.get() : nullptr, + bias_data, + 1.f, + out_data, + m, n, k, + block_size_, + *bestTactic, + reinterpret_cast(workspace_buffer.get()), + workspace_size, + stream); + } + + return Status::OK(); + } + } + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { bool done = (nbits_ == 8) ? TryMatMul8Bits( reinterpret_cast(Y->MutableData()), @@ -62,24 +325,24 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { blob_data, reinterpret_cast(scales_data), static_cast(zero_points_data), - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), + m, + n, + k, SafeInt(block_size_), GetDeviceProp().sharedMemPerBlock, - static_cast(ctx->GetComputeStream()->GetHandle())) + stream) : TryMatMul4Bits( reinterpret_cast(Y->MutableData()), reinterpret_cast(a_data), blob_data, reinterpret_cast(scales_data), static_cast(zero_points_data), - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), + m, + n, + k, SafeInt(block_size_), GetDeviceProp().sharedMemPerBlock, - static_cast(ctx->GetComputeStream()->GetHandle())); + stream); if (done) { return Status::OK(); } @@ -104,7 +367,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } else { ORT_RETURN_IF_ERROR(Dequantize8Bits( reinterpret_cast(b_data), @@ -115,7 +378,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } } else { // row-wise block ORT_RETURN_IF_ERROR(DequantizeBlockwise8b( @@ -127,7 +390,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { column_wise_quant_blk_, SafeInt(K_), SafeInt(N_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } } else { // 4 bits if (column_wise_quant_blk_) { @@ -145,7 +408,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } else { ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), @@ -156,7 +419,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } } else { // row-wise block @@ -171,11 +434,10 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { column_wise_quant_blk_, SafeInt(K_), SafeInt(N_), - static_cast(ctx->GetComputeStream()->GetHandle()))); + stream)); } } - DUMP_TENSOR_INIT(); DUMP_TENSOR_D("DeQuantized", b_data, N_, K_padded); const CudaT alpha = ToCudaType::FromFloat(1.f); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index f5c2c6c4e4fdf..02740d905c7c7 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -10,11 +10,27 @@ #include "core/common/safeint.h" #include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/llm/fpA_intB_gemm_profiler.h" +#include "core/platform/env_var_utils.h" namespace onnxruntime { namespace contrib { namespace cuda { using namespace onnxruntime::cuda; +using onnxruntime::llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner; +using onnxruntime::llm::kernels::weight_only::GemmDims; +using onnxruntime::llm::kernels::weight_only::GemmIdCore; +using onnxruntime::llm::kernels::weight_only::GemmPluginProfilerManager; +using onnxruntime::llm::kernels::weight_only::WeightOnlyGroupwiseQuantGemmPluginProfiler; +using GemmProfilerPtr = std::shared_ptr; +using WeightOnlyGemmRunnerPtr = std::shared_ptr; + +// Environment variable to configure fpA_intB_gemm for experiments. Set it to 0 to disable, 1 to eanble all. +constexpr const char* kFpAIntBGemmOption = "ORT_FPA_INTB_GEMM"; +constexpr int kFpAIntBGemmOption_All = 0x01; +constexpr int kFpAIntBGemmOption_Gemv = 0x02; +constexpr int kFpAIntBGemmOption_Int4 = 0x04; +constexpr int kFpAIntBGemmOption_Int8 = 0x08; template class MatMulNBits final : public CudaKernel { @@ -24,16 +40,91 @@ class MatMulNBits final : public CudaKernel { ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + + constexpr size_t kInputIndexScale = 2; + constexpr size_t kInputIndexZeroPoints = 3; + constexpr size_t kInputIndexGroupIndex = 4; + constexpr size_t kInputIndexBias = 5; + + has_zero_points_ = info.GetInputCount() > kInputIndexZeroPoints && info.node().InputDefs()[kInputIndexZeroPoints]->Exists(); + has_g_idx_ = info.GetInputCount() > kInputIndexGroupIndex && info.node().InputDefs()[kInputIndexGroupIndex]->Exists(); + has_bias_ = info.GetInputCount() > kInputIndexBias && info.node().InputDefs()[kInputIndexBias]->Exists(); + sm_ = this->GetDeviceProp().major * 10 + this->GetDeviceProp().minor; + + if (has_zero_points_) { + int32_t zero_point_type = info.node().InputDefs()[kInputIndexZeroPoints]->TypeAsProto()->tensor_type().elem_type(); + int32_t scale_type = info.node().InputDefs()[kInputIndexScale]->TypeAsProto()->tensor_type().elem_type(); + is_zero_points_scale_same_type_ = (zero_point_type == scale_type); + } + + if constexpr (std::is_same::value) { + int option = ParseEnvironmentVariableWithDefault(kFpAIntBGemmOption, 0); + if ((option & (static_cast(nbits_) | kFpAIntBGemmOption_All)) != 0 && + (block_size_ == 64 || block_size_ == 128) && + (nbits_ == 4 || nbits_ == 8) && + !has_g_idx_ && has_zero_points_ && !has_bias_ && + N_ % (nbits_ == 8 ? 32 : 64) == 0 && + K_ % block_size_ == 0 && + sm_ >= 75) { + if ((option & (kFpAIntBGemmOption_Gemv | kFpAIntBGemmOption_All)) != 0) { + using onnxruntime::llm::kernels::fpA_intB_gemv::KernelType; + KernelType cuda_kernel_type = (nbits_ == 8) ? KernelType::FP16Int8Groupwise : KernelType::FP16Int4Groupwise; + if (onnxruntime::llm::kernels::fpA_intB_gemv::is_supported(sm_, cuda_kernel_type)) { + has_fpA_intB_gemv_ = true; + } + } + + InitGemmProfiler(sm_); + + constexpr int max_m = 8291; + RunGemmProfile(has_fpA_intB_gemv_, 1, max_m); + has_fpA_intB_gemm_ = true; + } + } + +#ifndef NDEBUG + printf("n=%d, k=%d, block_size=%d, bits=%d, zp_bits=%d, g_idx=%d, bias=%d, gemv=%d, gemm=%d\n", + int(N_), int(K_), int(block_size_), int(nbits_), + has_zero_points_ ? (is_zero_points_scale_same_type_ ? int(sizeof(T)) * 8 : int(nbits_)) : int(0), + int(has_g_idx_ ? 1 : 0), int(has_bias_ ? 1 : 0), + int(has_fpA_intB_gemv_), int(has_fpA_intB_gemm_)); +#endif } Status ComputeInternal(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + private: + void InitGemmProfiler(int sm); + void RunGemmProfile(bool hasWeightOnlyCudaKernel, int min_m, int max_m); + + Status PrePack_B(const Tensor& tensor, AllocatorPtr alloc, cudaStream_t stream); + Status PrePack_Scale(const Tensor& tensor, AllocatorPtr alloc, cudaStream_t stream); + Status PrePack_ZeroPoint(const Tensor& tensor, AllocatorPtr alloc, cudaStream_t stream); + int64_t K_; int64_t N_; int64_t block_size_; int64_t nbits_; + int sm_{0}; bool column_wise_quant_blk_{true}; + + bool has_g_idx_{false}; + bool has_bias_{false}; + bool has_zero_points_{false}; + bool is_zero_points_scale_same_type_{false}; + bool has_fpA_intB_gemv_{false}; + bool has_fpA_intB_gemm_{false}; + + WeightOnlyGemmRunnerPtr weightOnlyGemmRunner_{nullptr}; + mutable GemmProfilerPtr gemmProfiler_{nullptr}; + GemmIdCore gemmId_{}; + + IAllocatorUniquePtr fpA_intB_weight_buffer_; + IAllocatorUniquePtr fpA_intB_scale_buffer_; + IAllocatorUniquePtr fpA_intB_zero_buffer_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h index 2b2b726e62c79..63e2ab8e9cb9b 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_call.h @@ -15,30 +15,30 @@ std::conditional_t CudaCall( ERRTYPE retCode, const char* exprString, const char* libName, SUCCTYPE successCode, const char* msg, const char* file, const int line); -#define CUDA_CALL(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) -#define CUBLAS_CALL(expr) (CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDA_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) +#define CUBLAS_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUSPARSE_CALL(expr) (CudaCall((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CURAND_CALL(expr) (CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUDNN_CALL(expr) (CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUDNN_CALL2(expr, m) (CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, m, __FILE__, __LINE__)) +#define CUSPARSE_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CURAND_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDNN_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDNN_CALL2(expr, m) (::onnxruntime::CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, m, __FILE__, __LINE__)) -#define CUFFT_CALL(expr) (CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__)) +#define CUFFT_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__)) -#define CUDA_CALL_THROW(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) -#define CUBLAS_CALL_THROW(expr) (CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDA_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) +#define CUBLAS_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUBLAS", CUBLAS_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUSPARSE_CALL_THROW(expr) (CudaCall((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CURAND_CALL_THROW(expr) (CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUSPARSE_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUSPARSE", CUSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CURAND_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CURAND", CURAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) // the cudnn configuration call that doesn't need set stream -#define CUDNN_CALL_THROW(expr) (CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, "", __FILE__, __LINE__)) +#define CUDNN_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUDNN", CUDNN_STATUS_SUCCESS, "", __FILE__, __LINE__)) -#define CUFFT_CALL_THROW(expr) (CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__)) +#define CUFFT_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "CUFFT", CUFFT_SUCCESS, "", __FILE__, __LINE__)) #ifdef ORT_USE_NCCL -#define NCCL_CALL(expr) (CudaCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) -#define NCCL_CALL_THROW(expr) (CudaCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) +#define NCCL_CALL(expr) (::onnxruntime::CudaCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) +#define NCCL_CALL_THROW(expr) (::onnxruntime::CudaCall((expr), #expr, "NCCL", ncclSuccess, "", __FILE__, __LINE__)) #endif } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 043f7ed57b8b0..f8739b859bef5 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -20,6 +20,7 @@ #include "test/optimizer/graph_transform_test_builder.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "test/util/include/scoped_env_vars.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/util/qmath.h" @@ -486,6 +487,7 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura if (use_float16) { opts.output_abs_error = fp16_abs_error; + opts.output_rel_error = use_float16 ? 0.001f : 0.0005f; } std::vector> execution_providers; @@ -548,11 +550,8 @@ TEST(MatMulNBits, Float16Large) { // absolute error of 0.08, but the A10 has errors going as high as 0.22. Ultimately, given the large number // of elements in this test, ULPs should probably be used instead of absolute/relative tolerances. float abs_error = 0.3f; -#elif USE_WEBGPU - // Use absolute error of 0.1 for WebGPU with subgroup implementation - float abs_error = 0.1f; #else - float abs_error = 0.05f; + float abs_error = 0.1f; #endif for (auto block_size : {16, 32, 64, 128}) { @@ -564,6 +563,53 @@ TEST(MatMulNBits, Float16Large) { } } +#ifdef USE_CUDA +TEST(MatMulNBits, Fp16_Int4_Int4ZeroPoint) { + float abs_error = 0.1f; + constexpr bool use_float16 = true; + constexpr bool has_g_idx = false; + constexpr bool zp_is_4bit = true; + constexpr bool has_zeropoint = true; + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + } +} + +TEST(MatMulNBits, Fp16_Int4_Fp16ZeroPoint) { + float abs_error = 0.1f; + constexpr bool use_float16 = true; + constexpr bool has_g_idx = false; + constexpr bool zp_is_4bit = false; + constexpr bool has_zeropoint = true; + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + } +} + +TEST(MatMulNBits, Fp16_Int4_NoZeroPoint) { + float abs_error = 0.1f; + constexpr bool use_float16 = true; + constexpr bool has_g_idx = false; + constexpr bool zp_is_4bit = true; + constexpr bool has_zeropoint = false; + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + for (auto block_size : {64, 128}) { + RunTest(1, 256, 1024, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + RunTest(32, 1024, 2048, block_size, 0, has_zeropoint, use_float16, has_g_idx, zp_is_4bit, abs_error); + } +} +#endif + #endif // defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc index 63677094b1b4b..39f6958d47a12 100644 --- a/onnxruntime/test/contrib_ops/matmul_8bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_8bits_test.cc @@ -21,6 +21,7 @@ #include "test/optimizer/graph_transform_test_builder.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "test/util/include/scoped_env_vars.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/util/qmath.h" @@ -206,28 +207,26 @@ void RunTest8Bits(const TestOptions8Bits& opts) { } template -void TestMatMul8BitsTyped() { +void TestMatMul8BitsTyped(float abs_error = 0.1f, float rel_error = 0.02f) { TestOptions8Bits base_opts{}; base_opts.M = M, base_opts.N = N, base_opts.K = K; base_opts.block_size = block_size; base_opts.accuracy_level = accuracy_level; - if (base_opts.accuracy_level == 4) { - base_opts.output_abs_error = 0.1f; - base_opts.output_rel_error = 0.02f; - } else if constexpr (std::is_same::value) { - base_opts.output_abs_error = 0.055f; - base_opts.output_rel_error = 0.02f; - } + base_opts.output_abs_error = abs_error; + base_opts.output_rel_error = rel_error; { TestOptions8Bits opts = base_opts; + opts.has_zero_point = false; + opts.has_bias = false; RunTest8Bits(opts); } { TestOptions8Bits opts = base_opts; opts.has_zero_point = true; + opts.has_bias = false; RunTest8Bits(opts); } @@ -235,6 +234,7 @@ void TestMatMul8BitsTyped() { #if !defined(USE_CUDA) && !defined(USE_WEBGPU) { TestOptions8Bits opts = base_opts; + opts.has_zero_point = false; opts.has_bias = true; RunTest8Bits(opts); } @@ -249,7 +249,7 @@ void TestMatMul8BitsTyped() { } } // namespace -TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float) { +TEST(MatMulNBits, Float32_8b_AccuracyLevel4) { TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); TestMatMul8BitsTyped(); @@ -285,9 +285,25 @@ TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float) { } #if defined(USE_CUDA) || defined(USE_WEBGPU) -TEST(MatMulNBits, Float32_8b_AccuracyLevel4_Float16) { - TestMatMul8BitsTyped(); - TestMatMul8BitsTyped(); +TEST(MatMulNBits, Float16_8b_AccuracyLevel4) { + constexpr float abs_error = 0.055f; + constexpr float rel_error = 0.02f; + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); +} +#endif + +#if defined(USE_CUDA) +TEST(MatMulNBits, Fp16_Int8_Cuda) { + constexpr float abs_error = 0.5f; + constexpr float rel_error = 0.05f; + + ScopedEnvironmentVariables scoped_env_vars{EnvVarMap{{"ORT_FPA_INTB_GEMM", "1"}}}; + + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); + TestMatMul8BitsTyped(abs_error, rel_error); } #endif