Skip to content

Commit 7243b60

Browse files
hiworldwzjwangzaijun
and
wangzaijun
authored
overlap post sample. (#697)
batch 1, decode faster 1ms. --------- Co-authored-by: wangzaijun <[email protected]>
1 parent 6ede09e commit 7243b60

File tree

1 file changed

+35
-18
lines changed
  • lightllm/server/router/model_infer/mode_backend/continues_batch

1 file changed

+35
-18
lines changed

lightllm/server/router/model_infer/mode_backend/continues_batch/post_process.py

+35-18
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,44 @@
1-
import re
21
import torch
3-
from typing import List, Tuple
4-
from lightllm.server.router.model_infer.infer_batch import InferBatch
2+
from typing import List
53
from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty
4+
from dataclasses import dataclass
5+
6+
7+
@dataclass
8+
class OverlapStream:
9+
overlap_stream: torch.cuda.Stream = None
10+
11+
def get_overlap_stream(self):
12+
if self.overlap_stream is None:
13+
self.overlap_stream = torch.cuda.Stream()
14+
return self.overlap_stream
15+
16+
17+
g_single_overlap_stream = OverlapStream()
618

719

820
def sample(logits, reqs, eos_id: List[int] = [2]):
21+
22+
with torch.cuda.stream(g_single_overlap_stream.get_overlap_stream()):
23+
(
24+
presence_penalties,
25+
frequency_penalties,
26+
repetition_penalties,
27+
exponential_decay_length_penalties,
28+
temperatures,
29+
top_ps,
30+
top_ks,
31+
p_token_ids,
32+
p_token_counts,
33+
p_cumsum_seq_len,
34+
p_max_len_in_batch,
35+
length_penalty_idx,
36+
mask_eos_reqs,
37+
) = _get_post_sample_tensors(reqs)
38+
39+
torch.cuda.current_stream().wait_stream(g_single_overlap_stream.get_overlap_stream())
40+
941
logits = logits.contiguous()
10-
(
11-
presence_penalties,
12-
frequency_penalties,
13-
repetition_penalties,
14-
exponential_decay_length_penalties,
15-
temperatures,
16-
top_ps,
17-
top_ks,
18-
p_token_ids,
19-
p_token_counts,
20-
p_cumsum_seq_len,
21-
p_max_len_in_batch,
22-
length_penalty_idx,
23-
mask_eos_reqs,
24-
) = _get_post_sample_tensors(reqs)
2542

2643
apply_penalty(
2744
logits,

0 commit comments

Comments
 (0)