File tree 1 file changed +35
-18
lines changed
lightllm/server/router/model_infer/mode_backend/continues_batch
1 file changed +35
-18
lines changed Original file line number Diff line number Diff line change 1
- import re
2
1
import torch
3
- from typing import List , Tuple
4
- from lightllm .server .router .model_infer .infer_batch import InferBatch
2
+ from typing import List
5
3
from lightllm .common .basemodel .triton_kernel .apply_penalty import apply_penalty
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class OverlapStream :
9
+ overlap_stream : torch .cuda .Stream = None
10
+
11
+ def get_overlap_stream (self ):
12
+ if self .overlap_stream is None :
13
+ self .overlap_stream = torch .cuda .Stream ()
14
+ return self .overlap_stream
15
+
16
+
17
+ g_single_overlap_stream = OverlapStream ()
6
18
7
19
8
20
def sample (logits , reqs , eos_id : List [int ] = [2 ]):
21
+
22
+ with torch .cuda .stream (g_single_overlap_stream .get_overlap_stream ()):
23
+ (
24
+ presence_penalties ,
25
+ frequency_penalties ,
26
+ repetition_penalties ,
27
+ exponential_decay_length_penalties ,
28
+ temperatures ,
29
+ top_ps ,
30
+ top_ks ,
31
+ p_token_ids ,
32
+ p_token_counts ,
33
+ p_cumsum_seq_len ,
34
+ p_max_len_in_batch ,
35
+ length_penalty_idx ,
36
+ mask_eos_reqs ,
37
+ ) = _get_post_sample_tensors (reqs )
38
+
39
+ torch .cuda .current_stream ().wait_stream (g_single_overlap_stream .get_overlap_stream ())
40
+
9
41
logits = logits .contiguous ()
10
- (
11
- presence_penalties ,
12
- frequency_penalties ,
13
- repetition_penalties ,
14
- exponential_decay_length_penalties ,
15
- temperatures ,
16
- top_ps ,
17
- top_ks ,
18
- p_token_ids ,
19
- p_token_counts ,
20
- p_cumsum_seq_len ,
21
- p_max_len_in_batch ,
22
- length_penalty_idx ,
23
- mask_eos_reqs ,
24
- ) = _get_post_sample_tensors (reqs )
25
42
26
43
apply_penalty (
27
44
logits ,
You can’t perform that action at this time.
0 commit comments