diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 51bff1c44bf..75c6366c7fa 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -81,6 +81,8 @@ add_library(${TARGET} STATIC preset.cpp preset.h regex-partial.cpp + reasoning-budget.cpp + reasoning-budget.h regex-partial.h sampling.cpp sampling.h diff --git a/common/reasoning-budget.cpp b/common/reasoning-budget.cpp new file mode 100644 index 00000000000..a55e4f509d4 --- /dev/null +++ b/common/reasoning-budget.cpp @@ -0,0 +1,219 @@ +#include "reasoning-budget.h" +#include "common.h" +#include "unicode.h" + +#include "log.h" + +#include +#include +#include +#include + +struct token_matcher { + std::vector tokens; + size_t pos = 0; + + bool advance(llama_token token) { + if (tokens.empty()) { + return false; + } + + if (token == tokens[pos]) { + pos++; + if (pos >= tokens.size()) { + pos = 0; + return true; + } + } else { + pos = 0; + if (token == tokens[0]) { + pos = 1; + } + } + return false; + } + + void reset() { pos = 0; } +}; + +struct common_reasoning_budget_ctx { + const llama_vocab * vocab; + + token_matcher start_matcher; + token_matcher end_matcher; + std::vector forced_tokens; + + int32_t budget; // maximum tokens in reasoning block + int32_t remaining; // tokens remaining in budget + + common_reasoning_budget_state state; + + // for forcing + size_t force_pos; // next position in forced_tokens to force +}; + +static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) { + return "reasoning-budget"; +} + +static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; + + switch (ctx->state) { + case REASONING_BUDGET_IDLE: + { + if (ctx->start_matcher.advance(token)) { + ctx->state = REASONING_BUDGET_COUNTING; + ctx->remaining = ctx->budget; + LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget); + + if (ctx->remaining <= 0) { + ctx->state = REASONING_BUDGET_FORCING; + ctx->force_pos = 0; + LOG_INF("reasoning-budget: budget=0, forcing immediately\n"); + } + } + break; + } + case REASONING_BUDGET_COUNTING: + case REASONING_BUDGET_WAITING_UTF8: + { + if (ctx->end_matcher.advance(token)) { + ctx->state = REASONING_BUDGET_DONE; + LOG_INF("reasoning-budget: deactivated (natural end)\n"); + break; + } + + bool utf8_complete = true; + if (ctx->vocab != nullptr) { + const std::string piece = common_token_to_piece(ctx->vocab, token, false); + utf8_complete = common_utf8_is_complete(piece); + } + + if (ctx->state == REASONING_BUDGET_WAITING_UTF8) { + if (utf8_complete) { + ctx->state = REASONING_BUDGET_FORCING; + ctx->force_pos = 0; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n"); + } + } else if (ctx->state == REASONING_BUDGET_COUNTING) { + ctx->remaining--; + if (ctx->remaining <= 0) { + if (utf8_complete) { + ctx->state = REASONING_BUDGET_FORCING; + ctx->force_pos = 0; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n"); + } else { + ctx->state = REASONING_BUDGET_WAITING_UTF8; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n"); + } + } + } + break; + } + case REASONING_BUDGET_FORCING: + // force_pos is advanced in apply(), not here. + // This ensures the first forced token isn't skipped when the sampler + // is initialized directly in FORCING state (e.g. COUNTING + budget=0) + break; + case REASONING_BUDGET_DONE: + break; + } +} + +static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; + + if (ctx->state != REASONING_BUDGET_FORCING) { + // passthrough — don't modify logits + return; + } + + if (ctx->force_pos >= ctx->forced_tokens.size()) { + return; + } + + const llama_token forced = ctx->forced_tokens[ctx->force_pos]; + + // set all logits to -inf except the forced token + for (size_t i = 0; i < cur_p->size; i++) { + if (cur_p->data[i].id != forced) { + cur_p->data[i].logit = -INFINITY; + } + } + + // advance to next forced token (done here rather than in accept so that + // the first forced token isn't skipped when starting in FORCING state) + ctx->force_pos++; + if (ctx->force_pos >= ctx->forced_tokens.size()) { + ctx->state = REASONING_BUDGET_DONE; + LOG_INF("reasoning-budget: forced sequence complete, done\n"); + } +} + +static void common_reasoning_budget_reset(struct llama_sampler * smpl) { + auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; + ctx->state = REASONING_BUDGET_IDLE; + ctx->remaining = ctx->budget; + ctx->start_matcher.reset(); + ctx->end_matcher.reset(); + ctx->force_pos = 0; +} + +static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx; + return common_reasoning_budget_init( + ctx->vocab, + ctx->start_matcher.tokens, + ctx->end_matcher.tokens, + ctx->forced_tokens, + ctx->budget, + ctx->state); +} + +static void common_reasoning_budget_free(struct llama_sampler * smpl) { + delete (common_reasoning_budget_ctx *) smpl->ctx; +} + +static struct llama_sampler_i common_reasoning_budget_i = { + /* .name = */ common_reasoning_budget_name, + /* .accept = */ common_reasoning_budget_accept, + /* .apply = */ common_reasoning_budget_apply, + /* .reset = */ common_reasoning_budget_reset, + /* .clone = */ common_reasoning_budget_clone, + /* .free = */ common_reasoning_budget_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, +}; + +struct llama_sampler * common_reasoning_budget_init( + const struct llama_vocab * vocab, + const std::vector & start_tokens, + const std::vector & end_tokens, + const std::vector & forced_tokens, + int32_t budget, + common_reasoning_budget_state initial_state) { + // promote COUNTING with budget <= 0 to FORCING + if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) { + initial_state = REASONING_BUDGET_FORCING; + } + + return llama_sampler_init( + /* .iface = */ &common_reasoning_budget_i, + /* .ctx = */ new common_reasoning_budget_ctx { + /* .vocab = */ vocab, + /* .start_matcher = */ { start_tokens, 0 }, + /* .end_matcher = */ { end_tokens, 0 }, + /* .forced_tokens = */ forced_tokens, + /* .budget = */ budget, + /* .remaining = */ budget, + /* .state = */ initial_state, + /* .force_pos = */ 0, + } + ); +} diff --git a/common/reasoning-budget.h b/common/reasoning-budget.h new file mode 100644 index 00000000000..08ad2824811 --- /dev/null +++ b/common/reasoning-budget.h @@ -0,0 +1,41 @@ +#pragma once + +#include "llama.h" + +#include +#include + +enum common_reasoning_budget_state { + REASONING_BUDGET_IDLE, // waiting for start sequence + REASONING_BUDGET_COUNTING, // counting down tokens + REASONING_BUDGET_FORCING, // forcing budget message + end sequence + REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion + REASONING_BUDGET_DONE, // passthrough forever +}; + +// Creates a reasoning budget sampler that limits token generation inside a +// reasoning block (e.g. between and ). +// +// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE +// IDLE: passthrough, watching for start_tokens sequence +// COUNTING: counting down remaining tokens, watching for natural end_tokens +// WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence +// FORCING: forces forced_tokens token-by-token (all other logits -> -inf) +// DONE: passthrough forever +// +// Parameters: +// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr) +// start_tokens - token sequence that activates counting +// end_tokens - token sequence for natural deactivation +// forced_tokens - token sequence forced when budget expires +// budget - max tokens allowed in the reasoning block +// initial_state - initial state of the sampler (e.g. IDLE or COUNTING) +// note: COUNTING with budget <= 0 is promoted to FORCING +// +struct llama_sampler * common_reasoning_budget_init( + const struct llama_vocab * vocab, + const std::vector & start_tokens, + const std::vector & end_tokens, + const std::vector & forced_tokens, + int32_t budget, + common_reasoning_budget_state initial_state); diff --git a/common/sampling.cpp b/common/sampling.cpp index 73b49306432..f849d4f61af 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" +#include "reasoning-budget.h" #include #include @@ -252,13 +253,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st // reasoning budget sampler — added first so it can force tokens before other samplers if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) { - samplers.push_back(llama_sampler_init_reasoning_budget( + samplers.push_back(common_reasoning_budget_init( vocab, - params.reasoning_budget_start.data(), params.reasoning_budget_start.size(), - params.reasoning_budget_end.data(), params.reasoning_budget_end.size(), - params.reasoning_budget_forced.data(), params.reasoning_budget_forced.size(), + params.reasoning_budget_start, + params.reasoning_budget_end, + params.reasoning_budget_forced, params.reasoning_budget_tokens, - params.reasoning_budget_activate_immediately)); + params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE)); } if (params.has_logit_bias()) { diff --git a/common/unicode.cpp b/common/unicode.cpp index c0ef6d02926..f71fe56783f 100644 --- a/common/unicode.cpp +++ b/common/unicode.cpp @@ -1,8 +1,10 @@ #include "unicode.h" + +#include #include #include -#include #include +#include // implementation adopted from src/unicode.cpp @@ -67,6 +69,20 @@ utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t off return utf8_parse_result(utf8_parse_result::INVALID); } +bool common_utf8_is_complete(const std::string & s) { + if (s.empty()) { + return true; + } + for (int i = 1; i <= std::min(4, (int)s.size()); i++) { + unsigned char c = s[s.size() - i]; + if ((c & 0xC0) != 0x80) { + int expected = (c >= 0xF0) ? 4 : (c >= 0xE0) ? 3 : (c >= 0xC0) ? 2 : 1; + return i >= expected; + } + } + return false; +} + std::string common_unicode_cpts_to_utf8(const std::vector & cps) { std::string result; for (size_t i = 0; i < cps.size(); ++i) { diff --git a/common/unicode.h b/common/unicode.h index 87bcc0ffcaf..9b32fa19d62 100644 --- a/common/unicode.h +++ b/common/unicode.h @@ -20,6 +20,9 @@ struct utf8_parse_result { // Returns 0 for invalid first bytes size_t common_utf8_sequence_length(unsigned char first_byte); +// Check if a string ends with a complete UTF-8 sequence. +bool common_utf8_is_complete(const std::string & s); + // Parse a single UTF-8 codepoint from input utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset); diff --git a/include/llama.h b/include/llama.h index 44e670b262b..0bd10294cb8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1455,36 +1455,6 @@ extern "C" { // LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab); - /// @details Reasoning budget sampler. Limits the number of tokens a model can generate inside - /// a reasoning block (e.g. between and ). - /// - /// State machine: IDLE -> COUNTING -> FORCING -> DONE - /// - IDLE: passthrough, watching accepted tokens for the start sequence - /// - COUNTING: counts down tokens, watching for natural end (defuse) - /// - FORCING: forces the budget message + end sequence token-by-token - /// - DONE: passthrough forever - /// - /// @param vocab The vocabulary (for tokenization and EOG checks) - /// @param start_tokens Token sequence that activates the countdown (e.g. "") - /// @param n_start Number of tokens in start_tokens - /// @param end_tokens Token sequence that deactivates naturally (e.g. "") - /// @param n_end Number of tokens in end_tokens - /// @param forced_tokens Token sequence forced when budget expires (e.g. "(budget exceeded)") - /// @param n_forced Number of tokens in forced_tokens - /// @param budget Maximum number of tokens allowed in the reasoning block - /// @param activate_immediately If true, skip IDLE and start in COUNTING directly - /// - LLAMA_API struct llama_sampler * llama_sampler_init_reasoning_budget( - const struct llama_vocab * vocab, - const llama_token * start_tokens, - size_t n_start, - const llama_token * end_tokens, - size_t n_end, - const llama_token * forced_tokens, - size_t n_forced, - int32_t budget, - bool activate_immediately); - // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); diff --git a/src/llama-sampler.cpp b/src/llama-sampler.cpp index dee0cb6d49d..9bbc5dbde24 100644 --- a/src/llama-sampler.cpp +++ b/src/llama-sampler.cpp @@ -3822,283 +3822,6 @@ struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * voca ); } -// reasoning budget - -// Check if a string ends with an incomplete UTF-8 multi-byte sequence. -// Returns true if cutting after this string would split a multi-byte character. -static bool llama_utf8_is_incomplete(const std::string & s) { - if (s.empty()) { - return false; - } - - // Scan backwards to count trailing continuation bytes (10xxxxxx) - int i = (int)s.size() - 1; - int n_cont = 0; - while (i >= 0 && (static_cast(s[i]) & 0xC0) == 0x80) { - n_cont++; - i--; - } - - if (i < 0) { - // Only continuation bytes, no leading byte — malformed - return true; - } - - const unsigned char lead = static_cast(s[i]); - - if ((lead & 0x80) == 0x00) { - // ASCII byte — complete on its own, trailing continuations would be malformed - return n_cont > 0; - } - - // Determine expected continuation bytes from leading byte - int expected; - if ((lead & 0xE0) == 0xC0) { expected = 1; } // 110xxxxx: 2-byte - else if ((lead & 0xF0) == 0xE0) { expected = 2; } // 1110xxxx: 3-byte - else if ((lead & 0xF8) == 0xF0) { expected = 3; } // 11110xxx: 4-byte - else { return true; } // invalid leading byte - - return n_cont < expected; -} - -enum llama_reasoning_budget_state { - REASONING_BUDGET_IDLE, // waiting for start sequence - REASONING_BUDGET_COUNTING, // counting down tokens - REASONING_BUDGET_FORCING, // forcing budget message + end sequence - REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion - REASONING_BUDGET_DONE, // passthrough forever -}; - -struct llama_sampler_reasoning_budget { - const llama_vocab * vocab; - - std::vector start_tokens; // sequence that starts counting (e.g. "") - std::vector end_tokens; // sequence that deactivates naturally (e.g. "") - std::vector forced_tokens; // sequence forced when budget expires (e.g. "(budget exceeded)") - - int32_t budget; // maximum tokens in reasoning block - int32_t remaining; // tokens remaining in budget - - llama_reasoning_budget_state state; - - // for multi-token sequence matching - size_t start_match_pos; // how many tokens of start_tokens we've matched so far - size_t end_match_pos; // how many tokens of end_tokens we've matched so far - - // for forcing - size_t force_pos; // next position in forced_tokens to force -}; - -static const char * llama_sampler_reasoning_budget_name(const struct llama_sampler * /*smpl*/) { - return "reasoning-budget"; -} - -static void llama_sampler_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) { - auto * ctx = (llama_sampler_reasoning_budget *) smpl->ctx; - - switch (ctx->state) { - case REASONING_BUDGET_IDLE: - { - // watch for start sequence - if (!ctx->start_tokens.empty() && token == ctx->start_tokens[ctx->start_match_pos]) { - ctx->start_match_pos++; - if (ctx->start_match_pos >= ctx->start_tokens.size()) { - // full start sequence matched - ctx->state = REASONING_BUDGET_COUNTING; - ctx->remaining = ctx->budget; - ctx->start_match_pos = 0; - LLAMA_LOG_INFO("reasoning-budget: activated, budget=%d tokens\n", ctx->budget); - - if (ctx->remaining <= 0) { - // budget is 0 — go straight to forcing - ctx->state = REASONING_BUDGET_FORCING; - ctx->force_pos = 0; - LLAMA_LOG_INFO("reasoning-budget: budget=0, forcing immediately\n"); - } - } - } else { - ctx->start_match_pos = 0; - // check if current token starts a new match - if (!ctx->start_tokens.empty() && token == ctx->start_tokens[0]) { - ctx->start_match_pos = 1; - } - } - break; - } - case REASONING_BUDGET_COUNTING: - case REASONING_BUDGET_WAITING_UTF8: - { - // check for natural end sequence (deactivate) - if (!ctx->end_tokens.empty() && token == ctx->end_tokens[ctx->end_match_pos]) { - ctx->end_match_pos++; - if (ctx->end_match_pos >= ctx->end_tokens.size()) { - // natural end — stop constraining - ctx->state = REASONING_BUDGET_DONE; - ctx->end_match_pos = 0; - LLAMA_LOG_INFO("reasoning-budget: deactivated (natural end)\n"); - } - } else { - ctx->end_match_pos = 0; - if (!ctx->end_tokens.empty() && token == ctx->end_tokens[0]) { - ctx->end_match_pos = 1; - } - } - - if (ctx->state == REASONING_BUDGET_WAITING_UTF8) { - // Check if the token completes the UTF-8 sequence - bool still_incomplete = false; // default: assume complete (safe fallback for null vocab) - if (ctx->vocab != nullptr) { - const std::string piece = ctx->vocab->token_to_piece(token); - still_incomplete = llama_utf8_is_incomplete(piece); - } - - if (!still_incomplete) { - // UTF-8 sequence complete, now start forcing - ctx->state = REASONING_BUDGET_FORCING; - ctx->force_pos = 0; - ctx->end_match_pos = 0; - LLAMA_LOG_INFO("reasoning-budget: UTF-8 complete, now forcing end sequence\n"); - } - } else if (ctx->state == REASONING_BUDGET_COUNTING) { - ctx->remaining--; - if (ctx->remaining <= 0) { - // Budget exhausted — check if we need to wait for UTF-8 completion - bool wait_for_utf8 = false; - if (ctx->vocab != nullptr) { - const std::string piece = ctx->vocab->token_to_piece(token); - wait_for_utf8 = llama_utf8_is_incomplete(piece); - } - - if (wait_for_utf8) { - // Incomplete UTF-8 sequence, wait for completion - ctx->state = REASONING_BUDGET_WAITING_UTF8; - ctx->force_pos = 0; - ctx->end_match_pos = 0; - LLAMA_LOG_INFO("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n"); - } else { - // Complete UTF-8, go straight to forcing - ctx->state = REASONING_BUDGET_FORCING; - ctx->force_pos = 0; - ctx->end_match_pos = 0; - LLAMA_LOG_INFO("reasoning-budget: budget exhausted, forcing end sequence\n"); - } - } - } - break; - } - case REASONING_BUDGET_FORCING: - { - // force_pos is advanced in apply(), not here - // This ensures the first forced token isn't skipped when the sampler - // is initialized directly in FORCING state (e.g. activate_immediately + budget=0) - break; - } - case REASONING_BUDGET_DONE: - break; - } -} - -static void llama_sampler_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * ctx = (llama_sampler_reasoning_budget *) smpl->ctx; - - if (ctx->state != REASONING_BUDGET_FORCING) { - // passthrough — don't modify logits - return; - } - - if (ctx->force_pos >= ctx->forced_tokens.size()) { - return; - } - - const llama_token forced = ctx->forced_tokens[ctx->force_pos]; - - // set all logits to -inf except the forced token - for (size_t i = 0; i < cur_p->size; i++) { - if (cur_p->data[i].id != forced) { - cur_p->data[i].logit = -INFINITY; - } - } - - // advance to next forced token (done here rather than in accept so that - // the first forced token isn't skipped when starting in FORCING state) - ctx->force_pos++; - if (ctx->force_pos >= ctx->forced_tokens.size()) { - ctx->state = REASONING_BUDGET_DONE; - LLAMA_LOG_INFO("reasoning-budget: forced sequence complete, done\n"); - } -} - -static void llama_sampler_reasoning_budget_reset(struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_reasoning_budget *) smpl->ctx; - ctx->state = REASONING_BUDGET_IDLE; - ctx->remaining = ctx->budget; - ctx->start_match_pos = 0; - ctx->end_match_pos = 0; - ctx->force_pos = 0; -} - -static struct llama_sampler * llama_sampler_reasoning_budget_clone(const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_reasoning_budget *) smpl->ctx; - return llama_sampler_init_reasoning_budget( - ctx->vocab, - ctx->start_tokens.data(), ctx->start_tokens.size(), - ctx->end_tokens.data(), ctx->end_tokens.size(), - ctx->forced_tokens.data(), ctx->forced_tokens.size(), - ctx->budget, - ctx->state == REASONING_BUDGET_COUNTING || ctx->state == REASONING_BUDGET_FORCING || ctx->state == REASONING_BUDGET_WAITING_UTF8); -} - -static void llama_sampler_reasoning_budget_free(struct llama_sampler * smpl) { - delete (llama_sampler_reasoning_budget *) smpl->ctx; -} - -static struct llama_sampler_i llama_sampler_reasoning_budget_i = { - /* .name = */ llama_sampler_reasoning_budget_name, - /* .accept = */ llama_sampler_reasoning_budget_accept, - /* .apply = */ llama_sampler_reasoning_budget_apply, - /* .reset = */ llama_sampler_reasoning_budget_reset, - /* .clone = */ llama_sampler_reasoning_budget_clone, - /* .free = */ llama_sampler_reasoning_budget_free, - /* .backend_init = */ nullptr, - /* .backend_accept = */ nullptr, - /* .backend_apply = */ nullptr, - /* .backend_set_input = */ nullptr, -}; - -struct llama_sampler * llama_sampler_init_reasoning_budget( - const struct llama_vocab * vocab, - const llama_token * start_tokens, - size_t n_start, - const llama_token * end_tokens, - size_t n_end, - const llama_token * forced_tokens, - size_t n_forced, - int32_t budget, - bool activate_immediately) { - auto initial_state = activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE; - - // if activated immediately with budget <= 0, go straight to forcing - if (activate_immediately && budget <= 0) { - initial_state = REASONING_BUDGET_FORCING; - } - - return llama_sampler_init( - /* .iface = */ &llama_sampler_reasoning_budget_i, - /* .ctx = */ new llama_sampler_reasoning_budget { - /* .vocab = */ vocab, - /* .start_tokens = */ std::vector(start_tokens, start_tokens + n_start), - /* .end_tokens = */ std::vector(end_tokens, end_tokens + n_end), - /* .forced_tokens = */ std::vector(forced_tokens, forced_tokens + n_forced), - /* .budget = */ budget, - /* .remaining = */ budget, - /* .state = */ initial_state, - /* .start_match_pos = */ 0, - /* .end_match_pos = */ 0, - /* .force_pos = */ 0, - } - ); -} - // utils uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7fd895e2b64..bb0f0ef0ed8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -149,6 +149,7 @@ endif () if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) # these tests are disabled on Windows because they use internal functions not exported with LLAMA_API (when building with shared libraries) llama_build_and_test(test-sampling.cpp) + llama_build_and_test(test-reasoning-budget.cpp) llama_build_and_test(test-grammar-parser.cpp) llama_build_and_test(test-grammar-integration.cpp) llama_build_and_test(test-llama-grammar.cpp) diff --git a/tests/test-reasoning-budget.cpp b/tests/test-reasoning-budget.cpp new file mode 100644 index 00000000000..ab540a84630 --- /dev/null +++ b/tests/test-reasoning-budget.cpp @@ -0,0 +1,238 @@ +#include "reasoning-budget.h" +#include "unicode.h" + +#include "llama.h" +#include "ggml.h" + +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include +#include +#include +#include +#include + +// Reasoning budget sampler test helper +// These tests use nullptr vocab which safely falls back to treating all tokens as complete +// (The UTF-8 boundary detection logic is tested separately in test_utf8_boundary_detection) +static void test_reasoning_budget( + const char * test_name, + const std::vector & sequence, + const std::vector & start_tokens, + const std::vector & end_tokens, + const std::vector & forced_tokens, + int32_t budget, + common_reasoning_budget_state initial_state, + size_t expected_force_start, // token index where forcing should start (SIZE_MAX = never) + size_t expected_force_end // token index where forcing should end (after this, no more forcing) +) { + // Find the maximum token ID to ensure our vocab covers all tokens + llama_token max_token = 0; + for (auto t : sequence) max_token = std::max(max_token, t); + for (auto t : start_tokens) max_token = std::max(max_token, t); + for (auto t : end_tokens) max_token = std::max(max_token, t); + for (auto t : forced_tokens) max_token = std::max(max_token, t); + + // Create a minimal sampler with mock vocabulary + // For this test, we use nullptr as vocab since we're testing state transitions + // The UTF-8 boundary check will treat all tokens as complete (safe fallback) + auto * sampler = common_reasoning_budget_init( + nullptr, // vocab - not used for basic state machine tests + start_tokens, + end_tokens, + forced_tokens, + budget, + initial_state + ); + + // Create a test token data array for checking forcing behavior + // Vocab size must be large enough to include all tokens (start, end, forced, sequence) + std::vector cur; + const size_t n_vocab = (size_t)max_token + 1; + for (size_t i = 0; i < n_vocab; i++) { + cur.emplace_back(llama_token_data{(llama_token)i, logf((float)(i+1)), 0.0f}); + } + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + + size_t actual_force_start = SIZE_MAX; + size_t actual_force_end = SIZE_MAX; + + // Feed the sequence and track when forcing occurs + for (size_t i = 0; i < sequence.size(); i++) { + llama_sampler_accept(sampler, sequence[i]); + + // Check if we're in forcing state by applying and seeing if logits are modified + cur_p.selected = -1; + for (size_t j = 0; j < cur.size(); j++) { + cur[j].logit = logf((float)(j+1)); // reset logits + } + + llama_sampler_apply(sampler, &cur_p); + + // Check if forcing is active (all logits except one should be -INFINITY) + size_t finite_count = 0; + llama_token finite_token = -1; + for (size_t j = 0; j < cur.size(); j++) { + if (std::isfinite(cur[j].logit)) { + finite_count++; + finite_token = cur[j].id; + } + } + + fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token); + + if (finite_count == 1) { + if (actual_force_start == SIZE_MAX) { + actual_force_start = i; + } + actual_force_end = i; + } else if (actual_force_start != SIZE_MAX && actual_force_end != SIZE_MAX) { + // Forcing stopped + break; + } + } + + llama_sampler_free(sampler); + + // Verify forcing occurred at expected positions + if (expected_force_start == SIZE_MAX) { + if (actual_force_start != SIZE_MAX) { + fprintf(stderr, "Test '%s' FAILED: Expected no forcing, but forcing occurred at %zu\n", test_name, actual_force_start); + GGML_ASSERT(false && "Expected no forcing, but forcing occurred"); + } + } else { + if (actual_force_start == SIZE_MAX) { + fprintf(stderr, "Test '%s' FAILED: Expected forcing but none occurred\n", test_name); + GGML_ASSERT(false && "Expected forcing but none occurred"); + } + if (actual_force_start != expected_force_start) { + fprintf(stderr, "Test '%s' FAILED: Forcing started at %zu, expected %zu\n", test_name, actual_force_start, expected_force_start); + GGML_ASSERT(false && "Forcing started at wrong position"); + } + } + + if (expected_force_end != SIZE_MAX) { + if (actual_force_end < expected_force_end) { + fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end); + GGML_ASSERT(false && "Forcing ended too early"); + } + } + + fprintf(stderr, " Test '%s' passed (force_start=%zu, force_end=%zu)\n", test_name, actual_force_start, actual_force_end); + (void)sequence; +} + +// UTF-8 boundary detection unit test +// Tests common_utf8_is_complete() from reasoning-budget.h +static void test_utf8_boundary_detection() { + // Complete sequences + GGML_ASSERT(common_utf8_is_complete("hello")); + GGML_ASSERT(common_utf8_is_complete("")); + GGML_ASSERT(common_utf8_is_complete("\xC2\xA0")); // complete 2-byte UTF-8 (U+00A0) + GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C")); // complete 3-byte UTF-8 (left double quote) + GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80")); // complete 4-byte UTF-8 (emoji) + GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9")); // ASCII + complete 2-byte + + // Incomplete sequences + GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1))); // 2-byte start, missing continuation + GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2))); // 3-byte start + 1 cont, missing 1 + GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1))); // 3-byte start, missing 2 + GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3))); // 4-byte start + 2 cont, missing 1 + GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2))); // 4-byte start + 1 cont, missing 2 + GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1))); // 4-byte start, missing 3 + GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1))); // orphan continuation byte + + // Mixed: ASCII followed by start of multi-byte + GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6))); // ASCII + incomplete 2-byte + GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7))); // ASCII + complete 2-byte +} + +int main(void) { + // Reasoning budget sampler tests + printf("Testing reasoning budget sampler... "); + + // Test 1: Basic budget with start/end tokens - no forcing (natural end before budget exhausted) + { + const std::vector start = {100}; // start token + const std::vector end = {101}; // end token + const std::vector forced = {102}; // forced token (not used in this test) + const std::vector sequence = {100, 50, 51, 101, 52}; // start, two tokens, end, one more + + test_reasoning_budget("natural end before budget exhausted", sequence, start, end, forced, + 5, // budget of 5 tokens + REASONING_BUDGET_IDLE, + SIZE_MAX, SIZE_MAX); // no forcing expected (natural end) + } + + // Test 2: Budget exhausted, forcing should occur + // Flow: i=0 accept(100)->COUNTING, i=1 accept(50)->remaining=1, i=2 accept(51)->remaining=0->FORCING + // Forcing is active at i=2 and i=3 (when apply() is called while in FORCING state) + // At i=4, force_pos becomes 2 which equals forced_tokens.size(), so state becomes DONE + { + const std::vector start = {100}; + const std::vector end = {101}; + const std::vector forced = {102, 101}; // forced message + end + const std::vector sequence = {100, 50, 51, 52, 53}; // start + 4 tokens (budget=2) + + test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced, + 2, // budget of 2 tokens + REASONING_BUDGET_IDLE, + 2, // forcing starts at i=2 (after accept(51) depletes budget, apply() forces) + 3); // forcing continues through i=3 (at i=4 state becomes DONE) + } + + // Test 3: Activate immediately with budget=0, forcing should start right away + // Flow: Since no start token in sequence, state stays IDLE (no start/end configured means passthrough) + // This test needs start token to be in the sequence or use activate_immediately with start token present + { + const std::vector start = {100}; + const std::vector end = {101}; + const std::vector forced = {102, 101}; + const std::vector sequence = {100, 50, 51, 52}; // start token first, then 3 tokens + + test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced, + 0, // budget of 0 tokens + REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0 + 0, // forcing starts at i=0 (after accept(100), budget=0 goes straight to FORCING) + 1); // forcing continues through i=1 (at i=2 state becomes DONE) + } + + // Test 4: No start/end tokens configured - passthrough (no forcing) + { + const std::vector start = {}; + const std::vector end = {}; + const std::vector forced = {102}; + const std::vector sequence = {50, 51, 52, 53}; + + test_reasoning_budget("no start/end configured", sequence, start, end, forced, + 2, // budget + REASONING_BUDGET_IDLE, + SIZE_MAX, SIZE_MAX); // no forcing (no start/end configured) + } + + // Test 5: Activate immediately with budget > 0, count down then force + // Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING + // So forcing starts at i=1 (apply after accept sees FORCING with force_pos=0) + { + const std::vector start = {100}; + const std::vector end = {101}; + const std::vector forced = {102, 101}; + const std::vector sequence = {50, 51, 52, 53}; + + test_reasoning_budget("activate immediately with budget", sequence, start, end, forced, + 2, // budget of 2 tokens + REASONING_BUDGET_COUNTING, + 1, // forcing starts at i=1 (after 2 accepts deplete budget) + 2); // forcing continues through i=2 + } + + printf("OK (5 tests passed)\n"); + + printf("Testing UTF-8 boundary detection... "); + test_utf8_boundary_detection(); + printf("OK\n"); + + return 0; +} diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index be4c3583e17..7cd96c5cd35 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -7,7 +7,6 @@ #include #include -#include #include #include @@ -182,174 +181,6 @@ static void test_dry( tester.check(); } -// Reasoning budget sampler test helper -// These tests use nullptr vocab which safely falls back to treating all tokens as complete -// (The UTF-8 boundary detection logic is tested separately in test_utf8_boundary_detection) -static void test_reasoning_budget( - const char * test_name, - const std::vector & sequence, - const std::vector & start_tokens, - const std::vector & end_tokens, - const std::vector & forced_tokens, - int32_t budget, - bool activate_immediately, - size_t expected_force_start, // token index where forcing should start (SIZE_MAX = never) - size_t expected_force_end // token index where forcing should end (after this, no more forcing) -) { - // Find the maximum token ID to ensure our vocab covers all tokens - llama_token max_token = 0; - for (auto t : sequence) max_token = std::max(max_token, t); - for (auto t : start_tokens) max_token = std::max(max_token, t); - for (auto t : end_tokens) max_token = std::max(max_token, t); - for (auto t : forced_tokens) max_token = std::max(max_token, t); - - // Create a minimal sampler with mock vocabulary - // For this test, we use nullptr as vocab since we're testing state transitions - // The UTF-8 boundary check will treat all tokens as complete (safe fallback) - auto * sampler = llama_sampler_init_reasoning_budget( - nullptr, // vocab - not used for basic state machine tests - start_tokens.data(), start_tokens.size(), - end_tokens.data(), end_tokens.size(), - forced_tokens.data(), forced_tokens.size(), - budget, - activate_immediately - ); - - // Create a test token data array for checking forcing behavior - // Vocab size must be large enough to include all tokens (start, end, forced, sequence) - std::vector cur; - const size_t n_vocab = (size_t)max_token + 1; - for (size_t i = 0; i < n_vocab; i++) { - cur.emplace_back(llama_token_data{(llama_token)i, logf((float)(i+1)), 0.0f}); - } - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - - size_t actual_force_start = SIZE_MAX; - size_t actual_force_end = SIZE_MAX; - - // Feed the sequence and track when forcing occurs - for (size_t i = 0; i < sequence.size(); i++) { - llama_sampler_accept(sampler, sequence[i]); - - // Check if we're in forcing state by applying and seeing if logits are modified - cur_p.selected = -1; - for (size_t j = 0; j < cur.size(); j++) { - cur[j].logit = logf((float)(j+1)); // reset logits - } - - llama_sampler_apply(sampler, &cur_p); - - // Check if forcing is active (all logits except one should be -INFINITY) - size_t finite_count = 0; - llama_token finite_token = -1; - for (size_t j = 0; j < cur.size(); j++) { - if (std::isfinite(cur[j].logit)) { - finite_count++; - finite_token = cur[j].id; - } - } - - fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token); - - if (finite_count == 1) { - if (actual_force_start == SIZE_MAX) { - actual_force_start = i; - } - actual_force_end = i; - } else if (actual_force_start != SIZE_MAX && actual_force_end != SIZE_MAX) { - // Forcing stopped - break; - } - } - - llama_sampler_free(sampler); - - // Verify forcing occurred at expected positions - if (expected_force_start == SIZE_MAX) { - if (actual_force_start != SIZE_MAX) { - fprintf(stderr, "Test '%s' FAILED: Expected no forcing, but forcing occurred at %zu\n", test_name, actual_force_start); - GGML_ASSERT(false && "Expected no forcing, but forcing occurred"); - } - } else { - if (actual_force_start == SIZE_MAX) { - fprintf(stderr, "Test '%s' FAILED: Expected forcing but none occurred\n", test_name); - GGML_ASSERT(false && "Expected forcing but none occurred"); - } - if (actual_force_start != expected_force_start) { - fprintf(stderr, "Test '%s' FAILED: Forcing started at %zu, expected %zu\n", test_name, actual_force_start, expected_force_start); - GGML_ASSERT(false && "Forcing started at wrong position"); - } - } - - if (expected_force_end != SIZE_MAX) { - if (actual_force_end < expected_force_end) { - fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end); - GGML_ASSERT(false && "Forcing ended too early"); - } - } - - fprintf(stderr, " Test '%s' passed (force_start=%zu, force_end=%zu)\n", test_name, actual_force_start, actual_force_end); - (void)sequence; -} - -// UTF-8 boundary detection unit test -// Tests the core logic used by the reasoning budget sampler to detect incomplete UTF-8 sequences -// This mirrors llama_utf8_is_incomplete() from llama-sampler.cpp -static void test_utf8_boundary_detection() { - // Reimplement the same logic as llama_utf8_is_incomplete for testing - auto is_incomplete = [](const std::string & s) -> bool { - if (s.empty()) { - return false; - } - - int i = (int)s.size() - 1; - int n_cont = 0; - while (i >= 0 && (static_cast(s[i]) & 0xC0) == 0x80) { - n_cont++; - i--; - } - - if (i < 0) { - return true; // only continuation bytes, no leading byte - } - - const unsigned char lead = static_cast(s[i]); - - if ((lead & 0x80) == 0x00) { - return n_cont > 0; // ASCII followed by continuation bytes = malformed - } - - int expected; - if ((lead & 0xE0) == 0xC0) { expected = 1; } - else if ((lead & 0xF0) == 0xE0) { expected = 2; } - else if ((lead & 0xF8) == 0xF0) { expected = 3; } - else { return true; } // invalid leading byte - - return n_cont < expected; - }; - - // Complete sequences — should NOT wait - GGML_ASSERT(!is_incomplete("hello")); - GGML_ASSERT(!is_incomplete("")); - GGML_ASSERT(!is_incomplete("\xC2\xA0")); // complete 2-byte UTF-8 (U+00A0) - GGML_ASSERT(!is_incomplete("\xE2\x80\x9C")); // complete 3-byte UTF-8 (left double quote) - GGML_ASSERT(!is_incomplete("\xF0\x9F\x98\x80")); // complete 4-byte UTF-8 (emoji) - GGML_ASSERT(!is_incomplete("abc\xC3\xA9")); // ASCII + complete 2-byte - - // Incomplete sequences — SHOULD wait - GGML_ASSERT(is_incomplete(std::string("\xC2", 1))); // 2-byte start, missing continuation - GGML_ASSERT(is_incomplete(std::string("\xE2\x80", 2))); // 3-byte start + 1 cont, missing 1 - GGML_ASSERT(is_incomplete(std::string("\xE2", 1))); // 3-byte start, missing 2 - GGML_ASSERT(is_incomplete(std::string("\xF0\x9F\x98", 3))); // 4-byte start + 2 cont, missing 1 - GGML_ASSERT(is_incomplete(std::string("\xF0\x9F", 2))); // 4-byte start + 1 cont, missing 2 - GGML_ASSERT(is_incomplete(std::string("\xF0", 1))); // 4-byte start, missing 3 - GGML_ASSERT(is_incomplete(std::string("\x80", 1))); // orphan continuation byte - - // Mixed: ASCII followed by start of multi-byte - GGML_ASSERT(is_incomplete(std::string("hello\xC3", 6))); // ASCII + incomplete 2-byte - GGML_ASSERT(!is_incomplete(std::string("hello\xC3\xA9", 7))); // ASCII + complete 2-byte -} - static void test_top_n_sigma(const std::vector & probs, const std::vector & probs_expected, int n) { sampler_tester tester(probs, probs_expected); @@ -561,88 +392,6 @@ int main(void) { test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f); test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f); - // Reasoning budget sampler tests - printf("Testing reasoning budget sampler... "); - - // Test 1: Basic budget with start/end tokens - no forcing (natural end before budget exhausted) - { - const std::vector start = {100}; // start token - const std::vector end = {101}; // end token - const std::vector forced = {102}; // forced token (not used in this test) - const std::vector sequence = {100, 50, 51, 101, 52}; // start, two tokens, end, one more - - test_reasoning_budget("natural end before budget exhausted", sequence, start, end, forced, - 5, // budget of 5 tokens - false, // don't activate immediately - SIZE_MAX, SIZE_MAX); // no forcing expected (natural end) - } - - // Test 2: Budget exhausted, forcing should occur - // Flow: i=0 accept(100)->COUNTING, i=1 accept(50)->remaining=1, i=2 accept(51)->remaining=0->FORCING - // Forcing is active at i=2 and i=3 (when apply() is called while in FORCING state) - // At i=4, force_pos becomes 2 which equals forced_tokens.size(), so state becomes DONE - { - const std::vector start = {100}; - const std::vector end = {101}; - const std::vector forced = {102, 101}; // forced message + end - const std::vector sequence = {100, 50, 51, 52, 53}; // start + 4 tokens (budget=2) - - test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced, - 2, // budget of 2 tokens - false, // don't activate immediately - 2, // forcing starts at i=2 (after accept(51) depletes budget, apply() forces) - 3); // forcing continues through i=3 (at i=4 state becomes DONE) - } - - // Test 3: Activate immediately with budget=0, forcing should start right away - // Flow: Since no start token in sequence, state stays IDLE (no start/end configured means passthrough) - // This test needs start token to be in the sequence or use activate_immediately with start token present - { - const std::vector start = {100}; - const std::vector end = {101}; - const std::vector forced = {102, 101}; - const std::vector sequence = {100, 50, 51, 52}; // start token first, then 3 tokens - - test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced, - 0, // budget of 0 tokens - true, // activate immediately when start token seen - 0, // forcing starts at i=0 (after accept(100), budget=0 goes straight to FORCING) - 1); // forcing continues through i=1 (at i=2 state becomes DONE) - } - - // Test 4: No start/end tokens configured - passthrough (no forcing) - { - const std::vector start = {}; - const std::vector end = {}; - const std::vector forced = {102}; - const std::vector sequence = {50, 51, 52, 53}; - - test_reasoning_budget("no start/end configured", sequence, start, end, forced, - 2, // budget - false, // don't activate immediately - SIZE_MAX, SIZE_MAX); // no forcing (no start/end configured) - } - - // Test 5: Activate immediately with budget > 0, count down then force - // Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING - // So forcing starts at i=1 (apply after accept sees FORCING with force_pos=0) - { - const std::vector start = {100}; - const std::vector end = {101}; - const std::vector forced = {102, 101}; - const std::vector sequence = {50, 51, 52, 53}; - - test_reasoning_budget("activate immediately with budget", sequence, start, end, forced, - 2, // budget of 2 tokens - true, // activate immediately - 1, // forcing starts at i=1 (after 2 accepts deplete budget) - 2); // forcing continues through i=2 - } - - printf("OK (5 tests passed)\n"); - - printf("Testing UTF-8 boundary detection... "); - test_utf8_boundary_detection(); printf("OK\n"); test_perf();