diff --git a/.gitmodules b/.gitmodules index a45558d527..6bdbddf144 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "3rdparty/spdlog"] path = 3rdparty/spdlog url = https://github.com/gabime/spdlog.git +[submodule "3rdparty/cccl"] + path = 3rdparty/cccl + url = https://github.com/NVIDIA/cccl.git diff --git a/3rdparty/cccl b/3rdparty/cccl new file mode 160000 index 0000000000..876867684f --- /dev/null +++ b/3rdparty/cccl @@ -0,0 +1 @@ +Subproject commit 876867684f7fac130e0f5911236e0a92a970d4fd diff --git a/build_backend.py b/build_backend.py index d14f5787f7..9bcac12470 100644 --- a/build_backend.py +++ b/build_backend.py @@ -105,6 +105,7 @@ def ln(source: str, target: str) -> None: ln("3rdparty/cutlass", "cutlass") ln("3rdparty/spdlog", "spdlog") + ln("3rdparty/cccl", "cccl") ln("csrc", "csrc") ln("include", "include") diff --git a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu index 50cabaeacc..74f7d45109 100644 --- a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu @@ -187,11 +187,7 @@ struct KernelTraits<2> { template <> struct KernelTraits<1> { -#if CUDA_VERSION >= 12090 using MaxOp = cuda::maximum<>; -#else - using MaxOp = cub::Max; -#endif using PackedType = float; }; @@ -944,11 +940,7 @@ __global__ void finalizeDeepSeekKernel(KernelParams params) { float constexpr E4m3MaxVal{448.f}; // Compute the absolute max -#if CUDA_VERSION >= 12090 float aMax = BlockReduce(temp_storage).Reduce(fabsf(acc), cuda::maximum<>{}); -#else - float aMax = BlockReduce(temp_storage).Reduce(fabsf(acc), cub::Max{}); -#endif if (threadIdx.x == 0) { if (params.outDqSfsPtr) { diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 9611ac001f..1cc34b0d7b 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -95,12 +95,16 @@ def join_multiline(vs: List[str]) -> str: return " $\n ".join(vs) +def get_cccl_includes() -> List: + """Get vendored CCCL include directories (added with -I for CTK override precedence).""" + return [p.resolve() for p in jit_env.CCCL_INCLUDE_DIRS] + + def get_system_includes(cuda_home: str) -> List: """Get list of system include directories.""" system_includes = [ sysconfig.get_path("include"), "$cuda_home/include", - "$cuda_home/include/cccl", tvm_ffi.libinfo.find_include_path(), tvm_ffi.libinfo.find_dlpack_include_path(), jit_env.FLASHINFER_INCLUDE_DIR.resolve(), @@ -121,6 +125,7 @@ def build_common_cflags( extra_include_dirs: Optional[List[Path]] = None, ) -> List[str]: """Build common compilation flags.""" + cccl_includes = get_cccl_includes() system_includes = get_system_includes(cuda_home) common_cflags = [] @@ -130,6 +135,11 @@ def build_common_cflags( if extra_include_dirs is not None: for extra_dir in extra_include_dirs: common_cflags.append(f"-I{extra_dir.resolve()}") + # Vendored CCCL headers use -I (not -isystem) so they take precedence + # over the CTK-bundled copy. CCCL headers use #pragma system_header + # internally to suppress warnings. See https://github.com/NVIDIA/cccl/issues/527 + for cccl_dir in cccl_includes: + common_cflags.append(f"-I{cccl_dir}") for sys_dir in system_includes: common_cflags.append(f"-isystem {sys_dir}") diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index 8cb4c0faa1..16194339a8 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -152,3 +152,8 @@ def _get_workspace_dir_name() -> pathlib.Path: _package_root / "data" / "cutlass" / "tools" / "util" / "include", ] SPDLOG_INCLUDE_DIR: pathlib.Path = _package_root / "data" / "spdlog" / "include" +CCCL_INCLUDE_DIRS: list[pathlib.Path] = [ + _package_root / "data" / "cccl" / "cub", + _package_root / "data" / "cccl" / "libcudacxx" / "include", + _package_root / "data" / "cccl" / "thrust", +] diff --git a/include/flashinfer/fastdiv.cuh b/include/flashinfer/fastdiv.cuh index 305241c1af..4a3984ef65 100644 --- a/include/flashinfer/fastdiv.cuh +++ b/include/flashinfer/fastdiv.cuh @@ -1,9 +1,5 @@ /* - * Copyright 2014 Maxim Milakov - * - * The code is based on the Chapter 10 of Hacker's Delight book by Henry S. Warren, Jr. - * The struct is adapted from https://github.com/milakov/int_fastdiv/blob/master/int_fastdiv.h - * by Maxim Milakov, the difference is that here we use uint32_t instead of int32_t. + * Copyright (c) 2024 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,90 +16,42 @@ #ifndef FLASHINFER_FASTDIV_CUH_ #define FLASHINFER_FASTDIV_CUH_ #include +#include namespace flashinfer { +// API-compatible wrapper around cuda::fast_mod_div. +// Preserves the default constructor, implicit conversions, and divmod() +// method expected by existing call sites throughout the attention kernels. struct uint_fastdiv { - uint32_t d; - uint32_t m; - uint32_t s; - uint32_t a; - - __host__ __device__ uint_fastdiv() : d(0), m(0), s(0), a(0) {} + __host__ __device__ uint_fastdiv() : impl_(1), d_(0) {} - __host__ uint_fastdiv(uint32_t d) : d(d) { - unsigned int p, nc, delta, q1, r1, q2, r2; - a = 0; - nc = unsigned(-1) - unsigned(-d) % d; - p = 31; - q1 = 0x80000000 / nc; - r1 = 0x80000000 - q1 * nc; - q2 = 0x7FFFFFFF / d; - r2 = 0x7FFFFFFF - q2 * d; - do { - p++; - if (r1 >= nc - r1) { - q1 = 2 * q1 + 1; - r1 = 2 * r1 - nc; - } else { - q1 = 2 * q1; - r1 = 2 * r1; - } - if (r2 + 1 >= d - r2) { - if (q2 >= 0x7FFFFFFF) a = 1; - q2 = 2 * q2 + 1; - r2 = 2 * r2 + 1 - d; - } else { - if (q2 >= 0x80000000) a = 1; - q2 = 2 * q2; - r2 = 2 * r2 + 1; - } - delta = d - 1 - r2; - } while (p < 64 && (q1 < delta || (q1 == delta && r1 == 0))); - m = q2 + 1; - s = p - 32; - } + __host__ uint_fastdiv(uint32_t d) : impl_(d ? d : 1), d_(d) {} - __host__ __device__ __forceinline__ operator unsigned int() const { return d; } + __host__ __device__ __forceinline__ operator unsigned int() const { return d_; } __host__ __device__ __forceinline__ void divmod(uint32_t n, uint32_t& q, uint32_t& r) const { - if (d == 1) { - q = n; - } else { -#ifdef __CUDA_ARCH__ - q = __umulhi(m, n); -#else - q = (((unsigned long long)((long long)m * (long long)n)) >> 32); -#endif - q += a * n; - q >>= s; - } - r = n - q * d; + q = n / impl_; + r = n - q * d_; } + + private: + cuda::fast_mod_div impl_; + uint32_t d_; }; __host__ __device__ __forceinline__ uint32_t operator/(const uint32_t n, const uint_fastdiv& divisor) { - uint32_t q; - if (divisor.d == 1) { - q = n; - } else { -#ifdef __CUDA_ARCH__ - q = __umulhi(divisor.m, n); -#else - q = (((unsigned long long)((long long)divisor.m * (long long)n)) >> 32); -#endif - q += divisor.a * n; - q >>= divisor.s; - } + uint32_t q, r; + divisor.divmod(n, q, r); return q; } __host__ __device__ __forceinline__ uint32_t operator%(const uint32_t n, const uint_fastdiv& divisor) { - uint32_t quotient = n / divisor; - uint32_t remainder = n - quotient * divisor; - return remainder; + uint32_t q, r; + divisor.divmod(n, q, r); + return r; } } // namespace flashinfer diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 2dd0ae33b0..67100d163c 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -37,15 +37,8 @@ #include "utils.cuh" #include "vec_dtypes.cuh" -// Define reduction operators based on CUDA version -// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum -#if CUDA_VERSION >= 12090 using MaxReduceOp = cuda::maximum<>; using MinReduceOp = cuda::minimum<>; -#else -using MaxReduceOp = cub::Max; -using MinReduceOp = cub::Min; -#endif namespace flashinfer { diff --git a/include/flashinfer/trtllm/fmha/kernelParams.h b/include/flashinfer/trtllm/fmha/kernelParams.h index e69d583854..541c6f7614 100644 --- a/include/flashinfer/trtllm/fmha/kernelParams.h +++ b/include/flashinfer/trtllm/fmha/kernelParams.h @@ -25,6 +25,7 @@ #include #include +#include #include #include "../../utils.cuh" @@ -33,33 +34,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -// -// CCCL >= 3.1.0 (CUDA CTK 13.1) introduces the fast_mod_div math operations. -// The following code makes sure that the host initialization works with older CUDA CTK versions. -// - -// Refer to -// https://github.com/NVIDIA/cccl/blob/main/libcudacxx/include/cuda/__cmath/fast_modulo_division.h#L76-L81 -// about how to compute the fast modulo division. -struct FastModDivInt32 { - public: - FastModDivInt32(int32_t divisor) : mDivisor(divisor) { - mShift = std::max(ceilLog2(mDivisor) - 1, 0); - mMultiplier = static_cast( - flashinfer::ceil_div(uint64_t(1) << (32 + mShift), static_cast(mDivisor))); - } - - private: - int32_t ceilLog2(int32_t value) const { - return static_cast(std::ceil(std::log2(value))); - } - - private: - int32_t mDivisor = 1; - uint32_t mMultiplier = 0; - uint32_t mAdd = 0; - int32_t mShift = 0; -}; +using FastModDivInt32 = cuda::fast_mod_div; //////////////////////////////////////////////////////////////////////////////////////////////////// using Dtype = Data_type; diff --git a/include/flashinfer/trtllm/fmha/lse.cuh b/include/flashinfer/trtllm/fmha/lse.cuh index b41d084ace..b031ad9207 100644 --- a/include/flashinfer/trtllm/fmha/lse.cuh +++ b/include/flashinfer/trtllm/fmha/lse.cuh @@ -18,43 +18,23 @@ limitations under the License. #include +#include +#include + #include "../../math.cuh" #include "../../utils.cuh" namespace flashinfer { -__global__ void ComputeLSEFromMDKernel(float2* __restrict__ md, float* __restrict__ lse, int n) { - int elem_idx = blockIdx.x * blockDim.x + threadIdx.x; - if (elem_idx >= n) return; -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); -#endif - float2 md_elem = md[elem_idx]; - float m = md_elem.x; - float d = md_elem.y; - lse[elem_idx] = math::log2e * m + math::ptx_log2(d); -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); -#endif -} +struct MDToLSE { + __host__ __device__ float operator()(float2 md_elem) const { + return math::log2e * md_elem.x + log2f(md_elem.y); + } +}; -inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool launch_with_pdl, +inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool /*launch_with_pdl*/, cudaStream_t stream) { - int num_threads = std::min(1024, UpPowerOfTwo(n)); - int num_blocks = ceil_div(n, num_threads); - cudaLaunchConfig_t config; - config.gridDim = num_blocks; - config.blockDim = num_threads; - config.dynamicSmemBytes = 0; - config.stream = stream; - cudaLaunchAttribute attrs[1]; - attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = launch_with_pdl; - config.numAttrs = 1; - config.attrs = attrs; - - FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, ComputeLSEFromMDKernel, md, lse, n)); - return cudaSuccess; + return cub::DeviceTransform::Transform(md, lse, n, MDToLSE{}, stream); } }; // namespace flashinfer diff --git a/pyproject.toml b/pyproject.toml index 0c85cbe13d..a854c1de4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,12 +56,14 @@ exclude = ["flashinfer-jit-cache*", "flashinfer-cubin*"] "flashinfer.data" = "." "flashinfer.data.cutlass" = "3rdparty/cutlass" "flashinfer.data.spdlog" = "3rdparty/spdlog" +"flashinfer.data.cccl" = "3rdparty/cccl" [tool.setuptools.package-data] "flashinfer" = ["_build_meta.py"] "flashinfer.data" = ["csrc/**", "include/**"] "flashinfer.data.cutlass" = ["include/**", "tools/util/include/**"] "flashinfer.data.spdlog" = ["include/**"] +"flashinfer.data.cccl" = ["cub/cub/**", "libcudacxx/include/**", "thrust/thrust/**"] [tool.mypy] files = ["flashinfer"] diff --git a/scripts/modal_runner.py b/scripts/modal_runner.py index 574d006ff1..8897a848ee 100644 --- a/scripts/modal_runner.py +++ b/scripts/modal_runner.py @@ -113,6 +113,19 @@ def _run_flashinfer_command(command: str) -> str: check=True, ) + if not os.path.exists("3rdparty/cccl/cub"): + print("=== Downloading CCCL ===") + subprocess.run( + [ + "git", + "clone", + "--depth=1", + "https://github.com/NVIDIA/cccl.git", + "3rdparty/cccl", + ], + check=True, + ) + # Run the user command print(f"=== Running command: {command} ===") import shlex