From 70255791eed9fa87c3227710f33556f67b10e443 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Thu, 7 Dec 2023 16:11:08 +0800 Subject: [PATCH 01/10] fix kernel bug --- csrc/activation_kernels.cu | 105 ++++++++---------- csrc/attention/attention_kernels.cu | 4 +- csrc/cache_kernels.cu | 5 + csrc/layernorm_kernels.cu | 4 +- csrc/pos_encoding_kernels.cu | 3 +- .../squeezellm/quant_cuda_kernel.cu | 3 +- tests/kernels/conftest.py | 15 +-- tests/kernels/test_activation.py | 30 ++--- tests/kernels/test_attention.py | 50 ++++----- tests/kernels/test_cache.py | 43 +++---- tests/kernels/test_layernorm.py | 9 +- tests/kernels/test_pos_encoding.py | 11 +- 12 files changed, 132 insertions(+), 150 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 89d1ba2d37dd..61769b5b95fb 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,21 +1,20 @@ -#include #include - +#include +#include #include "dispatch_utils.h" namespace vllm { -template -__device__ __forceinline__ T silu(const T& x) { +template __device__ __forceinline__ T silu(const T &x) { // x * sigmoid(x) - return (T) (((float) x) / (1.0f + expf((float) -x))); + return (T)(((float)x) / (1.0f + expf((float)-x))); } -template -__global__ void silu_and_mul_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., 2, d] - const int d) { +template +__global__ void +silu_and_mul_kernel(scalar_t *__restrict__ out, // [..., d] + const scalar_t *__restrict__ input, // [..., 2, d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); @@ -26,35 +25,30 @@ __global__ void silu_and_mul_kernel( } // namespace vllm -void silu_and_mul( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., 2 * d] +void silu_and_mul(torch::Tensor &out, // [..., d] + torch::Tensor &input) // [..., 2 * d] { int64_t num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; dim3 grid(num_tokens); dim3 block(std::min(d, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), - "silu_and_mul_kernel", - [&] { - vllm::silu_and_mul_kernel<<>>( - out.data_ptr(), - input.data_ptr(), - d); - }); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] { + vllm::silu_and_mul_kernel<<>>( + out.data_ptr(), input.data_ptr(), d); + }); } namespace vllm { // Element-wise activation kernel template. -template -__global__ void activation_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., d] - const int d) { +template +__global__ void +activation_kernel(scalar_t *__restrict__ out, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * d + idx]); @@ -65,50 +59,45 @@ __global__ void activation_kernel( } // namespace vllm // Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - int d = input.size(-1); \ - int64_t num_tokens = input.numel() / d; \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), \ - "activation_kernel", \ - [&] { \ - vllm::activation_kernel><<>>( \ - out.data_ptr(), \ - input.data_ptr(), \ - d); \ - }); +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / d; \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ + vllm::activation_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); namespace vllm { -template -__device__ __forceinline__ T gelu_new_kernel(const T& x) { - const float x3 = (float) (x * x * x); - const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); - return ((T) 0.5) * x * (((T) 1.0) + t); +template __device__ __forceinline__ T gelu_new_kernel(const T &x) { + const float x3 = (float)(x * x * x); + const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); + return ((T)0.5) * x * (((T)1.0) + t); } -template -__device__ __forceinline__ T gelu_fast_kernel(const T& x) { - const float f = (float) x; - const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); - return ((T) 0.5) * x * (((T) 1.0) + t); +template +__device__ __forceinline__ T gelu_fast_kernel(const T &x) { + const float f = (float)x; + const T t = + (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); + return ((T)0.5) * x * (((T)1.0) + t); } } // namespace vllm -void gelu_new( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_new(torch::Tensor &out, // [..., d] + torch::Tensor &input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } -void gelu_fast( - torch::Tensor& out, // [..., d] - torch::Tensor& input) // [..., d] +void gelu_fast(torch::Tensor &out, // [..., d] + torch::Tensor &input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 78e8d8ecd6d4..ebb3effc499c 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -17,7 +17,7 @@ */ #include #include - +#include #include "attention_dtypes.h" #include "attention_utils.cuh" @@ -608,6 +608,7 @@ void paged_attention_v1_launcher( dim3 grid(num_heads, num_seqs, 1); dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { // NOTE(woosuk): To reduce the compilation time, we only compile for the @@ -777,6 +778,7 @@ void paged_attention_v2_launcher( int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { // NOTE(woosuk): To reduce the compilation time, we only compile for the diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 3ad52b1681c0..b6356564a35e 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,5 +1,6 @@ #include #include +#include #include "dispatch_utils.h" @@ -32,6 +33,7 @@ void swap_blocks( void *dst_ptr = dst.data_ptr(); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + const at::cuda::OptionalCUDAGuard device_guard(src_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. for (const auto& pair : block_mapping) { @@ -126,6 +128,7 @@ void copy_blocks( const int numel_per_block = key_caches[0][0].numel(); dim3 grid(num_layers, num_pairs); dim3 block(std::min(1024, numel_per_block)); + const at::cuda::OptionalCUDAGuard device_guard(cache_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { @@ -206,6 +209,7 @@ void reshape_and_cache( dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( key.scalar_type(), @@ -366,6 +370,7 @@ void gather_cached_kv( dim3 grid(num_tokens); dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( key.scalar_type(), diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 7434f4fd7998..b620e6ca1022 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,6 +1,6 @@ #include #include - +#include #include "dispatch_utils.h" #include "reduction_utils.cuh" @@ -76,6 +76,7 @@ void rms_norm( dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), @@ -101,6 +102,7 @@ void fused_add_rms_norm( dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 0a5ec95f8c0d..2e0a724c14b3 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,6 +1,6 @@ #include #include - +#include #include "dispatch_utils.h" namespace vllm { @@ -93,6 +93,7 @@ void rotary_embedding( dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( query.scalar_type(), diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index 1392b877397b..30e3ba866e72 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -7,6 +7,7 @@ // half-tensor #include #include +#include #define BLOCKWIDTH 128 #define BLOCKHEIGHT4 16 @@ -134,7 +135,7 @@ void squeezellm_gemm( (width + BLOCKWIDTH - 1) / BLOCKWIDTH ); dim3 threads(BLOCKWIDTH); - + const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); vllm::squeezellm::NUQ4MatMulKernel<<>>( (half2*) vec.data(), mat.data_ptr(), diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index 97516bd3052c..6bd72502cf7a 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -5,14 +5,9 @@ def create_kv_caches( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - dtype: torch.dtype, - seed: int, -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + num_blocks: int, block_size: int, num_layers: int, num_heads: int, + head_size: int, dtype: torch.dtype, seed: int, + device: str) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -23,7 +18,7 @@ def create_kv_caches( for _ in range(num_layers): key_cache = torch.empty(size=key_cache_shape, dtype=dtype, - device='cuda') + device=device) key_cache.uniform_(-scale, scale) key_caches.append(key_cache) @@ -32,7 +27,7 @@ def create_kv_caches( for _ in range(num_layers): value_cache = torch.empty(size=value_cache_shape, dtype=dtype, - device='cuda') + device=device) value_cache.uniform_(-scale, scale) value_caches.append(value_cache) return key_caches, value_caches diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index ba062054bf40..8ee43ebd8be7 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -7,22 +7,26 @@ NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 4096, 5120, 13824] # Arbitrary values for testing SEEDS = [0] +DEVICES = [i for i in range(torch.cuda.device_count())] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() def test_silu_and_mul( num_tokens: int, d: int, dtype: torch.dtype, seed: int, + device: int, ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda") + gpu_id = f'cuda:{device}' + x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id) layer = SiluAndMul() out = layer(x) ref_out = layer._forward(x) @@ -33,16 +37,14 @@ def test_silu_and_mul( @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() -def test_gelu_new( - num_tokens: int, - d: int, - dtype: torch.dtype, - seed: int, -) -> None: +def test_gelu_new(num_tokens: int, d: int, dtype: torch.dtype, seed: int, + device: int) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") + gpu_id = f'cuda:{device}' + x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) layer = NewGELU() out = layer(x) ref_out = layer._forward(x) @@ -53,15 +55,13 @@ def test_gelu_new( @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -def test_gelu_fast( - num_tokens: int, - d: int, - dtype: torch.dtype, - seed: int, -) -> None: +@pytest.mark.parametrize("device", DEVICES) +def test_gelu_fast(num_tokens: int, d: int, dtype: torch.dtype, seed: int, + device: int) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") + gpu_id = f'cuda:{device}' + x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) layer = FastGELU() out = layer(x) ref_out = layer._forward(x) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index a65d4d54d7c8..478cea383a09 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -24,6 +24,7 @@ BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] SEEDS = [0] +DEVICES = [0] def ref_masked_attention( @@ -87,7 +88,7 @@ def ref_single_query_cached_kv_attention( alibi_bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. - position_ids = torch.arange(context_len, device="cuda").int() + position_ids = torch.arange(context_len, device=query.device).int() alibi_bias = (position_ids - context_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) @@ -105,45 +106,39 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -def test_paged_attention( - kv_cache_factory, - version: str, - num_seqs: int, - num_heads: Tuple[int, int], - head_size: int, - use_alibi: bool, - block_size: int, - dtype: torch.dtype, - seed: int, -) -> None: +@pytest.mark.parametrize("device", DEVICES) +def test_paged_attention(kv_cache_factory, version: str, num_seqs: int, + num_heads: Tuple[int, int], head_size: int, + use_alibi: bool, block_size: int, dtype: torch.dtype, + seed: int, device: int) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - + gpu_id = f'cuda:{device}' scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype, - device="cuda") + device=gpu_id) query.uniform_(-scale, scale) assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads head_mapping = torch.repeat_interleave( - torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), + torch.arange(num_kv_heads, dtype=torch.int32, device=gpu_id), num_queries_per_kv) alibi_slopes = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, - device="cuda") + device=gpu_id) context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] context_lens[-1] = MAX_SEQ_LEN max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id) # Create the block tables. max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size @@ -154,12 +149,12 @@ def test_paged_attention( for _ in range(max_num_blocks_per_seq) ] block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id) # Create the KV caches. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, dtype, - seed) + seed, gpu_id) key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. @@ -252,7 +247,7 @@ def ref_multi_query_kv_attention( attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype, device="cuda") + attn_mask = attn_mask.to(dtype=dtype, device=query.device) ref_output = ref_masked_attention( query[start_idx:end_idx], @@ -272,18 +267,15 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() -def test_multi_query_kv_attention( - num_seqs: int, - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - seed: int, -) -> None: +def test_multi_query_kv_attention(num_seqs: int, num_heads: Tuple[int, int], + head_size: int, dtype: torch.dtype, + seed: int, device: int) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - + gpu_id = f'cuda:{device}' # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use # a smaller MAX_SEQ_LEN here. @@ -297,7 +289,7 @@ def test_multi_query_kv_attention( num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype, - device="cuda") + device=gpu_id) qkv.uniform_(-scale, scale) query, key, value = qkv.split( [num_query_heads, num_kv_heads, num_kv_heads], dim=1) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 9b5d7687a3fe..5943114b8d42 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -14,6 +14,7 @@ NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing NUM_MAPPINGS = [256] # Arbitrary values for testing SEEDS = [0] +DEVICES = [i for i in range(torch.cuda.device_count())] @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @@ -24,22 +25,16 @@ @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() -def test_copy_blocks( - kv_cache_factory, - num_mappings: int, - num_layers: int, - num_heads: int, - head_size: int, - block_size: int, - num_blocks: int, - dtype: torch.dtype, - seed: int, -) -> None: +def test_copy_blocks(kv_cache_factory, num_mappings: int, num_layers: int, + num_heads: int, head_size: int, block_size: int, + num_blocks: int, dtype: torch.dtype, seed: int, + device: int) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - + gpu_id = f'cuda:{device}' # Generate random block mappings where each source block is mapped to two # destination blocks. assert 2 * num_mappings <= num_blocks @@ -56,7 +51,7 @@ def test_copy_blocks( # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, num_layers, num_heads, - head_size, dtype, seed) + head_size, dtype, seed, gpu_id) # Clone the KV caches. cloned_key_caches = [key_cache.clone() for key_cache in key_caches] @@ -88,38 +83,32 @@ def test_copy_blocks( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() -def test_reshape_and_cache( - kv_cache_factory, - num_tokens: int, - num_heads: int, - head_size: int, - block_size: int, - num_blocks: int, - dtype: torch.dtype, - seed: int, -) -> None: +def test_reshape_and_cache(kv_cache_factory, num_tokens: int, num_heads: int, + head_size: int, block_size: int, num_blocks: int, + dtype: torch.dtype, seed: int, device: int) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - + gpu_id = f'cuda:{device}' # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device="cuda") + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id) qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, - device="cuda") + device=gpu_id) _, key, value = qkv.unbind(dim=1) # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, num_heads, head_size, dtype, - seed) + seed, gpu_id) key_cache, value_cache = key_caches[0], value_caches[0] # Clone the KV caches. diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index b362e2c43f0d..253fcb3f39ab 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -8,6 +8,7 @@ HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing ADD_RESIDUAL = [False, True] SEEDS = [0] +DEVICES = [i for i in range(torch.cuda.device_count())] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -15,6 +16,7 @@ @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() def test_rms_norm( num_tokens: int, @@ -22,14 +24,15 @@ def test_rms_norm( add_residual: bool, dtype: torch.dtype, seed: int, + device: int, ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - - layer = RMSNorm(hidden_size).to(dtype).cuda() + gpu_id = f'cuda:{device}' + layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id) layer.weight.data.normal_(mean=1.0, std=0.1) scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda") + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=gpu_id) x *= scale residual = torch.randn_like(x) * scale if add_residual else None diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 25d6bf2378ca..4658f06aa774 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -13,6 +13,7 @@ BATCH_SIZES = [1, 5] # Arbitrary values for testing SEQ_LENS = [11, 8192] # Arbitrary values for testing SEEDS = [0] +DEVICES = [i for i in range(torch.cuda.device_count())] @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -23,6 +24,7 @@ @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() def test_rotary_embedding( is_neox_style: bool, @@ -33,6 +35,7 @@ def test_rotary_embedding( rotary_dim: Optional[int], dtype: torch.dtype, seed: int, + device: int, max_position: int = 8192, base: int = 10000, ) -> None: @@ -40,20 +43,20 @@ def test_rotary_embedding( rotary_dim = head_size torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - + gpu_id = f'cuda:{device}' if rotary_dim is None: rotary_dim = head_size rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) - rope = rope.to(dtype).cuda() + rope = rope.to(dtype=dtype, device=gpu_id) positions = torch.randint(0, max_position, (batch_size, seq_len), - device="cuda") + device=gpu_id) query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype, - device="cuda") + device=gpu_id) key = torch.randn_like(query) # NOTE(woosuk): The reference implementation should be executed first From 207ee6008273aeead740813eb3fc799041fd18d9 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Thu, 7 Dec 2023 16:12:23 +0800 Subject: [PATCH 02/10] fix kernel bug --- tests/kernels/test_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 478cea383a09..f9f26be33b9c 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -24,7 +24,7 @@ BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] SEEDS = [0] -DEVICES = [0] +DEVICES = [i for i in range(torch.cuda.device_count())] def ref_masked_attention( From efbf0e5b1876f8b8056dcdce860d49f2116ad899 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Thu, 28 Dec 2023 10:07:25 +0800 Subject: [PATCH 03/10] modify code format --- csrc/activation_kernels.cu | 1 + csrc/attention/attention_kernels.cu | 1 + csrc/layernorm_kernels.cu | 1 + csrc/pos_encoding_kernels.cu | 1 + 4 files changed, 4 insertions(+) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 61769b5b95fb..0156c90e4769 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,6 +1,7 @@ #include #include #include + #include "dispatch_utils.h" namespace vllm { diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index ebb3effc499c..b69160923663 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -18,6 +18,7 @@ #include #include #include + #include "attention_dtypes.h" #include "attention_utils.cuh" diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index b620e6ca1022..6d34d014c858 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,6 +1,7 @@ #include #include #include + #include "dispatch_utils.h" #include "reduction_utils.cuh" diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 2e0a724c14b3..5df31aafff2f 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,6 +1,7 @@ #include #include #include + #include "dispatch_utils.h" namespace vllm { From a1d330034c21733b4bdfd89a093e056709977a4b Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Thu, 28 Dec 2023 10:45:10 +0800 Subject: [PATCH 04/10] shorten the test time --- tests/kernels/test_activation.py | 2 +- tests/kernels/test_attention.py | 2 +- tests/kernels/test_cache.py | 2 +- tests/kernels/test_layernorm.py | 2 +- tests/kernels/test_pos_encoding.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 8ee43ebd8be7..c569a8535d90 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -7,7 +7,7 @@ NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 4096, 5120, 13824] # Arbitrary values for testing SEEDS = [0] -DEVICES = [i for i in range(torch.cuda.device_count())] +DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index f9f26be33b9c..36fea9fc8b08 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -24,7 +24,7 @@ BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] SEEDS = [0] -DEVICES = [i for i in range(torch.cuda.device_count())] +DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] def ref_masked_attention( diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 5943114b8d42..46423345c82f 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -14,7 +14,7 @@ NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing NUM_MAPPINGS = [256] # Arbitrary values for testing SEEDS = [0] -DEVICES = [i for i in range(torch.cuda.device_count())] +DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 253fcb3f39ab..6f3112eb03a5 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -8,7 +8,7 @@ HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing ADD_RESIDUAL = [False, True] SEEDS = [0] -DEVICES = [i for i in range(torch.cuda.device_count())] +DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 4658f06aa774..fe4693cc1a5c 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -13,7 +13,7 @@ BATCH_SIZES = [1, 5] # Arbitrary values for testing SEQ_LENS = [11, 8192] # Arbitrary values for testing SEEDS = [0] -DEVICES = [i for i in range(torch.cuda.device_count())] +DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) From 4d0cfe904aee3a86ad98bb602b7ab627e740d5c4 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Fri, 29 Dec 2023 00:35:00 +0800 Subject: [PATCH 05/10] revert the code format --- csrc/activation_kernels.cu | 103 +++++++++++++++++++++---------------- 1 file changed, 58 insertions(+), 45 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 0156c90e4769..1a2adc5587e3 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -6,16 +6,17 @@ namespace vllm { -template __device__ __forceinline__ T silu(const T &x) { +template +__device__ __forceinline__ T silu(const T &x) { // x * sigmoid(x) return (T)(((float)x) / (1.0f + expf((float)-x))); } template -__global__ void -silu_and_mul_kernel(scalar_t *__restrict__ out, // [..., d] - const scalar_t *__restrict__ input, // [..., 2, d] - const int d) { +__global__ void silu_and_mul_kernel( + scalar_t *__restrict__ out, // [..., d] + const scalar_t *__restrict__ input, // [..., 2, d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); @@ -24,35 +25,41 @@ silu_and_mul_kernel(scalar_t *__restrict__ out, // [..., d] } } -} // namespace vllm +} // namespace vllm -void silu_and_mul(torch::Tensor &out, // [..., d] - torch::Tensor &input) // [..., 2 * d] +void silu_and_mul( + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] { int64_t num_tokens = input.numel() / input.size(-1); int d = input.size(-1) / 2; dim3 grid(num_tokens); dim3 block(std::min(d, 1024)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_kernel", [&] { - vllm::silu_and_mul_kernel<<>>( - out.data_ptr(), input.data_ptr(), d); - }); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "silu_and_mul_kernel", + [&] { + vllm::silu_and_mul_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + d); + }); } namespace vllm { // Element-wise activation kernel template. -template -__global__ void -activation_kernel(scalar_t *__restrict__ out, // [..., d] - const scalar_t *__restrict__ input, // [..., d] - const int d) { +template +__global__ void activation_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., d] + const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x); } } @@ -60,45 +67,51 @@ activation_kernel(scalar_t *__restrict__ out, // [..., d] } // namespace vllm // Launch element-wise activation kernel. -#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ - int d = input.size(-1); \ - int64_t num_tokens = input.numel() / d; \ - dim3 grid(num_tokens); \ - dim3 block(std::min(d, 1024)); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ - vllm::activation_kernel> \ - <<>>(out.data_ptr(), \ - input.data_ptr(), d); \ - }); +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / d; \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), \ + "activation_kernel", \ + [&] { \ + vllm::activation_kernel><<>>( \ + out.data_ptr(), \ + input.data_ptr(), \ + d); \ + }); namespace vllm { -template __device__ __forceinline__ T gelu_new_kernel(const T &x) { - const float x3 = (float)(x * x * x); - const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); - return ((T)0.5) * x * (((T)1.0) + t); +template +__device__ __forceinline__ T gelu_new_kernel(const T& x) { + const float x3 = (float) (x * x * x); + const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); + return ((T) 0.5) * x * (((T) 1.0) + t); } -template -__device__ __forceinline__ T gelu_fast_kernel(const T &x) { - const float f = (float)x; - const T t = - (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); - return ((T)0.5) * x * (((T)1.0) + t); +template +__device__ __forceinline__ T gelu_fast_kernel(const T& x) { + const float f = (float) x; + const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); + return ((T) 0.5) * x * (((T) 1.0) + t); } } // namespace vllm -void gelu_new(torch::Tensor &out, // [..., d] - torch::Tensor &input) // [..., d] +void gelu_new( + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); } -void gelu_fast(torch::Tensor &out, // [..., d] - torch::Tensor &input) // [..., d] +void gelu_fast( + torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] { LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); } From 358476e4e6e10444f35f08e86d8bad2e815fe81f Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Fri, 29 Dec 2023 00:40:32 +0800 Subject: [PATCH 06/10] revert the code format --- csrc/activation_kernels.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 1a2adc5587e3..ee342551d094 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -6,16 +6,16 @@ namespace vllm { -template -__device__ __forceinline__ T silu(const T &x) { +template +__device__ __forceinline__ T silu(const T& x) { // x * sigmoid(x) - return (T)(((float)x) / (1.0f + expf((float)-x))); + return (T) (((float) x) / (1.0f + expf((float) -x))); } -template +template __global__ void silu_and_mul_kernel( - scalar_t *__restrict__ out, // [..., d] - const scalar_t *__restrict__ input, // [..., 2, d] + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { From 67396c7a3924f6413c20f28e50745fad41bfe6c0 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Fri, 29 Dec 2023 00:44:14 +0800 Subject: [PATCH 07/10] revert the code format --- csrc/activation_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index ee342551d094..cdadff9f187c 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -25,7 +25,7 @@ __global__ void silu_and_mul_kernel( } } -} // namespace vllm +} // namespace vllm void silu_and_mul( torch::Tensor& out, // [..., d] From 0165e9ae28b12a78f8589f84d8ab685c07c322d0 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Fri, 29 Dec 2023 00:55:55 +0800 Subject: [PATCH 08/10] fix ldg bug --- csrc/activation_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index cdadff9f187c..c70aa668f01a 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -59,7 +59,7 @@ __global__ void activation_kernel( const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); + const scalar_t x = __ldg(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x); } } From 2e0928e290d7132dee2ececfd99ec4c0e557a8d2 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Fri, 29 Dec 2023 09:35:01 +0800 Subject: [PATCH 09/10] add comma --- tests/kernels/conftest.py | 12 +++++++++--- tests/kernels/test_activation.py | 18 ++++++++++++++---- tests/kernels/test_attention.py | 27 ++++++++++++++++++++------- tests/kernels/test_cache.py | 30 +++++++++++++++++++++++------- 4 files changed, 66 insertions(+), 21 deletions(-) diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index 6bd72502cf7a..fca97ab76bf0 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -5,9 +5,15 @@ def create_kv_caches( - num_blocks: int, block_size: int, num_layers: int, num_heads: int, - head_size: int, dtype: torch.dtype, seed: int, - device: str) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index c569a8535d90..3e7b326fd4e7 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -39,8 +39,13 @@ def test_silu_and_mul( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() -def test_gelu_new(num_tokens: int, d: int, dtype: torch.dtype, seed: int, - device: int) -> None: +def test_gelu_new( + num_tokens: int, + d: int, + dtype: torch.dtype, + seed: int, + device: int, +) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) gpu_id = f'cuda:{device}' @@ -56,8 +61,13 @@ def test_gelu_new(num_tokens: int, d: int, dtype: torch.dtype, seed: int, @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) -def test_gelu_fast(num_tokens: int, d: int, dtype: torch.dtype, seed: int, - device: int) -> None: +def test_gelu_fast( + num_tokens: int, + d: int, + dtype: torch.dtype, + seed: int, + device: int, +) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) gpu_id = f'cuda:{device}' diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 36fea9fc8b08..5e41e387fcd6 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -107,10 +107,18 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) -def test_paged_attention(kv_cache_factory, version: str, num_seqs: int, - num_heads: Tuple[int, int], head_size: int, - use_alibi: bool, block_size: int, dtype: torch.dtype, - seed: int, device: int) -> None: +def test_paged_attention( + kv_cache_factory, + version: str, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + seed: int, + device: int, +) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -269,9 +277,14 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() -def test_multi_query_kv_attention(num_seqs: int, num_heads: Tuple[int, int], - head_size: int, dtype: torch.dtype, - seed: int, device: int) -> None: +def test_multi_query_kv_attention( + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + seed: int, + device: int, +) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 46423345c82f..d78d46696b2c 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -27,10 +27,18 @@ @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() -def test_copy_blocks(kv_cache_factory, num_mappings: int, num_layers: int, - num_heads: int, head_size: int, block_size: int, - num_blocks: int, dtype: torch.dtype, seed: int, - device: int) -> None: +def test_copy_blocks( + kv_cache_factory, + num_mappings: int, + num_layers: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: int, +) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) @@ -85,9 +93,17 @@ def test_copy_blocks(kv_cache_factory, num_mappings: int, num_layers: int, @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) @torch.inference_mode() -def test_reshape_and_cache(kv_cache_factory, num_tokens: int, num_heads: int, - head_size: int, block_size: int, num_blocks: int, - dtype: torch.dtype, seed: int, device: int) -> None: +def test_reshape_and_cache( + kv_cache_factory, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: int, +) -> None: random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) From e25df181d8c5fa764bdc3ce93323e7ce0abedb85 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 3 Jan 2024 02:57:50 +0000 Subject: [PATCH 10/10] Use double quote --- tests/kernels/test_activation.py | 6 +++--- tests/kernels/test_attention.py | 4 ++-- tests/kernels/test_cache.py | 4 ++-- tests/kernels/test_layernorm.py | 2 +- tests/kernels/test_pos_encoding.py | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 3e7b326fd4e7..826bf8350af1 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -25,7 +25,7 @@ def test_silu_and_mul( ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - gpu_id = f'cuda:{device}' + gpu_id = f"cuda:{device}" x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id) layer = SiluAndMul() out = layer(x) @@ -48,7 +48,7 @@ def test_gelu_new( ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - gpu_id = f'cuda:{device}' + gpu_id = f"cuda:{device}" x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) layer = NewGELU() out = layer(x) @@ -70,7 +70,7 @@ def test_gelu_fast( ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - gpu_id = f'cuda:{device}' + gpu_id = f"cuda:{device}" x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) layer = FastGELU() out = layer(x) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 5e41e387fcd6..91e84d96a140 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -122,7 +122,7 @@ def test_paged_attention( random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - gpu_id = f'cuda:{device}' + gpu_id = f"cuda:{device}" scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads query = torch.empty(num_seqs, @@ -288,7 +288,7 @@ def test_multi_query_kv_attention( random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - gpu_id = f'cuda:{device}' + gpu_id = f"cuda:{device}" # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use # a smaller MAX_SEQ_LEN here. diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index d78d46696b2c..1d8d41e013b0 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -42,7 +42,7 @@ def test_copy_blocks( random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - gpu_id = f'cuda:{device}' + gpu_id = f"cuda:{device}" # Generate random block mappings where each source block is mapped to two # destination blocks. assert 2 * num_mappings <= num_blocks @@ -107,7 +107,7 @@ def test_reshape_and_cache( random.seed(seed) torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - gpu_id = f'cuda:{device}' + gpu_id = f"cuda:{device}" # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping = random.sample(range(num_slots), num_tokens) diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 6f3112eb03a5..8a06b3aa268b 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -28,7 +28,7 @@ def test_rms_norm( ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - gpu_id = f'cuda:{device}' + gpu_id = f"cuda:{device}" layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id) layer.weight.data.normal_(mean=1.0, std=0.1) scale = 1 / (2 * hidden_size) diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index fe4693cc1a5c..aad310e2bc6d 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -43,7 +43,7 @@ def test_rotary_embedding( rotary_dim = head_size torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - gpu_id = f'cuda:{device}' + gpu_id = f"cuda:{device}" if rotary_dim is None: rotary_dim = head_size rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)