@@ -121,7 +121,7 @@ impl SamplerStage {
121121 pub fn from_grammar ( grammar : LlamaGrammar , start_position : Option < usize > ) -> Self {
122122 SamplerStage :: Grammar ( GrammarStage {
123123 grammar,
124- accepted_to : start_position,
124+ accepted_up_to : start_position,
125125 } )
126126 }
127127
@@ -193,7 +193,7 @@ impl SamplerStage {
193193 llama_sample_tail_free ( context, p_ptr, * z, min_keep) ;
194194 }
195195 SamplerStage :: Grammar ( stage) => {
196- stage. apply ( context, tokens, candidates_p, min_keep)
196+ candidates_p = stage. apply ( context, tokens, candidates_p, min_keep)
197197 }
198198 }
199199 }
@@ -206,7 +206,7 @@ impl SamplerStage {
206206#[ derive( Clone , Debug ) ]
207207pub struct GrammarStage {
208208 grammar : LlamaGrammar ,
209- accepted_to : Option < usize > ,
209+ accepted_up_to : Option < usize > ,
210210}
211211
212212impl GrammarStage {
@@ -216,15 +216,21 @@ impl GrammarStage {
216216 tokens : & [ Token ] ,
217217 mut candidates_p : llama_token_data_array ,
218218 _min_keep : usize ,
219- ) {
220- let accepted_to = self . accepted_to . unwrap_or ( tokens. len ( ) ) ;
221- for token in & tokens[ accepted_to..] {
219+ ) -> llama_token_data_array {
220+ // If `accepted_up_to` is `None`, assume that we should start at the end of context.
221+ let accepted_up_to = self . accepted_up_to . unwrap_or ( tokens. len ( ) ) ;
222+
223+ // Accept all new tokens until the end of context.
224+ for token in & tokens[ accepted_up_to..] {
222225 unsafe { llama_grammar_accept_token ( context, self . grammar . grammar . as_ptr ( ) , token. 0 ) }
223226 }
224- self . accepted_to = Some ( tokens. len ( ) ) ;
227+ self . accepted_up_to = Some ( tokens. len ( ) ) ;
225228
229+ // Apply grammar sampling to `candidates_p`.
226230 let p_ptr = addr_of_mut ! ( candidates_p) ;
227231 unsafe { llama_sample_grammar ( context, p_ptr, self . grammar . grammar . as_ptr ( ) ) } ;
232+
233+ candidates_p
228234 }
229235}
230236
0 commit comments