diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index c6ae5db8f9c4..617cf6c0e4a5 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel( const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); - const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = silu(x) * y; } } @@ -57,7 +58,7 @@ __global__ void activation_kernel( const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x); } } diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 505c63d2efd7..debde463786e 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -39,7 +39,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); +#else + sum += __shfl_xor(sum, mask); +#endif } // Warp leaders store the data to shared memory. @@ -58,11 +62,19 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); +#else + sum += __shfl_xor(sum, mask); +#endif } // Broadcast to other threads. +#ifndef USE_ROCM return __shfl_sync(uint32_t(-1), sum, 0); +#else + return __shfl(sum, 0); +#endif } // Grid: (num_heads, num_seqs). @@ -196,7 +208,11 @@ __global__ void single_query_cached_kv_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { +#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); +#else + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); +#endif } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -208,10 +224,18 @@ __global__ void single_query_cached_kv_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); +#else + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); +#endif } // Broadcast the max qk value to all threads. +#ifndef USE_ROCM qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); +#else + qk_max = __shfl(qk_max, 0); +#endif // Get the sum of the exp values. float exp_sum = 0.f; @@ -284,7 +308,11 @@ __global__ void single_query_cached_kv_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM acc += __shfl_xor_sync(uint32_t(-1), acc, mask); +#else + acc += __shfl_xor(acc, mask); +#endif } accs[i] = acc; } @@ -342,7 +370,7 @@ __global__ void single_query_cached_kv_attention_kernel( #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ cudaFuncSetAttribute( \ - vllm::single_query_cached_kv_attention_kernel, \ + (void*)vllm::single_query_cached_kv_attention_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ vllm::single_query_cached_kv_attention_kernel \ <<>>( \ diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index bb7df25b14f0..7e6b64eea96f 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -39,7 +39,11 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { float qk = sum(qk_vec); #pragma unroll for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM qk += __shfl_xor_sync(uint32_t(-1), qk, mask); +#else + qk += __shfl_xor(qk, mask); +#endif } return qk; } diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 2154bfcf8631..9ad2e299c7aa 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -21,8 +21,17 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" -#include -#include +#ifndef USE_ROCM + #include + #include +#else + #include + #include + + typedef __hip_bfloat162 __nv_bfloat162; + typedef __hip_bfloat16 __nv_bfloat16; +#endif + #include namespace vllm { @@ -98,7 +107,17 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else - return a + b; + #ifndef USE_ROCM + return a + b; + #else + // See https://github.com/RadeonOpenCompute/ROCm/issues/2534 + hip_bfloat16 A, B; + __hip_bfloat16 c; + A.data = a.data; + B.data = b.data; + c.data = (A + B).data; + return c; + #endif #endif } diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index e67921128d52..dc45dbf3daea 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -21,6 +21,10 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" +#ifdef USE_ROCM + #include +#endif + #include namespace vllm { @@ -63,58 +67,107 @@ struct FloatVec { // Utility functions for type conversions. inline __device__ uint32_t h0_h0(uint16_t a) { +#ifndef USE_ROCM uint32_t b; asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); +#else + uint32_t b = a; + b <<= 16; + b |= a; +#endif return b; } inline __device__ float half_to_float(uint16_t h) { +#ifndef USE_ROCM float f; asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); return f; +#else + return __half2float(__ushort_as_half(h)); +#endif } inline __device__ float2 half2_to_float2(uint32_t v) { +#ifndef USE_ROCM uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); return make_float2(half_to_float(lo), half_to_float(hi)); +#else + union { + __half2 h2; + uint32_t u32; + } V; + V.u32 = v; + return make_float2(half_to_float(V.h2.x), half_to_float(V.h2.y)); +#endif } inline __device__ uint16_t float_to_half(float f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); return tmp.u16[0]; +#else + return __half_as_ushort(__float2half(f)); +#endif } inline __device__ uint32_t float2_to_half2(float2 f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + #else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + #endif + return tmp.u32; #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + union { + __half2 h2; + uint32_t u32; + } R; + + R.h2.x = __half_as_ushort(__float2half_rn(f.x)); + R.h2.y = __half_as_ushort(__float2half_rn(f.y)); + return R.u32; #endif - return tmp.u32; } // Vector addition. inline __device__ uint16_t add(uint16_t a, uint16_t b) { +#ifndef USE_ROCM uint16_t c; asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; +#else + return __half_as_ushort(__hadd(__ushort_as_half(a), __ushort_as_half(b))); +#endif } inline __device__ uint32_t add(uint32_t a, uint32_t b) { +#ifndef USE_ROCM uint32_t c; asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; +#else + union { + __half2 h2; + uint32_t u32; + } A, B, C; + A.u32 = a; + B.u32 = b; + C.h2 = __hadd2(A.h2, B.h2); + return C.u32; +#endif } inline __device__ uint2 add(uint2 a, uint2 b) { @@ -157,16 +210,31 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { // Vector multiplication. template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { +#ifndef USE_ROCM uint16_t c; asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; +#else + return __half_as_ushort(__hmul(__ushort_as_half(a), __ushort_as_half(b))); +#endif } template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { +#ifndef USE_ROCM uint32_t c; asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; +#else + union { + __half2 h2; + uint32_t u32; + } A, B, C; + A.u32 = a; + B.u32 = b; + C.h2 = __hmul2(A.h2, B.h2); + return C.u32; +#endif } template<> @@ -271,9 +339,21 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { // Vector fused multiply-add. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { +#ifndef USE_ROCM uint32_t d; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); return d; +#else + union { + __half2 h2; + uint32_t u32; + } A, B, C, D; + A.u32 = a; + B.u32 = b; + C.u32 = c; + D.h2 = __hfma2(A.h2, B.h2, C.h2); + return D.u32; +#endif } inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index ddad2b5a29b9..1a9376b3103e 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" #include @@ -28,8 +29,8 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } - void *src_ptr = src.data_ptr(); - void *dst_ptr = dst.data_ptr(); + char *src_ptr = static_cast(src.data_ptr()); + char *dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -176,8 +177,8 @@ __global__ void reshape_and_cache_kernel( + head_idx * head_size * block_size + head_offset * block_size + block_offset; - key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]); - value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]); + key_cache[tgt_key_idx] = VLLM_LDG(&key[src_key_idx]); + value_cache[tgt_value_idx] = VLLM_LDG(&value[src_value_idx]); } } @@ -262,8 +263,8 @@ __global__ void gather_cached_kv_kernel( + head_offset * block_size + block_offset; - key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); - value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); + key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]); + value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]); } } @@ -328,8 +329,8 @@ __global__ void gather_cached_kv_kernel_optimized( src_key_indices[j] = src_key_idx; src_value_indices[j] = src_value_idx; - keys_to_store[j] = __ldg(&key_cache[src_key_idx]); - values_to_store[j] = __ldg(&value_cache[src_value_idx]); + keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]); + values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]); } #pragma unroll diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h new file mode 100644 index 000000000000..3348b78cfa19 --- /dev/null +++ b/csrc/cuda_compat.h @@ -0,0 +1,7 @@ +#pragma once + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index f1c30fe7ea99..2439f5922a3f 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -1,3 +1,7 @@ +#ifdef USE_ROCM + #include +#endif + int get_device_attribute( int attribute, int device_id) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index b4351ee0d794..1e977fa92837 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -19,14 +20,14 @@ inline __device__ void apply_rotary_embedding( // GPT-NeoX style rotary embedding. x_index = rot_offset; y_index = embed_dim + rot_offset; - cos = __ldg(cos_ptr + x_index); - sin = __ldg(sin_ptr + x_index); + cos = VLLM_LDG(cos_ptr + x_index); + sin = VLLM_LDG(sin_ptr + x_index); } else { // GPT-J style rotary embedding. x_index = 2 * rot_offset; y_index = 2 * rot_offset + 1; - cos = __ldg(cos_ptr + x_index / 2); - sin = __ldg(sin_ptr + x_index / 2); + cos = VLLM_LDG(cos_ptr + x_index / 2); + sin = VLLM_LDG(sin_ptr + x_index / 2); } const scalar_t x = arr[x_index]; diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index bc35aa0424b5..382ad162dfef 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -23,7 +23,11 @@ template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) +#ifndef USE_ROCM val += __shfl_xor_sync(0xffffffff, val, mask, 32); +#else + val += __shfl_xor(val, mask, 32); +#endif return val; } diff --git a/setup.py b/setup.py index 8b2ad97dd540..2d9b2afc067f 100644 --- a/setup.py +++ b/setup.py @@ -24,10 +24,6 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] -if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") - def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. @@ -64,66 +60,6 @@ def get_torch_arch_list() -> Set[str]: return set(arch_list) -# First, check the TORCH_CUDA_ARCH_LIST environment variable. -compute_capabilities = get_torch_arch_list() -if not compute_capabilities: - # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available - # GPUs on the current machine. - device_count = torch.cuda.device_count() - for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 7: - raise RuntimeError( - "GPUs with compute capability below 7.0 are not supported.") - compute_capabilities.add(f"{major}.{minor}") - -nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) -if not compute_capabilities: - # If no GPU is specified nor available, add all supported architectures - # based on the NVCC CUDA version. - compute_capabilities = set(SUPPORTED_ARCHS) - if nvcc_cuda_version < Version("11.1"): - compute_capabilities.remove("8.6") - if nvcc_cuda_version < Version("11.8"): - compute_capabilities.remove("8.9") - compute_capabilities.remove("9.0") - -# Validate the NVCC CUDA version. -if nvcc_cuda_version < Version("11.0"): - raise RuntimeError("CUDA 11.0 or higher is required to build the package.") -if nvcc_cuda_version < Version("11.1"): - if any(cc.startswith("8.6") for cc in compute_capabilities): - raise RuntimeError( - "CUDA 11.1 or higher is required for compute capability 8.6.") -if nvcc_cuda_version < Version("11.8"): - if any(cc.startswith("8.9") for cc in compute_capabilities): - # CUDA 11.8 is required to generate the code targeting compute capability 8.9. - # However, GPUs with compute capability 8.9 can also run the code generated by - # the previous versions of CUDA 11 and targeting compute capability 8.0. - # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 - # instead of 8.9. - warnings.warn( - "CUDA 11.8 or higher is required for compute capability 8.9. " - "Targeting compute capability 8.0 instead.") - compute_capabilities = set(cc for cc in compute_capabilities - if not cc.startswith("8.9")) - compute_capabilities.add("8.0+PTX") - if any(cc.startswith("9.0") for cc in compute_capabilities): - raise RuntimeError( - "CUDA 11.8 or higher is required for compute capability 9.0.") - -# Add target compute capabilities to NVCC flags. -for capability in compute_capabilities: - num = capability[0] + capability[2] - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] - if capability.endswith("+PTX"): - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] - -# Use NVCC threads to parallelize the build. -if nvcc_cuda_version >= Version("11.2"): - num_threads = min(os.cpu_count(), 8) - NVCC_FLAGS += ["--threads", str(num_threads)] - ext_modules = [] # Cache operations. @@ -181,6 +117,7 @@ def get_torch_arch_list() -> Set[str]: ) ext_modules.append(activation_extension) + # Quantization kernels. quantization_extension = CUDAExtension( name="vllm.quantization_ops", @@ -193,7 +130,8 @@ def get_torch_arch_list() -> Set[str]: "nvcc": NVCC_FLAGS, }, ) -ext_modules.append(quantization_extension) +if not torch.version.hip: + ext_modules.append(quantization_extension) # Misc. CUDA utils. cuda_utils_extension = CUDAExtension( diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 0b3ad0aa255a..44aa94493085 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -72,4 +72,4 @@ def test_gelu_fast( out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") activation_ops.gelu_fast(out, x) ref_out = get_activation("gelu_fast")(x) - assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) + assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-3) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 59d8b0a59ce6..672b0e888358 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -1,10 +1,9 @@ import random from typing import List, Optional, Tuple +from flash_attn.flash_attn_interface import _flash_attn_forward import pytest import torch -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from vllm import attention_ops from vllm.utils import get_max_shared_memory_bytes @@ -12,7 +11,7 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer -MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +MAX_SEQ_LEN = 8192 NUM_BLOCKS = 128 # Arbitrary values for testing DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -193,7 +192,7 @@ def test_single_query_cached_kv_attention( # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two # outputs. Thus, we use a relaxed tolerance for the test. - assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-3) def ref_multi_query_kv_attention( @@ -270,20 +269,28 @@ def test_multi_query_kv_attention( # Handle MQA and GQA key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) - attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) - output = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - ) - output = output.squeeze(0) cu_seq_lens = [0] for seq_len in seq_lens: cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + max_prompt_len = max(seq_lens) + + output = torch.empty_like(query) + _flash_attn_forward( + query, + key, + value, + output, + cu_seq_lens, + cu_seq_lens, + max_prompt_len, + max_prompt_len, + dropout_p=0.0, + softmax_scale=scale, + causal=True, + return_softmax=False, + ) + ref_output = ref_multi_query_kv_attention( cu_seq_lens, query, diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index d66041744084..0f66a9e35884 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -170,5 +170,5 @@ def test_rotary_embedding( ref_key = ref_key.view(num_tokens, num_heads * head_size) # Compare the results. - assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5) - assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5) + assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-3) + assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-3) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index a0a62034aa24..773d64c3a4bd 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,7 +1,6 @@ from typing import Dict, List, Optional, Tuple import torch -from xformers.ops import AttentionBias from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceData @@ -25,6 +24,7 @@ def __init__( seq_groups: List[Tuple[List[int], SamplingParams]], seq_data: Dict[int, SequenceData], prompt_lens: List[int], + cumulative_prompt_lens: torch.Tensor, slot_mapping: torch.Tensor, context_lens: torch.Tensor, max_context_len: int, @@ -34,6 +34,7 @@ def __init__( self.seq_groups = seq_groups self.seq_data = seq_data self.prompt_lens = prompt_lens + self.cumulative_prompt_lens = cumulative_prompt_lens self.slot_mapping = slot_mapping self.context_lens = context_lens self.max_context_len = max_context_len @@ -59,6 +60,7 @@ def __init__( self.num_prompts = len(prompt_lens) self.num_prompt_tokens = sum(prompt_lens) + self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 self.num_generation_tokens = context_lens.shape[0] self.num_valid_tokens = slot_mapping.shape[0] if block_tables.numel() > 0: @@ -69,7 +71,7 @@ def __init__( assert context_lens.shape[0] == self.num_generation_tokens # Set during the execution of the first attention op. - self.attn_bias: List[AttentionBias] = [] + self.attn_bias = [] def __repr__(self) -> str: # Print only useful metadata. diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index b1d0588d97f7..1d6f43d2ce56 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -3,9 +3,7 @@ import torch import torch.nn as nn -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, - LowerTriangularMaskWithTensorBias) +from flash_attn.flash_attn_interface import _flash_attn_forward from vllm import attention_ops from vllm import cache_ops @@ -116,19 +114,30 @@ def multi_query_kv_attention( value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=1) - - # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. - out = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=input_metadata.attn_bias[0], - p=0.0, - scale=self.scale, + + if query.dtype == torch.float: + raise ValueError('The float data type is not supported by ' + 'FlashAttention. Use the half data type instead.') + head_size = query.shape[-1] + if head_size > 128: + raise ValueError('FlashAttention does not support head_size > 128.') + + # Directly call FlashAttention's internal function to avoid allocating + # a new tensor for the output. + _flash_attn_forward( + query, + key, + value, + output, + input_metadata.cumulative_prompt_lens, + input_metadata.cumulative_prompt_lens, + input_metadata.max_prompt_len, + input_metadata.max_prompt_len, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + return_softmax=False, ) - # TODO(woosuk): Unnecessary copy. Optimize. - output.copy_(out.squeeze(0)) - return output def single_query_cached_kv_attention( self, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 951ba1f0ceba..84db237aca6c 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -23,9 +23,9 @@ "GPTJForCausalLM": GPTJForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM, "InternLMForCausalLM": InternLMForCausalLM, - "LlamaForCausalLM": LlamaForCausalLM, - "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* - "MistralForCausalLM": MistralForCausalLM, + # "LlamaForCausalLM": LlamaForCausalLM, + # "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* + # "MistralForCausalLM": MistralForCausalLM, "MPTForCausalLM": MPTForCausalLM, "OPTForCausalLM": OPTForCausalLM, "QWenLMHeadModel": QWenLMHeadModel, @@ -34,7 +34,7 @@ # FIXME(woosuk): Remove this once all models support quantization. _MODEL_CLASSES_SUPPORT_QUANTIZATION = [ - LlamaForCausalLM, + # LlamaForCausalLM, ] diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 01d85355b297..21678bf0b048 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -8,11 +8,11 @@ from vllm.model_executor.models.gpt_j import GPTJForCausalLM from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM from vllm.model_executor.models.internlm import InternLMForCausalLM -from vllm.model_executor.models.llama import LlamaForCausalLM +# from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.mpt import MPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.qwen import QWenLMHeadModel -from vllm.model_executor.models.mistral import MistralForCausalLM +# from vllm.model_executor.models.mistral import MistralForCausalLM __all__ = [ "AquilaForCausalLM", @@ -25,9 +25,9 @@ "GPTJForCausalLM", "GPTNeoXForCausalLM", "InternLMForCausalLM", - "LlamaForCausalLM", + # "LlamaForCausalLM", "MPTForCausalLM", "OPTForCausalLM", "QWenLMHeadModel", - "MistralForCausalLM", + # "MistralForCausalLM", ] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6fbc155d68d6..4bf709bd6f64 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -197,6 +197,11 @@ def _prepare_inputs( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + cumulative_prompt_lens: List[int] = [0] + for prompt_len in prompt_lens: + cumulative_prompt_lens.append( + cumulative_prompt_lens[-1] + prompt_len) + # Add generation tokens. max_context_len = 0 max_num_blocks_per_seq = 0 @@ -264,6 +269,8 @@ def _prepare_inputs( block_tables_tensor = torch.tensor(padded_block_tables, dtype=torch.int, device="cuda") + cumulative_prompt_lens_tensor = torch.tensor( + cumulative_prompt_lens, dtype=torch.int, device='cuda') seq_data: Dict[int, SequenceData] = {} for seq_group_metadata in seq_group_metadata_list: @@ -273,6 +280,7 @@ def _prepare_inputs( seq_groups=seq_groups, seq_data=seq_data, prompt_lens=prompt_lens, + cumulative_prompt_lens=cumulative_prompt_lens_tensor, slot_mapping=slot_mapping_tensor, context_lens=context_lens_tensor, max_context_len=max_context_len,