Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fused ROPE and reshape cache kernel #229

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
moving load cos/sin to common code
maleksan85 committed Nov 21, 2024
commit 97582ccd6719932da8b709d2fde7216e688211ea
4 changes: 3 additions & 1 deletion csrc/rocm/fused_rope_and_reshape_cache.cu
Original file line number Diff line number Diff line change
@@ -84,8 +84,10 @@ inline __device__ void store_value_into_cache(
}
}



template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt, bool IS_NEOX>
__global__ void fused_rotary_embedding_and_reshape_cache_kernel(
__global__ void __launch_bounds__ (512) fused_rotary_embedding_and_reshape_cache_kernel(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]

Unchanged files with check annotations Beta

import random
from itertools import accumulate, product

Check failure on line 2 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.11)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:2:23: F401 `itertools.accumulate` imported but unused

Check failure on line 2 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.11)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:2:35: F401 `itertools.product` imported but unused

Check failure on line 2 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.8)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:2:23: F401 `itertools.accumulate` imported but unused

Check failure on line 2 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.8)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:2:35: F401 `itertools.product` imported but unused

Check failure on line 2 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.10)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:2:23: F401 `itertools.accumulate` imported but unused

Check failure on line 2 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.10)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:2:35: F401 `itertools.product` imported but unused

Check failure on line 2 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:2:23: F401 `itertools.accumulate` imported but unused

Check failure on line 2 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:2:35: F401 `itertools.product` imported but unused

Check failure on line 2 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.9)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:2:23: F401 `itertools.accumulate` imported but unused

Check failure on line 2 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.9)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:2:35: F401 `itertools.product` imported but unused
from time import perf_counter
from typing import List, Optional

Check failure on line 4 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.11)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:4:20: F401 `typing.List` imported but unused

Check failure on line 4 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.8)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:4:20: F401 `typing.List` imported but unused

Check failure on line 4 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.10)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:4:20: F401 `typing.List` imported but unused

Check failure on line 4 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:4:20: F401 `typing.List` imported but unused

Check failure on line 4 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.9)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:4:20: F401 `typing.List` imported but unused
import pytest
import torch
from vllm.config import CacheConfig
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT, get_rope
from .allclose_default import get_default_atol, get_default_rtol

Check failure on line 13 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.11)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:13:31: F401 `.allclose_default.get_default_atol` imported but unused

Check failure on line 13 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.11)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:13:49: F401 `.allclose_default.get_default_rtol` imported but unused

Check failure on line 13 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.8)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:13:31: F401 `.allclose_default.get_default_atol` imported but unused

Check failure on line 13 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.8)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:13:49: F401 `.allclose_default.get_default_rtol` imported but unused

Check failure on line 13 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.10)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:13:31: F401 `.allclose_default.get_default_atol` imported but unused

Check failure on line 13 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.10)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:13:49: F401 `.allclose_default.get_default_rtol` imported but unused

Check failure on line 13 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:13:31: F401 `.allclose_default.get_default_atol` imported but unused

Check failure on line 13 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.12)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:13:49: F401 `.allclose_default.get_default_rtol` imported but unused

Check failure on line 13 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.9)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:13:31: F401 `.allclose_default.get_default_atol` imported but unused

Check failure on line 13 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.9)

Ruff (F401)

tests/kernels/test_fused_rope_and_reshape_cache.py:13:49: F401 `.allclose_default.get_default_rtol` imported but unused
NUM_TOKENS = [42] # Arbitrary values for testing
is_neox_style, rope_scaling={"type": "llama3",
"low_freq_factor": 1.0,
"high_freq_factor": 2.0,
"original_max_position_embeddings": 1024},

Check failure on line 77 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.11)

Ruff (E501)

tests/kernels/test_fused_rope_and_reshape_cache.py:77:81: E501 Line too long (90 > 80)

Check failure on line 77 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.8)

Ruff (E501)

tests/kernels/test_fused_rope_and_reshape_cache.py:77:81: E501 Line too long (90 > 80)

Check failure on line 77 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.10)

Ruff (E501)

tests/kernels/test_fused_rope_and_reshape_cache.py:77:81: E501 Line too long (90 > 80)

Check failure on line 77 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/kernels/test_fused_rope_and_reshape_cache.py:77:81: E501 Line too long (90 > 80)

Check failure on line 77 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.9)

Ruff (E501)

