Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1334,12 +1334,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg(
{"-cpent", "--checkpoint-every-n-tokens"}, "N",
string_format("create a checkpoint every n tokens during prefill (processing), -1 to disable (default: %d)", params.checkpoint_every_nt),
{"-cms", "--checkpoint-min-step"}, "N",
string_format("minimum spacing between context checkpoints in tokens (default: %d, 0 = no minimum)", params.checkpoint_min_step),
[](common_params & params, int value) {
params.checkpoint_every_nt = value;
if (value < 0) {
throw std::invalid_argument("checkpoint-min-step must be non-negative");
}
params.checkpoint_min_step = value;
}
).set_env("LLAMA_ARG_CHECKPOINT_EVERY_NT").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
).set_env("LLAMA_ARG_CHECKPOINT_MIN_SPACING_NT").set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-cram", "--cache-ram"}, "N",
string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)"
Expand Down
6 changes: 4 additions & 2 deletions common/chat-auto-parser-helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ std::vector<segment> prune_whitespace_segments(const std::vector<segment> & segm

namespace autoparser {

static const std::string ERR_TMPL = "#**ERROR**#";

std::string apply_template(const common_chat_template & tmpl, const template_params & params) {
generation_params tmpl_params;
tmpl_params.messages = params.messages;
Expand All @@ -326,7 +328,7 @@ std::string apply_template(const common_chat_template & tmpl, const template_par
return common_chat_template_direct_apply(tmpl, tmpl_params);
} catch (const std::exception & e) {
LOG_DBG("Template application failed: %s\n", e.what());
return "";
return ERR_TMPL;
}
}

Expand All @@ -347,7 +349,7 @@ std::optional<compare_variants_result> compare_variants(
std::string output_B = apply_template(tmpl, params_B);

// Check for template application failures
if (output_A.empty() || output_B.empty()) {
if (output_A == ERR_TMPL || output_B == ERR_TMPL) {
return std::nullopt;
}

Expand Down
6 changes: 6 additions & 0 deletions common/chat-auto-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ struct analyze_tools : analyze_base {

struct autoparser {
jinja::caps jinja_caps;
std::string user_start;
std::string assistant_start;
analyze_reasoning reasoning;
analyze_content content;
analyze_tools tools;
Expand All @@ -387,6 +389,10 @@ struct autoparser {

autoparser() = default;

// Find the starting marker for the user message and assistant message
std::string detect_user_start_marker(const common_chat_template & tmpl);
std::string detect_assistant_start_marker(const common_chat_template & tmpl);

// Run full differential analysis on a template
void analyze_template(const common_chat_template & tmpl);

Expand Down
177 changes: 176 additions & 1 deletion common/chat-diff-analyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include "peg-parser.h"

#include <algorithm>
#include <cctype>
#include <ostream>
#include <sstream>

#define ANSI_RESET "\033[0m"
#define ANSI_PURPLE "\033[1m\x1b[38;5;126m"
Expand All @@ -23,6 +26,7 @@ static const std::string FUN_SECOND = "SSS_SECOND_FUN_S";
static const std::string ARG_FIRST = "AA_ARG_FST_AA";
static const std::string ARG_SECOND = "BB_ARG_SND_BB";
static const std::string USER_MSG = "U_USER_MSG Hello END_U";
static const std::string USER_MSG_TWO = "V_USER_MSG Hello END_V";
static const std::string ASSISTANT_MSG = "A_ASST_MSG I can help END_A";
static const std::string THINKING_CONTENT = "REASON_PART I am thinking END_R";
static const std::string CALL_ID_001 = "call00001";
Expand Down Expand Up @@ -71,6 +75,7 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
analysis.content.end = "<|END_OF_TURN_TOKEN|>";
analysis.preserved_tokens.push_back("<|CHATBOT_TOKEN|>");
analysis.preserved_tokens.push_back("<|END_OF_TURN_TOKEN|>");
analysis.user_start = "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>";
LOG_DBG(ANSI_ORANGE "[Patch: Cohere Command R+]\n" ANSI_RESET);
}
},
Expand Down Expand Up @@ -108,7 +113,59 @@ static std::vector<std::function<void(const common_chat_template & tmpl, autopar
analysis.tools.function.close = "```";
LOG_DBG(ANSI_ORANGE "[Patch: DeepSeek-R1-Distill-Qwen]\n" ANSI_RESET);
}
}
},
// Nemotron Nano v2
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
if (tmpl.src.find("<SPECIAL_10>") != std::string::npos && tmpl.src.find("<SPECIAL_11>") != std::string::npos &&
tmpl.src.find("<SPECIAL_12>") != std::string::npos && tmpl.src.find("<TOOL_RESPONSE>") != std::string::npos) {

analysis.tools.format.mode = tool_format::JSON_NATIVE;
analysis.tools.format.section_start = "";
analysis.tools.format.section_end = "";
analysis.tools.format.per_call_start = "<TOOLCALL>";
analysis.tools.format.per_call_end = "</TOOLCALL>";
analysis.content.mode = content_mode::PLAIN;
analysis.content.start = "";
analysis.content.end = "";
analysis.reasoning.mode = reasoning_mode::TAG_BASED;
analysis.reasoning.start = "<think>\n\n";
analysis.reasoning.end = "</think>";
analysis.assistant_start = "<SPECIAL_11>Assistant";
analysis.user_start = "<SPECIAL_11>User";
analysis.preserved_tokens.clear();
analysis.preserved_tokens.push_back("<SPECIAL_12>");
analysis.preserved_tokens.push_back("<SPECIAL_11>");
analysis.preserved_tokens.push_back("</think>");
analysis.preserved_tokens.push_back("<TOOLCALL>");
analysis.preserved_tokens.push_back("</TOOLCALL>");
LOG_DBG(ANSI_ORANGE "[Patch: Nemotron Nano v2]\n" ANSI_RESET);
}
},
// Fireworks
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
if (tmpl.src.find("{%- set system_prompt = '<|start_header_id|>' + 'system' + '<|end_header_id|>\\n\\n'"
" + message['content'] | trim + '\\n' + system_prompt_suffix + '<|eot_id|>' -%}") != std::string::npos) {
analysis.assistant_start = "<|start_header_id|>assistant<|end_header_id|>";
analysis.user_start = "<|start_header_id|>user<|end_header_id|>";
LOG_DBG(ANSI_ORANGE "[Patch: Fireworks v2]\n" ANSI_RESET);
}
},
// Solar Open
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
if (tmpl.src.find("<|begin|>assistant<|think|><|end|>") != std::string::npos) {
analysis.assistant_start = "<|begin|>assistant";
LOG_DBG(ANSI_ORANGE "[Patch: Solar Open]\n" ANSI_RESET);
}
},
// Apriel 1.6
[](const common_chat_template & tmpl, autoparser & analysis) -> void {
if (tmpl.src.find("if not loop.last and '[BEGIN FINAL RESPONSE]' in asst_text") != std::string::npos) {
analysis.user_start = "<|begin_user|>";
analysis.assistant_start = "<|begin_assistant|>";
LOG_DBG(ANSI_ORANGE "[Patch: Apriel 1.6]\n" ANSI_RESET);
}
},

});

