Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "attention-rs"
version = "0.4.1"
version = "0.4.3"
edition = "2021"
description = "High-performance LLM attention kernels and operations (PagedAttention, Flahinfer, Mamba, MoE, RoPE) for Candle, optimized for CUDA and Metal."
repository = "https://github.com/guoqingbao/attention.rs"
Expand All @@ -12,25 +12,25 @@ categories = ["algorithms", "hardware-support", "science"]
license = "MIT"

[dependencies]
candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "157b048" }
candle-nn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "157b048" }
candle-flash-attn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "157b048", optional = true }
candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "5bed038" }
candle-nn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "5bed038" }
#candle-flash-attn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "5bed038", optional = true }
serde = { version = "1.0.190", features = ["serde_derive"] }
serde_json = "1.0.108"
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
tracing = "0.1.40"
parking_lot = "0.12.4"
rayon="1.10.0"
kernels = { path = "./src/kernels", version="0.4.1", optional = true}
kernels = { path = "./src/kernels", version="0.4.2", optional = true}
metal = { version = "0.27.0", features = ["mps"], optional = true }
metal-kernels = { path = "./src/metal-kernels", version="0.1.9", optional = true}
flashattn-rs = { git = "https://github.com/guoqingbao/flashattn.rs.git", version="0.1.0", rev = "a59e803", optional = true }

[features]
cuda = ["candle-core/cuda", "candle-nn/cuda", "dep:kernels"]
graph = ["cuda", "candle-core/graph"]
flash-attn = ["dep:candle-flash-attn"]
flash-decoding = ["dep:candle-flash-attn", "candle-flash-attn/flash-decoding", "kernels/no-fp8-kvcache"]
flash-context = ["dep:candle-flash-attn", "candle-flash-attn/flash-context", "kernels/no-fp8-kvcache"]
flashattn = ["dep:flashattn-rs", "flashattn-rs/flash-context", "kernels/no-fp8-kvcache"]
flash-decoding = ["dep:flashattn-rs", "flashattn-rs/flash-decoding", "kernels/no-fp8-kvcache"]
no-marlin = ["dep:kernels", "kernels/no-marlin"]
no-fp8-kvcache = ["dep:kernels", "kernels/no-fp8-kvcache"]
metal = ["candle-core/metal", "candle-nn/metal", "dep:metal-kernels", "dep:metal"]
Expand Down
2 changes: 1 addition & 1 deletion ReadMe.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ attention-rs = { git = "https://github.com/guoqingbao/attention.rs" }

- `cuda`: Enable CUDA kernels and optimizations.
- `metal`: Enable Metal kernels for Apple Silicon.
- `flash-attn`: Enable Flash Attention integration.
- `flashattn`: Enable Flash Attention integration.
- `flashinfer`: Enable FlashInfer integration.
- `cutlass`: Enable CUTLASS-optimized FP8 kernels (requires CUDA).

Expand Down
6 changes: 3 additions & 3 deletions src/flashinfer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ thread_local! {
}

fn is_supported_flashinfer_gqa_group_size(group_size: usize) -> bool {
matches!(group_size, 1 | 2 | 3 | 4 | 8 | 16 | 32 | 64)
matches!(group_size, 1 | 2 | 3 | 4 | 6 | 8 | 16 | 32 | 64)
}

fn is_supported_flashinfer_decode_group_size(group_size: usize) -> bool {
matches!(group_size, 1 | 2 | 3 | 4 | 8 | 16 | 32 | 64)
matches!(group_size, 1 | 2 | 3 | 4 | 6 | 8 | 16 | 32 | 64)
}

