Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "3rdparty/spdlog"]
path = 3rdparty/spdlog
url = https://github.com/gabime/spdlog.git
[submodule "3rdparty/cccl"]
path = 3rdparty/cccl
url = https://github.com/NVIDIA/cccl.git
1 change: 1 addition & 0 deletions 3rdparty/cccl
Submodule cccl added at 876867
1 change: 1 addition & 0 deletions build_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def ln(source: str, target: str) -> None:

ln("3rdparty/cutlass", "cutlass")
ln("3rdparty/spdlog", "spdlog")
ln("3rdparty/cccl", "cccl")
ln("csrc", "csrc")
ln("include", "include")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,7 @@ struct KernelTraits<2> {

template <>
struct KernelTraits<1> {
#if CUDA_VERSION >= 12090
using MaxOp = cuda::maximum<>;
#else
using MaxOp = cub::Max;
#endif
using PackedType = float;
};

Expand Down Expand Up @@ -944,11 +940,7 @@ __global__ void finalizeDeepSeekKernel(KernelParams params) {
float constexpr E4m3MaxVal{448.f};

// Compute the absolute max
#if CUDA_VERSION >= 12090
float aMax = BlockReduce(temp_storage).Reduce(fabsf(acc), cuda::maximum<>{});
#else
float aMax = BlockReduce(temp_storage).Reduce(fabsf(acc), cub::Max{});
#endif

if (threadIdx.x == 0) {
if (params.outDqSfsPtr) {
Expand Down
12 changes: 11 additions & 1 deletion flashinfer/jit/cpp_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,16 @@ def join_multiline(vs: List[str]) -> str:
return " $\n ".join(vs)


def get_cccl_includes() -> List:
Comment thread
kahyunnam marked this conversation as resolved.
"""Get vendored CCCL include directories (added with -I for CTK override precedence)."""
return [p.resolve() for p in jit_env.CCCL_INCLUDE_DIRS]


def get_system_includes(cuda_home: str) -> List:
"""Get list of system include directories."""
system_includes = [
sysconfig.get_path("include"),
"$cuda_home/include",
"$cuda_home/include/cccl",
tvm_ffi.libinfo.find_include_path(),
tvm_ffi.libinfo.find_dlpack_include_path(),
jit_env.FLASHINFER_INCLUDE_DIR.resolve(),
Expand All @@ -121,6 +125,7 @@ def build_common_cflags(
extra_include_dirs: Optional[List[Path]] = None,
) -> List[str]:
"""Build common compilation flags."""
cccl_includes = get_cccl_includes()
system_includes = get_system_includes(cuda_home)

common_cflags = []
Expand All @@ -130,6 +135,11 @@ def build_common_cflags(
if extra_include_dirs is not None:
for extra_dir in extra_include_dirs:
common_cflags.append(f"-I{extra_dir.resolve()}")
# Vendored CCCL headers use -I (not -isystem) so they take precedence
# over the CTK-bundled copy. CCCL headers use #pragma system_header
# internally to suppress warnings. See https://github.com/NVIDIA/cccl/issues/527
for cccl_dir in cccl_includes:
common_cflags.append(f"-I{cccl_dir}")
for sys_dir in system_includes:
common_cflags.append(f"-isystem {sys_dir}")

Expand Down
5 changes: 5 additions & 0 deletions flashinfer/jit/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,8 @@ def _get_workspace_dir_name() -> pathlib.Path:
_package_root / "data" / "cutlass" / "tools" / "util" / "include",
]
SPDLOG_INCLUDE_DIR: pathlib.Path = _package_root / "data" / "spdlog" / "include"
CCCL_INCLUDE_DIRS: list[pathlib.Path] = [
_package_root / "data" / "cccl" / "cub",
_package_root / "data" / "cccl" / "libcudacxx" / "include",
_package_root / "data" / "cccl" / "thrust",
]
90 changes: 19 additions & 71 deletions include/flashinfer/fastdiv.cuh
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
/*
* Copyright 2014 Maxim Milakov
*
* The code is based on the Chapter 10 of Hacker's Delight book by Henry S. Warren, Jr.
* The struct is adapted from https://github.com/milakov/int_fastdiv/blob/master/int_fastdiv.h
* by Maxim Milakov, the difference is that here we use uint32_t instead of int32_t.
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,90 +16,42 @@
#ifndef FLASHINFER_FASTDIV_CUH_
#define FLASHINFER_FASTDIV_CUH_
#include <cstdint>
#include <cuda/cmath>

namespace flashinfer {

// API-compatible wrapper around cuda::fast_mod_div<uint32_t>.
// Preserves the default constructor, implicit conversions, and divmod()
// method expected by existing call sites throughout the attention kernels.
struct uint_fastdiv {
uint32_t d;
uint32_t m;
uint32_t s;
uint32_t a;

__host__ __device__ uint_fastdiv() : d(0), m(0), s(0), a(0) {}
__host__ __device__ uint_fastdiv() : impl_(1), d_(0) {}

__host__ uint_fastdiv(uint32_t d) : d(d) {
unsigned int p, nc, delta, q1, r1, q2, r2;
a = 0;
nc = unsigned(-1) - unsigned(-d) % d;
p = 31;
q1 = 0x80000000 / nc;
r1 = 0x80000000 - q1 * nc;
q2 = 0x7FFFFFFF / d;
r2 = 0x7FFFFFFF - q2 * d;
do {
p++;
if (r1 >= nc - r1) {
q1 = 2 * q1 + 1;
r1 = 2 * r1 - nc;
} else {
q1 = 2 * q1;
r1 = 2 * r1;
}
if (r2 + 1 >= d - r2) {
if (q2 >= 0x7FFFFFFF) a = 1;
q2 = 2 * q2 + 1;
r2 = 2 * r2 + 1 - d;
} else {
if (q2 >= 0x80000000) a = 1;
q2 = 2 * q2;
r2 = 2 * r2 + 1;
}
delta = d - 1 - r2;
} while (p < 64 && (q1 < delta || (q1 == delta && r1 == 0)));
m = q2 + 1;
s = p - 32;
}
__host__ uint_fastdiv(uint32_t d) : impl_(d ? d : 1), d_(d) {}

__host__ __device__ __forceinline__ operator unsigned int() const { return d; }
__host__ __device__ __forceinline__ operator unsigned int() const { return d_; }

__host__ __device__ __forceinline__ void divmod(uint32_t n, uint32_t& q, uint32_t& r) const {
if (d == 1) {
q = n;
} else {
#ifdef __CUDA_ARCH__
q = __umulhi(m, n);
#else
q = (((unsigned long long)((long long)m * (long long)n)) >> 32);
#endif
q += a * n;
q >>= s;
}
r = n - q * d;
q = n / impl_;
r = n - q * d_;
}

private:
cuda::fast_mod_div<uint32_t> impl_;
uint32_t d_;
};

__host__ __device__ __forceinline__ uint32_t operator/(const uint32_t n,
const uint_fastdiv& divisor) {
uint32_t q;
if (divisor.d == 1) {
q = n;
} else {
#ifdef __CUDA_ARCH__
q = __umulhi(divisor.m, n);
#else
q = (((unsigned long long)((long long)divisor.m * (long long)n)) >> 32);
#endif
q += divisor.a * n;
q >>= divisor.s;
}
uint32_t q, r;
divisor.divmod(n, q, r);
return q;
}

__host__ __device__ __forceinline__ uint32_t operator%(const uint32_t n,
const uint_fastdiv& divisor) {
uint32_t quotient = n / divisor;
uint32_t remainder = n - quotient * divisor;
return remainder;
uint32_t q, r;
divisor.divmod(n, q, r);
return r;
}

} // namespace flashinfer
Expand Down
7 changes: 0 additions & 7 deletions include/flashinfer/sampling.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,8 @@
#include "utils.cuh"
#include "vec_dtypes.cuh"

// Define reduction operators based on CUDA version
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
#if CUDA_VERSION >= 12090
using MaxReduceOp = cuda::maximum<>;
using MinReduceOp = cuda::minimum<>;
#else
using MaxReduceOp = cub::Max;
using MinReduceOp = cub::Min;
#endif

namespace flashinfer {

Expand Down
29 changes: 2 additions & 27 deletions include/flashinfer/trtllm/fmha/kernelParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <cmath>
#include <cstdint>
#include <cuda/cmath>
#include <cute/tensor.hpp>

#include "../../utils.cuh"
Expand All @@ -33,33 +34,7 @@

////////////////////////////////////////////////////////////////////////////////////////////////////

//
// CCCL >= 3.1.0 (CUDA CTK 13.1) introduces the fast_mod_div math operations.
// The following code makes sure that the host initialization works with older CUDA CTK versions.
//

// Refer to
// https://github.com/NVIDIA/cccl/blob/main/libcudacxx/include/cuda/__cmath/fast_modulo_division.h#L76-L81
// about how to compute the fast modulo division.
struct FastModDivInt32 {
public:
FastModDivInt32(int32_t divisor) : mDivisor(divisor) {
mShift = std::max(ceilLog2(mDivisor) - 1, 0);
mMultiplier = static_cast<uint32_t>(
flashinfer::ceil_div(uint64_t(1) << (32 + mShift), static_cast<uint64_t>(mDivisor)));
}

private:
int32_t ceilLog2(int32_t value) const {
return static_cast<int32_t>(std::ceil(std::log2(value)));
}

private:
int32_t mDivisor = 1;
uint32_t mMultiplier = 0;
uint32_t mAdd = 0;
int32_t mShift = 0;
};
using FastModDivInt32 = cuda::fast_mod_div<int32_t>;

////////////////////////////////////////////////////////////////////////////////////////////////////
using Dtype = Data_type;
Expand Down
40 changes: 10 additions & 30 deletions include/flashinfer/trtllm/fmha/lse.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,23 @@ limitations under the License.

#include <cuda.h>

#include <cmath>
#include <cub/device/device_transform.cuh>

#include "../../math.cuh"
#include "../../utils.cuh"

namespace flashinfer {

__global__ void ComputeLSEFromMDKernel(float2* __restrict__ md, float* __restrict__ lse, int n) {
int elem_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (elem_idx >= n) return;
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
float2 md_elem = md[elem_idx];
float m = md_elem.x;
float d = md_elem.y;
lse[elem_idx] = math::log2e * m + math::ptx_log2(d);
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
struct MDToLSE {
__host__ __device__ float operator()(float2 md_elem) const {
return math::log2e * md_elem.x + log2f(md_elem.y);
}
};

inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool launch_with_pdl,
inline cudaError_t ComputeLSEFromMD(float2* md, float* lse, int n, bool /*launch_with_pdl*/,
Copy link
Copy Markdown
Member Author

@kahyunnam kahyunnam Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers: launch_with_pdl is unused β€” DeviceTransform enables PDL unconditionally on SM90+ via its internal launcher. On pre-Hopper GPUs, the PDL instructions compile to no-ops. This means callers that pass false will still get PDL when the GPU supports it, which is probably harmless (PDL is a performance hint, not semantic).

cudaStream_t stream) {
int num_threads = std::min(1024, UpPowerOfTwo(n));
int num_blocks = ceil_div(n, num_threads);
cudaLaunchConfig_t config;
config.gridDim = num_blocks;
config.blockDim = num_threads;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = launch_with_pdl;
config.numAttrs = 1;
config.attrs = attrs;

FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, ComputeLSEFromMDKernel, md, lse, n));
return cudaSuccess;
return cub::DeviceTransform::Transform(md, lse, n, MDToLSE{}, stream);
}

}; // namespace flashinfer
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ exclude = ["flashinfer-jit-cache*", "flashinfer-cubin*"]
"flashinfer.data" = "."
"flashinfer.data.cutlass" = "3rdparty/cutlass"
"flashinfer.data.spdlog" = "3rdparty/spdlog"
"flashinfer.data.cccl" = "3rdparty/cccl"

[tool.setuptools.package-data]
"flashinfer" = ["_build_meta.py"]
"flashinfer.data" = ["csrc/**", "include/**"]
"flashinfer.data.cutlass" = ["include/**", "tools/util/include/**"]
"flashinfer.data.spdlog" = ["include/**"]
"flashinfer.data.cccl" = ["cub/cub/**", "libcudacxx/include/**", "thrust/thrust/**"]

[tool.mypy]
files = ["flashinfer"]
Expand Down
13 changes: 13 additions & 0 deletions scripts/modal_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ def _run_flashinfer_command(command: str) -> str:
check=True,
)

if not os.path.exists("3rdparty/cccl/cub"):
print("=== Downloading CCCL ===")
subprocess.run(
[
"git",
"clone",
"--depth=1",
"https://github.com/NVIDIA/cccl.git",
"3rdparty/cccl",
],
check=True,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# Run the user command
print(f"=== Running command: {command} ===")
import shlex
Expand Down
Loading