// Common JSON structures
Expand Down Expand Up @@ -166,13 +223,17 @@ void autoparser::analyze_template(const common_chat_template & tmpl) {
reasoning = analyze_reasoning(tmpl, jinja_caps.supports_tool_calls);
content = analyze_content(tmpl, reasoning);
tools = analyze_tools(jinja_caps.supports_tool_calls ? analyze_tools(tmpl, jinja_caps, reasoning) : analyze_tools());
assistant_start = detect_assistant_start_marker(tmpl);
user_start = detect_user_start_marker(tmpl);
collect_preserved_tokens();

for (auto & workaround : workarounds) {
workaround(tmpl, *this);
}

LOG_DBG("\n--- Reasoning & Content Structure ---\n");
LOG_DBG("user_msg_start: %s\n", user_start.c_str());
LOG_DBG("assistant_msg_start: %s\n", assistant_start.c_str());
LOG_DBG("reasoning_mode: %s\n", mode_to_str(reasoning.mode).c_str());
LOG_DBG("reasoning_start: '%s'\n", reasoning.start.c_str());
LOG_DBG("reasoning_end: '%s'\n", reasoning.end.c_str());
Expand Down Expand Up @@ -245,6 +306,120 @@ void autoparser::collect_preserved_tokens() {
add_token(tools.call_id.suffix);
}

std::string autoparser::detect_assistant_start_marker(const common_chat_template & tmpl) {
json user_msg = json{
{ "role", "user" },
{ "content", USER_MSG }
};

json assistant_no_reasoning = json{
{ "role", "assistant" },
{ "content", ASSISTANT_MSG }
};

template_params params;
params.messages = json::array({ user_msg });
params.add_generation_prompt = false;
params.enable_thinking = true;

auto comparison = compare_variants(
tmpl, params, [&](template_params & p) {
p.messages = json::array({ user_msg, assistant_no_reasoning });
}
);

if (!comparison) {
LOG_DBG(ANSI_ORANGE "%s: Template application failed, skipping assistant start detection\n" ANSI_RESET, __func__);
return "";
}

auto usermsg = comparison->diff.right;
if (usermsg.find(ASSISTANT_MSG) == std::string::npos) {
LOG_DBG(ANSI_ORANGE "%s: Did not find assistant message in assistant message block, skipping detection\n" ANSI_RESET, __func__);
}

auto ast_prefix = usermsg.substr(0, usermsg.find(ASSISTANT_MSG));
if (!reasoning.start.empty() && ast_prefix.find(trim_whitespace(reasoning.start)) != std::string::npos) {
ast_prefix = ast_prefix.substr(0, ast_prefix.find(trim_whitespace(reasoning.start)));
}
if (!reasoning.end.empty() && ast_prefix.find(trim_whitespace(reasoning.end)) != std::string::npos) {
ast_prefix = ast_prefix.substr(0, ast_prefix.find(trim_whitespace(reasoning.end)));
}
return trim_whitespace(ast_prefix);
}

std::string autoparser::detect_user_start_marker(const common_chat_template & tmpl) {
json user_msg = json{
{ "role", "user" },
{ "content", USER_MSG }
};

json assistant = json{
{ "role", "assistant" },
{ "content", ASSISTANT_MSG }
};

json user_msg_two = json{
{ "role", "user" },
{ "content", USER_MSG_TWO }
};

template_params params;
params.messages = json::array({});
params.add_generation_prompt = false;
params.enable_thinking = true;

auto comparison = compare_variants(
tmpl, params, [&](template_params & p) {
p.messages = json::array({ user_msg });
}
);

if (!comparison) {
LOG_DBG(ANSI_ORANGE "%s: Template application failed, unsupported empty messages? trying complex variant\n" ANSI_RESET, __func__);
params.messages = json::array({ user_msg_two, assistant });
comparison = compare_variants(
tmpl, params, [&](template_params & p) {
p.messages = json::array({ user_msg_two, assistant, user_msg });
}
);
if (!comparison) {
LOG_DBG(ANSI_ORANGE "%s: Template application failed for reserve variant, aborting\n" ANSI_RESET, __func__);
return "";
}
}

auto usermsg = comparison->diff.right;
if (usermsg.find(USER_MSG) == std::string::npos) {
LOG_DBG(ANSI_ORANGE "%s: Did not find user message in user message block, aborting detection\n" ANSI_RESET, __func__);
}

if (usermsg.find(ASSISTANT_MSG) != std::string::npos) {
usermsg = usermsg.substr(usermsg.find(ASSISTANT_MSG) + ASSISTANT_MSG.size());
}

auto candidate = usermsg.substr(0, usermsg.find(USER_MSG));
auto candidate_split = segmentize_markers(candidate);
std::stringstream result;
bool encountered_marker = false;
for (const auto & mrk : candidate_split) {
std::string lower_mrk = std::string(mrk.value);
std::transform(lower_mrk.begin(), lower_mrk.end(), lower_mrk.begin(),
[](unsigned char c) { return std::tolower(c); });
// heuristic to weed out potential end markers, but only at the start
if (mrk.type == segment_type::MARKER && !encountered_marker &&
(lower_mrk.find("end") != std::string::npos || lower_mrk.find("close") != std::string::npos)) {
continue;
}
if (mrk.type == segment_type::TEXT && !encountered_marker && trim_whitespace(mrk.value).empty()) {
continue;
}
encountered_marker |= mrk.type == segment_type::MARKER;
result << mrk.value;
}
return trim_whitespace(result.str());
}

analyze_reasoning::analyze_reasoning(const common_chat_template & tmpl, bool supports_tools)
: analyze_base(tmpl) {
LOG_DBG(ANSI_PURPLE "=== Starting differential analysis ===\n" ANSI_RESET);
Expand Down
65 changes: 65 additions & 0 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,45 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const
return text;
}

std::vector<common_chat_msg_span> common_chat_split_by_role(const std::string & prompt, const std::vector<common_chat_msg_delimiter> & delims) {
if (delims.empty() || prompt.empty()) {
return {};
}

auto parser = build_peg_parser([&](common_peg_parser_builder & p) {
std::vector<std::string> all_delims;
std::vector<common_peg_parser> tagged_messages;

all_delims.reserve(delims.size());
tagged_messages.reserve(delims.size());
for (const auto & d : delims) {
all_delims.push_back(d.delimiter);
}

auto any_delim = p.until_one_of(all_delims);
for (const auto & d : delims) {
tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim));
}

return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end();
});