fn is_supported_flashinfer_decode_shape(group_size: usize, head_dim: usize) -> bool {
Expand Down Expand Up @@ -811,7 +811,7 @@ impl FlashInferPrefill {
let group_size = self.num_qo_heads / self.num_kv_heads;
if !is_supported_flashinfer_gqa_group_size(group_size) {
candle::bail!(
"flashinfer prefill only supports gqa group_size in [1,2,3,4,8,16,32,64], got {}",
"flashinfer prefill only supports gqa group_size in [1,2,3,4,6,8,16,32,64], got {}",
group_size
);
}
Expand Down
173 changes: 171 additions & 2 deletions src/fp8_linear.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,61 @@
#[cfg(all(feature = "cuda", feature = "cutlass"))]
#[cfg(feature = "cuda")]
use crate::cuda_utils;
#[cfg(feature = "cuda")]
use crate::kernels::ffi;
#[cfg(feature = "metal")]
use crate::metal_kernels;
#[cfg(all(feature = "cuda", feature = "flashinfer"))]
use candle_core::cuda_backend::cudarc::driver::CudaSlice;
#[cfg(feature = "cuda")]
use candle_core::cuda_backend::cudarc::driver::DevicePtr;
#[cfg(feature = "cuda")]
use candle_core::cuda_backend::WrapErr;
use candle_core::{DType, Device, Result, Tensor};
#[cfg(all(feature = "cuda", feature = "flashinfer"))]
use std::cell::RefCell;

#[cfg(all(feature = "cuda", feature = "flashinfer"))]
struct FlashInferFp8Workspace {
buffer: CudaSlice<u8>,
size: usize,
device_ordinal: usize,
}

#[cfg(all(feature = "cuda", feature = "cutlass"))]
#[cfg(all(feature = "cuda", feature = "flashinfer"))]
thread_local! {
static FLASHINFER_FP8_WORKSPACE: RefCell<Option<FlashInferFp8Workspace>> = const { RefCell::new(None) };
}

#[cfg(all(feature = "cuda", feature = "flashinfer"))]
fn get_or_init_flashinfer_fp8_workspace(
dev: &candle_core::cuda_backend::CudaDevice,
required_size: usize,
) -> Result<(*mut std::ffi::c_void, usize)> {
FLASHINFER_FP8_WORKSPACE.with(|cell| {
let mut slot = cell.borrow_mut();
let ordinal = dev.ordinal();

let needs_init = match slot.as_ref() {
None => true,
Some(existing) => existing.device_ordinal != ordinal || existing.size < required_size,
};

if needs_init {
let alloc_size = required_size.max(1);
let buffer = unsafe { dev.alloc::<u8>(alloc_size) }.w()?;
*slot = Some(FlashInferFp8Workspace {
buffer,
size: alloc_size,
device_ordinal: ordinal,
});
}

let ws = slot.as_ref().unwrap();
Ok((*ws.buffer.device_ptr() as *mut std::ffi::c_void, ws.size))
})
}

#[cfg(feature = "cuda")]
fn get_cuda_slice<
T: candle_core::cuda_backend::cudarc::driver::DeviceRepr + candle_core::cuda_backend::CudaDType,
>(
Expand Down Expand Up @@ -222,6 +269,128 @@ pub fn fp8_matmul(
Ok(output)
}

/// FP8 Matrix Multiplication using FlashInfer/TensorRT-LLM SM90 blockwise GEMM.
///
/// This path expects Hopper-native blockwise scales in `[N/128, K/128]` layout and
/// relies on the underlying runner's small-`M` swapAB optimization for decode.
#[cfg(all(feature = "cuda", feature = "flashinfer"))]
pub fn fp8_matmul_flashinfer(
input: &Tensor,
weight: &Tensor,
weight_scale: &Tensor,
) -> Result<Tensor> {
let (m, k) = input.dims2()?;
let (n, k_w) = weight.dims2()?;

if k != k_w {
candle_core::bail!(
"Shape mismatch in fp8_matmul_flashinfer: input [{}, {}], weight [{}, {}]",
m,
k,
n,
k_w
);
}

if input.dtype() != DType::BF16 {
candle_core::bail!("fp8_matmul_flashinfer requires bf16 input");
}
if weight.dtype() != DType::U8 || weight_scale.dtype() != DType::F32 {
candle_core::bail!("fp8_matmul_flashinfer requires u8 weights and f32 scales");
}
if !input.is_contiguous() {
candle_core::bail!("fp8_matmul_flashinfer requires contiguous input");
}
if !weight.is_contiguous() {
candle_core::bail!("fp8_matmul_flashinfer requires contiguous row-major weight");
}
if !weight_scale.is_contiguous() {
candle_core::bail!("fp8_matmul_flashinfer requires contiguous row-major weight_scale");
}
if k % 128 != 0 {
candle_core::bail!("fp8_matmul_flashinfer requires K divisible by 128");
}
if n % 64 != 0 {
candle_core::bail!("fp8_matmul_flashinfer requires N divisible by 64");
}

let expected_scale = ((n + 127) / 128, k / 128);
if weight_scale.dims2()? != expected_scale {
candle_core::bail!(
"fp8_matmul_flashinfer expects weight_scale shape [{}, {}], got {:?}",
expected_scale.0,
expected_scale.1,
weight_scale.dims()
);
}

let dev = input.device();
let sm_version = cuda_utils::sm_version(dev.as_cuda_device()?).unwrap_or(0) as usize;
if !(90..100).contains(&sm_version) {
candle_core::bail!("fp8_matmul_flashinfer requires Hopper (sm90)");
}

let cu_dev = dev.as_cuda_device()?;
let stream = *cu_dev.cu_stream() as i64;
let m_padded = (m + 4 - 1) / 4 * 4;
let out = Tensor::zeros((m, n), DType::BF16, dev)?;
let k_over_128 = k / 128;
let input_q = Tensor::zeros((m, k), DType::U8, dev)?;
// FlashInfer/DeepGEMM expects scales_a to use an M-aligned leading stride.
// Their own tests allocate [K/128, M_padded] and treat only the first M columns as live.
let input_scale = Tensor::zeros((k_over_128, m_padded), DType::F32, dev)?;
let scale_stride = input_scale.stride()[0] as i32;
let q_ptr = get_cuda_slice::<u8>(&input_q)? as *mut std::ffi::c_void;
let s_ptr = get_cuda_slice::<f32>(&input_scale)? as *mut f32;
let inp_ptr = get_cuda_slice::<half::bf16>(input)? as *const std::ffi::c_void;

unsafe {
let num_groups = m * k_over_128;
ffi::fp8_quantize_per_token_group_launch(
inp_ptr,
q_ptr,
s_ptr,
num_groups as i32,
128,
k_over_128 as i32,
scale_stride,
false,
true,
stream,
);
}

let required_ws =
unsafe { ffi::flashinfer_fp8_blockscale_workspace_size_fp8(m as i32, n as i32, k as i32) };
let (workspace_ptr, workspace_size) =
get_or_init_flashinfer_fp8_workspace(cu_dev, required_ws)?;

let weight_ptr = get_cuda_slice::<u8>(weight)? as *const std::ffi::c_void;
let weight_scale_ptr = get_cuda_slice::<f32>(weight_scale)? as *const f32;
let out_ptr = get_cuda_slice::<half::bf16>(&out)? as *mut std::ffi::c_void;

let status = unsafe {
ffi::flashinfer_fp8_blockscale_fp8(
q_ptr as *const std::ffi::c_void,
s_ptr as *const f32,
weight_ptr,
weight_scale_ptr,
out_ptr,
m as i32,
n as i32,
k as i32,
workspace_ptr,
workspace_size,
stream,
)
};
if status != 0 {
candle_core::bail!("flashinfer fp8 blockscale gemm failed with status {status}");
}

Ok(out)
}

/// FP8 Matrix Multiplication using CUTLASS blockwise kernels (SM90+).
///
/// # Arguments
Expand Down
Loading