|  | 
| 1 |  | -#define LLAMA_API_INTERNAL | 
| 2 | 1 | #include "sampling.h" | 
|  | 2 | + | 
| 3 | 3 | #include <random> | 
| 4 | 4 | 
 | 
| 5 |  | -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { | 
|  | 5 | +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_context * ctx, llama_seq_id seq_id) { | 
| 6 | 6 |     struct llama_sampling_context * result = new llama_sampling_context(); | 
| 7 | 7 | 
 | 
| 8 | 8 |     result->params  = params; | 
|  | 9 | +    result->seq_id  = seq_id; | 
|  | 10 | +    result->ctx     = ctx; | 
| 9 | 11 |     result->grammar = nullptr; | 
| 10 | 12 | 
 | 
| 11 | 13 |     // if there is a grammar, parse it | 
| @@ -81,7 +83,7 @@ void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t s | 
| 81 | 83 |     if (seed == LLAMA_DEFAULT_SEED) { | 
| 82 | 84 |         seed = std::random_device{}(); | 
| 83 | 85 |     } | 
| 84 |  | -    ctx->rng.seed(seed); | 
|  | 86 | +    llama_set_rng_seed_seq(ctx->ctx, seed, ctx->seq_id); | 
| 85 | 87 | } | 
| 86 | 88 | 
 | 
| 87 | 89 | void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { | 
| @@ -271,10 +273,10 @@ static llama_token llama_sampling_sample_impl( | 
| 271 | 273 |                   bool is_resampling) { | 
| 272 | 274 |     const llama_sampling_params & params = ctx_sampling->params; | 
| 273 | 275 | 
 | 
| 274 |  | -    const float   temp            = params.temp; | 
| 275 |  | -    const int     mirostat        = params.mirostat; | 
| 276 |  | -    const float   mirostat_tau    = params.mirostat_tau; | 
| 277 |  | -    const float   mirostat_eta    = params.mirostat_eta; | 
|  | 276 | +    const float temp         = params.temp; | 
|  | 277 | +    const int   mirostat     = params.mirostat; | 
|  | 278 | +    const float mirostat_tau = params.mirostat_tau; | 
|  | 279 | +    const float mirostat_eta = params.mirostat_eta; | 
| 278 | 280 | 
 | 
| 279 | 281 |     std::vector<float> original_logits; | 
| 280 | 282 |     auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); | 
| @@ -304,7 +306,7 @@ static llama_token llama_sampling_sample_impl( | 
| 304 | 306 | 
 | 
| 305 | 307 |             sampler_queue(ctx_main, params, cur_p, min_keep); | 
| 306 | 308 | 
 | 
| 307 |  | -            id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); | 
|  | 309 | +            id = llama_sample_token_seq(ctx_main, &cur_p, ctx_sampling->seq_id); | 
| 308 | 310 | 
 | 
| 309 | 311 |             //{ | 
| 310 | 312 |             //    const int n_top = 10; | 
|  | 
0 commit comments