@@ -14,12 +14,13 @@ use crate::{grammar::LlamaGrammar, Sampler, Token};
1414///
1515/// Standard ordering for samplers (taken from [kobold.cpp](https://github.com/LostRuins/koboldcpp)):
1616///
17- /// 1. [`SamplerStage::RepetitionPenalty`]
18- /// 2. [`SamplerStage::Temperature`], [SamplerStage::DynamicTemperature]
19- /// 3. [`SamplerStage::TopK`]
20- /// 4. [`SamplerStage::TailFree`]
21- /// 5. [`SamplerStage::Typical`]
22- /// 6. [`SamplerStage::TopP`], [`SamplerStage::MinP`]
17+ /// 1. [`SamplerStage::Grammar`]
18+ /// 2. [`SamplerStage::RepetitionPenalty`]
19+ /// 3. [`SamplerStage::Temperature`], [SamplerStage::DynamicTemperature]
20+ /// 4. [`SamplerStage::TopK`]
21+ /// 5. [`SamplerStage::TailFree`]
22+ /// 6. [`SamplerStage::Typical`]
23+ /// 7. [`SamplerStage::TopP`], [`SamplerStage::MinP`]
2324#[ derive( Clone , Debug ) ]
2425#[ non_exhaustive]
2526pub enum SamplerStage {
@@ -103,16 +104,34 @@ pub enum SamplerStage {
103104 ///
104105 /// See: <https://www.trentonbricken.com/Tail-Free-Sampling/>
105106 TailFree ( f32 ) ,
107+
108+ /// A stage that uses a [`LlamaGrammar`] to remove tokens that do not align with a given
109+ /// grammar. Since this stage has to handle mutable state, an instance of this stage should
110+ /// only be used in one completion.
111+ ///
112+ /// See [`GrammarStage`] and [`LlamaGrammar`] for more information.
113+ Grammar ( GrammarStage ) ,
106114}
107115
108116impl SamplerStage {
117+ /// Creates a new [`SamplerStage::Grammar`] from a [`LlamaGrammar`].
118+ ///
119+ /// `start_position` indicates the token position to begin applying the grammar at. [`None`]
120+ /// indicates that the grammar begins at the end of context.
121+ pub fn from_grammar ( grammar : LlamaGrammar , start_position : Option < usize > ) -> Self {
122+ SamplerStage :: Grammar ( GrammarStage {
123+ grammar,
124+ accepted_up_to : start_position,
125+ } )
126+ }
127+
109128 /// Applies this [`SamplerStage`] to the provided token data array.
110129 ///
111130 /// Ensures that at least `min_keep` tokens remain after the
112131 /// [`SamplerStage`]'s are applied.
113132 #[ allow( clippy:: not_unsafe_ptr_arg_deref) ]
114133 pub fn apply (
115- & self ,
134+ & mut self ,
116135 context : * mut llama_context ,
117136 tokens : & [ Token ] ,
118137 mut candidates_p : llama_token_data_array ,
@@ -173,13 +192,48 @@ impl SamplerStage {
173192 SamplerStage :: TailFree ( z) => {
174193 llama_sample_tail_free ( context, p_ptr, * z, min_keep) ;
175194 }
195+ SamplerStage :: Grammar ( stage) => {
196+ candidates_p = stage. apply ( context, tokens, candidates_p, min_keep)
197+ }
176198 }
177199 }
178200
179201 candidates_p
180202 }
181203}
182204
205+ /// Opaque internals for [`SamplerStage::Grammar`].
206+ #[ derive( Clone , Debug ) ]
207+ pub struct GrammarStage {
208+ grammar : LlamaGrammar ,
209+ accepted_up_to : Option < usize > ,
210+ }
211+
212+ impl GrammarStage {
213+ fn apply (
214+ & mut self ,
215+ context : * mut llama_context ,
216+ tokens : & [ Token ] ,
217+ mut candidates_p : llama_token_data_array ,
218+ _min_keep : usize ,
219+ ) -> llama_token_data_array {
220+ // If `accepted_up_to` is `None`, assume that we should start at the end of context.
221+ let accepted_up_to = self . accepted_up_to . unwrap_or ( tokens. len ( ) ) ;
222+
223+ // Accept all new tokens until the end of context.
224+ for token in & tokens[ accepted_up_to..] {
225+ unsafe { llama_grammar_accept_token ( context, self . grammar . grammar . as_ptr ( ) , token. 0 ) }
226+ }
227+ self . accepted_up_to = Some ( tokens. len ( ) ) ;
228+
229+ // Apply grammar sampling to `candidates_p`.
230+ let p_ptr = addr_of_mut ! ( candidates_p) ;
231+ unsafe { llama_sample_grammar ( context, p_ptr, self . grammar . grammar . as_ptr ( ) ) } ;
232+
233+ candidates_p
234+ }
235+ }
236+
183237/// Determines how the next token is selected from the distribution produced by
184238/// the model and the [`SamplerStage`]'s.
185239#[ derive( Clone , Debug ) ]
@@ -232,7 +286,6 @@ impl TokenSelector {
232286pub struct StandardSampler {
233287 stages : Vec < SamplerStage > ,
234288 min_keep : usize ,
235- grammar : Option < LlamaGrammar > ,
236289 token_selector : TokenSelector ,
237290}
238291
@@ -246,12 +299,10 @@ impl StandardSampler {
246299 pub fn new_softmax (
247300 stages : Vec < SamplerStage > ,
248301 min_keep : usize ,
249- grammar : Option < LlamaGrammar > ,
250302 ) -> StandardSampler {
251303 StandardSampler {
252304 stages,
253305 min_keep,
254- grammar : grammar,
255306 token_selector : TokenSelector :: Softmax ,
256307 }
257308 }
@@ -262,7 +313,6 @@ impl StandardSampler {
262313 StandardSampler {
263314 stages : Vec :: new ( ) ,
264315 min_keep : 0 ,
265- grammar : None ,
266316 token_selector : TokenSelector :: Greedy ,
267317 }
268318 }
@@ -279,7 +329,6 @@ impl StandardSampler {
279329 StandardSampler {
280330 stages,
281331 min_keep,
282- grammar : None ,
283332 token_selector : TokenSelector :: Mirostat {
284333 tau,
285334 eta,
@@ -300,7 +349,6 @@ impl StandardSampler {
300349 StandardSampler {
301350 stages,
302351 min_keep,
303- grammar : None ,
304352 token_selector : TokenSelector :: MirostatV2 {
305353 tau,
306354 eta,
@@ -325,7 +373,6 @@ impl Default for StandardSampler {
325373 SamplerStage :: MinP ( 0.05 ) ,
326374 SamplerStage :: Temperature ( 0.8 ) ,
327375 ] ,
328- grammar : None ,
329376 min_keep : 1 ,
330377 token_selector : TokenSelector :: Softmax ,
331378 }
@@ -340,25 +387,12 @@ impl Sampler for StandardSampler {
340387 tokens : & [ Token ] ,
341388 mut candidates_p : llama_token_data_array ,
342389 ) -> Token {
343- let p_ptr = addr_of_mut ! ( candidates_p) ;
344390 let min_keep = self . min_keep . max ( 1 ) ;
345391
346- // Note: We should sample grammar before applying other sampling stages.
347- if let Some ( grammar) = self . grammar . as_mut ( ) {
348- unsafe { llama_sample_grammar ( context, p_ptr, grammar. grammar . as_ptr ( ) ) } ;
349- }
350-
351- for stage in & self . stages {
392+ for stage in & mut self . stages {
352393 candidates_p = stage. apply ( context, tokens, candidates_p, min_keep) ;
353394 }
354395
355- let token = self . token_selector . select ( context, candidates_p) ;
356-
357- // Note: We must accept the token into the grammar after sampling if a grammar is provided.
358- if let Some ( grammar) = self . grammar . as_mut ( ) {
359- unsafe { llama_grammar_accept_token ( context, grammar. grammar . as_ptr ( ) , token. 0 ) }
360- }
361-
362- token
396+ self . token_selector . select ( context, candidates_p)
363397 }
364398}
0 commit comments