Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ add_library(${TARGET} STATIC
preset.cpp
preset.h
regex-partial.cpp
reasoning-budget.cpp
reasoning-budget.h
regex-partial.h
sampling.cpp
sampling.h
Expand Down
219 changes: 219 additions & 0 deletions common/reasoning-budget.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
#include "reasoning-budget.h"
#include "common.h"
#include "unicode.h"

#include "log.h"

#include <cmath>
#include <cstdint>
#include <string>
#include <vector>

struct token_matcher {
std::vector<llama_token> tokens;
size_t pos = 0;

bool advance(llama_token token) {
if (tokens.empty()) {
return false;
}

if (token == tokens[pos]) {
pos++;
if (pos >= tokens.size()) {
pos = 0;
return true;
}
} else {
pos = 0;
if (token == tokens[0]) {
pos = 1;
}
}
return false;
}

void reset() { pos = 0; }
};

struct common_reasoning_budget_ctx {
const llama_vocab * vocab;

token_matcher start_matcher;
token_matcher end_matcher;
std::vector<llama_token> forced_tokens;

int32_t budget; // maximum tokens in reasoning block
int32_t remaining; // tokens remaining in budget

common_reasoning_budget_state state;

// for forcing
size_t force_pos; // next position in forced_tokens to force
};

static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
return "reasoning-budget";
}

static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;

switch (ctx->state) {
case REASONING_BUDGET_IDLE:
{
if (ctx->start_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);

if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
}
}
break;
}
case REASONING_BUDGET_COUNTING:
case REASONING_BUDGET_WAITING_UTF8:
{
if (ctx->end_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: deactivated (natural end)\n");
break;
}

bool utf8_complete = true;
if (ctx->vocab != nullptr) {
const std::string piece = common_token_to_piece(ctx->vocab, token, false);
utf8_complete = common_utf8_is_complete(piece);
}

if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
}
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
ctx->remaining--;
if (ctx->remaining <= 0) {
if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
} else {
ctx->state = REASONING_BUDGET_WAITING_UTF8;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
}
}
}
break;
}
case REASONING_BUDGET_FORCING:
// force_pos is advanced in apply(), not here.
// This ensures the first forced token isn't skipped when the sampler
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
break;
case REASONING_BUDGET_DONE:
break;
}
}

static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;

if (ctx->state != REASONING_BUDGET_FORCING) {
// passthrough — don't modify logits
return;
}

if (ctx->force_pos >= ctx->forced_tokens.size()) {
return;
}

const llama_token forced = ctx->forced_tokens[ctx->force_pos];

// set all logits to -inf except the forced token
for (size_t i = 0; i < cur_p->size; i++) {
if (cur_p->data[i].id != forced) {
cur_p->data[i].logit = -INFINITY;
}
}

// advance to next forced token (done here rather than in accept so that
// the first forced token isn't skipped when starting in FORCING state)
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: forced sequence complete, done\n");
}
}

static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
ctx->state = REASONING_BUDGET_IDLE;
ctx->remaining = ctx->budget;
ctx->start_matcher.reset();
ctx->end_matcher.reset();
ctx->force_pos = 0;
}

static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
return common_reasoning_budget_init(
ctx->vocab,
ctx->start_matcher.tokens,
ctx->end_matcher.tokens,
ctx->forced_tokens,
ctx->budget,
ctx->state);
}

static void common_reasoning_budget_free(struct llama_sampler * smpl) {
delete (common_reasoning_budget_ctx *) smpl->ctx;
}

static struct llama_sampler_i common_reasoning_budget_i = {
/* .name = */ common_reasoning_budget_name,
/* .accept = */ common_reasoning_budget_accept,
/* .apply = */ common_reasoning_budget_apply,
/* .reset = */ common_reasoning_budget_reset,
/* .clone = */ common_reasoning_budget_clone,
/* .free = */ common_reasoning_budget_free,
/* .backend_init = */ nullptr,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ nullptr,
/* .backend_set_input = */ nullptr,
};

struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
// promote COUNTING with budget <= 0 to FORCING
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
initial_state = REASONING_BUDGET_FORCING;
}

return llama_sampler_init(
/* .iface = */ &common_reasoning_budget_i,
/* .ctx = */ new common_reasoning_budget_ctx {
/* .vocab = */ vocab,
/* .start_matcher = */ { start_tokens, 0 },
/* .end_matcher = */ { end_tokens, 0 },
/* .forced_tokens = */ forced_tokens,
/* .budget = */ budget,
/* .remaining = */ budget,
/* .state = */ initial_state,
/* .force_pos = */ 0,
}
);
}
41 changes: 41 additions & 0 deletions common/reasoning-budget.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#pragma once

#include "llama.h"

#include <cstdint>
#include <vector>