common_peg_parse_context ctx(prompt);
const auto result = parser.parse(ctx);
if (!result.success()) {
return {};
}

std::vector<common_chat_msg_span> spans;
ctx.ast.visit(result, [&](const common_peg_ast_node & node) {
if (!node.tag.empty()) {
spans.push_back({ node.tag, node.start, node.end - node.start });
}
});

return spans;
}

json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const {
if (!content.empty() && !content_parts.empty()) {
throw std::runtime_error("Cannot specify both content and content_parts");
Expand Down Expand Up @@ -1042,6 +1081,14 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp

data.prompt = prompt;
data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages);
data.message_spans = common_chat_split_by_role(prompt, {
{ "assistant", "<|start|>assistant" },
{ "user", "<|start|>user" },
{ "system", "<|start|>developer" },
{ "system", "<|start|>system" },
{ "tool", "<|start|>functions" },
});

data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.supports_thinking = true;

Expand Down Expand Up @@ -1181,6 +1228,11 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ
data.prompt += data.generation_prompt;
}

data.message_spans = common_chat_split_by_role(data.prompt, {
{ "user", "<|turn>user\n" },
{ "assistant", "<|turn>model\n" },
});

data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4;
data.supports_thinking = true;
data.thinking_start_tag = "<|channel>thought";
Expand Down Expand Up @@ -2393,6 +2445,19 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_
struct autoparser::autoparser autoparser;
autoparser.analyze_template(tmpl);
auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser);

std::vector<common_chat_msg_delimiter> delimiters;
if (!autoparser.assistant_start.empty()) {
delimiters.push_back({ "assistant", autoparser.assistant_start });
}
if (!autoparser.user_start.empty()) {
delimiters.push_back({ "user", autoparser.user_start });
}

if (!delimiters.empty()) {
auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters);
}

auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE;
if (auto_params.supports_thinking) {
auto_params.thinking_start_tag = trim_whitespace(autoparser.reasoning.start);
Expand Down
Loading