tests/kernels/test_fused_rope_and_reshape_cache.py:77:81: E501 Line too long (90 > 80)
fused_with_kv_cache_op=True,
cache_config = cache_config)
rope = rope.to(dtype=dtype)
#------------------Simulate------------------------------
time_start = perf_counter()
ref_query, ref_key = rope.forward_cuda(positions, query.clone(), key.clone())

Check failure on line 202 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.11)

Ruff (E501)

tests/kernels/test_fused_rope_and_reshape_cache.py:202:81: E501 Line too long (81 > 80)

Check failure on line 202 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.8)

Ruff (E501)

tests/kernels/test_fused_rope_and_reshape_cache.py:202:81: E501 Line too long (81 > 80)

Check failure on line 202 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.10)

Ruff (E501)

tests/kernels/test_fused_rope_and_reshape_cache.py:202:81: E501 Line too long (81 > 80)

Check failure on line 202 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.12)

Ruff (E501)

tests/kernels/test_fused_rope_and_reshape_cache.py:202:81: E501 Line too long (81 > 80)

Check failure on line 202 in tests/kernels/test_fused_rope_and_reshape_cache.py

GitHub Actions / ruff (3.9)

Ruff (E501)

tests/kernels/test_fused_rope_and_reshape_cache.py:202:81: E501 Line too long (81 > 80)
ops.reshape_and_cache(ref_key.view(-1, num_kv_heads, head_size),
value.view(-1, num_kv_heads, head_size),
cloned_key_cache,
import torch
import torch.nn as nn
from vllm import _custom_ops as ops

Check failure on line 30 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.11)

Ruff (F401)

vllm/model_executor/layers/rotary_embedding.py:30:33: F401 `vllm._custom_ops` imported but unused

Check failure on line 30 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.8)

Ruff (F401)

vllm/model_executor/layers/rotary_embedding.py:30:33: F401 `vllm._custom_ops` imported but unused

Check failure on line 30 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.10)

Ruff (F401)

vllm/model_executor/layers/rotary_embedding.py:30:33: F401 `vllm._custom_ops` imported but unused

Check failure on line 30 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.12)

Ruff (F401)

vllm/model_executor/layers/rotary_embedding.py:30:33: F401 `vllm._custom_ops` imported but unused

Check failure on line 30 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.9)

Ruff (F401)

vllm/model_executor/layers/rotary_embedding.py:30:33: F401 `vllm._custom_ops` imported but unused
from vllm.config import CacheConfig
from vllm.model_executor.custom_op import CustomOp
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops

Check failure on line 162 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.11)

Ruff (F811)

vllm/model_executor/layers/rotary_embedding.py:162:41: F811 Redefinition of unused `ops` from line 30

Check failure on line 162 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.8)

Ruff (F811)

vllm/model_executor/layers/rotary_embedding.py:162:41: F811 Redefinition of unused `ops` from line 30

Check failure on line 162 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.10)

Ruff (F811)

vllm/model_executor/layers/rotary_embedding.py:162:41: F811 Redefinition of unused `ops` from line 30

Check failure on line 162 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.12)

Ruff (F811)

vllm/model_executor/layers/rotary_embedding.py:162:41: F811 Redefinition of unused `ops` from line 30

Check failure on line 162 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.9)

Ruff (F811)

vllm/model_executor/layers/rotary_embedding.py:162:41: F811 Redefinition of unused `ops` from line 30
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm._ipex_ops import ipex_ops as ops

Check failure on line 185 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.11)

Ruff (F811)

vllm/model_executor/layers/rotary_embedding.py:185:48: F811 Redefinition of unused `ops` from line 30

Check failure on line 185 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.8)

Ruff (F811)

vllm/model_executor/layers/rotary_embedding.py:185:48: F811 Redefinition of unused `ops` from line 30

Check failure on line 185 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.10)

Ruff (F811)

vllm/model_executor/layers/rotary_embedding.py:185:48: F811 Redefinition of unused `ops` from line 30

Check failure on line 185 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.12)

Ruff (F811)

vllm/model_executor/layers/rotary_embedding.py:185:48: F811 Redefinition of unused `ops` from line 30

Check failure on line 185 in vllm/model_executor/layers/rotary_embedding.py

GitHub Actions / ruff (3.9)

Ruff (F811)

vllm/model_executor/layers/rotary_embedding.py:185:48: F811 Redefinition of unused `ops` from line 30
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
dtype=query.dtype)