Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions common/reasoning-budget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,24 @@ common_reasoning_budget_state common_reasoning_budget_get_state(const struct lla
}
return ((const common_reasoning_budget_ctx *)smpl->ctx)->state;
}

bool common_reasoning_budget_force(struct llama_sampler * smpl) {
if (!smpl) {
return false;
}

auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;

// only a sampler that is actively counting down the budget may be forced;
// any other state (idle, already forcing/waiting, or done) is left untouched
if (ctx->state != REASONING_BUDGET_COUNTING) {
return false;
}

ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: forced into forcing state (manual transition)\n");

return true;
}
4 changes: 4 additions & 0 deletions common/reasoning-budget.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,7 @@ struct llama_sampler * common_reasoning_budget_init(
common_reasoning_budget_state initial_state = REASONING_BUDGET_IDLE);

common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);

// Manually transition the reasoning budget sampler into the FORCING state.
// Returns true if the transition occurred.
bool common_reasoning_budget_force(struct llama_sampler * smpl);
8 changes: 8 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,14 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain);
}

bool common_sampler_reasoning_budget_force(struct common_sampler * gsmpl) {
if (!gsmpl) {
return false;
}

return common_reasoning_budget_force(gsmpl->rbudget);
}

// helpers

llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
Expand Down
3 changes: 3 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample

uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);

// force the reasoning budget sampler (if any) to begin forcing its end sequence now.
bool common_sampler_reasoning_budget_force(struct common_sampler * gsmpl);

// helpers

// access the internal list of current candidate tokens
Expand Down
73 changes: 72 additions & 1 deletion tests/test-reasoning-budget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,76 @@ static void test_reasoning_budget_clone_mid_forcing() {
llama_sampler_free(sampler);
}

static void test_reasoning_budget_force_manual() {
const std::vector<llama_token> start = {100};
const std::vector<llama_token> end = {101};
const std::vector<llama_token> forced = {102, 101};

// if COUNTING, force() succeeds and begins forcing the end sequence from the start
{
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 5, REASONING_BUDGET_IDLE);

llama_sampler_accept(sampler, 100); // COUNTING, remaining=5
llama_sampler_accept(sampler, 50); // COUNTING, remaining=4
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_COUNTING);

GGML_ASSERT(common_reasoning_budget_force(sampler) && "force() should succeed from COUNTING");
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_FORCING);

// forces the configured sequence from force_pos=0, then transitions to DONE
GGML_ASSERT(get_forced_token(sampler, 102) == 102);
llama_sampler_accept(sampler, 102);
GGML_ASSERT(get_forced_token(sampler, 102) == 101);
llama_sampler_accept(sampler, 101);
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_DONE);

llama_sampler_free(sampler);
}

// if IDLE, force() is a no-op
{
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 5, REASONING_BUDGET_IDLE);

GGML_ASSERT(!common_reasoning_budget_force(sampler) && "force() must not transition from IDLE");
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_IDLE);

llama_sampler_free(sampler);
}

// if DONE, force() is a no-op
{
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 5, REASONING_BUDGET_IDLE);

llama_sampler_accept(sampler, 100); // COUNTING
llama_sampler_accept(sampler, 101); // natural end -> DONE
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_DONE);

GGML_ASSERT(!common_reasoning_budget_force(sampler) && "force() must not transition from DONE");
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_DONE);

llama_sampler_free(sampler);
}

// if FORCING, force() is a no-op and must not rewind the force position
{
auto * sampler = common_reasoning_budget_init(nullptr, start, end, forced, 0, REASONING_BUDGET_FORCING);

GGML_ASSERT(get_forced_token(sampler, 102) == 102);
llama_sampler_accept(sampler, 102); // advance to the second forced token (force_pos=1)

GGML_ASSERT(!common_reasoning_budget_force(sampler) && "force() must not transition from FORCING");
GGML_ASSERT(common_reasoning_budget_get_state(sampler) == REASONING_BUDGET_FORCING);
GGML_ASSERT(get_forced_token(sampler, 102) == 101 && "force() must not rewind the force position");

llama_sampler_free(sampler);
}

// a null sampler is safely ignored
GGML_ASSERT(!common_reasoning_budget_force(nullptr));

fprintf(stderr, " Test 'manual force transition' passed\n");
}

// UTF-8 boundary detection unit test
// Tests common_utf8_is_complete() from reasoning-budget.h
static void test_utf8_boundary_detection() {
Expand Down Expand Up @@ -312,8 +382,9 @@ int main(void) {

test_reasoning_budget_clone_mid_counting();
test_reasoning_budget_clone_mid_forcing();
test_reasoning_budget_force_manual();

printf("OK (8 tests passed)\n");
printf("OK (9 tests passed)\n");

printf("Testing UTF-8 boundary detection... ");
test_utf8_boundary_detection();
Expand Down
Loading