From 567d1dbd87b266a469a53d8f5104b7084a227836 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 9 Oct 2025 14:04:19 +0800 Subject: [PATCH 01/11] chore: misc cleanup --- .editorconfig | 5 ++++- CMakeLists.txt | 4 ++-- setup.py | 8 ++++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/.editorconfig b/.editorconfig index 10ac9729a..a9e8a6df4 100644 --- a/.editorconfig +++ b/.editorconfig @@ -14,7 +14,10 @@ insert_final_newline = true indent_size = 4 [*.{cpp,hpp,cxx,cc,c,h,cu,cuh}] -indent_size = 4 +indent_size = 2 + +[{*.cmake,CMakeLists.txt}] +indent_size = 2 [*.{yaml,yml}] indent_size = 2 diff --git a/CMakeLists.txt b/CMakeLists.txt index e40b7b027..80e9454fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,7 +56,7 @@ else() # Set default build type to RelWithDebInfo if not provided if(NOT CMAKE_BUILD_TYPE) - # Set default build type to Release if not provided + # Set default build type to Release if not provided set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) message(STATUS "Setting default build type to ${CMAKE_BUILD_TYPE}") endif() @@ -199,7 +199,7 @@ if(USE_CUDA) set(CUDA_MAJOR_VERSION ${CUDAToolkit_VERSION_MAJOR}) message(STATUS "Setting CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}") add_compile_definitions(CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}) - + list(APPEND TILE_LANG_INCLUDES ${CUDAToolkit_INCLUDE_DIRS}) endif(USE_CUDA) diff --git a/setup.py b/setup.py index fc9a5ca59..d4c3152af 100644 --- a/setup.py +++ b/setup.py @@ -417,7 +417,7 @@ def patch_libs(libpath): subprocess.run([patchelf_path, '--set-rpath', '$ORIGIN', libpath]) -class TileLangBuilPydCommand(build_py): +class TileLangBuildPyCommand(build_py): """Customized setuptools install command - builds TVM after setting up LLVM.""" def run(self): @@ -643,7 +643,7 @@ def __init__(self, name, sourcedir=""): self.sourcedir = os.path.abspath(sourcedir) -class TilelangExtensionBuild(build_ext): +class TileLangExtensionBuild(build_ext): """ Custom build_ext command for CMake-based projects. @@ -929,8 +929,8 @@ def build_cmake(self, ext): CythonExtension("TileLangCython", sourcedir="."), ], cmdclass={ - "build_py": TileLangBuilPydCommand, + "build_py": TileLangBuildPyCommand, "sdist": TileLangSdistCommand, - "build_ext": TilelangExtensionBuild, + "build_ext": TileLangExtensionBuild, }, ) From cebf47b8cbcf0580b1370bec99dd2ec6823af6ba Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 9 Oct 2025 14:44:34 +0800 Subject: [PATCH 02/11] feat: add pre-commit config --- .clang-format | 8 +++++++ .pre-commit-config.yaml | 48 +++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 7 +++++- 3 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 .clang-format create mode 100644 .pre-commit-config.yaml diff --git a/.clang-format b/.clang-format new file mode 100644 index 000000000..964712a78 --- /dev/null +++ b/.clang-format @@ -0,0 +1,8 @@ +--- +BasedOnStyle: LLVM +UseTab: Never +IndentWidth: 2 +ColumnLimit: 80 + +Language: Cpp +Standard: c++17 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..18e1dcad4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,48 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +ci: + autofix_prs: true + autofix_commit_msg: "[Lint]: [pre-commit.ci] auto fixes [...]" + autoupdate_commit_msg: "[CI] [pre-commit.ci] autoupdate" + autoupdate_schedule: monthly +default_stages: [pre-commit, pre-push, manual] +exclude: '^(build|3rdparty)/.*$' # exclude build and 3rdparty directories +repos: + # - repo: https://github.com/pre-commit/pre-commit-hooks + # rev: v6.0.0 + # hooks: + # - id: check-symlinks + # - id: destroyed-symlinks + # - id: trailing-whitespace + # - id: end-of-file-fixer + # - id: check-added-large-files + # - id: check-merge-conflict + # fail_fast: true + # - id: check-executables-have-shebangs + # - id: check-shebang-scripts-are-executable + # - id: detect-private-key + # - id: check-yaml + # - id: check-toml + # - id: check-ast + # fail_fast: true + # - id: debug-statements + - repo: https://github.com/pre-commit/mirrors-clang-format + rev: v15.0.7 # sync with requirements-lint.txt + hooks: + - id: clang-format + exclude: \.json$ + # - repo: https://github.com/astral-sh/ruff-pre-commit + # rev: v0.6.5 # sync with requirements-lint.txt + # hooks: + # - id: ruff-check + # args: [--fix, --exit-non-zero-on-fix] + - repo: https://github.com/google/yapf + rev: v0.40.2 # sync with requirements-lint.txt + hooks: + - id: yapf + args: [--recursive, --in-place] + - repo: https://github.com/codespell-project/codespell + rev: v2.3.0 # sync with requirements-lint.txt + hooks: + - id: codespell + additional_dependencies: [".[toml]"] diff --git a/pyproject.toml b/pyproject.toml index 7193341dd..65901ead0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,11 @@ skip = [ ".venv" ] +[tool.ruff] +target-version = "py38" +line-length = 100 +output-format = "full" + [tool.ruff.lint] select = [ # pycodestyle @@ -57,4 +62,4 @@ ignore = [ ] [tool.ruff.lint.per-file-ignores] "3rdparty/**/*" = ["ALL"] -"examples/deepseek_v32/inference/**/*" = ["ALL"] \ No newline at end of file +"examples/deepseek_v32/inference/**/*" = ["ALL"] From 4ad758a41a8fb4a9a612aec4afc24383b10d3141 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 9 Oct 2025 15:00:56 +0800 Subject: [PATCH 03/11] chore: update lint dependencies --- .pre-commit-config.yaml | 25 +++++++++++++++++-------- pyproject.toml | 4 ++++ requirements-lint.txt | 8 +++----- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 18e1dcad4..1f305bf2d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,19 +30,28 @@ repos: rev: v15.0.7 # sync with requirements-lint.txt hooks: - id: clang-format - exclude: \.json$ - # - repo: https://github.com/astral-sh/ruff-pre-commit - # rev: v0.6.5 # sync with requirements-lint.txt - # hooks: - # - id: ruff-check - # args: [--fix, --exit-non-zero-on-fix] + exclude: | + (?ix)( + ^.+\.json$ + ) + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.0 # sync with requirements-lint.txt + hooks: + - id: ruff-check + args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/google/yapf - rev: v0.40.2 # sync with requirements-lint.txt + rev: v0.43.0 # sync with requirements-lint.txt hooks: - id: yapf args: [--recursive, --in-place] - repo: https://github.com/codespell-project/codespell - rev: v2.3.0 # sync with requirements-lint.txt + rev: v2.4.1 # sync with requirements-lint.txt hooks: - id: codespell additional_dependencies: [".[toml]"] + exclude: | + (?x)( + ^.+\.(cpp|hpp|cxx|cc|c|h|cu|cuh)$| + ^.+\.svg$| + ^.*\brequirements\b.*\.txt$ + ) diff --git a/pyproject.toml b/pyproject.toml index 65901ead0..1d3755099 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,10 +53,14 @@ ignore = [ "E741", # line too long "E501", + # if-else-block instead of ternary + "SIM108", # key in dict.keys() "SIM118", # memory leaks "B019", + # zip without explicit strict + "B905", # No such file or directory "E902", ] diff --git a/requirements-lint.txt b/requirements-lint.txt index 46737db5d..92f61068d 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -1,8 +1,6 @@ # formatting -yapf==0.40.2 -toml==0.10.2 -tomli==2.0.1 -ruff==0.6.5 -codespell==2.3.0 +yapf==0.43.0 +ruff==0.14.0 +codespell[toml]==2.4.1 clang-format==15.0.7 clang-tidy==18.1.8 From 8b853b6618259243e8ed504c346dfc43ab19ecf6 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 9 Oct 2025 15:03:32 +0800 Subject: [PATCH 04/11] style: fix lint issues --- docs/deeplearning_operators/matmul.md | 4 ++-- examples/deepseek_v32/fp8_lighting_indexer.py | 2 ++ tilelang/jit/adapter/libgen.py | 8 ++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/deeplearning_operators/matmul.md b/docs/deeplearning_operators/matmul.md index 490d731e0..fea036ebe 100644 --- a/docs/deeplearning_operators/matmul.md +++ b/docs/deeplearning_operators/matmul.md @@ -8,7 +8,7 @@ :class: myclass1 myclass2 :name: a-tip-reference - This document is still **experimental** and may be incomplete. + This document is still **experimental** and may be incomplete. Suggestions and improvements are highly encouraged—please submit a PR! ::: @@ -256,4 +256,4 @@ For more advanced usage—including partial lowering, explicitly controlling thr * [BitBLAS](https://github.com/tile-ai/bitblas) * [Triton](https://github.com/openai/triton) * [Cutlass](https://github.com/NVIDIA/cutlass) -* [PyCUDA](https://documen.tician.de/pycuda/) +* [PyCUDA](https://documen.tician.de/pycuda/) diff --git a/examples/deepseek_v32/fp8_lighting_indexer.py b/examples/deepseek_v32/fp8_lighting_indexer.py index 64df55cbb..279dd91c7 100644 --- a/examples/deepseek_v32/fp8_lighting_indexer.py +++ b/examples/deepseek_v32/fp8_lighting_indexer.py @@ -258,6 +258,7 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, cost = mask.sum() return logits, cost + def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) @@ -302,5 +303,6 @@ def logits_fn(): print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}") print(f"cost_ref: {cost_ref}") + if __name__ == "__main__": test_fp8_lighting_indexer() diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py index 89f127f0c..5d1143a67 100644 --- a/tilelang/jit/adapter/libgen.py +++ b/tilelang/jit/adapter/libgen.py @@ -64,7 +64,7 @@ def compile_lib(self, timeout: float = None): verbose = self.verbose if is_cuda_target(target): from tilelang.env import CUTLASS_INCLUDE_DIR - src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) # noqa: SIM115 target_arch = get_target_arch(get_target_compute_version(target)) libpath = src.name.replace(".cu", ".so") @@ -111,7 +111,7 @@ def compile_lib(self, timeout: float = None): elif is_hip_target(target): from tilelang.env import COMPOSABLE_KERNEL_INCLUDE_DIR - src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115 libpath = src.name.replace(".cpp", ".so") rocm_path = find_rocm_path() arch = get_rocm_arch(rocm_path) @@ -128,7 +128,7 @@ def compile_lib(self, timeout: float = None): ] elif is_cpu_target(target): from tilelang.contrib.cc import get_cplus_compiler - src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cpp", delete=False) # noqa: SIM115 libpath = src.name.replace(".cpp", ".so") command = [get_cplus_compiler(), "-std=c++17", "-fPIC", "-shared", src.name] @@ -228,7 +228,7 @@ def compile_lib(self, timeout: float = None): verbose = self.verbose if is_cuda_target(target): from tilelang.env import (CUDA_HOME, CUTLASS_INCLUDE_DIR, TILELANG_TEMPLATE_PATH) - src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) # noqa: SIM115 libpath = src.name.replace(".cu", ".cubin") project_root = osp.join(osp.dirname(__file__), "..", "..") From f6d2cc7b8e1608edfb5f4846789043f8395173e6 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 9 Oct 2025 15:05:57 +0800 Subject: [PATCH 05/11] feat: add pre-commit hooks --- .pre-commit-config.yaml | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f305bf2d..32fce4601 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,24 +8,26 @@ ci: default_stages: [pre-commit, pre-push, manual] exclude: '^(build|3rdparty)/.*$' # exclude build and 3rdparty directories repos: - # - repo: https://github.com/pre-commit/pre-commit-hooks - # rev: v6.0.0 - # hooks: - # - id: check-symlinks - # - id: destroyed-symlinks - # - id: trailing-whitespace - # - id: end-of-file-fixer - # - id: check-added-large-files - # - id: check-merge-conflict - # fail_fast: true - # - id: check-executables-have-shebangs - # - id: check-shebang-scripts-are-executable - # - id: detect-private-key - # - id: check-yaml - # - id: check-toml - # - id: check-ast - # fail_fast: true - # - id: debug-statements + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: check-symlinks + - id: destroyed-symlinks + # FIXME: enable these hooks + # - id: trailing-whitespace + # - id: end-of-file-fixer + - id: check-added-large-files + - id: check-merge-conflict + fail_fast: true + # FIXME: enable these hooks + # - id: check-executables-have-shebangs + # - id: check-shebang-scripts-are-executable + - id: detect-private-key + - id: check-yaml + - id: check-toml + - id: check-ast + fail_fast: true + - id: debug-statements - repo: https://github.com/pre-commit/mirrors-clang-format rev: v15.0.7 # sync with requirements-lint.txt hooks: From bcc9c87f03704733de3799a629a9f5c76b2f1171 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Thu, 9 Oct 2025 15:19:11 +0800 Subject: [PATCH 06/11] fix: fix typos --- src/layout/gemm_layouts.cc | 2 +- src/op/parallel.cc | 4 ++-- src/target/codegen_cuda.cc | 2 +- src/target/ptx.h | 2 +- src/transform/inject_assumes.cc | 4 ++-- src/transform/loop_vectorize_dynamic.cc | 11 +++++++---- 6 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index 659696fec..cbb7278f6 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -588,7 +588,7 @@ Layout makeGemmVoltaABLayout(int stride, int continuous, bool is_a, // ref: // https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/layout/tensor_op_multiplicand_sm75.h#L54 -// Althought the four settings (T or NT) used distinct layouts in CUTLASS, they +// Although the four settings (T or NT) used distinct layouts in CUTLASS, they // appeared to result in the same mem layout Layout makeTensorOpMultiplicand(int mat_stride, int mat_continuous, int elementsize, int crosswise) { diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 9f1d92148..2a1135d7e 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -215,9 +215,9 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, return {}; if (level == InferLevel::kStrict) { LayoutMap results; - // Deduce buffers that shoule be complicated replicated. + // Deduce buffers that should be complicated replicated. // For example: - // for i in T.Parllel(m): + // for i in T.Parallel(m): // fragment[0] = x[i] // then fragment[0] must be replicated on all threads. for (const auto &[buffer, indices] : indice_map_) { diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 472a29ffe..87ac6f0ca 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2088,7 +2088,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, DataType element_dtype = op->buffer->dtype; int lanes = op->dtype.lanes(); - // delcare type. + // declare type. if (value_dtype.lanes() == element_dtype.lanes()) { std::string ref = GetBufferRef(op->dtype, op->buffer.get(), index); HandleVolatileLoads(ref, op, os); diff --git a/src/target/ptx.h b/src/target/ptx.h index 15acb96b1..c1675a394 100644 --- a/src/target/ptx.h +++ b/src/target/ptx.h @@ -150,7 +150,7 @@ std::string PrintArriveBarrierAsm(const std::string &barrier); * \brief Print ptx barrier arrival with expect tx operation using * mbarrier.arrive.expect_tx \param barrier: The name of the barrier in shared * memory. \param byte_count: Increases the tx count of the mbarrier object to - * track completion of addtional async transactions. + * track completion of additional async transactions. */ std::string PrintArriveBarrierExpectTxAsm(const std::string &barrier, const std::string &byte_count); diff --git a/src/transform/inject_assumes.cc b/src/transform/inject_assumes.cc index a2ddfc4a0..d4c8a53c8 100644 --- a/src/transform/inject_assumes.cc +++ b/src/transform/inject_assumes.cc @@ -33,8 +33,8 @@ class AssumeInjector : public tvm::tir::StmtExprMutator { }; tvm::StructuralHash sh; tvm::StructuralEqual se; - // grouped by expr, since the amount of varidic shape symbols is usualy much - // smaller than buffer + // grouped by expr, since the amount of variadic shape symbols is usually + // much smaller than buffer std::vector items; // hash => index in items std::unordered_map> buckets; diff --git a/src/transform/loop_vectorize_dynamic.cc b/src/transform/loop_vectorize_dynamic.cc index 0756fce43..d02582726 100644 --- a/src/transform/loop_vectorize_dynamic.cc +++ b/src/transform/loop_vectorize_dynamic.cc @@ -243,9 +243,9 @@ class VectorizedBodyMutator : public StmtExprMutator { std::vector conditions_; }; -class VectorizedConditionExtracter : public StmtExprVisitor { +class VectorizedConditionExtractor : public StmtExprVisitor { public: - VectorizedConditionExtracter() = default; + VectorizedConditionExtractor() = default; std::vector GetConditions(const Stmt &body) { this->VisitStmt(body); return conditions_; @@ -268,6 +268,9 @@ class VectorizedConditionExtracter : public StmtExprVisitor { std::vector conditions_; }; +// backward-compatibility: extracter -> extractor +using VectorizedConditionExtracter = VectorizedConditionExtractor; + class NestedLoopChecker : public StmtExprVisitor { public: NestedLoopChecker() : loop_num_(0) {} @@ -391,8 +394,8 @@ class VectorizeRewriterDynamic : public StmtExprMutator { vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var); Stmt body = Substitute(fnode->body, vmap); - VectorizedConditionExtracter extracter; - std::vector conditions = extracter.GetConditions(body); + VectorizedConditionExtractor extractor; + std::vector conditions = extractor.GetConditions(body); VectorizedConditionMutator condition_mutator(inner_var, vector_size_); From a6d59fc9b77aa0ca2951a3375a158b56d844c1f5 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 10 Oct 2025 11:43:02 +0800 Subject: [PATCH 07/11] chore: update .gitattributes --- .gitattributes | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.gitattributes b/.gitattributes index 2f6d49472..bbb14db37 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,10 @@ +* text eol=lf +*.bat eol=crlf + +*.svg binary +*.jpg binary +*.jpeg binary +*.png binary +*.gif binary + *.h linguist-language=C++ From 9011f4a32616015172d3f3359a16b901f476d67d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Oct 2025 03:44:05 +0000 Subject: [PATCH 08/11] [Lint]: [pre-commit.ci] auto fixes [...] --- .../minference/ops/vertical_slash_index.cu | 243 +++++------ maint/precision/cuda_ops.cu | 400 ++++++++++-------- src/tl_templates/cuda/compress_sm90.cu | 148 ++++--- src/tl_templates/cuda/cuda_bf16_fallbacks.cuh | 282 ++++++------ 4 files changed, 581 insertions(+), 492 deletions(-) diff --git a/examples/minference/ops/vertical_slash_index.cu b/examples/minference/ops/vertical_slash_index.cu index ae01f331b..8f49abe3a 100644 --- a/examples/minference/ops/vertical_slash_index.cu +++ b/examples/minference/ops/vertical_slash_index.cu @@ -3,157 +3,142 @@ #include +#include #include #include -#include #include #include -__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) { - for (int idx = range_start; idx < range_end; idx += block_size) { - block_offset[block_count++] = idx; - } +__device__ void save_blocks(int *block_offset, int range_start, int range_end, + int block_size, int &block_count) { + for (int idx = range_start; idx < range_end; idx += block_size) { + block_offset[block_count++] = idx; + } } __global__ void convert_vertical_slash_indexes_kernel( - const int* seqlens, // [BATCH, ] - const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] - const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] - int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] - int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] - int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] - int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] - int N_HEADS, - int N_ROWS, - int BLOCK_SIZE_M, - int BLOCK_SIZE_N, - int NNZ_V, - int NNZ_S -) { - const int batch_idx = blockIdx.y; - const int head_idx = blockIdx.x; - const int group_idx = blockIdx.z; + const int *seqlens, // [BATCH, ] + const int *vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int *slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int *block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int *block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int *column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int *column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int N_HEADS, int N_ROWS, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int NNZ_V, + int NNZ_S) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; - int seqlen = seqlens[batch_idx]; - int block_idx_m = group_idx * blockDim.x + threadIdx.x; - int start_m = block_idx_m * BLOCK_SIZE_M; - if (start_m >= seqlen) { - return; - } - int end_m = start_m + BLOCK_SIZE_M; - vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; - slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; - int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; - block_count += row_offset; - block_offset += row_offset * NNZ_S; - column_count += row_offset; - column_index += row_offset * NNZ_V; + int seqlen = seqlens[batch_idx]; + int block_idx_m = group_idx * blockDim.x + threadIdx.x; + int start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= seqlen) { + return; + } + int end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; - int tmp_col_cnt = 0, tmp_blk_cnt = 0; - int s = 0, v = 0; - int v_idx = vertical_indexes[v++]; - int s_idx = slash_indexes[s++]; - while (s_idx >= end_m) { - s_idx = slash_indexes[s++]; - } - s_idx = max(end_m - s_idx, BLOCK_SIZE_M); - int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; - while (1) { - if (v_idx < range_end) { - if (v_idx < range_start) { - column_index[tmp_col_cnt++] = v_idx; - } - if (v < NNZ_V) { - v_idx = vertical_indexes[v++]; - } else { - v_idx = end_m + BLOCK_SIZE_M; - } - } else { - if (s < NNZ_S) { - s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); - } else { - save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); - break; - } - if (s_idx > range_end + BLOCK_SIZE_M) { - save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); - range_start = s_idx - BLOCK_SIZE_M; - range_end = s_idx; - } else if (s_idx > range_end) { - range_end += BLOCK_SIZE_M; - } - } + int tmp_col_cnt = 0, tmp_blk_cnt = 0; + int s = 0, v = 0; + int v_idx = vertical_indexes[v++]; + int s_idx = slash_indexes[s++]; + while (s_idx >= end_m) { + s_idx = slash_indexes[s++]; + } + s_idx = max(end_m - s_idx, BLOCK_SIZE_M); + int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + v_idx = end_m + BLOCK_SIZE_M; + } + } else { + if (s < NNZ_S) { + s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, + tmp_blk_cnt); + break; + } + if (s_idx > range_end + BLOCK_SIZE_M) { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, + tmp_blk_cnt); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } } + } - block_count[0] = tmp_blk_cnt; - column_count[0] = tmp_col_cnt; + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; } void convert_vertical_slash_indexes_64x64( - const int* seqlens, // [BATCH, ] - const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] - const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] - int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] - int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] - int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] - int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] - int BATCH_SIZE, - int N_HEADS, - int N_ROWS, - int NNZ_V, - int NNZ_S -) { - const int BLOCK_SIZE_M = 64; - const int BLOCK_SIZE_N = 64; - const int N_THREADS = 64; - const dim3 dimBlock(N_THREADS); - const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); - convert_vertical_slash_indexes_kernel<<>>( - seqlens, vertical_indexes, slash_indexes, - block_count, block_offset, column_count, column_index, - N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S - ); + const int *seqlens, // [BATCH, ] + const int *vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int *slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int *block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int *block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int *column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int *column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int BATCH_SIZE, int N_HEADS, int N_ROWS, int NNZ_V, int NNZ_S) { + const int BLOCK_SIZE_M = 64; + const int BLOCK_SIZE_N = 64; + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel<<>>( + seqlens, vertical_indexes, slash_indexes, block_count, block_offset, + column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, + NNZ_V, NNZ_S); } std::vector convert_vertical_slash_indexes( - torch::Tensor seqlens, // [BATCH, ] - torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] - torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] - int context_size, - int block_size_M, - int block_size_N -) { - assert(block_size_M == 64); - assert(block_size_N == 64); + torch::Tensor seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int context_size, int block_size_M, int block_size_N) { + assert(block_size_M == 64); + assert(block_size_N == 64); - cudaSetDevice(seqlens.get_device()); + cudaSetDevice(seqlens.get_device()); - int batch_size = slash_indexes.size(0); - int num_heads = slash_indexes.size(1); - int nnz_slash = slash_indexes.size(2); - int nnz_vertical = vertical_indexes.size(2); - int num_rows = (context_size + block_size_M - 1) / block_size_M; + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; - torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); - torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); - torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); - torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); + torch::Tensor block_count = + torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); + torch::Tensor block_offset = torch::zeros( + {batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); + torch::Tensor column_count = + torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); + torch::Tensor column_index = torch::zeros( + {batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); - convert_vertical_slash_indexes_64x64( - seqlens.data_ptr(), - vertical_indexes.data_ptr(), - slash_indexes.data_ptr(), - block_count.data_ptr(), - block_offset.data_ptr(), - column_count.data_ptr(), - column_index.data_ptr(), - batch_size, - num_heads, - num_rows, - nnz_vertical, - nnz_slash - ); + convert_vertical_slash_indexes_64x64( + seqlens.data_ptr(), vertical_indexes.data_ptr(), + slash_indexes.data_ptr(), block_count.data_ptr(), + block_offset.data_ptr(), column_count.data_ptr(), + column_index.data_ptr(), batch_size, num_heads, num_rows, + nnz_vertical, nnz_slash); - return { block_count, block_offset, column_count, column_index }; + return {block_count, block_offset, column_count, column_index}; } diff --git a/maint/precision/cuda_ops.cu b/maint/precision/cuda_ops.cu index 519335751..6bd1856af 100644 --- a/maint/precision/cuda_ops.cu +++ b/maint/precision/cuda_ops.cu @@ -1,242 +1,294 @@ -#include #include #include #include +#include enum OperatorType { - OP_DIV, - OP_RECIPROCAL, - OP_EXP, - OP_LOG, - OP_SIN, - OP_COS, - OP_SQRT, - OP_TANH, - OP_RSQRT, - OP_INV_SQRT + OP_DIV, + OP_RECIPROCAL, + OP_EXP, + OP_LOG, + OP_SIN, + OP_COS, + OP_SQRT, + OP_TANH, + OP_RSQRT, + OP_INV_SQRT }; // ================= 精确版本 device 运算符 ================= -__device__ __forceinline__ float precise_div(float a, float b) { - return a / b; -} +__device__ __forceinline__ float precise_div(float a, float b) { return a / b; } __device__ __forceinline__ float precise_reciprocal(float x) { - return 1.0f / x; -} -__device__ __forceinline__ float precise_exp(float x) { - return expf(x); -} -__device__ __forceinline__ float precise_log(float x) { - return logf(x); -} -__device__ __forceinline__ float precise_sin(float x) { - return sinf(x); -} -__device__ __forceinline__ float precise_cos(float x) { - return cosf(x); -} -__device__ __forceinline__ float precise_sqrt(float x) { - return sqrtf(x); -} -__device__ __forceinline__ float precise_tanh(float x) { - return tanhf(x); -} -__device__ __forceinline__ float precise_rsqrt(float x) { - return rsqrtf(x); -} + return 1.0f / x; +} +__device__ __forceinline__ float precise_exp(float x) { return expf(x); } +__device__ __forceinline__ float precise_log(float x) { return logf(x); } +__device__ __forceinline__ float precise_sin(float x) { return sinf(x); } +__device__ __forceinline__ float precise_cos(float x) { return cosf(x); } +__device__ __forceinline__ float precise_sqrt(float x) { return sqrtf(x); } +__device__ __forceinline__ float precise_tanh(float x) { return tanhf(x); } +__device__ __forceinline__ float precise_rsqrt(float x) { return rsqrtf(x); } __device__ __forceinline__ float precise_inv_sqrt(float x) { - return 1.0f / sqrtf(x); + return 1.0f / sqrtf(x); } // ================= double 精确版本 device 运算符 ================= __device__ __forceinline__ double double_precise_div(double a, double b) { - return a / b; + return a / b; } __device__ __forceinline__ double double_precise_reciprocal(double x) { - return 1.0 / x; + return 1.0 / x; } __device__ __forceinline__ double double_precise_exp(double x) { - return exp(x); + return exp(x); } __device__ __forceinline__ double double_precise_log(double x) { - return log(x); + return log(x); } __device__ __forceinline__ double double_precise_sin(double x) { - return sin(x); + return sin(x); } __device__ __forceinline__ double double_precise_cos(double x) { - return cos(x); + return cos(x); } __device__ __forceinline__ double double_precise_sqrt(double x) { - return sqrt(x); + return sqrt(x); } __device__ __forceinline__ double double_precise_tanh(double x) { - return tanh(x); + return tanh(x); } __device__ __forceinline__ double double_precise_rsqrt(double x) { - return 1.0 / sqrt(x); + return 1.0 / sqrt(x); } __device__ __forceinline__ double double_precise_inv_sqrt(double x) { - return 1.0 / sqrt(x); + return 1.0 / sqrt(x); } // ================= 快速近似版本 device 运算符 ================= __device__ __forceinline__ float fast_div(float a, float b) { - return __fdividef(a, b); + return __fdividef(a, b); } __device__ __forceinline__ float fast_reciprocal(float x) { - float ret; - asm volatile("rcp.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); - return ret; -} -__device__ __forceinline__ float fast_exp(float x) { - return __expf(x); -} -__device__ __forceinline__ float fast_log(float x) { - return __logf(x); -} -__device__ __forceinline__ float fast_sin(float x) { - return __sinf(x); -} -__device__ __forceinline__ float fast_cos(float x) { - return __cosf(x); -} + float ret; + asm volatile("rcp.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return ret; +} +__device__ __forceinline__ float fast_exp(float x) { return __expf(x); } +__device__ __forceinline__ float fast_log(float x) { return __logf(x); } +__device__ __forceinline__ float fast_sin(float x) { return __sinf(x); } +__device__ __forceinline__ float fast_cos(float x) { return __cosf(x); } __device__ __forceinline__ float fast_sqrt(float x) { - float ret; - asm volatile("sqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); - return ret; -} -__device__ __forceinline__ float fast_tanh(float x) { - return __tanhf(x); + float ret; + asm volatile("sqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return ret; } +__device__ __forceinline__ float fast_tanh(float x) { return __tanhf(x); } __device__ __forceinline__ float fast_rsqrt(float x) { - // return rsqrtf(x); - float ret; - asm volatile("rsqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); - return ret; + // return rsqrtf(x); + float ret; + asm volatile("rsqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return ret; } __device__ __forceinline__ float fast_inv_sqrt(float x) { - float ret; - asm volatile("sqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); - return 1.0f / ret; + float ret; + asm volatile("sqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return 1.0f / ret; } // ================= 精确版本 kernel ================= -__global__ void precise_operator_kernel(const float* x, const float* y, float* result, int64_t n, OperatorType op_type) { - int64_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { - float a = x[i]; - float b = (y != nullptr) ? y[i] : 0.0f; - float r = 0.0f; - switch (op_type) { - case OP_DIV: r = precise_div(a, b); break; - case OP_RECIPROCAL: r = precise_reciprocal(a); break; - case OP_EXP: r = precise_exp(a); break; - case OP_LOG: r = precise_log(a); break; - case OP_SIN: r = precise_sin(a); break; - case OP_COS: r = precise_cos(a); break; - case OP_SQRT: r = precise_sqrt(a); break; - case OP_TANH: r = precise_tanh(a); break; - case OP_RSQRT: r = precise_rsqrt(a); break; - case OP_INV_SQRT: r = precise_inv_sqrt(a); break; - } - result[i] = r; +__global__ void precise_operator_kernel(const float *x, const float *y, + float *result, int64_t n, + OperatorType op_type) { + int64_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + float a = x[i]; + float b = (y != nullptr) ? y[i] : 0.0f; + float r = 0.0f; + switch (op_type) { + case OP_DIV: + r = precise_div(a, b); + break; + case OP_RECIPROCAL: + r = precise_reciprocal(a); + break; + case OP_EXP: + r = precise_exp(a); + break; + case OP_LOG: + r = precise_log(a); + break; + case OP_SIN: + r = precise_sin(a); + break; + case OP_COS: + r = precise_cos(a); + break; + case OP_SQRT: + r = precise_sqrt(a); + break; + case OP_TANH: + r = precise_tanh(a); + break; + case OP_RSQRT: + r = precise_rsqrt(a); + break; + case OP_INV_SQRT: + r = precise_inv_sqrt(a); + break; } + result[i] = r; + } } // ================= double 精确版本 kernel ================= -__global__ void double_precise_operator_kernel(const double* x, const double* y, double* result, int64_t n, OperatorType op_type) { - int64_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { - double a = x[i]; - double b = (y != nullptr) ? y[i] : 0.0; - double r = 0.0; - switch (op_type) { - case OP_DIV: r = double_precise_div(a, b); break; - case OP_RECIPROCAL: r = double_precise_reciprocal(a); break; - case OP_EXP: r = double_precise_exp(a); break; - case OP_LOG: r = double_precise_log(a); break; - case OP_SIN: r = double_precise_sin(a); break; - case OP_COS: r = double_precise_cos(a); break; - case OP_SQRT: r = double_precise_sqrt(a); break; - case OP_TANH: r = double_precise_tanh(a); break; - case OP_RSQRT: r = double_precise_rsqrt(a); break; - case OP_INV_SQRT: r = double_precise_inv_sqrt(a); break; - } - result[i] = r; +__global__ void double_precise_operator_kernel(const double *x, const double *y, + double *result, int64_t n, + OperatorType op_type) { + int64_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + double a = x[i]; + double b = (y != nullptr) ? y[i] : 0.0; + double r = 0.0; + switch (op_type) { + case OP_DIV: + r = double_precise_div(a, b); + break; + case OP_RECIPROCAL: + r = double_precise_reciprocal(a); + break; + case OP_EXP: + r = double_precise_exp(a); + break; + case OP_LOG: + r = double_precise_log(a); + break; + case OP_SIN: + r = double_precise_sin(a); + break; + case OP_COS: + r = double_precise_cos(a); + break; + case OP_SQRT: + r = double_precise_sqrt(a); + break; + case OP_TANH: + r = double_precise_tanh(a); + break; + case OP_RSQRT: + r = double_precise_rsqrt(a); + break; + case OP_INV_SQRT: + r = double_precise_inv_sqrt(a); + break; } + result[i] = r; + } } // ================= 快速版本 kernel ================= -__global__ void fast_operator_kernel(const float* x, const float* y, float* result, int64_t n, OperatorType op_type) { - int64_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { - float a = x[i]; - float b = (y != nullptr) ? y[i] : 0.0f; - float r = 0.0f; - switch (op_type) { - case OP_DIV: r = fast_div(a, b); break; - case OP_RECIPROCAL: r = fast_reciprocal(a); break; - case OP_EXP: r = fast_exp(a); break; - case OP_LOG: r = fast_log(a); break; - case OP_SIN: r = fast_sin(a); break; - case OP_COS: r = fast_cos(a); break; - case OP_SQRT: r = fast_sqrt(a); break; - case OP_TANH: r = fast_tanh(a); break; - case OP_RSQRT: r = fast_rsqrt(a); break; - case OP_INV_SQRT: r = fast_inv_sqrt(a); break; - } - result[i] = r; +__global__ void fast_operator_kernel(const float *x, const float *y, + float *result, int64_t n, + OperatorType op_type) { + int64_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + float a = x[i]; + float b = (y != nullptr) ? y[i] : 0.0f; + float r = 0.0f; + switch (op_type) { + case OP_DIV: + r = fast_div(a, b); + break; + case OP_RECIPROCAL: + r = fast_reciprocal(a); + break; + case OP_EXP: + r = fast_exp(a); + break; + case OP_LOG: + r = fast_log(a); + break; + case OP_SIN: + r = fast_sin(a); + break; + case OP_COS: + r = fast_cos(a); + break; + case OP_SQRT: + r = fast_sqrt(a); + break; + case OP_TANH: + r = fast_tanh(a); + break; + case OP_RSQRT: + r = fast_rsqrt(a); + break; + case OP_INV_SQRT: + r = fast_inv_sqrt(a); + break; } + result[i] = r; + } } // 精确版本 -void launch_precise_operator(const at::Tensor& x, const c10::optional& y, at::Tensor& result, int op_type) { - int64_t n = x.numel(); - int threads = 256; - int blocks = (n + threads - 1) / threads; - const float* y_ptr = nullptr; - if (y.has_value()) { - y_ptr = y.value().data_ptr(); - } - precise_operator_kernel<<>>( - x.data_ptr(), y_ptr, result.data_ptr(), n, static_cast(op_type) - ); +void launch_precise_operator(const at::Tensor &x, + const c10::optional &y, + at::Tensor &result, int op_type) { + int64_t n = x.numel(); + int threads = 256; + int blocks = (n + threads - 1) / threads; + const float *y_ptr = nullptr; + if (y.has_value()) { + y_ptr = y.value().data_ptr(); + } + precise_operator_kernel<<>>( + x.data_ptr(), y_ptr, result.data_ptr(), n, + static_cast(op_type)); } // double 精确版本 -void launch_double_precise_operator(const at::Tensor& x, const c10::optional& y, at::Tensor& result, int op_type) { - int64_t n = x.numel(); - int threads = 256; - int blocks = (n + threads - 1) / threads; - const double* y_ptr = nullptr; - if (y.has_value()) { - y_ptr = y.value().data_ptr(); - } - double_precise_operator_kernel<<>>( - x.data_ptr(), y_ptr, result.data_ptr(), n, static_cast(op_type) - ); +void launch_double_precise_operator(const at::Tensor &x, + const c10::optional &y, + at::Tensor &result, int op_type) { + int64_t n = x.numel(); + int threads = 256; + int blocks = (n + threads - 1) / threads; + const double *y_ptr = nullptr; + if (y.has_value()) { + y_ptr = y.value().data_ptr(); + } + double_precise_operator_kernel<<>>( + x.data_ptr(), y_ptr, result.data_ptr(), n, + static_cast(op_type)); } // 快速版本 -void launch_fast_operator(const at::Tensor& x, const c10::optional& y, at::Tensor& result, int op_type) { - int64_t n = x.numel(); - int threads = 256; - int blocks = (n + threads - 1) / threads; - const float* y_ptr = nullptr; - if (y.has_value()) { - y_ptr = y.value().data_ptr(); - } - fast_operator_kernel<<>>( - x.data_ptr(), y_ptr, result.data_ptr(), n, static_cast(op_type) - ); +void launch_fast_operator(const at::Tensor &x, + const c10::optional &y, + at::Tensor &result, int op_type) { + int64_t n = x.numel(); + int threads = 256; + int blocks = (n + threads - 1) / threads; + const float *y_ptr = nullptr; + if (y.has_value()) { + y_ptr = y.value().data_ptr(); + } + fast_operator_kernel<<>>( + x.data_ptr(), y_ptr, result.data_ptr(), n, + static_cast(op_type)); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_precise_operator", &launch_precise_operator, "CUDA Precise Operator", - py::arg("x"), py::arg("y") = c10::nullopt, py::arg("result"), py::arg("op_type")); - m.def("launch_double_precise_operator", &launch_double_precise_operator, "CUDA Double Precise Operator", - py::arg("x"), py::arg("y") = c10::nullopt, py::arg("result"), py::arg("op_type")); - m.def("launch_fast_operator", &launch_fast_operator, "CUDA Fast Operator", - py::arg("x"), py::arg("y") = c10::nullopt, py::arg("result"), py::arg("op_type")); + m.def("launch_precise_operator", &launch_precise_operator, + "CUDA Precise Operator", py::arg("x"), py::arg("y") = c10::nullopt, + py::arg("result"), py::arg("op_type")); + m.def("launch_double_precise_operator", &launch_double_precise_operator, + "CUDA Double Precise Operator", py::arg("x"), + py::arg("y") = c10::nullopt, py::arg("result"), py::arg("op_type")); + m.def("launch_fast_operator", &launch_fast_operator, "CUDA Fast Operator", + py::arg("x"), py::arg("y") = c10::nullopt, py::arg("result"), + py::arg("op_type")); } \ No newline at end of file diff --git a/src/tl_templates/cuda/compress_sm90.cu b/src/tl_templates/cuda/compress_sm90.cu index 8bb236dd8..3ec1d7aac 100644 --- a/src/tl_templates/cuda/compress_sm90.cu +++ b/src/tl_templates/cuda/compress_sm90.cu @@ -13,60 +13,68 @@ using namespace cute; -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - if (error != cutlass::Status::kSuccess) { \ - std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ - << " at: " << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ + << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ } -#define CUDA_CHECK(status) \ - { \ - cudaError_t error = status; \ - if (error != cudaSuccess) { \ - std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ - << " at line: " << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ } -template +template std::tuple compress_impl(torch::Tensor A) { using ElementA = T; using ElementE = uint8_t; - using LayoutTagA = conditional_t; + using LayoutTagA = conditional_t; using ProblemShape = cute::Shape; using StrideA = cutlass::gemm::TagToStrideA_t; using StrideE = StrideA; // NOTE: this is derived from sparse sm90 mma atoms - // Ref: https://github.com/NVIDIA/cutlass/blob/dc4817921edda44a549197ff3a9dcf5df0636e7b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp - using SparseE = conditional_t<(sizeof_bits_v == 32), cute::sparse_elem<4, ElementE>, cute::sparse_elem<8, ElementE>>; - static constexpr GMMA::Major GmmaMajorA = transposed ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K; - using SparseConfig = cutlass::Sm90GemmSparseConfig< - cute::sparse_elem<2, ElementA>, GmmaMajorA, - SparseE, cute::C>; + // Ref: + // https://github.com/NVIDIA/cutlass/blob/dc4817921edda44a549197ff3a9dcf5df0636e7b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp + using SparseE = conditional_t<(sizeof_bits_v == 32), + cute::sparse_elem<4, ElementE>, + cute::sparse_elem<8, ElementE>>; + static constexpr GMMA::Major GmmaMajorA = + transposed ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K; + using SparseConfig = + cutlass::Sm90GemmSparseConfig, GmmaMajorA, + SparseE, cute::C>; using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< ProblemShape, ElementA, LayoutTagA, SparseConfig>; - using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< - ProblemShape, ElementA, LayoutTagA, SparseConfig, cutlass::arch::Sm90>; + using CompressorKernel = + cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, ElementA, LayoutTagA, SparseConfig, + cutlass::arch::Sm90>; - using Compressor = cutlass::transform::device::TransformUniversalAdapter; + using Compressor = + cutlass::transform::device::TransformUniversalAdapter; TORCH_CHECK(A.is_contiguous(), "A need to be contiguous"); TORCH_CHECK(A.dim() == 2, "Might support batch dim in the future "); int M = -1; int K = -1; - int N = -1; // not used, but required for config + int N = -1; // not used, but required for config int L = 1; - if constexpr(transposed) { + if constexpr (transposed) { M = A.size(1); K = A.size(0); } else { @@ -75,24 +83,27 @@ std::tuple compress_impl(torch::Tensor A) { } ProblemShape problem_shape = make_tuple(M, N, K, L); - StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + StrideA stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); CompressorUtility compressor_utility(problem_shape, stride_A); int ME = compressor_utility.get_metadata_m_physical(); int KE = compressor_utility.get_metadata_k_physical(); int KC = compressor_utility.get_tensorA_k_physical(); - StrideE stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); + StrideE stride_E = + cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); auto dtype = A.dtype().toScalarType(); - torch::Tensor A_compressed = torch::zeros(KC * M, - torch::TensorOptions().dtype(dtype).device(A.device())); - torch::Tensor E = torch::zeros({ME, KE}, - torch::TensorOptions().dtype(torch::kUInt8).device(A.device())); + torch::Tensor A_compressed = torch::zeros( + KC * M, torch::TensorOptions().dtype(dtype).device(A.device())); + torch::Tensor E = torch::zeros( + {ME, KE}, torch::TensorOptions().dtype(torch::kUInt8).device(A.device())); cutlass::KernelHardwareInfo hw_info; hw_info.device_id = A.device().index(); hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); typename Compressor::Arguments arguments{problem_shape, { @@ -120,40 +131,45 @@ std::tuple compress_impl(torch::Tensor A) { } // block <= 128 -// Ref https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 -#define DISPATCH_BLOCK_K(TYPE, BLOCK_K, FACTOR, TENSOR, TRANSPOSED) \ - [&]() -> std::tuple { \ - switch (BLOCK_K) { \ - case int(32 * FACTOR): return compress_impl(TENSOR); \ - case int(64 * FACTOR): return compress_impl(TENSOR); \ - case int(128 * FACTOR): return compress_impl(TENSOR); \ - default: \ - TORCH_CHECK(false, "Unsupported block_k: ", BLOCK_K); \ - } \ +// Ref +// https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 +#define DISPATCH_BLOCK_K(TYPE, BLOCK_K, FACTOR, TENSOR, TRANSPOSED) \ + [&]() -> std::tuple { \ + switch (BLOCK_K) { \ + case int(32 * FACTOR): \ + return compress_impl(TENSOR); \ + case int(64 * FACTOR): \ + return compress_impl(TENSOR); \ + case int(128 * FACTOR): \ + return compress_impl(TENSOR); \ + default: \ + TORCH_CHECK(false, "Unsupported block_k: ", BLOCK_K); \ + } \ }() -#define DISPATCH_CONTIGUOUS(TRANSPOSED) \ - [&]() -> std::tuple { \ - switch (dtype) { \ - case torch::kFloat32: \ - return DISPATCH_BLOCK_K(float, block_k, 0.5, A, TRANSPOSED); \ - case torch::kFloat16: \ - case torch::kBFloat16: \ - return DISPATCH_BLOCK_K(cute::half_t, block_k, 1, A, TRANSPOSED); \ - case torch::kFloat8_e4m3fn: \ - return DISPATCH_BLOCK_K(cute::float_e4m3_t, block_k, 2, A, TRANSPOSED); \ - case torch::kFloat8_e5m2: \ - return DISPATCH_BLOCK_K(cute::float_e5m2_t, block_k, 2, A, TRANSPOSED); \ - case torch::kChar: \ - return DISPATCH_BLOCK_K(int8_t, block_k, 2, A, TRANSPOSED); \ - case torch::kByte: \ - return DISPATCH_BLOCK_K(uint8_t, block_k, 2, A, TRANSPOSED); \ - default: \ - TORCH_CHECK(false, "Unsupported dtype"); \ - } \ +#define DISPATCH_CONTIGUOUS(TRANSPOSED) \ + [&]() -> std::tuple { \ + switch (dtype) { \ + case torch::kFloat32: \ + return DISPATCH_BLOCK_K(float, block_k, 0.5, A, TRANSPOSED); \ + case torch::kFloat16: \ + case torch::kBFloat16: \ + return DISPATCH_BLOCK_K(cute::half_t, block_k, 1, A, TRANSPOSED); \ + case torch::kFloat8_e4m3fn: \ + return DISPATCH_BLOCK_K(cute::float_e4m3_t, block_k, 2, A, TRANSPOSED); \ + case torch::kFloat8_e5m2: \ + return DISPATCH_BLOCK_K(cute::float_e5m2_t, block_k, 2, A, TRANSPOSED); \ + case torch::kChar: \ + return DISPATCH_BLOCK_K(int8_t, block_k, 2, A, TRANSPOSED); \ + case torch::kByte: \ + return DISPATCH_BLOCK_K(uint8_t, block_k, 2, A, TRANSPOSED); \ + default: \ + TORCH_CHECK(false, "Unsupported dtype"); \ + } \ }() -std::tuple compress_sm90(torch::Tensor A, int64_t block_k, bool transposed) { +std::tuple +compress_sm90(torch::Tensor A, int64_t block_k, bool transposed) { auto dtype = A.dtype().toScalarType(); return transposed ? DISPATCH_CONTIGUOUS(true) : DISPATCH_CONTIGUOUS(false); } diff --git a/src/tl_templates/cuda/cuda_bf16_fallbacks.cuh b/src/tl_templates/cuda/cuda_bf16_fallbacks.cuh index f5641f616..e6dd9625a 100644 --- a/src/tl_templates/cuda/cuda_bf16_fallbacks.cuh +++ b/src/tl_templates/cuda/cuda_bf16_fallbacks.cuh @@ -26,232 +26,268 @@ namespace fastertransformer { #ifdef ENABLE_BF16 inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = __low2float(val); - f_val.y = __high2float(val); - return f_val; + float2 f_val; + f_val.x = __low2float(val); + f_val.y = __high2float(val); + return f_val; #else - return __bfloat1622float2(val); + return __bfloat1622float2(val); #endif } inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = max(min(__low2float(val), 127.f), -128.f); - f_val.y = max(min(__high2float(val), 127.f), -128.f); - union { int8_t int8[2]; int16_t int16; }; - int8[0] = static_cast(static_cast(f_val.x)); - int8[1] = static_cast(static_cast(f_val.y)); - return int16; + float2 f_val; + f_val.x = max(min(__low2float(val), 127.f), -128.f); + f_val.y = max(min(__high2float(val), 127.f), -128.f); + union { + int8_t int8[2]; + int16_t int16; + }; + int8[0] = static_cast(static_cast(f_val.x)); + int8[1] = static_cast(static_cast(f_val.y)); + return int16; #else - val = __hmin2(val, make_bfloat162(127., 127.)); - val = __hmax2(val, make_bfloat162(-128., -128.)); - union { int8_t int8[2]; int16_t int16; }; - int8[0] = static_cast(static_cast(val.x)); - int8[1] = static_cast(static_cast(val.y)); - return int16; + val = __hmin2(val, make_bfloat162(127., 127.)); + val = __hmax2(val, make_bfloat162(-128., -128.)); + union { + int8_t int8[2]; + int16_t int16; + }; + int8[0] = static_cast(static_cast(val.x)); + int8[1] = static_cast(static_cast(val.y)); + return int16; #endif } inline __device__ __nv_bfloat162 float22bf162(const float2 val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __floats2bfloat162_rn(val.x, val.y); + return __floats2bfloat162_rn(val.x, val.y); #else - return __float22bfloat162_rn(val); + return __float22bfloat162_rn(val); #endif } inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - __nv_bfloat162 val2; - val2.x = val; - val2.y = val; - return val2; + __nv_bfloat162 val2; + val2.x = val; + val2.y = val; + return val2; #else - return __bfloat162bfloat162(val); + return __bfloat162bfloat162(val); #endif } -inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, + const __nv_bfloat162 y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); #else - return __hadd2(x, y); + return __hadd2(x, y); #endif } -inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) { +inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, + const __nv_bfloat16 y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); + return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y)); #else - return __hadd(x, y); + return __hadd(x, y); #endif } -inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, + const __nv_bfloat162 y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); #else - return __hsub2(x, y); + return __hsub2(x, y); #endif } -inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) { +inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, + const __nv_bfloat16 y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) ); + return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y)); #else - return __hsub(x, y); + return __hsub(x, y); #endif } -inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, + const __nv_bfloat162 y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); #else - return __hmul2(x, y); + return __hmul2(x, y); #endif } -inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) { +inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, + const __nv_bfloat16 y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); + return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y)); #else - return __hmul(x, y); + return __hmul(x, y); #endif } -inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) { +inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, + const __nv_bfloat162 y, + const __nv_bfloat162 z) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh, fzl, fzh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - fzl = __low2float(z); - fzh = __high2float(z); - return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); + float fxl, fxh, fyl, fyh, fzl, fzh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + fzl = __low2float(z); + fzh = __high2float(z); + return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); #else - return __hfma2(x, y, z); + return __hfma2(x, y, z); #endif } -inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { +inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, + const __nv_bfloat16 y, + const __nv_bfloat16 z) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); + return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + + __bfloat162float(z)); #else - return __hfma(x, y, z); + return __hfma(x, y, z); #endif } inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh; - fxl = __low2float(x); - fxh = __high2float(x);; - return __floats2bfloat162_rn(expf(fxl), expf(fxh)); + float fxl, fxh; + fxl = __low2float(x); + fxh = __high2float(x); + ; + return __floats2bfloat162_rn(expf(fxl), expf(fxh)); #else - return h2exp(x); + return h2exp(x); #endif } #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); }; -inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); }; +inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, + const __nv_bfloat162 y) { + return bf16hmul2(x, y); +}; +inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, + const __nv_bfloat162 y) { + return bf16hadd2(x, y); +}; -inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ - __nv_bfloat162 t; t.x = x; t.y = y; return t; +inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, + const __nv_bfloat16 y) { + __nv_bfloat162 t; + t.x = x; + t.y = y; + return t; } #endif -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, + __nv_bfloat16 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + + __bfloat162float(c)); #else - return a + b + c; + return a + b + c; #endif } -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, + __nv_bfloat16 c, __nv_bfloat16 d) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + + __bfloat162float(c) + __bfloat162float(d)); #else - return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); + return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); #endif } -inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); #else - return a + b + c; + return a + b + c; #endif } -inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { +inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, + __nv_bfloat16 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * + __bfloat162float(c)); #else - return a * b * c; + return a * b * c; #endif } -inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); #else - return a * b * c; + return a * b * c; #endif } -inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { +inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c, __nv_bfloat162 d) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - fdl = __low2float(d); - fdh = __high2float(d); - return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); + float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + fdl = __low2float(d); + fdh = __high2float(d); + return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); #else - return a * b * c + d; + return a * b * c + d; #endif } #endif // ENABLE_BF16 -} // namespace fastertransformer +} // namespace fastertransformer From 3ca61ac26d95c7b6cf57b400483ec6614b70b533 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 10 Oct 2025 12:08:30 +0800 Subject: [PATCH 09/11] docs: update CONTRIBUTING.md --- CONTRIBUTING.md | 94 ++++++++++++++++++++++++++++++++++--------- requirements-lint.txt | 1 + 2 files changed, 76 insertions(+), 19 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 480f68d6e..de2dfb730 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,14 +2,19 @@ That would be awesome if you want to contribute something to TileLang! -- [Contributing](CONTRIBUTING.md#contributing) - - [Reporting Bugs](CONTRIBUTING.md#reporting-bugs) - - [Asking Questions](CONTRIBUTING.md#asking-questions) - - [Submitting Pull Requests](CONTRIBUTING.md#submitting-pull-requests) - - [Repository Setup](CONTRIBUTING.md#repository-setup) - - [Running Tests](CONTRIBUTING.md#running-tests) +### Table of Contents -## Reporting Bugs +- [Report Bugs](#report-bugs) +- [Ask Questions](#ask-questions) +- [Submit Pull Requests](#submit-pull-requests) +- [Setup Development Environment](#setup-development-environment) +- [Install Develop Version](#install-develop-version) +- [Lint Check](#lint-check) +- [Test Locally](#test-locally) +- [Build Wheels](#build-wheels) +- [Documentation](#documentation) + +## Report Bugs If you run into any weird behavior while using TileLang, feel free to open a new issue in this repository! Please run a **search before opening** a new issue, to make sure that someone else hasn't already reported or solved the bug you've found. @@ -18,35 +23,86 @@ Any issue you open must include: - Code snippet that reproduces the bug with a minimal setup. - A clear explanation of what the issue is. - -## Asking Questions +## Ask Questions Please ask questions in issues. -## Submitting Pull Requests +## Submit Pull Requests All pull requests are super welcomed and greatly appreciated! Issues in need of a solution are marked with a [`♥ help`](https://github.com/ianstormtaylor/TileLang/issues?q=is%3Aissue+is%3Aopen+label%3A%22%E2%99%A5+help%22) label if you're looking for somewhere to start. -Please run `./format.sh` before submitting a pull request to make sure that your code is formatted correctly. +If you're new to contributing to TileLang, you can follow the following guidelines before submitting a pull request. + +> [!NOTE] +> Please include tests and docs with every pull request if applicable! + +## Setup Development Environment + +Before contributing to TileLang, please follow the instructions below to setup. + +1. Fork TileLang ([fork](https://github.com/tile-ai/tilelang/fork)) on GitHub and clone the repository. + + ```bash + git clone --recurse-submodules git@github.com:/tilelang.git # use the SSH protocol + cd tilelang + + git remote add upstream git@github.com:tile-ai/tilelang.git + ``` + +2. Setup a development environment: + + ```bash + uv venv --seed venv # use `python3 -m venv venv` if you don't have `uv` + + source venv/bin/activate + python3 -m pip install --upgrade pip setuptools wheel "build[uv]" + uv pip install --requirements requirements-dev.txt + ``` + +3. Setup the [`pre-commit`](https://pre-commit.com) hooks: + + ```bash + pre-commit install --install-hooks + ``` -Please include tests and docs with every pull request! +Then you are ready to rock. Thanks for contributing to TileLang! -## Repository Setup +## Install Develop Version -To run the build, you need to have the TileLang repository cloned to your computer. After that, you need to `cd` into the directory where you cloned it, and install the dependencies with `python`: +To install TileLang in an "editable" mode, run: ```bash -python setup.py install +python3 -m pip install --no-build-isolation --verbose --editable . ``` +in the main directory. This installation is removable by: -## Running Tests +```bash +python3 -m pip uninstall tilelang +``` + +## Lint Check + +To check the linting, run: + +```bash +pre-commit run --all-files +``` + +## Test Locally -To run the tests, start by building the project as described in the [Repository Setup](CONTRIBUTING.md#repository-setup) section. +To run the tests, start by building the project as described in the [Setup Development Environment](#setup-development-environment) section. Then you can rerun the tests with: -```text -python -m pytest testing +```bash +python3 -m pytest testing ``` +## Build Wheels + +_TBA_ + +## Documentation + +_TBA_ diff --git a/requirements-lint.txt b/requirements-lint.txt index 92f61068d..8025d3ce2 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -1,4 +1,5 @@ # formatting +pre-commit yapf==0.43.0 ruff==0.14.0 codespell[toml]==2.4.1 From a7a62bd00ebe014bb7000c9896e267f08d9b6646 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 10 Oct 2025 12:40:47 +0800 Subject: [PATCH 10/11] chore: update default venv name --- .gitignore | 7 +++++++ CONTRIBUTING.md | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 5bcb6f773..eb96b1622 100644 --- a/.gitignore +++ b/.gitignore @@ -26,7 +26,14 @@ nnfusion.tar.gz # makeenv and test intermediate files tmp/ +.env +.envrc +.venv +env/ venv/ +ENV/ +env.bak/ +venv.bak/ .vscode/ .vs/ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index de2dfb730..e4b45e24b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -52,9 +52,9 @@ Before contributing to TileLang, please follow the instructions below to setup. 2. Setup a development environment: ```bash - uv venv --seed venv # use `python3 -m venv venv` if you don't have `uv` + uv venv --seed .venv # use `python3 -m venv .venv` if you don't have `uv` - source venv/bin/activate + source .venv/bin/activate python3 -m pip install --upgrade pip setuptools wheel "build[uv]" uv pip install --requirements requirements-dev.txt ``` From 04393028b466cf9eb23a1d1c03435183b316a250 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 10 Oct 2025 13:57:49 +0800 Subject: [PATCH 11/11] chore: revert and exclude CUDA files --- .pre-commit-config.yaml | 1 + .../minference/ops/vertical_slash_index.cu | 243 ++++++----- maint/precision/cuda_ops.cu | 400 ++++++++---------- src/tl_templates/cuda/compress_sm90.cu | 148 +++---- src/tl_templates/cuda/cuda_bf16_fallbacks.cuh | 282 ++++++------ 5 files changed, 493 insertions(+), 581 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 32fce4601..2846e58ef 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,6 +34,7 @@ repos: - id: clang-format exclude: | (?ix)( + ^.+\.(cu|cuh)$| ^.+\.json$ ) - repo: https://github.com/astral-sh/ruff-pre-commit diff --git a/examples/minference/ops/vertical_slash_index.cu b/examples/minference/ops/vertical_slash_index.cu index 8f49abe3a..ae01f331b 100644 --- a/examples/minference/ops/vertical_slash_index.cu +++ b/examples/minference/ops/vertical_slash_index.cu @@ -3,142 +3,157 @@ #include -#include #include #include +#include #include #include -__device__ void save_blocks(int *block_offset, int range_start, int range_end, - int block_size, int &block_count) { - for (int idx = range_start; idx < range_end; idx += block_size) { - block_offset[block_count++] = idx; - } +__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) { + for (int idx = range_start; idx < range_end; idx += block_size) { + block_offset[block_count++] = idx; + } } __global__ void convert_vertical_slash_indexes_kernel( - const int *seqlens, // [BATCH, ] - const int *vertical_indexes, // [BATCH, N_HEADS, NNZ_V] - const int *slash_indexes, // [BATCH, N_HEADS, NNZ_S] - int *block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] - int *block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] - int *column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] - int *column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] - int N_HEADS, int N_ROWS, int BLOCK_SIZE_M, int BLOCK_SIZE_N, int NNZ_V, - int NNZ_S) { - const int batch_idx = blockIdx.y; - const int head_idx = blockIdx.x; - const int group_idx = blockIdx.z; + const int* seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int N_HEADS, + int N_ROWS, + int BLOCK_SIZE_M, + int BLOCK_SIZE_N, + int NNZ_V, + int NNZ_S +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; - int seqlen = seqlens[batch_idx]; - int block_idx_m = group_idx * blockDim.x + threadIdx.x; - int start_m = block_idx_m * BLOCK_SIZE_M; - if (start_m >= seqlen) { - return; - } - int end_m = start_m + BLOCK_SIZE_M; - vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; - slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; - int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; - block_count += row_offset; - block_offset += row_offset * NNZ_S; - column_count += row_offset; - column_index += row_offset * NNZ_V; + int seqlen = seqlens[batch_idx]; + int block_idx_m = group_idx * blockDim.x + threadIdx.x; + int start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= seqlen) { + return; + } + int end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; - int tmp_col_cnt = 0, tmp_blk_cnt = 0; - int s = 0, v = 0; - int v_idx = vertical_indexes[v++]; - int s_idx = slash_indexes[s++]; - while (s_idx >= end_m) { - s_idx = slash_indexes[s++]; - } - s_idx = max(end_m - s_idx, BLOCK_SIZE_M); - int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; - while (1) { - if (v_idx < range_end) { - if (v_idx < range_start) { - column_index[tmp_col_cnt++] = v_idx; - } - if (v < NNZ_V) { - v_idx = vertical_indexes[v++]; - } else { - v_idx = end_m + BLOCK_SIZE_M; - } - } else { - if (s < NNZ_S) { - s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); - } else { - save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, - tmp_blk_cnt); - break; - } - if (s_idx > range_end + BLOCK_SIZE_M) { - save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, - tmp_blk_cnt); - range_start = s_idx - BLOCK_SIZE_M; - range_end = s_idx; - } else if (s_idx > range_end) { - range_end += BLOCK_SIZE_M; - } + int tmp_col_cnt = 0, tmp_blk_cnt = 0; + int s = 0, v = 0; + int v_idx = vertical_indexes[v++]; + int s_idx = slash_indexes[s++]; + while (s_idx >= end_m) { + s_idx = slash_indexes[s++]; + } + s_idx = max(end_m - s_idx, BLOCK_SIZE_M); + int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + v_idx = end_m + BLOCK_SIZE_M; + } + } else { + if (s < NNZ_S) { + s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); + break; + } + if (s_idx > range_end + BLOCK_SIZE_M) { + save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } } - } - block_count[0] = tmp_blk_cnt; - column_count[0] = tmp_col_cnt; + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; } void convert_vertical_slash_indexes_64x64( - const int *seqlens, // [BATCH, ] - const int *vertical_indexes, // [BATCH, N_HEADS, NNZ_V] - const int *slash_indexes, // [BATCH, N_HEADS, NNZ_S] - int *block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] - int *block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] - int *column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] - int *column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] - int BATCH_SIZE, int N_HEADS, int N_ROWS, int NNZ_V, int NNZ_S) { - const int BLOCK_SIZE_M = 64; - const int BLOCK_SIZE_N = 64; - const int N_THREADS = 64; - const dim3 dimBlock(N_THREADS); - const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); - convert_vertical_slash_indexes_kernel<<>>( - seqlens, vertical_indexes, slash_indexes, block_count, block_offset, - column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, - NNZ_V, NNZ_S); + const int* seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int BATCH_SIZE, + int N_HEADS, + int N_ROWS, + int NNZ_V, + int NNZ_S +) { + const int BLOCK_SIZE_M = 64; + const int BLOCK_SIZE_N = 64; + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel<<>>( + seqlens, vertical_indexes, slash_indexes, + block_count, block_offset, column_count, column_index, + N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S + ); } std::vector convert_vertical_slash_indexes( - torch::Tensor seqlens, // [BATCH, ] - torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] - torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] - int context_size, int block_size_M, int block_size_N) { - assert(block_size_M == 64); - assert(block_size_N == 64); + torch::Tensor seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int context_size, + int block_size_M, + int block_size_N +) { + assert(block_size_M == 64); + assert(block_size_N == 64); - cudaSetDevice(seqlens.get_device()); + cudaSetDevice(seqlens.get_device()); - int batch_size = slash_indexes.size(0); - int num_heads = slash_indexes.size(1); - int nnz_slash = slash_indexes.size(2); - int nnz_vertical = vertical_indexes.size(2); - int num_rows = (context_size + block_size_M - 1) / block_size_M; + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; - torch::Tensor block_count = - torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); - torch::Tensor block_offset = torch::zeros( - {batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); - torch::Tensor column_count = - torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); - torch::Tensor column_index = torch::zeros( - {batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); + torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); + torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); + torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); + torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); - convert_vertical_slash_indexes_64x64( - seqlens.data_ptr(), vertical_indexes.data_ptr(), - slash_indexes.data_ptr(), block_count.data_ptr(), - block_offset.data_ptr(), column_count.data_ptr(), - column_index.data_ptr(), batch_size, num_heads, num_rows, - nnz_vertical, nnz_slash); + convert_vertical_slash_indexes_64x64( + seqlens.data_ptr(), + vertical_indexes.data_ptr(), + slash_indexes.data_ptr(), + block_count.data_ptr(), + block_offset.data_ptr(), + column_count.data_ptr(), + column_index.data_ptr(), + batch_size, + num_heads, + num_rows, + nnz_vertical, + nnz_slash + ); - return {block_count, block_offset, column_count, column_index}; + return { block_count, block_offset, column_count, column_index }; } diff --git a/maint/precision/cuda_ops.cu b/maint/precision/cuda_ops.cu index 6bd1856af..519335751 100644 --- a/maint/precision/cuda_ops.cu +++ b/maint/precision/cuda_ops.cu @@ -1,294 +1,242 @@ +#include #include #include #include -#include enum OperatorType { - OP_DIV, - OP_RECIPROCAL, - OP_EXP, - OP_LOG, - OP_SIN, - OP_COS, - OP_SQRT, - OP_TANH, - OP_RSQRT, - OP_INV_SQRT + OP_DIV, + OP_RECIPROCAL, + OP_EXP, + OP_LOG, + OP_SIN, + OP_COS, + OP_SQRT, + OP_TANH, + OP_RSQRT, + OP_INV_SQRT }; // ================= 精确版本 device 运算符 ================= -__device__ __forceinline__ float precise_div(float a, float b) { return a / b; } +__device__ __forceinline__ float precise_div(float a, float b) { + return a / b; +} __device__ __forceinline__ float precise_reciprocal(float x) { - return 1.0f / x; -} -__device__ __forceinline__ float precise_exp(float x) { return expf(x); } -__device__ __forceinline__ float precise_log(float x) { return logf(x); } -__device__ __forceinline__ float precise_sin(float x) { return sinf(x); } -__device__ __forceinline__ float precise_cos(float x) { return cosf(x); } -__device__ __forceinline__ float precise_sqrt(float x) { return sqrtf(x); } -__device__ __forceinline__ float precise_tanh(float x) { return tanhf(x); } -__device__ __forceinline__ float precise_rsqrt(float x) { return rsqrtf(x); } + return 1.0f / x; +} +__device__ __forceinline__ float precise_exp(float x) { + return expf(x); +} +__device__ __forceinline__ float precise_log(float x) { + return logf(x); +} +__device__ __forceinline__ float precise_sin(float x) { + return sinf(x); +} +__device__ __forceinline__ float precise_cos(float x) { + return cosf(x); +} +__device__ __forceinline__ float precise_sqrt(float x) { + return sqrtf(x); +} +__device__ __forceinline__ float precise_tanh(float x) { + return tanhf(x); +} +__device__ __forceinline__ float precise_rsqrt(float x) { + return rsqrtf(x); +} __device__ __forceinline__ float precise_inv_sqrt(float x) { - return 1.0f / sqrtf(x); + return 1.0f / sqrtf(x); } // ================= double 精确版本 device 运算符 ================= __device__ __forceinline__ double double_precise_div(double a, double b) { - return a / b; + return a / b; } __device__ __forceinline__ double double_precise_reciprocal(double x) { - return 1.0 / x; + return 1.0 / x; } __device__ __forceinline__ double double_precise_exp(double x) { - return exp(x); + return exp(x); } __device__ __forceinline__ double double_precise_log(double x) { - return log(x); + return log(x); } __device__ __forceinline__ double double_precise_sin(double x) { - return sin(x); + return sin(x); } __device__ __forceinline__ double double_precise_cos(double x) { - return cos(x); + return cos(x); } __device__ __forceinline__ double double_precise_sqrt(double x) { - return sqrt(x); + return sqrt(x); } __device__ __forceinline__ double double_precise_tanh(double x) { - return tanh(x); + return tanh(x); } __device__ __forceinline__ double double_precise_rsqrt(double x) { - return 1.0 / sqrt(x); + return 1.0 / sqrt(x); } __device__ __forceinline__ double double_precise_inv_sqrt(double x) { - return 1.0 / sqrt(x); + return 1.0 / sqrt(x); } // ================= 快速近似版本 device 运算符 ================= __device__ __forceinline__ float fast_div(float a, float b) { - return __fdividef(a, b); + return __fdividef(a, b); } __device__ __forceinline__ float fast_reciprocal(float x) { - float ret; - asm volatile("rcp.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); - return ret; -} -__device__ __forceinline__ float fast_exp(float x) { return __expf(x); } -__device__ __forceinline__ float fast_log(float x) { return __logf(x); } -__device__ __forceinline__ float fast_sin(float x) { return __sinf(x); } -__device__ __forceinline__ float fast_cos(float x) { return __cosf(x); } + float ret; + asm volatile("rcp.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return ret; +} +__device__ __forceinline__ float fast_exp(float x) { + return __expf(x); +} +__device__ __forceinline__ float fast_log(float x) { + return __logf(x); +} +__device__ __forceinline__ float fast_sin(float x) { + return __sinf(x); +} +__device__ __forceinline__ float fast_cos(float x) { + return __cosf(x); +} __device__ __forceinline__ float fast_sqrt(float x) { - float ret; - asm volatile("sqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); - return ret; + float ret; + asm volatile("sqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return ret; +} +__device__ __forceinline__ float fast_tanh(float x) { + return __tanhf(x); } -__device__ __forceinline__ float fast_tanh(float x) { return __tanhf(x); } __device__ __forceinline__ float fast_rsqrt(float x) { - // return rsqrtf(x); - float ret; - asm volatile("rsqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); - return ret; + // return rsqrtf(x); + float ret; + asm volatile("rsqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return ret; } __device__ __forceinline__ float fast_inv_sqrt(float x) { - float ret; - asm volatile("sqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); - return 1.0f / ret; + float ret; + asm volatile("sqrt.approx.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return 1.0f / ret; } // ================= 精确版本 kernel ================= -__global__ void precise_operator_kernel(const float *x, const float *y, - float *result, int64_t n, - OperatorType op_type) { - int64_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { - float a = x[i]; - float b = (y != nullptr) ? y[i] : 0.0f; - float r = 0.0f; - switch (op_type) { - case OP_DIV: - r = precise_div(a, b); - break; - case OP_RECIPROCAL: - r = precise_reciprocal(a); - break; - case OP_EXP: - r = precise_exp(a); - break; - case OP_LOG: - r = precise_log(a); - break; - case OP_SIN: - r = precise_sin(a); - break; - case OP_COS: - r = precise_cos(a); - break; - case OP_SQRT: - r = precise_sqrt(a); - break; - case OP_TANH: - r = precise_tanh(a); - break; - case OP_RSQRT: - r = precise_rsqrt(a); - break; - case OP_INV_SQRT: - r = precise_inv_sqrt(a); - break; +__global__ void precise_operator_kernel(const float* x, const float* y, float* result, int64_t n, OperatorType op_type) { + int64_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + float a = x[i]; + float b = (y != nullptr) ? y[i] : 0.0f; + float r = 0.0f; + switch (op_type) { + case OP_DIV: r = precise_div(a, b); break; + case OP_RECIPROCAL: r = precise_reciprocal(a); break; + case OP_EXP: r = precise_exp(a); break; + case OP_LOG: r = precise_log(a); break; + case OP_SIN: r = precise_sin(a); break; + case OP_COS: r = precise_cos(a); break; + case OP_SQRT: r = precise_sqrt(a); break; + case OP_TANH: r = precise_tanh(a); break; + case OP_RSQRT: r = precise_rsqrt(a); break; + case OP_INV_SQRT: r = precise_inv_sqrt(a); break; + } + result[i] = r; } - result[i] = r; - } } // ================= double 精确版本 kernel ================= -__global__ void double_precise_operator_kernel(const double *x, const double *y, - double *result, int64_t n, - OperatorType op_type) { - int64_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { - double a = x[i]; - double b = (y != nullptr) ? y[i] : 0.0; - double r = 0.0; - switch (op_type) { - case OP_DIV: - r = double_precise_div(a, b); - break; - case OP_RECIPROCAL: - r = double_precise_reciprocal(a); - break; - case OP_EXP: - r = double_precise_exp(a); - break; - case OP_LOG: - r = double_precise_log(a); - break; - case OP_SIN: - r = double_precise_sin(a); - break; - case OP_COS: - r = double_precise_cos(a); - break; - case OP_SQRT: - r = double_precise_sqrt(a); - break; - case OP_TANH: - r = double_precise_tanh(a); - break; - case OP_RSQRT: - r = double_precise_rsqrt(a); - break; - case OP_INV_SQRT: - r = double_precise_inv_sqrt(a); - break; +__global__ void double_precise_operator_kernel(const double* x, const double* y, double* result, int64_t n, OperatorType op_type) { + int64_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + double a = x[i]; + double b = (y != nullptr) ? y[i] : 0.0; + double r = 0.0; + switch (op_type) { + case OP_DIV: r = double_precise_div(a, b); break; + case OP_RECIPROCAL: r = double_precise_reciprocal(a); break; + case OP_EXP: r = double_precise_exp(a); break; + case OP_LOG: r = double_precise_log(a); break; + case OP_SIN: r = double_precise_sin(a); break; + case OP_COS: r = double_precise_cos(a); break; + case OP_SQRT: r = double_precise_sqrt(a); break; + case OP_TANH: r = double_precise_tanh(a); break; + case OP_RSQRT: r = double_precise_rsqrt(a); break; + case OP_INV_SQRT: r = double_precise_inv_sqrt(a); break; + } + result[i] = r; } - result[i] = r; - } } // ================= 快速版本 kernel ================= -__global__ void fast_operator_kernel(const float *x, const float *y, - float *result, int64_t n, - OperatorType op_type) { - int64_t i = blockIdx.x * blockDim.x + threadIdx.x; - if (i < n) { - float a = x[i]; - float b = (y != nullptr) ? y[i] : 0.0f; - float r = 0.0f; - switch (op_type) { - case OP_DIV: - r = fast_div(a, b); - break; - case OP_RECIPROCAL: - r = fast_reciprocal(a); - break; - case OP_EXP: - r = fast_exp(a); - break; - case OP_LOG: - r = fast_log(a); - break; - case OP_SIN: - r = fast_sin(a); - break; - case OP_COS: - r = fast_cos(a); - break; - case OP_SQRT: - r = fast_sqrt(a); - break; - case OP_TANH: - r = fast_tanh(a); - break; - case OP_RSQRT: - r = fast_rsqrt(a); - break; - case OP_INV_SQRT: - r = fast_inv_sqrt(a); - break; +__global__ void fast_operator_kernel(const float* x, const float* y, float* result, int64_t n, OperatorType op_type) { + int64_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { + float a = x[i]; + float b = (y != nullptr) ? y[i] : 0.0f; + float r = 0.0f; + switch (op_type) { + case OP_DIV: r = fast_div(a, b); break; + case OP_RECIPROCAL: r = fast_reciprocal(a); break; + case OP_EXP: r = fast_exp(a); break; + case OP_LOG: r = fast_log(a); break; + case OP_SIN: r = fast_sin(a); break; + case OP_COS: r = fast_cos(a); break; + case OP_SQRT: r = fast_sqrt(a); break; + case OP_TANH: r = fast_tanh(a); break; + case OP_RSQRT: r = fast_rsqrt(a); break; + case OP_INV_SQRT: r = fast_inv_sqrt(a); break; + } + result[i] = r; } - result[i] = r; - } } // 精确版本 -void launch_precise_operator(const at::Tensor &x, - const c10::optional &y, - at::Tensor &result, int op_type) { - int64_t n = x.numel(); - int threads = 256; - int blocks = (n + threads - 1) / threads; - const float *y_ptr = nullptr; - if (y.has_value()) { - y_ptr = y.value().data_ptr(); - } - precise_operator_kernel<<>>( - x.data_ptr(), y_ptr, result.data_ptr(), n, - static_cast(op_type)); +void launch_precise_operator(const at::Tensor& x, const c10::optional& y, at::Tensor& result, int op_type) { + int64_t n = x.numel(); + int threads = 256; + int blocks = (n + threads - 1) / threads; + const float* y_ptr = nullptr; + if (y.has_value()) { + y_ptr = y.value().data_ptr(); + } + precise_operator_kernel<<>>( + x.data_ptr(), y_ptr, result.data_ptr(), n, static_cast(op_type) + ); } // double 精确版本 -void launch_double_precise_operator(const at::Tensor &x, - const c10::optional &y, - at::Tensor &result, int op_type) { - int64_t n = x.numel(); - int threads = 256; - int blocks = (n + threads - 1) / threads; - const double *y_ptr = nullptr; - if (y.has_value()) { - y_ptr = y.value().data_ptr(); - } - double_precise_operator_kernel<<>>( - x.data_ptr(), y_ptr, result.data_ptr(), n, - static_cast(op_type)); +void launch_double_precise_operator(const at::Tensor& x, const c10::optional& y, at::Tensor& result, int op_type) { + int64_t n = x.numel(); + int threads = 256; + int blocks = (n + threads - 1) / threads; + const double* y_ptr = nullptr; + if (y.has_value()) { + y_ptr = y.value().data_ptr(); + } + double_precise_operator_kernel<<>>( + x.data_ptr(), y_ptr, result.data_ptr(), n, static_cast(op_type) + ); } // 快速版本 -void launch_fast_operator(const at::Tensor &x, - const c10::optional &y, - at::Tensor &result, int op_type) { - int64_t n = x.numel(); - int threads = 256; - int blocks = (n + threads - 1) / threads; - const float *y_ptr = nullptr; - if (y.has_value()) { - y_ptr = y.value().data_ptr(); - } - fast_operator_kernel<<>>( - x.data_ptr(), y_ptr, result.data_ptr(), n, - static_cast(op_type)); +void launch_fast_operator(const at::Tensor& x, const c10::optional& y, at::Tensor& result, int op_type) { + int64_t n = x.numel(); + int threads = 256; + int blocks = (n + threads - 1) / threads; + const float* y_ptr = nullptr; + if (y.has_value()) { + y_ptr = y.value().data_ptr(); + } + fast_operator_kernel<<>>( + x.data_ptr(), y_ptr, result.data_ptr(), n, static_cast(op_type) + ); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("launch_precise_operator", &launch_precise_operator, - "CUDA Precise Operator", py::arg("x"), py::arg("y") = c10::nullopt, - py::arg("result"), py::arg("op_type")); - m.def("launch_double_precise_operator", &launch_double_precise_operator, - "CUDA Double Precise Operator", py::arg("x"), - py::arg("y") = c10::nullopt, py::arg("result"), py::arg("op_type")); - m.def("launch_fast_operator", &launch_fast_operator, "CUDA Fast Operator", - py::arg("x"), py::arg("y") = c10::nullopt, py::arg("result"), - py::arg("op_type")); + m.def("launch_precise_operator", &launch_precise_operator, "CUDA Precise Operator", + py::arg("x"), py::arg("y") = c10::nullopt, py::arg("result"), py::arg("op_type")); + m.def("launch_double_precise_operator", &launch_double_precise_operator, "CUDA Double Precise Operator", + py::arg("x"), py::arg("y") = c10::nullopt, py::arg("result"), py::arg("op_type")); + m.def("launch_fast_operator", &launch_fast_operator, "CUDA Fast Operator", + py::arg("x"), py::arg("y") = c10::nullopt, py::arg("result"), py::arg("op_type")); } \ No newline at end of file diff --git a/src/tl_templates/cuda/compress_sm90.cu b/src/tl_templates/cuda/compress_sm90.cu index 3ec1d7aac..8bb236dd8 100644 --- a/src/tl_templates/cuda/compress_sm90.cu +++ b/src/tl_templates/cuda/compress_sm90.cu @@ -13,68 +13,60 @@ using namespace cute; -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - if (error != cutlass::Status::kSuccess) { \ - std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ - << " at: " << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ + << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ } -#define CUDA_CHECK(status) \ - { \ - cudaError_t error = status; \ - if (error != cudaSuccess) { \ - std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ - << " at line: " << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ } -template +template std::tuple compress_impl(torch::Tensor A) { using ElementA = T; using ElementE = uint8_t; - using LayoutTagA = conditional_t; + using LayoutTagA = conditional_t; using ProblemShape = cute::Shape; using StrideA = cutlass::gemm::TagToStrideA_t; using StrideE = StrideA; // NOTE: this is derived from sparse sm90 mma atoms - // Ref: - // https://github.com/NVIDIA/cutlass/blob/dc4817921edda44a549197ff3a9dcf5df0636e7b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp - using SparseE = conditional_t<(sizeof_bits_v == 32), - cute::sparse_elem<4, ElementE>, - cute::sparse_elem<8, ElementE>>; - static constexpr GMMA::Major GmmaMajorA = - transposed ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K; - using SparseConfig = - cutlass::Sm90GemmSparseConfig, GmmaMajorA, - SparseE, cute::C>; + // Ref: https://github.com/NVIDIA/cutlass/blob/dc4817921edda44a549197ff3a9dcf5df0636e7b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp + using SparseE = conditional_t<(sizeof_bits_v == 32), cute::sparse_elem<4, ElementE>, cute::sparse_elem<8, ElementE>>; + static constexpr GMMA::Major GmmaMajorA = transposed ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K; + using SparseConfig = cutlass::Sm90GemmSparseConfig< + cute::sparse_elem<2, ElementA>, GmmaMajorA, + SparseE, cute::C>; using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< ProblemShape, ElementA, LayoutTagA, SparseConfig>; - using CompressorKernel = - cutlass::transform::kernel::StructuredSparseCompressor< - ProblemShape, ElementA, LayoutTagA, SparseConfig, - cutlass::arch::Sm90>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, ElementA, LayoutTagA, SparseConfig, cutlass::arch::Sm90>; - using Compressor = - cutlass::transform::device::TransformUniversalAdapter; + using Compressor = cutlass::transform::device::TransformUniversalAdapter; TORCH_CHECK(A.is_contiguous(), "A need to be contiguous"); TORCH_CHECK(A.dim() == 2, "Might support batch dim in the future "); int M = -1; int K = -1; - int N = -1; // not used, but required for config + int N = -1; // not used, but required for config int L = 1; - if constexpr (transposed) { + if constexpr(transposed) { M = A.size(1); K = A.size(0); } else { @@ -83,27 +75,24 @@ std::tuple compress_impl(torch::Tensor A) { } ProblemShape problem_shape = make_tuple(M, N, K, L); - StrideA stride_A = - cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); CompressorUtility compressor_utility(problem_shape, stride_A); int ME = compressor_utility.get_metadata_m_physical(); int KE = compressor_utility.get_metadata_k_physical(); int KC = compressor_utility.get_tensorA_k_physical(); - StrideE stride_E = - cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); + StrideE stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); auto dtype = A.dtype().toScalarType(); - torch::Tensor A_compressed = torch::zeros( - KC * M, torch::TensorOptions().dtype(dtype).device(A.device())); - torch::Tensor E = torch::zeros( - {ME, KE}, torch::TensorOptions().dtype(torch::kUInt8).device(A.device())); + torch::Tensor A_compressed = torch::zeros(KC * M, + torch::TensorOptions().dtype(dtype).device(A.device())); + torch::Tensor E = torch::zeros({ME, KE}, + torch::TensorOptions().dtype(torch::kUInt8).device(A.device())); cutlass::KernelHardwareInfo hw_info; hw_info.device_id = A.device().index(); hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); typename Compressor::Arguments arguments{problem_shape, { @@ -131,45 +120,40 @@ std::tuple compress_impl(torch::Tensor A) { } // block <= 128 -// Ref -// https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 -#define DISPATCH_BLOCK_K(TYPE, BLOCK_K, FACTOR, TENSOR, TRANSPOSED) \ - [&]() -> std::tuple { \ - switch (BLOCK_K) { \ - case int(32 * FACTOR): \ - return compress_impl(TENSOR); \ - case int(64 * FACTOR): \ - return compress_impl(TENSOR); \ - case int(128 * FACTOR): \ - return compress_impl(TENSOR); \ - default: \ - TORCH_CHECK(false, "Unsupported block_k: ", BLOCK_K); \ - } \ +// Ref https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 +#define DISPATCH_BLOCK_K(TYPE, BLOCK_K, FACTOR, TENSOR, TRANSPOSED) \ + [&]() -> std::tuple { \ + switch (BLOCK_K) { \ + case int(32 * FACTOR): return compress_impl(TENSOR); \ + case int(64 * FACTOR): return compress_impl(TENSOR); \ + case int(128 * FACTOR): return compress_impl(TENSOR); \ + default: \ + TORCH_CHECK(false, "Unsupported block_k: ", BLOCK_K); \ + } \ }() -#define DISPATCH_CONTIGUOUS(TRANSPOSED) \ - [&]() -> std::tuple { \ - switch (dtype) { \ - case torch::kFloat32: \ - return DISPATCH_BLOCK_K(float, block_k, 0.5, A, TRANSPOSED); \ - case torch::kFloat16: \ - case torch::kBFloat16: \ - return DISPATCH_BLOCK_K(cute::half_t, block_k, 1, A, TRANSPOSED); \ - case torch::kFloat8_e4m3fn: \ - return DISPATCH_BLOCK_K(cute::float_e4m3_t, block_k, 2, A, TRANSPOSED); \ - case torch::kFloat8_e5m2: \ - return DISPATCH_BLOCK_K(cute::float_e5m2_t, block_k, 2, A, TRANSPOSED); \ - case torch::kChar: \ - return DISPATCH_BLOCK_K(int8_t, block_k, 2, A, TRANSPOSED); \ - case torch::kByte: \ - return DISPATCH_BLOCK_K(uint8_t, block_k, 2, A, TRANSPOSED); \ - default: \ - TORCH_CHECK(false, "Unsupported dtype"); \ - } \ +#define DISPATCH_CONTIGUOUS(TRANSPOSED) \ + [&]() -> std::tuple { \ + switch (dtype) { \ + case torch::kFloat32: \ + return DISPATCH_BLOCK_K(float, block_k, 0.5, A, TRANSPOSED); \ + case torch::kFloat16: \ + case torch::kBFloat16: \ + return DISPATCH_BLOCK_K(cute::half_t, block_k, 1, A, TRANSPOSED); \ + case torch::kFloat8_e4m3fn: \ + return DISPATCH_BLOCK_K(cute::float_e4m3_t, block_k, 2, A, TRANSPOSED); \ + case torch::kFloat8_e5m2: \ + return DISPATCH_BLOCK_K(cute::float_e5m2_t, block_k, 2, A, TRANSPOSED); \ + case torch::kChar: \ + return DISPATCH_BLOCK_K(int8_t, block_k, 2, A, TRANSPOSED); \ + case torch::kByte: \ + return DISPATCH_BLOCK_K(uint8_t, block_k, 2, A, TRANSPOSED); \ + default: \ + TORCH_CHECK(false, "Unsupported dtype"); \ + } \ }() -std::tuple -compress_sm90(torch::Tensor A, int64_t block_k, bool transposed) { +std::tuple compress_sm90(torch::Tensor A, int64_t block_k, bool transposed) { auto dtype = A.dtype().toScalarType(); return transposed ? DISPATCH_CONTIGUOUS(true) : DISPATCH_CONTIGUOUS(false); } diff --git a/src/tl_templates/cuda/cuda_bf16_fallbacks.cuh b/src/tl_templates/cuda/cuda_bf16_fallbacks.cuh index e6dd9625a..f5641f616 100644 --- a/src/tl_templates/cuda/cuda_bf16_fallbacks.cuh +++ b/src/tl_templates/cuda/cuda_bf16_fallbacks.cuh @@ -26,268 +26,232 @@ namespace fastertransformer { #ifdef ENABLE_BF16 inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = __low2float(val); - f_val.y = __high2float(val); - return f_val; + float2 f_val; + f_val.x = __low2float(val); + f_val.y = __high2float(val); + return f_val; #else - return __bfloat1622float2(val); + return __bfloat1622float2(val); #endif } inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = max(min(__low2float(val), 127.f), -128.f); - f_val.y = max(min(__high2float(val), 127.f), -128.f); - union { - int8_t int8[2]; - int16_t int16; - }; - int8[0] = static_cast(static_cast(f_val.x)); - int8[1] = static_cast(static_cast(f_val.y)); - return int16; + float2 f_val; + f_val.x = max(min(__low2float(val), 127.f), -128.f); + f_val.y = max(min(__high2float(val), 127.f), -128.f); + union { int8_t int8[2]; int16_t int16; }; + int8[0] = static_cast(static_cast(f_val.x)); + int8[1] = static_cast(static_cast(f_val.y)); + return int16; #else - val = __hmin2(val, make_bfloat162(127., 127.)); - val = __hmax2(val, make_bfloat162(-128., -128.)); - union { - int8_t int8[2]; - int16_t int16; - }; - int8[0] = static_cast(static_cast(val.x)); - int8[1] = static_cast(static_cast(val.y)); - return int16; + val = __hmin2(val, make_bfloat162(127., 127.)); + val = __hmax2(val, make_bfloat162(-128., -128.)); + union { int8_t int8[2]; int16_t int16; }; + int8[0] = static_cast(static_cast(val.x)); + int8[1] = static_cast(static_cast(val.y)); + return int16; #endif } inline __device__ __nv_bfloat162 float22bf162(const float2 val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __floats2bfloat162_rn(val.x, val.y); + return __floats2bfloat162_rn(val.x, val.y); #else - return __float22bfloat162_rn(val); + return __float22bfloat162_rn(val); #endif } inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - __nv_bfloat162 val2; - val2.x = val; - val2.y = val; - return val2; + __nv_bfloat162 val2; + val2.x = val; + val2.y = val; + return val2; #else - return __bfloat162bfloat162(val); + return __bfloat162bfloat162(val); #endif } -inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, - const __nv_bfloat162 y) { +inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); #else - return __hadd2(x, y); + return __hadd2(x, y); #endif } -inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, - const __nv_bfloat16 y) { +inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y)); + return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); #else - return __hadd(x, y); + return __hadd(x, y); #endif } -inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, - const __nv_bfloat162 y) { +inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); #else - return __hsub2(x, y); + return __hsub2(x, y); #endif } -inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, - const __nv_bfloat16 y) { +inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y)); + return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) ); #else - return __hsub(x, y); + return __hsub(x, y); #endif } -inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, - const __nv_bfloat162 y) { +inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); #else - return __hmul2(x, y); + return __hmul2(x, y); #endif } -inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, - const __nv_bfloat16 y) { +inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y)); + return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); #else - return __hmul(x, y); + return __hmul(x, y); #endif } -inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, - const __nv_bfloat162 y, - const __nv_bfloat162 z) { +inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh, fzl, fzh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - fzl = __low2float(z); - fzh = __high2float(z); - return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); + float fxl, fxh, fyl, fyh, fzl, fzh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + fzl = __low2float(z); + fzh = __high2float(z); + return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); #else - return __hfma2(x, y, z); + return __hfma2(x, y, z); #endif } -inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, - const __nv_bfloat16 y, - const __nv_bfloat16 z) { +inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + - __bfloat162float(z)); + return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); #else - return __hfma(x, y, z); + return __hfma(x, y, z); #endif } inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh; - fxl = __low2float(x); - fxh = __high2float(x); - ; - return __floats2bfloat162_rn(expf(fxl), expf(fxh)); + float fxl, fxh; + fxl = __low2float(x); + fxh = __high2float(x);; + return __floats2bfloat162_rn(expf(fxl), expf(fxh)); #else - return h2exp(x); + return h2exp(x); #endif } #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, - const __nv_bfloat162 y) { - return bf16hmul2(x, y); -}; -inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, - const __nv_bfloat162 y) { - return bf16hadd2(x, y); -}; +inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); }; +inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); }; -inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, - const __nv_bfloat16 y) { - __nv_bfloat162 t; - t.x = x; - t.y = y; - return t; +inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ + __nv_bfloat162 t; t.x = x; t.y = y; return t; } #endif -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, - __nv_bfloat16 c) { +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + - __bfloat162float(c)); + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); #else - return a + b + c; + return a + b + c; #endif } -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, - __nv_bfloat16 c, __nv_bfloat16 d) { +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + - __bfloat162float(c) + __bfloat162float(d)); + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); #else - return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); + return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); #endif } -inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, - __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); #else - return a + b + c; + return a + b + c; #endif } -inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, - __nv_bfloat16 c) { +inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * - __bfloat162float(c)); + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); #else - return a * b * c; + return a * b * c; #endif } -inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, - __nv_bfloat162 c) { +inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); #else - return a * b * c; + return a * b * c; #endif } -inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, - __nv_bfloat162 c, __nv_bfloat162 d) { +inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - fdl = __low2float(d); - fdh = __high2float(d); - return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); + float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + fdl = __low2float(d); + fdh = __high2float(d); + return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); #else - return a * b * c + d; + return a * b * c + d; #endif } #endif // ENABLE_BF16 -} // namespace fastertransformer +} // namespace fastertransformer