@@ -149,11 +149,12 @@ static void sampler_queue(
149149 }
150150}
151151
152- llama_token llama_sampling_sample (
152+ static llama_token llama_sampling_sample_impl (
153153 struct llama_sampling_context * ctx_sampling,
154154 struct llama_context * ctx_main,
155155 struct llama_context * ctx_cfg,
156- const int idx) {
156+ const int idx,
157+ bool is_resampling) { // Add a parameter to indicate if we are resampling
157158 const llama_sampling_params & params = ctx_sampling->params ;
158159
159160 const int n_vocab = llama_n_vocab (llama_get_model (ctx_main));
@@ -173,8 +174,17 @@ llama_token llama_sampling_sample(
173174
174175 llama_token id = 0 ;
175176
177+ // Get a pointer to the logits
176178 float * logits = llama_get_logits_ith (ctx_main, idx);
177179
180+ // Declare original_logits at the beginning of the function scope
181+ std::vector<float > original_logits;
182+
183+ if (!is_resampling) {
184+ // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
185+ original_logits = std::vector<float >(logits, logits + llama_n_vocab (llama_get_model (ctx_main)));
186+ }
187+
178188 // apply params.logit_bias map
179189 for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
180190 logits[it->first ] += it->second ;
@@ -210,7 +220,8 @@ llama_token llama_sampling_sample(
210220 }
211221 }
212222
213- if (ctx_sampling->grammar != NULL ) {
223+ // If we are in the resampling phase, apply grammar checks before sampling logic
224+ if (is_resampling && ctx_sampling->grammar != NULL ) {
214225 llama_sample_grammar (ctx_main, &cur_p, ctx_sampling->grammar );
215226 }
216227
@@ -252,9 +263,40 @@ llama_token llama_sampling_sample(
252263 }
253264 }
254265
266+ if (ctx_sampling->grammar != NULL && !is_resampling) {
267+ // Create an array with a single token data element for the sampled id
268+ llama_token_data single_token_data = {id, logits[id], 0 .0f };
269+ llama_token_data_array single_token_data_array = { &single_token_data, 1 , false };
270+
271+ // Apply grammar constraints to the single token
272+ llama_sample_grammar (ctx_main, &single_token_data_array, ctx_sampling->grammar );
273+
274+ // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
275+ bool is_valid = single_token_data_array.data [0 ].logit != -INFINITY;
276+
277+ // If the token is not valid according to the grammar, perform resampling
278+ if (!is_valid) {
279+ LOG (" Resampling because token %d: '%s' does not meet grammar rules\n " , id, llama_token_to_piece (ctx_main, id).c_str ());
280+
281+ // Restore logits from the copy
282+ std::copy (original_logits.begin (), original_logits.end (), logits);
283+
284+ return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, true ); // Pass true for is_resampling
285+ }
286+ }
287+
255288 return id;
256289}
257290
291+ llama_token llama_sampling_sample (
292+ struct llama_sampling_context * ctx_sampling,
293+ struct llama_context * ctx_main,
294+ struct llama_context * ctx_cfg,
295+ const int idx) {
296+ // Call the implementation function with is_resampling set to false by default
297+ return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, false );
298+ }
299+
258300void llama_sampling_accept (
259301 struct llama_sampling_context * ctx_sampling,
260302 struct llama_context * ctx_main,
0 commit comments