Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize paged attention on triton3 #2553

Merged
merged 12 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/unit-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
- name: Install pytorch
run: |
python3 -m pip cache dir
python3 -m pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu118
python3 -m pip install torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
- name: Build lmdeploy
run: |
python3 -m pip install cmake
Expand All @@ -77,7 +77,7 @@ jobs:
run: |
python3 -m pip install pynvml packaging protobuf transformers_stream_generator
# manually install flash attn
python3 -m pip install /root/packages/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
python3 -m pip install -r requirements.txt -r requirements/test.txt
python3 -m pip install .
- name: Check env
Expand Down
18 changes: 10 additions & 8 deletions lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,13 @@ def check_env_triton(device: str):

if device == 'cuda':
device_cap = torch.cuda.get_device_capability()
TRITON_VER_220 = version.parse('2.2.0')
TRITON_VER_231 = version.parse('2.3.1')

if device_cap[0] <= 7:
if (triton_version >= TRITON_VER_220
and triton_version <= TRITON_VER_231):
if triton_version <= TRITON_VER_231:
err = RuntimeError(
'Attention triton kernel does not fully support '
'triton[2.2.0~2.3.1] on device with capability<8. '
'triton<3.0.0 on device with capability<8. '
'Please upgrade your triton version.')
_handle_exception(err, 'Triton', logger)

Expand Down Expand Up @@ -142,7 +140,8 @@ def check_awq(hf_config):


def check_transformers_version(model_path: str,
trust_remote_code: bool = True):
trust_remote_code: bool = True,
dtype: str = 'auto'):
"""check transformers version."""
from packaging import version
logger = get_logger('lmdeploy')
Expand Down Expand Up @@ -206,7 +205,8 @@ def __check_model_dtype_support(config):

try:
model_config = ModelConfig.from_hf_config(config,
model_path=model_path)
model_path=model_path,
dtype=dtype)
if model_config.dtype == torch.bfloat16:
assert torch.cuda.is_bf16_supported(), (
'bf16 is not supported on your device')
Expand All @@ -229,11 +229,13 @@ def __check_model_dtype_support(config):
check_awq(config)


def check_model(model_path: str, trust_remote_code: bool = True):
def check_model(model_path: str,
trust_remote_code: bool = True,
dtype: str = 'auto'):
"""check model requirements."""
logger = get_logger('lmdeploy')
logger.info('Checking model.')
check_transformers_version(model_path, trust_remote_code)
check_transformers_version(model_path, trust_remote_code, dtype)


def check_adapter(path: str):
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self,
else:
engine_config = copy.deepcopy(engine_config)
check_env(engine_config.device_type)
check_model(model_path, trust_remote_code)
check_model(model_path, trust_remote_code, engine_config.dtype)
if engine_config.max_batch_size is None:
engine_config.max_batch_size = get_max_batch_size(
engine_config.device_type)
Expand Down
61 changes: 35 additions & 26 deletions lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,12 +378,21 @@ def fill_kv_cache(k_states: Tensor,
block_offsets: Tensor,
k_scales_zeros: Tensor = None,
v_scales_zeros: Tensor = None,
quant_policy: Literal[0, 4, 8] = 0):
quant_policy: Literal[0, 4, 8] = 0,
kv_layout: str = 'bshd'):
"""fill key/value state to cache for paged attention."""
if kv_layout == 'bshd':
b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)
elif kv_layout == 'bhsd':
b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3)
else:
raise RuntimeError('Unsupported layout.')

block_offsets = block_offsets.contiguous()
batch_size = block_offsets.size(0)
block_size, num_heads, head_dim = k_caches.size()[1:]
block_size = k_caches.size(s_dim)
num_heads = k_caches.size(h_dim)
head_dim = k_caches.size(d_dim)
head_dim_v = v_states.size(-1)
max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1

Expand Down Expand Up @@ -412,14 +421,14 @@ def fill_kv_cache(k_states: Tensor,
stride_vss=v_states.stride(-3),
stride_vsh=v_states.stride(-2),
stride_vsd=v_states.stride(-1),
stride_kcn=k_caches.stride(0),
stride_kcb=k_caches.stride(1),
stride_kch=k_caches.stride(2),
stride_kcd=k_caches.stride(3),
stride_vcn=v_caches.stride(0),
stride_vcb=v_caches.stride(1),
stride_vch=v_caches.stride(2),
stride_vcd=v_caches.stride(3),
stride_kcn=k_caches.stride(b_dim),
stride_kcb=k_caches.stride(s_dim),
stride_kch=k_caches.stride(h_dim),
stride_kcd=k_caches.stride(d_dim),
stride_vcn=v_caches.stride(b_dim),
stride_vcb=v_caches.stride(s_dim),
stride_vch=v_caches.stride(h_dim),
stride_vcd=v_caches.stride(d_dim),
stride_boff=block_offsets.stride(0),
BLOCK=BLOCK,
BLOCK_D=BLOCK_D,
Expand Down Expand Up @@ -450,22 +459,22 @@ def fill_kv_cache(k_states: Tensor,
stride_vss=v_states.stride(-3),
stride_vsh=v_states.stride(-2),
stride_vsd=v_states.stride(-1),
stride_kcn=k_caches.stride(0),
stride_kcb=k_caches.stride(1),
stride_kch=k_caches.stride(2),
stride_kcd=k_caches.stride(3),
stride_vcn=v_caches.stride(0),
stride_vcb=v_caches.stride(1),
stride_vch=v_caches.stride(2),
stride_vcd=v_caches.stride(3),
stride_kszn=k_scales_zeros.stride(0),
stride_kszb=k_scales_zeros.stride(1),
stride_kszh=k_scales_zeros.stride(2),
stride_kszd=k_scales_zeros.stride(3),
stride_vszn=v_scales_zeros.stride(0),
stride_vszb=v_scales_zeros.stride(1),
stride_vszh=v_scales_zeros.stride(2),
stride_vszd=v_scales_zeros.stride(3),
stride_kcn=k_caches.stride(b_dim),
stride_kcb=k_caches.stride(s_dim),
stride_kch=k_caches.stride(h_dim),
stride_kcd=k_caches.stride(d_dim),
stride_vcn=v_caches.stride(b_dim),
stride_vcb=v_caches.stride(s_dim),
stride_vch=v_caches.stride(h_dim),
stride_vcd=v_caches.stride(d_dim),
stride_kszn=k_scales_zeros.stride(b_dim),
stride_kszb=k_scales_zeros.stride(s_dim),
stride_kszh=k_scales_zeros.stride(h_dim),
stride_kszd=k_scales_zeros.stride(d_dim),
stride_vszn=v_scales_zeros.stride(b_dim),
stride_vszb=v_scales_zeros.stride(s_dim),
stride_vszh=v_scales_zeros.stride(h_dim),
stride_vszd=v_scales_zeros.stride(d_dim),
quant_policy=quant_policy,
stride_boff=block_offsets.stride(0),
BLOCK=BLOCK,
Expand Down
Loading
Loading