From 81c7b95fd22d02b4bfba8323f4b48ee3580c6e86 Mon Sep 17 00:00:00 2001 From: Alde Rojas Date: Sun, 31 May 2026 16:49:49 -0500 Subject: [PATCH] common : support manually triggering the reasoning budget end sequence --- common/reasoning-budget.cpp | 21 ++++++++++ common/reasoning-budget.h | 4 ++ common/sampling.cpp | 8 ++++ common/sampling.h | 3 ++ tests/test-reasoning-budget.cpp | 73 ++++++++++++++++++++++++++++++++- 5 files changed, 108 insertions(+), 1 deletion(-) diff --git a/common/reasoning-budget.cpp b/common/reasoning-budget.cpp index 958c9cacf51..ce41d029b05 100644 --- a/common/reasoning-budget.cpp +++ b/common/reasoning-budget.cpp @@ -247,3 +247,24 @@ common_reasoning_budget_state common_reasoning_budget_get_state(const struct lla } return ((const common_reasoning_budget_ctx *)smpl->ctx)->state; } + +bool common_reasoning_budget_force(struct llama_sampler * smpl) { + if (!smpl) { + return false; + } + + auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; + + // only a sampler that is actively counting down the budget may be forced; + // any other state (idle, already forcing/waiting, or done) is left untouched + if (ctx->state != REASONING_BUDGET_COUNTING) { + return false; + } + + ctx->state = REASONING_BUDGET_FORCING; + ctx->force_pos = 0; + ctx->end_matcher.reset(); + LOG_INF("reasoning-budget: forced into forcing state (manual transition)\n"); + + return true; +} diff --git a/common/reasoning-budget.h b/common/reasoning-budget.h index ef37f46ee4d..0cf689a5663 100644 --- a/common/reasoning-budget.h +++ b/common/reasoning-budget.h @@ -40,3 +40,7 @@ struct llama_sampler * common_reasoning_budget_init( common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE); common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl); + +// Manually transition the reasoning budget sampler into the FORCING state. +// Returns true if the transition occurred. +bool common_reasoning_budget_force(struct llama_sampler * smpl); diff --git a/common/sampling.cpp b/common/sampling.cpp index 5665d0a706c..bbfa9a9ecd6 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -661,6 +661,14 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { return llama_sampler_get_seed(gsmpl->chain); } +bool common_sampler_reasoning_budget_force(struct common_sampler * gsmpl) { + if (!gsmpl) { + return false; + } + + return common_reasoning_budget_force(gsmpl->rbudget); +} + // helpers llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) { diff --git a/common/sampling.h b/common/sampling.h index 49506a00cd8..19cbbbaba36 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -87,6 +87,9 @@ std::vector common_sampler_sample_and_accept_n(struct common_sample uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); +// force the reasoning budget sampler (if any) to begin forcing its end sequence now. +bool common_sampler_reasoning_budget_force(struct common_sampler * gsmpl); + // helpers // access the internal list of current candidate tokens diff --git a/tests/test-reasoning-budget.cpp b/tests/test-reasoning-budget.cpp index 10d0ac21da8..f54cff4f8a2 100644 --- a/tests/test-reasoning-budget.cpp +++ b/tests/test-reasoning-budget.cpp @@ -184,6 +184,76 @@ static void test_reasoning_budget_clone_mid_forcing() { llama_sampler_free(sampler); } +static void test_reasoning_budget_force_manual() { + const std::vector start = {100}; + const std::vector end = {101}; + const std::vector forced = {102, 101}; + + // if COUNTING, force() succeeds and begins forcing the end sequence from the start + { + auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 5, REASONING_BUDGET_IDLE); + + llama_sampler_accept(sampler, 100); // COUNTING, remaining=5 + llama_sampler_accept(sampler, 50); // COUNTING, remaining=4 + GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_COUNTING); + + GGML_ASSERT(common_reasoning_budget_force(sampler) && "force() should succeed from COUNTING"); + GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_FORCING); + + // forces the configured sequence from force_pos=0, then transitions to DONE + GGML_ASSERT(get_forced_token(sampler, 102) == 102); + llama_sampler_accept(sampler, 102); + GGML_ASSERT(get_forced_token(sampler, 102) == 101); + llama_sampler_accept(sampler, 101); + GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_DONE); + + llama_sampler_free(sampler); + } + + // if IDLE, force() is a no-op + { + auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 5, REASONING_BUDGET_IDLE); + + GGML_ASSERT(!common_reasoning_budget_force(sampler) && "force() must not transition from IDLE"); + GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_IDLE); + + llama_sampler_free(sampler); + } + + // if DONE, force() is a no-op + { + auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 5, REASONING_BUDGET_IDLE); + + llama_sampler_accept(sampler, 100); // COUNTING + llama_sampler_accept(sampler, 101); // natural end -> DONE + GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_DONE); + + GGML_ASSERT(!common_reasoning_budget_force(sampler) && "force() must not transition from DONE"); + GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_DONE); + + llama_sampler_free(sampler); + } + + // if FORCING, force() is a no-op and must not rewind the force position + { + auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 0, REASONING_BUDGET_FORCING); + + GGML_ASSERT(get_forced_token(sampler, 102) == 102); + llama_sampler_accept(sampler, 102); // advance to the second forced token (force_pos=1) + + GGML_ASSERT(!common_reasoning_budget_force(sampler) && "force() must not transition from FORCING"); + GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_FORCING); + GGML_ASSERT(get_forced_token(sampler, 102) == 101 && "force() must not rewind the force position"); + + llama_sampler_free(sampler); + } + + // a null sampler is safely ignored + GGML_ASSERT(!common_reasoning_budget_force(nullptr)); + + fprintf(stderr, " Test 'manual force transition' passed\n"); +} + // UTF-8 boundary detection unit test // Tests common_utf8_is_complete() from reasoning-budget.h static void test_utf8_boundary_detection() { @@ -312,8 +382,9 @@ int main(void) { test_reasoning_budget_clone_mid_counting(); test_reasoning_budget_clone_mid_forcing(); + test_reasoning_budget_force_manual(); - printf("OK (8 tests passed)\n"); + printf("OK (9 tests passed)\n"); printf("Testing UTF-8 boundary detection... "); test_utf8_boundary_detection();