Skip to content

Commit 57d25ec

Browse files
committed
feat: Implementing Past Future Scheduler
1 parent 32f7db7 commit 57d25ec

File tree

7 files changed

+114
-13
lines changed

7 files changed

+114
-13
lines changed

lightllm/server/api_cli.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
214214
parser.add_argument(
215215
"--router_max_new_token_len", type=int, default=1024, help="the request max new token len for router"
216216
)
217+
parser.add_argument(
218+
"--past_future_scheduler",
219+
action="store_true",
220+
help="""use past_future_scheduler for adaptive request new token len prediction,
221+
override --router_token_ratio and --router_max_new_token_len (still used during warmup)""",
222+
)
217223

218224
parser.add_argument(
219225
"--router_max_wait_tokens",

lightllm/server/core/objs/req.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def get_all_prompt_metadata(self):
266266
class ChunkedPrefillReq(Req):
267267
_pack_ = 4
268268

269-
def get_tuple_tokens(self, is_busy, router_max_new_token_len):
269+
def get_tuple_tokens(self, is_busy, router_max_new_token_len, has_out_len_factor=1.1):
270270
args = get_env_start_args()
271271
# chuncked prefill 推理的过程中,存在很多模式的延迟 step 推理的控制, 用于
272272
# 保证更好的包间数据或者是提升 dp 模式下prefill 的效率,但是在估计 token 显存
@@ -283,7 +283,7 @@ def get_tuple_tokens(self, is_busy, router_max_new_token_len):
283283
cur_max_new_token_len = self.sample_params.max_new_tokens
284284
else:
285285
cur_max_new_token_len = min(
286-
self.sample_params.max_new_tokens, max(int(1.1 * has_out_len), router_max_new_token_len)
286+
self.sample_params.max_new_tokens, max(int(has_out_len_factor * has_out_len), router_max_new_token_len)
287287
)
288288

289289
a_len = max(self.input_len + has_out_len + 1, self.shm_cur_kv_len + 1)

lightllm/server/httpserver/manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -649,9 +649,9 @@ async def recycle_resource_loop(self):
649649
continue
650650

651651
logger.info(
652-
f"left req id {req_status.group_req_objs.group_req_id}"
653-
f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} "
654-
f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}"
652+
f"left req id: {req_status.group_req_objs.group_req_id}, "
653+
f"can release: {req_status.group_req_objs.shm_req_objs[0].can_released_mark}, "
654+
f"refcount: {req_status.group_req_objs.shm_req_objs[0].ref_count}"
655655
)
656656
return
657657

lightllm/server/router/manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .shm_reqs_io_buffer import ShmReqsIOBuffer
2323
from lightllm.utils.log_utils import init_logger, log_time_ready
2424
from lightllm.server.router.token_load import TokenLoad
25+
from lightllm.server.router.req_queue.chunked_prefill.impl_past_future import PastFutureQueue
2526
from lightllm.server.metrics.manager import MetricClient
2627
from lightllm.common.basemodel.infer_lock import g_router_lock
2728
from lightllm.common.mem_manager import ReadOnlyStaticsMemoryManager
@@ -319,6 +320,8 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch):
319320

320321
def _filter_reqs_from_running_batch(self):
321322
if self.running_batch is not None:
323+
if isinstance(self.req_queue, PastFutureQueue):
324+
self.req_queue.record_finished_len_from_batch(self.running_batch)
322325
self.running_batch.filter_out_finished_req(self.shm_req_manager)
323326
if self.running_batch.is_clear():
324327
self.running_batch = None

lightllm/server/router/req_queue/__init__.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,37 @@
11
from .chunked_prefill.impl_for_pd_decode import QueueForPDDecode
22
from .chunked_prefill.impl import ChunkedPrefillQueue
33
from .chunked_prefill.beam_impl import ChunkedBeamContinuesBatchQueue
4+
from .chunked_prefill.impl_past_future import PastFutureQueue
45
from .dp_base_queue import DpQueue
56

67

78
def _get_req_queue_class(args, router, dp_size_in_node: int):
9+
if args.past_future_scheduler:
10+
if args.diverse_mode:
11+
raise ValueError("Diverse mode is not supported with past future scheduler yet")
12+
chunked_prefill_queue_impl = PastFutureQueue
13+
else:
14+
chunked_prefill_queue_impl = ChunkedPrefillQueue
15+
816
if args.diverse_mode:
917
return ChunkedBeamContinuesBatchQueue
1018
if args.token_healing_mode:
11-
return ChunkedPrefillQueue
19+
return chunked_prefill_queue_impl
1220
if args.output_constraint_mode != "none":
13-
return ChunkedPrefillQueue
21+
return chunked_prefill_queue_impl
1422
if args.first_token_constraint_mode:
15-
return ChunkedPrefillQueue
23+
return chunked_prefill_queue_impl
1624
if args.run_mode == "decode":
1725
return QueueForPDDecode
1826
if args.run_mode == "prefill":
19-
return ChunkedPrefillQueue
27+
return chunked_prefill_queue_impl
2028

