Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "cuda_compat.h"
#include "dispatch_utils.h"

namespace vllm {
Expand All @@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel(
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y;
}
}
Expand Down Expand Up @@ -57,7 +58,7 @@ __global__ void activation_kernel(
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
Expand Down
30 changes: 29 additions & 1 deletion csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
#ifndef USE_ROCM
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
#else
sum += __shfl_xor(sum, mask);
#endif
}

// Warp leaders store the data to shared memory.
Expand All @@ -58,11 +62,19 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
#ifndef USE_ROCM
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
#else
sum += __shfl_xor(sum, mask);
#endif
}

// Broadcast to other threads.
#ifndef USE_ROCM
return __shfl_sync(uint32_t(-1), sum, 0);
#else
return __shfl(sum, 0);
#endif
}

// Grid: (num_heads, num_seqs).
Expand Down Expand Up @@ -196,7 +208,11 @@ __global__ void single_query_cached_kv_attention_kernel(
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
#ifndef USE_ROCM
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
#else
qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask));
#endif
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
Expand All @@ -208,10 +224,18 @@ __global__ void single_query_cached_kv_attention_kernel(
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
#ifndef USE_ROCM
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
#else
qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask));
#endif
}
// Broadcast the max qk value to all threads.
#ifndef USE_ROCM
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
#else
qk_max = __shfl(qk_max, 0);
#endif

// Get the sum of the exp values.
float exp_sum = 0.f;
Expand Down Expand Up @@ -284,7 +308,11 @@ __global__ void single_query_cached_kv_attention_kernel(
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
#ifndef USE_ROCM
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
#else
acc += __shfl_xor(acc, mask);
#endif
}
accs[i] = acc;
}
Expand Down Expand Up @@ -342,7 +370,7 @@ __global__ void single_query_cached_kv_attention_kernel(

#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
cudaFuncSetAttribute( \
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
(void*)vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
Expand Down
4 changes: 4 additions & 0 deletions csrc/attention/attention_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
#ifndef USE_ROCM
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
#else
qk += __shfl_xor(qk, mask);
#endif
}
return qk;
}
Expand Down
25 changes: 22 additions & 3 deletions csrc/attention/dtype_bfloat16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@
#include "attention_generic.cuh"
#include "dtype_float32.cuh"

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>

typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16;
#endif

#include <stdint.h>

namespace vllm {
Expand Down Expand Up @@ -98,7 +107,17 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return a + b;
#ifndef USE_ROCM
return a + b;
#else
// See https://github.com/RadeonOpenCompute/ROCm/issues/2534
hip_bfloat16 A, B;
__hip_bfloat16 c;
A.data = a.data;
B.data = b.data;
c.data = (A + B).data;
return c;
#endif
#endif
}

Expand Down
90 changes: 85 additions & 5 deletions csrc/attention/dtype_float16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include "attention_generic.cuh"
#include "dtype_float32.cuh"

#ifdef USE_ROCM
#include <hip/hip_fp16.h>
#endif

#include <stdint.h>

namespace vllm {
Expand Down Expand Up @@ -63,58 +67,107 @@ struct FloatVec<uint4> {

// Utility functions for type conversions.
inline __device__ uint32_t h0_h0(uint16_t a) {
#ifndef USE_ROCM
uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
#else
uint32_t b = a;
b <<= 16;
b |= a;
#endif
return b;
}

inline __device__ float half_to_float(uint16_t h) {
#ifndef USE_ROCM
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
#else
return __half2float(__ushort_as_half(h));
#endif
}

inline __device__ float2 half2_to_float2(uint32_t v) {
#ifndef USE_ROCM
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
#else
union {
__half2 h2;
uint32_t u32;
} V;
V.u32 = v;
return make_float2(half_to_float(V.h2.x), half_to_float(V.h2.y));
#endif
}

inline __device__ uint16_t float_to_half(float f) {
#ifndef USE_ROCM
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
return tmp.u16[0];
#else
return __half_as_ushort(__float2half(f));
#endif
}

inline __device__ uint32_t float2_to_half2(float2 f) {
#ifndef USE_ROCM
union {
uint32_t u32;
uint16_t u16[2];
} tmp;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
return tmp.u32;
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
union {
__half2 h2;
uint32_t u32;
} R;

R.h2.x = __half_as_ushort(__float2half_rn(f.x));
R.h2.y = __half_as_ushort(__float2half_rn(f.y));
return R.u32;
#endif
return tmp.u32;
}

// Vector addition.
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
#ifndef USE_ROCM
uint16_t c;
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
#else
return __half_as_ushort(__hadd(__ushort_as_half(a), __ushort_as_half(b)));
#endif
}

inline __device__ uint32_t add(uint32_t a, uint32_t b) {
#ifndef USE_ROCM
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
#else
union {
__half2 h2;
uint32_t u32;
} A, B, C;
A.u32 = a;
B.u32 = b;
C.h2 = __hadd2(A.h2, B.h2);
return C.u32;
#endif
}

inline __device__ uint2 add(uint2 a, uint2 b) {
Expand Down Expand Up @@ -157,16 +210,31 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
// Vector multiplication.
template<>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
#ifndef USE_ROCM
uint16_t c;
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
#else
return __half_as_ushort(__hmul(__ushort_as_half(a), __ushort_as_half(b)));
#endif
}

template<>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
#ifndef USE_ROCM
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
#else
union {
__half2 h2;
uint32_t u32;
} A, B, C;
A.u32 = a;
B.u32 = b;
C.h2 = __hmul2(A.h2, B.h2);
return C.u32;
#endif
}

template<>
Expand Down Expand Up @@ -271,9 +339,21 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {

// Vector fused multiply-add.
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
#ifndef USE_ROCM
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
return d;
#else
union {
__half2 h2;
uint32_t u32;
} A, B, C, D;
A.u32 = a;
B.u32 = b;
C.u32 = c;
D.h2 = __hfma2(A.h2, B.h2, C.h2);
return D.u32;
#endif
}

inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
Expand Down
17 changes: 9 additions & 8 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "cuda_compat.h"
#include "dispatch_utils.h"

#include <algorithm>
Expand Down Expand Up @@ -28,8 +29,8 @@ void swap_blocks(
TORCH_CHECK(false, "Invalid device combination");
}

void *src_ptr = src.data_ptr();
void *dst_ptr = dst.data_ptr();
char *src_ptr = static_cast<char*>(src.data_ptr());
char *dst_ptr = static_cast<char*>(dst.data_ptr());

const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
Expand Down Expand Up @@ -176,8 +177,8 @@ __global__ void reshape_and_cache_kernel(
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]);
value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]);
key_cache[tgt_key_idx] = VLLM_LDG(&key[src_key_idx]);
value_cache[tgt_value_idx] = VLLM_LDG(&value[src_value_idx]);
}
}

Expand Down Expand Up @@ -262,8 +263,8 @@ __global__ void gather_cached_kv_kernel(
+ head_offset * block_size
+ block_offset;

key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
}
}

Expand Down Expand Up @@ -328,8 +329,8 @@ __global__ void gather_cached_kv_kernel_optimized(
src_key_indices[j] = src_key_idx;
src_value_indices[j] = src_value_idx;

keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
values_to_store[j] = __ldg(&value_cache[src_value_idx]);
keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
}

#pragma unroll
Expand Down
Loading