enum common_reasoning_budget_state {
REASONING_BUDGET_IDLE, // waiting for start sequence
REASONING_BUDGET_COUNTING, // counting down tokens
REASONING_BUDGET_FORCING, // forcing budget message + end sequence
REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion
REASONING_BUDGET_DONE, // passthrough forever
};

// Creates a reasoning budget sampler that limits token generation inside a
// reasoning block (e.g. between <think> and </think>).
//
// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE
// IDLE: passthrough, watching for start_tokens sequence
// COUNTING: counting down remaining tokens, watching for natural end_tokens
// WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence
// FORCING: forces forced_tokens token-by-token (all other logits -> -inf)
// DONE: passthrough forever
//
// Parameters:
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires
// budget - max tokens allowed in the reasoning block
// initial_state - initial state of the sampler (e.g. IDLE or COUNTING)
// note: COUNTING with budget <= 0 is promoted to FORCING
//
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state);
11 changes: 6 additions & 5 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "common.h"
#include "log.h"
#include "reasoning-budget.h"

#include <algorithm>
#include <cmath>
Expand Down Expand Up @@ -252,13 +253,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st

// reasoning budget sampler — added first so it can force tokens before other samplers
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
samplers.push_back(llama_sampler_init_reasoning_budget(
samplers.push_back(common_reasoning_budget_init(
vocab,
params.reasoning_budget_start.data(), params.reasoning_budget_start.size(),
params.reasoning_budget_end.data(), params.reasoning_budget_end.size(),
params.reasoning_budget_forced.data(), params.reasoning_budget_forced.size(),
params.reasoning_budget_start,
params.reasoning_budget_end,
params.reasoning_budget_forced,
params.reasoning_budget_tokens,
params.reasoning_budget_activate_immediately));
params.reasoning_budget_activate_immediately ? REASONING_BUDGET_COUNTING : REASONING_BUDGET_IDLE));
}

if (params.has_logit_bias()) {
Expand Down
18 changes: 17 additions & 1 deletion common/unicode.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include "unicode.h"

#include <algorithm>
#include <cassert>
#include <stdexcept>
#include <vector>
#include <string>
#include <vector>

// implementation adopted from src/unicode.cpp

Expand Down Expand Up @@ -67,6 +69,20 @@ utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t off
return utf8_parse_result(utf8_parse_result::INVALID);
}

bool common_utf8_is_complete(const std::string & s) {
if (s.empty()) {
return true;
}
for (int i = 1; i <= std::min(4, (int)s.size()); i++) {
unsigned char c = s[s.size() - i];
if ((c & 0xC0) != 0x80) {
int expected = (c >= 0xF0) ? 4 : (c >= 0xE0) ? 3 : (c >= 0xC0) ? 2 : 1;
return i >= expected;
}
}
return false;
}

std::string common_unicode_cpts_to_utf8(const std::vector<uint32_t> & cps) {
std::string result;
for (size_t i = 0; i < cps.size(); ++i) {
Expand Down
3 changes: 3 additions & 0 deletions common/unicode.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ struct utf8_parse_result {
// Returns 0 for invalid first bytes
size_t common_utf8_sequence_length(unsigned char first_byte);

// Check if a string ends with a complete UTF-8 sequence.
bool common_utf8_is_complete(const std::string & s);

// Parse a single UTF-8 codepoint from input
utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset);

Expand Down
30 changes: 0 additions & 30 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1455,36 +1455,6 @@ extern "C" {
//
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab);

/// @details Reasoning budget sampler. Limits the number of tokens a model can generate inside
/// a reasoning block (e.g. between <think> and </think>).
///
/// State machine: IDLE -> COUNTING -> FORCING -> DONE
/// - IDLE: passthrough, watching accepted tokens for the start sequence
/// - COUNTING: counts down tokens, watching for natural end (defuse)
/// - FORCING: forces the budget message + end sequence token-by-token
/// - DONE: passthrough forever
///
/// @param vocab The vocabulary (for tokenization and EOG checks)
/// @param start_tokens Token sequence that activates the countdown (e.g. "<think>")
/// @param n_start Number of tokens in start_tokens
/// @param end_tokens Token sequence that deactivates naturally (e.g. "</think>")
/// @param n_end Number of tokens in end_tokens
/// @param forced_tokens Token sequence forced when budget expires (e.g. "(budget exceeded)</think>")
/// @param n_forced Number of tokens in forced_tokens
/// @param budget Maximum number of tokens allowed in the reasoning block
/// @param activate_immediately If true, skip IDLE and start in COUNTING directly
///
LLAMA_API struct llama_sampler * llama_sampler_init_reasoning_budget(
const struct llama_vocab * vocab,
const llama_token * start_tokens,
size_t n_start,
const llama_token * end_tokens,
size_t n_end,
const llama_token * forced_tokens,
size_t n_forced,
int32_t budget,
bool activate_immediately);

// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);

Expand Down
Loading