diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..10225b431d --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ +.PHONY: fmt + +fmt: + cargo fmt + ruff format + find mistralrs-* -type f \( -name "*.metal" -o -name "*.c" -o -name "*.cu" -o -name "*.hpp" -o -name "*.h" -o -name "*.cpp" \) -exec clang-format -i {} + \ No newline at end of file diff --git a/examples/server/stream_completion_bench.py b/examples/server/stream_completion_bench.py index 7a51c6da13..56e0456f24 100644 --- a/examples/server/stream_completion_bench.py +++ b/examples/server/stream_completion_bench.py @@ -31,7 +31,7 @@ def run(): request(stream=False) finished = datetime.now() - print(f"Duration: {finished-now}") + print(f"Duration: {finished - now}") print("\nStreaming: ") print("=" * 15) @@ -42,7 +42,7 @@ def run(): pass finished = datetime.now() - print(f"Duration: {finished-now}") + print(f"Duration: {finished - now}") if __name__ == "__main__": diff --git a/mistralrs-bench/src/main.rs b/mistralrs-bench/src/main.rs index 2bf4e77f32..f427aa6ca4 100644 --- a/mistralrs-bench/src/main.rs +++ b/mistralrs-bench/src/main.rs @@ -370,7 +370,11 @@ fn main() -> anyhow::Result<()> { #[cfg(feature = "metal")] let device = Device::new_metal(0)?; #[cfg(not(feature = "metal"))] - let device = Device::cuda_if_available(0)?; + let device = if cfg!(feature = "nccl") { + Device::Cpu + } else { + Device::cuda_if_available(0)? + }; if let Some(seed) = args.seed { device.set_seed(seed)?; @@ -429,7 +433,7 @@ fn main() -> anyhow::Result<()> { DeviceMapSetting::Auto(auto_device_map_params) }; - let no_paged_attn = if device.is_cuda() { + let no_paged_attn = if device.is_cuda() || cfg!(feature = "nccl") { args.no_paged_attn } else if device.is_metal() { !args.paged_attn diff --git a/mistralrs-core/build.rs b/mistralrs-core/build.rs index 1fae6e92ae..2e1d26fcf2 100644 --- a/mistralrs-core/build.rs +++ b/mistralrs-core/build.rs @@ -7,7 +7,7 @@ fn main() { use std::{path::PathBuf, vec}; println!("cargo:rerun-if-changed=build.rs"); let build_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap()); - let lib_files = vec!["src/cuda/nonzero_bitwise.cu"]; + let lib_files = vec!["src/cuda/nonzero_bitwise.cu", "src/cuda/sort.cu"]; for lib_file in lib_files.iter() { println!("cargo:rerun-if-changed={lib_file}"); } @@ -23,7 +23,9 @@ fn main() { .arg("--expt-relaxed-constexpr") .arg("--expt-extended-lambda") .arg("--use_fast_math") - .arg("--verbose"); + .arg("--verbose") + .arg("--compiler-options") + .arg("-fPIC"); // https://github.com/EricLBuehler/mistral.rs/issues/286 if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { diff --git a/mistralrs-core/src/cublaslt/mod.rs b/mistralrs-core/src/cublaslt/mod.rs index 60d2c81526..ba9d8f2d81 100644 --- a/mistralrs-core/src/cublaslt/mod.rs +++ b/mistralrs-core/src/cublaslt/mod.rs @@ -20,7 +20,7 @@ static mut CUBLASLT: Option = None; pub static CUBLASLT_HANDLE: Lazy>> = Lazy::new(|| Mutex::new(None)); -pub fn setup_cublas_lt_wrapper() { +pub fn setup_cublas_lt_wrapper(device: Device) { unsafe { INIT.call_once(|| { #[cfg(not(feature = "cuda"))] @@ -34,21 +34,17 @@ pub fn setup_cublas_lt_wrapper() { // Then check if we can create a device // Then check that the device is CUDA use candle_core::cuda_backend::cudarc::driver; - CUBLASLT = driver::result::init() - .ok() - .and_then(|_| Device::cuda_if_available(0).ok()) - .and_then(|device| match device { - Device::Cuda(_) => Some(CublasLtWrapper { - cublaslt: CublasLt::new(&device).unwrap(), - }), - _ => None, - }); - tracing::info!("Initialized cuBLASlt handle"); + CUBLASLT = match device { + Device::Cuda(_) => Some(CublasLtWrapper { + cublaslt: CublasLt::new(&device).unwrap(), + }), + _ => None, + } } + #[allow(static_mut_refs)] + let cublaslt: Option<&'static CublasLtWrapper> = CUBLASLT.as_ref(); + *CUBLASLT_HANDLE.lock().unwrap() = cublaslt; }); - #[allow(static_mut_refs)] - let cublaslt: Option<&'static CublasLtWrapper> = CUBLASLT.as_ref(); - *CUBLASLT_HANDLE.lock().unwrap() = cublaslt; } } diff --git a/mistralrs-core/src/cuda/ffi.rs b/mistralrs-core/src/cuda/ffi.rs index f945b0ad10..344ba7c6d6 100644 --- a/mistralrs-core/src/cuda/ffi.rs +++ b/mistralrs-core/src/cuda/ffi.rs @@ -1,16 +1,21 @@ use std::ffi::c_void; +#[cfg(feature = "cuda")] +type FfiCudaStream = candle_core::cuda::cudarc::driver::sys::CUstream; +#[cfg(not(feature = "cuda"))] +type FfiCudaStream = *const std::ffi::c_void; + #[allow(dead_code)] extern "C" { - pub(crate) fn count_nonzero_bf16(d_in: *const c_void, N: u32) -> u32; - pub(crate) fn count_nonzero_f16(d_in: *const c_void, N: u32) -> u32; - pub(crate) fn count_nonzero_f32(d_in: *const c_void, N: u32) -> u32; - pub(crate) fn count_nonzero_f64(d_in: *const c_void, N: u32) -> u32; - pub(crate) fn count_nonzero_u8(d_in: *const c_void, N: u32) -> u32; - pub(crate) fn count_nonzero_u32(d_in: *const c_void, N: u32) -> u32; - pub(crate) fn count_nonzero_i16(d_in: *const c_void, N: u32) -> u32; - pub(crate) fn count_nonzero_i64(d_in: *const c_void, N: u32) -> u32; - pub(crate) fn count_nonzero_i32(d_in: *const c_void, N: u32) -> u32; + pub(crate) fn count_nonzero_bf16(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32; + pub(crate) fn count_nonzero_f16(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32; + pub(crate) fn count_nonzero_f32(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32; + pub(crate) fn count_nonzero_f64(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32; + pub(crate) fn count_nonzero_u8(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32; + pub(crate) fn count_nonzero_u32(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32; + pub(crate) fn count_nonzero_i16(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32; + pub(crate) fn count_nonzero_i64(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32; + pub(crate) fn count_nonzero_i32(d_in: *const c_void, N: u32, stream: FfiCudaStream) -> u32; pub(crate) fn nonzero_bf16( d_in: *const c_void, N: u32, @@ -18,6 +23,7 @@ extern "C" { dims: *const c_void, num_dims: u32, d_out: *mut c_void, + stream: FfiCudaStream, ); pub(crate) fn nonzero_f16( d_in: *const c_void, @@ -26,6 +32,7 @@ extern "C" { dims: *const c_void, num_dims: u32, d_out: *mut c_void, + stream: FfiCudaStream, ); pub(crate) fn nonzero_f32( d_in: *const c_void, @@ -34,6 +41,7 @@ extern "C" { dims: *const c_void, num_dims: u32, d_out: *mut c_void, + stream: FfiCudaStream, ); pub(crate) fn nonzero_f64( d_in: *const c_void, @@ -42,6 +50,7 @@ extern "C" { dims: *const c_void, num_dims: u32, d_out: *mut c_void, + stream: FfiCudaStream, ); pub(crate) fn nonzero_u8( d_in: *const c_void, @@ -50,6 +59,7 @@ extern "C" { dims: *const c_void, num_dims: u32, d_out: *mut c_void, + stream: FfiCudaStream, ); pub(crate) fn nonzero_u32( d_in: *const c_void, @@ -58,6 +68,7 @@ extern "C" { dims: *const c_void, num_dims: u32, d_out: *mut c_void, + stream: FfiCudaStream, ); pub(crate) fn nonzero_i64( d_in: *const c_void, @@ -66,6 +77,7 @@ extern "C" { dims: *const c_void, num_dims: u32, d_out: *mut c_void, + stream: FfiCudaStream, ); pub(crate) fn nonzero_i16( d_in: *const c_void, @@ -74,6 +86,7 @@ extern "C" { dims: *const c_void, num_dims: u32, d_out: *mut c_void, + stream: FfiCudaStream, ); pub(crate) fn nonzero_i32( d_in: *const c_void, @@ -82,6 +95,7 @@ extern "C" { dims: *const c_void, num_dims: u32, d_out: *mut c_void, + stream: FfiCudaStream, ); pub(crate) fn bitwise_and_u8( @@ -161,4 +175,117 @@ extern "C" { pub(crate) fn leftshift_u32(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32); pub(crate) fn leftshift_i64(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32); pub(crate) fn leftshift_i32(d_in1: *const c_void, d_out: *mut c_void, N: u32, k: i32); + + pub(crate) fn asort_asc_f32( + x: *const c_void, + dst: *mut c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); + pub(crate) fn asort_asc_f16( + x: *const c_void, + dst: *mut c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); + pub(crate) fn asort_asc_bf16( + x: *const c_void, + dst: *const c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); + pub(crate) fn asort_asc_f64( + x: *const c_void, + dst: *mut c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); + pub(crate) fn asort_asc_u8( + x: *const c_void, + dst: *mut c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); + pub(crate) fn asort_asc_u32( + x: *const c_void, + dst: *mut c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); + pub(crate) fn asort_asc_i64( + x: *const c_void, + dst: *mut c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); + pub(crate) fn asort_desc_f32( + x: *const c_void, + dst: *mut c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); + pub(crate) fn asort_desc_f16( + x: *const c_void, + dst: *mut c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); + pub(crate) fn asort_desc_bf16( + x: *const c_void, + dst: *mut c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); + pub(crate) fn asort_desc_f64( + x: *const c_void, + dst: *mut c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); + pub(crate) fn asort_desc_u8( + x: *const c_void, + dst: *mut c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); + pub(crate) fn asort_desc_u32( + x: *const c_void, + dst: *mut c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); + pub(crate) fn asort_desc_i64( + x: *const c_void, + dst: *mut c_void, + nrows: i32, + ncols: i32, + inplace: bool, + stream: i64, + ); } diff --git a/mistralrs-core/src/cuda/nonzero_bitwise.cu b/mistralrs-core/src/cuda/nonzero_bitwise.cu index ced2f9ff82..ca35c56034 100644 --- a/mistralrs-core/src/cuda/nonzero_bitwise.cu +++ b/mistralrs-core/src/cuda/nonzero_bitwise.cu @@ -1,8 +1,22 @@ // Get inspiration from // https://github.com/pytorch/pytorch/blob/65aa16f968af2cd18ff8c25cc657e7abda594bfc/aten/src/ATen/native/cuda/Nonzero.cu +#include #include #include -#include +#include + +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(err)); \ + exit(err); \ + } \ + } while (0) + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) int next_power_of_2(const uint32_t num_nonzero) { int result = 1; @@ -20,32 +34,36 @@ template struct NonZeroOp { // count the number of non-zero elements in an array, to better allocate memory template -void count_nonzero(const T *d_in, const uint32_t N, uint32_t *h_out) { +void count_nonzero(const T *d_in, const uint32_t N, uint32_t *h_out, + cudaStream_t stream) { cub::TransformInputIterator, const T *> itr( d_in, NonZeroOp()); size_t temp_storage_bytes = 0; - size_t *d_num_nonzero; - cudaMalloc((void **)&d_num_nonzero, sizeof(uint32_t)); - cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, d_num_nonzero, N); + uint32_t *d_num_nonzero; + CUDA_CHECK( + cudaMallocAsync((void **)&d_num_nonzero, sizeof(uint32_t), stream)); + CUDA_CHECK(cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, + d_num_nonzero, N, stream)); void **d_temp_storage; - cudaMalloc(&d_temp_storage, temp_storage_bytes); - cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, itr, d_num_nonzero, - N); - cudaMemcpy(h_out, d_num_nonzero, sizeof(uint32_t), cudaMemcpyDeviceToHost); - cudaFree(d_num_nonzero); - cudaFree(d_temp_storage); + CUDA_CHECK(cudaMallocAsync(&d_temp_storage, temp_storage_bytes, stream)); + CUDA_CHECK(cub::DeviceReduce::Sum(d_temp_storage, temp_storage_bytes, itr, + d_num_nonzero, N, stream)); + CUDA_CHECK(cudaMemcpyAsync(h_out, d_num_nonzero, sizeof(uint32_t), + cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaFreeAsync(d_num_nonzero, stream)); + CUDA_CHECK(cudaFreeAsync(d_temp_storage, stream)); } #define COUNT_NONZERO_OP(TYPENAME, RUST_NAME) \ - extern "C" uint32_t count_nonzero_##RUST_NAME(const TYPENAME *d_in, \ - uint32_t N) { \ + extern "C" uint32_t count_nonzero_##RUST_NAME( \ + const TYPENAME *d_in, uint32_t N, cudaStream_t stream) { \ uint32_t result; \ - count_nonzero(d_in, N, &result); \ + count_nonzero(d_in, N, &result, stream); \ return result; \ } #define COUNT_NONZERO_OP_DUMMY(RUST_NAME) \ - extern "C" uint32_t count_nonzero_##RUST_NAME(const uint16_t *d_in, \ - uint32_t N) { \ + extern "C" uint32_t count_nonzero_##RUST_NAME( \ + const uint16_t *d_in, uint32_t N, cudaStream_t stream) { \ return 0; \ } @@ -90,45 +108,53 @@ __global__ void transform_indices(const uint32_t *temp_indices, // get the indices of non-zero elements in an array template void nonzero(const T *d_in, const uint32_t N, const uint32_t num_nonzero, - const uint32_t *dims, const uint32_t num_dims, uint32_t *d_out) { + const uint32_t *dims, const uint32_t num_dims, uint32_t *d_out, + cudaStream_t stream) { cub::TransformInputIterator, const T *> itr( d_in, NonZeroOp()); cub::CountingInputIterator counting_itr(0); uint32_t *out_temp; uint32_t *num_selected_out; - cudaMalloc((void **)&out_temp, num_nonzero * sizeof(uint32_t)); - cudaMalloc((void **)&num_selected_out, sizeof(uint32_t)); + CUDA_CHECK(cudaMallocAsync((void **)&out_temp, num_nonzero * sizeof(uint32_t), + stream)); + CUDA_CHECK( + cudaMallocAsync((void **)&num_selected_out, sizeof(uint32_t), stream)); size_t temp_storage_bytes = 0; - cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr, - out_temp, num_selected_out, N); + CUDA_CHECK(cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, + counting_itr, itr, out_temp, + num_selected_out, N, stream)); void **d_temp_storage; - cudaMalloc(&d_temp_storage, temp_storage_bytes); - cub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, counting_itr, - itr, out_temp, num_selected_out, (int)N); + CUDA_CHECK(cudaMallocAsync(&d_temp_storage, temp_storage_bytes, stream)); + CUDA_CHECK(cub::DeviceSelect::Flagged(d_temp_storage, temp_storage_bytes, + counting_itr, itr, out_temp, + num_selected_out, (int)N, stream)); int nthreads = next_power_of_2(num_nonzero); if (nthreads > 1024) { nthreads = 1024; } const int nblocks = (num_nonzero + nthreads - 1) / nthreads; - transform_indices<<>>(out_temp, num_nonzero, dims, - num_dims, d_out); - cudaDeviceSynchronize(); - cudaFree(out_temp); - cudaFree(d_temp_storage); - cudaFree(num_selected_out); + transform_indices<<>>(out_temp, num_nonzero, + dims, num_dims, d_out); + CUDA_CHECK(cudaGetLastError()); + + CUDA_CHECK(cudaFreeAsync(out_temp, stream)); + CUDA_CHECK(cudaFreeAsync(d_temp_storage, stream)); + CUDA_CHECK(cudaFreeAsync(num_selected_out, stream)); } #define NONZERO_OP(TYPENAME, RUST_NAME) \ - extern "C" void nonzero_##RUST_NAME( \ - const TYPENAME *d_in, uint32_t N, uint32_t num_nonzero, \ - const uint32_t *dims, uint32_t num_dims, uint32_t *d_out) { \ - nonzero(d_in, N, num_nonzero, dims, num_dims, d_out); \ + extern "C" void nonzero_##RUST_NAME(const TYPENAME *d_in, uint32_t N, \ + uint32_t num_nonzero, \ + const uint32_t *dims, uint32_t num_dims, \ + uint32_t *d_out, cudaStream_t stream) { \ + nonzero(d_in, N, num_nonzero, dims, num_dims, d_out, stream); \ } #define NONZERO_OP_DUMMY(RUST_NAME) \ - extern "C" void nonzero_##RUST_NAME( \ - const uint16_t *d_in, uint32_t N, uint32_t num_nonzero, \ - const uint32_t *dims, uint32_t num_dims, uint32_t *d_out) { \ + extern "C" void nonzero_##RUST_NAME(const uint16_t *d_in, uint32_t N, \ + uint32_t num_nonzero, \ + const uint32_t *dims, uint32_t num_dims, \ + uint32_t *d_out, cudaStream_t stream) { \ assert(false); \ } @@ -187,7 +213,7 @@ void bitwise_and(const T *d_in1, const T *d_in2, T *d_out, int N) { } const int nblocks = (N + nthreads - 1) / nthreads; bitwise_and__kernel<<>>(d_in1, d_in2, d_out, N); - cudaDeviceSynchronize(); + CUDA_CHECK(cudaGetLastError()); } template @@ -198,7 +224,7 @@ void bitwise_or(const T *d_in1, const T *d_in2, T *d_out, int N) { } const int nblocks = (N + nthreads - 1) / nthreads; bitwise_or__kernel<<>>(d_in1, d_in2, d_out, N); - cudaDeviceSynchronize(); + CUDA_CHECK(cudaGetLastError()); } template @@ -209,7 +235,7 @@ void bitwise_xor(const T *d_in1, const T *d_in2, T *d_out, int N) { } const int nblocks = (N + nthreads - 1) / nthreads; bitwise_xor__kernel<<>>(d_in1, d_in2, d_out, N); - cudaDeviceSynchronize(); + CUDA_CHECK(cudaGetLastError()); } #define BITWISE_OP(TYPENAME, RUST_NAME) \ @@ -235,8 +261,8 @@ BITWISE_OP(int64_t, i64) BITWISE_OP(int32_t, i32) template -__global__ void leftshift_kernel(const T *d_in1, T *d_out, - const uint32_t N, const int32_t k) { +__global__ void leftshift_kernel(const T *d_in1, T *d_out, const uint32_t N, + const int32_t k) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) { d_out[idx] = d_in1[idx] << k; @@ -251,13 +277,13 @@ void leftshift(const T *d_in1, T *d_out, int N, const int32_t k) { } const int nblocks = (N + nthreads - 1) / nthreads; leftshift_kernel<<>>(d_in1, d_out, N, k); - cudaDeviceSynchronize(); + CUDA_CHECK(cudaGetLastError()); } -#define LEFTSHIFT_OP(TYPENAME, RUST_NAME) \ - extern "C" void leftshift_##RUST_NAME(const TYPENAME *d_in1, \ - TYPENAME *d_out, uint32_t N, int32_t k) { \ - leftshift(d_in1, d_out, N, k); \ +#define LEFTSHIFT_OP(TYPENAME, RUST_NAME) \ + extern "C" void leftshift_##RUST_NAME( \ + const TYPENAME *d_in1, TYPENAME *d_out, uint32_t N, int32_t k) { \ + leftshift(d_in1, d_out, N, k); \ } LEFTSHIFT_OP(uint8_t, u8) diff --git a/mistralrs-core/src/cuda/sort.cu b/mistralrs-core/src/cuda/sort.cu new file mode 100644 index 0000000000..59dcd288a8 --- /dev/null +++ b/mistralrs-core/src/cuda/sort.cu @@ -0,0 +1,132 @@ +#include "cuda_bf16.h" +#include "cuda_fp16.h" +#include +#include +template inline __device__ void swap(T &a, T &b) { + T tmp = a; + a = b; + b = tmp; +} + +template +__global__ void bitonic_sort_kernel(T *arr, uint32_t *dst, int j, int k) { + unsigned int i, ij; + i = threadIdx.x + blockDim.x * blockIdx.x; + ij = i ^ j; + + if (ij > i) { + if constexpr (ascending) { + if ((i & k) == 0) { + if (arr[i] > arr[ij]) { + swap(arr[i], arr[ij]); + swap(dst[i], dst[ij]); + } + } else { + if (arr[i] < arr[ij]) { + swap(arr[i], arr[ij]); + swap(dst[i], dst[ij]); + } + } + } + + if constexpr (!ascending) { + if ((i & k) != 0) { + if (arr[i] > arr[ij]) { + swap(arr[i], arr[ij]); + swap(dst[i], dst[ij]); + } + } else { + if (arr[i] < arr[ij]) { + swap(arr[i], arr[ij]); + swap(dst[i], dst[ij]); + } + } + } + } + __syncthreads(); +} + +int next_power_of_2(int x) { + int n = 1; + while (n < x) { + n *= 2; + } + return n; +} + +#define ASORT_OP(T, RUST_NAME, ASC) \ + extern "C" void RUST_NAME(void *x1, void *dst1, const int nrows, \ + const int ncols, bool inplace, int64_t stream) { \ + T *x = reinterpret_cast(x1); \ + uint32_t *dst = reinterpret_cast(dst1); \ + const cudaStream_t custream = (cudaStream_t)stream; \ + int ncols_pad = next_power_of_2(ncols); \ + T *x_row_padded; \ + uint32_t *dst_row_padded; \ + cudaMallocAsync((void **)&x_row_padded, ncols_pad * sizeof(T), custream); \ + cudaMallocAsync((void **)&dst_row_padded, ncols_pad * sizeof(uint32_t), \ + custream); \ + uint32_t *indices_padded = \ + (uint32_t *)malloc(ncols_pad * sizeof(uint32_t)); \ + for (int i = 0; i < ncols_pad; i++) { \ + indices_padded[i] = i; \ + } \ + T *values_padded = (T *)malloc((ncols_pad - ncols) * sizeof(T)); \ + for (int i = 0; i < ncols_pad - ncols; i++) { \ + values_padded[i] = \ + ASC ? std::numeric_limits::max() : std::numeric_limits::min(); \ + } \ + int max_threads_per_block = 1024; \ + int threads_per_block = \ + max_threads_per_block > ncols_pad ? ncols_pad : max_threads_per_block; \ + int blocks_per_row = \ + (ncols_pad + threads_per_block - 1) / threads_per_block; \ + for (int row = 0; row < nrows; row++) { \ + T *x_row = x + row * ncols; \ + uint32_t *dst_row = dst + row * ncols; \ + cudaMemcpyAsync(x_row_padded, x_row, ncols * sizeof(T), \ + cudaMemcpyDeviceToDevice, custream); \ + if (ncols_pad - ncols > 0) \ + cudaMemcpyAsync(x_row_padded + ncols, values_padded, \ + (ncols_pad - ncols) * sizeof(T), \ + cudaMemcpyHostToDevice, custream); \ + cudaMemcpyAsync(dst_row_padded, indices_padded, \ + ncols_pad * sizeof(uint32_t), cudaMemcpyHostToDevice, \ + custream); \ + for (int k = 2; k <= ncols_pad; k <<= 1) { \ + for (int j = k >> 1; j > 0; j = j >> 1) { \ + bitonic_sort_kernel \ + <<>>( \ + x_row_padded, dst_row_padded, j, k); \ + } \ + } \ + if (inplace) \ + cudaMemcpyAsync(x_row, x_row_padded, ncols * sizeof(T), \ + cudaMemcpyDeviceToDevice, custream); \ + cudaMemcpyAsync(dst_row, dst_row_padded, ncols * sizeof(uint32_t), \ + cudaMemcpyDeviceToDevice, custream); \ + } \ + cudaFreeAsync(x_row_padded, custream); \ + cudaFreeAsync(dst_row_padded, custream); \ + cudaStreamSynchronize(custream); \ + free(indices_padded); \ + free(values_padded); \ + } + +ASORT_OP(__nv_bfloat16, asort_asc_bf16, true) +ASORT_OP(__nv_bfloat16, asort_desc_bf16, false) + +ASORT_OP(__half, asort_asc_f16, true) +ASORT_OP(__half, asort_desc_f16, false) + +ASORT_OP(float, asort_asc_f32, true) +ASORT_OP(double, asort_asc_f64, true) +ASORT_OP(uint8_t, asort_asc_u8, true) +ASORT_OP(uint32_t, asort_asc_u32, true) +ASORT_OP(int64_t, asort_asc_i64, true) + +ASORT_OP(float, asort_desc_f32, false) +ASORT_OP(double, asort_desc_f64, false) +ASORT_OP(uint8_t, asort_desc_u8, false) +ASORT_OP(uint32_t, asort_desc_u32, false) +ASORT_OP(int64_t, asort_desc_i64, false) diff --git a/mistralrs-core/src/daemon.rs b/mistralrs-core/src/daemon.rs index ffd33f39b0..54c1491a81 100644 --- a/mistralrs-core/src/daemon.rs +++ b/mistralrs-core/src/daemon.rs @@ -5,7 +5,7 @@ use serde_big_array::BigArray; pub(crate) const IS_DAEMON_FLAG: &str = "__MISTRALRS_DAEMON_INTERNAL"; -pub(crate) fn is_daemon() -> bool { +pub fn is_daemon() -> bool { std::env::var(IS_DAEMON_FLAG).is_ok() } diff --git a/mistralrs-core/src/device_map.rs b/mistralrs-core/src/device_map.rs index 44ba8a8f24..f2168395a5 100644 --- a/mistralrs-core/src/device_map.rs +++ b/mistralrs-core/src/device_map.rs @@ -155,10 +155,22 @@ impl DeviceMapSetting { for DeviceLayerMapMetadata { ordinal, layers } in device_layers.as_ref().unwrap() { - let dev = match device { - Device::Cpu => Device::Cpu, - Device::Cuda(_) => Device::cuda_if_available(*ordinal)?, - Device::Metal(_) => Device::new_metal(*ordinal)?, + let dev = match device.location() { + DeviceLocation::Cpu => Device::Cpu, + DeviceLocation::Cuda { gpu_id: device_ord } => { + if device_ord == *ordinal { + device.clone() + } else { + Device::new_cuda_with_stream(*ordinal)? + } + } + DeviceLocation::Metal { gpu_id: device_ord } => { + if device_ord == *ordinal { + device.clone() + } else { + Device::new_metal(*ordinal)? + } + } }; if !device.is_cpu() { dev.set_seed(original_seed.unwrap())?; diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index e3aef792eb..d89e53324e 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -55,7 +55,7 @@ mod paged_attention; #[cfg(not(any(all(feature = "cuda", target_family = "unix"), feature = "metal")))] use dummy_paged_attention as paged_attention; mod attention; -pub(crate) mod daemon; +pub mod daemon; mod diffusion_models; mod pipeline; mod prefix_cacher; @@ -250,13 +250,12 @@ impl MistralRsBuilder { } #[cfg(feature = "cuda")] -fn set_gemm_reduced_precision_f16() { +fn set_gemm_reduced_precision_f16(device: candle_core::Device) { use mistralrs_quant::INHIBIT_GEMM_F16; - use candle_core::{DType, Device, Tensor}; + use candle_core::{DType, Tensor}; - // NOTE(EricLBuehler): When we support multi-GPU inference, we should check for each gpu here - let a = Tensor::zeros((2, 2), DType::BF16, &Device::new_cuda(0).unwrap()).unwrap(); + let a = Tensor::zeros((2, 2), DType::BF16, &device).unwrap(); candle_core::cuda::set_gemm_reduced_precision_bf16(true); match a.matmul(&a) { Ok(_) => tracing::info!("Enabling GEMM reduced precision in BF16."), @@ -269,7 +268,7 @@ fn set_gemm_reduced_precision_f16() { } } - let a = Tensor::zeros((2, 2), DType::F16, &Device::new_cuda(0).unwrap()).unwrap(); + let a = Tensor::zeros((2, 2), DType::F16, &device).unwrap(); candle_core::cuda::set_gemm_reduced_precision_f16(true); match a.matmul(&a) { Ok(_) => tracing::info!("Enabling GEMM reduced precision in F16."), @@ -284,7 +283,7 @@ fn set_gemm_reduced_precision_f16() { } #[cfg(not(feature = "cuda"))] -fn set_gemm_reduced_precision_f16() {} +fn set_gemm_reduced_precision_f16(_device: candle_core::Device) {} impl Drop for MistralRs { fn drop(&mut self) { @@ -317,9 +316,9 @@ impl MistralRs { ModelCategory::Diffusion => true, }; if !gemm_full_precision_f16.unwrap_or(false) && model_supports_reduced_gemm { - set_gemm_reduced_precision_f16(); + set_gemm_reduced_precision_f16(get_mut_arcmutex!(pipeline).device()); } - setup_cublas_lt_wrapper(); + setup_cublas_lt_wrapper(get_mut_arcmutex!(pipeline).device()); let truncate_sequence = truncate_sequence.unwrap_or(false); let no_kv_cache = no_kv_cache.unwrap_or(false); @@ -399,7 +398,7 @@ impl MistralRs { request_sender.send(req).await.unwrap(); let resp = receiver.recv().await.unwrap(); - assert!(resp.is_ok()); + resp.unwrap(); continue; } Request::Tokenize(mut x) => { @@ -409,7 +408,7 @@ impl MistralRs { request_sender.send(req).await.unwrap(); let resp = receiver.recv().await.unwrap(); - assert!(resp.is_ok()); + resp.unwrap(); continue; } Request::Normal(mut x) => { @@ -420,7 +419,7 @@ impl MistralRs { request_sender.send(req).await.unwrap(); let resp = receiver.recv().await.unwrap(); - assert!(resp.as_result().is_ok()); + resp.as_result().unwrap(); continue; } Request::TerminateAllSeqsNextStep => { diff --git a/mistralrs-core/src/models/deepseek2.rs b/mistralrs-core/src/models/deepseek2.rs index fae471686d..888472b2bd 100644 --- a/mistralrs-core/src/models/deepseek2.rs +++ b/mistralrs-core/src/models/deepseek2.rs @@ -301,22 +301,12 @@ impl Attention { (q_pe, k_pe) = self.rotary_emb.forward(&q_pe, &k_pe, seqlen_offsets)?; - let mut q = Tensor::zeros( - (bs, self.num_attention_heads, seq_len, self.q_head_dim), - q_pe.dtype(), - q_pe.device(), - )?; - q = q.slice_assign(&[&.., &.., &.., &(..self.cfg.qk_nope_head_dim)], &q_nope)?; - q = q.slice_assign(&[&.., &.., &.., &(self.cfg.qk_nope_head_dim..)], &q_pe)?; - - let mut k = Tensor::zeros( - (bs, self.num_attention_heads, seq_len, self.q_head_dim), - k_pe.dtype(), - k_pe.device(), - )?; - k = k.slice_assign(&[&.., &.., &.., &(..self.cfg.qk_nope_head_dim)], &k_nope)?; - let k_pe = k_pe.repeat((1, k.dim(1)?, 1, 1))?; - k = k.slice_assign(&[&.., &.., &.., &(self.cfg.qk_nope_head_dim..)], &k_pe)?; + let q = Tensor::cat(&[&q_nope, &q_pe], D::Minus1)?.contiguous()?; + let mut k = Tensor::cat( + &[&k_nope, &k_pe.repeat((1, self.num_attention_heads, 1, 1))?], + D::Minus1, + )? + .contiguous()?; let mut attn_out = match &self.paged_attn { Some(paged_attn) => match metadata { diff --git a/mistralrs-core/src/models/deepseek3.rs b/mistralrs-core/src/models/deepseek3.rs index 73acb3edb3..462ff1a047 100644 --- a/mistralrs-core/src/models/deepseek3.rs +++ b/mistralrs-core/src/models/deepseek3.rs @@ -34,7 +34,6 @@ serde_default_fn!(f64, routed_scaling_factor, 1.0); serde_default_fn!(TopkMethod, topk_method, TopkMethod::Greedy); serde_default_fn!(usize, moe_layer_freq, 1); serde_default_fn!(usize, first_k_dense_replace, 0); -serde_default_fn!(bool, norm_topk_prob, false); serde_default_fn!(ScoringFunc, scoring_func, ScoringFunc::Softmax); serde_default_fn!(Activation, hidden_act, Activation::Silu); serde_default_fn!(bool, tie_word_embeddings, false); @@ -77,9 +76,6 @@ pub struct DeepSeekV3Config { pub(crate) moe_layer_freq: usize, #[serde(default = "first_k_dense_replace")] pub(crate) first_k_dense_replace: usize, - // k dense layers - #[serde(default = "norm_topk_prob")] - pub(crate) norm_topk_prob: bool, #[serde(default = "scoring_func")] scoring_func: ScoringFunc, #[serde(default = "hidden_act")] @@ -305,22 +301,12 @@ impl Attention { (q_pe, k_pe) = self.rotary_emb.forward(&q_pe, &k_pe, seqlen_offsets)?; - let mut q = Tensor::zeros( - (bs, self.num_attention_heads, seq_len, self.q_head_dim), - q_pe.dtype(), - q_pe.device(), - )?; - q = q.slice_assign(&[&.., &.., &.., &(..self.cfg.qk_nope_head_dim)], &q_nope)?; - q = q.slice_assign(&[&.., &.., &.., &(self.cfg.qk_nope_head_dim..)], &q_pe)?; - - let mut k = Tensor::zeros( - (bs, self.num_attention_heads, seq_len, self.q_head_dim), - k_pe.dtype(), - k_pe.device(), - )?; - k = k.slice_assign(&[&.., &.., &.., &(..self.cfg.qk_nope_head_dim)], &k_nope)?; - let k_pe = k_pe.repeat((1, k.dim(1)?, 1, 1))?; - k = k.slice_assign(&[&.., &.., &.., &(self.cfg.qk_nope_head_dim..)], &k_pe)?; + let q = Tensor::cat(&[&q_nope, &q_pe], D::Minus1)?.contiguous()?; + let mut k = Tensor::cat( + &[&k_nope, &k_pe.repeat((1, self.num_attention_heads, 1, 1))?], + D::Minus1, + )? + .contiguous()?; let mut attn_out = match &self.paged_attn { Some(paged_attn) => match metadata { @@ -572,7 +558,7 @@ impl MoeGate { } }; - if self.top_k > 1 && self.cfg.norm_topk_prob { + if matches!(self.cfg.scoring_func, ScoringFunc::Sigmoid) { let denmoninator = (topk_weight.sum_keepdim(D::Minus1)? + 1e-20)?; topk_weight = topk_weight.broadcast_div(&denmoninator)?; } diff --git a/mistralrs-core/src/ops.rs b/mistralrs-core/src/ops.rs index c9ddfa6cac..7562b5efb5 100644 --- a/mistralrs-core/src/ops.rs +++ b/mistralrs-core/src/ops.rs @@ -389,23 +389,29 @@ impl NonZero { } #[cfg(feature = "cuda")] -fn count_nonzero_cuda(dtype: candle_core::DType, d_in: *const c_void, n: u32) -> u32 { +fn count_nonzero_cuda( + dtype: candle_core::DType, + d_in: *const c_void, + n: u32, + stream: candle_core::cuda::cudarc::driver::sys::CUstream, +) -> u32 { unsafe { match dtype { - candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n), - candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n), - candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n), - candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n), - candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n), - candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n), - candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n), - candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n), - candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n), + candle_core::DType::U8 => ffi::count_nonzero_u8(d_in, n, stream), + candle_core::DType::U32 => ffi::count_nonzero_u32(d_in, n, stream), + candle_core::DType::I64 => ffi::count_nonzero_i64(d_in, n, stream), + candle_core::DType::I16 => ffi::count_nonzero_i16(d_in, n, stream), + candle_core::DType::I32 => ffi::count_nonzero_i32(d_in, n, stream), + candle_core::DType::BF16 => ffi::count_nonzero_bf16(d_in, n, stream), + candle_core::DType::F16 => ffi::count_nonzero_f16(d_in, n, stream), + candle_core::DType::F32 => ffi::count_nonzero_f32(d_in, n, stream), + candle_core::DType::F64 => ffi::count_nonzero_f64(d_in, n, stream), candle_core::DType::F8E4M3 => todo!(), } } } +#[allow(clippy::too_many_arguments)] #[cfg(feature = "cuda")] fn nonzero_cuda( dtype: candle_core::DType, @@ -415,33 +421,36 @@ fn nonzero_cuda( dims: *const c_void, num_dims: u32, d_out: *mut c_void, + stream: candle_core::cuda::cudarc::driver::sys::CUstream, ) { unsafe { match dtype { - candle_core::DType::U8 => ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out), + candle_core::DType::U8 => { + ffi::nonzero_u8(d_in, n, num_nonzero, dims, num_dims, d_out, stream) + } candle_core::DType::U32 => { - ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out) + ffi::nonzero_u32(d_in, n, num_nonzero, dims, num_dims, d_out, stream) } candle_core::DType::I64 => { - ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out) + ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream) } candle_core::DType::I32 => { - ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out) + ffi::nonzero_i64(d_in, n, num_nonzero, dims, num_dims, d_out, stream) } candle_core::DType::I16 => { - ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out) + ffi::nonzero_i16(d_in, n, num_nonzero, dims, num_dims, d_out, stream) } candle_core::DType::BF16 => { - ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out) + ffi::nonzero_bf16(d_in, n, num_nonzero, dims, num_dims, d_out, stream) } candle_core::DType::F16 => { - ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out) + ffi::nonzero_f16(d_in, n, num_nonzero, dims, num_dims, d_out, stream) } candle_core::DType::F32 => { - ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out) + ffi::nonzero_f32(d_in, n, num_nonzero, dims, num_dims, d_out, stream) } candle_core::DType::F64 => { - ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out) + ffi::nonzero_f64(d_in, n, num_nonzero, dims, num_dims, d_out, stream) } candle_core::DType::F8E4M3 => todo!(), } @@ -498,7 +507,9 @@ impl CustomOp1 for NonZero { candle_core::DType::F8E4M3 => todo!(), } as *const c_void; let n = layout.shape().elem_count(); - let num_nonzero = count_nonzero_cuda(storage.dtype(), d_in, u32::try_from(n)?); + + let num_nonzero = + count_nonzero_cuda(storage.dtype(), d_in, u32::try_from(n)?, *dev.cu_stream()); let d_out = unsafe { dev.alloc::(num_nonzero as usize * layout.dims().len()) } .map_err(|_| Error::Msg("Failed to allocate memory for nonzero result".to_string()))?; let d_out_ptr = *d_out.device_ptr() as *mut c_void; @@ -519,6 +530,7 @@ impl CustomOp1 for NonZero { d_dims_ptr, u32::try_from(layout.dims().len())?, d_out_ptr, + *dev.cu_stream(), ); let shape = Shape::from_dims(&[num_nonzero as usize, layout.dims().len()]); let dst = candle_core::CudaStorage::wrap_cuda_slice(d_out, dev); @@ -550,6 +562,174 @@ impl NonZeroOp for Tensor { } } +#[allow(dead_code)] +#[derive(Debug, Clone)] +struct ArgSort { + asc: bool, + last_dim: usize, + inplace: bool, +} + +impl candle_core::CustomOp1 for ArgSort { + fn name(&self) -> &'static str { + "argsort" + } + + fn cpu_fwd( + &self, + _: &candle_core::CpuStorage, + _: &candle_core::Layout, + ) -> Result<(candle_core::CpuStorage, candle_core::Shape)> { + panic!("not implemented!") + } + + #[allow(clippy::cast_possible_truncation)] + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &candle_core::CudaStorage, + layout: &candle_core::Layout, + ) -> Result<(candle_core::CudaStorage, candle_core::Shape)> { + use candle_core::backend::BackendStorage; + use candle_core::cuda_backend::cudarc::driver::DevicePtr; + use candle_core::cuda_backend::CudaStorageSlice; + use candle_core::cuda_backend::WrapErr; + let dev = storage.device(); + let elem_count = layout.shape().elem_count(); + let ncols = self.last_dim as i32; + let nrows = elem_count as i32 / ncols; + let dst = unsafe { dev.alloc::(elem_count) }.w()?; + + use std::ffi::c_void; + + let src = match &storage.slice { + CudaStorageSlice::U8(inp) => inp.device_ptr(), + CudaStorageSlice::U32(inp) => inp.device_ptr(), + CudaStorageSlice::I64(inp) => inp.device_ptr(), + CudaStorageSlice::BF16(inp) => inp.device_ptr(), + CudaStorageSlice::F16(inp) => inp.device_ptr(), + CudaStorageSlice::F32(inp) => inp.device_ptr(), + CudaStorageSlice::F64(inp) => inp.device_ptr(), + _ => candle_core::bail!("Unexpected dtype in asort"), + }; + let src_ptr = *src as *const c_void; + let dst_ptr = *dst.device_ptr() as *mut c_void; + let stream = *dev.cu_stream() as i64; + unsafe { + if self.asc { + match storage.dtype() { + candle_core::DType::U8 => { + ffi::asort_asc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + candle_core::DType::U32 => { + ffi::asort_asc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + candle_core::DType::I64 => { + ffi::asort_asc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + candle_core::DType::BF16 => { + ffi::asort_asc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + candle_core::DType::F16 => { + ffi::asort_asc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + candle_core::DType::F32 => { + ffi::asort_asc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + candle_core::DType::F64 => { + ffi::asort_asc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + _ => candle_core::bail!("Unexpected dtype in asort"), + } + } else { + match storage.dtype() { + candle_core::DType::U8 => { + ffi::asort_desc_u8(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + candle_core::DType::U32 => { + ffi::asort_desc_u32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + candle_core::DType::I64 => { + ffi::asort_desc_i64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + candle_core::DType::BF16 => { + ffi::asort_desc_bf16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + candle_core::DType::F16 => { + ffi::asort_desc_f16(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + candle_core::DType::F32 => { + ffi::asort_desc_f32(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + candle_core::DType::F64 => { + ffi::asort_desc_f64(src_ptr, dst_ptr, nrows, ncols, self.inplace, stream) + } + _ => candle_core::bail!("Unexpected dtype in asort"), + } + } + } + let dst_ret = candle_core::cuda_backend::CudaStorage { + slice: CudaStorageSlice::U32(dst), + device: dev.clone(), + }; + Ok((dst_ret, layout.shape().clone())) + } +} + +#[allow(dead_code)] +pub trait ArgSortOp { + fn arg_sort(&self, asc: bool) -> Result; + fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)>; +} + +impl ArgSortOp for Tensor { + /// Returns the indices that sort the tensor along the last dimension. + /// + /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in + /// descending order. The sort is unstable so there is no guarantees on the final order when it + /// comes to ties. + fn arg_sort(&self, asc: bool) -> Result { + if !self.is_contiguous() { + return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" }); + } + let last_dim = match self.dims().last() { + Some(last_dim) => *last_dim, + None => candle_core::bail!("empty last-dim in arg-sort"), + }; + // No need for a backward pass for arg sort. + self.apply_op1_no_bwd(&ArgSort { + asc, + last_dim, + inplace: false, + }) + } + + /// Sorts the tensor along the last dimension, returns the sorted tensor together with the + /// sorted indexes. + /// + /// If `asc` is `true`, sorting is in ascending order. Otherwise sorting is performed in + /// descending order. The sort is unstable so there is no guarantees on the final order when it + /// comes to ties. + fn sort(&self, asc: bool) -> Result<(Tensor, Tensor)> { + if !self.is_contiguous() { + return Err(candle_core::Error::RequiresContiguous { op: "arg_sort" }); + } + let last_dim = match self.dims().last() { + Some(last_dim) => *last_dim, + None => candle_core::bail!("empty last-dim in arg-sort"), + }; + let sorted = self.copy()?; + + let asort = sorted.apply_op1_no_bwd(&ArgSort { + asc, + last_dim, + inplace: true, + })?; + + Ok((sorted, asort)) + } +} + #[allow(dead_code)] pub struct TopKOutput { pub values: Tensor, @@ -571,29 +751,28 @@ pub trait TopKLastDimOp { impl TopKLastDimOp for Tensor { fn topk(&self, topk: usize) -> Result { // Sorted descending - let sorted_indices = self.arg_sort_last_dim(false)?; + #[cfg(feature = "cuda")] + let (values, sorted_indices) = self.sort(false)?; + #[cfg(not(feature = "cuda"))] + let (values, sorted_indices) = self.sort_last_dim(false)?; let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?; + let topk_values = values.narrow(D::Minus1, 0, topk)?.contiguous()?; Ok(TopKOutput { - values: self.gather(&topk_indices, D::Minus1)?, + values: topk_values, indices: topk_indices, }) } fn topk_unsorted(&self, topk: usize) -> Result { // Sorted descending - let sorted_indices_all = self.arg_sort_last_dim(false)?; - let topk_indices_sorted = sorted_indices_all - .narrow(D::Minus1, 0, topk)? - .contiguous()?; - let topk_values_sorted = self.gather(&topk_indices_sorted, D::Minus1)?; - + let TopKOutput { values, indices } = self.topk(topk)?; // Reorder the indices ascending - let reorder_indices = topk_indices_sorted.arg_sort_last_dim(true)?; - let topk_indices_unsorted = topk_indices_sorted - .to_dtype(DType::F32)? - .gather(&reorder_indices, D::Minus1)? - .to_dtype(DType::U32)?; - let topk_values_unsorted = topk_values_sorted.gather(&reorder_indices, D::Minus1)?; + #[cfg(feature = "cuda")] + let reorder_indices = indices.arg_sort(true)?; + #[cfg(not(feature = "cuda"))] + let reorder_indices = indices.arg_sort_last_dim(true)?; + let topk_indices_unsorted = indices.gather(&reorder_indices, D::Minus1)?; + let topk_values_unsorted = values.gather(&reorder_indices, D::Minus1)?; Ok(TopKOutput { values: topk_values_unsorted, indices: topk_indices_unsorted, diff --git a/mistralrs-core/src/paged_attention/mod.rs b/mistralrs-core/src/paged_attention/mod.rs index 63bb78c84e..e1c4dbf482 100644 --- a/mistralrs-core/src/paged_attention/mod.rs +++ b/mistralrs-core/src/paged_attention/mod.rs @@ -128,10 +128,11 @@ pub fn calculate_cache_config( min_mem_gpu = min_mem_gpu.min(mem_gpu); } - // Cap at kv cache for max seq len - let mem_for_toks = - ctxt_to_blocks!(config.max_seq_len(), dtype_size, block_size, config) / SIZE_IN_MB; - let mem_gpu = min_mem_gpu.min(mem_for_toks); + // // Cap at kv cache for max seq len + // let mem_for_toks = + // ctxt_to_blocks!(config.max_seq_len(), dtype_size, block_size, config) / SIZE_IN_MB; + // let mem_gpu = min_mem_gpu.min(mem_for_toks); + let mem_gpu = min_mem_gpu; let num_gpu_blocks = mb_to_blocks!(mem_gpu * SIZE_IN_MB, dtype_size, block_size, config); let num_cpu_blocks = mb_to_blocks!(mem_cpu * SIZE_IN_MB, dtype_size, block_size, config); diff --git a/mistralrs-core/src/pipeline/normal.rs b/mistralrs-core/src/pipeline/normal.rs index 281e4e0783..7c6435761d 100644 --- a/mistralrs-core/src/pipeline/normal.rs +++ b/mistralrs-core/src/pipeline/normal.rs @@ -302,12 +302,17 @@ impl Loader for NormalLoader { let available_devices = if let Ok(payload) = env::var(daemon::IS_DAEMON_FLAG) { let payload: WorkerTransferData = serde_json::from_str(&payload)?; let WorkerTransferData::Init { id: _, worker_rank } = payload; - vec![candle_core::Device::new_cuda(worker_rank + 1)?] + vec![candle_core::Device::new_cuda_with_stream(worker_rank + 1)?] } else if use_nccl { - vec![candle_core::Device::new_cuda(0)?] + vec![candle_core::Device::new_cuda_with_stream(0)?] } else { device_map::get_all_similar_devices(device)? }; + let device = if use_nccl { + available_devices[0].clone() + } else { + device.clone() + }; // If auto, convert to Map if not using nccl if use_nccl { @@ -407,12 +412,12 @@ impl Loader for NormalLoader { let pipeline_mapper = mapper.into_mapper( self.inner.get_total_device_mapping_num_layers(&config)?, - device, + &device, self.config.topology.as_ref(), )?; let mapper = mapper.into_mapper( self.inner.get_total_device_mapping_num_layers(&config)?, - device, + &device, self.config.topology.as_ref(), )?; let mut layer_devices = Vec::new(); @@ -469,6 +474,7 @@ impl Loader for NormalLoader { let multi_progress = Arc::new(MultiProgress::new()); let mut model = if use_nccl { + let device = available_devices[0].clone(); #[cfg(not(feature = "nccl"))] warn!( "NCCL support was included in the build, be sure to build with `--features nccl`." @@ -495,10 +501,9 @@ impl Loader for NormalLoader { info!("Local tensor parallel world size is {local_world_size}"); info!("Global tensor parallel world size is {global_world_size}"); - let mut id = mistralrs_quant::Id::new(); - // TP uses parallel pipelines. let name = daemon::ipc_name()?; + let mut id; let local_rank = if let Ok(payload) = env::var(daemon::IS_DAEMON_FLAG) { let payload: WorkerTransferData = serde_json::from_str(&payload)?; let WorkerTransferData::Init { @@ -511,6 +516,7 @@ impl Loader for NormalLoader { stream.write_all(b"ready\n")?; worker_rank + 1 } else { + id = mistralrs_quant::Id::new(); let num_workers = mistralrs_quant::distributed::get_global_tp_size_from_devices()? - 1; let mut children = Vec::new(); @@ -596,7 +602,7 @@ impl Loader for NormalLoader { // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html?ncclcomminitrank#ncclcomminitrank let comm = mistralrs_quant::Comm::from_device( id, - device, + &device, local_rank + rank_offset, global_world_size, )?; @@ -624,7 +630,6 @@ impl Loader for NormalLoader { }; info!("Loading all ranks."); - let device = available_devices[0].clone(); // The mapper is specific to this pipeline let mapper = DeviceMapSetting::Nccl { nm_device: available_devices[0].clone(), @@ -900,7 +905,7 @@ impl Loader for NormalLoader { paged_attn_config.block_size, dtype, model.config(), - device, + &device, &pipeline_mapper .get_unique_devices() .into_iter() diff --git a/mistralrs-paged-attn/build.rs b/mistralrs-paged-attn/build.rs index 78d43eacf9..a7222c138d 100644 --- a/mistralrs-paged-attn/build.rs +++ b/mistralrs-paged-attn/build.rs @@ -26,7 +26,20 @@ pub use backend::{copy_blocks, paged_attention, reshape_and_cache, swap_blocks}; println!("cargo:rerun-if-changed=src/cuda/pagedattention.cu"); println!("cargo:rerun-if-changed=src/cuda/copy_blocks_kernel.cu"); println!("cargo:rerun-if-changed=src/cuda/reshape_and_cache_kernel.cu"); - let mut builder = bindgen_cuda::Builder::default(); + let mut builder = bindgen_cuda::Builder::default() + .arg("-std=c++17") + .arg("-O3") + .arg("-U__CUDA_NO_HALF_OPERATORS__") + .arg("-U__CUDA_NO_HALF_CONVERSIONS__") + .arg("-U__CUDA_NO_HALF2_OPERATORS__") + .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + .arg("--expt-relaxed-constexpr") + .arg("--expt-extended-lambda") + .arg("--use_fast_math") + .arg("--verbose") + .arg("--compiler-options") + .arg("-fPIC"); + // https://github.com/EricLBuehler/mistral.rs/issues/286 if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { builder = builder.arg("--compiler-options"); diff --git a/mistralrs-paged-attn/src/cuda/attention/attention_dtypes.h b/mistralrs-paged-attn/src/cuda/attention/attention_dtypes.h index 88b4eddec7..7ae40cc8f1 100644 --- a/mistralrs-paged-attn/src/cuda/attention/attention_dtypes.h +++ b/mistralrs-paged-attn/src/cuda/attention/attention_dtypes.h @@ -1,6 +1,6 @@ #pragma once #include "attention_generic.cuh" +#include "dtype_bfloat16.cuh" #include "dtype_float16.cuh" #include "dtype_float32.cuh" -#include "dtype_bfloat16.cuh" diff --git a/mistralrs-paged-attn/src/cuda/backend/paged_attention.rs b/mistralrs-paged-attn/src/cuda/backend/paged_attention.rs index 486d506ea5..c92c015f89 100644 --- a/mistralrs-paged-attn/src/cuda/backend/paged_attention.rs +++ b/mistralrs-paged-attn/src/cuda/backend/paged_attention.rs @@ -203,6 +203,7 @@ impl PagedAttention { q_stride as c_int, kv_block_stride as c_int, kv_head_stride as c_int, + *dev.cu_stream(), internal_type, ) } @@ -241,6 +242,7 @@ impl PagedAttention { q_stride as c_int, kv_block_stride as c_int, kv_head_stride as c_int, + *dev.cu_stream(), internal_type, ) } @@ -393,6 +395,8 @@ fn update_cache< let vc = vc.as_cuda_slice::()?; let s = s.as_cuda_slice::()?; + let dev = k.device(); + // Get cuda views for all tensors let k = k.slice(k_l.start_offset()..); let v = v.slice(v_l.start_offset()..); @@ -453,6 +457,7 @@ fn update_cache< x as c_int, key_stride, value_stride, + *dev.cu_stream(), internal_type, ) } diff --git a/mistralrs-paged-attn/src/cuda/copy_blocks_kernel.cu b/mistralrs-paged-attn/src/cuda/copy_blocks_kernel.cu index 7179b33ac4..5ba0b207fe 100644 --- a/mistralrs-paged-attn/src/cuda/copy_blocks_kernel.cu +++ b/mistralrs-paged-attn/src/cuda/copy_blocks_kernel.cu @@ -1,17 +1,17 @@ #include // Grid: (num_layers, num_pairs) -template -__device__ void copy_blocks_internal_kernel( - int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { +template +__device__ void +copy_blocks_internal_kernel(int64_t *key_cache_ptrs, int64_t *value_cache_ptrs, + const int64_t *__restrict__ block_mapping, + const int numel_per_block) { const int layer_idx = blockIdx.x; const int pair_idx = blockIdx.y; - scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); - scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]); + scalar_t *key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); + scalar_t *value_cache = + reinterpret_cast(value_cache_ptrs[layer_idx]); int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; @@ -30,53 +30,61 @@ __device__ void copy_blocks_internal_kernel( } // Monomorphize the generics ourselves -extern "C" __global__ void copy_blocks_kernel_u8(int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { - copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); +extern "C" __global__ void +copy_blocks_kernel_u8(int64_t *key_cache_ptrs, int64_t *value_cache_ptrs, + const int64_t *__restrict__ block_mapping, + const int numel_per_block) { + copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, + block_mapping, numel_per_block); } -extern "C" __global__ void copy_blocks_kernel_u32(int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { - copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); +extern "C" __global__ void +copy_blocks_kernel_u32(int64_t *key_cache_ptrs, int64_t *value_cache_ptrs, + const int64_t *__restrict__ block_mapping, + const int numel_per_block) { + copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, + block_mapping, numel_per_block); } -extern "C" __global__ void copy_blocks_kernel_i64(int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { - copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); +extern "C" __global__ void +copy_blocks_kernel_i64(int64_t *key_cache_ptrs, int64_t *value_cache_ptrs, + const int64_t *__restrict__ block_mapping, + const int numel_per_block) { + copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, + block_mapping, numel_per_block); } -extern "C" __global__ void copy_blocks_kernel_f32(int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { - copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); +extern "C" __global__ void +copy_blocks_kernel_f32(int64_t *key_cache_ptrs, int64_t *value_cache_ptrs, + const int64_t *__restrict__ block_mapping, + const int numel_per_block) { + copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, + block_mapping, numel_per_block); } -extern "C" __global__ void copy_blocks_kernel_f64(int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { - copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); +extern "C" __global__ void +copy_blocks_kernel_f64(int64_t *key_cache_ptrs, int64_t *value_cache_ptrs, + const int64_t *__restrict__ block_mapping, + const int numel_per_block) { + copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, + block_mapping, numel_per_block); } -// f16, bf16 are special cases: We use a 16-bit integer to simulate the bit width. -// SAFETY: This is technically UB due to aliasing, but it is OK because the width is compatible. -extern "C" __global__ void copy_blocks_kernel_f16(int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { - copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); +// f16, bf16 are special cases: We use a 16-bit integer to simulate the bit +// width. SAFETY: This is technically UB due to aliasing, but it is OK because +// the width is compatible. +extern "C" __global__ void +copy_blocks_kernel_f16(int64_t *key_cache_ptrs, int64_t *value_cache_ptrs, + const int64_t *__restrict__ block_mapping, + const int numel_per_block) { + copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, + block_mapping, numel_per_block); } -extern "C" __global__ void copy_blocks_kernel_bf16(int64_t* key_cache_ptrs, - int64_t* value_cache_ptrs, - const int64_t* __restrict__ block_mapping, - const int numel_per_block) { - copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, block_mapping, numel_per_block); +extern "C" __global__ void +copy_blocks_kernel_bf16(int64_t *key_cache_ptrs, int64_t *value_cache_ptrs, + const int64_t *__restrict__ block_mapping, + const int numel_per_block) { + copy_blocks_internal_kernel(key_cache_ptrs, value_cache_ptrs, + block_mapping, numel_per_block); } \ No newline at end of file diff --git a/mistralrs-paged-attn/src/cuda/cuda_compat.h b/mistralrs-paged-attn/src/cuda/cuda_compat.h index 90986179f9..897559487c 100644 --- a/mistralrs-paged-attn/src/cuda/cuda_compat.h +++ b/mistralrs-paged-attn/src/cuda/cuda_compat.h @@ -1,29 +1,28 @@ #pragma once - #ifndef USE_ROCM - #define VLLM_LDG(arg) __ldg(arg) +#define VLLM_LDG(arg) __ldg(arg) #else - #define VLLM_LDG(arg) *(arg) +#define VLLM_LDG(arg) *(arg) #endif #ifndef USE_ROCM - #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) #else - #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) #endif #ifndef USE_ROCM - #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) +#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) #else - #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #endif #ifndef USE_ROCM - #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ - cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) +#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) #else - #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ - hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) #endif - diff --git a/mistralrs-paged-attn/src/cuda/ffi.rs b/mistralrs-paged-attn/src/cuda/ffi.rs index 0b59da73f3..7ef017ca0b 100644 --- a/mistralrs-paged-attn/src/cuda/ffi.rs +++ b/mistralrs-paged-attn/src/cuda/ffi.rs @@ -1,5 +1,7 @@ use core::ffi::{c_int, c_long, c_void}; +use candle_core::cuda::cudarc::driver::sys::CUstream; + extern "C" { pub fn reshape_and_cache( key: *const c_void, @@ -15,6 +17,7 @@ extern "C" { x: c_int, key_stride: c_int, value_stride: c_int, + stream: CUstream, dtype: u32, ); @@ -40,6 +43,7 @@ extern "C" { q_stride: c_int, kv_block_stride: c_int, kv_head_stride: c_int, + stream: CUstream, dtype: u32, ); @@ -68,6 +72,7 @@ extern "C" { q_stride: c_int, kv_block_stride: c_int, kv_head_stride: c_int, + stream: CUstream, dtype: u32, ); diff --git a/mistralrs-paged-attn/src/cuda/pagedattention.cu b/mistralrs-paged-attn/src/cuda/pagedattention.cu index 540950cd68..6af6ec8a1e 100644 --- a/mistralrs-paged-attn/src/cuda/pagedattention.cu +++ b/mistralrs-paged-attn/src/cuda/pagedattention.cu @@ -1,5 +1,6 @@ /* - * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Copyright (c) 2023, The vLLM team. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * @@ -16,6 +17,7 @@ * limitations under the License. */ #include +#include #ifdef USE_ROCM #include @@ -35,11 +37,21 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(err)); \ + exit(err); \ + } \ + } while (0) + namespace vllm { // Utility function for attention softmax. -template -inline __device__ float block_sum(float* red_smem, float sum) { +template +inline __device__ float block_sum(float *red_smem, float sum) { // Decompose the thread index into warp / lane. int warp = threadIdx.x / WARP_SIZE; int lane = threadIdx.x % WARP_SIZE; @@ -74,44 +86,41 @@ inline __device__ float block_sum(float* red_smem, float sum) { } inline __device__ float fast_tanh(float x) { - #if defined(__CUDA_ARCH__) - #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) - float y; - asm volatile ( "tanh.approx.f32 %0, %1; " : "=f"(y) : "f"(x)); - return y; - #else - return ::tanhf(x); - #endif - #else +#if defined(__CUDA_ARCH__) +#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) + float y; + asm volatile("tanh.approx.f32 %0, %1; " : "=f"(y) : "f"(x)); + return y; +#else + return ::tanhf(x); +#endif +#else return std::tanh(x); - #endif +#endif } // TODO(woosuk): Merge the last two dimensions of the grid. // Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - int PARTITION_SIZE = 0> // Zero means no partitioning. +template // Zero means no partitioning. __device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const float softcapping, - const uint32_t* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const uint32_t* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride) { + float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t *__restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, const float softcapping, + const uint32_t + *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const uint32_t *__restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float *__restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -123,22 +132,29 @@ __device__ void paged_attention_kernel( } const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); const int num_blocks = end_block_idx - start_block_idx; // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); const int num_tokens = end_token_idx - start_token_idx; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -148,13 +164,14 @@ __device__ void paged_attention_kernel( const int num_heads = gridDim.x; const int num_queries_per_kv = num_heads / num_kv_heads; const int kv_head_idx = head_idx / num_queries_per_kv; - const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread group - // fetch or compute 16 bytes at a time. - // For example, if the size of a thread group is 4 and the data type is half, - // then the vector size is 16 / (4 * sizeof(half)) == 2. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; @@ -167,23 +184,26 @@ __device__ void paged_attention_kernel( // Load the query to registers. // Each thread in a thread group has a different part of the query. - // For example, if the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... - // th vectors of the query, and so on. - // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. - const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + // For example, if the thread group size is 4, then the first thread in the + // group has 0, 4, 8, ... th vectors of the query, and the second thread has + // 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because q is + // split from a qkv tensor, it may not be contiguous. + const scalar_t *q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; #pragma unroll - for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs // Memory planning. extern __shared__ char shared_mem[]; // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. - float* logits = reinterpret_cast(shared_mem); + float *logits = reinterpret_cast(shared_mem); // Workspace for reduction. __shared__ float red_smem[2 * NUM_WARPS]; @@ -196,45 +216,51 @@ __device__ void paged_attention_kernel( // Each warp fetches a block of keys for each iteration. // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. - const uint32_t* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); + const uint32_t *block_table = block_tables + seq_idx * max_num_blocks_per_seq; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); // Load a key to registers. // Each thread in a thread group has a different part of the key. - // For example, if the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th - // vectors of the key, and so on. + // For example, if the thread group size is 4, then the first thread in the + // group has 0, 4, 8, ... th vectors of the key, and the second thread has + // 1, 5, 9, ... th vectors of the key, and so on. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; K_vec k_vecs[NUM_VECS_PER_THREAD]; #pragma unroll for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride - + physical_block_offset * x; + const scalar_t *k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); } // Compute dot product. // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); - + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + // Apply softcapping if (softcapping != 1.0) { qk = fast_tanh(qk / softcapping) * softcapping; } // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + qk += + (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -287,13 +313,12 @@ __device__ void paged_attention_kernel( // If partitioning is enabled, store the max logit and exp_sum. if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float *max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; + float *exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; *exp_sums_ptr = exp_sum; } @@ -305,7 +330,8 @@ __device__ void paged_attention_kernel( constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; @@ -316,32 +342,37 @@ __device__ void paged_attention_kernel( scalar_t zero_value; zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64 - // because int32 can lead to overflow when this variable is multiplied by large numbers - // (e.g., kv_block_stride). - const int64_t physical_block_number = static_cast(block_table[block_idx]); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); - const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride; + const scalar_t *v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; - V_vec v_vec = *reinterpret_cast(v_ptr + offset); + V_vec v_vec = *reinterpret_cast(v_ptr + offset); if (block_idx == num_context_blocks - 1) { - // NOTE(woosuk): When v_vec contains the tokens that are out of the context, - // we should explicitly zero out the values since they may contain NaNs. - // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 - scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t *v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { - v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + v_vec_ptr[j] = + token_idx + j < context_len ? v_vec_ptr[j] : zero_value; } } accs[i] += dot(logits_vec, v_vec); @@ -360,18 +391,18 @@ __device__ void paged_attention_kernel( accs[i] = acc; } - // NOTE(woosuk): A barrier is required because the shared memory space for logits - // is reused for the output. + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. __syncthreads(); // Perform reduction across warps. - float* out_smem = reinterpret_cast(shared_mem); + float *out_smem = reinterpret_cast(shared_mem); #pragma unroll for (int i = NUM_WARPS; i > 1; i /= 2) { int mid = i / 2; // Upper warps write to shared memory. if (warp_idx >= mid && warp_idx < i) { - float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; + float *dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -384,7 +415,7 @@ __device__ void paged_attention_kernel( // Lower warps update the output. if (warp_idx < mid) { - const float* src = &out_smem[warp_idx * HEAD_SIZE]; + const float *src = &out_smem[warp_idx * HEAD_SIZE]; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -398,9 +429,9 @@ __device__ void paged_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; + scalar_t *out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -412,75 +443,71 @@ __device__ void paged_attention_kernel( } // Grid: (num_heads, num_seqs, 1). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS> +template __global__ void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const float softcapping, - const uint32_t* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const uint32_t* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride) { + scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, const float softcapping, + const uint32_t + *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const uint32_t *__restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float *__restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride) { paged_attention_kernel( - /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, num_kv_heads, scale, softcapping, block_tables, context_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_kv_heads, scale, softcapping, block_tables, context_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride); } // Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> +template __global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const float softcapping, - const uint32_t* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const uint32_t* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride) { - paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, softcapping, - block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride); + float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, const float softcapping, + const uint32_t + *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const uint32_t *__restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float *__restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + softcapping, block_tables, context_lens, max_num_blocks_per_seq, + alibi_slopes, q_stride, kv_block_stride, kv_head_stride); } // Grid: (num_heads, num_seqs). -template< - typename scalar_t, - int HEAD_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> +template __global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const uint32_t* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions) { + scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] + const float + *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float + *__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const uint32_t *__restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; @@ -488,9 +515,11 @@ __global__ void paged_attention_v2_reduce_kernel( const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t *out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t *tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { out_ptr[i] = tmp_out_ptr[i]; } @@ -508,9 +537,10 @@ __global__ void paged_attention_v2_reduce_kernel( __shared__ float red_smem[2 * NUM_WARPS]; // Load max logits to shared memory. - float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + float *shared_max_logits = reinterpret_cast(shared_mem); + const float *max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float max_logit = -FLT_MAX; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { const float l = max_logits_ptr[i]; @@ -539,9 +569,11 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. - float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + float *shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float *exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float global_exp_sum = 0.0f; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { float l = shared_max_logits[i]; @@ -554,14 +586,17 @@ __global__ void paged_attention_v2_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t *tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t *out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { float acc = 0.0f; for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; } from_float(out_ptr[i], acc); } @@ -569,53 +604,33 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), \ - shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ - reinterpret_cast(out), \ - reinterpret_cast(query), \ - reinterpret_cast(key_cache), \ - reinterpret_cast(value_cache), \ - num_kv_heads, \ - scale, \ - softcapping, \ - block_tables, \ - context_lens, \ - max_num_blocks_per_seq, \ - reinterpret_cast(alibi_slopes), \ - q_stride, \ - kv_block_stride, \ - kv_head_stride); +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void *)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + reinterpret_cast(out), reinterpret_cast(query), \ + reinterpret_cast(key_cache), \ + reinterpret_cast(value_cache), num_kv_heads, scale, \ + softcapping, block_tables, context_lens, max_num_blocks_per_seq, \ + reinterpret_cast(alibi_slopes), q_stride, kv_block_stride, \ + kv_head_stride); // TODO(woosuk): Tune NUM_THREADS. -template< - typename T, - int BLOCK_SIZE, - int NUM_THREADS = 128> -void paged_attention_v1_launcher( - void *out, - void *query, - void *key_cache, - void *value_cache, - void* __restrict__ alibi_slopes, - int num_kv_heads, - float scale, - float softcapping, - uint32_t *block_tables, - uint32_t *context_lens, - int max_context_len, - - int num_seqs, - int num_heads, - int head_size, - int max_num_blocks_per_seq, - int q_stride, - int kv_block_stride, - int kv_head_stride - ) { +template +void paged_attention_v1_launcher(void *out, void *query, void *key_cache, + void *value_cache, + void *__restrict__ alibi_slopes, + int num_kv_heads, float scale, + float softcapping, uint32_t *block_tables, + uint32_t *context_lens, int max_context_len, + + int num_seqs, int num_heads, int head_size, + int max_num_blocks_per_seq, int q_stride, + int kv_block_stride, int kv_head_stride, + cudaStream_t stream) { // int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); // assert(head_size % thread_group_size == 0); @@ -623,7 +638,8 @@ void paged_attention_v1_launcher( // NOTE: alibi_slopes is optional. It may be nullptr. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; + int padded_max_context_len = + DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_context_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len @@ -632,99 +648,77 @@ void paged_attention_v1_launcher( dim3 grid(num_heads, num_seqs, 1); dim3 block(NUM_THREADS); - const cudaStream_t stream = 0; switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V1(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V1(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V1(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V1(112); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V1(128); - break; - case 192: - LAUNCH_PAGED_ATTENTION_V1(192); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V1(256); - break; - default: - break; + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V1(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + break; } } -#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - alibi_slopes, \ - num_kv_heads, \ - scale, \ - softcapping, \ - block_tables, \ - context_lens, \ - max_context_len, \ - num_seqs, \ - num_heads, \ - head_size, \ - max_num_blocks_per_seq, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride); +#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_launcher( \ + out, query, key_cache, value_cache, alibi_slopes, num_kv_heads, scale, \ + softcapping, block_tables, context_lens, max_context_len, num_seqs, \ + num_heads, head_size, max_num_blocks_per_seq, q_stride, kv_block_stride, \ + kv_head_stride, stream); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER(T, 8); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER(T, 16); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER(T, 32); \ - break; \ - default: \ - break; \ +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, 32); \ + break; \ + default: \ + break; \ } extern "C" void paged_attention_v1( - void *out, // [num_seqs, num_heads, head_size] - void *query, // [num_seqs, num_heads, head_size] - void *key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - void *value_cache, // [num_blocks, num_heads, head_size, block_size] - void *alibi_slopes, // [num_heads] - int32_t num_kv_heads, - float scale, - float softcapping, - uint32_t *block_tables, // [num_seqs, max_num_blocks_per_seq] - uint32_t *context_lens, // [num_seqs] - int32_t block_size, - int32_t max_context_len, - - int32_t num_seqs, - int32_t num_heads, - int32_t head_size, - int32_t max_num_blocks_per_seq, - int32_t q_stride, - int32_t kv_block_stride, - int32_t kv_head_stride, - - uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 - ) { + void *out, // [num_seqs, num_heads, head_size] + void *query, // [num_seqs, num_heads, head_size] + void *key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + void *value_cache, // [num_blocks, num_heads, head_size, block_size] + void *alibi_slopes, // [num_heads] + int32_t num_kv_heads, float scale, float softcapping, + uint32_t *block_tables, // [num_seqs, max_num_blocks_per_seq] + uint32_t *context_lens, // [num_seqs] + int32_t block_size, int32_t max_context_len, + + int32_t num_seqs, int32_t num_heads, int32_t head_size, + int32_t max_num_blocks_per_seq, int32_t q_stride, int32_t kv_block_stride, + int32_t kv_head_stride, cudaStream_t stream, + + uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 +) { if (dtype == 2) { CALL_V1_LAUNCHER_BLOCK_SIZE(float); } else if (dtype == 0) { @@ -732,71 +726,42 @@ extern "C" void paged_attention_v1( } else if (dtype == 1) { CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); } + CUDA_CHECK(cudaGetLastError()); } -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums, \ - max_logits, \ - tmp_out_ptr, \ - reinterpret_cast(query), \ - reinterpret_cast(key_cache), \ - reinterpret_cast(value_cache), \ - num_kv_heads, \ - scale, \ - softcapping, \ - block_tables, \ - context_lens, \ - max_num_blocks_per_seq, \ - reinterpret_cast(alibi_slopes), \ - q_stride, \ - kv_block_stride, \ - kv_head_stride); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - reinterpret_cast(out), \ - exp_sums, \ - max_logits, \ - tmp_out_ptr, \ - context_lens, \ - max_num_partitions); - -template< - typename T, - int BLOCK_SIZE, - int NUM_THREADS = 128, - int PARTITION_SIZE = 512> +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums, max_logits, tmp_out_ptr, reinterpret_cast(query), \ + reinterpret_cast(key_cache), \ + reinterpret_cast(value_cache), num_kv_heads, scale, \ + softcapping, block_tables, context_lens, max_num_blocks_per_seq, \ + reinterpret_cast(alibi_slopes), q_stride, kv_block_stride, \ + kv_head_stride); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + reinterpret_cast(out), exp_sums, max_logits, tmp_out_ptr, \ + context_lens, max_num_partitions); + +template void paged_attention_v2_launcher( - void *out, - float *exp_sums, - float *max_logits, - void *tmp_out, - void *query, - void *key_cache, - void *value_cache, - void *alibi_slopes, - int num_kv_heads, - float scale, - float softcapping, - uint32_t *block_tables, - uint32_t *context_lens, - int max_context_len, - - int num_seqs, - int num_heads, - int head_size, - int max_num_blocks_per_seq, - int q_stride, - int kv_block_stride, - int kv_head_stride - - ) { + void *out, float *exp_sums, float *max_logits, void *tmp_out, void *query, + void *key_cache, void *value_cache, void *alibi_slopes, int num_kv_heads, + float scale, float softcapping, uint32_t *block_tables, + uint32_t *context_lens, int max_context_len, + + int num_seqs, int num_heads, int head_size, int max_num_blocks_per_seq, + int q_stride, int kv_block_stride, int kv_head_stride, cudaStream_t stream + +) { // int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); // NOTE: alibi_slopes is optional. It may be nullptr. - T* tmp_out_ptr = reinterpret_cast(tmp_out); + T *tmp_out_ptr = reinterpret_cast(tmp_out); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); @@ -811,105 +776,81 @@ void paged_attention_v2_launcher( int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); dim3 block(NUM_THREADS); - const cudaStream_t stream = 0; switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V2(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V2(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V2(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V2(112); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V2(128); - break; - case 192: - LAUNCH_PAGED_ATTENTION_V2(192); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V2(256); - break; - default: - break; + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V2(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V2(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V2(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V2(112); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V2(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V2(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V2(256); + break; + default: + break; } } -#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - alibi_slopes, \ - num_kv_heads, \ - scale, \ - softcapping, \ - block_tables, \ - context_lens, \ - max_context_len, \ - num_seqs, \ - num_heads, \ - head_size, \ - max_num_blocks_per_seq, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride); +#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + alibi_slopes, num_kv_heads, scale, softcapping, block_tables, \ + context_lens, max_context_len, num_seqs, num_heads, head_size, \ + max_num_blocks_per_seq, q_stride, kv_block_stride, kv_head_stride, \ + stream); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 8: \ - CALL_V2_LAUNCHER(T, 8); \ - break; \ - case 16: \ - CALL_V2_LAUNCHER(T, 16); \ - break; \ - case 32: \ - CALL_V2_LAUNCHER(T, 32); \ - break; \ - default: \ - break; \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER(T, 32); \ + break; \ + default: \ + break; \ } extern "C" void paged_attention_v2( - void *out, // [num_seqs, num_heads, head_size] - float *exp_sums, // [num_seqs, num_heads, max_num_partitions] - float *max_logits, // [num_seqs, num_heads, max_num_partitions] - void *tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - void *query, // [num_seqs, num_heads, head_size] - void *key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - void *value_cache, // [num_blocks, num_heads, head_size, block_size] - void *alibi_slopes, // [num_heads] - int32_t num_kv_heads, - float scale, - float softcapping, - uint32_t *block_tables, // [num_seqs, max_num_blocks_per_seq] - uint32_t *context_lens, // [num_seqs] - int32_t block_size, - int32_t max_context_len, - - int32_t num_seqs, - int32_t num_heads, - int32_t head_size, - int32_t max_num_blocks_per_seq, - int32_t q_stride, - int32_t kv_block_stride, - int32_t kv_head_stride, - - uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 - ) { + void *out, // [num_seqs, num_heads, head_size] + float *exp_sums, // [num_seqs, num_heads, max_num_partitions] + float *max_logits, // [num_seqs, num_heads, max_num_partitions] + void *tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + void *query, // [num_seqs, num_heads, head_size] + void *key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + void *value_cache, // [num_blocks, num_heads, head_size, block_size] + void *alibi_slopes, // [num_heads] + int32_t num_kv_heads, float scale, float softcapping, + uint32_t *block_tables, // [num_seqs, max_num_blocks_per_seq] + uint32_t *context_lens, // [num_seqs] + int32_t block_size, int32_t max_context_len, + + int32_t num_seqs, int32_t num_heads, int32_t head_size, + int32_t max_num_blocks_per_seq, int32_t q_stride, int32_t kv_block_stride, + int32_t kv_head_stride, cudaStream_t stream, + + uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 +) { if (dtype == 2) { CALL_V2_LAUNCHER_BLOCK_SIZE(float); } else if (dtype == 0) { @@ -917,6 +858,7 @@ extern "C" void paged_attention_v2( } else if (dtype == 1) { CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); } + CUDA_CHECK(cudaGetLastError()); } #undef WARP_SIZE diff --git a/mistralrs-paged-attn/src/cuda/reshape_and_cache_kernel.cu b/mistralrs-paged-attn/src/cuda/reshape_and_cache_kernel.cu index 165a674d6a..d2679f899b 100644 --- a/mistralrs-paged-attn/src/cuda/reshape_and_cache_kernel.cu +++ b/mistralrs-paged-attn/src/cuda/reshape_and_cache_kernel.cu @@ -1,6 +1,7 @@ -#include #include +#include #include +#include #include "cuda_compat.h" @@ -9,21 +10,29 @@ #include #include +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(err)); \ + exit(err); \ + } \ + } while (0) + namespace vllm { -template +template __global__ void reshape_and_cache_kernel( - const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] - const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] - const int64_t* __restrict__ slot_mapping, // [num_tokens] - const int key_stride, - const int value_stride, - const int num_heads, - const int head_size, - const int block_size, - const int x) { + const scalar_t *__restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t *__restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, + // block_size, x] + scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, + // block_size] + const int64_t *__restrict__ slot_mapping, // [num_tokens] + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -44,64 +53,50 @@ __global__ void reshape_and_cache_kernel( const int x_idx = head_offset / x; const int x_offset = head_offset % x; - const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; key_cache[tgt_key_idx] = key[src_key_idx]; value_cache[tgt_value_idx] = value[src_value_idx]; } } -#define CALL_RESHAPE_AND_CACHE(T) \ - vllm::reshape_and_cache_kernel<<>>( \ - reinterpret_cast(key), \ - reinterpret_cast(value), \ - reinterpret_cast(key_cache), \ - reinterpret_cast(value_cache), \ - slot_mapping, \ - key_stride, \ - value_stride, \ - num_heads, \ - head_size, \ - block_size, \ - x); - +#define CALL_RESHAPE_AND_CACHE(T) \ + vllm::reshape_and_cache_kernel<<>>( \ + reinterpret_cast(key), reinterpret_cast(value), \ + reinterpret_cast(key_cache), reinterpret_cast(value_cache), \ + slot_mapping, key_stride, value_stride, num_heads, head_size, \ + block_size, x); } // namespace vllm extern "C" void reshape_and_cache( - void *key, // [num_tokens, num_heads, head_size] - void *value, // [num_tokens, num_heads, head_size] - void *key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - void *value_cache, // [num_blocks, num_heads, head_size, block_size] - int64_t* slot_mapping, // [num_tokens] + void *key, // [num_tokens, num_heads, head_size] + void *value, // [num_tokens, num_heads, head_size] + void *key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + void *value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t *slot_mapping, // [num_tokens] - int32_t num_tokens, - int32_t num_heads, - int32_t head_size, - int32_t block_size, - int32_t x, - int32_t key_stride, - int32_t value_stride, + int32_t num_tokens, int32_t num_heads, int32_t head_size, + int32_t block_size, int32_t x, int32_t key_stride, int32_t value_stride, + cudaStream_t stream, - uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 - ) -{ + uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 +) { dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); - const cudaStream_t stream = 0; - if (dtype == 0){ + if (dtype == 0) { CALL_RESHAPE_AND_CACHE(uint16_t); } else if (dtype == 1) { CALL_RESHAPE_AND_CACHE(__nv_bfloat16); } else if (dtype == 2) { CALL_RESHAPE_AND_CACHE(float); } + CUDA_CHECK(cudaGetLastError()); } \ No newline at end of file diff --git a/mistralrs-paged-attn/src/metal/kernels/copy_blocks.metal b/mistralrs-paged-attn/src/metal/kernels/copy_blocks.metal index 809b8d71b9..952370f9e2 100644 --- a/mistralrs-paged-attn/src/metal/kernels/copy_blocks.metal +++ b/mistralrs-paged-attn/src/metal/kernels/copy_blocks.metal @@ -65,57 +65,49 @@ struct _MLX_BFloat16 { ///////////////////////////////////////////////////////////////////////////// // Conversions to bfloat - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) thread : bits_(float_to_bfloat_bits(static_cast(x))) {} - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup : bits_(float_to_bfloat_bits(static_cast(x))) {} - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) device : bits_(float_to_bfloat_bits(static_cast(x))) {} - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) constant : bits_(float_to_bfloat_bits(static_cast(x))) {} ///////////////////////////////////////////////////////////////////////////// // Conversions from bfloat - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const thread { return static_cast(bfloat_bits_to_float(bits_)); } - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const threadgroup { return static_cast(bfloat_bits_to_float(bits_)); } - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const device { return static_cast(bfloat_bits_to_float(bits_)); } - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const constant { return static_cast(bfloat_bits_to_float(bits_)); } @@ -133,29 +125,29 @@ constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { ///////////////////////////////////////////////////////////////////////////// // Binary operators -#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ - constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ } -#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ - constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } \ - constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ } ///////////////////////////////////////////////////////////////////////////// // Arithmetic Operators -#define bfloat_binop(_op_, _operator_) \ - bfloat_binop_base( \ - _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(_op_, _operator_, float, float, float); \ - bfloat_binop_helper(_op_, _operator_, float, half, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \ + _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); bfloat_binop(+, operator+); @@ -165,14 +157,14 @@ bfloat_binop(/, operator/); ///////////////////////////////////////////////////////////////////////////// // Comparison ops -#define bfloat_compop(__op__, __operator__) \ - bfloat_binop_base( \ - __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(__op__, __operator__, bool, float, float); \ - bfloat_binop_helper(__op__, __operator__, bool, half, float); \ - bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base(__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, \ + float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); bfloat_compop(>, operator>); @@ -189,27 +181,27 @@ bfloat_compop(!=, operator!=); ///////////////////////////////////////////////////////////////////////////// // Inplace Operators -#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ - constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ - addr_space _MLX_BFloat16& lhs, itype rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ - } \ - constexpr METAL_FUNC addr_space itype& __operator__( \ - addr_space itype& lhs, _MLX_BFloat16 rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype &__operator__(addr_space itype &lhs, \ + _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ } -#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ - bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ - bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); -#define bfloat_inplace_op(itype) \ - bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ - bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ - bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ bfloat_inplace_op_addr_space_helper(/, operator/=, itype); bfloat_inplace_op(float); @@ -225,16 +217,16 @@ bfloat_inplace_op(uint64_t); #undef bfloat_inplace_op_addr_space_helper #undef bfloat_inplace_op -#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ - constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ - addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ } -#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ - bfloat_inplace_op_helper(__op__, __operator__, device); \ - bfloat_inplace_op_helper(__op__, __operator__, thread); \ +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ bfloat_inplace_op_helper(__op__, __operator__, threadgroup); bfloat_inplace_op_addr_space_helper(+, operator+=); @@ -253,54 +245,47 @@ typedef struct _MLX_BFloat16 bfloat16_t; #endif +template +[[kernel]] void copy_blocks(device T *key_cache [[buffer(0)]], + device T *value_cache [[buffer(1)]], + const device int64_t *block_mapping [[buffer(2)]], + device const int &numel_per_block, + uint gid [[thread_position_in_grid]], + uint tid [[thread_position_in_threadgroup]], + uint threads_per_threadgroup + [[threads_per_threadgroup]]) { + const int pair_idx = gid; + + int64_t src_block_number = block_mapping[2 * pair_idx]; + int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; + + const int64_t src_block_offset = src_block_number * numel_per_block; + const int64_t dst_block_offset = dst_block_number * numel_per_block; + + // Copy key cache blocks + for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + key_cache[dst_offset] = key_cache[src_offset]; + } - - - -template -[[kernel]] void copy_blocks( - device T* key_cache [[buffer(0)]], - device T* value_cache [[buffer(1)]], - const device int64_t* block_mapping [[buffer(2)]], - device const int& numel_per_block, - uint gid [[thread_position_in_grid]], - uint tid [[thread_position_in_threadgroup]], - uint threads_per_threadgroup [[threads_per_threadgroup]]) -{ - const int pair_idx = gid; - - int64_t src_block_number = block_mapping[2 * pair_idx]; - int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; - - const int64_t src_block_offset = src_block_number * numel_per_block; - const int64_t dst_block_offset = dst_block_number * numel_per_block; - - // Copy key cache blocks - for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) { - int64_t src_offset = src_block_offset + i; - int64_t dst_offset = dst_block_offset + i; - key_cache[dst_offset] = key_cache[src_offset]; - } - - // Copy value cache blocks - for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) { - int64_t src_offset = src_block_offset + i; - int64_t dst_offset = dst_block_offset + i; - value_cache[dst_offset] = value_cache[src_offset]; - } + // Copy value cache blocks + for (int i = tid; i < numel_per_block; i += threads_per_threadgroup) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + value_cache[dst_offset] = value_cache[src_offset]; + } } -#define instantiate_copy_blocks(type) \ - template [[host_name("copy_blocks_" #type)]] \ - [[kernel]] void copy_blocks( \ - device type* key_cache_ptrs [[buffer(0)]], \ - device type * value_cache_ptrs [[buffer(1)]], \ - const device int64_t* block_mapping [[buffer(2)]], \ - device const int& numel_per_block, \ - uint gid [[thread_position_in_grid]], \ - uint tid [[thread_position_in_threadgroup]], \ - uint threads_per_threadgroup [[threads_per_threadgroup]]); - -instantiate_copy_blocks(float) -instantiate_copy_blocks(bfloat16_t) -instantiate_copy_blocks(half) +#define instantiate_copy_blocks(type) \ + template [[host_name("copy_blocks_" #type)]] [[kernel]] void \ + copy_blocks(device type * key_cache_ptrs [[buffer(0)]], \ + device type * value_cache_ptrs [[buffer(1)]], \ + const device int64_t *block_mapping [[buffer(2)]], \ + device const int &numel_per_block, \ + uint gid [[thread_position_in_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint threads_per_threadgroup [[threads_per_threadgroup]]); + +instantiate_copy_blocks(float) instantiate_copy_blocks(bfloat16_t) + instantiate_copy_blocks(half) diff --git a/mistralrs-paged-attn/src/metal/kernels/pagedattention.metal b/mistralrs-paged-attn/src/metal/kernels/pagedattention.metal index 3458f67712..ade9ed3745 100644 --- a/mistralrs-paged-attn/src/metal/kernels/pagedattention.metal +++ b/mistralrs-paged-attn/src/metal/kernels/pagedattention.metal @@ -1,7 +1,7 @@ // Updated from MLX commit has f70764a -#include #include +#include using namespace metal; @@ -69,57 +69,49 @@ struct _MLX_BFloat16 { ///////////////////////////////////////////////////////////////////////////// // Conversions to bfloat - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) thread : bits_(float_to_bfloat_bits(static_cast(x))) {} - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup : bits_(float_to_bfloat_bits(static_cast(x))) {} - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) device : bits_(float_to_bfloat_bits(static_cast(x))) {} - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) constant : bits_(float_to_bfloat_bits(static_cast(x))) {} ///////////////////////////////////////////////////////////////////////////// // Conversions from bfloat - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const thread { return static_cast(bfloat_bits_to_float(bits_)); } - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const threadgroup { return static_cast(bfloat_bits_to_float(bits_)); } - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const device { return static_cast(bfloat_bits_to_float(bits_)); } - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const constant { return static_cast(bfloat_bits_to_float(bits_)); } @@ -137,29 +129,29 @@ constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { ///////////////////////////////////////////////////////////////////////////// // Binary operators -#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ - constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ } -#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ - constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } \ - constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ } ///////////////////////////////////////////////////////////////////////////// // Arithmetic Operators -#define bfloat_binop(_op_, _operator_) \ - bfloat_binop_base( \ - _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(_op_, _operator_, float, float, float); \ - bfloat_binop_helper(_op_, _operator_, float, half, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \ + _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); bfloat_binop(+, operator+); @@ -169,14 +161,14 @@ bfloat_binop(/, operator/); ///////////////////////////////////////////////////////////////////////////// // Comparison ops -#define bfloat_compop(__op__, __operator__) \ - bfloat_binop_base( \ - __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(__op__, __operator__, bool, float, float); \ - bfloat_binop_helper(__op__, __operator__, bool, half, float); \ - bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base(__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, \ + float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); bfloat_compop(>, operator>); @@ -193,27 +185,27 @@ bfloat_compop(!=, operator!=); ///////////////////////////////////////////////////////////////////////////// // Inplace Operators -#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ - constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ - addr_space _MLX_BFloat16& lhs, itype rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ - } \ - constexpr METAL_FUNC addr_space itype& __operator__( \ - addr_space itype& lhs, _MLX_BFloat16 rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype &__operator__(addr_space itype &lhs, \ + _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ } -#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ - bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ - bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); -#define bfloat_inplace_op(itype) \ - bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ - bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ - bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ bfloat_inplace_op_addr_space_helper(/, operator/=, itype); bfloat_inplace_op(float); @@ -229,16 +221,16 @@ bfloat_inplace_op(uint64_t); #undef bfloat_inplace_op_addr_space_helper #undef bfloat_inplace_op -#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ - constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ - addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ } -#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ - bfloat_inplace_op_helper(__op__, __operator__, device); \ - bfloat_inplace_op_helper(__op__, __operator__, thread); \ +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ bfloat_inplace_op_helper(__op__, __operator__, threadgroup); bfloat_inplace_op_addr_space_helper(+, operator+=); @@ -260,114 +252,76 @@ typedef struct _MLX_BFloat16 bfloat16_t; // ========================================== Generic vector types // A vector type to store Q, K, V elements. -template -struct Vec {}; +template struct Vec {}; // A vector type to store FP32 accumulators. -template -struct FloatVec {}; +template struct FloatVec {}; // Template vector operations. -template -inline Acc mul(A a, B b); +template inline Acc mul(A a, B b); -template -inline float sum(T v); +template inline float sum(T v); -template -inline float dot(T a, T b) { +template inline float dot(T a, T b) { return sum(mul(a, b)); } -template -inline float dot(T a, T b) { +template inline float dot(T a, T b) { return sum(mul(a, b)); } - - // FP32 vector data types. struct Float8_ { float4 x; float4 y; }; -template<> -struct Vec { +template <> struct Vec { using Type = float; }; -template<> -struct Vec { +template <> struct Vec { using Type = float2; }; -template<> -struct Vec { +template <> struct Vec { using Type = float4; }; -template<> -struct Vec { +template <> struct Vec { using Type = Float8_; }; -template<> -struct FloatVec { +template <> struct FloatVec { using Type = float; }; -template<> -struct FloatVec { +template <> struct FloatVec { using Type = float2; }; -template<> -struct FloatVec { +template <> struct FloatVec { using Type = float4; }; -template<> -struct FloatVec { +template <> struct FloatVec { using Type = Float8_; }; -template<> -inline float mul(float a, float b) { - return a*b; -} +template <> inline float mul(float a, float b) { return a * b; } -template<> -inline float2 mul(float2 a, float2 b) { - return a*b; -} +template <> inline float2 mul(float2 a, float2 b) { return a * b; } -template<> -inline float4 mul(float4 a, float4 b) { - return a*b; -} +template <> inline float4 mul(float4 a, float4 b) { return a * b; } -template<> -inline Float8_ mul(Float8_ a, Float8_ b) { +template <> inline Float8_ mul(Float8_ a, Float8_ b) { Float8_ c; c.x = a.x * b.x; c.y = a.y * b.y; return c; } -template<> -inline float sum(float a) { - return a; -} +template <> inline float sum(float a) { return a; } -template<> -inline float sum(float2 a) { - return a.x + a.y; -} +template <> inline float sum(float2 a) { return a.x + a.y; } -template<> -inline float sum(float4 a) { - return a.x + a.y + a.z + a.w; -} +template <> inline float sum(float4 a) { return a.x + a.y + a.z + a.w; } -template<> -inline float sum(Float8_ a) { - return sum(a.x) + sum(a.y); -} +template <> inline float sum(Float8_ a) { return sum(a.x) + sum(a.y); } inline Float8_ fma(Float8_ a, Float8_ b, Float8_ c) { Float8_ res; @@ -376,21 +330,10 @@ inline Float8_ fma(Float8_ a, Float8_ b, Float8_ c) { return res; } -inline void from_float(thread float& dst, float src) { - dst = src; -} -inline void from_float(thread float2& dst, float2 src) { - dst = src; -} -inline void from_float(thread float4& dst, float4 src) { - dst = src; -} -inline void from_float(thread Float8_& dst, Float8_ src) { - dst = src; -} - - - +inline void from_float(thread float &dst, float src) { dst = src; } +inline void from_float(thread float2 &dst, float2 src) { dst = src; } +inline void from_float(thread float4 &dst, float4 src) { dst = src; } +inline void from_float(thread Float8_ &dst, Float8_ src) { dst = src; } // BF16 vector data types. // #if defined(__HAVE_BFLOAT__) @@ -560,65 +503,50 @@ struct Bfloat8_ { Bfloat4_ y; }; -template<> -struct Vec { +template <> struct Vec { using Type = bfloat16_t; }; -template<> -struct Vec { +template <> struct Vec { using Type = Bfloat2_; }; -template<> -struct Vec { +template <> struct Vec { using Type = Bfloat4_; }; -template<> -struct Vec { +template <> struct Vec { using Type = Bfloat8_; }; -template<> -struct FloatVec { +template <> struct FloatVec { using Type = float; }; -template<> -struct FloatVec { +template <> struct FloatVec { using Type = float2; }; -template<> -struct FloatVec { +template <> struct FloatVec { using Type = float4; }; -template<> -struct FloatVec { +template <> struct FloatVec { using Type = Float8_; }; -template<> -inline float mul(bfloat16_t a, bfloat16_t b) { +template <> inline float mul(bfloat16_t a, bfloat16_t b) { return (float)a * (float)b; } -template<> -inline bfloat16_t mul(bfloat16_t a, bfloat16_t b) { - return a*b; -} +template <> inline bfloat16_t mul(bfloat16_t a, bfloat16_t b) { return a * b; } -template<> -inline float2 mul(Bfloat2_ a, Bfloat2_ b) { +template <> inline float2 mul(Bfloat2_ a, Bfloat2_ b) { float2 a_f((float)a.x, (float)a.y); float2 b_f((float)b.x, (float)b.y); return a_f * b_f; } -template<> -inline Bfloat2_ mul(Bfloat2_ a, Bfloat2_ b) { +template <> inline Bfloat2_ mul(Bfloat2_ a, Bfloat2_ b) { Bfloat2_ c; c.x = a.x * b.x; c.y = a.y * b.y; return c; } -template<> -inline float4 mul(Bfloat4_ a, Bfloat4_ b) { +template <> inline float4 mul(Bfloat4_ a, Bfloat4_ b) { float2 x = mul(a.x, b.x); float2 y = mul(a.y, b.y); float4 c; @@ -628,54 +556,39 @@ inline float4 mul(Bfloat4_ a, Bfloat4_ b) { c.w = y.y; return c; } -template<> -inline Bfloat4_ mul(Bfloat4_ a, Bfloat4_ b) { +template <> inline Bfloat4_ mul(Bfloat4_ a, Bfloat4_ b) { Bfloat4_ c; c.x = mul(a.x, b.x); c.y = mul(a.y, b.y); return c; } -template<> -inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) { +template <> inline Float8_ mul(Bfloat8_ a, Bfloat8_ b) { Float8_ c; c.x = mul(a.x, b.x); c.y = mul(a.y, b.y); return c; } -template<> -inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) { +template <> inline Bfloat8_ mul(Bfloat8_ a, Bfloat8_ b) { Bfloat8_ c; c.x = mul(a.x, b.x); c.y = mul(a.y, b.y); return c; } -template<> -inline float sum(bfloat16_t a) { - return (float)a; -} +template <> inline float sum(bfloat16_t a) { return (float)a; } -template<> -inline float sum(Bfloat2_ a) { - return (float)a.x + (float)a.y; -} +template <> inline float sum(Bfloat2_ a) { return (float)a.x + (float)a.y; } -template<> -inline float sum(Bfloat4_ a) { - return sum(a.x) + sum(a.y); -} +template <> inline float sum(Bfloat4_ a) { return sum(a.x) + sum(a.y); } -template<> -inline float sum(Bfloat8_ a) { - return sum(a.x) + sum(a.y); -} +template <> inline float sum(Bfloat8_ a) { return sum(a.x) + sum(a.y); } inline float fma(bfloat16_t a, bfloat16_t b, float c) { return (float)a * (float)b + c; } inline bfloat16_t fma(bfloat16_t a, bfloat16_t b, bfloat16_t c) { - return a*b+c; + return a * b + c; } inline float2 fma(Bfloat2_ a, Bfloat2_ b, float2 c) { @@ -720,20 +633,20 @@ inline Bfloat8_ fma(Bfloat8_ a, Bfloat8_ b, Bfloat8_ c) { return c; } -inline void from_float(thread bfloat16_t& dst, float src) { +inline void from_float(thread bfloat16_t &dst, float src) { dst = static_cast(src); } -inline void from_float(thread Bfloat2_& dst, float2 src) { +inline void from_float(thread Bfloat2_ &dst, float2 src) { dst.x = static_cast(src.x); dst.y = static_cast(src.y); } -inline void from_float(thread Bfloat4_& dst, float4 src) { +inline void from_float(thread Bfloat4_ &dst, float4 src) { dst.x.x = static_cast(src.x); dst.x.y = static_cast(src.y); dst.y.x = static_cast(src.z); dst.y.y = static_cast(src.w); } -inline void from_float(thread Bfloat8_& dst, Float8_ src) { +inline void from_float(thread Bfloat8_ &dst, Float8_ src) { Bfloat4_ x; Bfloat4_ y; from_float(x, src.x); @@ -744,79 +657,52 @@ inline void from_float(thread Bfloat8_& dst, Float8_ src) { // #endif - - - - // FP16 vector data types. struct Half8_ { half4 x; half4 y; }; -template<> -struct Vec { +template <> struct Vec { using Type = half; }; -template<> -struct Vec { +template <> struct Vec { using Type = half2; }; -template<> -struct Vec { +template <> struct Vec { using Type = half4; }; -template<> -struct Vec { +template <> struct Vec { using Type = Half8_; }; -template<> -struct FloatVec { +template <> struct FloatVec { using Type = float; }; -template<> -struct FloatVec { +template <> struct FloatVec { using Type = float2; }; -template<> -struct FloatVec { +template <> struct FloatVec { using Type = float4; }; -template<> -struct FloatVec { +template <> struct FloatVec { using Type = Float8_; }; -template<> -inline float mul(half a, half b) { - return (float)a * (float)b; -} -template<> -inline half mul(half a, half b) { - return a*b; -} +template <> inline float mul(half a, half b) { return (float)a * (float)b; } +template <> inline half mul(half a, half b) { return a * b; } -template<> -inline float2 mul(half2 a, half2 b) { +template <> inline float2 mul(half2 a, half2 b) { return (float2)a * (float2)b; } -template<> -inline half2 mul(half2 a, half2 b) { - return a * b; -} +template <> inline half2 mul(half2 a, half2 b) { return a * b; } -template<> -inline float4 mul(half4 a, half4 b) { +template <> inline float4 mul(half4 a, half4 b) { return (float4)a * (float4)b; } -template<> -inline half4 mul(half4 a, half4 b) { - return a * b; -} +template <> inline half4 mul(half4 a, half4 b) { return a * b; } -template<> -inline Float8_ mul(Half8_ a, Half8_ b) { +template <> inline Float8_ mul(Half8_ a, Half8_ b) { float4 x = mul(a.x, b.x); float4 y = mul(a.y, b.y); Float8_ c; @@ -824,37 +710,22 @@ inline Float8_ mul(Half8_ a, Half8_ b) { c.y = y; return c; } -template<> -inline Half8_ mul(Half8_ a, Half8_ b) { +template <> inline Half8_ mul(Half8_ a, Half8_ b) { Half8_ c; c.x = mul(a.x, b.x); c.y = mul(a.y, b.y); return c; } -template<> -inline float sum(half a) { - return (float)a; -} +template <> inline float sum(half a) { return (float)a; } -template<> -inline float sum(half2 a) { - return (float)a.x + (float)a.y; -} +template <> inline float sum(half2 a) { return (float)a.x + (float)a.y; } -template<> -inline float sum(half4 a) { - return sum(a.x) + sum(a.y); -} +template <> inline float sum(half4 a) { return sum(a.x) + sum(a.y); } -template<> -inline float sum(Half8_ a) { - return sum(a.x) + sum(a.y); -} +template <> inline float sum(Half8_ a) { return sum(a.x) + sum(a.y); } -inline float fma(half a, half b, float c) { - return (float)a * (float)b + c; -} +inline float fma(half a, half b, float c) { return (float)a * (float)b + c; } inline float2 fma(half2 a, half2 b, float2 c) { return (float2)a * (float2)b + c; @@ -879,20 +750,20 @@ inline Half8_ fma(Half8_ a, Half8_ b, Half8_ c) { return c; } -inline void from_float(thread half& dst, float src) { +inline void from_float(thread half &dst, float src) { dst = static_cast(src); } -inline void from_float(thread half2& dst, float2 src) { +inline void from_float(thread half2 &dst, float2 src) { dst.x = static_cast(src.x); dst.y = static_cast(src.y); } -inline void from_float(thread half4& dst, float4 src) { +inline void from_float(thread half4 &dst, float4 src) { dst.x = static_cast(src.x); dst.y = static_cast(src.y); dst.z = static_cast(src.z); dst.w = static_cast(src.w); } -inline void from_float(thread Half8_& dst, Float8_ src) { +inline void from_float(thread Half8_ &dst, Float8_ src) { half4 x; half4 y; from_float(x, src.x); @@ -904,7 +775,7 @@ inline void from_float(thread Half8_& dst, Float8_ src) { // ========================================== Dot product utilities // TODO(EricLBuehler): optimize with vectorization -template +template inline float qk_dot_(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) { // Compute the parallel products for Q*K^T (treat vector lanes separately). using A_vec = typename FloatVec::Type; @@ -923,10 +794,10 @@ inline float qk_dot_(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) { return qk; } -template -struct Qk_dot { - template - static inline float dot(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) { +template struct Qk_dot { + template + static inline float dot(const threadgroup Vec (&q)[N], + const thread Vec (&k)[N]) { return qk_dot_(q, k); } }; @@ -934,8 +805,9 @@ struct Qk_dot { // ========================================== Block sum utility // Utility function for attention softmax. -template -inline float block_sum(threadgroup float* red_smem, float sum, uint simd_tid, uint simd_lid) { +template +inline float block_sum(threadgroup float *red_smem, float sum, uint simd_tid, + uint simd_lid) { // Compute the sum per simdgroup. #pragma unroll for (int mask = NUM_SIMD_LANES / 2; mask >= 1; mask /= 2) { @@ -967,7 +839,6 @@ inline float block_sum(threadgroup float* red_smem, float sum, uint simd_tid, ui // ========================================== Paged Attention kernel - #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) @@ -975,31 +846,40 @@ inline float block_sum(threadgroup float* red_smem, float sum, uint simd_tid, ui constant bool use_partitioning [[function_constant(10)]]; constant bool use_alibi [[function_constant(20)]]; -template +template [[kernel]] void paged_attention( - device float* exp_sums [[buffer(0), function_constant(use_partitioning)]], // [num_seqs, num_heads, max_num_partitions] - device float* max_logits [[buffer(1), function_constant(use_partitioning)]], // [num_seqs, num_heads, max_num_partitions] - device T* out [[buffer(2)]], // [num_seqs, num_heads, max_num_partitions, head_size] - device const T* q [[buffer(3)]], // [num_seqs, num_heads, head_size] - device const T* k_cache [[buffer(4)]], // [num_blocks, num_kv_heads, head_size/x, block_size, x] - device const T* v_cache [[buffer(5)]], // [num_blocks, num_kv_heads, head_size, block_size] - const constant int& num_kv_heads [[buffer(6)]], // [num_heads] - const constant float& scale [[buffer(7)]], - const constant float& softcapping [[buffer(8)]], - device const uint32_t* block_tables [[buffer(9)]], // [num_seqs, max_num_blocks_per_seq] - device const uint32_t* context_lens [[buffer(10)]], // [num_seqs] - const constant int& max_num_blocks_per_seq [[buffer(11)]], - device const float* alibi_slopes [[buffer(12), function_constant(use_alibi)]], // [num_heads] - const constant int& q_stride [[buffer(13)]], - const constant int& kv_block_stride [[buffer(14)]], - const constant int& kv_head_stride [[buffer(15)]], - threadgroup char* shared_mem [[threadgroup(0)]], + device float *exp_sums + [[buffer(0), function_constant(use_partitioning)]], // [num_seqs, num_heads, + // max_num_partitions] + device float *max_logits + [[buffer(1), function_constant(use_partitioning)]], // [num_seqs, num_heads, + // max_num_partitions] + device T *out + [[buffer(2)]], // [num_seqs, num_heads, max_num_partitions, head_size] + device const T *q [[buffer(3)]], // [num_seqs, num_heads, head_size] + device const T *k_cache + [[buffer(4)]], // [num_blocks, num_kv_heads, head_size/x, block_size, x] + device const T *v_cache + [[buffer(5)]], // [num_blocks, num_kv_heads, head_size, block_size] + const constant int &num_kv_heads [[buffer(6)]], // [num_heads] + const constant float &scale [[buffer(7)]], + const constant float &softcapping [[buffer(8)]], + device const uint32_t *block_tables + [[buffer(9)]], // [num_seqs, max_num_blocks_per_seq] + device const uint32_t *context_lens [[buffer(10)]], // [num_seqs] + const constant int &max_num_blocks_per_seq [[buffer(11)]], + device const float *alibi_slopes + [[buffer(12), function_constant(use_alibi)]], // [num_heads] + const constant int &q_stride [[buffer(13)]], + const constant int &kv_block_stride [[buffer(14)]], + const constant int &kv_head_stride [[buffer(15)]], + threadgroup char *shared_mem [[threadgroup(0)]], uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], uint3 threadgroups_per_grid [[threadgroups_per_grid]], uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], uint simd_tid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]] -) { + uint simd_lid [[thread_index_in_simdgroup]]) { const int seq_idx = threadgroup_position_in_grid.y; const int partition_idx = threadgroup_position_in_grid.z; const int max_num_partitions = threadgroups_per_grid.z; @@ -1012,22 +892,29 @@ template ::Type; using Q_vec = typename Vec::Type; @@ -1055,20 +942,22 @@ template (q_ptr + vec_idx * VEC_SIZE); + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); } threadgroup_barrier(mem_flags::mem_threadgroup); // Use fp32 on softmax logits for better accuracy - threadgroup float* logits = reinterpret_cast(shared_mem); + threadgroup float *logits = reinterpret_cast(shared_mem); // Workspace for reduction threadgroup float red_smem[2 * NUM_WARPS]; @@ -1081,45 +970,52 @@ template (block_table[block_idx]); + // because int32 can lead to overflow when this variable is multiplied by + // large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); // Load a key to registers. // Each thread in a thread group has a different part of the key. - // For example, if the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th - // vectors of the key, and so on. + // For example, if the thread group size is 4, then the first thread in the + // group has 0, 4, 8, ... th vectors of the key, and the second thread has + // 1, 5, 9, ... th vectors of the key, and so on. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = (thread_group_idx + i * NUM_SIMD_LANES) % BLOCK_SIZE; + const int physical_block_offset = + (thread_group_idx + i * NUM_SIMD_LANES) % BLOCK_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; K_vec k_vecs[NUM_VECS_PER_THREAD]; #pragma unroll for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const device T* k_ptr = k_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride - + physical_block_offset * x; + const device T *k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); } // Compute dot product. // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); - + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + // Apply softcapping if (softcapping != 1.0) { qk = precise::tanh(qk / softcapping) * softcapping; } // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + qk += + (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -1160,7 +1056,8 @@ template (&red_smem[NUM_WARPS], exp_sum, simd_tid, simd_lid); + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum, + simd_tid, simd_lid); // Compute softmax. const float inv_sum = divide(1.f, exp_sum + 1e-6f); @@ -1171,13 +1068,13 @@ template (block_table[block_idx]); + // because int32 can lead to overflow when this variable is multiplied by + // large numbers (e.g., kv_block_stride). + const int64_t physical_block_number = + static_cast(block_table[block_idx]); const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - Float_L_vec logits_float_vec = *reinterpret_cast(logits + token_idx - start_token_idx); + Float_L_vec logits_float_vec = *reinterpret_cast( + logits + token_idx - start_token_idx); from_float(logits_vec, logits_float_vec); - const device T* v_ptr = v_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride; + const device T *v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -1219,13 +1120,15 @@ template (v_ptr + offset); + // See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + V_vec v_vec = *reinterpret_cast(v_ptr + offset); if (block_idx == num_context_blocks - 1) { - thread T* v_vec_ptr = reinterpret_cast(&v_vec); + thread T *v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { - v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + v_vec_ptr[j] = + token_idx + j < context_len ? v_vec_ptr[j] : zero_value; } } accs[i] += dot(logits_vec, v_vec); @@ -1249,13 +1152,14 @@ template (shared_mem); + threadgroup float *out_smem = + reinterpret_cast(shared_mem); #pragma unroll for (int i = NUM_WARPS; i > 1; i /= 2) { int mid = i / 2; // Upper warps write to shared memory. if (warp_idx >= mid && warp_idx < i) { - threadgroup float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; + threadgroup float *dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -1268,7 +1172,7 @@ template +template [[kernel]] void paged_attention_v2_reduce( - device T* out [[buffer(0)]], - const device float* exp_sums [[buffer(1)]], - const device float* max_logits [[buffer(2)]], - const device T* tmp_out [[buffer(3)]], - device uint32_t* context_lens [[buffer(4)]], - const constant int& max_num_partitions [[buffer(5)]], - threadgroup char* shared_mem [[threadgroup(0)]], + device T *out [[buffer(0)]], const device float *exp_sums [[buffer(1)]], + const device float *max_logits [[buffer(2)]], + const device T *tmp_out [[buffer(3)]], + device uint32_t *context_lens [[buffer(4)]], + const constant int &max_num_partitions [[buffer(5)]], + threadgroup char *shared_mem [[threadgroup(0)]], uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], uint3 threadgroups_per_grid [[threadgroups_per_grid]], uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], uint3 threads_per_threadgroup [[threads_per_threadgroup]], uint simd_tid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]] -) { + uint simd_lid [[thread_index_in_simdgroup]]) { const int num_heads = threadgroups_per_grid.x; const int head_idx = threadgroup_position_in_grid.x; const int seq_idx = threadgroup_position_in_grid.y; @@ -1318,10 +1221,13 @@ template (shared_mem); - const device float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + threadgroup float *shared_max_logits = + reinterpret_cast(shared_mem); + const device float *max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float max_logit = -FLT_MAX; - for (int i = thread_position_in_threadgroup.x; i < num_partitions; i += threads_per_threadgroup.x) { + for (int i = thread_position_in_threadgroup.x; i < num_partitions; + i += threads_per_threadgroup.x) { const float l = max_logits_ptr[i]; shared_max_logits[i] = l; max_logit = max(max_logit, l); @@ -1367,116 +1276,162 @@ template (shared_mem + sizeof(float) * num_partitions); - const device float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; + threadgroup float *shared_exp_sums = reinterpret_cast( + shared_mem + sizeof(float) * num_partitions); + const device float *exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; float global_exp_sum = 0.0f; - for (int i = thread_position_in_threadgroup.x; i < num_partitions; i += threads_per_threadgroup.x) { + for (int i = thread_position_in_threadgroup.x; i < num_partitions; + i += threads_per_threadgroup.x) { float l = shared_max_logits[i]; float rescaled_exp_sum = exp_sums_ptr[i] * exp(l - max_logit); global_exp_sum += rescaled_exp_sum; shared_exp_sums[i] = rescaled_exp_sum; } threadgroup_barrier(mem_flags::mem_threadgroup); - global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum, simd_tid, simd_lid); + global_exp_sum = block_sum( + &red_smem[NUM_WARPS], global_exp_sum, simd_tid, simd_lid); const float inv_global_exp_sum = divide(1.0f, global_exp_sum + 1e-6f); // Aggregate tmp_out to out. - const device T* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - device T* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const device T *tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + device T *out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll - for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE; i += NUM_THREADS) { + for (int i = thread_position_in_threadgroup.x; i < HEAD_SIZE; + i += NUM_THREADS) { float acc = 0.0f; for (int j = 0; j < num_partitions; ++j) { - acc += float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + acc += float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; } out_ptr[i] = T(acc); } } -#define instantiate_paged_attention_inner(type, head_size, block_size, num_threads, num_simd_lanes, partition_size) \ - template [[host_name("paged_attention_" #type "_hs" #head_size "_bs" #block_size "_nt" #num_threads "_nsl" #num_simd_lanes "_ps" #partition_size)]] \ - [[kernel]] void paged_attention( \ - device float* exp_sums [[buffer(0), function_constant(use_partitioning)]], \ - device float* max_logits [[buffer(1), function_constant(use_partitioning)]], \ - device type* out [[buffer(2)]], \ - device const type* q [[buffer(3)]], \ - device const type* k_cache [[buffer(4)]], \ - device const type* v_cache [[buffer(5)]], \ - const constant int& num_kv_heads [[buffer(6)]], \ - const constant float& scale [[buffer(7)]], \ - const constant float& softcapping [[buffer(8)]], \ - device const uint32_t* block_tables [[buffer(9)]], \ - device const uint32_t* context_lens [[buffer(10)]], \ - const constant int& max_num_blocks_per_seq [[buffer(11)]], \ - device const float* alibi_slopes [[buffer(12), function_constant(use_alibi)]], \ - const constant int& q_stride [[buffer(13)]], \ - const constant int& kv_block_stride [[buffer(14)]], \ - const constant int& kv_head_stride [[buffer(15)]], \ - threadgroup char* shared_mem [[threadgroup(0)]], \ - uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ - uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ - uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \ - uint simd_tid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); \ - -#define instantiate_paged_attention_v2_reduce_inner(type, head_size, num_threads, num_simd_lanes, partition_size) \ - template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size "_nt" #num_threads "_nsl" #num_simd_lanes "_ps" #partition_size)]] \ - [[kernel]] void paged_attention_v2_reduce( \ - device type* out [[buffer(0)]], \ - const device float* exp_sums [[buffer(1)]], \ - const device float* max_logits [[buffer(2)]], \ - const device type* tmp_out [[buffer(3)]], \ - device uint32_t* context_lens [[buffer(4)]], \ - const constant int& max_num_partitions [[buffer(5)]], \ - threadgroup char* shared_mem [[threadgroup(0)]], \ - uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ - uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ - uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \ - uint3 threads_per_threadgroup [[threads_per_threadgroup]], \ - uint simd_tid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); \ - - -#define instantiate_paged_attention_heads(type, block_size, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_inner(type, 64, block_size, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_inner(type, 80, block_size, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_inner(type, 96, block_size, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_inner(type, 112, block_size, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_inner(type, 128, block_size, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_inner(type, 192, block_size, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_inner(type, 256, block_size, num_threads, num_simd_lanes, partition_size) - -#define instantiate_paged_attention_v2_reduce_heads(type, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_v2_reduce_inner(type, 64, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_v2_reduce_inner(type, 80, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_v2_reduce_inner(type, 96, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_v2_reduce_inner(type, 112, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_v2_reduce_inner(type, 128, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_v2_reduce_inner(type, 192, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_v2_reduce_inner(type, 256, num_threads, num_simd_lanes, partition_size) - -#define instantiate_paged_attention_block_size(type, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_heads(type, 8, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_heads(type, 16, num_threads, num_simd_lanes, partition_size) \ - instantiate_paged_attention_heads(type, 32, num_threads, num_simd_lanes, partition_size) +#define instantiate_paged_attention_inner( \ + type, head_size, block_size, num_threads, num_simd_lanes, partition_size) \ + template \ + [[host_name("paged_attention_" #type "_hs" #head_size "_bs" #block_size \ + "_nt" #num_threads "_nsl" #num_simd_lanes \ + "_ps" #partition_size)]] [[kernel]] void \ + paged_attention( \ + device float *exp_sums \ + [[buffer(0), function_constant(use_partitioning)]], \ + device float *max_logits \ + [[buffer(1), function_constant(use_partitioning)]], \ + device type *out [[buffer(2)]], device const type *q [[buffer(3)]], \ + device const type *k_cache [[buffer(4)]], \ + device const type *v_cache [[buffer(5)]], \ + const constant int &num_kv_heads [[buffer(6)]], \ + const constant float &scale [[buffer(7)]], \ + const constant float &softcapping [[buffer(8)]], \ + device const uint32_t *block_tables [[buffer(9)]], \ + device const uint32_t *context_lens [[buffer(10)]], \ + const constant int &max_num_blocks_per_seq [[buffer(11)]], \ + device const float *alibi_slopes \ + [[buffer(12), function_constant(use_alibi)]], \ + const constant int &q_stride [[buffer(13)]], \ + const constant int &kv_block_stride [[buffer(14)]], \ + const constant int &kv_head_stride [[buffer(15)]], \ + threadgroup char *shared_mem [[threadgroup(0)]], \ + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ + uint3 thread_position_in_threadgroup \ + [[thread_position_in_threadgroup]], \ + uint simd_tid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_paged_attention_v2_reduce_inner( \ + type, head_size, num_threads, num_simd_lanes, partition_size) \ + template [[host_name("paged_attention_v2_reduce_" #type "_hs" #head_size \ + "_nt" #num_threads "_nsl" #num_simd_lanes \ + "_ps" #partition_size)]] [[kernel]] void \ + paged_attention_v2_reduce( \ + device type * out [[buffer(0)]], \ + const device float *exp_sums [[buffer(1)]], \ + const device float *max_logits [[buffer(2)]], \ + const device type *tmp_out [[buffer(3)]], \ + device uint32_t *context_lens [[buffer(4)]], \ + const constant int &max_num_partitions [[buffer(5)]], \ + threadgroup char *shared_mem [[threadgroup(0)]], \ + uint3 threadgroup_position_in_grid [[threadgroup_position_in_grid]], \ + uint3 threadgroups_per_grid [[threadgroups_per_grid]], \ + uint3 thread_position_in_threadgroup [[thread_position_in_threadgroup]], \ + uint3 threads_per_threadgroup [[threads_per_threadgroup]], \ + uint simd_tid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_paged_attention_heads(type, block_size, num_threads, \ + num_simd_lanes, partition_size) \ + instantiate_paged_attention_inner(type, 64, block_size, num_threads, \ + num_simd_lanes, partition_size) \ + instantiate_paged_attention_inner(type, 80, block_size, num_threads, \ + num_simd_lanes, partition_size) \ + instantiate_paged_attention_inner(type, 96, block_size, num_threads, \ + num_simd_lanes, partition_size) \ + instantiate_paged_attention_inner(type, 112, block_size, \ + num_threads, num_simd_lanes, \ + partition_size) \ + instantiate_paged_attention_inner( \ + type, 128, block_size, num_threads, num_simd_lanes, \ + partition_size) \ + instantiate_paged_attention_inner( \ + type, 192, block_size, num_threads, num_simd_lanes, \ + partition_size) \ + instantiate_paged_attention_inner( \ + type, 256, block_size, num_threads, \ + num_simd_lanes, partition_size) + +#define instantiate_paged_attention_v2_reduce_heads( \ + type, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_v2_reduce_inner(type, 64, num_threads, \ + num_simd_lanes, partition_size) \ + instantiate_paged_attention_v2_reduce_inner( \ + type, 80, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_v2_reduce_inner( \ + type, 96, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_v2_reduce_inner( \ + type, 112, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_v2_reduce_inner( \ + type, 128, num_threads, num_simd_lanes, partition_size) \ + instantiate_paged_attention_v2_reduce_inner( \ + type, 192, num_threads, num_simd_lanes, \ + partition_size) \ + instantiate_paged_attention_v2_reduce_inner( \ + type, 256, num_threads, num_simd_lanes, \ + partition_size) + +#define instantiate_paged_attention_block_size(type, num_threads, \ + num_simd_lanes, partition_size) \ + instantiate_paged_attention_heads(type, 8, num_threads, num_simd_lanes, \ + partition_size) \ + instantiate_paged_attention_heads(type, 16, num_threads, num_simd_lanes, \ + partition_size) \ + instantiate_paged_attention_heads(type, 32, num_threads, \ + num_simd_lanes, partition_size) // TODO: tune num_threads = 256 // NOTE: partition_size = 0 -#define instantiate_paged_attention_v1(type, num_simd_lanes) \ +#define instantiate_paged_attention_v1(type, num_simd_lanes) \ instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 0) // TODO: tune num_threads = 256 // NOTE: partition_size = 512 -#define instantiate_paged_attention_v2(type, num_simd_lanes) \ - instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 512) \ - instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, 512) +#define instantiate_paged_attention_v2(type, num_simd_lanes) \ + instantiate_paged_attention_block_size(type, 256, num_simd_lanes, 512) \ + instantiate_paged_attention_v2_reduce_heads(type, 256, num_simd_lanes, \ + 512) instantiate_paged_attention_v1(float, 32) -instantiate_paged_attention_v1(bfloat16_t, 32) -instantiate_paged_attention_v1(half, 32) + instantiate_paged_attention_v1(bfloat16_t, 32) + instantiate_paged_attention_v1(half, 32) -instantiate_paged_attention_v2(float, 32) -instantiate_paged_attention_v2(bfloat16_t, 32) -instantiate_paged_attention_v2(half, 32) + instantiate_paged_attention_v2(float, 32) + instantiate_paged_attention_v2(bfloat16_t, 32) + instantiate_paged_attention_v2(half, 32) diff --git a/mistralrs-paged-attn/src/metal/kernels/reshape_and_cache.metal b/mistralrs-paged-attn/src/metal/kernels/reshape_and_cache.metal index fe0b89f63d..3d12c6563f 100644 --- a/mistralrs-paged-attn/src/metal/kernels/reshape_and_cache.metal +++ b/mistralrs-paged-attn/src/metal/kernels/reshape_and_cache.metal @@ -65,57 +65,49 @@ struct _MLX_BFloat16 { ///////////////////////////////////////////////////////////////////////////// // Conversions to bfloat - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) thread : bits_(float_to_bfloat_bits(static_cast(x))) {} - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup : bits_(float_to_bfloat_bits(static_cast(x))) {} - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) device : bits_(float_to_bfloat_bits(static_cast(x))) {} - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) constant : bits_(float_to_bfloat_bits(static_cast(x))) {} ///////////////////////////////////////////////////////////////////////////// // Conversions from bfloat - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const thread { return static_cast(bfloat_bits_to_float(bits_)); } - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const threadgroup { return static_cast(bfloat_bits_to_float(bits_)); } - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const device { return static_cast(bfloat_bits_to_float(bits_)); } - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() constant { return static_cast(bfloat_bits_to_float(bits_)); } @@ -133,29 +125,29 @@ constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { ///////////////////////////////////////////////////////////////////////////// // Binary operators -#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ - constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ } -#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ - constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } \ - constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ } ///////////////////////////////////////////////////////////////////////////// // Arithmetic Operators -#define bfloat_binop(_op_, _operator_) \ - bfloat_binop_base( \ - _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(_op_, _operator_, float, float, float); \ - bfloat_binop_helper(_op_, _operator_, float, half, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \ + _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); bfloat_binop(+, operator+); @@ -165,14 +157,14 @@ bfloat_binop(/, operator/); ///////////////////////////////////////////////////////////////////////////// // Comparison ops -#define bfloat_compop(__op__, __operator__) \ - bfloat_binop_base( \ - __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(__op__, __operator__, bool, float, float); \ - bfloat_binop_helper(__op__, __operator__, bool, half, float); \ - bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base(__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, \ + float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); bfloat_compop(>, operator>); @@ -189,27 +181,27 @@ bfloat_compop(!=, operator!=); ///////////////////////////////////////////////////////////////////////////// // Inplace Operators -#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ - constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ - addr_space _MLX_BFloat16& lhs, itype rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ - } \ - constexpr METAL_FUNC addr_space itype& __operator__( \ - addr_space itype& lhs, _MLX_BFloat16 rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype &__operator__(addr_space itype &lhs, \ + _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ } -#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ - bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ - bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); -#define bfloat_inplace_op(itype) \ - bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ - bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ - bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ bfloat_inplace_op_addr_space_helper(/, operator/=, itype); bfloat_inplace_op(float); @@ -225,16 +217,16 @@ bfloat_inplace_op(uint64_t); #undef bfloat_inplace_op_addr_space_helper #undef bfloat_inplace_op -#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ - constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ - addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ } -#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ - bfloat_inplace_op_helper(__op__, __operator__, device); \ - bfloat_inplace_op_helper(__op__, __operator__, thread); \ +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ bfloat_inplace_op_helper(__op__, __operator__, threadgroup); bfloat_inplace_op_addr_space_helper(+, operator+=); @@ -253,29 +245,24 @@ typedef struct _MLX_BFloat16 bfloat16_t; #endif - - - - - - -template +template [[kernel]] void reshape_and_cache( - const device T* __restrict__ key [[buffer(0)]], // [num_tokens, num_heads, head_size] - const device T* __restrict__ value [[buffer(1)]], // [num_tokens, num_heads, head_size] - device T* __restrict__ key_cache [[buffer(2)]], // [num_blocks, num_heads, head_size/x, block_size, x] - device T* __restrict__ value_cache [[buffer(3)]], // [num_blocks, num_heads, head_size, block_size] - const device int64_t* __restrict__ slot_mapping [[buffer(4)]], // [num_tokens] - device const int& key_stride, - device const int& value_stride, - device const int& num_heads, - device const int& head_size, - device const int& block_size, - device const int& x, + const device T *__restrict__ key + [[buffer(0)]], // [num_tokens, num_heads, head_size] + const device T *__restrict__ value + [[buffer(1)]], // [num_tokens, num_heads, head_size] + device T *__restrict__ key_cache + [[buffer(2)]], // [num_blocks, num_heads, head_size/x, block_size, x] + device T *__restrict__ value_cache + [[buffer(3)]], // [num_blocks, num_heads, head_size, block_size] + const device int64_t *__restrict__ slot_mapping + [[buffer(4)]], // [num_tokens] + device const int &key_stride, device const int &value_stride, + device const int &num_heads, device const int &head_size, + device const int &block_size, device const int &x, uint gid [[threadgroup_position_in_grid]], uint tid [[thread_position_in_threadgroup]], - uint threads_per_threadgroup [[threads_per_threadgroup]] -) { + uint threads_per_threadgroup [[threads_per_threadgroup]]) { const int64_t token_idx = gid; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -296,38 +283,33 @@ template const int x_idx = head_offset / x; const int x_offset = head_offset % x; - const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x - + head_idx * (head_size / x) * block_size * x - + x_idx * block_size * x - + block_offset * x - + x_offset; - const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size - + head_idx * head_size * block_size - + head_offset * block_size - + block_offset; + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; key_cache[tgt_key_idx] = key[src_key_idx]; value_cache[tgt_value_idx] = value[src_value_idx]; } } -#define instantiate_reshape_and_cache(type) \ - template [[host_name("reshape_and_cache_" #type)]] \ - [[kernel]] void reshape_and_cache( \ - const device type* __restrict__ key [[buffer(0)]], \ - const device type* __restrict__ value [[buffer(1)]], \ - device type* __restrict__ key_cache [[buffer(2)]], \ - device type* __restrict__ value_cache [[buffer(3)]], \ - const device int64_t* __restrict__ slot_mapping [[buffer(4)]], \ - device const int& key_stride, \ - device const int& value_stride, \ - device const int& num_heads, \ - device const int& head_size, \ - device const int& block_size, \ - device const int& x, \ - uint gid [[threadgroup_position_in_grid]], \ - uint tid [[thread_position_in_threadgroup]], \ - uint threads_per_threadgroup [[threads_per_threadgroup]]); - -instantiate_reshape_and_cache(float) -instantiate_reshape_and_cache(bfloat16_t) -instantiate_reshape_and_cache(half) +#define instantiate_reshape_and_cache(type) \ + template [[host_name("reshape_and_cache_" #type)]] [[kernel]] void \ + reshape_and_cache( \ + const device type *__restrict__ key [[buffer(0)]], \ + const device type *__restrict__ value [[buffer(1)]], \ + device type *__restrict__ key_cache [[buffer(2)]], \ + device type *__restrict__ value_cache [[buffer(3)]], \ + const device int64_t *__restrict__ slot_mapping [[buffer(4)]], \ + device const int &key_stride, device const int &value_stride, \ + device const int &num_heads, device const int &head_size, \ + device const int &block_size, device const int &x, \ + uint gid [[threadgroup_position_in_grid]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint threads_per_threadgroup [[threads_per_threadgroup]]); + +instantiate_reshape_and_cache(float) instantiate_reshape_and_cache(bfloat16_t) + instantiate_reshape_and_cache(half) diff --git a/mistralrs-pyo3/Cargo.toml b/mistralrs-pyo3/Cargo.toml index c53214a3b7..9b2a9eb971 100644 --- a/mistralrs-pyo3/Cargo.toml +++ b/mistralrs-pyo3/Cargo.toml @@ -44,3 +44,4 @@ metal = ["candle-core/metal", "mistralrs-core/metal"] flash-attn = ["cuda", "mistralrs-core/flash-attn"] accelerate = ["mistralrs-core/accelerate"] mkl = ["mistralrs-core/mkl"] +nccl = ["mistralrs-core/nccl"] diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 3427951ed5..43e61d8d82 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -44,7 +44,11 @@ static DEVICE: OnceLock> = OnceLock::new(); #[cfg(not(feature = "metal"))] fn get_device(seed: Option) -> &'static Result { DEVICE.get_or_init(|| { - let device = Device::cuda_if_available(0)?; + let device = if cfg!(feature = "nccl") { + Device::Cpu + } else { + Device::cuda_if_available(0)? + }; if let Some(seed) = seed { device.set_seed(seed)?; } @@ -648,7 +652,7 @@ impl Runner { None => DeviceMapSetting::Auto(auto_map_params), }; - let no_paged_attn = if device.is_cuda() { + let no_paged_attn = if device.is_cuda() || cfg!(feature = "nccl") { no_paged_attn } else if device.is_metal() { !paged_attn diff --git a/mistralrs-quant/build.rs b/mistralrs-quant/build.rs index af651c5e8d..9349789809 100644 --- a/mistralrs-quant/build.rs +++ b/mistralrs-quant/build.rs @@ -108,7 +108,9 @@ fn main() { .arg("--expt-relaxed-constexpr") .arg("--expt-extended-lambda") .arg("--use_fast_math") - .arg("--verbose"); + .arg("--verbose") + .arg("--compiler-options") + .arg("-fPIC"); // https://github.com/EricLBuehler/mistral.rs/issues/286 if let Some(cuda_nvcc_flags_env) = CUDA_NVCC_FLAGS { diff --git a/mistralrs-quant/kernels/bitsandbytes/dequant.cu b/mistralrs-quant/kernels/bitsandbytes/dequant.cu index 7dce07d8d4..2905b41ce7 100644 --- a/mistralrs-quant/kernels/bitsandbytes/dequant.cu +++ b/mistralrs-quant/kernels/bitsandbytes/dequant.cu @@ -3,201 +3,226 @@ #include #include -typedef enum DataType_t -{ - General8bit = 0, - FP4 = 1, - NF4 = 2, +typedef enum DataType_t { + General8bit = 0, + FP4 = 1, + NF4 = 2, } DataType_t; -__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) -{ +__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) { float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; - if((val & 0b0100) == 4) // 0 - if((val & 0b0010) == 2) //01 - if((val & 0b0001) == 1) // 111 - return 0.25000000f*absmax*sign; // 1111 + if ((val & 0b0100) == 4) // 0 + if ((val & 0b0010) == 2) // 01 + if ((val & 0b0001) == 1) // 111 + return 0.25000000f * absmax * sign; // 1111 else - return 0.16666667f*absmax*sign; // 1110 + return 0.16666667f * absmax * sign; // 1110 + else if ((val & 0b0001) == 1) // 110 + return 0.50000000f * absmax * sign; // 1101 else - if((val & 0b0001) == 1) // 110 - return 0.50000000f*absmax*sign; // 1101 - else - return 0.33333333f*absmax*sign; // 1100 - else - if((val & 0b0010) == 2) //10 - if((val & 0b0001) == 1) // 101 - return 1.00000000f*absmax*sign; // 1011 - else - return 0.66666667f*absmax*sign; // 1010 + return 0.33333333f * absmax * sign; // 1100 + else if ((val & 0b0010) == 2) // 10 + if ((val & 0b0001) == 1) // 101 + return 1.00000000f * absmax * sign; // 1011 else - if((val & 0b0001) == 1) // 100 - return 5.208333333e-03f*absmax*sign; // 1001 - else - return 0.00000000f*absmax*sign; // 1000 + return 0.66666667f * absmax * sign; // 1010 + else if ((val & 0b0001) == 1) // 100 + return 5.208333333e-03f * absmax * sign; // 1001 + else + return 0.00000000f * absmax * sign; // 1000 } -__device__ float dDequantizeNF4(unsigned char val) -{ +__device__ float dDequantizeNF4(unsigned char val) { // the values for this tree was generated by test_normal_map_tree // in the file tests/test_functional.py - if((val & 0b1000) == 8) - if((val & 0b0100) == 4) // 1 - if((val & 0b0010) == 2) // 11 - if((val & 0b0001) == 1) // 111 + if ((val & 0b1000) == 8) + if ((val & 0b0100) == 4) // 1 + if ((val & 0b0010) == 2) // 11 + if ((val & 0b0001) == 1) // 111 return 1.0f; else return 0.7229568362236023f; + else if ((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; else - if((val & 0b0001) == 1) // 110 - return 0.5626170039176941f; - else - return 0.44070982933044434f; - else - if((val & 0b0010) == 2) //10 - if((val & 0b0001) == 1) // 101 - return 0.33791524171829224f; - else - return 0.24611230194568634f; + return 0.44070982933044434f; + else if ((val & 0b0010) == 2) // 10 + if ((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; else - if((val & 0b0001) == 1) // 100 - return 0.16093020141124725f; - else - return 0.07958029955625534f; + return 0.24611230194568634f; + else if ((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; - else - if((val & 0b0100) == 4) // 0 - if((val & 0b0010) == 2) //01 - if((val & 0b0001) == 1) // 011 - return 0.0f; - else - return -0.09105003625154495f; + else if ((val & 0b0100) == 4) // 0 + if ((val & 0b0010) == 2) // 01 + if ((val & 0b0001) == 1) // 011 + return 0.0f; else - if((val & 0b0001) == 1) // 010 - return -0.18477343022823334f; - else - return -0.28444138169288635f; + return -0.09105003625154495f; + else if ((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; else - if((val & 0b0010) == 2) //00 - if((val & 0b0001) == 1) // 001 - return -0.39491748809814453f; - else - return -0.5250730514526367f; - else - if((val & 0b0001) == 1) // 000 - return -0.6961928009986877f; - else - return -1.0f; - + return -0.28444138169288635f; + else if ((val & 0b0010) == 2) // 00 + if ((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else if ((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; } -template -__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) -{ +template +__global__ void kDequantizeBlockwise(float *code, unsigned char *A, + float *absmax, T *out, const int blocksize, + const int n) { const int n_load = (gridDim.x * TILE_SIZE); int valid_items_load = 0; int valid_items_store = 0; const int base_idx = (blockIdx.x * TILE_SIZE); - T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; + T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)]; unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; - typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + typedef cub::BlockLoad + LoadChar; + typedef cub::BlockStore 0) ? 2 : 1), + cub::BLOCK_STORE_WARP_TRANSPOSE> + StoreT; __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; - for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) - { - if(DATA_TYPE > 0) - { - valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; - valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; - } - else - { + for (unsigned int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { + if (DATA_TYPE > 0) { + valid_items_load = + (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i; + valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2; + } else { valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; } - local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); + local_abs_max = + __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (blocksize)]); __syncthreads(); LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); - switch(DATA_TYPE) - { - case General8bit: - // load code through read-only cache via __ldg - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - vals[j] = __ldg(&code[qvals[j]])*local_abs_max; - break; - case FP4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - { - vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); - vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); - } - break; - case NF4: - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - { - vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; - vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; - } - break; + switch (DATA_TYPE) { + case General8bit: +// load code through read-only cache via __ldg +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]]) * local_abs_max; + break; + case FP4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: +#pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max; + vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max; + } + break; } __syncthreads(); - StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i * 2 : i]), vals, + valid_items_store); } } -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, cudaStream_t stream) -{ - int num_blocks = n/blocksize; +template +void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, + int blocksize, const int n, cudaStream_t stream) { + int num_blocks = n / blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; int tile_size = (DATA_TYPE > 0) ? 1024 : 512; - if(DATA_TYPE > 0) - kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize/2, n); + if (DATA_TYPE > 0) + kDequantizeBlockwise + <<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>( + code, A, absmax, out, blocksize / 2, n); else - kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n); + kDequantizeBlockwise + <<<(n + tile_size - 1) / tile_size, 64, 0, stream>>>(code, A, absmax, + out, blocksize, n); } -extern "C" void dequantize_blockwise_f32_int8(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) { - dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +extern "C" void dequantize_blockwise_f32_int8(float *code, unsigned char *A, + float *absmax, float *out, + int blocksize, const int n, + cudaStream_t stream) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, + stream); } -extern "C" void dequantize_blockwise_f32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) { - dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +extern "C" void dequantize_blockwise_f32_fp4(float *code, unsigned char *A, + float *absmax, float *out, + int blocksize, const int n, + cudaStream_t stream) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } -extern "C" void dequantize_blockwise_f32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) { - dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); +extern "C" void dequantize_blockwise_f32_nf4(float *code, unsigned char *A, + float *absmax, float *out, + int blocksize, const int n, + cudaStream_t stream) { + dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } -extern "C" void dequantize_blockwise_f16_int8(float *code, unsigned char *A, float *absmax, __half *out, int blocksize, const int n, cudaStream_t stream) { - dequantizeBlockwise<__half, General8bit>(code, A, absmax, out, blocksize, n, stream); +extern "C" void dequantize_blockwise_f16_int8(float *code, unsigned char *A, + float *absmax, __half *out, + int blocksize, const int n, + cudaStream_t stream) { + dequantizeBlockwise<__half, General8bit>(code, A, absmax, out, blocksize, n, + stream); } -extern "C" void dequantize_blockwise_f16_fp4(float *code, unsigned char *A, float *absmax, __half *out, int blocksize, const int n, cudaStream_t stream) { - dequantizeBlockwise<__half, FP4>(code, A, absmax, out, blocksize, n, stream); +extern "C" void dequantize_blockwise_f16_fp4(float *code, unsigned char *A, + float *absmax, __half *out, + int blocksize, const int n, + cudaStream_t stream) { + dequantizeBlockwise<__half, FP4>(code, A, absmax, out, blocksize, n, stream); } -extern "C" void dequantize_blockwise_f16_nf4(float *code, unsigned char *A, float *absmax, __half *out, int blocksize, const int n, cudaStream_t stream) { - dequantizeBlockwise<__half, NF4>(code, A, absmax, out, blocksize, n, stream); +extern "C" void dequantize_blockwise_f16_nf4(float *code, unsigned char *A, + float *absmax, __half *out, + int blocksize, const int n, + cudaStream_t stream) { + dequantizeBlockwise<__half, NF4>(code, A, absmax, out, blocksize, n, stream); } // #if __CUDA_ARCH__ >= 800 -extern "C" void dequantize_blockwise_bf16_int8(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream) { - dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream); +extern "C" void dequantize_blockwise_bf16_int8(float *code, unsigned char *A, + float *absmax, + __nv_bfloat16 *out, + int blocksize, const int n, + cudaStream_t stream) { + dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, + blocksize, n, stream); } -extern "C" void dequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream) { - dequantizeBlockwise<__nv_bfloat16, FP4>(code, A, absmax, out, blocksize, n, stream); +extern "C" void dequantize_blockwise_bf16_fp4(float *code, unsigned char *A, + float *absmax, __nv_bfloat16 *out, + int blocksize, const int n, + cudaStream_t stream) { + dequantizeBlockwise<__nv_bfloat16, FP4>(code, A, absmax, out, blocksize, n, + stream); } -extern "C" void dequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream) { - dequantizeBlockwise<__nv_bfloat16, NF4>(code, A, absmax, out, blocksize, n, stream); +extern "C" void dequantize_blockwise_bf16_nf4(float *code, unsigned char *A, + float *absmax, __nv_bfloat16 *out, + int blocksize, const int n, + cudaStream_t stream) { + dequantizeBlockwise<__nv_bfloat16, NF4>(code, A, absmax, out, blocksize, n, + stream); } // #endif diff --git a/mistralrs-quant/kernels/blockwise_fp8/blockwise_fp8.cu b/mistralrs-quant/kernels/blockwise_fp8/blockwise_fp8.cu index 9f60d74cc8..3ae5a07d17 100644 --- a/mistralrs-quant/kernels/blockwise_fp8/blockwise_fp8.cu +++ b/mistralrs-quant/kernels/blockwise_fp8/blockwise_fp8.cu @@ -1,10 +1,21 @@ #include #include +#include #include #include #include +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA error at %s:%d: %s\n", __FILE__, __LINE__, \ + cudaGetErrorString(err)); \ + exit(err); \ + } \ + } while (0) + template __global__ void dequant_fp8_blockwise_kernel( const __nv_fp8_e4m3 *__restrict__ weight, const float *__restrict__ scale, @@ -19,14 +30,6 @@ __global__ void dequant_fp8_blockwise_kernel( int start_y = grid_y * weight_block_size_y; int start_x = grid_x * weight_block_size_x; - // Use threadIdx to cover elements in the tile. - int local_y = threadIdx.y; - int local_x = threadIdx.x; - - // Compute global indices. - int weight_y = start_y + local_y; - int weight_x = start_x + local_x; - // Load the block's scale factor into shared memory. __shared__ float block_scale; if (threadIdx.x == 0 && threadIdx.y == 0) { @@ -34,67 +37,70 @@ __global__ void dequant_fp8_blockwise_kernel( } __syncthreads(); // Ensure all threads see the loaded value. - // Bounds check: if within the dimensions of the weight matrix. - if (weight_y < weight_height && weight_x < weight_width) { - int pos = weight_y * weight_row_stride + weight_x; - float w_val = - __half2float(__nv_cvt_fp8_to_halfraw(weight[pos].__x, __NV_E4M3)); - // Use the shared scale factor. - output[pos] = static_cast(w_val * block_scale); + // Loop over the tile using a fixed blockDim, covering the whole tile. + for (int local_y = threadIdx.y; local_y < weight_block_size_y; + local_y += blockDim.y) { + for (int local_x = threadIdx.x; local_x < weight_block_size_x; + local_x += blockDim.x) { + int weight_y = start_y + local_y; + int weight_x = start_x + local_x; + if (weight_y < weight_height && weight_x < weight_width) { + int pos = weight_y * weight_row_stride + weight_x; + float w_val = + __half2float(__nv_cvt_fp8_to_halfraw(weight[pos].__x, __NV_E4M3)); + output[pos] = static_cast(w_val * block_scale); + } + } } } extern "C" void launch_dequant_fp8_blockwise_kernel_f32( const __nv_fp8_e4m3 *d_weight, const float *d_scale, float *d_output, int weight_height, int weight_width, int weight_row_stride, - int scale_stride, int weight_block_size_y, int weight_block_size_x) { - // Calculate grid dimensions. + int scale_stride, int weight_block_size_y, int weight_block_size_x, + cudaStream_t stream) { int grid_y = (weight_height + weight_block_size_y - 1) / weight_block_size_y; int grid_x = (weight_width + weight_block_size_x - 1) / weight_block_size_x; - - // Set block dimensions to match the block size. - dim3 blockDim(weight_block_size_x, weight_block_size_y); + dim3 blockDim(32, 32); dim3 gridDim(grid_x, grid_y); - dequant_fp8_blockwise_kernel - <<>>(d_weight, d_scale, d_output, weight_height, - weight_width, weight_row_stride, scale_stride, - weight_block_size_y, weight_block_size_x); + dequant_fp8_blockwise_kernel<<>>( + d_weight, d_scale, d_output, weight_height, weight_width, + weight_row_stride, scale_stride, weight_block_size_y, + weight_block_size_x); + CUDA_CHECK(cudaGetLastError()); } extern "C" void launch_dequant_fp8_blockwise_kernel_f16( const __nv_fp8_e4m3 *d_weight, const float *d_scale, __half *d_output, int weight_height, int weight_width, int weight_row_stride, - int scale_stride, int weight_block_size_y, int weight_block_size_x) { - // Calculate grid dimensions. + int scale_stride, int weight_block_size_y, int weight_block_size_x, + cudaStream_t stream) { int grid_y = (weight_height + weight_block_size_y - 1) / weight_block_size_y; int grid_x = (weight_width + weight_block_size_x - 1) / weight_block_size_x; - - // Set block dimensions to match the block size. - dim3 blockDim(weight_block_size_x, weight_block_size_y); + dim3 blockDim(32, 32); dim3 gridDim(grid_x, grid_y); - dequant_fp8_blockwise_kernel<__half> - <<>>(d_weight, d_scale, d_output, weight_height, - weight_width, weight_row_stride, scale_stride, - weight_block_size_y, weight_block_size_x); + dequant_fp8_blockwise_kernel<__half><<>>( + d_weight, d_scale, d_output, weight_height, weight_width, + weight_row_stride, scale_stride, weight_block_size_y, + weight_block_size_x); + CUDA_CHECK(cudaGetLastError()); } extern "C" void launch_dequant_fp8_blockwise_kernel_bf16( const __nv_fp8_e4m3 *d_weight, const float *d_scale, __nv_bfloat16 *d_output, int weight_height, int weight_width, int weight_row_stride, int scale_stride, int weight_block_size_y, - int weight_block_size_x) { - // Calculate grid dimensions. + int weight_block_size_x, cudaStream_t stream) { int grid_y = (weight_height + weight_block_size_y - 1) / weight_block_size_y; int grid_x = (weight_width + weight_block_size_x - 1) / weight_block_size_x; - - // Set block dimensions to match the block size. - dim3 blockDim(weight_block_size_x, weight_block_size_y); + dim3 blockDim(32, 32); dim3 gridDim(grid_x, grid_y); - dequant_fp8_blockwise_kernel<__nv_bfloat16> - <<>>(d_weight, d_scale, d_output, weight_height, - weight_width, weight_row_stride, scale_stride, - weight_block_size_y, weight_block_size_x); + dequant_fp8_blockwise_kernel<__nv_bfloat16><<>>( + d_weight, d_scale, d_output, weight_height, weight_width, + weight_row_stride, scale_stride, weight_block_size_y, + weight_block_size_x); + CUDA_CHECK(cudaGetLastError()); } \ No newline at end of file diff --git a/mistralrs-quant/kernels/gptq/q_gemm.cu b/mistralrs-quant/kernels/gptq/q_gemm.cu index 38b1a637dc..43fabeb737 100644 --- a/mistralrs-quant/kernels/gptq/q_gemm.cu +++ b/mistralrs-quant/kernels/gptq/q_gemm.cu @@ -6,8 +6,8 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa #include #include -#include #include +#include #include "compat.cuh" #include "matrix_view.cuh" @@ -27,110 +27,118 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) #if defined(USE_ROCM) - #include +#include __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( hipblasHandle_t handle, hipblasOperation_t transA, - hipblasOperation_t transB, int m, int n, int k, const half* alpha, - const half* AP, int lda, const half* BP, int ldb, const half* beta, - half* CP, int ldc) { + hipblasOperation_t transB, int m, int n, int k, const half *alpha, + const half *AP, int lda, const half *BP, int ldb, const half *beta, + half *CP, int ldc) { return hipblasHgemm(handle, transA, transB, m, n, k, - reinterpret_cast(alpha), - reinterpret_cast(AP), lda, - reinterpret_cast(BP), ldb, - reinterpret_cast(beta), - reinterpret_cast(CP), ldc); + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); } - #define hipblasHgemm __compat_hipblasHgemm +#define hipblasHgemm __compat_hipblasHgemm - // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. - #define rocblas_operation_none HIPBLAS_OP_N - #define rocblas_hgemm __compat_hipblasHgemm +// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_hgemm __compat_hipblasHgemm #endif -__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half *a_ptr, const half2 g_result) { half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; + const half2 *a2_ptr = (const half2 *)a_ptr; #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + for (int i = 0; i < 4; i++) + result = __hfma2(dq[i], *a2_ptr++, result); return __hadd2(result, g_result); } -__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr) { +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half *a_ptr) { half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; + const half2 *a2_ptr = (const half2 *)a_ptr; #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + for (int i = 0; i < 4; i++) + result = __hfma2(dq[i], *a2_ptr++, result); return __half2float(__low2half(result)) + __half2float(__high2half(result)); } -__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half* a_ptr, +__forceinline__ __device__ half2 dot22_8(half2 (&dq)[4], const half *a_ptr, const half2 g_result, const half qs_h) { half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; + const half2 *a2_ptr = (const half2 *)a_ptr; #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + for (int i = 0; i < 4; i++) + result = __hfma2(dq[i], *a2_ptr++, result); return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ half2 dot22_16(half2 (&dq)[8], const half* a_ptr, +__forceinline__ __device__ half2 dot22_16(half2 (&dq)[8], const half *a_ptr, const half2 g_result, const half qs_h) { half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; + const half2 *a2_ptr = (const half2 *)a_ptr; #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + for (int i = 0; i < 8; i++) + result = __hfma2(dq[i], *a2_ptr++, result); return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ half2 dot22_32(half2 (&dq)[16], const half* a_ptr, +__forceinline__ __device__ half2 dot22_32(half2 (&dq)[16], const half *a_ptr, const half2 g_result, const half qs_h) { half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; + const half2 *a2_ptr = (const half2 *)a_ptr; #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + for (int i = 0; i < 16; i += 1) + result = __hfma2(dq[i], *a2_ptr++, result); return __hfma2(result, __halves2half2(qs_h, qs_h), g_result); } -__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half* a_ptr, +__forceinline__ __device__ float dot22_8_f(half2 (&dq)[4], const half *a_ptr, const float g_result, const float qs_f) { half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; + const half2 *a2_ptr = (const half2 *)a_ptr; #pragma unroll - for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + for (int i = 0; i < 4; i++) + result = __hfma2(dq[i], *a2_ptr++, result); float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ float dot22_16_f(half2 (&dq)[8], const half* a_ptr, +__forceinline__ __device__ float dot22_16_f(half2 (&dq)[8], const half *a_ptr, const float g_result, const float qs_f) { half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; + const half2 *a2_ptr = (const half2 *)a_ptr; #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + for (int i = 0; i < 8; i++) + result = __hfma2(dq[i], *a2_ptr++, result); float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ float dot22_32_f(half2 (&dq)[16], const half* a_ptr, +__forceinline__ __device__ float dot22_32_f(half2 (&dq)[16], const half *a_ptr, const float g_result, const float qs_f) { half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; + const half2 *a2_ptr = (const half2 *)a_ptr; #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + for (int i = 0; i < 16; i += 1) + result = __hfma2(dq[i], *a2_ptr++, result); float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result)); return fma(result_f, qs_f, g_result); } -__forceinline__ __device__ half dot22_8_h(half2 (&dq)[4], const half* a_ptr, +__forceinline__ __device__ half dot22_8_h(half2 (&dq)[4], const half *a_ptr, const half g_result, const half qs_h) { // Use FP32 accumulator to avoid potential overflow since unscaled weights are @@ -153,41 +161,43 @@ __forceinline__ __device__ half dot22_8_h(half2 (&dq)[4], const half* a_ptr, return __hadd(result_h, g_result); } -__forceinline__ __device__ half dot22_16_h(half2 (&dq)[8], const half* a_ptr, +__forceinline__ __device__ half dot22_16_h(half2 (&dq)[8], const half *a_ptr, const half g_result, const half qs_h) { half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; + const half2 *a2_ptr = (const half2 *)a_ptr; #pragma unroll - for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result); + for (int i = 0; i < 8; i++) + result = __hfma2(dq[i], *a2_ptr++, result); half result_h = __hadd(__low2half(result), __high2half(result)); return __hfma(result_h, qs_h, g_result); } -__forceinline__ __device__ half dot22_32_h(half2 (&dq)[16], const half* a_ptr, +__forceinline__ __device__ half dot22_32_h(half2 (&dq)[16], const half *a_ptr, const half g_result, const half qs_h) { half2 result = {}; - const half2* a2_ptr = (const half2*)a_ptr; + const half2 *a2_ptr = (const half2 *)a_ptr; #pragma unroll - for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result); + for (int i = 0; i < 16; i += 1) + result = __hfma2(dq[i], *a2_ptr++, result); half result_h = __hadd(__low2half(result), __high2half(result)); return __hfma(result_h, qs_h, g_result); } -typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half*, const uint32_t*, - const uint32_t*, const half*, - half*, const int, const int, +typedef void (*fp_gemm_half_q_half_gptq_kernel)(const half *, const uint32_t *, + const uint32_t *, const half *, + half *, const int, const int, const int, const int, - const int*); + const int *); template __global__ void gemm_half_q_half_gptq_4bit_kernel( - const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const half *__restrict__ a, const uint32_t *__restrict__ b_q_weight, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, half *__restrict__ c, const int size_m, const int size_n, const int size_k, const int groups, - const int* __restrict__ b_q_perm) { + const int *__restrict__ b_q_perm) { MatrixView_half a_(a, size_m, size_k); MatrixView_half_rw c_(c, size_m, size_n); MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); @@ -211,8 +221,8 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( if (offset_k + t < end_k) { for (int m = 0; m < m_count; ++m) { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; + const half *a_ptr = a_.item_ptr(offset_m + m, 0); + half *block_a_ptr = block_a[m]; half a0; if (b_q_perm) @@ -224,11 +234,12 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( } // Zero output - if (n >= size_n) return; + if (n >= size_n) + return; if (blockIdx.z == 0) { for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + *((uint64_t *)c_.item_ptr(offset_m + m, n)) = 0; } __syncthreads(); @@ -241,8 +252,8 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( // a, b offset int qk = offset_k / (32 / 4); - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + const half *a_ptr = &block_a[0][0]; int a_stride = BLOCK_KN_SIZE; // Initial group @@ -276,7 +287,7 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( #pragma unroll for (int j = 0; j < 4; j++) { - const int4* b_ptr4 = (int4*)b_ptr; + const int4 *b_ptr4 = (int4 *)b_ptr; int4 load_int4 = *b_ptr4; half2 dq[4][4]; @@ -309,7 +320,7 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( } for (int m = 0; m < m_count; m++) { - half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 *out = (half2 *)c_.item_ptr(offset_m + m, n); half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), @@ -321,11 +332,11 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( template __global__ void gemm_half_q_half_gptq_2bit_kernel( - const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const half *__restrict__ a, const uint32_t *__restrict__ b_q_weight, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, half *__restrict__ c, const int size_m, const int size_n, const int size_k, const int groups, - const int* __restrict__ b_q_perm) { + const int *__restrict__ b_q_perm) { MatrixView_half a_(a, size_m, size_k); MatrixView_half_rw c_(c, size_m, size_n); MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); @@ -349,8 +360,8 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( if (offset_k + t < end_k) { for (int m = 0; m < m_count; ++m) { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; + const half *a_ptr = a_.item_ptr(offset_m + m, 0); + half *block_a_ptr = block_a[m]; half a0; if (b_q_perm) @@ -362,11 +373,12 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( } // Zero output - if (n >= size_n) return; + if (n >= size_n) + return; if (blockIdx.z == 0) { for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + *((uint64_t *)c_.item_ptr(offset_m + m, n)) = 0; } __syncthreads(); @@ -379,8 +391,8 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( // a, b offset int qk = offset_k / (32 / 2); - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + const half *a_ptr = &block_a[0][0]; int a_stride = BLOCK_KN_SIZE; // Initial group @@ -403,7 +415,7 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( #pragma unroll for (int j = 0; j < 1; j++) { - const int4* b_ptr4 = (int4*)b_ptr; + const int4 *b_ptr4 = (int4 *)b_ptr; int4 load_int4 = *b_ptr4; half2 dq[4][8]; @@ -432,7 +444,7 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( } for (int m = 0; m < m_count; m++) { - half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 *out = (half2 *)c_.item_ptr(offset_m + m, n); half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); atomicAdd(out, result01); @@ -442,11 +454,11 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( template __global__ void gemm_half_q_half_gptq_3bit_kernel( - const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const half *__restrict__ a, const uint32_t *__restrict__ b_q_weight, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, half *__restrict__ c, const int size_m, const int size_n, const int size_k, const int groups, - const int* __restrict__ b_q_perm) { + const int *__restrict__ b_q_perm) { MatrixView_half a_(a, size_m, size_k); MatrixView_half_rw c_(c, size_m, size_n); MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); @@ -470,8 +482,8 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( if (offset_k + t < end_k) { for (int m = 0; m < m_count; ++m) { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; + const half *a_ptr = a_.item_ptr(offset_m + m, 0); + half *block_a_ptr = block_a[m]; half a0; if (b_q_perm) @@ -483,11 +495,12 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( } // Zero output - if (n >= size_n) return; + if (n >= size_n) + return; if (blockIdx.z == 0) { for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + *((uint64_t *)c_.item_ptr(offset_m + m, n)) = 0; } __syncthreads(); @@ -500,8 +513,8 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( // a, b offset int qk = offset_k / 32 * 3; - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + const half *a_ptr = &block_a[0][0]; int a_stride = BLOCK_KN_SIZE; // Initial group @@ -525,11 +538,11 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( #pragma unroll for (int j = 0; j < 1; j++) { int4 load_int4[3]; - load_int4[0] = *((int4*)b_ptr); + load_int4[0] = *((int4 *)b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*)b_ptr); + load_int4[1] = *((int4 *)b_ptr); b_ptr += size_n; - load_int4[2] = *((int4*)b_ptr); + load_int4[2] = *((int4 *)b_ptr); b_ptr += size_n; half2 dq[4][16]; @@ -560,7 +573,7 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( } for (int m = 0; m < m_count; m++) { - half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 *out = (half2 *)c_.item_ptr(offset_m + m, n); half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); atomicAdd(out, result01); @@ -570,11 +583,11 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( template __global__ void gemm_half_q_half_gptq_8bit_kernel( - const half* __restrict__ a, const uint32_t* __restrict__ b_q_weight, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, half* __restrict__ c, + const half *__restrict__ a, const uint32_t *__restrict__ b_q_weight, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, half *__restrict__ c, const int size_m, const int size_n, const int size_k, const int groups, - const int* __restrict__ b_q_perm) { + const int *__restrict__ b_q_perm) { MatrixView_half a_(a, size_m, size_k); MatrixView_half_rw c_(c, size_m, size_n); MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); @@ -598,8 +611,8 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( if (offset_k + t < end_k) { for (int m = 0; m < m_count; ++m) { - const half* a_ptr = a_.item_ptr(offset_m + m, 0); - half* block_a_ptr = block_a[m]; + const half *a_ptr = a_.item_ptr(offset_m + m, 0); + half *block_a_ptr = block_a[m]; half a0; if (b_q_perm) @@ -611,11 +624,12 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( } // Zero output - if (n >= size_n) return; + if (n >= size_n) + return; if (blockIdx.z == 0) { for (int m = 0; m < m_count; m++) - *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + *((uint64_t *)c_.item_ptr(offset_m + m, n)) = 0; } __syncthreads(); @@ -628,8 +642,8 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( // a, b offset int qk = offset_k / (32 / 8); - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; - const half* a_ptr = &block_a[0][0]; + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; + const half *a_ptr = &block_a[0][0]; int a_stride = BLOCK_KN_SIZE; // Initial group @@ -653,9 +667,9 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( #pragma unroll for (int j = 0; j < 4; j++) { int4 load_int4[2]; - load_int4[0] = *((int4*)b_ptr); + load_int4[0] = *((int4 *)b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*)b_ptr); + load_int4[1] = *((int4 *)b_ptr); b_ptr += size_n; half2 dq[4][4]; @@ -684,7 +698,7 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( } for (int m = 0; m < m_count; m++) { - half2* out = (half2*)c_.item_ptr(offset_m + m, n); + half2 *out = (half2 *)c_.item_ptr(offset_m + m, n); half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]); half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]); atomicAdd(out, result01); @@ -692,14 +706,19 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( } } -fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel( - bool first_block, const int m_count, const int bit) { -#define SELECT_KERNEL(M_COUNT) \ - if (m_count == M_COUNT) { \ - if (bit == 2) return gemm_half_q_half_gptq_2bit_kernel; \ - if (bit == 3) return gemm_half_q_half_gptq_3bit_kernel; \ - if (bit == 4) return gemm_half_q_half_gptq_4bit_kernel; \ - if (bit == 8) return gemm_half_q_half_gptq_8bit_kernel; \ +fp_gemm_half_q_half_gptq_kernel +pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count, + const int bit) { +#define SELECT_KERNEL(M_COUNT) \ + if (m_count == M_COUNT) { \ + if (bit == 2) \ + return gemm_half_q_half_gptq_2bit_kernel; \ + if (bit == 3) \ + return gemm_half_q_half_gptq_3bit_kernel; \ + if (bit == 4) \ + return gemm_half_q_half_gptq_4bit_kernel; \ + if (bit == 8) \ + return gemm_half_q_half_gptq_8bit_kernel; \ } #if BLOCK_M_SIZE_MAX >= 1 SELECT_KERNEL(1); @@ -728,11 +747,10 @@ fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel( return NULL; } -extern "C" void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, const int* b_q_perm, - half* c, int size_m, int size_n, int size_k, - int m_count, int groups, int bit) { +extern "C" void gemm_half_q_half_cuda_part( + const half *a, const uint32_t *b_q_weight, const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, const int *b_q_perm, half *c, int size_m, + int size_n, int size_k, int m_count, int groups, int bit) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -744,16 +762,15 @@ extern "C" void gemm_half_q_half_cuda_part(const half* a, const uint32_t* b_q_we fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count, bit); - kernel<<>>(a, b_q_weight, b_gptq_qzeros, - b_gptq_scales, c, size_m, size_n, - size_k, groups, b_q_perm); + kernel<<>>(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + c, size_m, size_n, size_k, groups, b_q_perm); } __global__ void reconstruct_exllama_8bit_kernel( - const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, - const int groups, half* __restrict__ b) { + const uint32_t *__restrict__ b_q_weight, const int *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half *__restrict__ b) { MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q8_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); @@ -768,12 +785,14 @@ __global__ void reconstruct_exllama_8bit_kernel( int t = threadIdx.x; if (b_q_perm) { - if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; } // Column int n = offset_n + t * 4; - if (n >= size_n) return; + if (n >= size_n) + return; // Find initial group int groupsize = size_k / groups; @@ -783,7 +802,7 @@ __global__ void reconstruct_exllama_8bit_kernel( // b offset int qk = offset_k / (32 / 8); - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; // Initial zeros/scale int zeros[4]; @@ -806,9 +825,9 @@ __global__ void reconstruct_exllama_8bit_kernel( for (int p = 0; p < 4; p++) { int4 load_int4[2]; - load_int4[0] = *((int4*)b_ptr); + load_int4[0] = *((int4 *)b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*)b_ptr); + load_int4[1] = *((int4 *)b_ptr); b_ptr += size_n; half2 dq[4][4]; @@ -824,7 +843,8 @@ __global__ void reconstruct_exllama_8bit_kernel( // half* dqh = (half*)dq; if (b_q_perm) { for (int j = 0; j < 4; j++) { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), @@ -832,7 +852,8 @@ __global__ void reconstruct_exllama_8bit_kernel( } } else { for (int j = 0; j < 4; j++) { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); @@ -847,10 +868,10 @@ __global__ void reconstruct_exllama_8bit_kernel( } __global__ void reconstruct_exllama_4bit_kernel( - const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, - const int groups, half* __restrict__ b) { + const uint32_t *__restrict__ b_q_weight, const int *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half *__restrict__ b) { MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); @@ -865,12 +886,14 @@ __global__ void reconstruct_exllama_4bit_kernel( int t = threadIdx.x; if (b_q_perm) { - if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; } // Column int n = offset_n + t * 4; - if (n >= size_n) return; + if (n >= size_n) + return; // Find initial group int groupsize = size_k / groups; @@ -880,7 +903,7 @@ __global__ void reconstruct_exllama_4bit_kernel( // b offset int qk = offset_k / (32 / 4); - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; // Initial zeros/scale int zeros[4]; @@ -913,7 +936,7 @@ __global__ void reconstruct_exllama_4bit_kernel( for (int p = 0; p < 4; p++) { half2 dq[4][4]; - const int4* b_ptr4 = (int4*)b_ptr; + const int4 *b_ptr4 = (int4 *)b_ptr; int4 load_int4 = *b_ptr4; dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, @@ -929,7 +952,8 @@ __global__ void reconstruct_exllama_4bit_kernel( // half* dqh = (half*)dq; if (b_q_perm) { for (int j = 0; j < 4; j++) { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), @@ -937,7 +961,8 @@ __global__ void reconstruct_exllama_4bit_kernel( } } else { for (int j = 0; j < 4; j++) { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); @@ -952,10 +977,10 @@ __global__ void reconstruct_exllama_4bit_kernel( } __global__ void reconstruct_exllama_3bit_kernel( - const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, - const int groups, half* __restrict__ b) { + const uint32_t *__restrict__ b_q_weight, const int *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half *__restrict__ b) { MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q3_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); @@ -970,12 +995,14 @@ __global__ void reconstruct_exllama_3bit_kernel( int t = threadIdx.x; if (b_q_perm) { - if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; } // Column int n = offset_n + t * 4; - if (n >= size_n) return; + if (n >= size_n) + return; // Find initial group int groupsize = size_k / groups; @@ -985,7 +1012,7 @@ __global__ void reconstruct_exllama_3bit_kernel( // b offset int qk = offset_k / 32 * 3; - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; // Initial zeros/scale int zeros[4]; @@ -1008,11 +1035,11 @@ __global__ void reconstruct_exllama_3bit_kernel( for (int p = 0; p < 1; p++) { int4 load_int4[3]; - load_int4[0] = *((int4*)b_ptr); + load_int4[0] = *((int4 *)b_ptr); b_ptr += size_n; - load_int4[1] = *((int4*)b_ptr); + load_int4[1] = *((int4 *)b_ptr); b_ptr += size_n; - load_int4[2] = *((int4*)b_ptr); + load_int4[2] = *((int4 *)b_ptr); b_ptr += size_n; half2 dq[4][16]; @@ -1027,7 +1054,8 @@ __global__ void reconstruct_exllama_3bit_kernel( if (b_q_perm) { for (int j = 0; j < 16; j++) { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), @@ -1035,7 +1063,8 @@ __global__ void reconstruct_exllama_3bit_kernel( } } else { for (int j = 0; j < 16; j++) { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); @@ -1050,10 +1079,10 @@ __global__ void reconstruct_exllama_3bit_kernel( } __global__ void reconstruct_exllama_2bit_kernel( - const uint32_t* __restrict__ b_q_weight, const int* __restrict__ b_q_perm, - const uint32_t* __restrict__ b_gptq_qzeros, - const half* __restrict__ b_gptq_scales, const int size_k, const int size_n, - const int groups, half* __restrict__ b) { + const uint32_t *__restrict__ b_q_weight, const int *__restrict__ b_q_perm, + const uint32_t *__restrict__ b_gptq_qzeros, + const half *__restrict__ b_gptq_scales, const int size_k, const int size_n, + const int groups, half *__restrict__ b) { MatrixView_half_rw b_(b, size_k, size_n); MatrixView_q2_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); @@ -1068,12 +1097,14 @@ __global__ void reconstruct_exllama_2bit_kernel( int t = threadIdx.x; if (b_q_perm) { - if (offset_k + t < size_k) perm[t] = b_q_perm[offset_k + t]; + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; } // Column int n = offset_n + t * 4; - if (n >= size_n) return; + if (n >= size_n) + return; // Find initial group int groupsize = size_k / groups; @@ -1083,7 +1114,7 @@ __global__ void reconstruct_exllama_2bit_kernel( // b offset int qk = offset_k / (32 / 2); - const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const uint32_t *b_ptr = b_q_weight + qk * size_n + n; // Initial zeros/scale int zeros[4]; @@ -1105,7 +1136,7 @@ __global__ void reconstruct_exllama_2bit_kernel( } for (int p = 0; p < 2; p++) { - const int4* b_ptr4 = (int4*)b_ptr; + const int4 *b_ptr4 = (int4 *)b_ptr; int4 load_int4 = *b_ptr4; half2 dq[4][8]; @@ -1118,7 +1149,8 @@ __global__ void reconstruct_exllama_2bit_kernel( // half* dqh = (half*)dq; if (b_q_perm) { for (int j = 0; j < 8; j++) { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), @@ -1126,7 +1158,8 @@ __global__ void reconstruct_exllama_2bit_kernel( } } else { for (int j = 0; j < 8; j++) { - for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + for (int v = 0; v < 4; v++) + dq[v][j] = __hmul2(scales[v], dq[v][j]); b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); @@ -1140,11 +1173,11 @@ __global__ void reconstruct_exllama_2bit_kernel( } } -extern "C" void reconstruct_exllama(const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, const int* b_q_perm, - half* out, int height, int width, int groups, - int bit) { +extern "C" void reconstruct_exllama(const uint32_t *b_q_weight, + const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, + const int *b_q_perm, half *out, int height, + int width, int groups, int bit) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -1166,9 +1199,9 @@ extern "C" void reconstruct_exllama(const uint32_t* b_q_weight, } __global__ void gemm_half_q_half_alt_4bit_kernel( - const half2* __restrict__ vec, const uint32_t* __restrict__ mat, - half* __restrict__ mul, const half* __restrict__ scales, - const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, + const half2 *__restrict__ vec, const uint32_t *__restrict__ mat, + half *__restrict__ mul, const half *__restrict__ scales, + const uint32_t *__restrict__ zeros, const int *__restrict__ g_idx, int batch, int height, int width) { int zero_width = width / 8; int vec_height = height * 4; @@ -1197,7 +1230,8 @@ __global__ void gemm_half_q_half_alt_4bit_kernel( } if (blockIdx.z == 0) { - for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); + for (int m = 0; m < b_end; m++) + mul[(b + m) * width + w] = __int2half_rn(0); } __syncthreads(); @@ -1265,9 +1299,9 @@ __global__ void gemm_half_q_half_alt_4bit_kernel( } __global__ void gemm_half_q_half_alt_8bit_kernel( - const half2* __restrict__ vec, const uint32_t* __restrict__ mat, - half* __restrict__ mul, const half* __restrict__ scales, - const uint32_t* __restrict__ zeros, const int* __restrict__ g_idx, + const half2 *__restrict__ vec, const uint32_t *__restrict__ mat, + half *__restrict__ mul, const half *__restrict__ scales, + const uint32_t *__restrict__ zeros, const int *__restrict__ g_idx, int batch, int height, int width) { int zero_width = width / 4; int vec_height = height * 2; @@ -1288,7 +1322,8 @@ __global__ void gemm_half_q_half_alt_8bit_kernel( } if (blockIdx.z == 0) { - for (int m = 0; m < b_end; m++) mul[(b + m) * width + w] = __int2half_rn(0); + for (int m = 0; m < b_end; m++) + mul[(b + m) * width + w] = __int2half_rn(0); } __syncthreads(); @@ -1351,11 +1386,11 @@ __global__ void gemm_half_q_half_alt_8bit_kernel( } } -extern "C" void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, - const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, const int* b_g_idx, - half* c, int size_m, int size_n, int size_k, - int bit) { +extern "C" void gemm_half_q_half_alt(const half *a, const uint32_t *b_q_weight, + const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, + const int *b_g_idx, half *c, int size_m, + int size_n, int size_k, int bit) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -1369,24 +1404,25 @@ extern "C" void gemm_half_q_half_alt(const half* a, const uint32_t* b_q_weight, kernel = gemm_half_q_half_alt_8bit_kernel; } - kernel<<>>( - (const half2*)a, b_q_weight, c, b_gptq_scales, b_gptq_qzeros, b_g_idx, - size_m, size_k / 32 * bit, size_n); + kernel<<>>((const half2 *)a, b_q_weight, c, + b_gptq_scales, b_gptq_qzeros, b_g_idx, + size_m, size_k / 32 * bit, size_n); } template -__global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, - const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, - const int* __restrict__ g_idx, +__global__ void reconstruct_gptq_kernel(const uint32_t *__restrict__ w, + const half *__restrict__ w_scales, + const uint32_t *__restrict__ w_zeros, + const int *__restrict__ g_idx, const int height, const int width, const int group, - half* __restrict__ out) { + half *__restrict__ out) { // Start of block int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; int row = blockIdx.y * 32 / bit; - if (column >= width) return; + if (column >= width) + return; // Views @@ -1395,7 +1431,7 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, T w_zeros_(w_zeros, group, width); uint32_t w_read = w[blockIdx.y * width + column]; - half* out_ptr = out_.item_ptr(row, column); + half *out_ptr = out_.item_ptr(row, column); #pragma unroll for (int s = 0; s < 32; s += bit) { @@ -1411,14 +1447,15 @@ __global__ void reconstruct_gptq_kernel(const uint32_t* __restrict__ w, } __global__ void reconstruct_gptq_3bit_kernel( - const uint32_t* __restrict__ w, const half* __restrict__ w_scales, - const uint32_t* __restrict__ w_zeros, const int* __restrict__ g_idx, + const uint32_t *__restrict__ w, const half *__restrict__ w_scales, + const uint32_t *__restrict__ w_zeros, const int *__restrict__ g_idx, const int height, const int width, const int group, - half* __restrict__ out) { + half *__restrict__ out) { // Start of block int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; int row = blockIdx.y * 32; - if (column >= width) return; + if (column >= width) + return; // Views @@ -1429,7 +1466,7 @@ __global__ void reconstruct_gptq_3bit_kernel( uint32_t w1 = w[(blockIdx.y * 3) * width + column]; uint32_t w2 = w[(blockIdx.y * 3 + 1) * width + column]; uint32_t w3 = w[(blockIdx.y * 3 + 2) * width + column]; - half* out_ptr = out_.item_ptr(row, column); + half *out_ptr = out_.item_ptr(row, column); #pragma unroll for (int i = 0; i < 32; i += 1) { @@ -1453,9 +1490,11 @@ __global__ void reconstruct_gptq_3bit_kernel( } } -extern "C" void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_gptq_qzeros, - const half* b_gptq_scales, const int* b_g_idx, half* out, - int height, int width, int groups, int bit) { +extern "C" void reconstruct_gptq(const uint32_t *b_q_weight, + const uint32_t *b_gptq_qzeros, + const half *b_gptq_scales, const int *b_g_idx, + half *out, int height, int width, int groups, + int bit) { dim3 blockDim, gridDim; blockDim.x = BLOCK_KN_SIZE; blockDim.y = 1; @@ -1472,9 +1511,8 @@ extern "C" void reconstruct_gptq(const uint32_t* b_q_weight, const uint32_t* b_g gridDim.y = DIVIDE(height, 32); } - kernel<<>>(b_q_weight, b_gptq_scales, - b_gptq_qzeros, b_g_idx, height, - width, groups, out); + kernel<<>>(b_q_weight, b_gptq_scales, b_gptq_qzeros, + b_g_idx, height, width, groups, out); } /* @@ -1532,12 +1570,13 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, } */ -__global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight, +__global__ void shuffle_4bit_kernel(uint32_t *__restrict__ b_q_weight, const int size_k, const int size_n) { int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; + if (n >= size_n) + return; int k = 0; - uint32_t* b_ptr = b_q_weight + n; + uint32_t *b_ptr = b_q_weight + n; while (k < size_k) { shuffle_4bit_8(b_ptr, size_n); b_ptr += 1 * size_n; @@ -1545,12 +1584,13 @@ __global__ void shuffle_4bit_kernel(uint32_t* __restrict__ b_q_weight, } } -__global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight, +__global__ void shuffle_8bit_kernel(uint32_t *__restrict__ b_q_weight, const int size_k, const int size_n) { int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; + if (n >= size_n) + return; int k = 0; - uint32_t* b_ptr = b_q_weight + n; + uint32_t *b_ptr = b_q_weight + n; while (k < size_k) { shuffle_8bit_4(b_ptr, size_n); b_ptr += 1 * size_n; @@ -1558,12 +1598,13 @@ __global__ void shuffle_8bit_kernel(uint32_t* __restrict__ b_q_weight, } } -__global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight, +__global__ void shuffle_2bit_kernel(uint32_t *__restrict__ b_q_weight, const int size_k, const int size_n) { int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; + if (n >= size_n) + return; int k = 0; - uint32_t* b_ptr = b_q_weight + n; + uint32_t *b_ptr = b_q_weight + n; while (k < size_k) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; @@ -1571,12 +1612,13 @@ __global__ void shuffle_2bit_kernel(uint32_t* __restrict__ b_q_weight, } } -__global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight, +__global__ void shuffle_3bit_kernel(uint32_t *__restrict__ b_q_weight, const int size_k, const int size_n) { int n = blockIdx.x * THREADS_X + threadIdx.x; - if (n >= size_n) return; + if (n >= size_n) + return; int k = 0; - uint32_t* b_ptr = b_q_weight + n; + uint32_t *b_ptr = b_q_weight + n; while (k < size_k) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; @@ -1584,15 +1626,16 @@ __global__ void shuffle_3bit_kernel(uint32_t* __restrict__ b_q_weight, } } -__global__ void make_sequential_4bit_kernel(const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, +__global__ void make_sequential_4bit_kernel(const uint32_t *__restrict__ w, + uint32_t *__restrict__ w_new, + const int *__restrict__ q_perm, const int w_width) { - const uint64_t* w2 = (uint64_t*)w; - uint64_t* w_new2 = (uint64_t*)w_new; + const uint64_t *w2 = (uint64_t *)w; + uint64_t *w_new2 = (uint64_t *)w_new; int w2_stride = w_width >> 1; int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; + if (w2_column >= w2_stride) + return; int w_new2_row = blockIdx.y; int q_perm_idx = w_new2_row << 3; uint64_t dst = 0; @@ -1615,15 +1658,16 @@ __global__ void make_sequential_4bit_kernel(const uint32_t* __restrict__ w, w_new2[w_new2_row * w2_stride + w2_column] = dst; } -__global__ void make_sequential_2bit_kernel(const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, +__global__ void make_sequential_2bit_kernel(const uint32_t *__restrict__ w, + uint32_t *__restrict__ w_new, + const int *__restrict__ q_perm, const int w_width) { - const uint64_t* w2 = (uint64_t*)w; - uint64_t* w_new2 = (uint64_t*)w_new; + const uint64_t *w2 = (uint64_t *)w; + uint64_t *w_new2 = (uint64_t *)w_new; int w2_stride = w_width >> 1; int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; + if (w2_column >= w2_stride) + return; int w_new2_row = blockIdx.y; int q_perm_idx = w_new2_row << 4; uint64_t dst = 0; @@ -1646,12 +1690,13 @@ __global__ void make_sequential_2bit_kernel(const uint32_t* __restrict__ w, w_new2[w_new2_row * w2_stride + w2_column] = dst; } -__global__ void make_sequential_3bit_kernel(const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, +__global__ void make_sequential_3bit_kernel(const uint32_t *__restrict__ w, + uint32_t *__restrict__ w_new, + const int *__restrict__ q_perm, const int w_width) { int w_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w_column >= w_width) return; + if (w_column >= w_width) + return; int w_new_row = blockIdx.y * 3; int q_perm_idx = blockIdx.y << 5; uint32_t dst[3] = {0, 0, 0}; @@ -1729,15 +1774,16 @@ __global__ void make_sequential_3bit_kernel(const uint32_t* __restrict__ w, w_new[(w_new_row + 2) * w_width + w_column] = dst[2]; } -__global__ void make_sequential_8bit_kernel(const uint32_t* __restrict__ w, - uint32_t* __restrict__ w_new, - const int* __restrict__ q_perm, +__global__ void make_sequential_8bit_kernel(const uint32_t *__restrict__ w, + uint32_t *__restrict__ w_new, + const int *__restrict__ q_perm, const int w_width) { - const uint64_t* w2 = (uint64_t*)w; - uint64_t* w_new2 = (uint64_t*)w_new; + const uint64_t *w2 = (uint64_t *)w; + uint64_t *w_new2 = (uint64_t *)w_new; int w2_stride = w_width >> 1; int w2_column = THREADS_X * blockIdx.x + threadIdx.x; - if (w2_column >= w2_stride) return; + if (w2_column >= w2_stride) + return; int w_new2_row = blockIdx.y; int q_perm_idx = w_new2_row << 2; uint64_t dst = 0; diff --git a/mistralrs-quant/kernels/hqq/hqq.cu b/mistralrs-quant/kernels/hqq/hqq.cu index 2efd7ac858..f662c849e7 100644 --- a/mistralrs-quant/kernels/hqq/hqq.cu +++ b/mistralrs-quant/kernels/hqq/hqq.cu @@ -1,9 +1,9 @@ // https://github.com/mobiusml/hqq/blob/master/hqq/kernels/hqq_aten_cuda_kernel.cu +#include #include #include #include -#include #if __CUDA_ARCH__ >= 530 #include "cuda_fp16.h" @@ -12,103 +12,142 @@ #include "cuda_bf16.h" #endif -inline unsigned int cdiv(unsigned int a, unsigned int b) { return (a + b - 1) / b;} -#define BLOCK_SIZE 256 //~256 +inline unsigned int cdiv(unsigned int a, unsigned int b) { + return (a + b - 1) / b; +} +#define BLOCK_SIZE 256 //~256 #define SHARED_SIZE 512 //~512 /*******************************************************************************************************************************************/ /************* 8-bit *************/ /*******************************************************************************************************************************************/ -//Simple +// Simple template -__global__ void dequantize_8bit_u8_kernel(unsigned char* Wq_packed, T* scale, T* zero, T* W_r, int h, int w) { - int i = blockIdx.x*blockDim.x + threadIdx.x; - int n = h*w; - if(i>=n) return; - - int j = i % w; - W_r[i] = ((T)(Wq_packed[i]) - zero[j])*scale[j]; +__global__ void dequantize_8bit_u8_kernel(unsigned char *Wq_packed, T *scale, + T *zero, T *W_r, int h, int w) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int n = h * w; + if (i >= n) + return; + + int j = i % w; + W_r[i] = ((T)(Wq_packed[i]) - zero[j]) * scale[j]; } -extern "C" void dequantize_8bit_u8_kernel_f32(unsigned char* Wq_packed, float* scale, float* zero, float* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_8bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_8bit_u8_kernel_f32(unsigned char *Wq_packed, + float *scale, float *zero, + float *W_r, int h, int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_8bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #if __CUDA_ARCH__ >= 530 -extern "C" void dequantize_8bit_u8_kernel_f16(unsigned char* Wq_packed, __half* scale, __half* zero, __half* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_8bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_8bit_u8_kernel_f16(unsigned char *Wq_packed, + __half *scale, __half *zero, + __half *W_r, int h, int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_8bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #else -extern "C" void dequantize_8bit_u8_kernel_f16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) { - assert(false); +extern "C" void dequantize_8bit_u8_kernel_f16(unsigned char *Wq_packed, + uint16_t *scale, uint16_t *zero, + uint16_t *W_r, int h, int w) { + assert(false); } #endif #if __CUDA_ARCH__ >= 800 -extern "C" void dequantize_8bit_u8_kernel_bf16(unsigned char* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_8bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_8bit_u8_kernel_bf16(unsigned char *Wq_packed, + __nv_bfloat16 *scale, + __nv_bfloat16 *zero, + __nv_bfloat16 *W_r, int h, + int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_8bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #else -extern "C" void dequantize_8bit_u8_kernel_bf16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) { - assert(false); +extern "C" void dequantize_8bit_u8_kernel_bf16(unsigned char *Wq_packed, + uint16_t *scale, uint16_t *zero, + uint16_t *W_r, int h, int w) { + assert(false); } #endif - /*******************************************************************************************************************************************/ /************* 4-bit *************/ /*******************************************************************************************************************************************/ -//Simple -/*__global__ void unpack_4bit_u8_kernel(unsigned char* Wq_packed, unsigned char* Wq_unpacked, int n) { - int i = blockIdx.x*blockDim.x + threadIdx.x; - if(i>=n) return; +// Simple +/*__global__ void unpack_4bit_u8_kernel(unsigned char* Wq_packed, unsigned char* +Wq_unpacked, int n) { int i = blockIdx.x*blockDim.x + threadIdx.x; if(i>=n) +return; - Wq_unpacked[i] = (Wq_packed[i] & 0xF0) >> 4; //First chunk - Wq_unpacked[i + n] = (Wq_packed[i] & 0x0F); //Second chunk + Wq_unpacked[i] = (Wq_packed[i] & 0xF0) >> 4; //First chunk + Wq_unpacked[i + n] = (Wq_packed[i] & 0x0F); //Second chunk }*/ -//Simple +// Simple template -__global__ void dequantize_4bit_u8_kernel(unsigned char* Wq_packed, T* scale, T* zero, T* W_r, int h, int w) { - int i = blockIdx.x*blockDim.x + threadIdx.x; - int n = h*w; - if(i>=n) return; - - int j = i % w; - //W_r[i] = (T)((Wq_packed[i] & 0xF0) >> 4);//((T)((Wq_packed[i] & 0xF0) >> 4) - zero[j])*scale[j]; //First chunk - //W_r[i + n] = (T)((Wq_packed[i] & 0x0F)) + (T)10000;//((T)((Wq_packed[i] & 0x0F)) - zero[j])*scale[j]; //Second chunk - W_r[i] = ((T)((Wq_packed[i] & 0xF0) >> 4) - zero[j])*scale[j]; //First chunk - W_r[i + n] = ((T)((Wq_packed[i] & 0x0F)) - zero[j])*scale[j]; //Second chunk +__global__ void dequantize_4bit_u8_kernel(unsigned char *Wq_packed, T *scale, + T *zero, T *W_r, int h, int w) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int n = h * w; + if (i >= n) + return; + + int j = i % w; + // W_r[i] = (T)((Wq_packed[i] & 0xF0) >> 4);//((T)((Wq_packed[i] & 0xF0) + // >> 4) - zero[j])*scale[j]; //First chunk W_r[i + n] = (T)((Wq_packed[i] & + // 0x0F)) + (T)10000;//((T)((Wq_packed[i] & 0x0F)) - zero[j])*scale[j]; + // //Second chunk + W_r[i] = ((T)((Wq_packed[i] & 0xF0) >> 4) - zero[j]) * scale[j]; // First + // chunk + W_r[i + n] = ((T)((Wq_packed[i] & 0x0F)) - zero[j]) * scale[j]; // Second + // chunk } -extern "C" void dequantize_4bit_u8_kernel_f32(unsigned char* Wq_packed, float* scale, float* zero, float* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_4bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_4bit_u8_kernel_f32(unsigned char *Wq_packed, + float *scale, float *zero, + float *W_r, int h, int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_4bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #if __CUDA_ARCH__ >= 530 -extern "C" void dequantize_4bit_u8_kernel_f16(unsigned char* Wq_packed, __half* scale, __half* zero, __half* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_4bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_4bit_u8_kernel_f16(unsigned char *Wq_packed, + __half *scale, __half *zero, + __half *W_r, int h, int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_4bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #else -extern "C" void dequantize_4bit_u8_kernel_f16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) { - assert(false); +extern "C" void dequantize_4bit_u8_kernel_f16(unsigned char *Wq_packed, + uint16_t *scale, uint16_t *zero, + uint16_t *W_r, int h, int w) { + assert(false); } #endif #if __CUDA_ARCH__ >= 800 -extern "C" void dequantize_4bit_u8_kernel_bf16(unsigned char* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_4bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_4bit_u8_kernel_bf16(unsigned char *Wq_packed, + __nv_bfloat16 *scale, + __nv_bfloat16 *zero, + __nv_bfloat16 *W_r, int h, + int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_4bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #else -extern "C" void dequantize_4bit_u8_kernel_bf16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) { - assert(false); +extern "C" void dequantize_4bit_u8_kernel_bf16(unsigned char *Wq_packed, + uint16_t *scale, uint16_t *zero, + uint16_t *W_r, int h, int w) { + assert(false); } #endif @@ -116,244 +155,314 @@ extern "C" void dequantize_4bit_u8_kernel_bf16(unsigned char* Wq_packed, uint16_ /************* 2-bit *************/ /*******************************************************************************************************************************************/ -//Simple -/*__global__ void unpack_2bit_u8_kernel(unsigned char* Wq_packed, unsigned char* Wq_unpacked, int n) { - int i = blockIdx.x*blockDim.x + threadIdx.x; - if(i>=n) return; +// Simple +/*__global__ void unpack_2bit_u8_kernel(unsigned char* Wq_packed, unsigned char* +Wq_unpacked, int n) { int i = blockIdx.x*blockDim.x + threadIdx.x; if(i>=n) +return; - Wq_unpacked[i] = (Wq_packed[i] & 0xC0) >> 6; //1st chunk - Wq_unpacked[i + n] = (Wq_packed[i] & 0x30) >> 4; //2nd chunk - Wq_unpacked[i + n*2] = (Wq_packed[i] & 0x0C) >> 2; //3rd chunk - Wq_unpacked[i + n*3] = (Wq_packed[i] & 0x03); //4th chunk + Wq_unpacked[i] = (Wq_packed[i] & 0xC0) >> 6; //1st chunk + Wq_unpacked[i + n] = (Wq_packed[i] & 0x30) >> 4; //2nd chunk + Wq_unpacked[i + n*2] = (Wq_packed[i] & 0x0C) >> 2; //3rd chunk + Wq_unpacked[i + n*3] = (Wq_packed[i] & 0x03); //4th chunk }*/ - -//Simple +// Simple template -__global__ void dequantize_2bit_u8_kernel(unsigned char* Wq_packed, T* scale, T* zero, T* W_r, int h, int w) { - int i = blockIdx.x*blockDim.x + threadIdx.x; - int n = h*w; - if(i>=n) return; - - int j = i % w; - W_r[i] = ((T)((Wq_packed[i] & 0xC0) >> 6) - zero[j])*scale[j]; //1st chunk - W_r[i + n] = ((T)((Wq_packed[i] & 0x30) >> 4) - zero[j])*scale[j]; //2nd chunk - W_r[i + n*2] = ((T)((Wq_packed[i] & 0x0C) >> 2) - zero[j])*scale[j]; //3rd chunk - W_r[i + n*3] = ((T)((Wq_packed[i] & 0x03)) - zero[j])*scale[j]; //4th chunk +__global__ void dequantize_2bit_u8_kernel(unsigned char *Wq_packed, T *scale, + T *zero, T *W_r, int h, int w) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int n = h * w; + if (i >= n) + return; + + int j = i % w; + W_r[i] = ((T)((Wq_packed[i] & 0xC0) >> 6) - zero[j]) * scale[j]; // 1st chunk + W_r[i + n] = + ((T)((Wq_packed[i] & 0x30) >> 4) - zero[j]) * scale[j]; // 2nd chunk + W_r[i + n * 2] = + ((T)((Wq_packed[i] & 0x0C) >> 2) - zero[j]) * scale[j]; // 3rd chunk + W_r[i + n * 3] = + ((T)((Wq_packed[i] & 0x03)) - zero[j]) * scale[j]; // 4th chunk } -extern "C" void dequantize_2bit_u8_kernel_f32(unsigned char* Wq_packed, float* scale, float* zero, float* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_2bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_2bit_u8_kernel_f32(unsigned char *Wq_packed, + float *scale, float *zero, + float *W_r, int h, int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_2bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #if __CUDA_ARCH__ >= 530 -extern "C" void dequantize_2bit_u8_kernel_f16(unsigned char* Wq_packed, __half* scale, __half* zero, __half* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_2bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_2bit_u8_kernel_f16(unsigned char *Wq_packed, + __half *scale, __half *zero, + __half *W_r, int h, int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_2bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #else -extern "C" void dequantize_2bit_u8_kernel_f16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) { - assert(false); +extern "C" void dequantize_2bit_u8_kernel_f16(unsigned char *Wq_packed, + uint16_t *scale, uint16_t *zero, + uint16_t *W_r, int h, int w) { + assert(false); } #endif #if __CUDA_ARCH__ >= 800 -extern "C" void dequantize_2bit_u8_kernel_bf16(unsigned char* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_2bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_2bit_u8_kernel_bf16(unsigned char *Wq_packed, + __nv_bfloat16 *scale, + __nv_bfloat16 *zero, + __nv_bfloat16 *W_r, int h, + int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_2bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #else -extern "C" void dequantize_2bit_u8_kernel_bf16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) { - assert(false); +extern "C" void dequantize_2bit_u8_kernel_bf16(unsigned char *Wq_packed, + uint16_t *scale, uint16_t *zero, + uint16_t *W_r, int h, int w) { + assert(false); } #endif - // //Shared // template -// __global__ void dequantize_2bit_u8_kernel(unsigned char* Wq_packed, scalar_t* scale, scalar_t* zero, scalar_t* W_r, int h, int w) { -// int i = blockIdx.x*blockDim.x + threadIdx.x; -// int n = h*w; -// int s = threadIdx.x; +// __global__ void dequantize_2bit_u8_kernel(unsigned char* Wq_packed, scalar_t* +// scale, scalar_t* zero, scalar_t* W_r, int h, int w) { int i = +// blockIdx.x*blockDim.x + threadIdx.x; int n = h*w; int s = threadIdx.x; // if(i>=n) return; // __shared__ unsigned char shared[BLOCK_SIZE]; // __shared__ scalar_t shared_meta[BLOCK_SIZE][2]; - + // int j = i % w; // shared[s] = Wq_packed[i]; // shared_meta[s][0] = zero[j]; // shared_meta[s][1] = scale[j]; // __syncthreads(); - -// W_r[i] = (scalar_t((shared[s] & 0xC0) >> 6) - shared_meta[s][0])*shared_meta[s][1]; //1st chunk -// W_r[i + n] = (scalar_t((shared[s] & 0x30) >> 4) - shared_meta[s][0])*shared_meta[s][1]; //2nd chunk -// W_r[i + n*2] = (scalar_t((shared[s] & 0x0C) >> 2) - shared_meta[s][0])*shared_meta[s][1]; //3rd chunk -// W_r[i + n*3] = (scalar_t((shared[s] & 0x03)) - shared_meta[s][0])*shared_meta[s][1]; //4th chunk +// W_r[i] = (scalar_t((shared[s] & 0xC0) >> 6) - +// shared_meta[s][0])*shared_meta[s][1]; //1st chunk W_r[i + n] = +// (scalar_t((shared[s] & 0x30) >> 4) - shared_meta[s][0])*shared_meta[s][1]; +// //2nd chunk W_r[i + n*2] = (scalar_t((shared[s] & 0x0C) >> 2) - +// shared_meta[s][0])*shared_meta[s][1]; //3rd chunk W_r[i + n*3] = +// (scalar_t((shared[s] & 0x03)) - shared_meta[s][0])*shared_meta[s][1]; +// //4th chunk // } - - /*******************************************************************************************************************************************/ /************* 1-bit *************/ /*******************************************************************************************************************************************/ -//Simple -/*__global__ void unpack_1bit_u8_kernel(unsigned char* Wq_packed, unsigned char* Wq_unpacked, int n) { - int i = blockIdx.x*blockDim.x + threadIdx.x; - if(i>=n) return; - - Wq_unpacked[i] = (Wq_packed[i] & 0x80) >> 7; //1st chunk - Wq_unpacked[i + n] = (Wq_packed[i] & 0x40) >> 6; //2nd chunk - Wq_unpacked[i + n*2] = (Wq_packed[i] & 0x20) >> 5; //3rd chunk - Wq_unpacked[i + n*3] = (Wq_packed[i] & 0x10) >> 4; //4th chunk - Wq_unpacked[i + n*4] = (Wq_packed[i] & 0x08) >> 3; //5th chunk - Wq_unpacked[i + n*5] = (Wq_packed[i] & 0x04) >> 2; //6th chunk - Wq_unpacked[i + n*6] = (Wq_packed[i] & 0x02) >> 1; //7th chunk - Wq_unpacked[i + n*7] = (Wq_packed[i] & 0x01); //8th chunk +// Simple +/*__global__ void unpack_1bit_u8_kernel(unsigned char* Wq_packed, unsigned char* +Wq_unpacked, int n) { int i = blockIdx.x*blockDim.x + threadIdx.x; if(i>=n) +return; + + Wq_unpacked[i] = (Wq_packed[i] & 0x80) >> 7; //1st chunk + Wq_unpacked[i + n] = (Wq_packed[i] & 0x40) >> 6; //2nd chunk + Wq_unpacked[i + n*2] = (Wq_packed[i] & 0x20) >> 5; //3rd chunk + Wq_unpacked[i + n*3] = (Wq_packed[i] & 0x10) >> 4; //4th chunk + Wq_unpacked[i + n*4] = (Wq_packed[i] & 0x08) >> 3; //5th chunk + Wq_unpacked[i + n*5] = (Wq_packed[i] & 0x04) >> 2; //6th chunk + Wq_unpacked[i + n*6] = (Wq_packed[i] & 0x02) >> 1; //7th chunk + Wq_unpacked[i + n*7] = (Wq_packed[i] & 0x01); //8th chunk }*/ -//Simple +// Simple template -__global__ void dequantize_1bit_u8_kernel(unsigned char* Wq_packed, T* scale, T* zero, T* W_r, int h, int w) { - int i = blockIdx.x*blockDim.x + threadIdx.x; - int n = h*w; - if(i>=n) return; - - int j = i % w; - W_r[i] = ((T)((Wq_packed[i] & 0x80) >> 7) - zero[j])*scale[j]; //1st chunk - W_r[i + n] = ((T)((Wq_packed[i] & 0x40) >> 6) - zero[j])*scale[j]; //2nd chunk - W_r[i + n*2] = ((T)((Wq_packed[i] & 0x20) >> 5) - zero[j])*scale[j]; //3rd chunk - W_r[i + n*3] = ((T)((Wq_packed[i] & 0x10) >> 4) - zero[j])*scale[j]; //4th chunk - W_r[i + n*4] = ((T)((Wq_packed[i] & 0x08) >> 3) - zero[j])*scale[j]; //5th chunk - W_r[i + n*5] = ((T)((Wq_packed[i] & 0x04) >> 2) - zero[j])*scale[j]; //6th chunk - W_r[i + n*6] = ((T)((Wq_packed[i] & 0x02) >> 1) - zero[j])*scale[j]; //7th chunk - W_r[i + n*7] = ((T)((Wq_packed[i] & 0x01)) - zero[j])*scale[j]; //8th chunk +__global__ void dequantize_1bit_u8_kernel(unsigned char *Wq_packed, T *scale, + T *zero, T *W_r, int h, int w) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int n = h * w; + if (i >= n) + return; + + int j = i % w; + W_r[i] = ((T)((Wq_packed[i] & 0x80) >> 7) - zero[j]) * scale[j]; // 1st chunk + W_r[i + n] = + ((T)((Wq_packed[i] & 0x40) >> 6) - zero[j]) * scale[j]; // 2nd chunk + W_r[i + n * 2] = + ((T)((Wq_packed[i] & 0x20) >> 5) - zero[j]) * scale[j]; // 3rd chunk + W_r[i + n * 3] = + ((T)((Wq_packed[i] & 0x10) >> 4) - zero[j]) * scale[j]; // 4th chunk + W_r[i + n * 4] = + ((T)((Wq_packed[i] & 0x08) >> 3) - zero[j]) * scale[j]; // 5th chunk + W_r[i + n * 5] = + ((T)((Wq_packed[i] & 0x04) >> 2) - zero[j]) * scale[j]; // 6th chunk + W_r[i + n * 6] = + ((T)((Wq_packed[i] & 0x02) >> 1) - zero[j]) * scale[j]; // 7th chunk + W_r[i + n * 7] = + ((T)((Wq_packed[i] & 0x01)) - zero[j]) * scale[j]; // 8th chunk } -extern "C" void dequantize_1bit_u8_kernel_f32(unsigned char* Wq_packed, float* scale, float* zero, float* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_1bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_1bit_u8_kernel_f32(unsigned char *Wq_packed, + float *scale, float *zero, + float *W_r, int h, int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_1bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #if __CUDA_ARCH__ >= 530 -extern "C" void dequantize_1bit_u8_kernel_f16(unsigned char* Wq_packed, __half* scale, __half* zero, __half* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_1bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_1bit_u8_kernel_f16(unsigned char *Wq_packed, + __half *scale, __half *zero, + __half *W_r, int h, int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_1bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #else -extern "C" void dequantize_1bit_u8_kernel_f16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) { - assert(false); +extern "C" void dequantize_1bit_u8_kernel_f16(unsigned char *Wq_packed, + uint16_t *scale, uint16_t *zero, + uint16_t *W_r, int h, int w) { + assert(false); } #endif #if __CUDA_ARCH__ >= 800 -extern "C" void dequantize_1bit_u8_kernel_bf16(unsigned char* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_1bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_1bit_u8_kernel_bf16(unsigned char *Wq_packed, + __nv_bfloat16 *scale, + __nv_bfloat16 *zero, + __nv_bfloat16 *W_r, int h, + int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_1bit_u8_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #else -extern "C" void dequantize_1bit_u8_kernel_bf16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) { - assert(false); +extern "C" void dequantize_1bit_u8_kernel_bf16(unsigned char *Wq_packed, + uint16_t *scale, uint16_t *zero, + uint16_t *W_r, int h, int w) { + assert(false); } #endif // //Shared // template -// __global__ void dequantize_1bit_u8_kernel(unsigned char* Wq_packed, scalar_t* scale, scalar_t* zero, scalar_t* W_r, int h, int w) { -// int i = blockIdx.x*blockDim.x + threadIdx.x; -// int s = threadIdx.x; -// int n = h*w; +// __global__ void dequantize_1bit_u8_kernel(unsigned char* Wq_packed, scalar_t* +// scale, scalar_t* zero, scalar_t* W_r, int h, int w) { int i = +// blockIdx.x*blockDim.x + threadIdx.x; int s = threadIdx.x; int n = h*w; // if(i>=n) return; // __shared__ unsigned char shared[BLOCK_SIZE]; // __shared__ scalar_t shared_meta[BLOCK_SIZE][2]; - + // int j = i % w; // shared[s] = Wq_packed[i]; // shared_meta[s][0] = zero[j]; // shared_meta[s][1] = scale[j]; // __syncthreads(); -// W_r[i] = (scalar_t((shared[s] & 0x80) >> 7) - shared_meta[s][0])*shared_meta[s][1]; //1st chunk -// W_r[i + n] = (scalar_t((shared[s] & 0x40) >> 6) - shared_meta[s][0])*shared_meta[s][1]; //2nd chunk -// W_r[i + n*2] = (scalar_t((shared[s] & 0x20) >> 5) - shared_meta[s][0])*shared_meta[s][1]; //3rd chunk -// W_r[i + n*3] = (scalar_t((shared[s] & 0x10) >> 4) - shared_meta[s][0])*shared_meta[s][1]; //4th chunk -// W_r[i + n*4] = (scalar_t((shared[s] & 0x08) >> 3) - shared_meta[s][0])*shared_meta[s][1]; //5th chunk -// W_r[i + n*5] = (scalar_t((shared[s] & 0x04) >> 2) - shared_meta[s][0])*shared_meta[s][1]; //6th chunk -// W_r[i + n*6] = (scalar_t((shared[s] & 0x02) >> 1) - shared_meta[s][0])*shared_meta[s][1]; //7th chunk -// W_r[i + n*7] = (scalar_t((shared[s] & 0x01)) - shared_meta[s][0])*shared_meta[s][1]; //8th chunk +// W_r[i] = (scalar_t((shared[s] & 0x80) >> 7) - +// shared_meta[s][0])*shared_meta[s][1]; //1st chunk W_r[i + n] = +// (scalar_t((shared[s] & 0x40) >> 6) - shared_meta[s][0])*shared_meta[s][1]; +// //2nd chunk W_r[i + n*2] = (scalar_t((shared[s] & 0x20) >> 5) - +// shared_meta[s][0])*shared_meta[s][1]; //3rd chunk W_r[i + n*3] = +// (scalar_t((shared[s] & 0x10) >> 4) - shared_meta[s][0])*shared_meta[s][1]; +// //4th chunk W_r[i + n*4] = (scalar_t((shared[s] & 0x08) >> 3) - +// shared_meta[s][0])*shared_meta[s][1]; //5th chunk W_r[i + n*5] = +// (scalar_t((shared[s] & 0x04) >> 2) - shared_meta[s][0])*shared_meta[s][1]; +// //6th chunk W_r[i + n*6] = (scalar_t((shared[s] & 0x02) >> 1) - +// shared_meta[s][0])*shared_meta[s][1]; //7th chunk W_r[i + n*7] = +// (scalar_t((shared[s] & 0x01)) - shared_meta[s][0])*shared_meta[s][1]; +// //8th chunk // } - /*******************************************************************************************************************************************/ /************* 3-bit *************/ /*******************************************************************************************************************************************/ -//Simple -/*__global__ void unpack_3bit_32_kernel(int32_t* Wq_packed, unsigned char* Wq_unpacked, int n) { - int i = blockIdx.x*blockDim.x + threadIdx.x; - if(i>=n) return; - - Wq_unpacked[i] = (Wq_packed[i] & 0x38000000) >> 27; //1st chunk - Wq_unpacked[i + n] = (Wq_packed[i] & 0x07000000) >> 24; //2nd chunk - Wq_unpacked[i + n*2] = (Wq_packed[i] & 0x00E00000) >> 21; //3rd chunk - Wq_unpacked[i + n*3] = (Wq_packed[i] & 0x001C0000) >> 18; //4th chunk - Wq_unpacked[i + n*4] = (Wq_packed[i] & 0x00038000) >> 15; //5th chunk - Wq_unpacked[i + n*5] = (Wq_packed[i] & 0x00007000) >> 12; //6th chunk - Wq_unpacked[i + n*6] = (Wq_packed[i] & 0x00000E00) >> 9; //7th chunk - Wq_unpacked[i + n*7] = (Wq_packed[i] & 0x000001C0) >> 6; //8th chunk - Wq_unpacked[i + n*8] = (Wq_packed[i] & 0x00000038) >> 3; //9th chunk - Wq_unpacked[i + n*9] = (Wq_packed[i] & 0x00000007); //10th chunk +// Simple +/*__global__ void unpack_3bit_32_kernel(int32_t* Wq_packed, unsigned char* +Wq_unpacked, int n) { int i = blockIdx.x*blockDim.x + threadIdx.x; if(i>=n) +return; + + Wq_unpacked[i] = (Wq_packed[i] & 0x38000000) >> 27; //1st chunk + Wq_unpacked[i + n] = (Wq_packed[i] & 0x07000000) >> 24; //2nd chunk + Wq_unpacked[i + n*2] = (Wq_packed[i] & 0x00E00000) >> 21; //3rd chunk + Wq_unpacked[i + n*3] = (Wq_packed[i] & 0x001C0000) >> 18; //4th chunk + Wq_unpacked[i + n*4] = (Wq_packed[i] & 0x00038000) >> 15; //5th chunk + Wq_unpacked[i + n*5] = (Wq_packed[i] & 0x00007000) >> 12; //6th chunk + Wq_unpacked[i + n*6] = (Wq_packed[i] & 0x00000E00) >> 9; //7th chunk + Wq_unpacked[i + n*7] = (Wq_packed[i] & 0x000001C0) >> 6; //8th chunk + Wq_unpacked[i + n*8] = (Wq_packed[i] & 0x00000038) >> 3; //9th chunk + Wq_unpacked[i + n*9] = (Wq_packed[i] & 0x00000007); //10th chunk }*/ - -//Simple +// Simple template -__global__ void dequantize_3bit_32_kernel(int32_t* Wq_packed, T* scale, T* zero, T* W_r, int h, int w) { - int i = blockIdx.x*blockDim.x + threadIdx.x; - int n = h*w; - if(i>=n) return; - - int j = i % w; - W_r[i] = ((T)((Wq_packed[i] & 0x38000000) >> 27) - zero[j])*scale[j]; //1st chunk - W_r[i + n] = ((T)((Wq_packed[i] & 0x07000000) >> 24) - zero[j])*scale[j]; //2nd chunk - W_r[i + n*2] = ((T)((Wq_packed[i] & 0x00E00000) >> 21) - zero[j])*scale[j]; //3rd chunk - W_r[i + n*3] = ((T)((Wq_packed[i] & 0x001C0000) >> 18) - zero[j])*scale[j]; //4th chunk - W_r[i + n*4] = ((T)((Wq_packed[i] & 0x00038000) >> 15) - zero[j])*scale[j]; //5th chunk - W_r[i + n*5] = ((T)((Wq_packed[i] & 0x00007000) >> 12) - zero[j])*scale[j]; //6th chunk - W_r[i + n*6] = ((T)((Wq_packed[i] & 0x00000E00) >> 9) - zero[j])*scale[j]; //7th chunk - W_r[i + n*7] = ((T)((Wq_packed[i] & 0x000001C0) >> 6) - zero[j])*scale[j]; //8th chunk - W_r[i + n*8] = ((T)((Wq_packed[i] & 0x00000038) >> 3) - zero[j])*scale[j]; //9th chunk - W_r[i + n*9] = ((T)((Wq_packed[i] & 0x00000007)) - zero[j])*scale[j]; //10th chunk +__global__ void dequantize_3bit_32_kernel(int32_t *Wq_packed, T *scale, T *zero, + T *W_r, int h, int w) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + int n = h * w; + if (i >= n) + return; + + int j = i % w; + W_r[i] = ((T)((Wq_packed[i] & 0x38000000) >> 27) - zero[j]) * + scale[j]; // 1st chunk + W_r[i + n] = ((T)((Wq_packed[i] & 0x07000000) >> 24) - zero[j]) * + scale[j]; // 2nd chunk + W_r[i + n * 2] = ((T)((Wq_packed[i] & 0x00E00000) >> 21) - zero[j]) * + scale[j]; // 3rd chunk + W_r[i + n * 3] = ((T)((Wq_packed[i] & 0x001C0000) >> 18) - zero[j]) * + scale[j]; // 4th chunk + W_r[i + n * 4] = ((T)((Wq_packed[i] & 0x00038000) >> 15) - zero[j]) * + scale[j]; // 5th chunk + W_r[i + n * 5] = ((T)((Wq_packed[i] & 0x00007000) >> 12) - zero[j]) * + scale[j]; // 6th chunk + W_r[i + n * 6] = + ((T)((Wq_packed[i] & 0x00000E00) >> 9) - zero[j]) * scale[j]; // 7th chunk + W_r[i + n * 7] = + ((T)((Wq_packed[i] & 0x000001C0) >> 6) - zero[j]) * scale[j]; // 8th chunk + W_r[i + n * 8] = + ((T)((Wq_packed[i] & 0x00000038) >> 3) - zero[j]) * scale[j]; // 9th chunk + W_r[i + n * 9] = + ((T)((Wq_packed[i] & 0x00000007)) - zero[j]) * scale[j]; // 10th chunk } -extern "C" void dequantize_3bit_32_kernel_f32(int32_t* Wq_packed, float* scale, float* zero, float* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_3bit_32_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_3bit_32_kernel_f32(int32_t *Wq_packed, float *scale, + float *zero, float *W_r, int h, + int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_3bit_32_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #if __CUDA_ARCH__ >= 530 -extern "C" void dequantize_3bit_32_kernel_f16(int32_t* Wq_packed, __half* scale, __half* zero, __half* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_3bit_32_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_3bit_32_kernel_f16(int32_t *Wq_packed, __half *scale, + __half *zero, __half *W_r, int h, + int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_3bit_32_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #else -extern "C" void dequantize_3bit_32_kernel_f16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) { - assert(false); +extern "C" void dequantize_3bit_32_kernel_f16(unsigned char *Wq_packed, + uint16_t *scale, uint16_t *zero, + uint16_t *W_r, int h, int w) { + assert(false); } #endif #if __CUDA_ARCH__ >= 800 -extern "C" void dequantize_3bit_32_kernel_bf16(int32_t* Wq_packed, __nv_bfloat16* scale, __nv_bfloat16* zero, __nv_bfloat16* W_r, int h, int w) { - int blocks = cdiv(h*w, BLOCK_SIZE); - dequantize_3bit_32_kernel<<>>(Wq_packed, scale, zero, W_r, h, w); +extern "C" void dequantize_3bit_32_kernel_bf16(int32_t *Wq_packed, + __nv_bfloat16 *scale, + __nv_bfloat16 *zero, + __nv_bfloat16 *W_r, int h, + int w) { + int blocks = cdiv(h * w, BLOCK_SIZE); + dequantize_3bit_32_kernel<<>>(Wq_packed, scale, zero, W_r, + h, w); } #else -extern "C" void dequantize_3bit_32_kernel_bf16(unsigned char* Wq_packed, uint16_t* scale, uint16_t* zero, uint16_t* W_r, int h, int w) { - assert(false); +extern "C" void dequantize_3bit_32_kernel_bf16(unsigned char *Wq_packed, + uint16_t *scale, uint16_t *zero, + uint16_t *W_r, int h, int w) { + assert(false); } #endif \ No newline at end of file diff --git a/mistralrs-quant/kernels/marlin/dummy_marlin_kernel.cu b/mistralrs-quant/kernels/marlin/dummy_marlin_kernel.cu index f6e73f4ea4..b6098861e2 100644 --- a/mistralrs-quant/kernels/marlin/dummy_marlin_kernel.cu +++ b/mistralrs-quant/kernels/marlin/dummy_marlin_kernel.cu @@ -1,41 +1,18 @@ #include -extern "C" void marlin_4bit_f16( - const void* A, - const void* B, - void* s, - void* C, - int prob_m, - int prob_k, - int prob_n, - void* workspace, - int groupsize -) { +extern "C" void marlin_4bit_f16(const void *A, const void *B, void *s, void *C, + int prob_m, int prob_k, int prob_n, + void *workspace, int groupsize) { assert(false); } -extern "C" void marlin_4bit_bf16( - const void* A, - const void* B, - void* s, - void* C, - int prob_m, - int prob_k, - int prob_n, - void* workspace, - int groupsize -) { +extern "C" void marlin_4bit_bf16(const void *A, const void *B, void *s, void *C, + int prob_m, int prob_k, int prob_n, + void *workspace, int groupsize) { assert(false); } - -extern "C" void gptq_marlin_repack( - void* weight, - void* perm, - void* out, - int size_k, - int size_n, - int num_bits -) { +extern "C" void gptq_marlin_repack(void *weight, void *perm, void *out, + int size_k, int size_n, int num_bits) { assert(false); } diff --git a/mistralrs-quant/kernels/marlin/marlin_kernel.cu b/mistralrs-quant/kernels/marlin/marlin_kernel.cu index 5928eec28d..f890793794 100644 --- a/mistralrs-quant/kernels/marlin/marlin_kernel.cu +++ b/mistralrs-quant/kernels/marlin/marlin_kernel.cu @@ -14,48 +14,45 @@ * limitations under the License. */ - #ifndef MARLIN_CUDA_KERNEL_CUH #define MARLIN_CUDA_KERNEL_CUH +#include "marlin/marlin_dtypes.cuh" +#include #include #include -#include #include -#include "marlin/marlin_dtypes.cuh" using namespace marlin; // m16n8k16 tensor core mma instruction with fp16/bf16 inputs and fp32 // output/accumulation. template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); +__device__ inline void mma(const typename ScalarType::FragA &a_frag, + const typename ScalarType::FragB &frag_b, + typename ScalarType::FragC &frag_c) { + const uint32_t *a = reinterpret_cast(&a_frag); + const uint32_t *b = reinterpret_cast(&frag_b); + float *c = reinterpret_cast(&frag_c); if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), + "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), + "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); } } // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. template -__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, - const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); +__device__ inline void ldsm4(typename ScalarType::FragA &frag_a, + const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_a); uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) @@ -65,8 +62,7 @@ __device__ inline void ldsm4(typename ScalarType::FragA& frag_a, // Lookup-table based 3-input logical operation; explicitly used for // dequantization as the compiler does not seem to automatically recognize it in // all cases. -template -__device__ inline int lop3(int a, int b, int c) { +template __device__ inline int lop3(int a, int b, int c) { int res; asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) @@ -74,7 +70,6 @@ __device__ inline int lop3(int a, int b, int c) { return res; } - template __device__ inline typename ScalarType::FragB dequant(int q); @@ -83,8 +78,7 @@ __device__ inline typename ScalarType::FragB dequant(int q); // changes: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h template <> -__device__ inline typename ScalarType::FragB - dequant(int q) { +__device__ inline typename ScalarType::FragB dequant(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -97,17 +91,17 @@ __device__ inline typename ScalarType::FragB const int MUL = 0x2c002c00; const int ADD = 0xd480d480; typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } template <> __device__ inline typename ScalarType::FragB - dequant(int q) { +dequant(int q) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; @@ -121,50 +115,50 @@ __device__ inline typename ScalarType::FragB static constexpr uint32_t MUL = 0x3F803F80; static constexpr uint32_t ADD = 0xC308C308; - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); + frag_b[0] = __hfma2(*reinterpret_cast(&lo), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); return frag_b; } // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, +__device__ inline void scale(typename ScalarType::FragB &frag_b, + typename ScalarType::FragS &frag_s, int i) { using scalar_t2 = typename ScalarType::scalar_t2; scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); frag_b[0] = __hmul2(frag_b[0], s); frag_b[1] = __hmul2(frag_b[1], s); } template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const int stages, // number of stages for the async global->shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ s, // fp16 quantization scales of shape - // (k/groupsize)xn - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization +__global__ void +Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 *__restrict__ s, // fp16 quantization scales of shape + // (k/groupsize)xn + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * @@ -205,11 +199,11 @@ __global__ void Marlin( int slice_row = (iters * blockIdx.x) % k_tiles; int slice_col_par = (iters * blockIdx.x) / k_tiles; int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice + int slice_iters; // number of threadblock tiles in the current slice int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top // We can easily implement parallel problem execution by just remapping // indices and advancing global pointers @@ -225,22 +219,27 @@ __global__ void Marlin( auto init_slice = [&]() { slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; slice_count = 1; slice_idx = 0; int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); if (col_first <= k_tiles * (slice_col_par + 1)) { int col_off = col_first - k_tiles * slice_col_par; slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; + if (col_off > 0) + slice_count++; int delta_first = iters * blockIdx.x - col_first; if (delta_first < 0 || (col_off == 0 && delta_first == 0)) slice_idx = slice_count - 1; else { slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; + if (col_off > 0) + slice_idx--; } } if (slice_col == n_tiles) { @@ -252,30 +251,29 @@ __global__ void Marlin( }; init_slice(); - int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; // stride of the A matrix in global memory // We typically use `constexpr` to indicate that this value is a compile-time // constant constexpr int a_sh_stride = - 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory + 16 * thread_k_blocks / 8; // stride of an A matrix tile in shared memory constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / - 8; // delta between subsequent A tiles in global memory + 8; // delta between subsequent A tiles in global memory int a_gl_rd_delta_i = a_gl_stride * - (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile constexpr int a_sh_wr_delta = - a_sh_stride * - (threads / a_gl_rd_delta_o); // between shared memory writes + a_sh_stride * (threads / a_gl_rd_delta_o); // between shared memory writes constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / - (thread_n_blocks / 4)); // between shared memory tile reads + (thread_n_blocks / 4)); // between shared memory tile reads constexpr int a_sh_rd_delta_i = - a_sh_stride * 16; // within a shared memory tile + a_sh_stride * 16; // within a shared memory tile constexpr int a_sh_stage = - a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + a_sh_stride * (16 * thread_m_blocks); // overall size of a tile constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, - a_sh_wr_delta); // number of shared write iterations for a tile + a_sh_wr_delta); // number of shared write iterations for a tile int b_gl_stride = 16 * prob_n / 32; constexpr int b_sh_stride = 32 * thread_n_blocks / 4; @@ -328,7 +326,7 @@ __global__ void Marlin( // needed if there are more threads than required for a certain tilesize or // when the batchsize is not a multiple of 16. bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll +#pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; @@ -347,13 +345,13 @@ __global__ void Marlin( // loop unrolls, all shared memory accesses are static, we simply precompute // both transformed reads and writes. int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll +#pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll +#pragma unroll for (int j = 0; j < thread_m_blocks; j++) a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); @@ -363,16 +361,16 @@ __global__ void Marlin( // runtime; we break dependencies between subsequent accesses with a tile by // maintining multiple pointers (we have enough registers), a tiny // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll + const int4 *B_ptr[b_sh_wr_iters]; +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; extern __shared__ int4 sh[]; // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_s = sh_b + (stages * b_sh_stage); + int4 *sh_a = sh; + int4 *sh_b = sh_a + (stages * a_sh_stage); + int4 *sh_s = sh_b + (stages * b_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; I4 frag_b_quant[2]; @@ -381,25 +379,25 @@ __global__ void Marlin( // Zero accumulators. auto zero_accums = [&]() { - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; + reinterpret_cast(frag_c)[i] = 0; }; // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { cp_async4_pred( &sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], a_sh_wr_pred[i]); } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); B_ptr[i] += b_gl_rd_delta_o; @@ -410,8 +408,9 @@ __global__ void Marlin( // and would need to be modified to support smaller groups. static_assert(group_blocks >= thread_k_blocks); if (pipe % (group_blocks / thread_k_blocks) == 0) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if (s_sh_wr_pred) cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + if (s_sh_wr_pred) + cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); s_gl_rd += s_gl_rd_delta; } } @@ -442,35 +441,38 @@ __global__ void Marlin( // This assumes group_blocks >= thread_k_blocks // and would need to be modified to support smaller groups. static_assert(group_blocks >= thread_k_blocks); - int4* sh_s_stage = + int4 *sh_s_stage = sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - frag_b_quant[k % 2] = *reinterpret_cast( + ldsm4(frag_a[k % 2][i], + &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); }; // Execute the actual tensor core matmul of a sub-tile. auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll for (int j = 0; j < 4; j++) { int b_quant = frag_b_quant[k % 2][j]; int b_quant_shift = b_quant >> 8; FragB frag_b0 = dequant(b_quant); // If there are no groups, we can just scale the final output once and can // avoid doing so for each weight. - if (group_blocks != -1) scale(frag_b0, frag_s[k % 2][j], 0); + if (group_blocks != -1) + scale(frag_b0, frag_s[k % 2][j], 0); FragB frag_b1 = dequant(b_quant_shift); - if (group_blocks != -1) scale(frag_b1, frag_s[k % 2][j], 1); - #pragma unroll + if (group_blocks != -1) + scale(frag_b1, frag_s[k % 2][j], 1); +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) { mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); @@ -495,38 +497,38 @@ __global__ void Marlin( // unnecessary read or write iterations, e.g., for two warps we write only // once by warp 1 and read only once by warp 0. - #pragma unroll +#pragma unroll for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll +#pragma unroll for (int i = red_off; i > 0; i /= 2) { if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll +#pragma unroll for (int j = 0; j < 4 * 2; j++) { int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll + float *c_rd = reinterpret_cast( + &sh[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh[red_sh_wr]); +#pragma unroll for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } __syncthreads(); } if (red_idx == 0) { - #pragma unroll +#pragma unroll for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll + float *c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +#pragma unroll for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; } } @@ -557,39 +559,39 @@ __global__ void Marlin( int row = (threadIdx.x % 32) / 4; if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll +// Interestingly, doing direct global accesses here really seems to mess up +// the compiler and lead to slowdowns, hence we also use async-copies even +// though these fetches are not actually asynchronous. +#pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + row < prob_m); } cp_async_fence(); cp_async_wait<0>(); } - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll +#pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( + reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); + Dtype::num2float(reinterpret_cast(&c_red)[j]); } } if (!last) { int4 c; - #pragma unroll +#pragma unroll for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); } C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = @@ -623,7 +625,7 @@ __global__ void Marlin( // We first reorder in shared memory to guarantee the most efficient final // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { + auto write = [&](int idx, float c0, float c1, FragS &s) { scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); @@ -633,13 +635,13 @@ __global__ void Marlin( res = __hmul2(res, s[0]); } - ((scalar_t2*)sh)[idx] = res; + ((scalar_t2 *)sh)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll +#pragma unroll for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll +#pragma unroll for (int j = 0; j < 4; j++) { int wr = c_sh_wr + 8 * j; write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], @@ -656,7 +658,7 @@ __global__ void Marlin( } __syncthreads(); - #pragma unroll +#pragma unroll for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { @@ -670,8 +672,9 @@ __global__ void Marlin( // Start global fetch and register load pipelines. auto start_pipes = [&]() { - #pragma unroll - for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); +#pragma unroll + for (int i = 0; i < stages - 1; i++) + fetch_to_shared(i, i, i < slice_iters); zero_accums(); wait_for_stage(); fetch_to_registers(0, 0); @@ -681,12 +684,12 @@ __global__ void Marlin( // Main loop. while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines have - // even length meaning that the next iteration will always start at index 0. - #pragma unroll +// We unroll over both the global fetch and the register load pipeline to +// ensure all shared memory accesses are static. Note that both pipelines have +// even length meaning that the next iteration will always start at index 0. +#pragma unroll for (int pipe = 0; pipe < stages;) { - #pragma unroll +#pragma unroll for (int k = 0; k < b_sh_wr_iters; k++) { fetch_to_registers(k + 1, pipe % stages); if (k == b_sh_wr_iters - 2) { @@ -698,7 +701,8 @@ __global__ void Marlin( matmul(k); } slice_iters--; - if (slice_iters == 0) break; + if (slice_iters == 0) + break; } a_gl_rd += a_gl_rd_delta_o * stages; @@ -711,7 +715,8 @@ __global__ void Marlin( // For per-column scales, we only fetch them here in the final step before // write-out if (group_blocks == -1 && last) { - if (s_sh_wr_pred) cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); + if (s_sh_wr_pred) + cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } thread_block_reduce(); @@ -719,17 +724,17 @@ __global__ void Marlin( cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice barrier_acquire(&locks[slice_col], slice_idx); global_reduce(slice_idx == 0, last); barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last) // only the last block in a slice actually writes the result write_result(); slice_row = 0; slice_col_par++; @@ -738,12 +743,13 @@ __global__ void Marlin( if (slice_iters) { a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll +#pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; } s_gl_rd = s_sh_stride * slice_col + threadIdx.x; start_pipes(); @@ -752,14 +758,16 @@ __global__ void Marlin( } } - // 8 warps are a good choice since every SM has 4 schedulers and having more // than 1 warp per schedule allows some more latency hiding. At the same time, // we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -const int SHARED_MEM = 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) -static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory +const int SHARED_MEM = + 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) +static constexpr int pack_factor_4bit = + 8; // We have 8 4-bit vals inside a 32 bit #define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ GROUP_BLOCKS, NUM_THREADS) \ @@ -767,13 +775,14 @@ static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute(Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, \ - SHARED_MEM); \ - Marlin<<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, SHARED_MEM); \ + Marlin \ + <<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, prob_m, prob_n, prob_k, locks); \ } typedef struct { @@ -785,22 +794,22 @@ typedef struct { thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N }; thread_config_t large_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X }; -bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, +bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, int prob_k) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || @@ -851,28 +860,27 @@ thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { return thread_config_t{-1, -1, -1}; } -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) -template -void marlin_matmul(const void* A, const void* B, void* s, void* C, int prob_m, int prob_k, - int prob_n, void* workspace, int groupsize - ) { - - int dev = 0; - cudaStream_t stream = 0; +template +void marlin_matmul(const void *A, const void *B, void *s, void *C, int prob_m, + int prob_k, int prob_n, void *workspace, int groupsize, + cudaStream_t stream) { + + int dev = 0; int thread_k = -1; - int thread_n = -1; - int sms = -1; + int thread_n = -1; + int sms = -1; int max_par = 16; int tot_m = prob_m; @@ -909,12 +917,12 @@ void marlin_matmul(const void* A, const void* B, void* s, void* C, int prob_m, i return; } - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + int4 *C_ptr = (int4 *)C; + const int4 *s_ptr = (const int4 *)s; - int* locks = (int*)workspace; + int *locks = (int *)workspace; for (int i = 0; i < tot_m_blocks; i += 4) { int thread_m_blocks = tot_m_blocks - i; @@ -924,7 +932,8 @@ void marlin_matmul(const void* A, const void* B, void* s, void* C, int prob_m, i // Note that parallel > 1 currently only works for inputs without any // padding par = (16 * thread_m_blocks - pad) / 64; - if (par > max_par) par = max_par; + if (par > max_par) + par = max_par; prob_m = 64 * par; i += 4 * (par - 1); thread_m_blocks = 4; @@ -948,40 +957,28 @@ void marlin_matmul(const void* A, const void* B, void* s, void* C, int prob_m, i } } -extern "C" void marlin_4bit_f16( - const void* A, - const void* B, - void* s, - void* C, - int prob_m, - int prob_k, - int prob_n, - void* workspace, - int groupsize -) { - marlin_matmul(A, B, s, C, prob_m, prob_k, prob_n, workspace, groupsize); +extern "C" void marlin_4bit_f16(const void *A, const void *B, void *s, void *C, + int prob_m, int prob_k, int prob_n, + void *workspace, int groupsize, + cudaStream_t stream) { + marlin_matmul(A, B, s, C, prob_m, prob_k, prob_n, workspace, groupsize, + stream); } -extern "C" void marlin_4bit_bf16( - const void* A, - const void* B, - void* s, - void* C, - int prob_m, - int prob_k, - int prob_n, - void* workspace, - int groupsize -) { - marlin_matmul(A, B, s, C, prob_m, prob_k, prob_n, workspace, groupsize); +extern "C" void marlin_4bit_bf16(const void *A, const void *B, void *s, void *C, + int prob_m, int prob_k, int prob_n, + void *workspace, int groupsize, + cudaStream_t stream) { + marlin_matmul(A, B, s, C, prob_m, prob_k, prob_n, workspace, + groupsize, stream); } - template -__global__ void gptq_marlin_repack_kernel( - uint32_t const* __restrict__ b_q_weight_ptr, - uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, - int size_k, int size_n) { +__global__ void +gptq_marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, + uint32_t *__restrict__ out_ptr, int size_k, + int size_n) { constexpr int pack_factor = 32 / num_bits; int k_tiles = size_k / tile_k_size; @@ -1009,8 +1006,8 @@ __global__ void gptq_marlin_repack_kernel( constexpr int perm_size = tile_k_size / 4; - int4* sh_perm_ptr = sh; - int4* sh_pipe_ptr = sh_perm_ptr; + int4 *sh_perm_ptr = sh; + int4 *sh_pipe_ptr = sh_perm_ptr; if constexpr (has_perm) { sh_pipe_ptr += perm_size; } @@ -1024,7 +1021,7 @@ __global__ void gptq_marlin_repack_kernel( auto load_perm_to_shared = [&](int k_tile_id) { int first_k_int4 = (k_tile_id * tile_k_size) / 4; - int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); + int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr); if (threadIdx.x < perm_size) { sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; @@ -1040,22 +1037,22 @@ __global__ void gptq_marlin_repack_kernel( int first_n = n_tile_id * tile_n_size; - int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; + int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; if constexpr (has_perm) { if (threadIdx.x < stage_size) { int k_id = threadIdx.x / stage_n_threads; int n_id = threadIdx.x % stage_n_threads; - uint32_t const* sh_perm_int_ptr = - reinterpret_cast(sh_perm_ptr); + uint32_t const *sh_perm_int_ptr = + reinterpret_cast(sh_perm_ptr); int src_k = sh_perm_int_ptr[k_id]; int src_k_packed = src_k / pack_factor; cp_async4( &sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast(&( + reinterpret_cast(&( b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); } @@ -1068,7 +1065,7 @@ __global__ void gptq_marlin_repack_kernel( int first_k_packed = first_k / pack_factor; cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( + reinterpret_cast( &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); } @@ -1099,10 +1096,10 @@ __global__ void gptq_marlin_repack_kernel( constexpr int sh_stride = 64; constexpr uint32_t mask = (1 << num_bits) - 1; - int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; - uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); - uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); uint32_t vals[8]; @@ -1127,13 +1124,13 @@ __global__ void gptq_marlin_repack_kernel( uint32_t b1_vals[tile_ints]; uint32_t b2_vals[tile_ints]; - #pragma unroll +#pragma unroll for (int i = 0; i < tile_ints; i++) { b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; } - #pragma unroll +#pragma unroll for (int i = 0; i < 4; i++) { int cur_elem = tc_row + tc_offsets[i]; int cur_int = cur_elem / pack_factor; @@ -1153,7 +1150,7 @@ __global__ void gptq_marlin_repack_kernel( constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; uint32_t res = 0; - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { res |= vals[pack_idx[i]] << (i * 4); } @@ -1165,7 +1162,7 @@ __global__ void gptq_marlin_repack_kernel( uint32_t res1 = 0; uint32_t res2 = 0; - #pragma unroll +#pragma unroll for (int i = 0; i < 4; i++) { res1 |= vals[pack_idx[i]] << (i * 8); res2 |= vals[4 + pack_idx[i]] << (i * 8); @@ -1177,14 +1174,14 @@ __global__ void gptq_marlin_repack_kernel( }; auto start_pipes = [&](int k_tile_id, int n_tile_id) { - #pragma unroll +#pragma unroll for (int pipe = 0; pipe < repack_stages - 1; pipe++) { fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); } wait_for_stage(); }; - #pragma unroll +#pragma unroll for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { int n_tile_id = 0; @@ -1195,7 +1192,7 @@ __global__ void gptq_marlin_repack_kernel( start_pipes(k_tile_id, n_tile_id); while (n_tile_id < n_tiles) { - #pragma unroll +#pragma unroll for (int pipe = 0; pipe < repack_stages; pipe++) { fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); @@ -1207,26 +1204,18 @@ __global__ void gptq_marlin_repack_kernel( } } - #define CALL_IF2(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ - cudaFuncSetAttribute( \ - gptq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - gptq_marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ - } +#define CALL_IF2(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + gptq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } -extern "C" void gptq_marlin_repack( - void* weight, - void* perm, - void* out, - int size_k, - int size_n, - int num_bits -) { +extern "C" void gptq_marlin_repack(void *weight, void *perm, void *out, + int size_k, int size_n, int num_bits) { // Verify compatibility with marlin tile of 16x64 assert(size_k % tile_k_size == 0); assert(size_n % tile_n_size == 0); @@ -1236,13 +1225,12 @@ extern "C" void gptq_marlin_repack( bool has_perm = true; int dev = 0; // Get ptrs - uint32_t const* b_q_weight_ptr = - reinterpret_cast(weight); - uint32_t const* perm_ptr = reinterpret_cast(perm); - uint32_t* out_ptr = reinterpret_cast(out); + uint32_t const *b_q_weight_ptr = reinterpret_cast(weight); + uint32_t const *perm_ptr = reinterpret_cast(perm); + uint32_t *out_ptr = reinterpret_cast(out); // Get dev info - cudaStream_t stream = 0; + cudaStream_t stream = 0; int blocks; cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); @@ -1260,7 +1248,6 @@ extern "C" void gptq_marlin_repack( else { assert(false); } - } #endif diff --git a/mistralrs-quant/kernels/ops/ops.cu b/mistralrs-quant/kernels/ops/ops.cu index b9d026790f..5dbfb2e464 100644 --- a/mistralrs-quant/kernels/ops/ops.cu +++ b/mistralrs-quant/kernels/ops/ops.cu @@ -20,7 +20,6 @@ __global__ void bitwise_or__kernel(const T *d_in1, const T *d_in2, T *d_out, } } - template void bitwise_or(const T *d_in1, const T *d_in2, T *d_out, int N) { int nthreads = mq_next_power_of_2(N); @@ -29,13 +28,12 @@ void bitwise_or(const T *d_in1, const T *d_in2, T *d_out, int N) { } const int nblocks = (N + nthreads - 1) / nthreads; bitwise_or__kernel<<>>(d_in1, d_in2, d_out, N); - cudaDeviceSynchronize(); } #define BITWISE_OP(TYPENAME, RUST_NAME) \ - extern "C" void mq_bitwise_or_##RUST_NAME(const TYPENAME *d_in1, \ - const TYPENAME *d_in2, \ - TYPENAME *d_out, uint32_t N) { \ + extern "C" void mq_bitwise_or_##RUST_NAME(const TYPENAME *d_in1, \ + const TYPENAME *d_in2, \ + TYPENAME *d_out, uint32_t N) { \ bitwise_or(d_in1, d_in2, d_out, N); \ } @@ -45,8 +43,8 @@ BITWISE_OP(int64_t, i64) BITWISE_OP(int32_t, i32) template -__global__ void leftshift_kernel(const T *d_in1, T *d_out, - const uint32_t N, const int32_t k) { +__global__ void leftshift_kernel(const T *d_in1, T *d_out, const uint32_t N, + const int32_t k) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < N) { d_out[idx] = d_in1[idx] << k; @@ -61,13 +59,12 @@ void leftshift(const T *d_in1, T *d_out, int N, const int32_t k) { } const int nblocks = (N + nthreads - 1) / nthreads; leftshift_kernel<<>>(d_in1, d_out, N, k); - cudaDeviceSynchronize(); } -#define LEFTSHIFT_OP(TYPENAME, RUST_NAME) \ - extern "C" void mq_leftshift_##RUST_NAME(const TYPENAME *d_in1, \ - TYPENAME *d_out, uint32_t N, int32_t k) { \ - leftshift(d_in1, d_out, N, k); \ +#define LEFTSHIFT_OP(TYPENAME, RUST_NAME) \ + extern "C" void mq_leftshift_##RUST_NAME( \ + const TYPENAME *d_in1, TYPENAME *d_out, uint32_t N, int32_t k) { \ + leftshift(d_in1, d_out, N, k); \ } LEFTSHIFT_OP(uint8_t, u8) diff --git a/mistralrs-quant/src/blockwise_fp8/ffi.rs b/mistralrs-quant/src/blockwise_fp8/ffi.rs index da5427c80b..fa01253f49 100644 --- a/mistralrs-quant/src/blockwise_fp8/ffi.rs +++ b/mistralrs-quant/src/blockwise_fp8/ffi.rs @@ -14,6 +14,7 @@ extern "C" { scale_stride: i32, weight_block_size_y: i32, weight_block_size_x: i32, + stream: candle_core::cuda::cudarc::driver::sys::CUstream, ); pub(crate) fn launch_dequant_fp8_blockwise_kernel_f16( @@ -26,6 +27,7 @@ extern "C" { scale_stride: i32, weight_block_size_y: i32, weight_block_size_x: i32, + stream: candle_core::cuda::cudarc::driver::sys::CUstream, ); pub(crate) fn launch_dequant_fp8_blockwise_kernel_bf16( @@ -38,5 +40,6 @@ extern "C" { scale_stride: i32, weight_block_size_y: i32, weight_block_size_x: i32, + stream: candle_core::cuda::cudarc::driver::sys::CUstream, ); } diff --git a/mistralrs-quant/src/blockwise_fp8/ops.rs b/mistralrs-quant/src/blockwise_fp8/ops.rs index c1a437a4a3..04b54177d7 100644 --- a/mistralrs-quant/src/blockwise_fp8/ops.rs +++ b/mistralrs-quant/src/blockwise_fp8/ops.rs @@ -143,6 +143,8 @@ impl CustomOp2 for Fp8BlockwiseDequantize { candle_core::bail!("Expected scale to be rank 2"); } + let dev = weight_s.device(); + let weight = weight_s .as_cuda_slice::()? .slice(weight_l.start_offset()..); @@ -174,6 +176,7 @@ impl CustomOp2 for Fp8BlockwiseDequantize { scale_stride, weight_block_size_y, weight_block_size_x, + *dev.cu_stream(), ) }; CudaStorage::wrap_cuda_slice(output, weight_s.device().clone()) @@ -194,6 +197,7 @@ impl CustomOp2 for Fp8BlockwiseDequantize { scale_stride, weight_block_size_y, weight_block_size_x, + *dev.cu_stream(), ) }; CudaStorage::wrap_cuda_slice(output, weight_s.device().clone()) @@ -214,6 +218,7 @@ impl CustomOp2 for Fp8BlockwiseDequantize { scale_stride, weight_block_size_y, weight_block_size_x, + *dev.cu_stream(), ) }; CudaStorage::wrap_cuda_slice(output, weight_s.device().clone()) diff --git a/mistralrs-quant/src/cublaslt/mod.rs b/mistralrs-quant/src/cublaslt/mod.rs index c11540e6c8..d51f60900e 100644 --- a/mistralrs-quant/src/cublaslt/mod.rs +++ b/mistralrs-quant/src/cublaslt/mod.rs @@ -25,7 +25,7 @@ static mut CUBLASLT: Option = None; pub static CUBLASLT_HANDLE: Lazy>> = Lazy::new(|| Mutex::new(None)); -pub fn maybe_init_cublas_lt_wrapper() { +pub fn maybe_init_cublas_lt_wrapper(device: Device) { unsafe { INIT.call_once(|| { #[cfg(not(feature = "cuda"))] @@ -39,15 +39,12 @@ pub fn maybe_init_cublas_lt_wrapper() { // Then check if we can create a device // Then check that the device is CUDA use candle_core::cuda_backend::cudarc::driver; - CUBLASLT = driver::result::init() - .ok() - .and_then(|_| Device::cuda_if_available(0).ok()) - .and_then(|device| match device { - Device::Cuda(_) => Some(CublasLtWrapper { - cublaslt: CublasLt::new(&device).unwrap(), - }), - _ => None, - }); + CUBLASLT = match device { + Device::Cuda(_) => Some(CublasLtWrapper { + cublaslt: CublasLt::new(&device).unwrap(), + }), + _ => None, + } } #[allow(static_mut_refs)] let cublaslt: Option<&'static CublasLtWrapper> = CUBLASLT.as_ref(); diff --git a/mistralrs-quant/src/distributed/mod.rs b/mistralrs-quant/src/distributed/mod.rs index e447602c47..e185dff589 100644 --- a/mistralrs-quant/src/distributed/mod.rs +++ b/mistralrs-quant/src/distributed/mod.rs @@ -69,6 +69,7 @@ mod ops { impl Comm { pub fn from_device(id: Id, dev: &Device, rank: usize, world_size: usize) -> Result { let device = dev.as_cuda_device()?.cuda_device(); + assert_eq!(rank, device.ordinal()); Ok(Self { comm: cudarc::nccl::Comm::from_rank(device, rank, world_size, id.0) .map_err(|e| e.0) @@ -133,6 +134,8 @@ mod ops { Some((0, l)) if l == s.len() => s, Some(_) | None => candle_core::bail!("input has to be contiguous"), }; + assert_eq!(dev.ordinal(), self.comm.rank()); + assert!(elem_count > 0); let mut dst = unsafe { dev.alloc::(elem_count) }.w()?; self.comm .comm diff --git a/mistralrs-quant/src/fp8/mod.rs b/mistralrs-quant/src/fp8/mod.rs index c42082e3bc..24ad1f3585 100644 --- a/mistralrs-quant/src/fp8/mod.rs +++ b/mistralrs-quant/src/fp8/mod.rs @@ -66,7 +66,7 @@ impl QuantMethod for FP8Linear { fn forward(&self, x: &Tensor) -> Result { // Batch matrix multiplication - maybe_init_cublas_lt_wrapper(); + maybe_init_cublas_lt_wrapper(x.device().clone()); match *CUBLASLT_HANDLE.lock().unwrap() { Some(handle) => { diff --git a/mistralrs-quant/src/fp8/quantize.rs b/mistralrs-quant/src/fp8/quantize.rs index 575dbd7dd4..dc41dd279a 100644 --- a/mistralrs-quant/src/fp8/quantize.rs +++ b/mistralrs-quant/src/fp8/quantize.rs @@ -116,7 +116,7 @@ mod tests { let mut x = Tensor::rand(0., 1., (1, 16, 32), &dev)?.to_dtype(DType::F32)?; // Batch matrix multiplication - maybe_init_cublas_lt_wrapper(); + maybe_init_cublas_lt_wrapper(x.device().clone()); let handle = CUBLASLT_HANDLE.lock().unwrap().unwrap(); diff --git a/mistralrs-quant/src/gptq/marlin_backend.rs b/mistralrs-quant/src/gptq/marlin_backend.rs index 200ad8a1a1..c0aed7498d 100644 --- a/mistralrs-quant/src/gptq/marlin_backend.rs +++ b/mistralrs-quant/src/gptq/marlin_backend.rs @@ -90,6 +90,7 @@ impl GPTQMatMul { size_n as i32, workspace_ptr, groupsize, + *dev.cu_stream(), ); } } else if x.dtype() == DType::BF16 { @@ -104,6 +105,7 @@ impl GPTQMatMul { size_n as i32, workspace_ptr, groupsize, + *dev.cu_stream(), ); } } diff --git a/mistralrs-quant/src/gptq/marlin_ffi.rs b/mistralrs-quant/src/gptq/marlin_ffi.rs index c988d21f4e..a51a5485c7 100644 --- a/mistralrs-quant/src/gptq/marlin_ffi.rs +++ b/mistralrs-quant/src/gptq/marlin_ffi.rs @@ -1,5 +1,7 @@ use std::os::raw::c_void; +use candle_core::cuda::cudarc::driver::sys::CUstream; + /// THIS IS AUTOGENERATED BY `build.rs`. DO NOT CHANGE! /// It indicated if the Marlin kernels were actually compiled. pub(crate) const HAVE_MARLIN_KERNELS: bool = true; @@ -16,6 +18,7 @@ extern "C" { n: i32, workspace: *const c_void, //tensor with at least `n / 128 * max_par` entries that are all zero groupsize: i32, + stream: CUstream, ); pub(crate) fn marlin_4bit_bf16( @@ -28,6 +31,7 @@ extern "C" { n: i32, workspace: *const c_void, //tensor with at least `n / 128 * max_par` entries that are all zero groupsize: i32, + stream: CUstream, ); pub(crate) fn gptq_marlin_repack( diff --git a/mistralrs-quant/src/metal_kernels/bitwise.metal b/mistralrs-quant/src/metal_kernels/bitwise.metal index 212f624620..caaab5db46 100644 --- a/mistralrs-quant/src/metal_kernels/bitwise.metal +++ b/mistralrs-quant/src/metal_kernels/bitwise.metal @@ -1,47 +1,37 @@ #include template -[[kernel]] void bitwise_or( - const device T* a [[buffer(0)]], - const device T* b [[buffer(1)]], - device T* output [[buffer(2)]], - uint tid [[ thread_position_in_grid ]] -) { - output[tid] = a[tid] | b[tid]; +[[kernel]] void bitwise_or(const device T *a [[buffer(0)]], + const device T *b [[buffer(1)]], + device T *output [[buffer(2)]], + uint tid [[thread_position_in_grid]]) { + output[tid] = a[tid] | b[tid]; } -#define instantiate_bitwise_or(type) \ - template [[host_name("bitwise_or_" #type)]] \ - [[kernel]] void bitwise_or( \ - const device type* a [[buffer(0)]], \ - const device type* b [[buffer(1)]], \ - device type* out [[buffer(2)]], \ - uint tid [[ thread_position_in_grid ]]); +#define instantiate_bitwise_or(type) \ + template [[host_name("bitwise_or_" #type)]] [[kernel]] void \ + bitwise_or(const device type *a [[buffer(0)]], \ + const device type *b [[buffer(1)]], \ + device type *out [[buffer(2)]], \ + uint tid [[thread_position_in_grid]]); -instantiate_bitwise_or(uint8_t) -instantiate_bitwise_or(uint32_t) -instantiate_bitwise_or(int64_t) -instantiate_bitwise_or(int) +instantiate_bitwise_or(uint8_t) instantiate_bitwise_or(uint32_t) + instantiate_bitwise_or(int64_t) instantiate_bitwise_or(int) -template -[[kernel]] void bitwise_leftshift( - const device T* a [[buffer(0)]], - device T* output [[buffer(1)]], - device const uint& k, - uint tid [[ thread_position_in_grid ]] -) { - output[tid] = a[tid] << k; + template + [[kernel]] void bitwise_leftshift(const device T *a [[buffer(0)]], + device T *output [[buffer(1)]], + device const uint &k, + uint tid [[thread_position_in_grid]] + ) { + output[tid] = a[tid] << k; } -#define instantiate_bitwise_leftshift(type) \ - template [[host_name("bitwise_leftshift_" #type)]] \ - [[kernel]] void bitwise_leftshift( \ - const device type* a [[buffer(0)]], \ - device type* out [[buffer(1)]], \ - device const uint& k, \ - uint tid [[ thread_position_in_grid ]]); +#define instantiate_bitwise_leftshift(type) \ + template [[host_name("bitwise_leftshift_" #type)]] [[kernel]] void \ + bitwise_leftshift( \ + const device type *a [[buffer(0)]], device type *out [[buffer(1)]], \ + device const uint &k, uint tid [[thread_position_in_grid]]); -instantiate_bitwise_leftshift(uint8_t) -instantiate_bitwise_leftshift(uint32_t) -instantiate_bitwise_leftshift(int64_t) -instantiate_bitwise_leftshift(int) +instantiate_bitwise_leftshift(uint8_t) instantiate_bitwise_leftshift(uint32_t) + instantiate_bitwise_leftshift(int64_t) instantiate_bitwise_leftshift(int) diff --git a/mistralrs-quant/src/metal_kernels/bnb_dequantize.metal b/mistralrs-quant/src/metal_kernels/bnb_dequantize.metal index a79a31ab78..36fb4269d5 100644 --- a/mistralrs-quant/src/metal_kernels/bnb_dequantize.metal +++ b/mistralrs-quant/src/metal_kernels/bnb_dequantize.metal @@ -66,57 +66,49 @@ struct _MLX_BFloat16 { ///////////////////////////////////////////////////////////////////////////// // Conversions to bfloat - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) thread : bits_(float_to_bfloat_bits(static_cast(x))) {} - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup : bits_(float_to_bfloat_bits(static_cast(x))) {} - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) device : bits_(float_to_bfloat_bits(static_cast(x))) {} - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC _MLX_BFloat16(T x) constant : bits_(float_to_bfloat_bits(static_cast(x))) {} ///////////////////////////////////////////////////////////////////////////// // Conversions from bfloat - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const thread { return static_cast(bfloat_bits_to_float(bits_)); } - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const threadgroup { return static_cast(bfloat_bits_to_float(bits_)); } - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const device { return static_cast(bfloat_bits_to_float(bits_)); } - template < - typename T, - typename = typename enable_if>::type> + template >::type> constexpr METAL_FUNC operator T() const constant { return static_cast(bfloat_bits_to_float(bits_)); } @@ -134,29 +126,29 @@ constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { ///////////////////////////////////////////////////////////////////////////// // Binary operators -#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ - constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ } -#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ - constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ - } \ - constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ - return static_cast(lhs) __op__ static_cast(rhs); \ +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ } ///////////////////////////////////////////////////////////////////////////// // Arithmetic Operators -#define bfloat_binop(_op_, _operator_) \ - bfloat_binop_base( \ - _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(_op_, _operator_, float, float, float); \ - bfloat_binop_helper(_op_, _operator_, float, half, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ - bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base(_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, \ + _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); bfloat_binop(+, operator+); @@ -166,14 +158,14 @@ bfloat_binop(/, operator/); ///////////////////////////////////////////////////////////////////////////// // Comparison ops -#define bfloat_compop(__op__, __operator__) \ - bfloat_binop_base( \ - __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ - bfloat_binop_helper(__op__, __operator__, bool, float, float); \ - bfloat_binop_helper(__op__, __operator__, bool, half, float); \ - bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ - bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base(__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, \ + float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); bfloat_compop(>, operator>); @@ -190,27 +182,27 @@ bfloat_compop(!=, operator!=); ///////////////////////////////////////////////////////////////////////////// // Inplace Operators -#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ - constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ - addr_space _MLX_BFloat16& lhs, itype rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ - } \ - constexpr METAL_FUNC addr_space itype& __operator__( \ - addr_space itype& lhs, _MLX_BFloat16 rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype &__operator__(addr_space itype &lhs, \ + _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ } -#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ - bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ - bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); -#define bfloat_inplace_op(itype) \ - bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ - bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ - bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ bfloat_inplace_op_addr_space_helper(/, operator/=, itype); bfloat_inplace_op(float); @@ -226,16 +218,16 @@ bfloat_inplace_op(uint64_t); #undef bfloat_inplace_op_addr_space_helper #undef bfloat_inplace_op -#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ - constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ - addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ - lhs = static_cast(lhs) __op__ static_cast(rhs); \ - return lhs; \ +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16 &__operator__( \ + addr_space _MLX_BFloat16 &lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ } -#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ - bfloat_inplace_op_helper(__op__, __operator__, device); \ - bfloat_inplace_op_helper(__op__, __operator__, thread); \ +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ bfloat_inplace_op_helper(__op__, __operator__, threadgroup); bfloat_inplace_op_addr_space_helper(+, operator+=); @@ -254,221 +246,207 @@ typedef struct _MLX_BFloat16 bfloat16_t; #endif - float dequantize_fp4_tree(unsigned char val, float absmax) { - float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; - if ((val & 0b0100) == 4) { // 0 - if ((val & 0b0010) == 2) { // 01 - if ((val & 0b0001) == 1) { // 111 - return 0.25000000f * absmax * sign; // 1111 - } else { - return 0.16666667f * absmax * sign; // 1110 - } + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if ((val & 0b0100) == 4) { // 0 + if ((val & 0b0010) == 2) { // 01 + if ((val & 0b0001) == 1) { // 111 + return 0.25000000f * absmax * sign; // 1111 + } else { + return 0.16666667f * absmax * sign; // 1110 + } + } else { + if ((val & 0b0001) == 1) { // 110 + return 0.50000000f * absmax * sign; // 1101 + } else { + return 0.33333333f * absmax * sign; // 1100 + } + } + } else { + if ((val & 0b0010) == 2) { // 10 + if ((val & 0b0001) == 1) { // 101 + return 1.00000000f * absmax * sign; // 1011 + } else { + return 0.66666667f * absmax * sign; // 1010 + } + } else { + if ((val & 0b0001) == 1) { // 100 + return 5.208333333e-03f * absmax * sign; // 1001 + } else { + return 0.00000000f * absmax * sign; // 1000 + } + } + } +} + +float dequantize_nf4(unsigned char val) { + if ((val & 0b1000) == 8) { + if ((val & 0b0100) == 4) { // 1 + if ((val & 0b0010) == 2) { // 11 + if ((val & 0b0001) == 1) { // 111 + return 1.0f; + } else { + return 0.7229568362236023f; + } + } else { + if ((val & 0b0001) == 1) { // 110 + return 0.5626170039176941f; } else { - if ((val & 0b0001) == 1) { // 110 - return 0.50000000f * absmax * sign; // 1101 - } else { - return 0.33333333f * absmax * sign; // 1100 - } + return 0.44070982933044434f; } + } } else { - if ((val & 0b0010) == 2) { // 10 - if ((val & 0b0001) == 1) { // 101 - return 1.00000000f * absmax * sign; // 1011 - } else { - return 0.66666667f * absmax * sign; // 1010 - } + if ((val & 0b0010) == 2) { // 10 + if ((val & 0b0001) == 1) { // 101 + return 0.33791524171829224f; } else { - if ((val & 0b0001) == 1) { // 100 - return 5.208333333e-03f * absmax * sign; // 1001 - } else { - return 0.00000000f * absmax * sign; // 1000 - } + return 0.24611230194568634f; } + } else { + if ((val & 0b0001) == 1) { // 100 + return 0.16093020141124725f; + } else { + return 0.07958029955625534f; + } + } } -} - -float dequantize_nf4(unsigned char val) { - if ((val & 0b1000) == 8) { - if ((val & 0b0100) == 4) { // 1 - if ((val & 0b0010) == 2) { // 11 - if ((val & 0b0001) == 1) { // 111 - return 1.0f; - } else { - return 0.7229568362236023f; - } - } else { - if ((val & 0b0001) == 1) { // 110 - return 0.5626170039176941f; - } else { - return 0.44070982933044434f; - } - } + } else { + if ((val & 0b0100) == 4) { // 0 + if ((val & 0b0010) == 2) { // 01 + if ((val & 0b0001) == 1) { // 011 + return 0.0f; + } else { + return -0.09105003625154495f; + } + } else { + if ((val & 0b0001) == 1) { // 010 + return -0.18477343022823334f; } else { - if ((val & 0b0010) == 2) { // 10 - if ((val & 0b0001) == 1) { // 101 - return 0.33791524171829224f; - } else { - return 0.24611230194568634f; - } - } else { - if ((val & 0b0001) == 1) { // 100 - return 0.16093020141124725f; - } else { - return 0.07958029955625534f; - } - } + return -0.28444138169288635f; } + } } else { - if ((val & 0b0100) == 4) { // 0 - if ((val & 0b0010) == 2) { // 01 - if ((val & 0b0001) == 1) { // 011 - return 0.0f; - } else { - return -0.09105003625154495f; - } - } else { - if ((val & 0b0001) == 1) { // 010 - return -0.18477343022823334f; - } else { - return -0.28444138169288635f; - } - } + if ((val & 0b0010) == 2) { // 00 + if ((val & 0b0001) == 1) { // 001 + return -0.39491748809814453f; } else { - if ((val & 0b0010) == 2) { // 00 - if ((val & 0b0001) == 1) { // 001 - return -0.39491748809814453f; - } else { - return -0.5250730514526367f; - } - } else { - if ((val & 0b0001) == 1) { // 000 - return -0.6961928009986877f; - } else { - return -1.0f; - } - } + return -0.5250730514526367f; } + } else { + if ((val & 0b0001) == 1) { // 000 + return -0.6961928009986877f; + } else { + return -1.0f; + } + } } + } } template -[[kernel]] void kernel_dequantize_nf4( - const device float* code [[buffer(0)]], - const device uchar* input [[buffer(1)]], - const device float* absmax [[buffer(2)]], - device T* out [[buffer(3)]], - device const int& blocksize, - device const int& n, - uint id [[thread_position_in_grid]]) { - - int block_idx = id * blocksize; - int valid_items = (n > blocksize + block_idx) ? blocksize : (n - block_idx); - int block_end = block_idx + valid_items; - - for (int i = block_idx; i < block_end; ++i) { - float local_abs_max = absmax[block_idx / (blocksize / 2)]; - - uint8_t input_value = static_cast(input[i]); - float high_nibble = dequantize_nf4(input_value >> 4); - float low_nibble = dequantize_nf4(input_value & 0x0F); - - out[i * 2] = static_cast(high_nibble * local_abs_max); - out[i * 2 + 1] = static_cast(low_nibble * local_abs_max); - } +[[kernel]] void kernel_dequantize_nf4(const device float *code [[buffer(0)]], + const device uchar *input [[buffer(1)]], + const device float *absmax [[buffer(2)]], + device T *out [[buffer(3)]], + device const int &blocksize, + device const int &n, + uint id [[thread_position_in_grid]]) { + + int block_idx = id * blocksize; + int valid_items = (n > blocksize + block_idx) ? blocksize : (n - block_idx); + int block_end = block_idx + valid_items; + + for (int i = block_idx; i < block_end; ++i) { + float local_abs_max = absmax[block_idx / (blocksize / 2)]; + + uint8_t input_value = static_cast(input[i]); + float high_nibble = dequantize_nf4(input_value >> 4); + float low_nibble = dequantize_nf4(input_value & 0x0F); + + out[i * 2] = static_cast(high_nibble * local_abs_max); + out[i * 2 + 1] = static_cast(low_nibble * local_abs_max); + } } template -[[kernel]] void kernel_dequantize_fp4( - const device float* code [[buffer(0)]], - const device uchar* input [[buffer(1)]], - const device float* absmax [[buffer(2)]], - device T* out [[buffer(3)]], - device const int& blocksize, - device const int& n, - uint id [[thread_position_in_grid]]) { - - int block_idx = id * blocksize; - int valid_items = (n > blocksize + block_idx) ? blocksize : (n - block_idx); - int block_end = block_idx + valid_items; - - for (int i = block_idx; i < block_end; ++i) { - float local_abs_max = absmax[block_idx / (blocksize / 2)]; - - // Extract the high and low nibbles from the input value - uint8_t input_value = static_cast(input[i]); - float high_nibble = dequantize_fp4_tree(input_value >> 4, local_abs_max); - float low_nibble = dequantize_fp4_tree(input_value & 0x0F, local_abs_max); - - out[i * 2] = static_cast(high_nibble); - out[i * 2 + 1] = static_cast(low_nibble); - } +[[kernel]] void kernel_dequantize_fp4(const device float *code [[buffer(0)]], + const device uchar *input [[buffer(1)]], + const device float *absmax [[buffer(2)]], + device T *out [[buffer(3)]], + device const int &blocksize, + device const int &n, + uint id [[thread_position_in_grid]]) { + + int block_idx = id * blocksize; + int valid_items = (n > blocksize + block_idx) ? blocksize : (n - block_idx); + int block_end = block_idx + valid_items; + + for (int i = block_idx; i < block_end; ++i) { + float local_abs_max = absmax[block_idx / (blocksize / 2)]; + + // Extract the high and low nibbles from the input value + uint8_t input_value = static_cast(input[i]); + float high_nibble = dequantize_fp4_tree(input_value >> 4, local_abs_max); + float low_nibble = dequantize_fp4_tree(input_value & 0x0F, local_abs_max); + + out[i * 2] = static_cast(high_nibble); + out[i * 2 + 1] = static_cast(low_nibble); + } } template -[[kernel]] void kernel_dequantize_int8( - const device float* code [[buffer(0)]], - const device uchar* input [[buffer(1)]], - const device float* absmax [[buffer(2)]], - device T* out [[buffer(3)]], - device const int& blocksize, - device const int& n, - uint id [[thread_position_in_grid]]) { - - int block_idx = id * blocksize; - int valid_items = (n > blocksize + block_idx) ? blocksize : (n - block_idx); - int block_end = block_idx + valid_items; - - for (int i = block_idx; i < block_end; ++i) { - float local_abs_max = absmax[block_idx / blocksize]; - - out[i] = static_cast(code[input[i]] * local_abs_max); - } +[[kernel]] void kernel_dequantize_int8(const device float *code [[buffer(0)]], + const device uchar *input [[buffer(1)]], + const device float *absmax [[buffer(2)]], + device T *out [[buffer(3)]], + device const int &blocksize, + device const int &n, + uint id [[thread_position_in_grid]]) { + + int block_idx = id * blocksize; + int valid_items = (n > blocksize + block_idx) ? blocksize : (n - block_idx); + int block_end = block_idx + valid_items; + + for (int i = block_idx; i < block_end; ++i) { + float local_abs_max = absmax[block_idx / blocksize]; + + out[i] = static_cast(code[input[i]] * local_abs_max); + } } -#define instantiate_dequantize_nf4(type) \ -template [[host_name("kernel_dequantize_nf4_" #type )]] \ -[[kernel]] void kernel_dequantize_nf4( \ - const device float* code [[buffer(0)]], \ - const device uchar* input [[buffer(1)]], \ - const device float* absmax [[buffer(2)]], \ - device type* out [[buffer(3)]], \ - device const int& blocksize, \ - device const int& n, \ - uint id [[thread_position_in_grid]]); \ - -instantiate_dequantize_nf4(float) -instantiate_dequantize_nf4(bfloat16_t) -instantiate_dequantize_nf4(half) - - -#define instantiate_dequantize_fp4(type) \ -template [[host_name("kernel_dequantize_fp4_" #type )]] \ -[[kernel]] void kernel_dequantize_fp4( \ - const device float* code [[buffer(0)]], \ - const device uchar* input [[buffer(1)]], \ - const device float* absmax [[buffer(2)]], \ - device type* out [[buffer(3)]], \ - device const int& blocksize, \ - device const int& n, \ - uint id [[thread_position_in_grid]]); \ - -instantiate_dequantize_fp4(float) -instantiate_dequantize_fp4(bfloat16_t) -instantiate_dequantize_fp4(half) - - -#define instantiate_dequantize_int8(type) \ -template [[host_name("kernel_dequantize_int8_" #type )]] \ -[[kernel]] void kernel_dequantize_int8( \ - const device float* code [[buffer(0)]], \ - const device uchar* input [[buffer(1)]], \ - const device float* absmax [[buffer(2)]], \ - device type* out [[buffer(3)]], \ - device const int& blocksize, \ - device const int& n, \ - uint id [[thread_position_in_grid]]); \ - -instantiate_dequantize_int8(float) -instantiate_dequantize_int8(bfloat16_t) -instantiate_dequantize_int8(half) +#define instantiate_dequantize_nf4(type) \ + template [[host_name("kernel_dequantize_nf4_" #type)]] [[kernel]] void \ + kernel_dequantize_nf4( \ + const device float *code [[buffer(0)]], \ + const device uchar *input [[buffer(1)]], \ + const device float *absmax [[buffer(2)]], \ + device type *out [[buffer(3)]], device const int &blocksize, \ + device const int &n, uint id [[thread_position_in_grid]]); + +instantiate_dequantize_nf4(float) instantiate_dequantize_nf4(bfloat16_t) + instantiate_dequantize_nf4(half) + +#define instantiate_dequantize_fp4(type) \ + template [[host_name("kernel_dequantize_fp4_" #type)]] [[kernel]] void \ + kernel_dequantize_fp4( \ + const device float *code [[buffer(0)]], \ + const device uchar *input [[buffer(1)]], \ + const device float *absmax [[buffer(2)]], \ + device type *out [[buffer(3)]], device const int &blocksize, \ + device const int &n, uint id [[thread_position_in_grid]]); + + instantiate_dequantize_fp4(float) instantiate_dequantize_fp4(bfloat16_t) + instantiate_dequantize_fp4(half) + +#define instantiate_dequantize_int8(type) \ + template [[host_name("kernel_dequantize_int8_" #type)]] [[kernel]] void \ + kernel_dequantize_int8( \ + const device float *code [[buffer(0)]], \ + const device uchar *input [[buffer(1)]], \ + const device float *absmax [[buffer(2)]], \ + device type *out [[buffer(3)]], device const int &blocksize, \ + device const int &n, uint id [[thread_position_in_grid]]); + + instantiate_dequantize_int8(float) + instantiate_dequantize_int8(bfloat16_t) + instantiate_dequantize_int8(half) diff --git a/mistralrs-quant/src/metal_kernels/hqq_dequantize.metal b/mistralrs-quant/src/metal_kernels/hqq_dequantize.metal index aa5b65efb0..bf08c68da3 100644 --- a/mistralrs-quant/src/metal_kernels/hqq_dequantize.metal +++ b/mistralrs-quant/src/metal_kernels/hqq_dequantize.metal @@ -5,200 +5,199 @@ //********************************/ template -[[kernel]] void dequantize_8bit( - const device char* weight [[buffer(0)]], - const device T* scale [[buffer(1)]], - const device T* zero [[buffer(2)]], - device T* output [[buffer(3)]], - device const uint& h, - device const uint& w, - uint tid [[ thread_position_in_grid ]] -) { - uint j = tid % w; - output[tid] = ((T)(weight[tid]) - zero[j])*scale[j]; +[[kernel]] void dequantize_8bit(const device char *weight [[buffer(0)]], + const device T *scale [[buffer(1)]], + const device T *zero [[buffer(2)]], + device T *output [[buffer(3)]], + device const uint &h, device const uint &w, + uint tid [[thread_position_in_grid]]) { + uint j = tid % w; + output[tid] = ((T)(weight[tid]) - zero[j]) * scale[j]; } -#define instantiate_dequantize_8bit(type) \ - template [[host_name("dequantize_8bit_" #type)]] \ - [[kernel]] void dequantize_8bit( \ - const device char* weight [[buffer(0)]], \ - const device type* scale [[buffer(1)]], \ - const device type* zero [[buffer(2)]], \ - device type* output [[buffer(3)]], \ - device const uint& h, \ - device const uint& w, \ - uint tid [[ thread_position_in_grid ]]); +#define instantiate_dequantize_8bit(type) \ + template [[host_name("dequantize_8bit_" #type)]] [[kernel]] void \ + dequantize_8bit(const device char *weight [[buffer(0)]], \ + const device type *scale [[buffer(1)]], \ + const device type *zero [[buffer(2)]], \ + device type *output [[buffer(3)]], \ + device const uint &h, device const uint &w, \ + uint tid [[thread_position_in_grid]]); instantiate_dequantize_8bit(float) #if defined(__HAVE_BFLOAT__) -instantiate_dequantize_8bit(bfloat) + instantiate_dequantize_8bit(bfloat) #endif -instantiate_dequantize_8bit(half) - - -/*********************************/ -/************* 4-bit *************/ -//********************************/ - -template -[[kernel]] void dequantize_4bit( - const device char* weight [[buffer(0)]], - const device T* scale [[buffer(1)]], - const device T* zero [[buffer(2)]], - device T* output [[buffer(3)]], - device const uint& h, - device const uint& w, - uint tid [[ thread_position_in_grid ]] -) { - uint n = h*w; - uint j = tid % w; - output[tid] = ((T)((weight[tid] & 0xF0) >> 4) - zero[j])*scale[j]; // First chunk - output[tid + n] = ((T)((weight[tid] & 0x0F)) - zero[j])*scale[j]; // Second chunk + instantiate_dequantize_8bit(half) + + /*********************************/ + /************* 4-bit *************/ + //********************************/ + + template + [[kernel]] void dequantize_4bit(const device char *weight [[buffer(0)]], + const device T *scale [[buffer(1)]], + const device T *zero [[buffer(2)]], + device T *output [[buffer(3)]], + device const uint &h, device const uint &w, + uint tid [[thread_position_in_grid]] + ) { + uint n = h * w; + uint j = tid % w; + output[tid] = + ((T)((weight[tid] & 0xF0) >> 4) - zero[j]) * scale[j]; // First chunk + output[tid + n] = + ((T)((weight[tid] & 0x0F)) - zero[j]) * scale[j]; // Second chunk } -#define instantiate_dequantize_4bit(type) \ - template [[host_name("dequantize_4bit_" #type)]] \ - [[kernel]] void dequantize_4bit( \ - const device char* weight [[buffer(0)]], \ - const device type* scale [[buffer(1)]], \ - const device type* zero [[buffer(2)]], \ - device type* output [[buffer(3)]], \ - device const uint& h, \ - device const uint& w, \ - uint tid [[ thread_position_in_grid ]]); +#define instantiate_dequantize_4bit(type) \ + template [[host_name("dequantize_4bit_" #type)]] [[kernel]] void \ + dequantize_4bit(const device char *weight [[buffer(0)]], \ + const device type *scale [[buffer(1)]], \ + const device type *zero [[buffer(2)]], \ + device type *output [[buffer(3)]], \ + device const uint &h, device const uint &w, \ + uint tid [[thread_position_in_grid]]); instantiate_dequantize_4bit(float) #if defined(__HAVE_BFLOAT__) -instantiate_dequantize_4bit(bfloat) + instantiate_dequantize_4bit(bfloat) #endif -instantiate_dequantize_4bit(half) - - -/*********************************/ -/************* 2-bit *************/ -//********************************/ - -template -[[kernel]] void dequantize_2bit( - const device char* weight [[buffer(0)]], - const device T* scale [[buffer(1)]], - const device T* zero [[buffer(2)]], - device T* output [[buffer(3)]], - device const uint& h, - device const uint& w, - uint tid [[ thread_position_in_grid ]] -) { - uint n = h*w; - uint j = tid % w; - output[tid] = ((T)((weight[tid] & 0xC0) >> 6) - zero[j])*scale[j]; // 1st chunk - output[tid + n] = ((T)((weight[tid] & 0x30) >> 4) - zero[j])*scale[j]; // 2nd chunk - output[tid + n*2] = ((T)((weight[tid] & 0x0C) >> 2) - zero[j])*scale[j]; // 3rd chunk - output[tid + n*3] = ((T)((weight[tid] & 0x03)) - zero[j])*scale[j]; // 4th chunk + instantiate_dequantize_4bit(half) + + /*********************************/ + /************* 2-bit *************/ + //********************************/ + + template + [[kernel]] void dequantize_2bit(const device char *weight [[buffer(0)]], + const device T *scale [[buffer(1)]], + const device T *zero [[buffer(2)]], + device T *output [[buffer(3)]], + device const uint &h, device const uint &w, + uint tid [[thread_position_in_grid]] + ) { + uint n = h * w; + uint j = tid % w; + output[tid] = + ((T)((weight[tid] & 0xC0) >> 6) - zero[j]) * scale[j]; // 1st chunk + output[tid + n] = + ((T)((weight[tid] & 0x30) >> 4) - zero[j]) * scale[j]; // 2nd chunk + output[tid + n * 2] = + ((T)((weight[tid] & 0x0C) >> 2) - zero[j]) * scale[j]; // 3rd chunk + output[tid + n * 3] = + ((T)((weight[tid] & 0x03)) - zero[j]) * scale[j]; // 4th chunk } -#define instantiate_dequantize_2bit(type) \ - template [[host_name("dequantize_2bit_" #type)]] \ - [[kernel]] void dequantize_2bit( \ - const device char* weight [[buffer(0)]], \ - const device type* scale [[buffer(1)]], \ - const device type* zero [[buffer(2)]], \ - device type* output [[buffer(3)]], \ - device const uint& h, \ - device const uint& w, \ - uint tid [[ thread_position_in_grid ]]); +#define instantiate_dequantize_2bit(type) \ + template [[host_name("dequantize_2bit_" #type)]] [[kernel]] void \ + dequantize_2bit(const device char *weight [[buffer(0)]], \ + const device type *scale [[buffer(1)]], \ + const device type *zero [[buffer(2)]], \ + device type *output [[buffer(3)]], \ + device const uint &h, device const uint &w, \ + uint tid [[thread_position_in_grid]]); instantiate_dequantize_2bit(float) #if defined(__HAVE_BFLOAT__) -instantiate_dequantize_2bit(bfloat) + instantiate_dequantize_2bit(bfloat) #endif -instantiate_dequantize_2bit(half) - - -/*********************************/ -/************* 1-bit *************/ -//********************************/ - -template -[[kernel]] void dequantize_1bit( - const device char* weight [[buffer(0)]], - const device T* scale [[buffer(1)]], - const device T* zero [[buffer(2)]], - device T* output [[buffer(3)]], - device const uint& h, - device const uint& w, - uint tid [[ thread_position_in_grid ]] -) { - uint n = h*w; - uint j = tid % w; - output[tid] = ((T)((weight[tid] & 0x80) >> 7) - zero[j])*scale[j]; // 1st chunk - output[tid + n] = ((T)((weight[tid] & 0x40) >> 6) - zero[j])*scale[j]; // 2nd chunk - output[tid + n*2] = ((T)((weight[tid] & 0x20) >> 5) - zero[j])*scale[j]; // 3rd chunk - output[tid + n*3] = ((T)((weight[tid] & 0x10) >> 4) - zero[j])*scale[j]; // 4th chunk - output[tid + n*4] = ((T)((weight[tid] & 0x08) >> 3) - zero[j])*scale[j]; // 5th chunk - output[tid + n*5] = ((T)((weight[tid] & 0x04) >> 2) - zero[j])*scale[j]; // 6th chunk - output[tid + n*6] = ((T)((weight[tid] & 0x02) >> 1) - zero[j])*scale[j]; // 7th chunk - output[tid + n*7] = ((T)((weight[tid] & 0x01)) - zero[j])*scale[j]; // 8th chunk + instantiate_dequantize_2bit(half) + + /*********************************/ + /************* 1-bit *************/ + //********************************/ + + template + [[kernel]] void dequantize_1bit(const device char *weight [[buffer(0)]], + const device T *scale [[buffer(1)]], + const device T *zero [[buffer(2)]], + device T *output [[buffer(3)]], + device const uint &h, device const uint &w, + uint tid [[thread_position_in_grid]] + ) { + uint n = h * w; + uint j = tid % w; + output[tid] = + ((T)((weight[tid] & 0x80) >> 7) - zero[j]) * scale[j]; // 1st chunk + output[tid + n] = + ((T)((weight[tid] & 0x40) >> 6) - zero[j]) * scale[j]; // 2nd chunk + output[tid + n * 2] = + ((T)((weight[tid] & 0x20) >> 5) - zero[j]) * scale[j]; // 3rd chunk + output[tid + n * 3] = + ((T)((weight[tid] & 0x10) >> 4) - zero[j]) * scale[j]; // 4th chunk + output[tid + n * 4] = + ((T)((weight[tid] & 0x08) >> 3) - zero[j]) * scale[j]; // 5th chunk + output[tid + n * 5] = + ((T)((weight[tid] & 0x04) >> 2) - zero[j]) * scale[j]; // 6th chunk + output[tid + n * 6] = + ((T)((weight[tid] & 0x02) >> 1) - zero[j]) * scale[j]; // 7th chunk + output[tid + n * 7] = + ((T)((weight[tid] & 0x01)) - zero[j]) * scale[j]; // 8th chunk } -#define instantiate_dequantize_1bit(type) \ - template [[host_name("dequantize_1bit_" #type)]] \ - [[kernel]] void dequantize_1bit( \ - const device char* weight [[buffer(0)]], \ - const device type* scale [[buffer(1)]], \ - const device type* zero [[buffer(2)]], \ - device type* output [[buffer(3)]], \ - device const uint& h, \ - device const uint& w, \ - uint tid [[ thread_position_in_grid ]]); +#define instantiate_dequantize_1bit(type) \ + template [[host_name("dequantize_1bit_" #type)]] [[kernel]] void \ + dequantize_1bit(const device char *weight [[buffer(0)]], \ + const device type *scale [[buffer(1)]], \ + const device type *zero [[buffer(2)]], \ + device type *output [[buffer(3)]], \ + device const uint &h, device const uint &w, \ + uint tid [[thread_position_in_grid]]); instantiate_dequantize_1bit(float) #if defined(__HAVE_BFLOAT__) -instantiate_dequantize_1bit(bfloat) + instantiate_dequantize_1bit(bfloat) #endif -instantiate_dequantize_1bit(half) - - -/*********************************/ -/************* 3-bit *************/ -//********************************/ - -template -[[kernel]] void dequantize_3bit( - const device int* weight [[buffer(0)]], - const device T* scale [[buffer(1)]], - const device T* zero [[buffer(2)]], - device T* output [[buffer(3)]], - device const uint& h, - device const uint& w, - uint tid [[ thread_position_in_grid ]] -) { - uint n = h*w; - uint j = tid % w; - output[tid] = ((T)((weight[tid] & 0x38000000) >> 27) - zero[j])*scale[j]; // 1st chunk - output[tid + n] = ((T)((weight[tid] & 0x07000000) >> 24) - zero[j])*scale[j]; // 2nd chunk - output[tid + n*2] = ((T)((weight[tid] & 0x00E00000) >> 21) - zero[j])*scale[j]; // 3rd chunk - output[tid + n*3] = ((T)((weight[tid] & 0x001C0000) >> 18) - zero[j])*scale[j]; // 4th chunk - output[tid + n*4] = ((T)((weight[tid] & 0x00038000) >> 15) - zero[j])*scale[j]; // 5th chunk - output[tid + n*5] = ((T)((weight[tid] & 0x00007000) >> 12) - zero[j])*scale[j]; // 6th chunk - output[tid + n*6] = ((T)((weight[tid] & 0x00000E00) >> 9) - zero[j])*scale[j]; // 7th chunk - output[tid + n*7] = ((T)((weight[tid] & 0x000001C0) >> 6) - zero[j])*scale[j]; // 8th chunk - output[tid + n*8] = ((T)((weight[tid] & 0x00000038) >> 3) - zero[j])*scale[j]; // 9th chunk - output[tid + n*9] = ((T)((weight[tid] & 0x00000007)) - zero[j])*scale[j]; // 10th chunk + instantiate_dequantize_1bit(half) + + /*********************************/ + /************* 3-bit *************/ + //********************************/ + + template + [[kernel]] void dequantize_3bit(const device int *weight [[buffer(0)]], + const device T *scale [[buffer(1)]], + const device T *zero [[buffer(2)]], + device T *output [[buffer(3)]], + device const uint &h, device const uint &w, + uint tid [[thread_position_in_grid]] + ) { + uint n = h * w; + uint j = tid % w; + output[tid] = + ((T)((weight[tid] & 0x38000000) >> 27) - zero[j]) * scale[j]; // 1st chunk + output[tid + n] = + ((T)((weight[tid] & 0x07000000) >> 24) - zero[j]) * scale[j]; // 2nd chunk + output[tid + n * 2] = + ((T)((weight[tid] & 0x00E00000) >> 21) - zero[j]) * scale[j]; // 3rd chunk + output[tid + n * 3] = + ((T)((weight[tid] & 0x001C0000) >> 18) - zero[j]) * scale[j]; // 4th chunk + output[tid + n * 4] = + ((T)((weight[tid] & 0x00038000) >> 15) - zero[j]) * scale[j]; // 5th chunk + output[tid + n * 5] = + ((T)((weight[tid] & 0x00007000) >> 12) - zero[j]) * scale[j]; // 6th chunk + output[tid + n * 6] = + ((T)((weight[tid] & 0x00000E00) >> 9) - zero[j]) * scale[j]; // 7th chunk + output[tid + n * 7] = + ((T)((weight[tid] & 0x000001C0) >> 6) - zero[j]) * scale[j]; // 8th chunk + output[tid + n * 8] = + ((T)((weight[tid] & 0x00000038) >> 3) - zero[j]) * scale[j]; // 9th chunk + output[tid + n * 9] = + ((T)((weight[tid] & 0x00000007)) - zero[j]) * scale[j]; // 10th chunk } -#define instantiate_dequantize_3bit(type) \ - template [[host_name("dequantize_3bit_" #type)]] \ - [[kernel]] void dequantize_3bit( \ - const device int* weight [[buffer(0)]], \ - const device type* scale [[buffer(1)]], \ - const device type* zero [[buffer(2)]], \ - device type* output [[buffer(3)]], \ - device const uint& h, \ - device const uint& w, \ - uint tid [[ thread_position_in_grid ]]); +#define instantiate_dequantize_3bit(type) \ + template [[host_name("dequantize_3bit_" #type)]] [[kernel]] void \ + dequantize_3bit(const device int *weight [[buffer(0)]], \ + const device type *scale [[buffer(1)]], \ + const device type *zero [[buffer(2)]], \ + device type *output [[buffer(3)]], \ + device const uint &h, device const uint &w, \ + uint tid [[thread_position_in_grid]]); instantiate_dequantize_3bit(float) #if defined(__HAVE_BFLOAT__) -instantiate_dequantize_3bit(bfloat) + instantiate_dequantize_3bit(bfloat) #endif -instantiate_dequantize_3bit(half) + instantiate_dequantize_3bit(half) diff --git a/mistralrs-quant/src/unquantized/mod.rs b/mistralrs-quant/src/unquantized/mod.rs index 3174561aa7..be06610cf1 100644 --- a/mistralrs-quant/src/unquantized/mod.rs +++ b/mistralrs-quant/src/unquantized/mod.rs @@ -52,7 +52,7 @@ impl QuantMethod for UnquantLinear { fn forward(&self, a: &Tensor) -> Result { // Batch matrix multiplication - maybe_init_cublas_lt_wrapper(); + maybe_init_cublas_lt_wrapper(a.device().clone()); let w = match *a.dims() { [b1, b2, _, _] => self.w.broadcast_left((b1, b2))?, diff --git a/mistralrs-server/src/main.rs b/mistralrs-server/src/main.rs index 1d5eed0d9c..3cf375e8e6 100644 --- a/mistralrs-server/src/main.rs +++ b/mistralrs-server/src/main.rs @@ -317,6 +317,8 @@ async fn main() -> Result<()> { let device = if args.cpu { args.no_paged_attn = true; Device::Cpu + } else if cfg!(feature = "nccl") { + Device::Cpu } else { Device::cuda_if_available(0)? }; @@ -377,7 +379,7 @@ async fn main() -> Result<()> { DeviceMapSetting::Auto(auto_device_map_params) }; - let no_paged_attn = if device.is_cuda() { + let no_paged_attn = if device.is_cuda() || cfg!(feature = "nccl") { args.no_paged_attn } else if device.is_metal() { !args.paged_attn diff --git a/scripts/generate_uqff_card.py b/scripts/generate_uqff_card.py index 0bd48f6246..5e94162ee9 100644 --- a/scripts/generate_uqff_card.py +++ b/scripts/generate_uqff_card.py @@ -65,7 +65,7 @@ "Enter topology used to make UQFF with multiple quantizations: " ) topologies[file] = topology - output += f"|{",".join(quants)} (see topology for this file)|" + output += f"|{','.join(quants)} (see topology for this file)|" else: output += f"|{quants.strip().upper()}|"