2129
if args.disable_chunked_prefill:
2230
# 虽然也使用chuncked prefill queue 但是由于 args.chunked_prefill_size = args.max_req_total_len
2331
# 所以调度的实际行为类似过去的 continues batch 调度,所以将两种调度的实现统一为一种实现,减少代码重复。
24-
return ChunkedPrefillQueue
32+
return chunked_prefill_queue_impl
2533
else:
26-
return ChunkedPrefillQueue
34+
return chunked_prefill_queue_impl
2735

2836

2937
def build_req_queue(args, router, dp_size_in_node: int):

lightllm/server/router/req_queue/chunked_prefill/impl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ def _init_cache_list(self, current_batch: Batch, is_busy):
2121
self.cache_len_list = []
2222
return
2323

24-
# @calculate_time(show=True, min_cost_ms=0.1)
25-
def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens):
24+
def _update_cache_len_list(self, req: Req, is_busy):
2625
self.cache_len_list.append(req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) # hard to analysis
2726
self.cache_len_list.sort(key=lambda x: -x[1])
2827

@@ -32,6 +31,11 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens
3231
size_array = np.arange(1, len(self.cache_len_list) + 1, 1)
3332

3433
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
34+
return need_max_token_num
35+
36+
# @calculate_time(show=True, min_cost_ms=0.1)
37+
def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens):
38+
need_max_token_num = self._update_cache_len_list(req, is_busy)
3539
with g_router_lock.obj:
3640
ok_token_num = (
3741
need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import bisect
2+
from collections import deque
3+
import random
4+
from typing import List, Tuple
5+
import numpy as np
6+
from ...batch import Batch, Req
7+
from .impl import ChunkedPrefillQueue
8+
9+
10+
class PastFutureQueue(ChunkedPrefillQueue):
11+
WINDOW_SIZE = 200
12+
MINIMUM_SAMPLES = 200
13+
MAXIMUM_LISTS = 5
14+
REVERSED = 0.05
15+
COMPLIANCE_IS_BUSY_FLAG = False
16+
17+
def __init__(self, args, router, dp_index, dp_size_in_node) -> None:
18+
super().__init__(args, router, dp_index, dp_size_in_node)
19+
initial_len = args.router_max_new_token_len
20+
self.history_output_len = deque([initial_len] * (self.WINDOW_SIZE // 2), maxlen=self.WINDOW_SIZE)
21+
22+
def _sample_cache_list(self, reqs: List[Req], is_busy, samples=1) -> List[List[Tuple[int, int]]]:
23+
cache_len_lists = [[] for _ in range(samples)]
24+
his_Lo = sorted(self.history_output_len)
25+
for req in reqs:
26+
dl = req.shm_cur_output_len
27+
pos = bisect.bisect(his_Lo, dl)
28+
29+
sample_range = [dl] + his_Lo[pos:] + [req.sample_params.max_new_tokens] # at least 2 value
30+
31+
for i in range(samples):
32+
random_p = np.random.random() * (len(sample_range) - 1)
33+
l_pos = int(random_p)
34+
l_val, r_val = sample_range[l_pos : l_pos + 2]
35+
36+
# Linear interpolation
37+
sampled = round(l_val + (r_val - l_val) * (random_p - l_pos))
38+
cache_len_lists[i].append(
39+
req.get_tuple_tokens(is_busy and self.COMPLIANCE_IS_BUSY_FLAG, sampled, has_out_len_factor=1.0)
40+
)
41+
42+
return cache_len_lists
43+
44+
def _calc_max_token_num_needed(self, cache_len_list: List[Tuple[int, int]]) -> int:
45+
cache_len_list.sort(key=lambda x: -x[1])
46+
47+
left_out_len_array = np.array([e[1] for e in cache_len_list])
48+
has_run_len_array = np.array([e[0] for e in cache_len_list])
49+
cum_run_len_array = np.cumsum(has_run_len_array)
50+
size_array = np.arange(1, len(cache_len_list) + 1, 1)
51+
52+
need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max()
53+
return need_max_token_num
54+
55+
def _init_cache_list(self, current_batch: Batch, is_busy):
56+
if current_batch is not None:
57+
n_lists = min(self.MAXIMUM_LISTS, int(self.MINIMUM_SAMPLES / len(current_batch.reqs)) + 1)
58+
local_reqs = [req for req in current_batch.reqs if req.sample_params.suggested_dp_index == self.dp_index]
59+
self._cache_len_lists = self._sample_cache_list(local_reqs, is_busy, samples=n_lists)
60+
else:
61+
self._cache_len_lists = [[]]
62+
self.cache_len_list = self._cache_len_lists[0] # keep compatibility
63+
64+
def _update_cache_len_list(self, req: Req, is_busy):
65+
need_max_token_nums = []
66+
for li in self._cache_len_lists:
67+
newreq_output_len_sample = random.choice(self.history_output_len)
68+
li.append(
69+
req.get_tuple_tokens(
70+
is_busy and self.COMPLIANCE_IS_BUSY_FLAG, newreq_output_len_sample, has_out_len_factor=1.0
71+
)
72+
)
73+
need_max_token_nums.append(self._calc_max_token_num_needed(li))
74+
need_max_token_num = np.max(need_max_token_nums)
75+
return need_max_token_num
76+
77+
def record_finished_len_from_batch(self, batch: Batch):
78+
for req in batch.reqs:
79+
if req.shm_infer_released:
80+
self.history_output_len.append(req.shm_cur_output_len)

0 commit comments

Comments
 (0)