Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,6 @@ common_init_result::common_init_result(common_params & params) :
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
}

// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) {
cparams.samplers = pimpl->samplers_seq_config.data();
cparams.n_samplers = pimpl->samplers_seq_config.size();
Expand Down
24 changes: 19 additions & 5 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,21 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}

void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);

delete gsmpl;
if (!gsmpl) {
return;
}

llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);

delete gsmpl;
}

void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
if (!gsmpl) {
return;
}

const auto tm = gsmpl->tm();

if (gsmpl->grmr && accept_grammar) {
Expand All @@ -355,6 +361,10 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
}

void common_sampler_reset(struct common_sampler * gsmpl) {
if (!gsmpl) {
return;
}

gsmpl->reset();
}

Expand Down Expand Up @@ -415,6 +425,10 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
}

struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
if (!gsmpl) {
return nullptr;
}

return gsmpl->chain;
}

Expand Down
1 change: 0 additions & 1 deletion examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ int main(int argc, char ** argv) {
sampler_configs.push_back({ i, smpl });
}

// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) {
ctx_params.samplers = sampler_configs.data();
ctx_params.n_samplers = sampler_configs.size();
Expand Down
1 change: 0 additions & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1255,7 +1255,6 @@ extern "C" {
// [EXPERIMENTAL]
// attach a sampler to the context
// note: prefer initializing the context with llama_context_params.samplers when possible
// note: changing the samplers of a context can cause graph reallocations and degraded performance
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);

// mirror of llama_sampler_i:
Expand Down
Loading