From 4a18be79e6a192c3baaf2bdb89e764229df745ca Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 27 Nov 2024 21:13:54 +0800 Subject: [PATCH 01/12] profile throughput without threads --- benchmark/profile_throughput.py | 34 +++++++--------- lmdeploy/pytorch/engine/model_agent.py | 55 -------------------------- 2 files changed, 15 insertions(+), 74 deletions(-) diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 58786d9c80..804b07d168 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import asyncio import csv import json import os import random import time from queue import Queue -from threading import Thread from typing import List, Tuple, Union import numpy as np @@ -86,15 +86,15 @@ def __init__(self, model_path: str, self.csv = csv self.pbar = None - def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, - temperature: float, top_p: float, top_k: int, - stream_output: bool): + async def _inference(self, req_queue: Queue, res_queue: Queue, + session_id: int, temperature: float, top_p: float, + top_k: int, stream_output: bool): model_inst = self.tm_model.create_instance() stats = [] # get each generated token's latency per_token_latency_stats = [] for prompt, input_seqlen, output_seqlen in iter( - req_queue.get, [None, None, None]): + req_queue.get_nowait, [None, None, None]): _per_token_latency_stats = [0] * (output_seqlen + 1) prev = time.perf_counter() n_prev_token = 0 @@ -102,7 +102,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, input_ids = self.tokenizer(prompt).input_ids state = DetokenizeState(len(input_ids)) - for outputs in model_inst.stream_infer( + async for outputs in model_inst.async_stream_infer( session_id, input_ids=input_ids, gen_config=GenerationConfig(max_new_tokens=output_seqlen, @@ -123,7 +123,7 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, prev = now # for pytorch engine to restart a session if isinstance(model_inst, EngineInstance): - model_inst.end(session_id) + await model_inst.async_end(session_id) assert output_seqlen <= n_token <= output_seqlen + 1, \ f'Error. session_id({session_id}) request {output_seqlen} ' \ f'tokens, but generate {n_token} tokens.\n' \ @@ -139,13 +139,12 @@ def _inference(self, req_queue: Queue, res_queue: Queue, session_id: int, # skip the first token latency per_token_latency_stats.append(_per_token_latency_stats[1:]) self.pbar.update(1) - res_queue.put((session_id, stats, per_token_latency_stats)) + res_queue.put_nowait((session_id, stats, per_token_latency_stats)) def process_request(self, requests, concurrency, temperature, top_p, top_k, stream_output): res_queue = Queue() req_queue = Queue() - threads = [] self.pbar = tqdm(total=len(requests)) @@ -157,18 +156,16 @@ def process_request(self, requests, concurrency, temperature, top_p, top_k, start = time.time() + event_loop = asyncio.get_event_loop() + # start threads + tasks = [] for i in range(concurrency): - t = Thread(target=self._inference, - args=(req_queue, res_queue, i, temperature, top_p, - top_k, stream_output), - daemon=True) - t.start() - threads.append(t) + task = self._inference(req_queue, res_queue, i, temperature, top_p, + top_k, stream_output) + tasks.append(task) - # wait for finish - for t in threads: - t.join() + event_loop.run_until_complete(asyncio.gather(*tasks)) elapsed_time = time.time() - start @@ -333,7 +330,6 @@ def main(): block_size=args.cache_block_seq_len, max_batch_size=args.concurrency, tp=args.tp, - thread_safe=True, eager_mode=args.eager_mode, enable_prefix_caching=args.enable_prefix_caching, quant_policy=args.quant_policy, diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 74938de812..74a5d70019 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -164,10 +164,6 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig): self.model_config = model_config self.cache_config = cache_config - def get_block_numel(self): - """get block nelement.""" - raise NotImplementedError('Not implemented') - async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. @@ -179,17 +175,6 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, """ raise NotImplementedError('Not implemented.') - def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, - swap_out_map: SwapMap): - """model forward. - - Args: - inputs (Dict): The input data comes from _make_inputs. - swap_in_map (SwapMap): Cache maps to swap in. - swap_out_map (SwapMap): Cache maps to swap out. - """ - raise NotImplementedError('Not implemented.') - def get_logits(self, hidden_states: torch.Tensor): """get logits of model output.""" raise NotImplementedError('Not implemented.') @@ -257,11 +242,6 @@ def _build_model(self, device=device) return patched_model - def get_block_numel(self): - """get block nelement.""" - k_cache = self.cache_engine.local_gpu_cache[0][0] - return k_cache[0].numel() - def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): cache_swapping(self.cache_engine, @@ -276,21 +256,6 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, ) return output - def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, - swap_out_map: SwapMap): - """model forward. - - Args: - inputs (Dict): The input data comes from _make_inputs. - swap_in_map (SwapMap): Cache maps to swap in. - swap_out_map (SwapMap): Cache maps to swap out. - """ - output = self._forward_impl(inputs, - swap_in_map=swap_in_map, - swap_out_map=swap_out_map) - self.stream.synchronize() - return output - async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. @@ -690,11 +655,6 @@ def _build_model( return model, cache_engine, cache_config - def get_block_numel(self): - """get block nelement.""" - k_cache = self.cache_engine.local_gpu_cache[0][0] - return k_cache[0].numel() - def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """forward impl.""" @@ -715,21 +675,6 @@ def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap, ) return output - def forward(self, inputs: ModelInputs, swap_in_map: SwapMap, - swap_out_map: SwapMap): - """model forward. - - Args: - inputs (Dict): The input data comes from _make_inputs. - swap_in_map (SwapMap): Cache maps to swap in. - swap_out_map (SwapMap): Cache maps to swap out. - """ - output = self._forward_impl(inputs, - swap_in_map=swap_in_map, - swap_out_map=swap_out_map) - self.stream.synchronize() - return output - async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, swap_out_map: SwapMap): """model forward. From 31afcf25857363d214ddbd64afc61073054fd088 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 28 Nov 2024 11:21:49 +0800 Subject: [PATCH 02/12] optimize main loop --- lmdeploy/pytorch/engine/engine.py | 21 ++++++++++----- lmdeploy/pytorch/engine/logits_process.py | 30 ++++++++++++++------- tests/pytorch/engine/test_logits_process.py | 3 ++- 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index cffe13bbdb..a56e6ba6a6 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -171,6 +171,7 @@ def __init__(self, self._start_loop() self._create_buffers() self.engine_instance = self.create_instance() + self._output_stream = torch.cuda.Stream() @classmethod def from_pretrained(cls, @@ -659,7 +660,8 @@ async def __long_context_single_forward(inputs): return ret def _make_infer_outputs(self, next_token_ids: torch.LongTensor, - logits: torch.Tensor, stopped: torch.Tensor): + logits: torch.Tensor, stopped: torch.Tensor, + event: torch.Event): """make infer output.""" def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence, @@ -680,6 +682,11 @@ def __get_q_start_loc(): else: return seq_length.cumsum(0) - seq_length + with torch.cuda.stream(self._output_stream): + event.wait() + next_token_ids = next_token_ids.cpu() + stopped = stopped.cpu() + running = self._running is_run = [seq.status == MessageStatus.RUNNING for seq in running] stopped = stopped.tolist() @@ -741,6 +748,7 @@ def __update_inputs(next_token_ids): logger.debug(': ' f'batch_size={inputs.seq_length.size(0)} ' f'num_tokens={inputs.input_ids.size(-1)}') + inputs = inputs.to_device('cuda') is_decoding = inputs.is_decoding if all_ids is not None: all_ids = all_ids.cuda() @@ -771,10 +779,11 @@ def __update_inputs(next_token_ids): next_token_ids, sampling_inputs.stop_words, num_appendable_ids) # send output - stopped = stopped.cpu() - finish = stopped.all().item() or (idx == loop_count - 1) + finish = (idx == loop_count - 1) finish = finish or _check_finish(self.scheduler, idx) - output = (next_token_ids.cpu(), logits, stopped) + event = torch.cuda.Event() + event.record() + output = (next_token_ids, logits, stopped, event) output_que.put_nowait((finish, output)) if finish: @@ -937,9 +946,9 @@ async def __step(): try: if isinstance(out, Exception): raise out - next_token_ids, logits, stopped = out + next_token_ids, logits, stopped, event = out step_outputs = self._make_infer_outputs( - next_token_ids, logits, stopped) + next_token_ids, logits, stopped, event) __send_resps(step_outputs) except Exception as e: raise e diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index 54740a4fb3..24cb336d71 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -21,10 +21,9 @@ def _process_temperature_(scores: torch.Tensor, temperature: torch.Tensor): def _process_bad_words_(scores: torch.Tensor, bad_words: torch.LongTensor, + mask: torch.BoolTensor, filter_value: float = -float('inf')): """process bad words.""" - mask = bad_words >= 0 - bad_words = bad_words.where(mask, 0) filtered_scores = scores.gather(1, bad_words) filtered_scores[mask] = filter_value scores.scatter_(1, bad_words, filtered_scores) @@ -127,7 +126,9 @@ def _guided_sampling(response_formats: Tuple[Dict], scores: torch.Tensor, class SamplingInputs: temperature: torch.Tensor = None bad_words: torch.LongTensor = None + bad_mask: torch.BoolTensor = None stop_words: torch.LongTensor = None + stop_mask: torch.BoolTensor = None repetition_penalty: torch.Tensor = None top_k: torch.LongTensor = None top_p: torch.Tensor = None @@ -200,9 +201,11 @@ def __get_bad_words(bad_words): """get bad words.""" max_bw_len = max(len(bw) for bw in bad_words) if max_bw_len == 0: - return None + return None, None if all(len(bw) == max_bw_len for bw in bad_words): - return torch.tensor(bad_words) + ret = torch.tensor(bad_words) + mask = torch.ones_like(ret, dtype=bool) + return ret, mask ret = torch.full((batch_size, max_bw_len), -1, dtype=torch.int64) for idx, bw in enumerate(bad_words): bw_len = len(bw) @@ -210,7 +213,10 @@ def __get_bad_words(bad_words): continue bw = ret.new_tensor(bw) ret[idx, :bw_len] = bw - return ret + + mask = ret >= 0 + ret = ret.where(mask, 0) + return ret, mask __gather_params() @@ -221,8 +227,8 @@ def __get_bad_words(bad_words): temperature = torch.tensor(temperature) - bad_words = __get_bad_words(bad_words) - stop_words = __get_bad_words(stop_words) + bad_words, bad_mask = __get_bad_words(bad_words) + stop_words, stop_mask = __get_bad_words(stop_words) max_top_k = max(top_k) if min(top_k) <= 0: @@ -243,7 +249,9 @@ def __get_bad_words(bad_words): sampling_input = cls( temperature=temperature, bad_words=bad_words, + bad_mask=bad_mask, stop_words=stop_words, + stop_mask=stop_mask, repetition_penalty=repetition_penalty, top_k=top_k, top_p=top_p, @@ -326,12 +334,14 @@ def __call__(self, all_ids: torch.LongTensor, bad_words = sampling_inputs.bad_words if bad_words is not None: - scores = _process_bad_words_(scores, bad_words) + bad_mask = sampling_inputs.bad_mask + scores = _process_bad_words_(scores, bad_words, bad_mask) stop_words = sampling_inputs.stop_words if stop_words is not None: - stop_words = torch.where(self.ignore_eos[:, None], stop_words, -1) - scores = _process_bad_words_(scores, stop_words) + stop_mask = sampling_inputs.stop_mask + stop_mask = torch.where(self.ignore_eos[:, None], stop_mask, False) + scores = _process_bad_words_(scores, stop_words, stop_mask) scores = _guided_sampling(sampling_inputs.response_formats, scores, guided_input_ids, self.tokenizer) diff --git a/tests/pytorch/engine/test_logits_process.py b/tests/pytorch/engine/test_logits_process.py index 5c5fdbdc18..69c8315411 100644 --- a/tests/pytorch/engine/test_logits_process.py +++ b/tests/pytorch/engine/test_logits_process.py @@ -35,8 +35,9 @@ def test_process_bad_words(): [4, 4], [-1, -1], ]) + mask = bad_words >= 0 - out_scores = _process_bad_words_(scores, bad_words) + out_scores = _process_bad_words_(scores, bad_words.where(mask, 0), mask) for score, bw in zip(out_scores, bad_words): bw = bw.tolist() From 88ad4dce7b7483ad7022d5cddd46d6a6e8ef2197 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 28 Nov 2024 11:24:55 +0800 Subject: [PATCH 03/12] fix torch.event --- lmdeploy/pytorch/engine/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index a56e6ba6a6..368ba19ff8 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -661,7 +661,7 @@ async def __long_context_single_forward(inputs): def _make_infer_outputs(self, next_token_ids: torch.LongTensor, logits: torch.Tensor, stopped: torch.Tensor, - event: torch.Event): + event: torch.cuda.Event): """make infer output.""" def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence, From 9585aeff29ad71cecd20a6ac2d070d06fbdacd1f Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 28 Nov 2024 12:12:37 +0800 Subject: [PATCH 04/12] fix python>3.11 --- benchmark/profile_throughput.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py index 804b07d168..4f06fad4f9 100644 --- a/benchmark/profile_throughput.py +++ b/benchmark/profile_throughput.py @@ -156,7 +156,8 @@ def process_request(self, requests, concurrency, temperature, top_p, top_k, start = time.time() - event_loop = asyncio.get_event_loop() + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) # start threads tasks = [] @@ -165,7 +166,10 @@ def process_request(self, requests, concurrency, temperature, top_p, top_k, top_k, stream_output) tasks.append(task) - event_loop.run_until_complete(asyncio.gather(*tasks)) + async def _gather_tasks(tasks): + return await asyncio.gather(*tasks) + + event_loop.run_until_complete(_gather_tasks(tasks)) elapsed_time = time.time() - start From 3ea4aa8b6a7c95f77731ee3014e6b10f094952e0 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 28 Nov 2024 14:08:45 +0800 Subject: [PATCH 05/12] optimize tp --- lmdeploy/pytorch/engine/engine.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 368ba19ff8..79b45b9593 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -748,7 +748,8 @@ def __update_inputs(next_token_ids): logger.debug(': ' f'batch_size={inputs.seq_length.size(0)} ' f'num_tokens={inputs.input_ids.size(-1)}') - inputs = inputs.to_device('cuda') + if self.gpu_count == 1: + inputs = inputs.to_device('cuda') is_decoding = inputs.is_decoding if all_ids is not None: all_ids = all_ids.cuda() From 549c6c60081b6243378184a7da0e8477fa8c1f70 Mon Sep 17 00:00:00 2001 From: grimoire Date: Thu, 28 Nov 2024 14:52:30 +0800 Subject: [PATCH 06/12] reduce cudagraph copy --- lmdeploy/pytorch/models/utils/cudagraph.py | 28 ++++++++++------------ 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/lmdeploy/pytorch/models/utils/cudagraph.py b/lmdeploy/pytorch/models/utils/cudagraph.py index 149376e4be..74d090a9a3 100644 --- a/lmdeploy/pytorch/models/utils/cudagraph.py +++ b/lmdeploy/pytorch/models/utils/cudagraph.py @@ -70,15 +70,14 @@ def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, *args, input_buffers['block_offsets'] = torch.zeros((max_batches, num_blocks), dtype=torch.int64, device=device) - input_buffers['q_start_loc'] = torch.zeros(max_batches, - dtype=torch.int64, - device=device) - input_buffers['q_seqlens'] = torch.zeros(max_batches, - dtype=torch.int64, - device=device) - input_buffers['kv_seqlens'] = torch.zeros(max_batches, - dtype=torch.int64, - device=device) + + input_buffers['qkv_lens'] = torch.zeros(3, + max_batches, + dtype=torch.int64, + device=device) + input_buffers['q_start_loc'] = input_buffers['qkv_lens'][0] + input_buffers['q_seqlens'] = input_buffers['qkv_lens'][1] + input_buffers['kv_seqlens'] = input_buffers['qkv_lens'][2] input_buffers['local_adapter_ids'] = torch.zeros(max_batches, dtype=torch.int64, device=device) @@ -111,13 +110,10 @@ def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, input_buffers['position_ids'][:, :num_tokens] = position_ids input_buffers[ 'block_offsets'][:batch_size, :num_blocks] = block_offsets - if q_seqlens.data_ptr() != input_buffers['q_seqlens'].data_ptr(): - input_buffers['q_seqlens'].zero_() - input_buffers['q_seqlens'][:batch_size] = q_seqlens - if kv_seqlens.data_ptr() != input_buffers['kv_seqlens'].data_ptr(): - input_buffers['kv_seqlens'].zero_() - input_buffers['kv_seqlens'][:batch_size] = kv_seqlens - input_buffers['q_start_loc'][:batch_size] = q_start_loc + + qkv = torch.stack((q_start_loc, q_seqlens, kv_seqlens)) + input_buffers['qkv_lens'].zero_() + input_buffers['qkv_lens'][:, :batch_size] = qkv if inputs_embeds is not None: emb_size = inputs_embeds.size(-1) if 'inputs_embeds' not in input_buffers: From 3df2e49d4d02f95d31ba5b41a21b4a961fd5d8cb Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 29 Nov 2024 12:58:30 +0800 Subject: [PATCH 07/12] optimize fill kv cache --- lmdeploy/pytorch/backends/cuda/attention.py | 5 +- .../pytorch/kernels/cuda/fill_kv_cache.py | 183 ++++++------------ 2 files changed, 63 insertions(+), 125 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/attention.py b/lmdeploy/pytorch/backends/cuda/attention.py index d01d6fe9b4..8261b869f0 100644 --- a/lmdeploy/pytorch/backends/cuda/attention.py +++ b/lmdeploy/pytorch/backends/cuda/attention.py @@ -94,7 +94,10 @@ def forward( kv_seqlens = attn_metadata.kv_seqlens kv_flatten_size = attn_metadata.kv_flatten_size quant_policy = attn_metadata.quant_policy - max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) + if attn_metadata.is_decoding: + max_q_seqlen = 1 + else: + max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2)) fill_max_q_seqlen = max_q_seqlen if attn_metadata.fill_seqlens is not None: fill_seqlens = attn_metadata.fill_seqlens diff --git a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py index 9ef614fadd..08dbb3138d 100644 --- a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py @@ -1,12 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Literal -import torch import triton import triton.language as tl from torch import Tensor -from .triton_utils import get_kernel_meta, wrap_jit_func +from .triton_utils import get_kernel_meta @triton.jit @@ -38,37 +37,6 @@ def _quant_int4(val1, val2): return q_val, scales, zeros -@wrap_jit_func(type_hint=dict( - KStates=Tensor, - VStates=Tensor, - KCaches=Tensor, - VCaches=Tensor, - QStartLoc=Tensor, - QSeqLens=Tensor, - KVSeqLens=Tensor, - BlockOffsets=Tensor, - num_heads=torch.int32, - head_dim=torch.int32, - stride_kss=int, - stride_ksh=int, - stride_ksd=int, - stride_vss=int, - stride_vsh=int, - stride_vsd=int, - stride_kcn=int, - stride_kcb=int, - stride_kch=int, - stride_kcd=int, - stride_vcn=int, - stride_vcb=int, - stride_vch=int, - stride_vcd=int, - stride_boff=int, - BLOCK=torch.int32, - BLOCK_D=torch.int32, - BLOCK_DV=torch.int32, - BLOCK_H=torch.int32, -)) @triton.jit def _fill_kv_cache_kernel( KStates, @@ -79,7 +47,7 @@ def _fill_kv_cache_kernel( QSeqLens, KVSeqLens, BlockOffsets, - num_heads: tl.constexpr, + is_decoding: tl.constexpr, head_dim: tl.constexpr, head_dim_v: tl.constexpr, stride_kss, @@ -100,108 +68,70 @@ def _fill_kv_cache_kernel( BLOCK: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_DV: tl.constexpr, - BLOCK_H: tl.constexpr, ): """fill kv cache kernel.""" - batch_id = tl.program_id(0) + batch_id = tl.program_id(2) + head_id = tl.program_id(0) block_id = tl.program_id(1) - # initialize - h_off = tl.arange(0, BLOCK_H) - d_off = tl.arange(0, BLOCK_D) - q_startloc = tl.load(QStartLoc + batch_id) q_seqlen = tl.load(QSeqLens + batch_id) kv_seqlen = tl.load(KVSeqLens + batch_id) history_seqlen = kv_seqlen - q_seqlen - block0_first_tokenloc = history_seqlen % BLOCK - - state_token_offset = tl.maximum(block_id * BLOCK - block0_first_tokenloc, - 0) - kv_block_id = _div_up(history_seqlen + 1, BLOCK) - 1 + block_id - kv_block_id = min(kv_block_id, stride_boff - 1) - block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id) + kv_block_id = history_seqlen // BLOCK + block_id - cur_startloc = q_startloc + state_token_offset - ks_ptr = KStates + cur_startloc * stride_kss - vs_ptr = VStates + cur_startloc * stride_vss + if kv_seqlen <= 0: + return - kc_ptr = KCaches + block_off * stride_kcn - vc_ptr = VCaches + block_off * stride_vcn + if kv_block_id * BLOCK >= kv_seqlen: + return - c_first_tokenloc = block0_first_tokenloc - if block_id != 0: - c_first_tokenloc *= 0 - c_last_tokenloc = tl.minimum( - BLOCK, q_seqlen + block0_first_tokenloc - block_id * BLOCK) + if is_decoding: + page_offs = tl.full((1, ), history_seqlen % BLOCK, dtype=tl.int32) + kv_mask = tl.full((1, ), 1, dtype=tl.int1) + q_offs = tl.full((1, ), q_startloc, dtype=tl.int32) + else: + page_offs = tl.arange(0, BLOCK) + kv_offs = kv_block_id * BLOCK + page_offs + kv_mask = (kv_offs >= history_seqlen) & (kv_offs < kv_seqlen) + token_off = q_startloc + kv_block_id * BLOCK - history_seqlen + q_offs = token_off + page_offs - for bidx in range(c_first_tokenloc, c_last_tokenloc): - sidx = bidx - c_first_tokenloc - mask = (h_off[:, None] < num_heads) & (d_off[None, :] < head_dim) - k = tl.load(ks_ptr + sidx * stride_kss + h_off[:, None] * stride_ksh + - d_off[None, :] * stride_ksd, - mask=mask) - tl.store(kc_ptr + bidx * stride_kcb + h_off[:, None] * stride_kch + - d_off[None, :] * stride_kcd, - k, - mask=mask) + block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id) - if BLOCK_DV > 0: - dv_off = tl.arange(0, BLOCK_DV) - maskv = (h_off[:, None] < num_heads) & (dv_off[None, :] < - head_dim_v) - v = tl.load(vs_ptr + sidx * stride_vss + - h_off[:, None] * stride_vsh + - dv_off[None, :] * stride_vsd, - mask=maskv) - tl.store(vc_ptr + bidx * stride_vcb + h_off[:, None] * stride_vch + - dv_off[None, :] * stride_vcd, - v, - mask=maskv) + d_off = tl.arange(0, BLOCK_D) + mask_ks = kv_mask[:, None] + mask_kc = mask_ks & (d_off[None, :] < head_dim) + d_off = d_off % head_dim + + ks_ptr = KStates + head_id * stride_ksh + ks_ptrs = ks_ptr + q_offs[:, + None] * stride_kss + d_off[None, :] * stride_ksd + kc_ptr = KCaches + block_off * stride_kcn + head_id * stride_kch + kc_ptrs = kc_ptr + page_offs[:, None] * stride_kcb + d_off[ + None, :] * stride_kcd + + if BLOCK_DV > 0: + dv_off = tl.arange(0, BLOCK_DV) + mask_vs = kv_mask[:, None] + mask_vc = mask_vs & (dv_off[None, :] < head_dim_v) + dv_off = dv_off % head_dim_v + vs_ptr = VStates + head_id * stride_vsh + vs_ptrs = vs_ptr + q_offs[:, None] * stride_vss + dv_off[ + None, :] * stride_vsd + vc_ptr = VCaches + block_off * stride_vcn + head_id * stride_vch + vc_ptrs = vc_ptr + page_offs[:, None] * stride_vcb + dv_off[ + None, :] * stride_vcd + + k = tl.load(ks_ptrs, mask=mask_ks) + if BLOCK_DV > 0: + v = tl.load(vs_ptrs, mask=mask_vs) + tl.store(kc_ptrs, k, mask=mask_kc) + if BLOCK_DV > 0: + tl.store(vc_ptrs, v, mask=mask_vc) -@wrap_jit_func(type_hint=dict( - KStates=Tensor, - VStates=Tensor, - KCaches=Tensor, - VCaches=Tensor, - KScalesZeros=Tensor, - VScalesZeros=Tensor, - QStartLoc=Tensor, - QSeqLens=Tensor, - KVSeqLens=Tensor, - BlockOffsets=Tensor, - num_heads=torch.int32, - head_dim=torch.int32, - stride_kss=int, - stride_ksh=int, - stride_ksd=int, - stride_vss=int, - stride_vsh=int, - stride_vsd=int, - stride_kcn=int, - stride_kcb=int, - stride_kch=int, - stride_kcd=int, - stride_vcn=int, - stride_vcb=int, - stride_vch=int, - stride_vcd=int, - stride_kszn=int, - stride_kszb=int, - stride_kszh=int, - stride_kszd=int, - stride_vszn=int, - stride_vszb=int, - stride_vszh=int, - stride_vszd=int, - stride_boff=int, - BLOCK=torch.int32, - BLOCK_D=torch.int32, - BLOCK_DV=torch.int32, - BLOCK_H=torch.int32, -)) @triton.jit def _fill_kv_cache_quant_kernel( KStates, @@ -394,15 +324,20 @@ def fill_kv_cache(k_states: Tensor, num_heads = k_caches.size(h_dim) head_dim = k_caches.size(d_dim) head_dim_v = v_states.size(-1) - max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1 + if max_q_seq_length == 1: + max_num_blocks = 1 + else: + max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1 BLOCK = block_size BLOCK_H = triton.next_power_of_2(num_heads) BLOCK_D = triton.next_power_of_2(head_dim) BLOCK_DV = triton.next_power_of_2(head_dim_v) - grid = [batch_size, max_num_blocks] kernel_meta = get_kernel_meta(k_states) if quant_policy == 0: + grid = [num_heads, max_num_blocks, batch_size] + num_heads = 1 + is_decoding = max_num_blocks == 1 _fill_kv_cache_kernel[grid]( k_states, v_states, @@ -412,7 +347,7 @@ def fill_kv_cache(k_states: Tensor, q_seq_length, kv_seq_length, block_offsets, - num_heads=num_heads, + is_decoding=is_decoding, head_dim=head_dim, head_dim_v=head_dim_v, stride_kss=k_states.stride(-3), @@ -433,12 +368,12 @@ def fill_kv_cache(k_states: Tensor, BLOCK=BLOCK, BLOCK_D=BLOCK_D, BLOCK_DV=BLOCK_DV, - BLOCK_H=BLOCK_H, num_warps=4, num_stages=3, **kernel_meta, ) else: + grid = [batch_size, max_num_blocks] _fill_kv_cache_quant_kernel[grid]( k_states, v_states, From 037cac64ebce290606dde96d4cb1c9ba6fccb89a Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 29 Nov 2024 13:43:20 +0800 Subject: [PATCH 08/12] optimize silu and mul --- lmdeploy/pytorch/kernels/cuda/activation.py | 51 ++++++++++----------- 1 file changed, 23 insertions(+), 28 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/activation.py b/lmdeploy/pytorch/kernels/cuda/activation.py index 2533840a95..1a79579be8 100644 --- a/lmdeploy/pytorch/kernels/cuda/activation.py +++ b/lmdeploy/pytorch/kernels/cuda/activation.py @@ -26,27 +26,23 @@ def _silu_and_mul_kernel( BLOCK_SIZE_N: tl.constexpr, ): """silu and mul kernel.""" - m_id = tl.program_id(0) + n_block_id = tl.program_id(0) + m_id = tl.program_id(1) up_ptr = gateup_ptr + N * stride_gun - offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_n = n_block_id * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) gate_ptrs = gateup_ptr + m_id * stride_gum + offs_n * stride_gun up_ptrs = up_ptr + m_id * stride_gum + offs_n * stride_gun out_ptrs = out_ptr + m_id * stride_om + offs_n * stride_on - for _ in range(0, N, BLOCK_SIZE_N): - gate = tl.load(gate_ptrs).to(tl.float32) - up = tl.load(up_ptrs).to(tl.float32) + gate = tl.load(gate_ptrs).to(tl.float32) + up = tl.load(up_ptrs).to(tl.float32) - gate = gate / (1 + fast_expf(-gate)) - out = gate * up + gate = gate / (1 + fast_expf(-gate)) + out = gate * up - tl.store(out_ptrs, out) - - gate_ptrs += BLOCK_SIZE_N * stride_gun - up_ptrs += BLOCK_SIZE_N * stride_gun - out_ptrs += BLOCK_SIZE_N * stride_on + tl.store(out_ptrs, out) @triton.jit @@ -61,28 +57,24 @@ def _silu_and_mul_no_align_kernel( BLOCK_SIZE_N: tl.constexpr, ): """silu and mul kernel.""" - m_id = tl.program_id(0) + n_block_id = tl.program_id(0) + m_id = tl.program_id(1) up_ptr = gateup_ptr + N * stride_gun - offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_n = n_block_id * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) gate_ptrs = gateup_ptr + m_id * stride_gum + offs_n * stride_gun up_ptrs = up_ptr + m_id * stride_gum + offs_n * stride_gun out_ptrs = out_ptr + m_id * stride_om + offs_n * stride_on - for n in range(0, N, BLOCK_SIZE_N): - mask = n + offs_n < N - gate = tl.load(gate_ptrs, mask=mask).to(tl.float32) - up = tl.load(up_ptrs, mask=mask).to(tl.float32) - - gate = gate / (1 + fast_expf(-gate)) - out = gate * up + mask = offs_n < N + gate = tl.load(gate_ptrs, mask=mask).to(tl.float32) + up = tl.load(up_ptrs, mask=mask).to(tl.float32) - tl.store(out_ptrs, out, mask=mask) + gate = gate / (1 + fast_expf(-gate)) + out = gate * up - gate_ptrs += BLOCK_SIZE_N * stride_gun - up_ptrs += BLOCK_SIZE_N * stride_gun - out_ptrs += BLOCK_SIZE_N * stride_on + tl.store(out_ptrs, out, mask=mask) def silu_and_mul(gate_up: torch.Tensor, out: torch.Tensor = None): @@ -96,10 +88,13 @@ def silu_and_mul(gate_up: torch.Tensor, out: torch.Tensor = None): out = gate_up.new_empty(out_shape) BLOCK_SIZE_N = triton.next_power_of_2(N) - BLOCK_SIZE_N = min(BLOCK_SIZE_N, 1024) + BLOCK_SIZE_N = min(BLOCK_SIZE_N, 512) num_warps = 4 - num_stages = 2 - grid = (M, ) + num_stages = 1 + grid = ( + triton.cdiv(N, BLOCK_SIZE_N), + M, + ) if N % BLOCK_SIZE_N == 0: _silu_and_mul_kernel[grid](gate_up, out, From 42472954a7c2754115e5def03ac98c68a73f05ce Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 29 Nov 2024 14:02:29 +0800 Subject: [PATCH 09/12] optimize apply rotary --- lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py index 9e14dc6a0c..a51f37e97e 100644 --- a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py @@ -60,8 +60,8 @@ def apply_rotary_pos_emb_qk_kernel( BLOCK_N: tl.constexpr, ): """apply rotary on key AND query kernel.""" - seq_block_id = tl.program_id(0) - head_id = tl.program_id(1) + seq_block_id = tl.program_id(1) + head_id = tl.program_id(0) pos_offset = seq_block_id * BLOCK + tl.arange(0, BLOCK) pos_mask = pos_offset < seq_len @@ -158,10 +158,13 @@ def apply_rotary_pos_emb(q: Tensor, num_heads_q = q.size(-2) num_heads_k = k.size(-2) num_warps = 4 - num_stages = 4 + num_stages = 1 kernel_meta = get_kernel_meta(q) - grid = [triton.cdiv(seq_len, BLOCK), num_heads_q + num_heads_k] + grid = [ + num_heads_q + num_heads_k, + triton.cdiv(seq_len, BLOCK), + ] apply_rotary_pos_emb_qk_kernel[grid](q, k, cos, From aa255127619e7242b1673a345b5831759fbdf4ea Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 29 Nov 2024 14:43:16 +0800 Subject: [PATCH 10/12] remove executor --- lmdeploy/pytorch/engine/model_agent.py | 10 +++--- .../kernels/cuda/apply_rotary_pos_emb.py | 32 ++----------------- 2 files changed, 9 insertions(+), 33 deletions(-) diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index 74a5d70019..4c7546a8e0 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -268,8 +268,9 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map) - await asyncio.get_event_loop().run_in_executor(None, - self.stream.synchronize) + await asyncio.sleep(0) + while not self.stream.query(): + await asyncio.sleep(0) return output def get_logits(self, hidden_states: torch.Tensor): @@ -687,8 +688,9 @@ async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap, output = self._forward_impl(inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map) - await asyncio.get_event_loop().run_in_executor(None, - self.stream.synchronize) + await asyncio.sleep(0) + while not self.stream.query(): + await asyncio.sleep(0) return output def get_logits(self, hidden_states: torch.Tensor): diff --git a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py index a51f37e97e..f9d5f2f171 100644 --- a/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py +++ b/lmdeploy/pytorch/kernels/cuda/apply_rotary_pos_emb.py @@ -4,35 +4,9 @@ import triton.language as tl from torch import Tensor -from .triton_utils import get_kernel_meta, wrap_jit_func - - -@wrap_jit_func(type_hint=dict( - Q=Tensor, - K=Tensor, - COS=Tensor, - SIN=Tensor, - POS=Tensor, - Q_EMB=Tensor, - K_EMB=Tensor, - seq_len=int, - stride_qs=int, - stride_qh=int, - stride_qd=int, - stride_ks=int, - stride_kh=int, - stride_kd=int, - stride_qes=int, - stride_qeh=int, - stride_qed=int, - stride_kes=int, - stride_keh=int, - stride_ked=int, - half_size=torch.int32, - BLOCK=torch.int32, - BLOCK_QH=torch.int32, - BLOCK_N=torch.int32, -)) +from .triton_utils import get_kernel_meta + + @triton.jit(do_not_specialize=('seq_len', )) def apply_rotary_pos_emb_qk_kernel( Q, From 4ea312743d49cc47dd8ad288afc63a9b277a38bc Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 2 Dec 2024 10:33:38 +0800 Subject: [PATCH 11/12] remove kernel --- lmdeploy/pytorch/kernels/cuda/activation.py | 76 +++++---------------- 1 file changed, 18 insertions(+), 58 deletions(-) diff --git a/lmdeploy/pytorch/kernels/cuda/activation.py b/lmdeploy/pytorch/kernels/cuda/activation.py index 1a79579be8..9a00e7354f 100644 --- a/lmdeploy/pytorch/kernels/cuda/activation.py +++ b/lmdeploy/pytorch/kernels/cuda/activation.py @@ -7,10 +7,8 @@ TRITON_VERSION = version.parse(triton.__version__) if TRITON_VERSION >= version.parse('3.0.0'): - fast_expf = tl.math.exp else: - tanh = tl.math.tanh fast_expf = tl.math.fast_expf @@ -36,40 +34,14 @@ def _silu_and_mul_kernel( up_ptrs = up_ptr + m_id * stride_gum + offs_n * stride_gun out_ptrs = out_ptr + m_id * stride_om + offs_n * stride_on - gate = tl.load(gate_ptrs).to(tl.float32) - up = tl.load(up_ptrs).to(tl.float32) - - gate = gate / (1 + fast_expf(-gate)) - out = gate * up - - tl.store(out_ptrs, out) - - -@triton.jit -def _silu_and_mul_no_align_kernel( - gateup_ptr, - out_ptr, - N: tl.constexpr, - stride_gum: tl.constexpr, - stride_gun: tl.constexpr, - stride_om: tl.constexpr, - stride_on: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, -): - """silu and mul kernel.""" - n_block_id = tl.program_id(0) - m_id = tl.program_id(1) - - up_ptr = gateup_ptr + N * stride_gun - - offs_n = n_block_id * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - gate_ptrs = gateup_ptr + m_id * stride_gum + offs_n * stride_gun - up_ptrs = up_ptr + m_id * stride_gum + offs_n * stride_gun - out_ptrs = out_ptr + m_id * stride_om + offs_n * stride_on - - mask = offs_n < N - gate = tl.load(gate_ptrs, mask=mask).to(tl.float32) - up = tl.load(up_ptrs, mask=mask).to(tl.float32) + if N % BLOCK_SIZE_N == 0: + mask = None + else: + mask = offs_n < N + gate = tl.load(gate_ptrs, mask=mask) + up = tl.load(up_ptrs, mask=mask) + gate = gate.to(tl.float32) + up = up.to(tl.float32) gate = gate / (1 + fast_expf(-gate)) out = gate * up @@ -95,27 +67,15 @@ def silu_and_mul(gate_up: torch.Tensor, out: torch.Tensor = None): triton.cdiv(N, BLOCK_SIZE_N), M, ) - if N % BLOCK_SIZE_N == 0: - _silu_and_mul_kernel[grid](gate_up, - out, - N, - stride_gum=gate_up.stride(0), - stride_gun=gate_up.stride(1), - stride_om=out.stride(0), - stride_on=out.stride(1), - BLOCK_SIZE_N=BLOCK_SIZE_N, - num_warps=num_warps, - num_stages=num_stages) - else: - _silu_and_mul_no_align_kernel[grid](gate_up, - out, - N, - stride_gum=gate_up.stride(0), - stride_gun=gate_up.stride(1), - stride_om=out.stride(0), - stride_on=out.stride(1), - BLOCK_SIZE_N=BLOCK_SIZE_N, - num_warps=num_warps, - num_stages=num_stages) + _silu_and_mul_kernel[grid](gate_up, + out, + N, + stride_gum=gate_up.stride(0), + stride_gun=gate_up.stride(1), + stride_om=out.stride(0), + stride_on=out.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + num_warps=num_warps, + num_stages=num_stages) return out From b26979379843053f7e3577e2ecfc209dd880ea1a Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 2 Dec 2024 14:14:43 +0800 Subject: [PATCH 12/12] remove num_heads==1 --- lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py index 08dbb3138d..93bd89f488 100644 --- a/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py +++ b/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py @@ -336,7 +336,6 @@ def fill_kv_cache(k_states: Tensor, kernel_meta = get_kernel_meta(k_states) if quant_policy == 0: grid = [num_heads, max_num_blocks, batch_size] - num_heads = 1 is_decoding = max_num_blocks == 1 _fill_kv_cache_kernel[grid]( k_states,