diff --git a/.gitignore b/.gitignore index b5195629e5cf..b1513ef0ddb0 100644 --- a/.gitignore +++ b/.gitignore @@ -181,6 +181,7 @@ _build/ # hip files generated by PyTorch *.hip *_hip* +hip_compat.h # Benchmark dataset *.json diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e4d70851e46e..8ac2a861e659 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -30,6 +30,7 @@ def main(args: argparse.Namespace): trust_remote_code=args.trust_remote_code, dtype=args.dtype, enforce_eager=args.enforce_eager, + kv_cache_dtype=args.kv_cache_dtype, ) for batch_size in args.batch_size: @@ -152,6 +153,14 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument('--enforce-eager', action='store_true', help='enforce eager mode and disable CUDA graph') + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=['auto', 'fp8'], + default='auto', + help='Data type for kv cache storage. If "auto", will use model data ' + 'type. FP8_E5M2 is only supported on cuda version greater than 11.8. ' + 'On AMD GPUs, only the more standard FP8_E4M3 is supported for inference.') parser.add_argument( '--profile', action='store_true', diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 3aac479c01bd..13684f829ae7 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -71,6 +71,7 @@ def run_vllm( dtype: str, max_model_len: Optional[int], enforce_eager: bool, + kv_cache_dtype: str, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -83,6 +84,7 @@ def run_vllm( dtype=dtype, max_model_len=max_model_len, enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, ) # Add the requests to the engine. @@ -206,7 +208,8 @@ def main(args: argparse.Namespace): args.quantization, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, - args.max_model_len, args.enforce_eager) + args.max_model_len, args.enforce_eager, + args.kv_cache_dtype) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -284,6 +287,14 @@ def main(args: argparse.Namespace): parser.add_argument("--enforce-eager", action="store_true", help="enforce eager execution") + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8"], + default="auto", + help='Data type for kv cache storage. If "auto", will use model data ' + 'type. FP8_E5M2 is only supported on cuda version greater than 11.8. ' + 'On AMD GPUs, only the more standard FP8_E4M3 is supported for inference.') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 935393e9942c..472dc444b2c5 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -1,9 +1,11 @@ +from typing import Optional import argparse import random import time import torch +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random from vllm._C import ops NUM_BLOCKS = 1024 @@ -23,6 +25,7 @@ def main( dtype: torch.dtype, seed: int, do_profile: bool, + kv_cache_dtype: Optional[str] = None, ) -> None: random.seed(seed) torch.random.manual_seed(seed) @@ -59,15 +62,10 @@ def main( block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") # Create the KV cache. - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) - key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda") - key_cache.uniform_(-scale, scale) - value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size) - value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device="cuda") - value_cache.uniform_(-scale, scale) + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype, + dtype) + key_cache, value_cache = key_caches[0], value_caches[0] # Prepare for the paged attention kernel. output = torch.empty_like(query) @@ -106,6 +104,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_context_len, alibi_slopes, + kv_cache_dtype, ) elif version == "v2": ops.paged_attention_v2( @@ -123,6 +122,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_context_len, alibi_slopes, + kv_cache_dtype, ) else: raise ValueError(f"Invalid version: {version}") @@ -168,16 +168,19 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: default="half") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--profile", action="store_true") + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["auto", "fp8"], + default="auto", + help='Data type for kv cache storage. If "auto", will use model data ' + 'type. FP8_E5M2 is only supported on cuda version greater than 11.8. ' + 'On AMD GPUs, only the more standard FP8_E4M3 is supported for inference.') args = parser.parse_args() print(args) if args.num_query_heads % args.num_kv_heads != 0: raise ValueError("num_query_heads must be divisible by num_kv_heads") - dtype_to_torch_dtype = { - "half": torch.half, - "bfloat16": torch.bfloat16, - "float": torch.float, - } main( version=args.version, num_seqs=args.batch_size, @@ -187,7 +190,8 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: head_size=args.head_size, block_size=args.block_size, use_alibi=args.use_alibi, - dtype=dtype_to_torch_dtype[args.dtype], + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], seed=args.seed, do_profile=args.profile, + kv_cache_dtype=args.kv_cache_dtype, ) diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h index 88b4eddec7fc..64f86381d9db 100644 --- a/csrc/attention/attention_dtypes.h +++ b/csrc/attention/attention_dtypes.h @@ -4,3 +4,4 @@ #include "dtype_float16.cuh" #include "dtype_float32.cuh" #include "dtype_bfloat16.cuh" +#include "dtype_fp8.cuh" diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 9dcacfbe47d4..2dca62c80a2d 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -25,6 +25,11 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" +#if defined(ENABLE_FP8_E5M2) +#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" +#elif defined(ENABLE_FP8_E4M3) +#include "../quantization/fp8/amd_detail/quant_utils.cuh" +#endif #include @@ -79,17 +84,19 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Grid: (num_heads, num_seqs, max_num_partitions). template< typename scalar_t, + typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, + bool IS_FP8_KV_CACHE, int PARTITION_SIZE = 0> // Zero means no partitioning. __device__ void paged_attention_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] @@ -145,6 +152,9 @@ __device__ void paged_attention_kernel( constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); using K_vec = typename Vec::Type; using Q_vec = typename Vec::Type; +#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) + using Quant_vec = typename Vec::Type; +#endif constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; @@ -176,7 +186,7 @@ __device__ void paged_attention_kernel( // x == THREAD_GROUP_SIZE * VEC_SIZE // Each thread group fetches x elements from the key at a time. - constexpr int x = 16 / sizeof(scalar_t); + constexpr int x = 16 / sizeof(cache_t); float qk_max = -FLT_MAX; // Iterate over the key blocks. @@ -202,13 +212,27 @@ __device__ void paged_attention_kernel( #pragma unroll for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride - + physical_block_offset * x; + const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + physical_block_offset * x; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset2 = (vec_idx * VEC_SIZE) % x; - k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + if constexpr (IS_FP8_KV_CACHE) { +#if defined(ENABLE_FP8_E5M2) + Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + // Vector conversion from Quant_vec to K_vec. + k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant); +#elif defined(ENABLE_FP8_E4M3) + Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + // Vector conversion from Quant_vec to K_vec. + k_vecs[j] = fp8_e4m3::vec_conversion(k_vec_quant); +#else + assert(false); +#endif + } else { + k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } } // Compute dot product. @@ -282,6 +306,9 @@ __device__ void paged_attention_kernel( constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; using L_vec = typename Vec::Type; +#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) + using V_quant_vec = typename Vec::Type; +#endif using Float_L_vec = typename FloatVec::Type; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; @@ -307,14 +334,29 @@ __device__ void paged_attention_kernel( L_vec logits_vec; from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); - const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride; + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; - V_vec v_vec = *reinterpret_cast(v_ptr + offset); + V_vec v_vec; + if constexpr (IS_FP8_KV_CACHE) { +#if defined(ENABLE_FP8_E5M2) + V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. + v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec); +#elif defined(ENABLE_FP8_E4M3) + V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. + v_vec = fp8_e4m3::vec_conversion(v_quant_vec); +#else + assert(false); +#endif + } else { + v_vec = *reinterpret_cast(v_ptr + offset); + } if (block_idx == num_context_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, // we should explicitly zero out the values since they may contain NaNs. @@ -395,14 +437,16 @@ __device__ void paged_attention_kernel( // Grid: (num_heads, num_seqs, 1). template< typename scalar_t, + typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, - int NUM_THREADS> + int NUM_THREADS, + bool IS_FP8_KV_CACHE> __global__ void paged_attention_v1_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] @@ -412,7 +456,7 @@ __global__ void paged_attention_v1_kernel( const int q_stride, const int kv_block_stride, const int kv_head_stride) { - paged_attention_kernel( + paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); @@ -421,17 +465,19 @@ __global__ void paged_attention_v1_kernel( // Grid: (num_heads, num_seqs, max_num_partitions). template< typename scalar_t, + typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, + bool IS_FP8_KV_CACHE, int PARTITION_SIZE> __global__ void paged_attention_v2_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] @@ -441,7 +487,7 @@ __global__ void paged_attention_v2_kernel( const int q_stride, const int kv_block_stride, const int kv_head_stride) { - paged_attention_kernel( + paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); @@ -550,10 +596,10 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), \ - shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ + ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \ + vllm::paged_attention_v1_kernel<<>>( \ out_ptr, \ query_ptr, \ key_cache_ptr, \ @@ -571,7 +617,9 @@ __global__ void paged_attention_v2_reduce_kernel( // TODO(woosuk): Tune NUM_THREADS. template< typename T, + typename CACHE_T, int BLOCK_SIZE, + bool IS_FP8_KV_CACHE, int NUM_THREADS = 128> void paged_attention_v1_launcher( torch::Tensor& out, @@ -602,8 +650,8 @@ void paged_attention_v1_launcher( T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); @@ -647,35 +695,35 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + paged_attention_v1_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ alibi_slopes); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER(T, 8); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER(T, 16); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER(T, 32); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } void paged_attention_v1( @@ -689,20 +737,36 @@ void paged_attention_v1( torch::Tensor& context_lens, // [num_seqs] int block_size, int max_context_len, - const c10::optional& alibi_slopes) { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype) { + if (kv_cache_dtype == "auto") { + if (query.dtype() == at::ScalarType::Float) { + CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else if (kv_cache_dtype == "fp8") { + if (query.dtype() == at::ScalarType::Float) { + CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); } } #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ + vllm::paged_attention_v2_kernel \ <<>>( \ exp_sums_ptr, \ max_logits_ptr, \ @@ -730,7 +794,9 @@ void paged_attention_v1( template< typename T, + typename CACHE_T, int BLOCK_SIZE, + bool IS_FP8_KV_CACHE, int NUM_THREADS = 128, int PARTITION_SIZE = 512> void paged_attention_v2_launcher( @@ -768,8 +834,8 @@ void paged_attention_v2_launcher( float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); @@ -816,34 +882,34 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_launcher( \ - out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ - query, \ - key_cache, \ - value_cache, \ - num_kv_heads, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + paged_attention_v2_launcher( \ + out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ + query, \ + key_cache, \ + value_cache, \ + num_kv_heads, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ alibi_slopes); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \ switch (block_size) { \ case 8: \ - CALL_V2_LAUNCHER(T, 8); \ + CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \ break; \ case 16: \ - CALL_V2_LAUNCHER(T, 16); \ + CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \ break; \ case 32: \ - CALL_V2_LAUNCHER(T, 32); \ + CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ @@ -864,15 +930,30 @@ void paged_attention_v2( torch::Tensor& context_lens, // [num_seqs] int block_size, int max_context_len, - const c10::optional& alibi_slopes) { - if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype) { + if (kv_cache_dtype == "auto") { + if (query.dtype() == at::ScalarType::Float) { + CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else if (kv_cache_dtype == "fp8") { + if (query.dtype() == at::ScalarType::Float) { + CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); } } diff --git a/csrc/attention/dtype_fp8.cuh b/csrc/attention/dtype_fp8.cuh new file mode 100644 index 000000000000..d11dee91ebe8 --- /dev/null +++ b/csrc/attention/dtype_fp8.cuh @@ -0,0 +1,35 @@ +#pragma once + +#include "attention_generic.cuh" + +#include +#ifdef ENABLE_FP8_E5M2 +#include +#endif + +namespace vllm { +#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3) +// fp8 vector types for quantization of kv cache + +template<> +struct Vec { + using Type = uint8_t; +}; + +template<> +struct Vec { + using Type = uint16_t; +}; + +template<> +struct Vec { + using Type = uint32_t; +}; + +template<> +struct Vec { + using Type = uint2; +}; +#endif // ENABLE_FP8_E5M2 + +} // namespace vllm diff --git a/csrc/cache.h b/csrc/cache.h index b26faad2ca81..aafee5524fe2 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -20,7 +20,8 @@ void reshape_and_cache( torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& slot_mapping); + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype); void gather_cached_kv( torch::Tensor& key, @@ -28,3 +29,8 @@ void gather_cached_kv( torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping); + +// Just for unittest +void convert_fp8( + torch::Tensor& src_cache, + torch::Tensor& dst_cache); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index b7523cb4c3b5..f8a7951f057d 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -4,12 +4,22 @@ #include "cuda_compat.h" #include "dispatch_utils.h" +#if defined(ENABLE_FP8_E5M2) +#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" +#else if defined(ENABLE_FP8_E4M3) +#include "quantization/fp8/amd_detail/quant_utils.cuh" +#endif #include #include #include #include +#ifdef USE_ROCM + #include + typedef __hip_bfloat16 __nv_bfloat16; +#endif + void swap_blocks( torch::Tensor& src, torch::Tensor& dst, @@ -131,7 +141,7 @@ void copy_blocks( dim3 block(std::min(1024, numel_per_block)); const at::cuda::OptionalCUDAGuard device_guard(cache_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { vllm::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), @@ -143,12 +153,12 @@ void copy_blocks( namespace vllm { -template +template __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] - scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, @@ -185,19 +195,48 @@ __global__ void reshape_and_cache_kernel( + head_idx * head_size * block_size + head_offset * block_size + block_offset; - key_cache[tgt_key_idx] = key[src_key_idx]; - value_cache[tgt_value_idx] = value[src_value_idx]; + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + if constexpr (is_fp8_kv_cache) { +#if defined(ENABLE_FP8_E5M2) + key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key); + value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value); +#elif defined(ENABLE_FP8_E4M3) + key_cache[tgt_key_idx] = fp8_e4m3::vec_conversion(tgt_key); + value_cache[tgt_value_idx] = fp8_e4m3::vec_conversion(tgt_value); +#else + assert(false); +#endif + } else { + key_cache[tgt_key_idx] = tgt_key; + value_cache[tgt_value_idx] = tgt_value; + } } } } // namespace vllm +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ + vllm::reshape_and_cache_kernel<<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), \ + key_stride, \ + value_stride, \ + num_heads, \ + head_size, \ + block_size, \ + x); + void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& slot_mapping) // [num_tokens] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype) { int num_tokens = key.size(0); int num_heads = key.size(1); @@ -212,23 +251,25 @@ void reshape_and_cache( dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - key.scalar_type(), - "reshape_and_cache_kernel", - [&] { - vllm::reshape_and_cache_kernel<<>>( - key.data_ptr(), - value.data_ptr(), - key_cache.data_ptr(), - value_cache.data_ptr(), - slot_mapping.data_ptr(), - key_stride, - value_stride, - num_heads, - head_size, - block_size, - x); - }); + if (kv_cache_dtype == "auto") { + if (key.dtype() == at::ScalarType::Float) { + CALL_RESHAPE_AND_CACHE(float, float, false); + } else if (key.dtype() == at::ScalarType::Half) { + CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false); + } else if (key.dtype() == at::ScalarType::BFloat16) { + CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); + } + } else if (kv_cache_dtype == "fp8") { + if (key.dtype() == at::ScalarType::Float) { + CALL_RESHAPE_AND_CACHE(float, uint8_t, true); + } else if (key.dtype() == at::ScalarType::Half) { + CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); + } else if (key.dtype() == at::ScalarType::BFloat16) { + CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); + } + } else { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } } namespace vllm { @@ -256,12 +297,12 @@ __global__ void gather_cached_kv_kernel( for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { const int tgt_key_idx = token_idx * key_stride + i; const int tgt_value_idx = token_idx * value_stride + i; - + const int head_idx = i / head_size; const int head_offset = i % head_size; const int x_idx = head_offset / x; // the offset of the [head_size/x] dimension const int x_offset = head_offset % x; - + const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x @@ -373,7 +414,7 @@ void gather_cached_kv( dim3 block(std::min(num_heads * head_size, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( + VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( key.scalar_type(), "gather_cached_kv_kernel_optimized", [&] { @@ -391,3 +432,66 @@ void gather_cached_kv( x); }); } + +namespace vllm { + +template +__global__ void convert_fp8_kernel( + const Tin* __restrict__ src_cache, + Tout* __restrict__ dst_cache, + const int64_t block_stride) { + const int64_t block_idx = blockIdx.x; + for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { + int64_t idx = block_idx * block_stride + i; +#if defined(ENABLE_FP8_E5M2) + dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]); +#elif defined(ENABLE_FP8_E4M3) + dst_cache[idx] = fp8_e4m3::vec_conversion(src_cache[idx]); +#else + assert(false); +#endif + } +} + +} // namespace vllm + +#define CALL_CONVERT_FP8(Tout, Tin) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), \ + block_stride); + +void convert_fp8( + torch::Tensor& src_cache, + torch::Tensor& dst_cache) +{ + torch::Device src_device = src_cache.device(); + torch::Device dst_device = dst_cache.device(); + TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") + TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") + TORCH_CHECK( + src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + at::cuda::OptionalCUDAGuard device_guard(src_device); + + int64_t num_blocks = src_cache.size(0); + int64_t block_stride = src_cache.stride(0); + + dim3 grid(num_blocks); + dim3 block(std::min(block_stride, int64_t(512))); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t); + } +} diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 0ae9cd641598..85fdfc091e94 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -14,3 +14,13 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) diff --git a/csrc/ops.h b/csrc/ops.h index 9340a60da141..0c123617411b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -13,7 +13,8 @@ void paged_attention_v1( torch::Tensor& context_lens, int block_size, int max_context_len, - const c10::optional& alibi_slopes); + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype); void paged_attention_v2( torch::Tensor& out, @@ -29,7 +30,8 @@ void paged_attention_v2( torch::Tensor& context_lens, int block_size, int max_context_len, - const c10::optional& alibi_slopes); + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype); void rms_norm( torch::Tensor& out, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index e6683c446154..9ee3cb07b689 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -74,6 +74,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "gather_cached_kv", &gather_cached_kv, "Gather key and value from the cache into contiguous QKV tensors"); + cache_ops.def( + "convert_fp8", + &convert_fp8, + "Convert the key and value cache to fp8 data type"); // Cuda utils pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); diff --git a/csrc/quantization/fp8/amd_detail/hip_float8.h b/csrc/quantization/fp8/amd_detail/hip_float8.h new file mode 100644 index 000000000000..87c7c9ce6610 --- /dev/null +++ b/csrc/quantization/fp8/amd_detail/hip_float8.h @@ -0,0 +1,167 @@ +#pragma once + +#ifdef __HIPCC__ +#include +#else +#include +#include +#include +#include +#endif + +#include "hip_float8_impl.h" + +struct alignas(1) hip_fp8 +{ + struct from_bits_t + { + }; + HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); } + uint8_t data; + + hip_fp8() = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; + explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) + : data(v) + { + } + +#ifdef __HIP__MI300__ + // NOTE: ON-DEVICE... always optimal bias + explicit HIP_FP8_DEVICE hip_fp8(float v) + : data(hip_fp8_impl::to_fp8_from_fp32(v)) + { + } + + explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) + : hip_fp8(static_cast(v)) + { + } + + // Host only implementation using s/w simulation + explicit HIP_FP8_HOST +#else // __HIP__MI300__ + // both Host and DEVICE for non-MI300 using s/w simulation + explicit HIP_FP8_HOST_DEVICE +#endif // __HIP__MI300__ + hip_fp8(float v) + { + data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, true /*clip*/>(v); + } + + explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) + : hip_fp8(static_cast(v)) + { + } + +#ifdef __HIP__MI300__ + // upcast using device specific intrinsic + explicit inline HIP_FP8_DEVICE operator float() const + { + float fval; + uint32_t i32val = static_cast(data); + + // upcast + asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + + return fval; + } + + explicit inline HIP_FP8_HOST operator float() const +#else // __HIP__MI300__ + explicit inline HIP_FP8_HOST_DEVICE operator float() const +#endif // __HIP__MI300__ + { + return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(data); + } +}; + +namespace std +{ +inline hip_fp8 sin(hip_fp8 a) +{ + return hip_fp8(sinf(float(a))); +} +inline hip_fp8 cos(hip_fp8 a) +{ + return hip_fp8(cosf(float(a))); +} +HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) +{ + return a; +} +} // namespace std + +// Special operator overloading +inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) +{ + return os << float(f8); +} + +// all + operator overloading with mixed types +// mixed types, always converts to f32, does computation in f32, and returns float +inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) +{ + return (fa + float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) +{ + return (float(a) + fb); +} + +inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) +{ + return hip_fp8(float(a) + float(b)); +} + +inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) +{ + return a = hip_fp8(float(a) + float(b)); +} + +// overloading multiplication, always returns float, +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) +{ + return float(a) * float(b); +} + +inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) +{ + return (a * float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) +{ + return (float(a) * b); +} + +inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) +{ + return ((float)a * float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) +{ + return ((float)a * float(b)); +} + +// overloading for compare +inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) +{ + return (a.data == b.data); +} +inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) +{ + return (a.data != b.data); +} + +inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) +{ + return static_cast(a) >= static_cast(b); +} +inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) +{ + return static_cast(a) > static_cast(b); +} diff --git a/csrc/quantization/fp8/amd_detail/hip_float8_impl.h b/csrc/quantization/fp8/amd_detail/hip_float8_impl.h new file mode 100644 index 000000000000..c88fbd913c2e --- /dev/null +++ b/csrc/quantization/fp8/amd_detail/hip_float8_impl.h @@ -0,0 +1,316 @@ +#pragma once + +#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#define __HIP__MI300__ +#endif + +#ifdef __HIPCC__ +#define HIP_FP8_HOST_DEVICE __host__ __device__ +#define HIP_FP8_HOST __host__ +#define HIP_FP8_DEVICE __device__ +#else +#define HIP_FP8_HOST_DEVICE +#define HIP_FP8_HOST +#define HIP_FP8_DEVICE +#endif + +namespace hip_fp8_impl +{ + +#ifdef __HIP__MI300__ +HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) +{ + uint8_t i8data; + union { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // NOTE: not endian independent + } val; + + uint32_t ival = 0; + val.fval = v; + + if ((val.i32val & 0x7F800000) != 0x7F800000) { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + + return i8data; +} +#endif // __HIP__MI300__ + +HIP_FP8_HOST inline int clz(uint32_t x) +{ + return __builtin_clz(x); +} +#if defined(__HIPCC__) || defined(__CUDA_ARCH__) +HIP_FP8_DEVICE inline int clz(uint32_t x) +{ + return __clz(x); +} +#endif + +template +HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0) +{ +#ifdef __HIPCC__ + constexpr bool is_half = std::is_same::value; +#else + constexpr bool is_half = false; +#endif + constexpr bool is_float = std::is_same::value; + static_assert(wm + we == 7, "wm+we==7"); + static_assert(is_half || is_float, "Only half and float can be cast to f8"); + + const int mfmt = (sizeof(T) == 4) ? 23 : 10; + uint32_t x; + if (sizeof(T) == 4) { + x = reinterpret_cast(_x); + } else { + x = reinterpret_cast(_x); + } + + uint32_t head, mantissa; + int exponent, bias; + uint32_t sign; + + if (sizeof(T) == 4) { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } else { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + + // Deal with inf and NaNs + if (negative_zero_nan) { + if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) { + return 0x80; + } + } else { + // if(__hisinf(x) || __hisnan(x)) + if ((x & 0x7C00) == 0x7C00) { + return 0x80; + } + } + } else { + if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } else { + if ((x & 0x7C00) == 0x7C00) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } + } + if (x == 0) { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implict 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if (exponent == 0) { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we +mostly concern fp16 here. In this case, f8 is usually in denormal. But there +could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has +exponent bias 16. It means that there are some numbers in fp16 denormal but they +are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 +(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = f8_denormal_act_exponent - act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } else { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if (act_exponent <= f8_denormal_act_exponent) { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal + range. For example fp8 nanoo mode, denormal exponent is -7, but if the + fp32/fp16 actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } else { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no difference + // for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == + static_cast(1 << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be + done before we shift right as shift right could rip off some residual part + and make something not midpoint look like midpoint. For example, the fp16 + number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after + shift right by 4 bits, it would look like midpoint. +*/ + + if (exponent_diff > 0) { + mantissa >>= exponent_diff; + } else if (exponent_diff == -1) { + mantissa <<= -exponent_diff; + } + bool implicit_one = mantissa & (1 << mfmt); + // if there is no implict 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + uint32_t drop_mask = (1 << (mfmt - wm)) - 1; + bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit that + // is not truncated is 1 + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + + // Now we deal with overflow + if (f8_exponent == 0) { + if ((1 << mfmt) & mantissa) { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } else { + if ((1 << (mfmt + 1)) & mantissa) { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + if (f8_exponent > max_exp) { + if (clip) { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } else { + return signed_inf; + } + } + + if (f8_exponent == 0 && mantissa == 0) { + return negative_zero_nan ? 0 : (sign << 7); + } + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +template +inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) +{ +#ifdef __HIPCC__ + constexpr bool is_half = std::is_same::value; +#else + constexpr bool is_half = false; +#endif + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported"); + + constexpr int weo = is_half ? 5 : 8; + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); + + T fInf, fNegInf, fNaN, fNeg0; + +#ifdef __HIPCC__ + if (is_half) { + const uint16_t ihInf = 0x7C00; + const uint16_t ihNegInf = 0xFC00; + const uint16_t ihNaN = 0x7C01; + const uint16_t ihNeg0 = 0x8000; + fInf = reinterpret_cast(ihInf); + fNegInf = reinterpret_cast(ihNegInf); + fNaN = reinterpret_cast(ihNaN); + fNeg0 = reinterpret_cast(ihNeg0); + } else +#endif + if (is_float) { + const uint32_t ifInf = 0x7F800000; + const uint32_t ifNegInf = 0xFF800000; + const uint32_t ifNaN = 0x7F800001; + const uint32_t ifNeg0 = 0x80000000; + fInf = reinterpret_cast(ifInf); + fNegInf = reinterpret_cast(ifNegInf); + fNaN = reinterpret_cast(ifNaN); + fNeg0 = reinterpret_cast(ifNeg0); + } + + if (x == 0) { + return 0; + } + + uint32_t sign = x >> 7; + uint32_t mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if (negative_zero_nan) { + if (x == 0x80) { + return fNaN; + } + } else { + if (x == 0x80) { + return fNeg0; + } + if (exponent == ((1 << we) - 1)) { + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + } + } + typename std::conditional::type retval; + if (we == 5 && is_half && !negative_zero_nan) { + retval = x << 8; + return reinterpret_cast(retval); + } + + const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + clz(mantissa) - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if (exponent <= 0) { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if (sizeof(T) == 2) { + retval = (sign << 15) | (exponent << 10) | mantissa; + } else { + retval = (sign << 31) | (exponent << 23) | mantissa; + } + return reinterpret_cast(retval); +} + +} // namespace hip_fp8_impl diff --git a/csrc/quantization/fp8/amd_detail/quant_utils.cuh b/csrc/quantization/fp8/amd_detail/quant_utils.cuh new file mode 100644 index 000000000000..7bc70d9264ab --- /dev/null +++ b/csrc/quantization/fp8/amd_detail/quant_utils.cuh @@ -0,0 +1,294 @@ +#pragma once +#include "hip_float8.h" + +#include +#include +#include + +#include "../../../attention/dtype_float32.cuh" +#include "../../../attention/dtype_bfloat16.cuh" + +namespace vllm +{ +namespace fp8_e4m3 { +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) +{ + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) +{ +#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0]; + tmp.h2r.y.data = f2[1]; + return tmp.ui32; +#else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = vec_conversion(static_cast(a)); + tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); + return tmp.u32; +#endif +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion(const uint32_t& a) +{ + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a); + tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion(const uint2& a) +{ + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x); + tmp.u64[1] = vec_conversion(a.y); + return tmp.u64x2; +} + +using __nv_bfloat16 = __hip_bfloat16; + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) +{ + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f); +} + +using __nv_bfloat162 = __hip_bfloat162; + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) +{ + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) +{ + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); + res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) +{ + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float vec_conversion(const uint8_t& a) +{ + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8); +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 vec_conversion(const uint16_t& a) +{ +#if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + //res.x = vec_conversion(static_cast(a)); + //res.y = vec_conversion(static_cast(a >> 8U)); + res.x = f2[0]; + res.y = f2[1]; + return res; +#else + float2 res; + res.x = vec_conversion(static_cast(a)); + res.y = vec_conversion(static_cast(a >> 8U)); + return res; +#endif +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) +{ + Float4_ res; + res.x = vec_conversion((uint16_t)a); + res.y = vec_conversion((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ vec_conversion(const uint2& a) +{ + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) +{ + __half_raw tmp; + tmp.x = a; + + hip_fp8 f8{static_cast(tmp.data)}; + return f8.data; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) +{ + hip_fp8 res{__bfloat162float(a)}; + return res.data; +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion(const float& a) +{ + hip_fp8 f8(a); + return f8.data; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion(const uint32_t& a) +{ + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +template <> +__inline__ __device__ uint32_t vec_conversion(const float2& a) +{ + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +template <> +__inline__ __device__ uint2 vec_conversion(const Float4_& a) +{ + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + return b; +} + +template <> +__inline__ __device__ float4 vec_conversion(const Float4_& a) +{ + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +template <> +__inline__ __device__ uint4 vec_conversion(const Float8_& a) +{ + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; +} + +template <> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a) +{ + __nv_bfloat162 b = __float22bfloat162_rn(a); + return b; +} + +template <> +__inline__ __device__ bf16_4_t vec_conversion(const Float4_& a) +{ + bf16_4_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + return b; +} + +template <> +__inline__ __device__ bf16_8_t vec_conversion(const Float8_& a) +{ + bf16_8_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + b.z = __float22bfloat162_rn(a.z); + b.w = __float22bfloat162_rn(a.w); + return b; +} +} +} // namespace vllm diff --git a/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh b/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh new file mode 100644 index 000000000000..9bcab25db03c --- /dev/null +++ b/csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh @@ -0,0 +1,277 @@ +#pragma once + +#include +#include +#include +#include +#include "../../attention/attention_dtypes.h" +#include "../../attention/dtype_float32.cuh" +#include "../../attention/dtype_float16.cuh" +#include "../../attention/dtype_bfloat16.cuh" + + +namespace vllm { +#ifdef ENABLE_FP8_E5M2 +namespace fp8_e5m2_unscaled { + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} + +// fp8 -> half +template<> +__inline__ __device__ uint16_t vec_conversion(const uint8_t& a) +{ + __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); + return res.x; +} + +// fp8x2 -> half2 +template<> +__inline__ __device__ uint32_t vec_conversion(const uint16_t& a) +{ + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2); + tmp.u16[0] = res.x; + tmp.u16[1] = res.y; + return tmp.u32; +} + +// fp8x4 -> half2x2 +template<> +__inline__ __device__ uint2 vec_conversion(const uint32_t& a) +{ + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a); + tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template<> +__inline__ __device__ uint4 vec_conversion(const uint2& a) +{ + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x); + tmp.u64[1] = vec_conversion(a.y); + return tmp.u64x2; +} + +// fp8 -> __nv_bfloat16 +template<> +__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) +{ + // Note there is no direct convert function from fp8 to bf16. + // fp8 -> half + __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); + // half -> float -> bf16 + float tmp = half_to_float(res.x); + return __float2bfloat16(tmp); +} + +// fp8x2 -> __nv_bfloat162 +template<> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) +{ + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); + return res; +} + +// fp8x4 -> bf16_4_t +template<> +__inline__ __device__ bf16_4_t vec_conversion(const uint32_t& a) +{ + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); + res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> bf16_8_t +template<> +__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) +{ + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template<> +__inline__ __device__ float vec_conversion(const uint8_t& a) +{ + // fp8 -> half + uint16_t tmp = vec_conversion(a); + // half -> float + return half_to_float(tmp); +} + +// fp8x2 -> float2 +template<> +__inline__ __device__ float2 vec_conversion(const uint16_t& a) +{ + // fp8x2 -> half2 + uint32_t tmp = vec_conversion(a); + // half2 -> float2 + return half2_to_float2(tmp); +} + +// fp8x4 -> float4 +template<> +__inline__ __device__ Float4_ vec_conversion(const uint32_t& a) +{ + Float4_ res; + res.x = vec_conversion((uint16_t)a); + res.y = vec_conversion((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> float8 +template<> +__inline__ __device__ Float8_ vec_conversion(const uint2& a) +{ + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + + +// half -> fp8 +template<> +__inline__ __device__ uint8_t vec_conversion(const uint16_t& a) +{ + __half_raw tmp; + tmp.x = a; + __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2); + return (uint8_t)res; +} + +// bf16 -> fp8 +template<> +__inline__ __device__ uint8_t vec_conversion(const __nv_bfloat16& a) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2); + return (uint8_t)res; +#endif +} + +// float -> fp8 +template<> +__inline__ __device__ uint8_t vec_conversion(const float& a) +{ + __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2); + return (uint8_t)res; +} + +// fp8x4 -> float4 +template<> +__inline__ __device__ float4 vec_conversion(const uint32_t& a) +{ + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + + +template<> +__inline__ __device__ uint32_t vec_conversion(const float2& a) +{ + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +template<> +__inline__ __device__ uint2 vec_conversion(const Float4_& a) +{ + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + + return b; +} + +template<> +__inline__ __device__ float4 vec_conversion(const Float4_& a) +{ + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +template<> +__inline__ __device__ uint4 vec_conversion(const Float8_& a) +{ + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; +} + +template<> +__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) { + __nv_bfloat162 b; + from_float(b, a); + return b; +} + +template<> +__inline__ __device__ bf16_4_t vec_conversion(const Float4_ &a) { + bf16_4_t b; + from_float(b, a); + return b; +} + +template<> +__inline__ __device__ bf16_8_t vec_conversion(const Float8_ &a) { + bf16_8_t b; + from_float(b, a); + return b; +} + +} // namespace fp8_e5m2_unscaled +#endif // ENABLE_FP8_E5M2 +} // namespace vllm diff --git a/docs/source/quantization/fp8_kv_cache.rst b/docs/source/quantization/fp8_kv_cache.rst new file mode 100644 index 000000000000..8db69a9c5b76 --- /dev/null +++ b/docs/source/quantization/fp8_kv_cache.rst @@ -0,0 +1,32 @@ +.. _fp8_kv_cache: + +FP8 E5M2 KV Cache +================== + +The int8/int4 quantization scheme requires additional scale GPU memory storage, which reduces the expected GPU memory benefits. +The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bflaot16 and fp8 to each other. + +Here is an example of how to enable this feature: + +.. code-block:: python + from vllm import LLM, SamplingParams + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + # Create an LLM. + llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8") + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + diff --git a/setup.py b/setup.py index 15715225490a..09cadc26bb1c 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ # Supported NVIDIA GPU architectures. NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} ROCM_SUPPORTED_ARCHS = { - "gfx90a", "gfx908", "gfx906", "gfx942", "gfx1030", "gfx1100" + "gfx90a", "gfx908", "gfx906", "gfx940", "gfx941", "gfx942", "gfx1030", "gfx1100" } # SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) @@ -288,6 +288,9 @@ def get_torch_arch_list() -> Set[str]: num_threads = min(os.cpu_count(), nvcc_threads) NVCC_FLAGS += ["--threads", str(num_threads)] + if nvcc_cuda_version >= Version("11.8"): + NVCC_FLAGS += ["-DENABLE_FP8_E5M2"] + # changes for punica kernels NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS REMOVE_NVCC_FLAGS = [ @@ -318,6 +321,8 @@ def get_torch_arch_list() -> Set[str]: "nvcc": NVCC_FLAGS_PUNICA, }, )) +elif _is_hip(): + NVCC_FLAGS += ["-DENABLE_FP8_E4M3"] elif _is_neuron(): neuronxcc_version = get_neuronxcc_version() diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index fca97ab76bf0..8c51bfc149ef 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -1,44 +1,7 @@ -from typing import List, Tuple - import pytest -import torch - - -def create_kv_caches( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, - seed: int, - device: str, -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - - scale = head_size**-0.5 - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) - key_caches = [] - for _ in range(num_layers): - key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device=device) - key_cache.uniform_(-scale, scale) - key_caches.append(key_cache) - - value_cache_shape = (num_blocks, num_heads, head_size, block_size) - value_caches = [] - for _ in range(num_layers): - value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device=device) - value_cache.uniform_(-scale, scale) - value_caches.append(value_cache) - return key_caches, value_caches +from vllm.utils import create_kv_caches_with_random @pytest.fixture() def kv_cache_factory(): - return create_kv_caches + return create_kv_caches_with_random diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 3949948e860f..fe50a60f71ad 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -6,14 +6,16 @@ from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask -from vllm._C import ops +from vllm._C import ops, cache_ops from vllm.utils import get_max_shared_memory_bytes FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -NUM_BLOCKS = 12000 # Arbitrary values for testing +# There may not be enough gpu memory due to large NUM_BLOCKS. +# Reduce NUM_BLOCKS when it happens. +NUM_BLOCKS = 4321 # Arbitrary values for testing PARTITION_SIZE = 512 DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -23,6 +25,7 @@ HEAD_SIZES = [64, 80, 96, 112, 128, 256] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] +KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] @@ -105,6 +108,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("use_alibi", USE_ALIBI) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) def test_paged_attention( @@ -116,6 +120,7 @@ def test_paged_attention( use_alibi: bool, block_size: int, dtype: torch.dtype, + kv_cache_dtype: str, seed: int, device: int, ) -> None: @@ -158,8 +163,9 @@ def test_paged_attention( # Create the KV caches. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, dtype, - seed, gpu_id) + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + gpu_id) key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. @@ -177,6 +183,7 @@ def test_paged_attention( block_size, max_context_len, alibi_slopes, + kv_cache_dtype, ) elif version == "v2": num_partitions = ((max_context_len + PARTITION_SIZE - 1) // @@ -209,11 +216,30 @@ def test_paged_attention( block_size, max_context_len, alibi_slopes, + kv_cache_dtype, ) else: raise AssertionError(f"Unknown version: {version}") # Run the reference implementation. + if kv_cache_dtype == "fp8": + # Convert cache data back to dtype. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, + block_size, x) + dequantized_key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device=gpu_id) + cache_ops.convert_fp8(key_cache, dequantized_key_cache) + key_cache = dequantized_key_cache + + value_cache_shape = value_cache.shape + dequantized_value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device=gpu_id) + cache_ops.convert_fp8(value_cache, dequantized_value_cache) + value_cache = dequantized_value_cache + ref_output = torch.empty_like(query) ref_single_query_cached_kv_attention( ref_output, @@ -230,7 +256,12 @@ def test_paged_attention( # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two # outputs. Thus, we use a relaxed tolerance for the test. - assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8": + atol, rtol = 1e-2, 1e-5 + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) def ref_multi_query_kv_attention( diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 7b1cc058f2cb..6db2c81f7aea 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -3,8 +3,11 @@ import pytest import torch +from typing import Tuple + from vllm._C import cache_ops +COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [42] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing @@ -15,6 +18,7 @@ NUM_MAPPINGS = [256] # Arbitrary values for testing SEEDS = [0] DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] +KV_CACHE_DTYPE = ["auto", "fp8"] @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @@ -26,6 +30,7 @@ @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @torch.inference_mode() def test_copy_blocks( kv_cache_factory, @@ -38,6 +43,7 @@ def test_copy_blocks( dtype: torch.dtype, seed: int, device: int, + kv_cache_dtype: str, ) -> None: random.seed(seed) torch.random.manual_seed(seed) @@ -59,7 +65,8 @@ def test_copy_blocks( # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, num_layers, num_heads, - head_size, dtype, seed, gpu_id) + head_size, kv_cache_dtype, + dtype, seed, gpu_id) # Clone the KV caches. cloned_key_caches = [key_cache.clone() for key_cache in key_caches] @@ -92,6 +99,7 @@ def test_copy_blocks( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @torch.inference_mode() def test_reshape_and_cache( kv_cache_factory, @@ -103,6 +111,7 @@ def test_reshape_and_cache( dtype: torch.dtype, seed: int, device: int, + kv_cache_dtype: str, ) -> None: random.seed(seed) torch.random.manual_seed(seed) @@ -123,17 +132,29 @@ def test_reshape_and_cache( # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, - num_heads, head_size, dtype, - seed, gpu_id) + num_heads, head_size, kv_cache_dtype, + dtype, seed, gpu_id) key_cache, value_cache = key_caches[0], value_caches[0] # Clone the KV caches. - cloned_key_cache = key_cache.clone() - cloned_value_cache = value_cache.clone() + if kv_cache_dtype == "fp8": + cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + cache_ops.convert_fp8(key_cache, cloned_key_cache) + cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + cache_ops.convert_fp8(value_cache, cloned_value_cache) + else: + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() # Call the reshape_and_cache kernel. cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping) + slot_mapping, kv_cache_dtype) + + if kv_cache_dtype == "fp8": + result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) + cache_ops.convert_fp8(key_cache, result_key_cache) + result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) + cache_ops.convert_fp8(value_cache, result_value_cache) # Run the reference implementation. reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) @@ -146,6 +167,116 @@ def test_reshape_and_cache( block_offset = block_offsets[i] cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i] + + if kv_cache_dtype == "fp8": + assert torch.allclose(result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1) + assert torch.allclose(result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1) + else: + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) + + +@pytest.mark.parametrize("direction", COPYING_DIRECTION) +@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@torch.inference_mode() +def test_swap_blocks( + kv_cache_factory, + direction: Tuple[str, str], + num_mappings: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: int, + kv_cache_dtype: str, +) -> None: + if kv_cache_dtype == "fp8" and "cpu" in direction: + return + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + src_device = f"{direction[0]}:{device}" if direction[ + 0] == "cuda" else direction[0] + dst_device = f"{direction[1]}:{device}" if direction[ + 1] == "cuda" else direction[1] + + src_blocks = random.sample(range(num_blocks), num_mappings) + # For the same device, mapping must not overlap + if src_device == dst_device: + remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) + dst_blocks = random.sample(remaining_blocks, num_mappings) + else: + dst_blocks = random.sample(range(num_blocks), num_mappings) + + block_mapping = dict(zip(src_blocks, dst_blocks)) + + # Create the KV caches on the first device. + src_key_caches, src_value_caches = kv_cache_factory( + num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, seed, + src_device) + + # Create the KV caches on the second device. + dist_key_caches, dist_value_caches = kv_cache_factory( + num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype, seed, + dst_device) + + src_key_caches_clone = src_key_caches[0].clone() + src_value_caches_clone = src_value_caches[0].clone() + + # Call the swap_blocks kernel. + cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) + cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0], + block_mapping) + + for src, dst in block_mapping.items(): + assert torch.allclose(src_key_caches_clone[src].cpu(), + dist_key_caches[0][dst].cpu()) + assert torch.allclose(src_value_caches_clone[src].cpu(), + dist_value_caches[0][dst].cpu()) + + +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_fp8_conversion( + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: int, +) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + gpu_id = f"cuda:{device}" + + low = -224.0 + high = 224.0 + shape = (num_blocks, num_heads, head_size, block_size) + cache = torch.empty(shape, dtype=dtype, device=gpu_id) + cache.uniform_(low, high) + + cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) + cache_ops.convert_fp8(cache, cache_fp8) + + converted_cache = torch.empty_like(cache) + cache_ops.convert_fp8(cache_fp8, converted_cache) - assert torch.allclose(key_cache, cloned_key_cache) - assert torch.allclose(value_cache, cloned_value_cache) + assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) diff --git a/vllm/config.py b/vllm/config.py index 11952d9471d8..d02040a656fe 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,13 +1,14 @@ from typing import Optional, Union, ClassVar from dataclasses import dataclass import os +from packaging.version import Version import torch from transformers import PretrainedConfig from vllm.logger import init_logger from vllm.transformers_utils.config import get_config -from vllm.utils import get_cpu_memory, is_hip +from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version logger = init_logger(__name__) @@ -273,6 +274,7 @@ class CacheConfig: gpu_memory_utilization: Fraction of GPU memory to use for the vLLM execution. swap_space: Size of the CPU swap space per GPU (in GiB). + cache_dtype: Data type for kv cache storage. """ def __init__( @@ -280,13 +282,16 @@ def __init__( block_size: int, gpu_memory_utilization: float, swap_space: int, + cache_dtype: str, sliding_window: Optional[int] = None, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB + self.cache_dtype = cache_dtype self.sliding_window = sliding_window self._verify_args() + self._verify_cache_dtype() # Will be set after profiling. self.num_gpu_blocks = None @@ -298,6 +303,25 @@ def _verify_args(self) -> None: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype == "fp8": + if not is_hip(): + nvcc_cuda_version = get_nvcc_cuda_version() + if nvcc_cuda_version < Version("11.8"): + raise ValueError( + "FP8 is not supported when cuda version is lower than 11.8." + ) + logger.info( + "Using fp8 data type to store kv cache. It reduces " + "the GPU memory footprint and boosts the performance. " + "But it may cause slight accuracy drop. " + "Currently we only support fp8 without scaling factors and " + "make e5m2 as a default format.") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 090fa95bcac0..c8ce60962503 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -17,6 +17,7 @@ class EngineArgs: download_dir: Optional[str] = None load_format: str = 'auto' dtype: str = 'auto' + kv_cache_dtype: str = 'auto' seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False @@ -121,6 +122,14 @@ def add_cli_args( 'The "auto" option will use FP16 precision ' 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models.') + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8'], + default='auto', + help='Data type for kv cache storage. If "auto", will use model data ' + 'type. FP8_E5M2 is only supported on cuda version greater than 11.8. ' + 'On AMD GPUs, only the more standard FP8_E4M3 is supported for inference.') parser.add_argument('--max-model-len', type=int, default=None, @@ -264,7 +273,7 @@ def create_engine_configs( self.max_context_len_to_capture) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, - self.swap_space, + self.swap_space, self.kv_cache_dtype, model_config.get_sliding_window()) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7c6808e32f3f..7e3efe68d161 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -84,6 +84,7 @@ def __init__( f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"quantization={model_config.quantization}, " f"enforce_eager={model_config.enforce_eager}, " + f"kv_cache_dtype={cache_config.cache_dtype}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. @@ -143,6 +144,7 @@ def _init_workers(self): rank=0, distributed_init_method=distributed_init_method, lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) self._run_workers("init_model") @@ -233,6 +235,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", model_config = copy.deepcopy(self.model_config) parallel_config = copy.deepcopy(self.parallel_config) scheduler_config = copy.deepcopy(self.scheduler_config) + cache_config = copy.deepcopy(self.cache_config) for rank, (worker, (node_id, _)) in enumerate(zip(self.workers, @@ -248,6 +251,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", rank, distributed_init_method, lora_config=self.lora_config, + cache_config=cache_config, )) driver_rank = 0 @@ -260,6 +264,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", driver_rank, distributed_init_method, lora_config=self.lora_config, + cache_config=cache_config, is_driver_worker=True, ) @@ -305,6 +310,7 @@ def _init_cache(self) -> None: block_size=self.cache_config.block_size, gpu_memory_utilization=self.cache_config.gpu_memory_utilization, cpu_swap_space=self.cache_config.swap_space_bytes, + cache_dtype=self.cache_config.cache_dtype, ) # Since we use a shared centralized controller, we take the minimum diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index ef49cc5902ea..f0a88ac8e27f 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -12,6 +12,7 @@ class InputMetadata: max_context_len: The maximum context length. context_lens: the length of attention context for each sequence. block_tables: The block tables. (Seq id -> list of physical block) + kv_cache_dtype: Data type to store kv cache. """ def __init__( @@ -25,6 +26,7 @@ def __init__( context_lens: Optional[torch.Tensor], block_tables: Optional[torch.Tensor], use_cuda_graph: bool, + kv_cache_dtype: str, ) -> None: self.is_prompt = is_prompt self.prompt_lens = prompt_lens @@ -35,6 +37,7 @@ def __init__( self.context_lens = context_lens self.block_tables = block_tables self.use_cuda_graph = use_cuda_graph + self.kv_cache_dtype = kv_cache_dtype # Set during the execution of the first attention op. # FIXME(woosuk): This is a hack. @@ -47,4 +50,5 @@ def __repr__(self) -> str: f"slot_mapping={self.slot_mapping}, " f"context_lens={self.context_lens}, " f"block_tables={self.block_tables}, " - f"use_cuda_graph={self.use_cuda_graph})") + f"use_cuda_graph={self.use_cuda_graph}, " + f"kv_cache_dtype={self.kv_cache_dtype})") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 8b5c6ab30d7b..91ed43f07c76 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -98,6 +98,7 @@ def forward( key_cache, value_cache, input_metadata.slot_mapping.flatten(), + input_metadata.kv_cache_dtype, ) if input_metadata.is_prompt: @@ -265,6 +266,7 @@ def _paged_attention( block_size, input_metadata.max_context_len, alibi_slopes, + input_metadata.kv_cache_dtype, ) else: # Run PagedAttention V2. @@ -295,5 +297,6 @@ def _paged_attention( block_size, input_metadata.max_context_len, alibi_slopes, + input_metadata.kv_cache_dtype, ) return output diff --git a/vllm/utils.py b/vllm/utils.py index 6a9508f6d33b..8d3923cc5b0a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,9 +1,11 @@ import enum import os import socket +import subprocess import uuid from platform import uname -from typing import List +from typing import List, Tuple, Union +from packaging.version import parse, Version import psutil import torch @@ -17,7 +19,17 @@ from collections import OrderedDict from typing import Any, Hashable, Optional +from vllm.logger import init_logger + T = TypeVar("T") +logger = init_logger(__name__) + +STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8": torch.uint8, +} class Device(enum.Enum): @@ -167,3 +179,99 @@ def get_open_port() -> int: def set_cuda_visible_devices(device_ids: List[int]) -> None: os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) + + +def get_nvcc_cuda_version() -> Version: + cuda_home = os.environ.get('CUDA_HOME') + if not cuda_home: + cuda_home = '/usr/local/cuda' + logger.info( + f'CUDA_HOME is not found in the environment. Using {cuda_home} as CUDA_HOME.' + ) + nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], + universal_newlines=True) + output = nvcc_output.split() + release_idx = output.index("release") + 1 + nvcc_cuda_version = parse(output[release_idx].split(",")[0]) + return nvcc_cuda_version + + +def _generate_random_fp8( + tensor: torch.tensor, + low: float, + high: float, +) -> None: + # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, + # it may occur Inf or NaN if we directly use torch.randint + # to generate random data for fp8 data. + # For example, s.11111.00 in fp8e5m2 format repesents Inf. + # | E4M3 | E5M2 + #-----|-------------|------------------- + # Inf | N/A | s.11111.00 + # NaN | s.1111.111 | s.11111.{01,10,11} + from vllm._C import cache_ops + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) + tensor_tmp.uniform_(low, high) + cache_ops.convert_fp8(tensor_tmp, tensor) + del tensor_tmp + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: Optional[int] = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + if isinstance(cache_dtype, str): + if cache_dtype == "auto": + if isinstance(model_dtype, str): + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + elif isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + else: + raise ValueError(f"Invalid model dtype: {model_dtype}") + elif cache_dtype in ["half", "bfloat16", "float"]: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + elif cache_dtype == "fp8": + torch_dtype = torch.uint8 + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + elif isinstance(cache_dtype, torch.dtype): + torch_dtype = cache_dtype + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=torch_dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(key_cache, -scale, scale) + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + value_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(value_cache, -scale, scale) + value_caches.append(value_cache) + return key_caches, value_caches diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 1dd0243f8f3a..f57e1ed75803 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -6,7 +6,7 @@ from vllm._C import cache_ops from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import in_wsl +from vllm.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE logger = init_logger(__name__) @@ -34,12 +34,16 @@ def __init__( self.head_size = model_config.get_head_size() self.num_layers = model_config.get_num_layers(parallel_config) self.num_heads = model_config.get_num_kv_heads(parallel_config) - self.dtype = model_config.dtype self.block_size = cache_config.block_size self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_cpu_blocks = cache_config.num_cpu_blocks + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + # Initialize the cache. self.gpu_cache = self.allocate_gpu_cache() self.cpu_cache = self.allocate_cpu_cache() @@ -142,6 +146,7 @@ def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: @staticmethod def get_cache_block_size( block_size: int, + cache_dtype: str, model_config: ModelConfig, parallel_config: ParallelConfig, ) -> int: @@ -152,7 +157,11 @@ def get_cache_block_size( key_cache_block = block_size * num_heads * head_size value_cache_block = key_cache_block total = num_layers * (key_cache_block + value_cache_block) - dtype_size = _get_dtype_size(model_config.dtype) + if cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + dtype_size = _get_dtype_size(dtype) return dtype_size * total diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 985115613e04..3e6d8dc2eccb 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -35,6 +35,7 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ): self.model_config = model_config @@ -67,6 +68,7 @@ def __init__( self.graph_block_tables = None # Set after initial profiling. # cache in_wsl result self.in_wsl = in_wsl() + self.kv_cache_dtype = kv_cache_dtype def load_model(self) -> None: self.model = get_model(self.model_config, self.lora_config) @@ -222,6 +224,7 @@ def _prepare_prompt( context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=False, + kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, @@ -349,6 +352,7 @@ def _prepare_decode( context_lens=context_lens, block_tables=block_tables, use_cuda_graph=use_captured_graph, + kv_cache_dtype=self.kv_cache_dtype, ) return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests @@ -472,6 +476,7 @@ def prepare_input_tensors( "context_lens": input_metadata.context_lens, "block_tables": input_metadata.block_tables, "use_cuda_graph": input_metadata.use_cuda_graph, + "kv_cache_dtype": input_metadata.kv_cache_dtype, "selected_token_indices": sampling_metadata.selected_token_indices, "lora_requests": lora_requests, @@ -494,6 +499,7 @@ def prepare_input_tensors( context_lens=metadata_dict["context_lens"], block_tables=metadata_dict["block_tables"], use_cuda_graph=metadata_dict["use_cuda_graph"], + kv_cache_dtype=metadata_dict["kv_cache_dtype"], ) sampling_metadata = SamplingMetadata( seq_groups=None, @@ -663,6 +669,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: context_lens=context_lens[:batch_size], block_tables=block_tables[:batch_size], use_cuda_graph=True, + kv_cache_dtype=self.kv_cache_dtype, ) if self.lora_config: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index aafd7306acf5..e509802f07fc 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -36,6 +36,7 @@ def __init__( rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, + kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -53,6 +54,7 @@ def __init__( parallel_config, scheduler_config, lora_config=self.lora_config, + kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). @@ -96,6 +98,7 @@ def profile_num_available_blocks( block_size: int, gpu_memory_utilization: float, cpu_swap_space: int, + cache_dtype: str, ) -> Tuple[int, int]: """Profiles the peak memory usage of the model and returns the maximum number of GPU and CPU cache blocks that can be allocated. @@ -120,7 +123,7 @@ def profile_num_available_blocks( peak_memory = total_gpu_memory - free_gpu_memory cache_block_size = CacheEngine.get_cache_block_size( - block_size, self.model_config, self.parallel_config) + block_size, cache_dtype, self.model_config, self.parallel_config) num_gpu_blocks = int( (total_gpu_memory * gpu_memory_utilization - peak_memory) // cache_block_size)