Skip to content
1 change: 0 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,6 @@ common_init_result::common_init_result(common_params & params) :
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
}

// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) {
cparams.samplers = pimpl->samplers_seq_config.data();
cparams.n_samplers = pimpl->samplers_seq_config.size();
Expand Down
24 changes: 19 additions & 5 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,21 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}

void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);

delete gsmpl;
if (!gsmpl) {
return;
}

llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);

delete gsmpl;
}

void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
if (!gsmpl) {
return;
}

const auto tm = gsmpl->tm();

if (gsmpl->grmr && accept_grammar) {
Expand All @@ -355,6 +361,10 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
}

void common_sampler_reset(struct common_sampler * gsmpl) {
if (!gsmpl) {
return;
}

gsmpl->reset();
}

Expand Down Expand Up @@ -415,6 +425,10 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
}

struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
if (!gsmpl) {
return nullptr;
}

return gsmpl->chain;
}

Expand Down
1 change: 0 additions & 1 deletion examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ int main(int argc, char ** argv) {
sampler_configs.push_back({ i, smpl });
}

// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) {
ctx_params.samplers = sampler_configs.data();
ctx_params.n_samplers = sampler_configs.size();
Expand Down
1 change: 0 additions & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,6 @@ extern "C" {
// [EXPERIMENTAL]
// attach a sampler to the context
// note: prefer initializing the context with llama_context_params.samplers when possible
// note: changing the samplers of a context can cause graph reallocations and degraded performance
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);

// mirror of llama_sampler_i:
Expand Down
Loading
Loading