diff --git a/common/reasoning-budget.cpp b/common/reasoning-budget.cpp index 958c9cacf51..e0edd16c70b 100644 --- a/common/reasoning-budget.cpp +++ b/common/reasoning-budget.cpp @@ -193,6 +193,7 @@ static struct llama_sampler_i common_reasoning_budget_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ nullptr, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ nullptr, }; static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { diff --git a/include/llama.h b/include/llama.h index 75095b22d08..7acef220c43 100644 --- a/include/llama.h +++ b/include/llama.h @@ -387,6 +387,7 @@ extern "C" { // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) struct llama_sampler_seq_config * samplers; size_t n_samplers; + uint32_t n_sampling_outputs_max; }; struct llama_model_tensor_override { @@ -1259,6 +1260,9 @@ extern "C" { // called before graph execution to set inputs for the current ubatch void (*backend_set_input)(struct llama_sampler * smpl); + + // returns the number of ggml tensors that a backend sampler needs. + uint32_t (*backend_n_nodes)(const struct llama_sampler * smpl); }; struct llama_sampler { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3cc8ffa6668..deafdc3259b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -85,6 +85,8 @@ llama_context::llama_context( cparams.ctx_type = params.ctx_type; + cparams.n_sampling_outputs_max = params.n_sampling_outputs_max; + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -1486,8 +1488,8 @@ int llama_context::encode(const llama_batch & batch_inp) { return 0; } -static std::map build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) { - std::map seq_to_row; +static std::map> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) { + std::map> seq_to_row; // how many output tokens we have seen so far for this ubatch. uint32_t local = 0; for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { @@ -1498,96 +1500,111 @@ static std::map build_seq_to_output_row(const llama_ubat const llama_seq_id seq_id = ubatch.seq_id[i][0]; // row_offset is the number of output tokens before this ubatch. - seq_to_row[seq_id] = row_offset + local; + seq_to_row[seq_id].push_back(row_offset + local); ++local; } return seq_to_row; } static void copy_tensor_async_ints( - const std::map & tensor_map, + const std::map> & tensor_map, const buffer_view & sampled, - const std::map & seq_to_row, + const std::map> & seq_to_row, ggml_backend_sched_t sched) { if (!sampled.has_data()) { return; } - for (const auto & [seq_id, tensor] : tensor_map) { + for (const auto & [seq_id, tensors] : tensor_map) { auto it = seq_to_row.find(seq_id); if (it == seq_to_row.end()) { continue; } - const uint32_t row = it->second; - GGML_ASSERT(row < sampled.size); + const std::vector & rows = it->second; + + for (size_t i = 0; i < tensors.size(); ++i) { + const uint32_t row = rows[i]; + ggml_tensor * tensor = tensors[i]; - GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); + GGML_ASSERT(row < sampled.size); + GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); - ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row])); + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row])); + } } } static void copy_tensor_async_floats( - const std::map & tensor_map, + const std::map> & tensor_map, const buffer_view & dst, size_t stride, std::vector & counts, - const std::map & seq_to_row, + const std::map> & seq_to_row, ggml_backend_sched_t sched) { if (!dst.has_data()) { return; } - for (const auto & [seq_id, tensor] : tensor_map) { + for (const auto & [seq_id, tensors] : tensor_map) { auto it = seq_to_row.find(seq_id); if (it == seq_to_row.end()) { continue; } - const uint32_t row = it->second; - GGML_ASSERT(row < counts.size()); + const std::vector & rows = it->second; + + for (size_t i = 0; i < tensors.size(); ++i) { + const uint32_t row = rows[i]; + ggml_tensor * tensor = tensors[i]; - GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); + GGML_ASSERT(row < counts.size()); + GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); - ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - float * row_ptr = dst.data + (size_t) row * stride; - ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + float * row_ptr = dst.data + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); - // Update the actual number of logits/probabilities that were written for this row. - counts[row] = ggml_nelements(tensor); + // Update the actual number of logits/probabilities that were written for this row. + counts[row] = ggml_nelements(tensor); + } } } static void copy_tensor_async_candidates( - const std::map & tensor_map, + const std::map> & tensor_map, const buffer_view & dst, size_t stride, std::vector & counts, - const std::map & seq_to_row, + const std::map> & seq_to_row, ggml_backend_sched_t sched) { if (!dst.has_data()) { return; } - for (const auto & [seq_id, tensor] : tensor_map) { + for (const auto & [seq_id, tensors] : tensor_map) { auto it = seq_to_row.find(seq_id); if (it == seq_to_row.end()) { continue; } - const uint32_t row = it->second; - GGML_ASSERT(row < counts.size()); + const std::vector & rows = it->second; + + for (size_t i = 0; i < tensors.size(); ++i) { + const uint32_t row = rows[i]; + ggml_tensor * tensor = tensors[i]; - GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); + GGML_ASSERT(row < counts.size()); + GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); - ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - llama_token * row_ptr = dst.data + (size_t) row * stride; - ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + llama_token * row_ptr = dst.data + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); - // Update the actual number of candidates that were written. - counts[row] = ggml_nelements(tensor); + // Update the actual number of candidates that were written. + counts[row] = ggml_nelements(tensor); + } } } @@ -1635,30 +1652,6 @@ int llama_context::decode(const llama_batch & batch_inp) { const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max; - // TODO: avoid this workaround in the future - if (has_samplers && batch_inp.logits) { - std::vector seq_output_count(n_seq_max, 0); - - for (int32_t i = 0; i < batch_inp.n_tokens; ++i) { - if (batch_inp.logits[i] == 0) { - continue; - } - - const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1; - - for (int32_t s = 0; s < ns; ++s) { - const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0; - - seq_output_count[seq_id]++; - if (seq_output_count[seq_id] > 1) { - LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n", - __func__, seq_id, seq_output_count[seq_id]); - return -1; - } - } - } - } - if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; @@ -2192,14 +2185,35 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { + uint32_t res; if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) { - return std::max(n_tokens * 40, 32u * model.n_tensors()); + res = std::max(n_tokens * 40, 32u * model.n_tensors()); + } else { + res = std::max(1024u, 8u*model.n_tensors()); + for (const auto & lora : model.loras) { + res += lora->get_n_nodes(); + } } - uint32_t res = std::max(1024u, 8u*model.n_tensors()); - for (const auto & lora : model.loras) { - res += lora->get_n_nodes(); + + // Account for backend sampling with multiple outputs per sequence. + uint32_t sampling_nodes = 0; + if (!sampling.samplers.empty()) { + uint32_t backend_n_nodes = 0; + for (const auto & [seq_id, sampler] : sampling.samplers) { + if (sampler->iface->backend_n_nodes) { + backend_n_nodes += sampler->iface->backend_n_nodes(sampler); + } + } + + const uint32_t sampling_outputs = std::min(n_tokens, cparams.n_sampling_outputs_max); + const uint32_t max_samplers = cparams.n_seq_max; + const uint32_t n_active_samplers = (uint32_t) sampling.samplers.size(); + const uint32_t logits_t_pad = 1; + // each active sampler contributes one logits_seq view plus its backend tensors per output + sampling_nodes = (backend_n_nodes + n_active_samplers) * sampling_outputs * max_samplers + logits_t_pad; } - return res; + + return res + sampling_nodes; } llm_graph_result * llama_context::get_gf_res_reserve() const { @@ -3353,6 +3367,7 @@ llama_context_params llama_context_default_params() { /*.kv_unified =*/ false, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, + /*.n_sampling_outputs_max =*/ 32, }; return result; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 20ec59fe335..ff6ec10436e 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -46,6 +46,8 @@ struct llama_cparams { enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; + uint32_t n_sampling_outputs_max; // max outputs per sequence for backend sampling + ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; }; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 31cf41a1c2d..0e60670cbbd 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -851,24 +851,24 @@ void llm_graph_result::set_outputs() { if (t_h_pre_norm != nullptr) { ggml_set_output(t_h_pre_norm); } - for (auto & [seq_id, t] : t_sampled) { - if (t != nullptr) { - ggml_set_output(t); + for (auto & [seq_id, tensors] : t_sampled) { + for (ggml_tensor * tensor : tensors) { + ggml_set_output(tensor); } } - for (auto & [seq_id, t] : t_sampled_probs) { - if (t != nullptr) { - ggml_set_output(t); + for (auto & [seq_id, tensors] : t_sampled_probs) { + for (ggml_tensor * tensor : tensors) { + ggml_set_output(tensor); } } - for (auto & [seq_id, t] : t_sampled_logits) { - if (t != nullptr) { - ggml_set_output(t); + for (auto & [seq_id, tensors] : t_sampled_logits) { + for (ggml_tensor * tensor : tensors) { + ggml_set_output(tensor); } } - for (auto & [seq_id, t] : t_candidates) { - if (t != nullptr) { - ggml_set_output(t); + for (auto & [seq_id, tensors] : t_candidates) { + for (ggml_tensor * tensor : tensors) { + ggml_set_output(tensor); } } } @@ -2828,13 +2828,13 @@ void llm_graph_context::build_sampling() const { auto inp_sampling = std::make_unique(samplers); res->add_input(std::move(inp_sampling)); - std::map seq_to_logit_row; + std::map> seq_to_logit_rows; int32_t logit_row_idx = 0; for (uint32_t i = 0; i < ubatch.n_tokens; i++) { if (ubatch.output[i]) { llama_seq_id seq_id = ubatch.seq_id[i][0]; - seq_to_logit_row[seq_id] = logit_row_idx; + seq_to_logit_rows[seq_id].push_back(logit_row_idx); logit_row_idx++; } } @@ -2847,48 +2847,66 @@ void llm_graph_context::build_sampling() const { // this is important in order to minimize graph reallocations ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0); + // During graph reservation, n_outputs can be very large (for example 512 for worst-case PP). + // We cap it to a user-configurable maximum since typical multi output scenarios use far fewer. + const uint32_t max_outputs = std::min(n_outputs, cparams.n_sampling_outputs_max); + for (const auto & [seq_id, sampler] : samplers) { - const auto it = seq_to_logit_row.find(seq_id); + const auto row_it = seq_to_logit_rows.find(seq_id); + const bool sampler_is_active = row_it != seq_to_logit_rows.end(); - // inactive samplers always work on the first row - const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0; - const int i_out = it != seq_to_logit_row.end() ? 1 : 0; + // Always build samplers for all possible outputs even if the sampler is + // not active (the sampler's sequence id is not in the current ubatch). + for (uint32_t i = 0; i < max_outputs; ++i) { + const bool real_output = sampler_is_active && i < row_it->second.size(); - ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); - ggml_format_name(logits_seq, "logits_seq_%d", seq_id); + const int32_t row_idx = real_output ? row_it->second[i] : 0; + const int i_out = real_output ? 1 : 0; - struct llama_sampler_data data = { - /*.logits =*/ logits_seq, - /*.probs =*/ nullptr, - /*.sampled =*/ nullptr, - /*.candidates =*/ nullptr, - }; + ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); + ggml_format_name(logits_seq, "logits_seq_%d_%d", seq_id, i); - assert(sampler->iface->backend_apply); - sampler->iface->backend_apply(sampler, ctx0, gf, &data); + struct llama_sampler_data data = { + /*.logits =*/ logits_seq, + /*.probs =*/ nullptr, + /*.sampled =*/ nullptr, + /*.candidates =*/ nullptr, + }; - if (data.sampled != nullptr) { - res->t_sampled[seq_id] = data.sampled; - outs[1] = data.sampled; - ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); - } + assert(sampler->iface->backend_apply); + sampler->iface->backend_apply(sampler, ctx0, gf, &data); - if (data.probs != nullptr) { - res->t_sampled_probs[seq_id] = data.probs; - outs[1] = data.probs; - ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); - } + if (data.sampled != nullptr) { + if (real_output) { + res->t_sampled[seq_id].push_back(data.sampled); + } + outs[1] = data.sampled; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); + } - if (data.logits != nullptr) { - res->t_sampled_logits[seq_id] = data.logits; - outs[1] = data.logits; - ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); - } + if (data.probs != nullptr) { + if (real_output) { + res->t_sampled_probs[seq_id].push_back(data.probs); + } + outs[1] = data.probs; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); + } + + if (data.logits != nullptr) { + if (real_output) { + res->t_sampled_logits[seq_id].push_back(data.logits); + } + outs[1] = data.logits; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); + } - if (data.candidates != nullptr) { - res->t_candidates[seq_id] = data.candidates; - outs[1] = data.candidates; - ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); + if (data.candidates != nullptr) { + if (real_output) { + res->t_candidates[seq_id].push_back(data.candidates); + } + outs[1] = data.candidates; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); + } } } diff --git a/src/llama-graph.h b/src/llama-graph.h index bf6778237e6..4e1900d7995 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -677,10 +677,10 @@ class llm_graph_result { ggml_tensor * t_embd_pooled = nullptr; ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm - std::map t_sampled_logits; - std::map t_candidates; - std::map t_sampled; - std::map t_sampled_probs; + std::map> t_sampled_logits; + std::map> t_candidates; + std::map> t_sampled; + std::map> t_sampled_probs; std::vector inputs; diff --git a/src/llama-sampler.cpp b/src/llama-sampler.cpp index 9bbc5dbde24..4f33da877c4 100644 --- a/src/llama-sampler.cpp +++ b/src/llama-sampler.cpp @@ -507,6 +507,7 @@ static struct llama_sampler_i llama_sampler_empty_i = { /* .backend_accept = */ llama_sampler_empty_backend_accept, /* .backend_apply = */ llama_sampler_empty_backend_apply, /* .backend_set_input = */ llama_sampler_empty_backend_set_input, + /* .backend_n_nodes = */ nullptr, }; struct llama_sampler * llama_sampler_init_empty(const char * name) { @@ -776,6 +777,17 @@ static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) { } } +static uint32_t llama_sampler_chain_backend_n_nodes(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_chain *) smpl->ctx; + uint32_t total = 0; + for (const auto & s : sctx->samplers) { + if (s.ptr->iface->backend_n_nodes) { + total += s.ptr->iface->backend_n_nodes(s.ptr); + } + } + return total; +} + static struct llama_sampler_i llama_sampler_chain_i = { /* .name = */ llama_sampler_chain_name, /* .accept = */ llama_sampler_chain_accept, @@ -787,6 +799,7 @@ static struct llama_sampler_i llama_sampler_chain_i = { /* .backend_accept = */ llama_sampler_chain_backend_accept, /* .backend_apply = */ llama_sampler_chain_backend_apply, /* .backend_set_input = */ llama_sampler_chain_backend_set_input, + /* .backend_n_nodes = */ llama_sampler_chain_backend_n_nodes, }; struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { @@ -995,6 +1008,10 @@ static void llama_sampler_greedy_backend_apply( data->sampled = curl; } +static uint32_t llama_sampler_greedy_backend_n_nodes(const struct llama_sampler *) { + return 1; +} + static struct llama_sampler_i llama_sampler_greedy_i = { /* .name = */ llama_sampler_greedy_name, /* .accept = */ nullptr, @@ -1006,6 +1023,7 @@ static struct llama_sampler_i llama_sampler_greedy_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ llama_sampler_greedy_backend_apply, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ llama_sampler_greedy_backend_n_nodes, }; struct llama_sampler * llama_sampler_init_greedy() { @@ -1180,6 +1198,9 @@ static void llama_sampler_dist_backend_apply( struct ggml_tensor * idxf = ggml_sum(ctx, mask); ggml_set_name(idxf, "dist_index_f32"); + // Clamp to prevent out-of-bounds access when computing the index. + idxf = ggml_clamp(ctx, idxf, 1.0f, mask->ne[0]); + // Use ggml_scale_bias to scale the index value by -1 and then add the size // of the mask to that value so we get the correct index ((-1 * idxf) + n). struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32); @@ -1214,6 +1235,10 @@ static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); } +static uint32_t llama_sampler_dist_backend_n_nodes(const struct llama_sampler *) { + return 11; +} + static struct llama_sampler_i llama_sampler_dist_i = { /* .name = */ llama_sampler_dist_name, /* .accept = */ nullptr, @@ -1225,6 +1250,7 @@ static struct llama_sampler_i llama_sampler_dist_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ llama_sampler_dist_backend_apply, /* .backend_set_input = */ llama_sampler_dist_backend_set_input, + /* .backend_n_nodes = */ llama_sampler_dist_backend_n_nodes, }; struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { @@ -1305,6 +1331,10 @@ static void llama_sampler_top_k_backend_apply( GGML_UNUSED(gf); } +static uint32_t llama_sampler_top_k_backend_n_nodes(const struct llama_sampler *) { + return 7; +} + static struct llama_sampler_i llama_sampler_top_k_i = { /* .name = */ llama_sampler_top_k_name, /* .accept = */ nullptr, @@ -1316,6 +1346,7 @@ static struct llama_sampler_i llama_sampler_top_k_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ llama_sampler_top_k_backend_apply, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ llama_sampler_top_k_backend_n_nodes, }; struct llama_sampler * llama_sampler_init_top_k(int32_t k) { @@ -1497,6 +1528,10 @@ static void llama_sampler_top_p_backend_apply( GGML_UNUSED(gf); } +static uint32_t llama_sampler_top_p_backend_n_nodes(const struct llama_sampler *) { + return 20; +} + static struct llama_sampler_i llama_sampler_top_p_i = { /* .name = */ llama_sampler_top_p_name, /* .accept = */ nullptr, @@ -1508,6 +1543,7 @@ static struct llama_sampler_i llama_sampler_top_p_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ llama_sampler_top_p_backend_apply, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ llama_sampler_top_p_backend_n_nodes, }; struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { @@ -1654,6 +1690,10 @@ static void llama_sampler_min_p_backend_apply( GGML_UNUSED(gf); } +static uint32_t llama_sampler_min_p_backend_n_nodes(const struct llama_sampler *) { + return 8; +} + static struct llama_sampler_i llama_sampler_min_p_i = { /* .name = */ llama_sampler_min_p_name, /* .accept = */ nullptr, @@ -1665,6 +1705,7 @@ static struct llama_sampler_i llama_sampler_min_p_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ llama_sampler_min_p_backend_apply, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ llama_sampler_min_p_backend_n_nodes, }; struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { @@ -1775,6 +1816,7 @@ static struct llama_sampler_i llama_sampler_typical_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ nullptr, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ nullptr, }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { @@ -1868,6 +1910,10 @@ static void llama_sampler_temp_backend_apply( llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp); } +static uint32_t llama_sampler_temp_backend_n_nodes(const struct llama_sampler *) { + return 5; +} + static struct llama_sampler_i llama_sampler_temp_i = { /* .name = */ llama_sampler_temp_name, /* .accept = */ nullptr, @@ -1879,6 +1925,7 @@ static struct llama_sampler_i llama_sampler_temp_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ llama_sampler_temp_backend_apply, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ llama_sampler_temp_backend_n_nodes, }; struct llama_sampler * llama_sampler_init_temp(float temp) { @@ -2065,6 +2112,10 @@ static void llama_sampler_temp_ext_backend_apply( data->logits = scaled_logits; } +static uint32_t llama_sampler_temp_ext_backend_n_nodes(const struct llama_sampler *) { + return 12; +} + static struct llama_sampler_i llama_sampler_temp_ext_i = { /* .name = */ llama_sampler_temp_ext_name, /* .accept = */ nullptr, @@ -2076,6 +2127,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ llama_sampler_temp_ext_backend_apply, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ llama_sampler_temp_ext_backend_n_nodes, }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { @@ -2183,6 +2235,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ nullptr, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ nullptr, }; struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { @@ -2302,6 +2355,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ nullptr, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { @@ -2406,6 +2460,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ nullptr, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { @@ -2527,6 +2582,7 @@ static struct llama_sampler_i llama_sampler_grammar_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ nullptr, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ nullptr, }; static struct llama_sampler * llama_sampler_init_grammar_impl( @@ -2738,6 +2794,7 @@ static struct llama_sampler_i llama_sampler_penalties_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ nullptr, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ nullptr, }; struct llama_sampler * llama_sampler_init_penalties( @@ -2837,6 +2894,7 @@ static struct llama_sampler_i llama_sampler_top_n_sigma_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ nullptr, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { @@ -3177,6 +3235,7 @@ static struct llama_sampler_i llama_sampler_dry_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ nullptr, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ nullptr, }; struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { @@ -3397,6 +3456,7 @@ static struct llama_sampler_i llama_sampler_adaptive_p_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ nullptr, /* .backend_set_input = */ nullptr, + /* .backend_n_nodes = */ nullptr, }; struct llama_sampler * llama_sampler_init_adaptive_p( @@ -3555,6 +3615,10 @@ static bool llama_sampler_logit_bias_backend_init( return true; } +static uint32_t llama_sampler_logit_bias_backend_n_nodes(const struct llama_sampler *) { + return 7; +} + static struct llama_sampler_i llama_sampler_logit_bias_i = { /* .name = */ llama_sampler_logit_bias_name, /* .accept = */ nullptr, @@ -3566,6 +3630,7 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = { /* .backend_accept = */ nullptr, /* .backend_apply = */ llama_sampler_logit_bias_backend_apply, /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input, + /* .backend_n_nodes = */ llama_sampler_logit_bias_backend_n_nodes, }; struct llama_sampler * llama_sampler_init_logit_bias( @@ -3809,6 +3874,7 @@ static struct llama_sampler_i llama_sampler_infill_i = { /* .backend_accept = */ nullptr, /* .backend_set_input = */ nullptr, /* .backend_init = */ nullptr, + /* .backend_n_nodes = */ nullptr, }; struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index 58361ae80ae..5d247cd6566 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -969,7 +969,7 @@ static void test_backend_cpu_mixed_batch(const test_params & params) { printf("backend-cpu mixed batch test PASSED\n"); } -static void test_backend_max_outputs(const test_params & params) { +static void test_backend_multiple_outputs(const test_params & params) { const int seq_id = 0; const int32_t seed = 88; @@ -995,17 +995,32 @@ static void test_backend_max_outputs(const test_params & params) { } for (size_t i = 0; i < tokens.size(); i++) { - // set all tokens as output to trigger error + // set all tokens as output to get multiple outputs for a single sequence. common_batch_add(batch, tokens[i], i, { seq_id }, true); } - printf(">>> test_max_outputs expected error start:\n"); const int ret = llama_decode(test_ctx.ctx.get(), batch); - GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence"); - printf("<<< test_max_outputs expected error end.\n"); + if (ret != 0) { + GGML_ASSERT(false && "Failed to decode sequence with multiple outputs"); + } + + std::vector sampled_tokens; + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), i); + const std::string token_str = test_ctx.token_to_piece(token, false); + //printf("Position %d: token id=%d, string='%s'\n", i, token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + sampled_tokens.push_back(token); + } + } + + GGML_ASSERT((int)sampled_tokens.size() == batch.n_tokens); + printf("Sampled %zu tokens for sequence %d\n", sampled_tokens.size(), seq_id); + llama_batch_free(batch); - printf("backend max outputs test PASSED\n"); + printf("backend multiple outputs test PASSED\n"); } struct backend_test_case { @@ -1024,7 +1039,7 @@ static const backend_test_case BACKEND_TESTS[] = { { "dist", test_backend_dist_sampling, true }, { "dist_and_cpu", test_backend_dist_sampling_and_cpu, true }, { "set_sampler", test_backend_set_sampler, true }, - { "max_outputs", test_backend_max_outputs, true }, + { "multiple_outputs",test_backend_multiple_outputs, true }, { "mixed", test_backend_mixed_sampling, true }, { "min_p", test_backend_min_p_sampling, true }, { "cpu_mixed", test_backend_cpu_mixed_batch, true },