1313class GuidedDecoder :
1414 bitmask_dtype = torch .int32
1515
16- def __init__ (self , guided_decoding_config : GuidedDecodingConfig ,
17- max_num_sequences : int , vocab_size_padded : int ):
16+ def __init__ (self ,
17+ guided_decoding_config : GuidedDecodingConfig ,
18+ max_num_sequences : int ,
19+ vocab_size_padded : int ,
20+ max_num_draft_tokens : int = 0 ):
1821 self .guided_decoding_backend = guided_decoding_config .backend
1922 self .max_num_sequences = max_num_sequences
2023 self .vocab_size_padded = vocab_size_padded
24+ self .max_num_draft_tokens = max_num_draft_tokens
2125
2226 self .grammar_matcher_factory : Optional [GrammarMatcherFactory ] = None
2327 self .grammar_matchers : List [
2428 Optional [GrammarMatcher ]] = [None ] * self .max_num_sequences
2529
2630 if self .guided_decoding_backend == GuidedDecodingConfig .GuidedDecodingBackend .XGRAMMAR :
2731 self .grammar_matcher_factory = XGrammarMatcherFactory (
28- guided_decoding_config , vocab_size_padded )
32+ guided_decoding_config ,
33+ vocab_size_padded ,
34+ max_num_draft_tokens = max_num_draft_tokens )
2935 elif self .guided_decoding_backend == GuidedDecodingConfig .GuidedDecodingBackend .LLGUIDANCE :
3036 self .grammar_matcher_factory = LLGuidanceMatcherFactory (
3137 guided_decoding_config , vocab_size_padded )
@@ -35,14 +41,16 @@ def __init__(self, guided_decoding_config: GuidedDecodingConfig,
3541 )
3642
3743 self .bitmask = torch .empty (self .max_num_sequences ,
44+ self .max_num_draft_tokens + 1 ,
3845 self .bitmask_size ,
3946 dtype = self .bitmask_dtype ,
4047 device = 'cuda' )
4148 self .bitmask_host = torch .empty (self .max_num_sequences ,
49+ self .max_num_draft_tokens + 1 ,
4250 self .bitmask_size ,
4351 dtype = self .bitmask_dtype ,
4452 pin_memory = True )
45-
53+ self . num_guided_tokens : List [ int ] = [ 0 ] * self . max_num_sequences
4654 self ._stream = torch .cuda .Stream ()
4755
4856 @property
@@ -52,44 +60,77 @@ def bitmask_size(self) -> int:
5260 @nvtx_range ("GuidedDecoder.build" )
5361 def build (self , scheduled_requests : ScheduledRequests ) -> None :
5462 for llm_req in scheduled_requests .all_requests ():
55- if llm_req .guided_decoding_params is None :
56- continue
57- slot = llm_req .py_seq_slot
58- if llm_req .is_context_init_state and llm_req .context_current_position == llm_req .prepopulated_prompt_len :
59- self .grammar_matchers [
60- slot ] = self .grammar_matcher_factory .create (
61- llm_req .guided_decoding_params )
63+ slot : int = llm_req .py_seq_slot
64+ require_guided : bool = True
6265
63- elif llm_req .is_generation_in_progress_state :
64- # The request is in a generation forward step.
65- # Currently, guided decoding does not support with beam search.
66- self .grammar_matchers [slot ].accept_token (
67- llm_req .get_last_tokens (0 ))
66+ if llm_req .guided_decoding_params is None :
67+ require_guided = False
6868 else :
69- continue
70-
71- # Fill the bitmask on host and asynchorously copy to device.
72- self .grammar_matchers [slot ].fill_next_token_bitmask (
73- self .bitmask_host , slot )
74- with torch .cuda .stream (self ._stream ):
75- self .bitmask [slot ].copy_ (self .bitmask_host [slot ],
76- non_blocking = True )
69+ if llm_req .is_context_init_state and llm_req .is_last_context_chunk :
70+ # The request is in the last chunk of a context forward step.
71+ matcher = self .grammar_matcher_factory .create (
72+ llm_req .guided_decoding_params )
73+ self .grammar_matchers [slot ] = matcher
74+ elif llm_req .is_generation_in_progress_state :
75+ # The request is in a generation forward step.
76+ matcher = self .grammar_matchers [slot ]
77+ # Rollback the grammar matcher to the last accepted token.
78+ num_rollback_tokens = self .num_guided_tokens [slot ] - (
79+ 1 + llm_req .py_num_accepted_draft_tokens )
80+ assert num_rollback_tokens >= 0
81+ matcher .rollback (num_rollback_tokens )
82+
83+ # Currently, guided decoding does not support with beam search.
84+ accepted = matcher .accept_token (llm_req .get_last_tokens (0 ))
85+ # TODO: Make this an error response.
86+ if not accepted :
87+ raise ValueError (
88+ f"Failed to accept new token: { llm_req .get_last_tokens (0 )} ."
89+ )
90+ else :
91+ require_guided = False
92+
93+ num_guided_tokens : int = 0
94+ if require_guided :
95+ if not matcher .is_terminated ():
96+ matcher .fill_next_token_bitmask (self .bitmask_host [slot ], 0 )
97+ num_guided_tokens += 1
98+ # Process draft tokens
99+ for i , tid in enumerate (llm_req .py_draft_tokens , 1 ):
100+ accepted = matcher .accept_token (tid )
101+ if matcher .is_terminated ():
102+ matcher .rollback (1 )
103+ accepted = False
104+ if accepted :
105+ matcher .fill_next_token_bitmask (self .bitmask_host [slot ],
106+ i )
107+ num_guided_tokens += 1
108+ else :
109+ break
110+
111+ self .num_guided_tokens [slot ] = num_guided_tokens
112+ if num_guided_tokens > 0 :
113+ with torch .cuda .stream (self ._stream ):
114+ self .bitmask [slot , :num_guided_tokens ].copy_ (
115+ self .bitmask_host [slot , :num_guided_tokens ],
116+ non_blocking = True )
77117
78118 @nvtx_range ("GuidedDecoder.execute" )
79119 def execute (self , scheduled_requests : ScheduledRequests ,
80120 logits : torch .Tensor ) -> None :
81- assert logits .size (0 ) == len (scheduled_requests .context_requests ) + len (
82- scheduled_requests .generation_requests )
83121 torch .cuda .current_stream ().wait_stream (self ._stream )
84122
85123 batched_logits , batched_bitmask = [], []
86- for i , llm_req in enumerate (scheduled_requests .all_requests ()):
87- if llm_req .guided_decoding_params is None :
88- continue
89- if llm_req .is_context_init_state and not llm_req .is_last_context_chunk :
90- continue
91- batched_logits .append (logits [i ])
92- batched_bitmask .append (self .bitmask [llm_req .py_seq_slot ])
124+ offset = 0
125+ for llm_req in scheduled_requests .all_requests ():
126+ slot : int = llm_req .py_seq_slot
127+ num_guided_tokens : int = self .num_guided_tokens [slot ]
128+ for i in range (num_guided_tokens ):
129+ batched_logits .append (logits [offset + i ])
130+ batched_bitmask .append (self .bitmask [slot , i ])
131+ offset += len (llm_req .py_draft_tokens ) + 1
132+
133+ assert offset == logits .size (0 )
93134
94135 if len (batched_logits ) > 0 :
95136 torch .ops .trtllm .logits_bitmask (batched_logits , batched_bitmask )
0 commit comments