Skip to content

Commit e7e0f93

Browse files
committed
Make changes more readable
1 parent 7230cc6 commit e7e0f93

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

crates/llama_cpp/src/standard_sampler.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ impl SamplerStage {
121121
pub fn from_grammar(grammar: LlamaGrammar, start_position: Option<usize>) -> Self {
122122
SamplerStage::Grammar(GrammarStage {
123123
grammar,
124-
accepted_to: start_position,
124+
accepted_up_to: start_position,
125125
})
126126
}
127127

@@ -193,7 +193,7 @@ impl SamplerStage {
193193
llama_sample_tail_free(context, p_ptr, *z, min_keep);
194194
}
195195
SamplerStage::Grammar(stage) => {
196-
stage.apply(context, tokens, candidates_p, min_keep)
196+
candidates_p = stage.apply(context, tokens, candidates_p, min_keep)
197197
}
198198
}
199199
}
@@ -206,7 +206,7 @@ impl SamplerStage {
206206
#[derive(Clone, Debug)]
207207
pub struct GrammarStage {
208208
grammar: LlamaGrammar,
209-
accepted_to: Option<usize>,
209+
accepted_up_to: Option<usize>,
210210
}
211211

212212
impl GrammarStage {
@@ -216,15 +216,21 @@ impl GrammarStage {
216216
tokens: &[Token],
217217
mut candidates_p: llama_token_data_array,
218218
_min_keep: usize,
219-
) {
220-
let accepted_to = self.accepted_to.unwrap_or(tokens.len());
221-
for token in &tokens[accepted_to..] {
219+
) -> llama_token_data_array {
220+
// If `accepted_up_to` is `None`, assume that we should start at the end of context.
221+
let accepted_up_to = self.accepted_up_to.unwrap_or(tokens.len());
222+
223+
// Accept all new tokens until the end of context.
224+
for token in &tokens[accepted_up_to..] {
222225
unsafe { llama_grammar_accept_token(context, self.grammar.grammar.as_ptr(), token.0) }
223226
}
224-
self.accepted_to = Some(tokens.len());
227+
self.accepted_up_to = Some(tokens.len());
225228

229+
// Apply grammar sampling to `candidates_p`.
226230
let p_ptr = addr_of_mut!(candidates_p);
227231
unsafe { llama_sample_grammar(context, p_ptr, self.grammar.grammar.as_ptr()) };
232+
233+
candidates_p
228234
}
229235
}
230236

0 commit comments

Comments
 (0)