Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize paged attention on triton3 #2553

Merged
merged 12 commits into from
Oct 18, 2024
Merged
14 changes: 9 additions & 5 deletions lmdeploy/pytorch/check_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def check_env_torch():
_handle_exception(e, 'PyTorch', logger)


MAX_TRITON_VERSION = '2.2.0'
MAX_TRITON_VERSION = '3.0.0'


def check_env_triton():
Expand Down Expand Up @@ -128,7 +128,8 @@ def check_awq(hf_config):


def check_transformers_version(model_path: str,
trust_remote_code: bool = True):
trust_remote_code: bool = True,
dtype: str = 'auto'):
"""check transformers version."""
from packaging import version
logger = get_logger('lmdeploy')
Expand Down Expand Up @@ -192,7 +193,8 @@ def __check_model_dtype_support(config):

try:
model_config = ModelConfig.from_hf_config(config,
model_path=model_path)
model_path=model_path,
dtype=dtype)
if model_config.dtype == torch.bfloat16:
assert torch.cuda.is_bf16_supported(), (
'bf16 is not supported on your device')
Expand All @@ -215,11 +217,13 @@ def __check_model_dtype_support(config):
check_awq(config)


def check_model(model_path: str, trust_remote_code: bool = True):
def check_model(model_path: str,
trust_remote_code: bool = True,
dtype: str = 'auto'):
"""check model requirements."""
logger = get_logger('lmdeploy')
logger.info('Checking model.')
check_transformers_version(model_path, trust_remote_code)
check_transformers_version(model_path, trust_remote_code, dtype)


def check_adapter(path: str):
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(self,
else:
engine_config = copy.deepcopy(engine_config)
check_env(engine_config.device_type)
check_model(model_path, trust_remote_code)
check_model(model_path, trust_remote_code, engine_config.dtype)
if engine_config.max_batch_size is None:
engine_config.max_batch_size = get_max_batch_size(
engine_config.device_type)
Expand Down
40 changes: 27 additions & 13 deletions lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,29 @@ def _fill_kv_cache_kernel(
mask=maskv)


def fill_kv_cache(k_states: Tensor, v_states: Tensor, k_caches: Tensor,
v_caches: Tensor, q_start_loc: Tensor, q_seq_length: Tensor,
kv_seq_length: Tensor, max_q_seq_length: int,
block_offsets: Tensor):
def fill_kv_cache(k_states: Tensor,
v_states: Tensor,
k_caches: Tensor,
v_caches: Tensor,
q_start_loc: Tensor,
q_seq_length: Tensor,
kv_seq_length: Tensor,
max_q_seq_length: int,
block_offsets: Tensor,
kv_layout: str = 'bshd'):
"""fill key/value state to cache for paged attention."""
if kv_layout == 'bshd':
b_dim, s_dim, h_dim, d_dim = (0, 1, 2, 3)
elif kv_layout == 'bhsd':
b_dim, s_dim, h_dim, d_dim = (0, 2, 1, 3)
else:
raise RuntimeError('Unsupported layout.')

block_offsets = block_offsets.contiguous()
batch_size = block_offsets.size(0)
block_size, num_heads, head_dim = k_caches.size()[1:]
block_size = k_caches.size(s_dim)
num_heads = k_caches.size(h_dim)
head_dim = k_caches.size(d_dim)
head_dim_v = v_states.size(-1)
max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1

Expand Down Expand Up @@ -171,14 +185,14 @@ def fill_kv_cache(k_states: Tensor, v_states: Tensor, k_caches: Tensor,
stride_vss=v_states.stride(-3),
stride_vsh=v_states.stride(-2),
stride_vsd=v_states.stride(-1),
stride_kcn=k_caches.stride(0),
stride_kcb=k_caches.stride(1),
stride_kch=k_caches.stride(2),
stride_kcd=k_caches.stride(3),
stride_vcn=v_caches.stride(0),
stride_vcb=v_caches.stride(1),
stride_vch=v_caches.stride(2),
stride_vcd=v_caches.stride(3),
stride_kcn=k_caches.stride(b_dim),
stride_kcb=k_caches.stride(s_dim),
stride_kch=k_caches.stride(h_dim),
stride_kcd=k_caches.stride(d_dim),
stride_vcn=v_caches.stride(b_dim),
stride_vcb=v_caches.stride(s_dim),
stride_vch=v_caches.stride(h_dim),
stride_vcd=v_caches.stride(d_dim),
stride_boff=block_offsets.stride(0),
BLOCK=BLOCK,
BLOCK_D=BLOCK_D,
Expand Down
Loading
Loading