Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,26 @@ bool CacheManager::loadCache() {
"%s: attempting to load saved session from '%s'\n",
__func__,
sessionPath_.c_str()));

// Remove tool tokens from KV cache before saving
llama_pos toolTokenCount = llmContext_->getLastToolTokenCount();
if (toolTokenCount > 0) {
auto* mem = llama_get_memory(ctx);
llama_pos currentPast = llmContext_->getNPast();
llama_pos newNPast = currentPast - toolTokenCount;

if (newNPast > 0) {
llama_memory_seq_rm(mem, -1, newNPast, -1);
llmContext_->setNPast(newNPast);

QLOG_IF(
Priority::DEBUG,
string_format(
"%s: removed %d tool tokens before saving cache\n",
__func__,
toolTokenCount));
}
}
if (!isFileInitialized(sessionPath_)) {
QLOG_IF(
Priority::DEBUG,
Expand Down Expand Up @@ -244,6 +264,26 @@ void CacheManager::saveCache() {
__func__,
sessionPath_.c_str()));

// Remove tool tokens from KV cache before saving
llama_pos toolTokenCount = llmContext_->getLastToolTokenCount();
if (toolTokenCount > 0) {
auto* mem = llama_get_memory(ctx);
llama_pos currentPast = llmContext_->getNPast();
llama_pos newNPast = currentPast - toolTokenCount;

if (newNPast > 0) {
llama_memory_seq_rm(mem, -1, newNPast, -1);
llmContext_->setNPast(newNPast);

QLOG_IF(
Priority::DEBUG,
string_format(
"%s: removed %d tool tokens before saving cache\n",
__func__,
toolTokenCount));
}
}

llama_token sessionTokens[2] = {
static_cast<llama_token>(llmContext_->getNPast()),
static_cast<llama_token>(llmContext_->getFirstMsgTokens())};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,9 @@ std::unique_ptr<LlmContext> LlamaModel::CreateContext(
return std::make_unique<MtmdLlmContext>(params, std::move(llamaInit));
}
isTextLlm = true;
return std::make_unique<TextLlmContext>(params, std::move(llamaInit));
auto ctx = std::make_unique<TextLlmContext>(params, std::move(llamaInit));
ctx->setCalculateToolTokenCount(params.use_jinja);
return ctx;
}

bool LlamaModel::LoadMedia(const LlamaModel::Input& input) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,14 @@ class LlmContext { // NOLINT(cppcoreguidelines-special-member-functions)
*/
virtual llama_pos removeLastNTokens(llama_pos count) = 0;

/**
* Get the number of tool tokens from the last user message.
* Used for cache management when tools are appended after user messages.
*
* @return - the number of tool tokens.
*/
[[nodiscard]] virtual llama_pos getLastToolTokenCount() const = 0;

/**
* The reset media method. It resets the media.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,8 @@ void MtmdLlmContext::setNPast(llama_pos nPast) { this->n_past = nPast; }

llama_pos MtmdLlmContext::getFirstMsgTokens() const { return firstMsgTokens; }

llama_pos MtmdLlmContext::getLastToolTokenCount() const { return 0; }

void MtmdLlmContext::setFirstMsgTokens(llama_pos firstMsgTokens) {
this->firstMsgTokens = firstMsgTokens;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ class MtmdLlmContext: public LlmContext {
*/
void setFirstMsgTokens(llama_pos firstMsgTokens) override;

/**
* The get last tool token count method. It returns 0 for multimodal context.
*
* @return - 0 (not applicable for multimodal).
*/
[[nodiscard]] llama_pos getLastToolTokenCount() const override;

/**
* The set n_discarded method. It sets the n_discarded.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,22 @@ void TextLlmContext::tokenizeChat(
AddonID, toString(EmptyTokenizedInput), errorMsg);
}

// Calculate tool token count by tokenizing without tools
lastToolTokenCount_ = 0;
if (calculateToolTokenCount_ && !tools.empty()) {
common_chat_templates_inputs inputsNoTools = inputs;
inputsNoTools.tools = {};
std::string promptNoTools = getPrompt(tmpls.get(), inputsNoTools);
std::vector<llama_token> tokensNoTools =
common_tokenize(lctx, promptNoTools, addSpecial, true);

if (!tokensNoTools.empty() && !inputTokens.empty() &&
inputTokens.size() > tokensNoTools.size()) {
lastToolTokenCount_ = static_cast<llama_pos>(
inputTokens.size() - tokensNoTools.size());
}
}

// Encode the input if model has encoder
if (llama_model_has_encoder(model) && n_past == 0 && !isCacheLoaded) {
int encInputSize = static_cast<int>(inputTokens.size());
Expand Down Expand Up @@ -508,6 +524,12 @@ void TextLlmContext::setNPast(llama_pos nPast) { this->n_past = nPast; }

llama_pos TextLlmContext::getFirstMsgTokens() const { return firstMsgTokens; }

llama_pos TextLlmContext::getLastToolTokenCount() const { return lastToolTokenCount_; }

void TextLlmContext::setCalculateToolTokenCount(bool enabled) {
calculateToolTokenCount_ = enabled;
}

void TextLlmContext::setFirstMsgTokens(llama_pos firstMsgTokens) {
this->firstMsgTokens = firstMsgTokens;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,22 @@ class TextLlmContext: public LlmContext {
* @param first_msg_tokens - the first msg tokens.
*/
void setFirstMsgTokens(llama_pos firstMsgTokens) override;

/**
* The get last tool token count method. It returns the number of tool tokens
* from the last user message.
*
* @return - the number of tool tokens.
*/
[[nodiscard]] llama_pos getLastToolTokenCount() const override;

/**
* The set calculate tool token count method. It enables/disables
* tool token count calculation for cache management.
*
* @param enabled - whether to calculate tool token count.
*/
void setCalculateToolTokenCount(bool enabled);
/**
* The set n_discarded method. It sets the n_discarded.
*
Expand Down Expand Up @@ -157,6 +173,8 @@ class TextLlmContext: public LlmContext {
llama_pos n_past = 0; // NOLINT(readability-identifier-naming)
llama_pos n_discarded = 0; // NOLINT(readability-identifier-naming)
llama_pos firstMsgTokens = 0; // NOLINT(readability-identifier-naming)
llama_pos lastToolTokenCount_ = 0; // NOLINT(readability-identifier-naming)
bool calculateToolTokenCount_ = true; // NOLINT(readability-identifier-naming)
ThreadPoolPtr threadpool; // NOLINT(readability-identifier-naming)
ThreadPoolPtr threadpool_batch; // NOLINT(readability-identifier-naming)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ std::string getChatTemplateForModel(

// For Qwen3 models, use the fixed template
if (isQwen3Model(model)) {
QLOG_IF(
Priority::ERROR, "[ChatTemplateUtils] Using CHANGED Qwen3 template\n");
return getFixedQwen3Template();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,8 @@ namespace qvac_lib_inference_addon_llama {
namespace utils {

const char* getFixedQwen3Template() {
return R"({%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
return R"({%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
Expand Down Expand Up @@ -81,6 +68,15 @@ const char* getFixedQwen3Template() {
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if tools %}
{{- '<|im_start|>system\n' }}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- endif %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined and enable_thinking is false %}
Expand Down
Loading
Loading