diff --git a/Jenkinsfile b/Jenkinsfile index 80392bfbedd..949d0c0b897 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -349,6 +349,20 @@ def cmake_build(Map conf=[:]){ echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." } } + if (params.RUN_CK_TILE_FLEX_ATTENTION_TESTS){ + try{ + archiveArtifacts "perf_tile_flex_attn_*.log" + if (arch_type == 1){ + stash includes: "perf_tile_flex_attn_gfx90a.log", name: "perf_tile_flex_attn_log_gfx90a" + } + else if (arch_type == 2){ + stash includes: "perf_tile_flex_attn_gfx942.log", name: "perf_tile_flex_attn_log_gfx942" + } + } + catch(Exception err){ + echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." + } + } if (params.RUN_CK_TILE_GEMM_TESTS){ try{ archiveArtifacts "perf_tile_gemm_*.log" @@ -663,6 +677,15 @@ def process_results(Map conf=[:]){ echo "could not locate the FMHA performance logs: ${err.getMessage()}." } } + if (params.RUN_CK_TILE_FLEX_ATTENTION_TESTS){ + try{ + unstash "perf_tile_flex_attn_log_gfx90a" + unstash "perf_tile_flex_attn_log_gfx942" + } + catch(Exception err){ + echo "could not locate the Flex Attention performance logs: ${err.getMessage()}." + } + } if (params.RUN_CK_TILE_GEMM_TESTS){ try{ unstash "perf_tile_gemm_log_gfx942" @@ -713,7 +736,7 @@ def process_results(Map conf=[:]){ } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.3;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.3;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_FLEX_ATTENTION_TESTS=false;RUN_CK_TILE_GEMM_TESTS=true 0 21 * * * % ROCMVERSION=6.3;hipTensor_test=true;RUN_CODEGEN_TESTS=true 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true @@ -793,6 +816,10 @@ pipeline { name: "RUN_CK_TILE_FMHA_TESTS", defaultValue: false, description: "Run the ck_tile FMHA tests (default: OFF)") + booleanParam( + name: "RUN_CK_TILE_FLEX_ATTENTION_TESTS", + defaultValue: false, + description: "Run the ck_tile FLEX ATTENTION tests (default: ON)") booleanParam( name: "RUN_CK_TILE_GEMM_TESTS", defaultValue: true, @@ -984,6 +1011,51 @@ pipeline { } } } + stage("Run RUN_CK_TILE_FLEX_ATTENTION_TESTS Test") + { + + parallel + { + stage("Run RUN_CK_TILE_FLEX_ATTENTION_TESTS Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_FLEX_ATTENTION_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 tile_example_flexattn_fwd && \ + cd ../ && + example/ck_tile/18_flexattn/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + stage("Run RUN_CK_TILE_FLEX_ATTENTION_TESTS Tests on gfx942") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_FLEX_ATTENTION_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx942") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ + make -j64 tile_example_flexattn_fwd && \ + cd ../ && + example/ck_tile/18_flexattn/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } stage("Run CK_TILE_GEMM Tests") { parallel diff --git a/example/ck_tile/18_flexattn/CMakeLists.txt b/example/ck_tile/18_flexattn/CMakeLists.txt new file mode 100644 index 00000000000..4651fddcb42 --- /dev/null +++ b/example/ck_tile/18_flexattn/CMakeLists.txt @@ -0,0 +1,105 @@ +# validate user-specified fmha_fwd API list +set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv") +set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING + "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") +if(FMHA_FWD_ENABLE_APIS STREQUAL "all") + set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS}) +endif() + +variable_watch(FMHA_SCORE_MOD_F) +set(FMHA_SCORE_MOD_F [[s + static_cast((q_idx - v_idx) % 8)]]) +# set(FMHA_SCORE_MOD_F [[s]]) + +variable_watch(FMHA_PRE_SOFTMAX_F) +# set(FMHA_PRE_SOFTMAX_F [[static_cast(tanh(s*1.0)/1.0)]]) +set(FMHA_PRE_SOFTMAX_F [[s]]) + +foreach(api ${FMHA_FWD_ENABLE_APIS}) + if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS) + message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.") + endif() +endforeach() + +# "fwd" is a must-have api for the fmha_fwd example, add it if not specified +if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND FMHA_FWD_ENABLE_APIS "fwd") +endif() + +string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") +# generate a list of kernels, but not actually emit files at config sta +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${FMHA_FWD_APIS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + RESULT_VARIABLE ret +) +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD kernels via Python.") +endif() + +# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory +# as current cmake list, otherwise will not figure out the dependency properly +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS) + +add_custom_command( + OUTPUT ${FMHA_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${FMHA_FWD_APIS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + "--score_mod_expr=${FMHA_SCORE_MOD_F}" + "--pre_softmax_expr=${FMHA_PRE_SOFTMAX_F}" + VERBATIM +) + +set(EXAMPLE_FMHA_FWD "tile_example_flexattn_fwd") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_FMHA_FWD}") +add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) +target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) + + +# NOTE: this is dangerous since will change the whole kernel to flush denormals +# WIP with compiler team for an exp2 intrinsic..., then remove this +if(NOT DEFINED FMHA_FWD_FAST_EXP2) + set(FMHA_FWD_FAST_EXP2 true) +endif() + +set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +# ... because they are auto-generated +if(FMHA_FWD_FAST_EXP2) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) +endif() + +# conditionally enable call to the fwd_splitkv API in fmha_fwd example +if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0) +endif() + +# conditionally enable call to the fwd_appendkv API in fmha_fwd example +if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) +endif() + +# Allow comparing floating points directly in order to check sentinel values +list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) + +list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS "-DCK_TILE_SCORE_MOD_F=${FMHA_SCORE_MOD_F}") +list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS "-DCK_PRE_SOFTMAX_F=${FMHA_PRE_SOFTMAX_F}") + +target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/18_flexattn/bias.hpp b/example/ck_tile/18_flexattn/bias.hpp new file mode 100644 index 00000000000..d8da0224cc5 --- /dev/null +++ b/example/ck_tile/18_flexattn/bias.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/ops/flex_fmha.hpp" + +// keep sync with BlockAttentionBiasEnum +enum class bias_enum +{ + no_bias = 0, + elementwise_bias = 1, + alibi = 2, +}; + +struct bias_info +{ + bias_enum type; + /* + * simple dispatch logic + * + * if type == elementwise_bias: + * if rank_info == 0: + * bias is 1*1*s*s + * elif rank_info == 1: + * bias is 1*h*s*s + * elif rank_info == 2: + * bias is b*h*s*s + * + * elif type == alibi: + * if rank_info == 0: + * alibi in 1*h + * elif rank_info == 1: + * alibi in b*h + */ + int rank_info; + + void serialize(std::ostream& os) const + { + if(type == bias_enum::no_bias) + os << "n"; + else if(type == bias_enum::elementwise_bias) + { + os << "e"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + else if(type == bias_enum::alibi) + { + os << "alibi"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + } + + static bias_info decode(std::string str) + { + bias_info info{bias_enum::no_bias, 0}; + if(str == "0" || str == "n") + { + info.type = bias_enum::no_bias; + } + else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 || + str.compare(0, 11, "elementwise") == 0) + { + info.type = bias_enum::elementwise_bias; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 || + str.compare(0, 5, "alibi") == 0) + { + info.type = bias_enum::alibi; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + return info; + } + + friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) + { + bi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/18_flexattn/codegen/__init__.py b/example/ck_tile/18_flexattn/codegen/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/18_flexattn/codegen/cmake_config.py b/example/ck_tile/18_flexattn/codegen/cmake_config.py new file mode 100644 index 00000000000..03ebfd67021 --- /dev/null +++ b/example/ck_tile/18_flexattn/codegen/cmake_config.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +GEN_DIR = "" # in Cmake, have to generate files in same folder \ No newline at end of file diff --git a/example/ck_tile/18_flexattn/codegen/cpp_symbol_map.py b/example/ck_tile/18_flexattn/codegen/cpp_symbol_map.py new file mode 100644 index 00000000000..332707eafd1 --- /dev/null +++ b/example/ck_tile/18_flexattn/codegen/cpp_symbol_map.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +FWD_DTYPE_MAP = { + "fp16" : "FmhaFwdFp16", + "bf16" : "FmhaFwdBf16", + "fp8" : "FmhaFwdFp8", + "fp8fp16": "FmhaFwdFp8Fp16", + "fp8bf16": "FmhaFwdFp8Bf16" +} + +BWD_DTYPE_MAP = { + "fp16": "FmhaBwdFp16", + "bf16": "FmhaBwdBf16" +} + +MASK_IMPL = { + "generic" : "ck_tile::GenericAttentionMask", + "simplified" : "ck_tile::SimplifiedGenericAttentionMask" +} + +_MASK_SIMPLIFIED_MAP = { + "s_no" : "ck_tile::SimplifiedGenericAttentionMask", + "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", +} + +_MASK_MAP = { + "no" : "FmhaMasks::NoMask", + "causal" : "FmhaMasks::CausalMask", + "generic" : "FmhaMasks::GenericMask" +} + +def get_mask_map(mask : str): + if mask == "generic": + return _MASK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_MAP + else: + assert False + return None + +_MASK_CHECK_MAP = { + "no" : "t.mask_type == mask_enum::no_mask", + "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", + "generic" : "t.mask_type == mask_enum::window_generic", +} + +_MASK_SIMPLIFIED_CHECK_MAP = { + "s_no" : "t.mask_type == mask_enum::no_mask", + "s_mask" : "t.mask_type != mask_enum::no_mask", +} + +def get_mask_check_map(mask : str): + if mask == "generic": + return _MASK_CHECK_MAP + elif mask == "simplified": + return _MASK_SIMPLIFIED_CHECK_MAP + else: + assert False + return None + +BIAS_MAP = { + "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", + "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", + "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" +} + +# TODO: this is ugly +BIAS_CHECK_MAP = { + "no" : "bias_enum::no_bias", + "bias" : "bias_enum::elementwise_bias", + "alibi" : "bias_enum::alibi" +} + +DROPOUT_MAP = { + "no" : "ck_tile::BlockDropoutBwd", + "dropout_wg32" : "ck_tile::BlockDropoutBwd", + "dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd", + "dropout_wg16" : "ck_tile::BlockDropoutBwd", + "dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd" +} + +DROPOUT_CHECK_MAP = { + "no" : "t.has_dropout == false", + "dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true", + "dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false", + "dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true", +} + +ROPE_MAP = { + "no" : "ck_tile::RotaryEmbeddingEnum::NONE", + "inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", + "half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED" +} + +ROPE_CHECK_MAP = { + "no" : "rope_enum::none", + "inter" : "rope_enum::interleaved", + "half" : "rope_enum::half_rotated" +} + +MODE_MAP = { + "batch" : "false", + "group" : "true" +} + +LAYOUT_MAP = { + "row" : "true", + "col" : "false" +} + +PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", +} + +PIPELINE_ENUM_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", + "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", +} + +BOOL_MAP = { + "t" : "true", + "f" : "false" +} diff --git a/example/ck_tile/18_flexattn/codegen/ops/__init__.py b/example/ck_tile/18_flexattn/codegen/ops/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/ck_tile/18_flexattn/codegen/ops/fmha_fwd.py b/example/ck_tile/18_flexattn/codegen/ops/fmha_fwd.py new file mode 100644 index 00000000000..b0fc09e2e46 --- /dev/null +++ b/example/ck_tile/18_flexattn/codegen/ops/fmha_fwd.py @@ -0,0 +1,564 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +struct score_mod_def_{F_idx} {{ + using TScore = typename fmha_pipeline_{F_idx}::SaccDataType; + template + CK_TILE_HOST_DEVICE TScore operator()(TScore s, + ck_tile::index_t b, + ck_tile::index_t h, + ck_tile::index_t q_idx, + ck_tile::index_t v_idx, + std::optional context = std::nullopt) const {{ + (void) s; (void) h; (void) b; (void) q_idx; (void) v_idx; (void) context; + return {F_score_mod_expr}; + }} +}}; + +struct pre_softmax_def_{F_idx} {{ + using TScore = typename fmha_pipeline_{F_idx}::SaccDataType; + CK_TILE_HOST_DEVICE TScore operator()(TScore s + ) const {{ + (void) s; + return {F_pre_softmax_expr}; + }} +}}; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdKernel; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; + // std::cout << "dtype=" << t.data_type << " qdim=" << t.hdim_q << " vdim=" << t.hdim_v << std::endl; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + return fmha_fwd_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_bias != 'no' : n += f'_{self.F_bias}' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_lse == 't' : n += '_lse' + if self.F_dropout == 't' : n += '_dropout' + if self.F_squant == 't' : n += '_squant' + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + first_i = True + for i, dtype in enumerate(self.pool.keys()): + if dtype != "bf16": + continue + per_hdim_case=str() + first_j = True + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + first_k = True + for k, trait in enumerate(traits): + if trait.dropout == "t": + continue + if trait.lse == "t": + continue + if trait.pipeline_tag not in ("qr", "qr_async"): + continue + if_k = 'if' if first_k else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + first_k = False + if_j = 'if' if first_j else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + first_j = False + if_i = 'if' if first_i else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + first_i = False + + if not per_dtypes: + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + F_score_mod_expr: str + F_pre_softmax_expr:str + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_score_mod_expr = self.F_score_mod_expr, + F_pre_softmax_expr = self.F_pre_softmax_expr) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } + else: + return None + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr : str, pre_softmax_expr : str) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["f"], ["f"]): + if hdim == 256: + # if True: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + else: + if bias == "bias": + # TODO: rocm 6.2 compiler problem if using qr_async for bias case + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + else: + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + if receipt == 1 and bias != "bias": + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # no need lse/dropout kernels + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + # for dtype in FWD_DTYPE_MAP.keys(): + for dtype in ["bf16"]: + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + F_score_mod_expr=score_mod_expr, + F_pre_softmax_expr=pre_softmax_expr, + mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/18_flexattn/fmha_fwd.cpp b/example/ck_tile/18_flexattn/fmha_fwd.cpp new file mode 100644 index 00000000000..ca03c433680 --- /dev/null +++ b/example/ck_tile/18_flexattn/fmha_fwd.cpp @@ -0,0 +1,1622 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_fwd.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ref/naive_attention.hpp" +#include "mask.hpp" +#include "rotary.hpp" +#include "utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API +#error "we should enable fmha_fwd_splitkv() api in order to cooperate with fmha_fwd_appendkv()" +#endif + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "-1", + "num of head, for k/v, -1 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert( + "s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" + "also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)") + .insert("s_k", "-1", "seqlen_k (including new key/value), -1 means equal to s") + .insert("s_knew", + "0", + "seqlen_k for new key/value, 0 means not to use this at all; " + "-1 to choose s_knew in [1, s] randomly.") + .insert("s_kpad", + "-1", + "seqlen_k stride between 2 batches, currently used in group-mode only\n" + "for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n" + "along seqlen, instead of packed. same as xformer kv_padding") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") + .insert("scale_s", + "0", + "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" + "note when squant=1, this value will be modified by range_q/k") + .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") + .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") + .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") + .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") + .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") + .insert("squant", + "auto", + "if using static quantization fusion or not. auto: fp8 will default use squant, " + "other will not\n" + "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to " + "P and O.\n" + "calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, " + "range_p, range_o") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", + "n", + "n or 0, no bias\n" + "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" + "a(libi) or 2, alibi with 1*h. a:1, b*h") + .insert("prec", "bf16", "data type. fp16/bf16/fp8/bf8") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)") + .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") + .insert("lse", "0", "0 not store lse, 1 store lse") + .insert("kname", "0", "if set to 1 will print kernel name") + .insert("init", + "uf", + "init method. ui, uniform random int, ni, normalized random int\n" + "uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, " + "quantization") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("p_drop", "0", "0~1 probability of dropout") + .insert("drop_seed", "1", "seed for random number generator") + .insert("drop_offset", "0", "offset for random number generator") + .insert("drop_prefs", + "0", + "seed and offset values are present on GPU; 0 - host, 1 - device/GPU") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert( + "rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all") + .insert("rotary_interleaved", "1", "whether to apply interleaved RoPE") + .insert("num_splits", + "1", + "# of splits for key/value. 0 to determine actual number by heuristic") + .insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe") + .insert("cache_batch_idx", "0", "whether to use index map to the kvcache") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string init_method) +{ + if(init_method == "ui" || init_method == "ni") + { + unsigned max_rounding_point_distance = 0; + double atol = 2e-3; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } + else + { + unsigned max_rounding_point_distance = 1; + double atol = 0.0625; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } +} + +int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) +{ + // If we have enough to almost fill the SMs, then just use 1 split + if(batch_nhead_mblocks >= 0.8f * num_SMs) + { + return 1; + } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || + ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + if(!is_split_eligible(num_splits)) + { + efficiency.push_back(0.f); + } + else + { + float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if(eff > max_efficiency) + { + max_efficiency = eff; + } + efficiency.push_back(eff); + } + } + for(int num_splits = 1; num_splits <= max_splits; num_splits++) + { + if(!is_split_eligible(num_splits)) + { + continue; + } + if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) + { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +int override_num_splits_if_necessary( + int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) +{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return num_splits; + } + + hipDeviceProp_t props{}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return num_splits; + } + + // tile size should match the generate.py + const int kM0 = 64; + const int kN1 = hdim_v; + + const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); + const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); + + if(num_splits < 1 && p_drop == 0.0f) + { + return num_splits_heuristic( + batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); + } + + return num_splits; +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + if(nhead_k < 0) + nhead_k = nhead; + + if(nhead % nhead_k != 0) + { + std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return false; + } + + std::optional seed = arg_parser.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v < 0) + hdim_v = hdim_q; + + ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); +#if !CK_TILE_FMHA_FWD_APPENDKV_API + if(seqlen_knew != 0) + { + std::cerr << "fmha_fwd_appendkv() is not enabled. ignoring the 's_knew' option" + << std::endl; + seqlen_knew = 0; + } +#endif + if(seqlen_knew < 0) + { + seqlen_knew = randint(1, arg_parser.get_int("s"), seed); + } + + ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); + if constexpr(!(std::is_same_v || + std::is_same_v)) + { + if(0 < rotary_dim) + { + std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl; + return false; + } + } +#if !CK_TILE_FMHA_FWD_APPENDKV_API + else if(0 < rotary_dim) + { + std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option" + << std::endl; + rotary_dim = 0; + } +#endif + // to use fmha_fwd_appendkv(), make sure it's in batch mode + const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim); + if(need_append_kvcache && mode == mode_enum::group) + { + std::cerr << "fmha_fwd_appendkv() will be invoked. ignoring the 'mode' option" << std::endl; + mode = mode_enum::batch; + } + if(!(rotary_dim <= hdim_q)) + { + std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; + return false; + } + else if(!(rotary_dim % 16 == 0)) + { + std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl; + return false; + } + + ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); +#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < page_block_size) + { + std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" + << std::endl; + page_block_size = 0; + } +#endif + if(!(page_block_size % 128 == 0)) + { + std::cerr << "only paged-kvcache block size divisible by 128 are currently supported" + << std::endl; + return false; + } + + bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx"); +#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API + if(use_cache_batch_idx) + { + std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } +#else + if(use_cache_batch_idx) + { + if(0 < page_block_size) + { + std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the " + "'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } + else if(mode == mode_enum::group) + { + std::cerr << "group mode will not use cache_batch_idx. ignoring the " + "'cache_batch_idx' option" + << std::endl; + use_cache_batch_idx = false; + } + } +#endif + const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); + + auto [seqlen_qs, seqlen_ks, seqlen_kpads] = + decode_seqlen(mode, + batch, + arg_parser.get_str("s"), + arg_parser.get_str("s_k"), + arg_parser.get_str("s_kpad"), + /*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0, + need_append_kvcache); + // compute kvcache seqlen_k (before appending knew/vnew) + auto cache_seqlen_ks = seqlen_ks; + std::transform(cache_seqlen_ks.begin(), + cache_seqlen_ks.end(), + cache_seqlen_ks.begin(), + [&](auto seqlen_k) { return seqlen_k - seqlen_knew; }); + +#if 0 + // clang-format off + std::cout << "seqlen_qs:"; for(auto xx : seqlen_qs) { std::cout << xx << ","; } std::cout << std::endl; + std::cout << "seqlen_ks:"; for(auto xx : seqlen_ks) { std::cout << xx << ","; } std::cout << std::endl; + std::cout << "seqlen_kpads:"; for(auto xx : seqlen_kpads) { std::cout << xx << ","; } std::cout << std::endl; + // clang-format on +#endif + + bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim + bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim + + float scale_s = arg_parser.get_float("scale_s"); + if(scale_s == .0f) + scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + + std::string squant_str = arg_parser.get_str("squant"); + bool squant = [&]() { + if(squant_str == "auto") + { + if(data_type == "fp8") + return true; + else + return false; + } + else + return atoi(squant_str.c_str()) != 0 ? true : false; + }(); + + std::string vlayout = arg_parser.get_str("vlayout"); + bool lse = arg_parser.get_bool("lse"); + + bias_info bias = bias_info::decode(arg_parser.get_str("bias")); + mask_info mask = mask_info::decode( + arg_parser.get_str("mask"), seqlen_qs[0], seqlen_ks[0]); // TODO: we don't need x/y anymore + + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + bool drop_prefs = arg_parser.get_bool("drop_prefs"); + + if(p_drop < 0.0f || p_drop > 1.0f) + { + std::cerr << "The value of p_drop should be 0~1" << std::endl; + return false; + } + + bool s_randval = false; + if(p_drop > 0.0f && do_validation != 0) + { + s_randval = true; + } + + std::string init_method = arg_parser.get_str("init"); + + const bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); + + ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); +#if !CK_TILE_FMHA_FWD_SPLITKV_API + if(num_splits != 1) + { + std::cerr << "split-kv is not supported. ignoring the 'num_splits' option" << std::endl; + num_splits = 1; + } +#endif + + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + stream_warmup, + stream_repeat, + arg_parser.get_str("timer") == std::string("gpu")}; + + const auto seqstart_q_host = to_seqstarts(seqlen_qs); + const auto seqstart_k_host = to_seqstarts(seqlen_ks); + const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); + + using TypeConfig = FmhaFwdTypeConfig; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using RandValOutputDataType = typename TypeConfig::RandValOutputDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using SaccDataType = typename TypeConfig::SaccDataType; + using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; + using PDataType = typename TypeConfig::PDataType; + using OaccDataType = typename TypeConfig::OaccDataType; + using ODataType = typename TypeConfig::ODataType; + + float range_q = arg_parser.get_float("range_q"); + float range_k = arg_parser.get_float("range_k"); + float range_v = arg_parser.get_float("range_v"); + float range_p = arg_parser.get_float("range_p"); + float range_o = arg_parser.get_float("range_o"); + + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float p_dtype_max = v_dtype_max; // assume p and v is the same type + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + + float scale_p = 1.f; + float scale_o = 1.f; + + if(squant) + { + scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max); + scale_p = p_dtype_max / range_p; + scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max); + } + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + auto max_seqlen_k = std::numeric_limits::min(); + { + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + if(max_seqlen_k < real_seqlen_k) + { + max_seqlen_k = real_seqlen_k; + } + + flop += nhead * (static_cast(2) * real_seqlen_q * real_seqlen_k * hdim_q + + static_cast(2) * real_seqlen_q * hdim_v * real_seqlen_k); + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * hdim_v * real_seqlen_k + + sizeof(ODataType) * real_seqlen_q * hdim_v); + } + } + + const ck_tile::index_t max_num_page_blocks = + (0 < page_block_size + ? batch * std::max(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size)) + : 0); + + // legalize num_splits according to other options + if(num_splits < 1) + { + num_splits = override_num_splits_if_necessary( + batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits); + } + if(128 < num_splits) + { + std::cerr << "num_splits greater than 128 is not supported" << std::endl; + return false; + } +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < p_drop && (1 < num_splits || use_kvcache)) + { + std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option" + << std::endl; + p_drop = 0.0f; + } +#endif + + static const auto get_lengths = [](bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; + + bool is_v_rowmajor = vlayout == std::string("r"); + + // host memory for storing all the tensor elements + const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); + const ck_tile::index_t shape_seqlen_k = + (mode == mode_enum::batch ? seqlen_ks[0] + : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() + : seqstart_k_with_padding_host.back())); + + ck_tile::HostTensor q_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor k_host( + 0 < page_block_size + ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) + : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + /// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode + ck_tile::HostTensor knew_host( + 0 < seqlen_knew + ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor v_host( + 0 < page_block_size + ? (is_v_rowmajor + ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v) + : get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size)) + : (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) + : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k))); + ck_tile::HostTensor vnew_host( + 0 < seqlen_knew + ? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v) + : get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew)) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor bias_host( + bias.type == bias_enum::elementwise_bias + ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + + ck_tile::HostTensor alibi_slope_host( + bias.type == bias_enum::alibi + ? (bias.rank_info == 0 ? std::array{1, nhead} + : std::array{batch, nhead}) + : std::array{1, 1}); + + auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin( + std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed); + + ck_tile::HostTensor lse_acc_host( + 1 < num_splits || use_kvcache + ? std::array{shape_batch, nhead, num_splits, shape_seqlen_q} + : std::array{1, 1, 1, 1}); + ck_tile::HostTensor o_acc_host( + 1 < num_splits || use_kvcache ? std::array{shape_batch, + nhead, + num_splits, + shape_seqlen_q, + hdim_v} + : std::array{1, 1, 1, 1, 1}); + + // batch mode of lse data layout is [batch, nhead, seqlen_q] + // group mode of lse data layout is [nhead, total_seqlen_q] + ck_tile::HostTensor lse_host( + lse ? std::array{shape_batch, nhead, shape_seqlen_q} + : std::array{1, 1, 1} /* dummy shape for simplifying code */); + + ck_tile::HostTensor o_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + + ck_tile::HostTensor randval_host( + p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) + : std::array{1, 1, 1, 1}); + + ck_tile::HostTensor block_table_host( + 0 < page_block_size ? std::array{batch, max_num_page_blocks / batch} + : std::array{1, 1}); + + ck_tile::HostTensor cache_batch_idx_host(use_cache_batch_idx + ? std::array{batch} + : std::array{1}); + + if(init_method == "ui" || init_method == "0") + { + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); + ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); + } + else if(init_method == "ni") + { + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(q_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(knew_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(vnew_host); + ck_tile::FillNormalDistributionIntegerValue{-3.f, 3.f, seed}(bias_host); + } + else if(init_method == "uf" || init_method == "1") + { + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(knew_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(vnew_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + } + else if(init_method == "nf") + { + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(q_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(k_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(knew_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(v_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(vnew_host); + ck_tile::FillNormalDistribution{0.f, 3.f, seed}(bias_host); + } + else if(init_method == "tf" || init_method == "2") + { + ck_tile::FillTrigValue{}(q_host); + ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(knew_host); + ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(vnew_host); + ck_tile::FillTrigValue{}(bias_host); + } + else if(init_method == "ufq" || init_method == "uf:q" || + init_method == "3") // suitable for fp8 quantization + { + ck_tile::FillUniformDistribution{-q_dtype_max, q_dtype_max, seed}(q_host); + ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, seed}(k_host); + ck_tile::FillUniformDistribution{-k_dtype_max, k_dtype_max, seed}(knew_host); + ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, seed}(v_host); + ck_tile::FillUniformDistribution{-v_dtype_max, v_dtype_max, seed}(vnew_host); + + // bias_fp8 = qscale_bias * bias_fp32 + float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k); + // Assume bias is in [-1.f, 1.f] in original fp32 + ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, seed}(bias_host); + } + if(bias.type == bias_enum::alibi) + { + auto slopes = ck_tile::get_alibi_slopes(nhead); + assert(slopes.size() == static_cast(nhead)); + if(bias.rank_info == 0) + { + // alibi in 1*h + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin()); + } + else + { + // alibi in b*h + for(auto i_b = 0; i_b < batch; i_b++) + { + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead); + } + } + } + iota_shuffle(block_table_host.begin(), block_table_host.end(), 0); + iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0); + + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || + 0 <= seqlen_kpads[0] + ? seqlen_ks.size() * sizeof(int32_t) + : 0); + ck_tile::DeviceMem cache_seqlen_k_buf( + need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); + ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + knew_buf.ToDevice(knew_host.data()); + v_buf.ToDevice(v_host.data()); + vnew_buf.ToDevice(vnew_host.data()); + bias_buf.ToDevice(bias_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() + : seqstart_k_with_padding_host.data()); + seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] + ? seqlen_ks.data() + : nullptr); + cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); + rotary_cos_buf.ToDevice(rotary_cos_host.data()); + rotary_sin_buf.ToDevice(rotary_sin_host.data()); + drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); + drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); + alibi_slope_buf.ToDevice(alibi_slope_host.data()); + block_table_buf.ToDevice(block_table_host.data()); + cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data()); + + // clang-format off + auto layout_str = [&](bool permute){ + if(permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + auto io_layout = [&](bool iperm_, bool operm_) { + if(iperm_ == operm_) return layout_str(iperm_); + else return layout_str(iperm_) + std::string("-") + layout_str(operm_); + }; + // clang-format on + const std::string prec = arg_parser.get_str("prec"); + + std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] << "/" << seqlen_ks[0] + << (seqlen_kpads[0] < 0 ? "" + : (std::string("(") + std::to_string(seqlen_kpads[0]) + ")")) + << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias + << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant + << ", mask:" << mask << ", v:" << vlayout; +#if CK_TILE_FMHA_FWD_APPENDKV_API + if(0 < rotary_dim) + { + std::cout << ", rotary_dim:" << rotary_dim << "(" + << (is_rotary_interleaved ? "inter" : "half") << ")"; + } +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(1 < num_splits) + { + std::cout << ", num_splits:" << num_splits; + } + if(0 < page_block_size) + { + std::cout << ", page_block_size:" << page_block_size; + } + if(use_cache_batch_idx) + { + std::cout << ", cache_batch_idx:" << use_cache_batch_idx; + } +#endif + std::cout << std::flush; + + const auto init_traits = [&](auto& traits) { + traits.hdim_q = hdim_q; + traits.hdim_v = hdim_v; + traits.data_type = data_type; + traits.is_v_rowmajor = is_v_rowmajor; + + if constexpr(std::is_same_v>) + { + traits.rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved + : rope_enum::half_rotated) + : rope_enum::none); + } + else // fmha_fwd_traits or fmha_splitkv_traits + { + traits.is_group_mode = (mode == mode_enum::group); + traits.mask_type = mask.type; + traits.bias_type = bias.type; + // traits.has_lse = lse; + // traits.do_fp8_static_quant = squant; + + // if constexpr(std::is_same_v>) + // { + // traits.has_dropout = (p_drop > 0.0f); + // } + } + }; + + const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) { + assert(nhead % nhead_k == 0); + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & + /// 'nhead_stride_bias' are 0. + // setup stride_* arguments + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size) + : (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); + }(); + const ck_tile::index_t stride_vnew = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return i_perm ? seqlen_knew : nhead_k * seqlen_knew; + }(); + const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + const ck_tile::index_t stride_randval = (max_seqlen_k); + const ck_tile::index_t stride_o_acc = (hdim_v); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = + (0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q) + : (i_perm ? shape_seqlen_k * hdim_q : hdim_q)); + const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v) + : (i_perm ? shape_seqlen_k * hdim_v : hdim_v); + else + return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size) + : (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k); + }(); + const ck_tile::index_t nhead_stride_vnew = [&]() { + if(is_v_rowmajor) + return i_perm ? seqlen_knew * hdim_v : hdim_v; + else + return i_perm ? hdim_v * seqlen_knew : seqlen_knew; + }(); + const ck_tile::index_t nhead_stride_bias = + (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = + (0 < page_block_size ? (nhead_k * page_block_size * hdim_q) + : (nhead_k * shape_seqlen_k * hdim_q)); + const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q); + const ck_tile::index_t batch_stride_v = + (0 < page_block_size ? (nhead_k * hdim_v * page_block_size) + : (nhead_k * hdim_v * shape_seqlen_k)); + const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); + const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); + const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); + // setup split_stride_* arguments (only used in split-kv kernel) + const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); + const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); + + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.v_ptr = v_buf.GetDeviceBuffer(); + + args.batch = batch; + args.seqlen_q = shape_seqlen_q; // unused in group mode + args.hdim_q = hdim_q; + args.hdim_v = hdim_v; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + + args.stride_q = stride_q; + args.stride_k = stride_k; + args.stride_v = stride_v; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.nhead_stride_v = nhead_stride_v; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.batch_stride_v = batch_stride_v; + + if constexpr(std::is_same_v>) + { + args.knew_ptr = knew_buf.GetDeviceBuffer(); + args.vnew_ptr = vnew_buf.GetDeviceBuffer(); + args.seqlen_knew = seqlen_knew; + + args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer(); + + args.rotary_cos_ptr = (0 < rotary_dim ? rotary_cos_buf.GetDeviceBuffer() : nullptr); + args.rotary_sin_ptr = (0 < rotary_dim ? rotary_sin_buf.GetDeviceBuffer() : nullptr); + args.rotary_dim = rotary_dim; + args.has_mask = (mask.type != mask_enum::no_mask); + + args.block_table_ptr = + (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + + args.cache_batch_idx = + (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + + args.stride_knew = stride_knew; + args.stride_vnew = stride_vnew; + args.nhead_stride_knew = nhead_stride_knew; + args.nhead_stride_vnew = nhead_stride_vnew; + args.batch_stride_knew = batch_stride_knew; + args.batch_stride_vnew = batch_stride_vnew; + } + else // fmha_fwd_args or fmha_fwd_splitkv_args + { + args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(); + args.lse_ptr = lse_buf.GetDeviceBuffer(); + args.o_ptr = o_buf.GetDeviceBuffer(); + + args.seqstart_q_ptr = + (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr); + args.seqstart_k_ptr = + (mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr); + args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0] + ? seqlen_k_buf.GetDeviceBuffer() + : nullptr); + + args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled) + args.max_seqlen_q = max_seqlen_q; + + args.scale_s = scale_s; + args.scale_p = scale_p; + args.scale_o = scale_o; + + args.stride_bias = + (bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias); + args.stride_o = stride_o; + args.nhead_stride_bias = nhead_stride_bias; + args.nhead_stride_lse = nhead_stride_lse; + args.nhead_stride_o = nhead_stride_o; + args.batch_stride_bias = batch_stride_bias; + args.batch_stride_lse = batch_stride_lse; + args.batch_stride_o = batch_stride_o; + + args.window_size_left = mask.left; + args.window_size_right = mask.right; + args.mask_type = static_cast(mask.type); + + if constexpr(std::is_same_v>) + { + args.rand_val_ptr = randval_buf.GetDeviceBuffer(); + + args.stride_randval = stride_randval; + args.nhead_stride_randval = nhead_stride_randval; + args.batch_stride_randval = batch_stride_randval; + + args.p_drop = p_drop; + args.s_randval = s_randval; + if(drop_prefs) + { + args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(), + drop_offset_buf.GetDeviceBuffer()); + } + else + { + args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); + } + } + else if constexpr(std::is_same_v>) + { + args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); + args.o_acc_ptr = o_acc_buf.GetDeviceBuffer(); + + args.block_table_ptr = + (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + args.is_gappy = false; // use 'false' for flash-attention integration + + args.cache_batch_idx = + (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + + args.num_splits = num_splits; + + args.stride_o_acc = stride_o_acc; + args.nhead_stride_lse_acc = nhead_stride_lse_acc; + args.nhead_stride_o_acc = nhead_stride_o_acc; + args.batch_stride_lse_acc = batch_stride_lse_acc; + args.batch_stride_o_acc = batch_stride_o_acc; + args.split_stride_lse_acc = split_stride_lse_acc; + args.split_stride_o_acc = split_stride_o_acc; + } + } + }; + + const float appendkv_ave_time = [&] { +#if CK_TILE_FMHA_FWD_APPENDKV_API + if(need_append_kvcache) + { + fmha_fwd_appendkv_traits fwd_appendkv_traits; + init_traits(fwd_appendkv_traits); + + fmha_fwd_appendkv_args fwd_appendkv_args; + init_args(fwd_appendkv_args); + + return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config); + } +#endif + return 0.0f; + }(); + + const float fwd_ave_time = [&] { +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(1 < num_splits || use_kvcache) + { + fmha_fwd_splitkv_traits fmha_splitkv_traits; + init_traits(fmha_splitkv_traits); + + fmha_fwd_splitkv_args fmha_splitkv_args; + init_args(fmha_splitkv_args); + + return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); + } +#endif + fmha_fwd_traits fmha_traits; + init_traits(fmha_traits); + + fmha_fwd_args fmha_args; + init_args(fmha_args); + + return fmha_fwd(fmha_traits, fmha_args, stream_config); + }(); + + if(appendkv_ave_time < 0.0f || fwd_ave_time < 0.0f) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return false; + } + + const float ave_time = (appendkv_ave_time + fwd_ave_time); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec + << " GB/s" << std::flush; + + if(do_validation == 0) + { + std::cout << std::flush << std::endl; + return true; + } + if(do_validation == 2) + { + // NOTE: use gpu to do validation + ck_tile::naive_attention_fwd_traits naive_t; + naive_t.q_type = data_type; + naive_t.k_type = data_type; + naive_t.v_type = data_type; + naive_t.o_type = data_type; + naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd"; + naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd"; + naive_t.variation = 0; // TODO? + naive_t.quant_algo = 0; + + ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes()); + + ck_tile::naive_attention_fwd_args naive_a; + naive_a.q_ptr = q_buf.GetDeviceBuffer(); + naive_a.k_ptr = k_buf.GetDeviceBuffer(); + naive_a.v_ptr = v_buf.GetDeviceBuffer(); + naive_a.o_ptr = o_naive_buf.GetDeviceBuffer(); + naive_a.scale_s = scale_s; + naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer + naive_a.page_table_ptr = + nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn) + naive_a.hdim = hdim_q; + naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different + naive_a.batch_q = batch; + naive_a.batch_kv = batch; + naive_a.batch_ratio_kv = 1; // batch_q / batch_kv + naive_a.seqlen_q = seqlen_qs[0]; + naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field + naive_a.nhead_q = nhead; + naive_a.nhead_kv = nhead_k; + naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv + naive_a.page_size = 0; // if paged, the seqlen-kv for each block + + ck_tile::stream_config naive_s{}; + + naive_attention_fwd(naive_t, naive_a, naive_s); + + auto o_naive_ref = o_naive_buf.ToHost(); + o_buf.FromDevice(o_host.data()); // TODO: ugly + + auto [rtol_, atol_] = get_elimit(init_method); + bool pass_ = ck_tile::check_err( + o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_); + std::cout << ", valid:" << (pass_ ? "y" : "n") << std::flush << std::endl; + return pass_; + } + + o_buf.FromDevice(o_host.data()); + lse_buf.FromDevice(lse_host.data()); + randval_buf.FromDevice(randval_host.data()); + + auto p_compute_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::scales{scale_p}; + else + return ck_tile::identity{}; + }(); + + auto oacc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{scale_o}); + else + return ck_tile::identity{}; + }(); + + float p_undrop = 1.0 - p_drop; + uint8_t p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + float rp_undrop = 1.0 / p_undrop; + + bool pass = true; + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t cache_b_idx = + (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = + (mode == mode_enum::batch + ? 0 + : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb])); + + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); + ck_tile::HostTensor v_host_ref({nhead, hdim_v, real_seqlen_k}); + ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); + + ck_tile::HostTensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); + + ck_tile::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); }); + +#if CK_TILE_FMHA_FWD_APPENDKV_API + // optionally apply RoPE to the q_host_ref + if(0 < rotary_dim) + { + decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); + + auto [rotary_cos_slice, rotary_sin_slice] = + slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q); + + ck_tile::reference_batched_rotary_position_embedding( + q_host_ref, rotary_cos_slice, rotary_sin_slice, is_rotary_interleaved, q_host_ref_ro, + /*use_1_row_sin_cos=*/mask.type == mask_enum::no_mask); + + q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); + } +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < page_block_size) { + if(i_perm) { + k_host_ref.ForEach([&](auto& self, auto i) { + self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); + }); + } else { + k_host_ref.ForEach([&](auto& self, auto i) { + self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); + }); + } + } else +#endif + { + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); + } + +#if CK_TILE_FMHA_FWD_APPENDKV_API + // copy Knew to the end of K + if(0 < seqlen_knew) + { + ck_tile::HostTensor knew_host_ref({nhead, seqlen_knew, hdim_q}); + if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); }); + else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); }); + + // optionally apply RoPE to the knew_host_ref + auto* real_knew_host_ref = &knew_host_ref; + std::optional knew_host_ref_ro; + if(0 < rotary_dim) + { + knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); + + auto [rotary_cos_slice, rotary_sin_slice] = + slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew); + + ck_tile::reference_batched_rotary_position_embedding( + knew_host_ref, + rotary_cos_slice, + rotary_sin_slice, + is_rotary_interleaved, + knew_host_ref_ro.value()); + + real_knew_host_ref = &knew_host_ref_ro.value(); + } + + (*real_knew_host_ref).ForEach([&](auto& self, auto i) { + k_host_ref(i[0], i[1] + cache_seqlen_ks[wb], i[2]) = self(i); + }); + } +#endif +#if CK_TILE_FMHA_FWD_SPLITKV_API + if(0 < page_block_size) { + if(is_v_rowmajor) { + if(i_perm) { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); + }); + } else { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); + }); + } + } + else + { + if(i_perm) { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); + }); + } else { + v_host_ref.ForEach([&](auto& self, auto i) { + self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[1], i[0] / nr, i[2] % page_block_size); + }); + } + } + } else +#endif + { + if(is_v_rowmajor) { + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); }); + } + else + { + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); }); + } + } + +#if CK_TILE_FMHA_FWD_APPENDKV_API + // copy Vnew to the end of V + if(0 < seqlen_knew) + { + ck_tile::HostTensor vnew_host_ref({nhead, hdim_v, seqlen_knew}); + if(is_v_rowmajor) + { + if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); }); + else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); }); + } + else + { + if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); }); + else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); }); + } + + vnew_host_ref.ForEach([&](auto& self, auto i) { + v_host_ref(i[0], i[1], i[2] + cache_seqlen_ks[wb]) = self(i); + }); + } +#endif + // clang-format on + + // reference + ck_tile::reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}); + +#ifndef CK_TILE_SCORE_MOD_F +#error "must be defined" +#else +#define XSTR(x) STR(x) +#define STR(x) #x +#pragma message "host score_mod_f: " XSTR(CK_TILE_SCORE_MOD_F) +#endif + + auto score_mod = [](auto s, + ck_tile::index_t b, + ck_tile::index_t h, + ck_tile::index_t q_idx, + ck_tile::index_t v_idx) { + ck_tile::detail::swallow(s, b, h, q_idx, v_idx); + return CK_TILE_SCORE_MOD_F; + }; + + s_host_ref.ForEach([&](auto& self, auto i) { + auto new_score = score_mod(self(i), wb, i[0], i[1], i[2]); + // printf("host score_mod at (%d %lu %lu %lu), score before: %f, score after: %f\n", + // wb, i[0], i[1], i[2], self(i), new_score); + self(i) = new_score; + }); + + auto scale_def = ck_tile::scales(scale_s); + s_host_ref.ForEach([&](auto& self, auto i) { self(i) = scale_def(self(i)); }); + + if(bias.type == bias_enum::elementwise_bias) + { + // elementwise bias + ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + ck_tile::reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + else if(bias.type == bias_enum::alibi) + { + // alibi construct elementwise bias to verify + auto alibi_host = [&]() { + if(mask.type != mask_enum::no_mask) + { + return ck_tile::make_alibi_from_lr_mask( + 0, + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + static_cast(mask.type)); + } + else + { + return ck_tile::Alibi{ + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT}; + } + }(); + + ck_tile::HostTensor alibi_bias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + auto i_b_slope = bias.rank_info == 0 ? 0 : wb; + for(auto i_h = 0; i_h < nhead; i_h++) + { + SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h); + alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope + : -current_slope; + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + SaccDataType pixel = 0; + alibi_host.update(pixel, i_r, i_c); + alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + } + } + } + // [nhead, real_seqlen_q, real_seqlen_k] + ck_tile::reference_batched_elementwise( + s_host_ref, alibi_bias_host_ref, s_host_ref); + } + + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking( + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + } + + auto pre_softmax = [](auto s) { + // ck_tile::detail::swallow(s); + return CK_PRE_SOFTMAX_F; + }; + s_host_ref.ForEach([&](auto& self, auto i) { + auto new_val = pre_softmax(self(i)); + self(i) = new_val; + }); + + if(lse) + { + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref); + } + else + { + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func); + } + + if(p_drop > 0) + { + ck_tile::HostTensor randval_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + randval_host_ref.ForEach([&](auto& self, auto idx) { + self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); + }); + ck_tile::reference_batched_dropout( + p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); + } + + ck_tile::reference_batched_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); + + ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); + // clang-format off + // permute + if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); }); + else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); + // clang-format on + + auto [rtol, atol] = get_elimit(init_method); + bool cur_pass = ck_tile::check_err( + o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + pass &= cur_pass; + if(!cur_pass) + { + std::cerr << "OUT mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + + if(lse) + { + ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); + lse_host_result.ForEach([&](auto& self, auto idx) { + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); + }); + + cur_pass = ck_tile::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); + + pass &= cur_pass; + if(!cur_pass) + { + std::cerr << "LSE mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp8") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/18_flexattn/fmha_fwd.hpp b/example/ck_tile/18_flexattn/fmha_fwd.hpp new file mode 100644 index 00000000000..5fc12a61271 --- /dev/null +++ b/example/ck_tile/18_flexattn/fmha_fwd.hpp @@ -0,0 +1,823 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/flex_fmha.hpp" + +#include "bias.hpp" +#include "mask.hpp" +#include "rotary.hpp" + +#include +#include +#include + +struct FmhaFwdFp16 +{ +}; + +struct FmhaFwdBf16 +{ +}; + +struct FmhaFwdFp8 +{ +}; + +struct FmhaFwdBf8 +{ +}; + +struct FmhaFwdFp8Fp16 +{ +}; + +struct FmhaFwdFp8Bf16 +{ +}; + +template +struct FmhaFwdTypeConfig; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::fp8_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf8_t; + using KDataType = ck_tile::bf8_t; + using VDataType = ck_tile::bf8_t; + using BiasDataType = ck_tile::bf8_t; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf8_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* + seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + float p_drop; + bool s_randval; + + std::variant, std::pair> + drop_seed_offset; +}; + +struct fmha_fwd_splitkv_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* lse_acc_ptr; + void* o_acc_ptr; + void* lse_ptr; + void* o_ptr; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not + // nullptr. + + const void* cache_batch_idx; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode: seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k + // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] + // + // batch mode (kvcache): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k_ptr[b] + // group mode (kvcache): + // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // + // when is_gappy=true: + // seqlen_k = kargs.seqlen_k_ptr[b] + // seqstart_k_ptr[b] now store local offset of each batch + // + // when is_gappy=false: + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + ck_tile::index_t num_splits; + + float scale_s; + float scale_p; + float scale_o; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o_acc; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_lse_acc; + ck_tile::index_t nhead_stride_o_acc; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_lse_acc; + ck_tile::index_t batch_stride_o_acc; + ck_tile::index_t batch_stride_o; + ck_tile::index_t split_stride_lse_acc; + ck_tile::index_t split_stride_o_acc; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; +}; + +struct fmha_fwd_appendkv_args +{ + void* q_ptr; + void* k_ptr; + const void* knew_ptr; + void* v_ptr; + const void* vnew_ptr; + + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_knew; + ck_tile::index_t batch; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0 + const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0 + ck_tile::index_t rotary_dim; + bool has_mask; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + + const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache) + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_knew; + ck_tile::index_t stride_v; + ck_tile::index_t stride_vnew; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_knew; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_vnew; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_knew; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_vnew; +}; + +template +auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + if constexpr(FmhaKernel::kIsGroupMode) + { + dim3 grids = FmhaKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = + FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } +} + +template +auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.is_gappy, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_k, // only used for paged-kvcache + args.batch_stride_v, // only used for paged-kvcache + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_acc_ptr, + args.o_acc_ptr, + args.batch, + args.seqlen_q, + args.seqlen_k, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o_acc, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.split_stride_lse_acc, + args.split_stride_o_acc, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + }(); + + dim3 grids = Kernel::GridSize( + args.batch, args.nhead_q, args.nhead_k, args.max_seqlen_q, args.hdim_v, args.num_splits); + + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel argumentszs + if constexpr(Kernel::kIsGroupMode) + { + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.seqstart_q_ptr, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + else + { // create batch mode kernel arguments + return Kernel::MakeKargs(args.lse_acc_ptr, + args.o_acc_ptr, + args.lse_ptr, + args.o_ptr, + args.batch, + args.seqlen_q, + args.hdim_v, + args.num_splits, + args.scale_o, + args.stride_o_acc, + args.stride_o, + args.nhead_stride_lse_acc, + args.nhead_stride_o_acc, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_lse_acc, + args.batch_stride_o_acc, + args.batch_stride_lse, + args.batch_stride_o, + args.split_stride_lse_acc, + args.split_stride_o_acc); + } + }(); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + + return ck_tile::make_tuple(kargs, grids); +} + +template +auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = Kernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.knew_ptr, + args.v_ptr, + args.vnew_ptr, + args.seqlen_q, + args.seqlen_k_ptr, + args.seqlen_knew, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.rotary_cos_ptr, + args.rotary_sin_ptr, + args.rotary_dim, + args.has_mask, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.stride_q, + args.stride_k, + args.stride_knew, + args.stride_v, + args.stride_vnew, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_knew, + args.nhead_stride_v, + args.nhead_stride_vnew, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_knew, + args.batch_stride_v, + args.batch_stride_vnew); + + dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew); + + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_fwd_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kHasDropout = kHasDropout_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); + +template +struct fmha_fwd_splitkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kIsPagedKV = kIsPagedKV_; +}; + +template +void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); + +template +std::string fmha_fwd_splitkv_get_name_(); + +template +struct fmha_fwd_splitkv_combine_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args); + +template +std::string fmha_fwd_splitkv_combine_get_name_(); + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_fwd_appendkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_; + static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_; + static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_; + static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSk = kPadSk_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr auto RotaryEnum = RotaryEnum_; + static constexpr bool kIsPagedKV = kIsPagedKV_; +}; + +template +float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args); + +// This is the public API, will be generated by script +struct fmha_fwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + // bool has_lse; + // bool has_dropout; + // bool do_fp8_static_quant; + // TODO: padding check is inside this api +}; +float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); + +struct fmha_fwd_splitkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse; + bool do_fp8_static_quant; + // TODO: padding check is inside this api +}; +float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, + fmha_fwd_splitkv_args, + const ck_tile::stream_config&); + +struct fmha_fwd_appendkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_v_rowmajor; + rope_enum rope_type; +}; +float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, + fmha_fwd_appendkv_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/18_flexattn/generate.py b/example/ck_tile/18_flexattn/generate.py new file mode 100644 index 00000000000..59db8f19af0 --- /dev/null +++ b/example/ck_tile/18_flexattn/generate.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import pkgutil +import sys +from typing import List, Optional + +import codegen.ops +from codegen.cmake_config import * + + +class HandlerId(IntEnum): + LIST_BLOBS = 0 + WRITE_BLOBS = 1 + +# inspect all modules under 'codegen.ops' and register API handlers +ops = [] +for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): + full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) + if full_module_name not in sys.modules: + ops.append(importer.find_spec(module_name).loader.load_module(module_name)) +unwanted_prefix = 'fmha_' +handlers = dict( + [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__, + (op.list_blobs, op.write_blobs)) for op in ops] +) +assert 0 < len(handlers) + +def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None: + if output_dir is None: + output_dir = Path(__file__).parent + else: + output_dir = Path(output_dir) / GEN_DIR + + output_dir.mkdir(parents=True, exist_ok=True) + + for api in api_list: + handler = handlers[api][HandlerId.WRITE_BLOBS] + handler(output_dir, kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr) + +# list all the files that will be generated +def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl, score_mod_expr, pre_softmax_expr) -> None: + assert output_file is not None + file_path = Path(output_file) + + # create an empty file / drop its contents if it exists + open(file_path, "w").close() + + for api in api_list: + handler = handlers[api][HandlerId.LIST_BLOBS] + handler(file_path, kernel_filter, receipt, mask_impl, score_mod_expr, pre_softmax_expr) + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK fmha kernel", + ) + parser.add_argument( + "-d", + "--direction", # we keep 'direction' option for backward compatibility + "-a", + "--api", + default='fwd', + required=False, + help="supply API(s) to generate (default: fwd). separated by comma." + ) + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="write all the blobs into a directory" + ) + parser.add_argument( + "-l", + "--list_blobs", + required=False, + help="list all the kernels to a file" + ) + # TODO: if using filter, must apply same value to output_dir and list_blobs + parser.add_argument( + "-f", + "--filter", + required=False, + help="filter out kernels that need to generate, using fnmatch module" + ) + + parser.add_argument( + "-m", + "--mask", + default="simplified", + required=False, + help="mask implementation, simplified/generic" + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ + " 1: generate more instance to cover all hdim\n" + \ + " 2: Only generate instance for Flash attention integration" + ) + + parser.add_argument( + "--score_mod_expr", + default="s", + required=False, + help="flex attention's score mod function, a cpp expression with `s`, `b`, `h`, `q_idx`, and `v_idx` variables" + ) + + parser.add_argument( + "--pre_softmax_expr", + default="s", + required=False, + help="flex attention's pre_softmax function, a cpp expression with `s` variable" + ) + + args = parser.parse_args() + api_list = args.direction.split(',') + if args.list_blobs is not None: + list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr, pre_softmax_expr=args.pre_softmax_expr) + else: + write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask, score_mod_expr=args.score_mod_expr, pre_softmax_expr=args.pre_softmax_expr) diff --git a/example/ck_tile/18_flexattn/mask.hpp b/example/ck_tile/18_flexattn/mask.hpp new file mode 100644 index 00000000000..37978132e3e --- /dev/null +++ b/example/ck_tile/18_flexattn/mask.hpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/flex_fmha.hpp" + +// keep this in sync with ck_tile::GenericAttentionMaskEnum +enum class mask_enum +{ + no_mask = 0, + mask_top_left, + mask_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck_tile::index_t y, x; + ck_tile::index_t left, right; // FA style SWA left/right + + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::mask_top_left) + os << "t(" << left << ":" << right << ")"; + else if(type == mask_enum::mask_bottom_right) + os << "b(" << left << ":" << right << ")"; + else + { + os << "g(" << y << ":" << x << ")"; + } + } + static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) + { + ck_tile::index_t x_total = seqlen_k; + ck_tile::index_t y_total = seqlen_q; + mask_info tmp; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + if(t == "xt" || t == "xb") + { + // xformer style sliding window attn from top-left + ck_tile::index_t window_size = atoi(v.c_str()); + ck_tile::index_t left_size = -1; + ck_tile::index_t right_size = 0; + if(window_size > 0) + { + left_size = window_size / 2; + right_size = window_size - 1 - left_size; + } + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, t == "xt"); + + tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = left_size; + tmp.right = right_size; + } + else + { + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + printf("not supported value %s, %s\n", v.c_str(), str.c_str()); + assert(0); + } + tmp.type = mask_enum::window_generic; + ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + tmp.type = mask_enum::mask_top_left; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "b") + { + tmp.type = mask_enum::mask_bottom_right; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "g") + { + tmp.y = v0; + tmp.x = v1; + tmp.left = v0; // TODO: don't use this? + tmp.right = v1; + } + else + { + printf("not supported type %s, %s\n", t.c_str(), str.c_str()); + assert(0); + } + } + } + else + { + auto set_causal_top_left = [&]() { + tmp.type = mask_enum::mask_top_left; + tmp.y = seqlen_q; + tmp.x = 1; + tmp.left = -1; + tmp.right = 0; + }; + auto set_causal_bottom_right = [&]() { + tmp.type = mask_enum::mask_bottom_right; + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + tmp.left = -1; + tmp.right = 0; + }; + if(str == "t") + set_causal_top_left(); + else if(str == "b") + set_causal_bottom_right(); + else + { + tmp.type = static_cast(atoi(str.c_str())); + if(tmp.type == mask_enum::mask_top_left) + { + set_causal_top_left(); + } + else if(tmp.type == mask_enum::mask_bottom_right) + { + set_causal_bottom_right(); + } + } + } + return tmp; + } + + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + { + mi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/18_flexattn/rotary.hpp b/example/ck_tile/18_flexattn/rotary.hpp new file mode 100644 index 00000000000..346f2a5e7ee --- /dev/null +++ b/example/ck_tile/18_flexattn/rotary.hpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +#include +#include +#include +#include +#include +#include +#include + +// keep sync with RotaryEmbeddingEnum +enum class rope_enum +{ + none = 0, + interleaved = 1, + half_rotated = 2, +}; + +template +std::tuple, ck_tile::HostTensor> +generate_rotary_cos_sin(ck_tile::index_t seqlen, + ck_tile::index_t rotary_dim, + std::optional seed = std::nullopt) +{ + // return dummy tensors if we won't apply RoPE at all + if(rotary_dim <= 0) + { + ck_tile::HostTensor dummy({1, 1}); + return std::make_tuple(dummy, dummy); + } + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_real_distribution generator(0.0f, 1.0f); + + const ck_tile::index_t num_rows = seqlen * 2; + const ck_tile::index_t num_cols = rotary_dim / 2; + + using std::begin, std::end; + + ck_tile::HostTensor angle({num_rows, num_cols}); + std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; }); + + ck_tile::HostTensor cos({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) { + return ck_tile::type_convert(std::cos(origin_value)); + }); + + ck_tile::HostTensor sin({num_rows, num_cols}); + std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) { + return ck_tile::type_convert(std::sin(origin_value)); + }); + + return std::make_tuple(cos, sin); +} + +template +std::tuple, ck_tile::HostTensor> +slice_rotary_cos_sin(const ck_tile::HostTensor& cos, + const ck_tile::HostTensor& sin, + ck_tile::index_t seqlen_offset, + ck_tile::index_t seqlen) +{ + assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2); + assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1)); + + assert(static_cast(seqlen_offset + seqlen) <= cos.get_length(0)); + + const ck_tile::index_t num_rows = seqlen; + const ck_tile::index_t num_cols = cos.get_length(1); + + ck_tile::HostTensor cos_pt({num_rows, num_cols}); + cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); }); + + ck_tile::HostTensor sin_pt({num_rows, num_cols}); + sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); }); + + return std::make_tuple(cos_pt, sin_pt); +} diff --git a/example/ck_tile/18_flexattn/script/benchmark_fwd.sh b/example/ck_tile/18_flexattn/script/benchmark_fwd.sh new file mode 100755 index 00000000000..5a92e4e58b1 --- /dev/null +++ b/example/ck_tile/18_flexattn/script/benchmark_fwd.sh @@ -0,0 +1,20 @@ +#!/bin/sh +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_flexattn_fwd -type f | head -n 1)" +VALID=0 + +for prec in "bf16" ; do +for perm in 0 ; do +for hdim in 64 128 256 ; do + +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 + +done +done +done diff --git a/example/ck_tile/18_flexattn/script/run_full_test.sh b/example/ck_tile/18_flexattn/script/run_full_test.sh new file mode 100755 index 00000000000..6be1c2630ee --- /dev/null +++ b/example/ck_tile/18_flexattn/script/run_full_test.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# +# in order to run this script you'd first need to build the tile_example_fmha_fwd and tile_eaxmple_fmha_bwd executables in ../build/bin/ +# +# run the script as "./run_full_test.sh +# input arguments: +# environment tag : a string describing the specifics of your test environment +# branch name : name of the branch in git repo (git status | grep -e 'On branch') +# host name : $hostname +# gpu architecture: e.g., gfx90a, or gfx942, etc. + +#get the command line arguments: +export env_type=$1 +echo 'Environment type: ' $env_type +export branch=$2 +echo 'Branch name: ' $branch +export host_name=$3 +echo 'Host name: ' $host_name +export GPU_arch=$4 +echo 'GPU_arch: ' $GPU_arch + +function print_log_header(){ + rm -f $1; + echo 'On branch ' $3 &> $1; + echo 'Node name: ' $4 >> $1; + #get GPU_arch and number of compute units from rocminfo + echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1; + rocminfo | grep "Compute Unit:" >> $1; + hipcc --version | grep -e 'HIP version' >> $1; + echo 'Environment type: ' $2 >> $1; + /opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1; +} + +#run verification tests +example/ck_tile/18_flexattn/script/smoke_test_fwd.sh + +#run performance benchmarks +export fmha_fwd_log="perf_tile_flex_attn_$GPU_arch.log" +print_log_header $fmha_fwd_log $env_type $branch $host_name +echo "Running performance benchmark for tile_flex_attn_fwd" +example/ck_tile/18_flexattn/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log +echo "Finishing performance benchmark for tile_flex_attn_fwd" + diff --git a/example/ck_tile/18_flexattn/script/smoke_test_fwd.sh b/example/ck_tile/18_flexattn/script/smoke_test_fwd.sh new file mode 100755 index 00000000000..5199b2c3676 --- /dev/null +++ b/example/ck_tile/18_flexattn/script/smoke_test_fwd.sh @@ -0,0 +1,106 @@ +#!/bin/bash +# TODO: run this script from CK root or build directory +EXE="$(find . -name tile_example_flexattn_fwd -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' +# mode=0 +# export HIP_VISIBLE_DEVICES=4 + +TEST_SPLITKV=0 +TEST_APPENDKV=0 +# options: +# -s: run splitkv tests +# -a: run appendkv tests +while getopts ":sa" opt; do + case "${opt}" in + s) + TEST_SPLITKV=1 + ;; + a) + TEST_APPENDKV=1 + ;; + *) + ;; + esac +done + +run_fp16_bf16_tests() { + local NUM_SPLITS="1" + local PAGE_BLOCK_SIZE="0" + local CACHE_BATCH_IDX="0" + + if [ $TEST_SPLITKV -eq 1 ] ; then + NUM_SPLITS="$NUM_SPLITS 2 3" + PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128" + CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1" + fi + + for prec in "bf16" ; do + for mode in 0 1; do + for perm in 0 ; do + for vlayout in "r" ; do + for hdim in 32 64 128 256 ; do + for lse in 0 ; do + for bias in "n" ; do + for p_drop in 0.0 ; do + for num_splits in $NUM_SPLITS ; do + for page_block_size in $PAGE_BLOCK_SIZE ; do + for cache_batch_idx in $CACHE_BATCH_IDX ; do + + # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done ; done + done ; done ; done ; done ; done + done ; +} + +run_fp8_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp16_appendkv_tests() { + for s in $(seq 63 1 65) ; do + for s_k in 65 129 ; do + for s_knew in 0 64 $s_k ; do + for hdim in 32 64 128 256 ; do + for ri in 0 1 ; do + for rdim in 0 16 32 $hdim ; do + for page_block_size in 0 128 ; do + for cache_batch_idx in 0 1 ; do + + $EXE -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS + + done ; done ; done ; done ; done + done ; done ; done +} + +set -x + +run_fp16_bf16_tests +# run_fp8_tests + +# if [ $TEST_APPENDKV -eq 1 ] ; then +# run_fp16_appendkv_tests +# fi + +set +x diff --git a/example/ck_tile/18_flexattn/utils.hpp b/example/ck_tile/18_flexattn/utils.hpp new file mode 100644 index 00000000000..faf3f08437a --- /dev/null +++ b/example/ck_tile/18_flexattn/utils.hpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/container/span.hpp" + +enum class mode_enum +{ + batch = 0, + group +}; + +std::ostream& operator<<(std::ostream& stream, mode_enum mode) +{ + return stream << (mode == mode_enum::batch ? "batch" : "group"); +} + +std::vector to_seqstarts(ck_tile::span seqlens) +{ + std::vector seqstarts = {0}; + for(int32_t seqlen : seqlens) + { + seqstarts.push_back(seqstarts.back() + seqlen); + } + assert(seqstarts.size() == seqlens.size() + 1); + return seqstarts; +} + +std::vector generate_seqlens(mode_enum mode, + unsigned count, + int32_t seqlen_avg, + int32_t seqlen_min = -1, // if not negative, clamp min + int32_t seqlen_max = -1, // if not negative, clamp max + std::optional seed = std::nullopt) +{ + assert(0 < count); + + seqlen_min = (0 < seqlen_min ? seqlen_min : 1); + seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits::max()); + assert(seqlen_min <= seqlen_max); + + std::vector seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max)); + + if(mode == mode_enum::group && 1 < count) + { + using size_type = std::vector::size_type; + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution idx_dist(0, count - 1); + auto next_idx = std::bind(idx_dist, std::ref(random_engine)); + + std::uniform_int_distribution step_dist(1, count - 1); + auto next_step = std::bind(step_dist, std::ref(random_engine)); + + for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat) + { + const size_type to_decrease = next_idx(); + // make sure each elements of seqlens is in range [seqlen_min, seqlen_max] + if(seqlens[to_decrease] == seqlen_min) + { + continue; + } + + const size_type to_increase = (to_decrease + next_step()) % count; + + if(seqlens[to_increase] >= seqlen_max) + { + continue; + } + + --seqlens[to_decrease]; + ++seqlens[to_increase]; + } + } + + return seqlens; +} + +std::vector generate_seqstarts(mode_enum mode, + unsigned count, + int32_t seqlen_avg, + int32_t seqlen_min = -1, + int32_t seqlen_max = -1, + std::optional seed = std::nullopt) +{ + return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_min, seqlen_max, seed)); +} + +// return random integer generated uniformly in range [low, high] +template +auto randint(Int low, Int high, std::optional seed = std::nullopt) + -> std::enable_if_t, Int> +{ + std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution dist(low, high); + return dist(engine); +} + +// return random integers generated uniformly in range [low, high] +template +auto randints(ForwardIterator first, + ForwardIterator last, + Int low, + Int high, + std::optional seed = std::nullopt) + -> std::enable_if_t> +{ + std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution dist(low, high); + + std::generate(first, last, [&] { return dist(engine); }); +} + +/* + * decode the seqlen string from cmdline + * example (assume batch=3) + * q_val=1,2,3 k_val=4,5,6 -> OK + * q_val=1,2,3 -> OK, k same as q + * q_val=1,2 -> OK, q will rand remaining 1 element, k same as q + * q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element + * q_val=1,2,3,4 -> OK, but ignore exceed one + * + * q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q + * q_val=1,2 k_val=4 -> not OK, k must have same splits with q + */ +std::tuple, + std::vector, + std::vector> +decode_seqlen(mode_enum mode, + ck_tile::index_t batch, + std::string q_val, + std::string k_val, + std::string k_pad_val, + ck_tile::index_t seqlen_k_min = 0, + bool need_append_kvcache = false, + std::optional seed = std::nullopt) +{ +#define _S2I_(str_) static_cast(std::atoi((str_).c_str())) + if(mode == mode_enum::batch) + { + ck_tile::index_t q = _S2I_(q_val); + ck_tile::index_t k = _S2I_(k_val); + + auto s_q = std::vector(batch, q); + auto s_k = [&] { + const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k); + std::vector seqlen_ks(batch, seqlen_k_max); + + if(1 < batch && need_append_kvcache) + { + // to keep the original s_k value, we always use seqlen_k_max in first batch + randints(std::next(seqlen_ks.begin()), + seqlen_ks.end(), + seqlen_k_min, + seqlen_k_max, + seed); + return seqlen_ks; + } + + return seqlen_ks; + }(); + auto s_kpad = std::vector(batch, -1); // TODO: batch not support k_padding + + // s_k should be greater than or equal to seqlen_k_min if provided + if(s_k.back() < seqlen_k_min) + { + std::ostringstream msg; + msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() + << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; + throw std::runtime_error(msg.str()); + } + + return std::make_tuple(s_q, s_k, s_kpad); + } + else + { + ck_tile::index_t idx = 0; + std::string::size_type pos_q = 0; + std::string::size_type pos_k = 0; + std::string::size_type pos_kp = 0; + std::vector s_q; + std::vector s_k; + std::vector s_kpad; + while(true) + { + auto found_q = q_val.find(',', pos_q); + auto found_k = k_val.find(',', pos_k); + auto found_kp = k_pad_val.find(',', pos_kp); + + ck_tile::index_t q = _S2I_( + q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q)); + ck_tile::index_t k = _S2I_( + k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k)); + ck_tile::index_t kp = _S2I_(k_pad_val.substr( + pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp)); + + s_q.push_back(q); + s_k.push_back(k < 0 ? q : k); + s_kpad.push_back(kp); + + // s_k should be greater than or equal to seqlen_k_min + if(s_k.back() < seqlen_k_min) + { + std::ostringstream msg; + msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back() + << ") is less than minimum seqlen_k (=" << seqlen_k_min << ")"; + throw std::runtime_error(msg.str()); + } + + idx++; + if(found_q == std::string::npos || idx >= batch) + { + break; + } + pos_q = found_q + 1; + pos_k = found_k == std::string::npos ? pos_k : found_k + 1; + pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1; + } + if(idx < batch) + { + auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), 1, s_kpad.back(), seed); + auto rem_k = + generate_seqlens(mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed); + + s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); + s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); + s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back()); + } + return std::make_tuple(s_q, s_k, s_kpad); + } +#undef _S2I_ +} + +int env_get_int(const char* var_name, int default_int) +{ + char* v = getenv(var_name); + int r = default_int; + if(v) + r = std::atoi(v); + return r; +} + +template +std::enable_if_t> iota_shuffle(RandomAccessIterator first, + RandomAccessIterator last, + Int value, + std::optional seed = std::nullopt) +{ + std::iota(first, last, value); + + std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}()); + std::shuffle(first, last, engine); +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 7f4ba2ed359..74c0f6db1de 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -17,4 +17,5 @@ add_subdirectory(14_moe_smoothquant) add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) +add_subdirectory(18_flexattn) add_subdirectory(35_batched_transpose) diff --git a/include/ck_tile/ops/flex_fmha.hpp b/include/ck_tile/ops/flex_fmha.hpp new file mode 100644 index 00000000000..c168be13aad --- /dev/null +++ b/include/ck_tile/ops/flex_fmha.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/fmha/block/block_position_encoding.hpp" +#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" +#include "ck_tile/ops/fmha/block/page_block_navigator.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_flex_fwd_kernel.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs_async.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_flex_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_flex_fwd_kernel.hpp new file mode 100644 index 00000000000..39b7ef12743 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_flex_fwd_kernel.hpp @@ -0,0 +1,1401 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" + +#include +#include +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct FmhaFwdKernel +{ + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using g0br = typename bfs::Gemm0BlockWarps; + using g1br = typename bfs::Gemm1BlockWarps; + using g0wt = typename bfs::Gemm0WarpTile; + using g1wt = typename bfs::Gemm1WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + float scale_s; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + }; + + struct FmhaFwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct FmhaFwdBatchModeBiasKargs : FmhaFwdCommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct FmhaFwdAlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + + struct FmhaFwdMaskKargs + { + // ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaFwdFp8StaticQuantKargs + { + float scale_p; + float scale_o; + }; + + struct FmhaFwdCommonLSEKargs + { + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; + }; + + struct FmhaFwdDropoutSeedOffset + { + template + union ValueOrPointer + { + T val; + const T* ptr; + }; + + ValueOrPointer drop_seed; + ValueOrPointer drop_offset; + bool is_drop_seed_offset_from_host; + }; + + struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset + { + void init_dropout(float p_drop, uint64_t seed, uint64_t offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed.val = seed; + this->drop_offset.val = offset; + this->is_drop_seed_offset_from_host = true; + } + + void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed.ptr = seed_ptr; + this->drop_offset.ptr = offset_ptr; + this->is_drop_seed_offset_from_host = false; + } + + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + + struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; + }; + + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + }; + + struct FmhaFwdGroupModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + const int32_t* seqstart_k_ptr; + const int32_t* seqlen_k_ptr; + }; + + using Kargs = std::conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + kargs.scale_o = scale_o; + } + if constexpr(kHasDropout) + { + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t seqlen_k, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqlen_q, + seqlen_k, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_randval, + batch_stride_lse, + batch_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + reinterpret_cast(seqstart_q_ptr), + reinterpret_cast(seqstart_k_ptr), + reinterpret_cast(seqlen_k_ptr)}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + kargs.scale_o = scale_o; + } + if constexpr(kHasDropout) + { + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } + + return kargs; + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + // std::variant<> can't take in a list initializer, overload for backward compatibility + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + const void* seqstart_k_ptr, + const void* seqlen_k_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + float scale_s, + float scale_p, + float scale_o, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + const std::tuple& drop_seed_offset) + { + return MakeKargsImpl( + q_ptr, + k_ptr, + v_ptr, + bias_ptr, + rand_val_ptr, + lse_ptr, + o_ptr, + seqstart_q_ptr, + seqstart_k_ptr, + seqlen_k_ptr, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + scale_s, + scale_p, + scale_o, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_randval, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_randval, + nhead_stride_lse, + nhead_stride_o, + window_size_left, + window_size_right, + mask_type, + p_drop, + s_randval, + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_, + bool has_padded_seqlen_k = false) + { + // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr) + if(has_padded_seqlen_k) + { + // TODO: this may need tuning + return dim3(nhead_, + batch_size_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1)); + } + else + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + nhead_, + batch_size_); + } + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + bool has_padded_seqlen_k = false; + + if constexpr(kIsGroupMode) + has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr); + + if(has_padded_seqlen_k) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + batch_offset_k = key_start * kargs.stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias + key_start; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; + kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + + if(kargs.seqlen_k_ptr != nullptr) + { + kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch]; + } + else + { + const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; + kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0]; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k + + batch_offset_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + + batch_offset_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen_k, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen_k)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen_k), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch_, + i_nhead_, + kargs.num_head_q, + kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val + : *kargs.drop_seed.ptr, + kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val + : *kargs.drop_offset.ptr, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + kargs.is_store_randval}; + } + else + { + return NullBlockDropout{}; + }; + }(); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + // may have state inside + auto score_mod_def = ScoreModFunction_{}; + + auto score_mod_arg = + [b = i_batch, h = i_nhead, score_mod_def](typename ScoreModFunction_::TScore s, + ck_tile::index_t q_idx, + ck_tile::index_t v_idx) { + auto new_score = score_mod_def( + s, b, h, q_idx, v_idx); // printf("device score_mod at (%d %d %d %d), score + // before: %f, score after: %f score_clip: %f\n", + // b, h, q_idx, v_idx, s, new_score, new_score_after_clip); + return new_score; + }; + + auto pre_softmax_def = PreSoftmaxFunction_{}; + auto pre_softmax_arg = [pre_softmax_def](typename PreSoftmaxFunction_::TScore s) { + return pre_softmax_def(s); + }; + + auto o_acc_tile = [&]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + pre_softmax_arg, // s_acc_element_func + score_mod_arg, + scales{kargs.scale_p}, // p_compute_element_func + composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + smem_ptr, + dropout); + } + else + { + return FmhaPipeline{}(q_dram_window, + identity{}, + k_dram_window, + identity{}, + v_dram_window, + identity{}, + bias_dram_window, + identity{}, + randval_dram_window, + lse_dram_window, + identity{}, + pre_softmax_arg, + score_mod_arg, + identity{}, + identity{}, + mask, + position_encoding, + kargs.scale_s, + smem_ptr, + dropout); + } + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs.hpp new file mode 100644 index 00000000000..4e48fe9e34c --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs.hpp @@ -0,0 +1,627 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +template +struct BlockFmhaPipelineQRKSVS +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + if constexpr(kQKHeaddim <= 32) + { + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const ScoreModFunction& score_mod, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + DropoutType& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // K tile in LDS + KDataType* k_lds_ptr = static_cast(static_cast( + static_cast(smem_ptr) + Policy::template GetSmemSizeQ())); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + + auto q = load_tile(q_dram_window); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -numeric::infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // Note: here occ are all cleard, return it + // Note: q loaded but no fence, ignore it. + return o_acc; + } + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + auto q_tile = tile_elementwise_in(q_element_func, q); + + // prefetch K tile + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(2 <= k0_loops); + static_assert(1 <= k1_loops); + do + { + // STAGE 1, QK gemm + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + + auto k_block_tile = load_tile(k_dram_window); + { + move_tile_window(k_dram_window, {0, kK0}); + clear_tile(s_acc); // initialize C + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + k_block_tile = load_tile(k_dram_window); + } + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + __builtin_amdgcn_sched_barrier( + 0); // prevent from messing up the order of global loads + } + + if constexpr(k0_loops > 2) + { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, i_k0 * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + move_tile_window(k_dram_window, {0, kK0}); + + store_tile( + k_lds_window, + tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 + k_block_tile = load_tile(k_dram_window); // global read i + 2 + }); + } + + const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + { // tail + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 2) * kK0>{}, + sequence{}), + k_lds_window); + block_sync_lds(); + + store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); + block_sync_lds(); + + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_lds_window); + } + + // STAGE 2, scale_s, add bias, mask, softmax + // FlexAttention score modifier operates directly on scores + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col); + }); + }); + } + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + if constexpr(kHasDropout) + { + dropout.template Run( + smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); + } + + block_sync_lds(); + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_prefetch); + store_tile( + v_lds_window, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch + } + move_tile_window(v_dram_window, {0, kK1}); + + const auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + const auto v = load_tile(v_dram_window); // load next v + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v); + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_window, + tile_elementwise_in(v_element_func, v)); // store next v + } + move_tile_window(v_dram_window, {0, kK1}); + }); + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + // tail + { + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + v_lds_window); + block_sync_lds(); + } + } while(++i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } +#else + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs_async.hpp new file mode 100644 index 00000000000..d4acab79746 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_flex_qr_ks_vs_async.hpp @@ -0,0 +1,764 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) +template +struct BlockFmhaPipelineQRKSVSAsync +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) + // only need special care about seq_k padding (oob need set -INF of p instead of zero) + static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + Problem::kPadHeadDimV == true); + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static constexpr auto R_LOG2E = 1.0 / log2e_v; +#endif + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + // minimize occupancy + if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) + { + return 1; + } + + if constexpr(kQKHeaddim <= 32) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && + FmhaMask::IsMasking) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 2; + else + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + } + }(); + + static constexpr const char* name = "qr_async"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& /*k_element_func*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const ScoreModFunction& score_mod, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + void* smem_ptr, + DropoutType& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), + {0, 0, 0}); + }, + number{}); + +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + auto k_lds_load = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor(i_buf)), + Policy::template MakeKLdsLoadBlockDescriptor(i_buf).get_lengths(), + {0, 0}); + }, + number{}); +#else + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), + {0, 0}); +#endif + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window( + q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQDramTileDistribution()); + q_dram_window.init_raw(); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = decltype(load_tile(q_dram_window)){}; + // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q. + // however, q would be cleared in the constructor of static distributed tensor + // set_tile(q, number<0>{}); // use per-dword clear to avoid scratch + load_tile_raw(q, q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -numeric::infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) + + // Note: here occ are all cleard, return it + return o_acc; + } + __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + auto k_dram_window = make_tile_window( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); // K DRAM tile window for + // load + k_dram_window.init_raw(); + constexpr auto k_oob_ck = bool_constant{}; + constexpr auto k_pre_np = [&]() { + if constexpr(kPadSeqLenK && + (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + (BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout))) + return bool_constant{}; + else + return bool_constant{}; + }(); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + + // prefetch K tile + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer()); + (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 + // auto q_tile = q; // tile_elementwise_in(q_element_func, q); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + // main loop + do + { + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.get_num_of_access()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, i_k0 * kK0>{}, sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[number{})>{}]); + +#else + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); +#endif + }); + } + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), +#if K_LDS_LOAD_USE_OFFSET_TRANSFORM + k_lds_load[number{})>{}]); + +#else + get_slice_tile( + k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); +#endif + } + __builtin_amdgcn_sched_barrier(1); + + // STAGE 2, scale_s, add bias, mask, softmax + // FlexAttention score modifier operates directly on scores + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) = score_mod(s_acc(i_j_idx), row, col); + }); + }); + } + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + if constexpr(kHasDropout) + { + auto randval_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + dropout.template Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } + + const auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pk_fp16_fp32( + tile_elementwise_in(p_compute_element_func, p_compute)); + else + return cast_tile( + tile_elementwise_in(p_compute_element_func, p_compute)); + }(); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + } + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + + if constexpr(k1_loops >= 2 && + LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail + { + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + } + } while(i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); + } +#else + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } +}; + +} // namespace ck_tile