Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,6 @@ static llama_token llama_sampling_sample_impl(
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else if (adaptive_target >= 0.0f && ctx_sampling->adapt_p_ctx!=nullptr) {
// adaptive p sampling
llama_prep_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx);
sampler_queue(ctx_main, params, ctx_sampling, cur_p, std::max(1, params.min_keep));
id = llama_sample_token_adaptive_p(ctx_main, &cur_p, ctx_sampling->adapt_p_ctx);
} else {
Expand Down Expand Up @@ -496,6 +495,11 @@ static llama_token_data_array llama_sampling_prepare_impl(
*original_logits = {logits, logits + n_vocab};
}

if ((params.temp > 0) && (params.mirostat == 0) && (params.adaptive_target >= 0) && (ctx_sampling->adapt_p_ctx != nullptr)) {
// collect original probability before logit bias is applied
llama_prep_adaptive_p(ctx_main, logits, ctx_sampling->adapt_p_ctx);
}

// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
Expand Down
2 changes: 1 addition & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,7 @@ LLAMA_API struct llama_grammar* llama_sampler_init_grammar_lazy_patterns(
const uint32_t seed);

void llama_prep_adaptive_p(struct llama_context * ctx,
llama_token_data_array * candidates,
float * logits,
struct llama_sampler_adaptive_p * adapt_p_ctx);

/// @details Adaptive p sampler described in https://github.com/MrJackSpade/adaptive-p-docs/blob/main/README.md
Expand Down
13 changes: 4 additions & 9 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ void llama_sample_adaptive_p_impl(struct llama_sampling * ctx, llama_token_data_

void llama_prep_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
float * logits,
struct llama_sampler_adaptive_p * adapt_p_ctx) {
if (adapt_p_ctx->updt_w_cur) {
// update with current probability, original not needed
Expand All @@ -1178,16 +1178,11 @@ void llama_prep_adaptive_p_impl(
constexpr float kDelta = 30.0f; //16.6f;
auto t_start = ggml_time_us();
auto & orig_prob = adapt_p_ctx->orig_prob;
if (candidates->size != orig_prob.size() || candidates->sorted) {
LLAMA_LOG_ERROR("%s: this function must be called before any other sampler has been applied\n", __func__);
LLAMA_LOG_ERROR("%s: the sampler has been initialized with a vocabulary of %zu, but is being called with %zu candidates\n",
__func__, orig_prob.size(), candidates->size);
GGML_ABORT("Bad candidates in adaptive_p sampler");
}

std::copy(logits, logits + orig_prob.size(), orig_prob.begin());

float max_logit = -INFINITY;
for (int j = 0; j < int(candidates->size); ++j) {
orig_prob[j] = candidates->data[j].logit;
for (int j = 0; j < int(orig_prob.size()); ++j) {
max_logit = std::max(max_logit, orig_prob[j]);
}
adapt_p_ctx->cum_orig_prob = iqk_exp_with_thresh(orig_prob.size(), orig_prob.data(), max_logit, max_logit - kDelta);
Expand Down
2 changes: 1 addition & 1 deletion src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ struct llama_sampler_adaptive_p * llama_init_adaptive_p_impl(int n_vocab,

void llama_prep_adaptive_p_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
float * logits,
struct llama_sampler_adaptive_p * adapt_p_ctx);

void llama_sample_adaptive_p_impl(
Expand Down
4 changes: 2 additions & 2 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8169,8 +8169,8 @@ void llama_sample_adaptive_p(llama_context * ctx,
llama_sample_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx);
}

void llama_prep_adaptive_p(struct llama_context * ctx, llama_token_data_array * candidates, struct llama_sampler_adaptive_p * adapt_p_ctx) {
llama_prep_adaptive_p_impl(&ctx->sampling, candidates, adapt_p_ctx);
void llama_prep_adaptive_p(struct llama_context * ctx, float * logits, struct llama_sampler_adaptive_p * adapt_p_ctx) {
llama_prep_adaptive_p_impl(&ctx->sampling, logits, adapt_p_ctx);
}


Expand Down