diff --git a/packages/qvac-lib-infer-llamacpp-llm/CHANGELOG.md b/packages/qvac-lib-infer-llamacpp-llm/CHANGELOG.md index 628dbe2d8f..9c99acb184 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/CHANGELOG.md +++ b/packages/qvac-lib-infer-llamacpp-llm/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## [0.14.0] - 2026-03-19 + +### Added + +#### `tools_at_end` configuration for dynamic tool management in multi-turn conversations + +New `tools_at_end` configuration option (`"true"` or `"false"`, default: `"false"`) places tool definitions at the end of the prompt (after conversation history) instead of in the system prompt. This enables KV cache optimization for multi-turn conversations with dynamic tool sets, where tools change between turns. Currently supports Qwen3 models only. + +- **KV cache trimming**: After each turn, tools are automatically removed from the KV cache, preventing stale tool definitions from accumulating +- **Conversation history reuse**: History tokens are preserved in cache, saving recomputation on long conversations +- **Dynamic tool replacement**: Different tool sets can be used per turn without cache bloat from unused tools + ## [0.13.0] - 2026-03-18 ### Added diff --git a/packages/qvac-lib-infer-llamacpp-llm/CMakeLists.txt b/packages/qvac-lib-infer-llamacpp-llm/CMakeLists.txt index a4f15ed077..bc0ce71e2d 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/CMakeLists.txt +++ b/packages/qvac-lib-infer-llamacpp-llm/CMakeLists.txt @@ -74,6 +74,7 @@ endif() ${PROJECT_SOURCE_DIR}/addon/src/utils/BackendSelection.cpp ${PROJECT_SOURCE_DIR}/addon/src/utils/ChatTemplateUtils.cpp ${PROJECT_SOURCE_DIR}/addon/src/utils/Qwen3ReasoningUtils.cpp + ${PROJECT_SOURCE_DIR}/addon/src/utils/Qwen3ToolsDynamicTemplate.cpp ${PROJECT_SOURCE_DIR}/addon/src/utils/QwenTemplate.cpp ) @@ -118,6 +119,7 @@ if(BUILD_CLI) ${PROJECT_SOURCE_DIR}/addon/src/utils/BackendSelection.cpp ${PROJECT_SOURCE_DIR}/addon/src/utils/ChatTemplateUtils.cpp ${PROJECT_SOURCE_DIR}/addon/src/utils/Qwen3ReasoningUtils.cpp + ${PROJECT_SOURCE_DIR}/addon/src/utils/Qwen3ToolsDynamicTemplate.cpp ${PROJECT_SOURCE_DIR}/addon/src/utils/QwenTemplate.cpp ) diff --git a/packages/qvac-lib-infer-llamacpp-llm/README.md b/packages/qvac-lib-infer-llamacpp-llm/README.md index dd7303127e..351c43ddba 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/README.md +++ b/packages/qvac-lib-infer-llamacpp-llm/README.md @@ -158,6 +158,7 @@ const config = { | presence_penalty | float | 0 | Presence penalty for sampling | | frequency_penalty | float | 0 | Frequency penalty for sampling | | tools | `"true"` or `"false"` | `"false"` | Enable tool calling with jinja templating | +| tools_at_end | `"true"` or `"false"` | `"false"` | Place tools at end of prompt ([details](./docs/tools-at-end.md)) | | verbosity | 0 – 3 (0=ERROR, 1=WARNING, 2=INFO, 3=DEBUG) | 0 | Logging verbosity level | | n_discarded | integer | 0 | Tokens to discard in sliding window context | | main-gpu | integer, `"integrated"`, or `"dedicated"` | — | GPU selection for multi-GPU systems | @@ -302,6 +303,8 @@ npm run quickstart - [LoRA Finetuning](./examples/finetune/simple-lora-finetune.js) – Basic LoRA finetuning. - [LoRA Finetuning Pause/Resume](./examples/finetune/simple-lora-finetune-pause-resume.js) – Pause and resume finetuning. - [LoRA Inference](./examples/simple-lora-inference.js) – Inference with a finetuned LoRA adapter. +- [Bench Tools Placement](./examples/benchToolsPlacement.js) – Benchmarks standard vs `tools_at_end` placement across multi-turn conversations. +- [Test Tool Removal](./examples/testToolRemoval.js) – Demonstrates dynamic tool addition and removal between turns. ## OCR with Vision-Language Models diff --git a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/LlamaModel.cpp b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/LlamaModel.cpp index 2408fe3c00..40a312e2d6 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/LlamaModel.cpp +++ b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/LlamaModel.cpp @@ -308,7 +308,9 @@ void LlamaModel::init(bool acquireLock) { common_params params; std::optional adrenoVersion; - commonParamsParse(modelPath, configFilemap, params, adrenoVersion); + bool toolsAtEnd = false; + commonParamsParse( + modelPath, configFilemap, params, adrenoVersion, toolsAtEnd); const std::string errorWhenFailed = toString(UnableToLoadModel); auto streamedFiles = @@ -334,7 +336,8 @@ void LlamaModel::init(bool acquireLock) { snap->llmContext_ = createContext( std::string(constructionArgs_.projectionPath), params, - std::move(llamaInit)); + std::move(llamaInit), + toolsAtEnd); if (snap->configuredNDiscarded_ > 0 && snap->llmContext_) { snap->llmContext_->setNDiscarded(snap->configuredNDiscarded_); @@ -360,6 +363,14 @@ bool LlamaModel::isLoaded() { return static_cast(state_->llmContext_); } +llama_pos LlamaModel::getNPastBeforeTools() const { + std::shared_lock lock(stateMtx_); + if (state_->llmContext_) { + return state_->llmContext_->dynamicToolsState().nPastBeforeTools(); + } + return -1; +} + llama_context* LlamaModel::getContext() { if (!state_->llmContext_) { return nullptr; @@ -504,6 +515,11 @@ std::string LlamaModel::processPromptImpl(const Prompt& prompt) { std::string out; ResolvedPrompt resolved = resolveChatAndTools(prompt.input); + if (resolved.shouldResetAfterInference && + state_->llmContext_->getNPast() > 0) { + resetState(true); + } + if (resolved.chatMsgs.empty() && resolved.tools.empty()) { QLOG_IF( Priority::INFO, @@ -552,6 +568,18 @@ std::string LlamaModel::processPromptImpl(const Prompt& prompt) { if (!prompt.outputCallback) { out = oss.str(); } + auto& dts = state_->llmContext_->dynamicToolsState(); + if (dts.toolsAtEnd() && !resolved.tools.empty() && + dts.nPastBeforeTools() > 0 && + state_->llmContext_->getNPast() > dts.nPastBeforeTools()) { + state_->llmContext_->removeLastNTokens( + state_->llmContext_->getNPast() - dts.nPastBeforeTools()); + dts.reset(); + if (state_->llmContext_->getFirstMsgTokens() > + state_->llmContext_->getNPast()) { + state_->llmContext_->setFirstMsgTokens(state_->llmContext_->getNPast()); + } + } if (resolved.shouldResetAfterInference) { resetState(false); } @@ -589,7 +617,8 @@ qvac_lib_inference_addon_cpp::RuntimeStats LlamaModel::runtimeStats() const { void LlamaModel::commonParamsParse( const std::string& modelPath, std::unordered_map& configFilemap, - common_params& params, std::optional& outAdrenoVersion) { + common_params& params, std::optional& outAdrenoVersion, + bool& outToolsAtEnd) { std::vector configVector; @@ -632,6 +661,26 @@ void LlamaModel::commonParamsParse( configFilemap.erase(iter); } + // parse tools_at_end flag from config + if (auto iter = configFilemap.find("tools_at_end"); + iter != configFilemap.end()) { + std::string val = iter->second; + std::transform(val.begin(), val.end(), val.begin(), ::tolower); + outToolsAtEnd = (val == "true"); + configFilemap.erase(iter); + } + + if (outToolsAtEnd) { + auto arch = metadata_.tryGetString("general.architecture"); + if (!arch.has_value() || arch.value() != "qwen3") { + QLOG_IF( + Priority::WARNING, + "[LlamaModel] tools_at_end is only supported for Qwen3 models, " + "ignoring\n"); + outToolsAtEnd = false; + } + } + auto deviceIt = configFilemap.find("device"); if (deviceIt == configFilemap.end()) { std::string errorMsg = @@ -968,12 +1017,14 @@ void LlamaModel::resetState(bool resetStats) { std::unique_ptr LlamaModel::createContext( std::string&& projectionPath, common_params& params, - common_init_result&& llamaInit) { + common_init_result&& llamaInit, bool toolsAtEnd) { if (!projectionPath.empty()) { params.mmproj.path = std::move(projectionPath); - return std::make_unique(params, std::move(llamaInit)); + return std::make_unique( + params, std::move(llamaInit), toolsAtEnd); } - return std::make_unique(params, std::move(llamaInit)); + return std::make_unique( + params, std::move(llamaInit), toolsAtEnd); } bool LlamaModel::loadMedia(const std::vector& input) { diff --git a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/LlamaModel.hpp b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/LlamaModel.hpp index 214c797b22..3f677ae718 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/LlamaModel.hpp +++ b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/LlamaModel.hpp @@ -158,6 +158,14 @@ class LlamaModel : public IModel, public IModelAsyncLoad, public IModelCancel { */ bool isLoaded(); + /** + * Get the nPast position before tool evaluation. + * This is used to find the boundary in the KV cache after evaluating + * conversation tokens but before tool tokens. + * @return the nPast position, or -1 if not set. + */ + llama_pos getNPastBeforeTools() const; + void waitForLoadInitialization() final { std::shared_ptr localState; { @@ -233,7 +241,8 @@ class LlamaModel : public IModel, public IModelAsyncLoad, public IModelCancel { void commonParamsParse( const std::string& modelPath, std::unordered_map& configFilemap, - common_params& params, std::optional& outAdrenoVersion); + common_params& params, std::optional& outAdrenoVersion, + bool& outToolsAtEnd); /** * The Format prompt method. It formats the prompt json to chat messages. @@ -246,7 +255,8 @@ class LlamaModel : public IModel, public IModelAsyncLoad, public IModelCancel { void resetState(bool resetStats = true); std::unique_ptr createContext( std::string&& projectionPath, common_params& params, - common_init_result&& llamaInit); + common_init_result&& llamaInit, bool toolsAtEnd); + bool loadMedia(const std::vector& input); void setInitLoader( diff --git a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/LlmContext.hpp b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/LlmContext.hpp index f60fec9809..dc8e8578c5 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/LlmContext.hpp +++ b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/LlmContext.hpp @@ -84,32 +84,58 @@ class LlamaBatch { const llama_batch* operator->() const noexcept { return &batch_; } }; -struct ThreadPoolDeleter{ - void operator()(ggml_threadpool* ptr) { - if (ptr != nullptr) { - auto* cpuDev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpuDev == nullptr) { - throw qvac_errors::StatusError( - ADDON_ID, toString(NoBackendFound), "no CPU backend found"); - } - auto* reg = ggml_backend_dev_backend_reg(cpuDev); - void* procAddr = - ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_free"); - if (procAddr == nullptr) { - throw qvac_errors::StatusError( - ADDON_ID, - toString(UnableToDeleteThreadPool), - "Failed to get ggml_threadpool_free function address"); - } - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - auto* ggmlThreadpoolFreeFn = - reinterpret_cast(procAddr); - ggmlThreadpoolFreeFn(ptr); +struct ThreadPoolDeleter { + void operator()(ggml_threadpool* ptr) { + if (ptr != nullptr) { + auto* cpuDev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpuDev == nullptr) { + throw qvac_errors::StatusError( + ADDON_ID, toString(NoBackendFound), "no CPU backend found"); + } + auto* reg = ggml_backend_dev_backend_reg(cpuDev); + void* procAddr = + ggml_backend_reg_get_proc_address(reg, "ggml_threadpool_free"); + if (procAddr == nullptr) { + throw qvac_errors::StatusError( + ADDON_ID, + toString(UnableToDeleteThreadPool), + "Failed to get ggml_threadpool_free function address"); } + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + auto* ggmlThreadpoolFreeFn = + reinterpret_cast(procAddr); + ggmlThreadpoolFreeFn(ptr); } + } }; using ThreadPoolPtr = std::unique_ptr; +class DynamicToolsState { +public: + void setToolsAtEnd(bool v) { toolsAtEnd_ = v; } + [[nodiscard]] bool toolsAtEnd() const { return toolsAtEnd_; } + [[nodiscard]] llama_pos nPastBeforeTools() const { return nPastBeforeTools_; } + void setNPastBeforeTools(llama_pos pos) { nPastBeforeTools_ = pos; } + void recordToolBoundary(llama_pos nPast, llama_pos totalTokens) { + if (toolsAtEnd_ && nConversationOnlyTokens_ > 0) { + nPastBeforeTools_ = nPast - (totalTokens - nConversationOnlyTokens_); + } + } + void setConversationOnlyTokens(llama_pos n) { nConversationOnlyTokens_ = n; } + [[nodiscard]] llama_pos conversationOnlyTokens() const { + return nConversationOnlyTokens_; + } + void reset() { + nConversationOnlyTokens_ = 0; + nPastBeforeTools_ = -1; + } + +private: + bool toolsAtEnd_ = false; + llama_pos nConversationOnlyTokens_ = 0; + llama_pos nPastBeforeTools_ = -1; +}; + class LlmContext { // NOLINT(cppcoreguidelines-special-member-functions) public: LlmContext() = default; @@ -211,6 +237,11 @@ class LlmContext { // NOLINT(cppcoreguidelines-special-member-functions) */ virtual void setNDiscarded(llama_pos nDiscarded) = 0; + DynamicToolsState& dynamicToolsState() { return dynamicToolsState_; } + [[nodiscard]] const DynamicToolsState& dynamicToolsState() const { + return dynamicToolsState_; + } + /** * Get the number of context slides (discards) that have occurred. */ @@ -276,6 +307,7 @@ class LlmContext { // NOLINT(cppcoreguidelines-special-member-functions) * */ virtual void resetMedia() {}; -}; - +private: + DynamicToolsState dynamicToolsState_; +}; diff --git a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/MtmdLlmContext.cpp b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/MtmdLlmContext.cpp index 07678cbcd8..afc9f5811e 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/MtmdLlmContext.cpp +++ b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/MtmdLlmContext.cpp @@ -1,6 +1,7 @@ #include "MtmdLlmContext.hpp" #include +#include #include #include @@ -20,9 +21,11 @@ using namespace qvac_lib_inference_addon_llama::utils; // NOLINTNEXTLINE(readability-function-cognitive-complexity) MtmdLlmContext::MtmdLlmContext( - common_params& commonParams, common_init_result&& llamaInit) + common_params& commonParams, common_init_result&& llamaInit, + bool toolsAtEnd) : llamaInit_(std::move(llamaInit)), params_(commonParams), model_(llamaInit_.model.get()), lctx_(llamaInit_.context.get()) { + dynamicToolsState().setToolsAtEnd(toolsAtEnd); if (model_ == nullptr) { throw qvac_errors::StatusError( @@ -40,7 +43,8 @@ MtmdLlmContext::MtmdLlmContext( vocab_ = llama_model_get_vocab(model_); - std::string chatTemplate = getChatTemplate(model_, params_); + std::string chatTemplate = + getChatTemplate(model_, params_, dynamicToolsState().toolsAtEnd()); tmpls_ = common_chat_templates_init(model_, chatTemplate); smpl_.reset(common_sampler_init(model_, params_.sampling)); @@ -153,6 +157,7 @@ void MtmdLlmContext::tokenizeChat( bool addSpecial = false; if (nPast_ == 0 && !isCacheLoaded) { + dynamicToolsState().reset(); isLastMessageFromUser = true; addSpecial = true; } else if (nPast_ > 0) { @@ -199,6 +204,40 @@ void MtmdLlmContext::tokenizeChat( throw qvac_errors::StatusError(ADDON_ID, toString(EncoderFailed), errorMsg); } + if (dynamicToolsState().toolsAtEnd() && !tools.empty()) { + inputs.tools = {}; + inputs.add_generation_prompt = false; + inputs.use_jinja = params_.use_jinja; + auto promptNoTools = getPrompt(tmpls_.get(), inputs); + + if (!promptNoTools.empty()) { + mtmd_input_text textNoTools; + textNoTools.text = promptNoTools.c_str(); + textNoTools.add_special = addSpecial; + textNoTools.parse_special = true; + + mtmd::input_chunks chunksNoTools(mtmd_input_chunks_init()); + int32_t resNoTools = mtmd_tokenize( + ctxVision_.get(), + chunksNoTools.ptr.get(), + &textNoTools, + bitmapsCPtr.data(), + bitmapsCPtr.size()); + + if (resNoTools == 0) { + dynamicToolsState().setConversationOnlyTokens( + mtmd_helper_get_n_tokens(chunksNoTools.ptr.get())); + assert( + dynamicToolsState().conversationOnlyTokens() <= + static_cast( + mtmd_helper_get_n_tokens(chunks.ptr.get())) && + "conversation-only tokens exceeds total tokens"); + } + } + } else { + dynamicToolsState().setConversationOnlyTokens(0); + } + resetMedia(); } @@ -315,6 +354,8 @@ bool MtmdLlmContext::evalMessageWithTools( nDiscarded_ = ctxSize - firstMsgTokens_ - 1; } } + dynamicToolsState().recordToolBoundary( + nPast_, static_cast(nTokens)); return true; } @@ -550,6 +591,8 @@ void MtmdLlmContext::loadMedia(const std::string& fname) { } void MtmdLlmContext::resetState(bool resetStats) { + + dynamicToolsState().reset(); // Reset the n_past nPast_ = 0; diff --git a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/MtmdLlmContext.hpp b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/MtmdLlmContext.hpp index 7d6f516194..2f25120df8 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/MtmdLlmContext.hpp +++ b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/MtmdLlmContext.hpp @@ -9,7 +9,7 @@ #include "LlmContext.hpp" #include "qvac-lib-inference-addon-cpp/Logger.hpp" -class MtmdLlmContext: public LlmContext { +class MtmdLlmContext : public LlmContext { public: /** * The constructor. @@ -18,7 +18,9 @@ class MtmdLlmContext: public LlmContext { * @param _llama_init - The result of initializing/loading the model using * .gguf file(s) */ - MtmdLlmContext(common_params& commonParams, common_init_result&& llamaInit); + MtmdLlmContext( + common_params& commonParams, common_init_result&& llamaInit, + bool toolsAtEnd = false); /** * The destructor. @@ -38,8 +40,8 @@ class MtmdLlmContext: public LlmContext { * @return - true if successful, false if inference is stopped. */ bool evalMessage( - const std::vector& chatMsgs, - bool isCacheLoaded, bool prefill) override; + const std::vector& chatMsgs, bool isCacheLoaded, + bool prefill) override; /** * The eval message with tools method. It evaluates the message with tools and @@ -165,11 +167,11 @@ class MtmdLlmContext: public LlmContext { void resetMedia() override; private: - /** - * The check antiprompt method. It checks the antiprompt. - * - * @return - true if the antiprompt is found, false otherwise. - */ + /** + * The check antiprompt method. It checks the antiprompt. + * + * @return - true if the antiprompt is found, false otherwise. + */ bool checkAntiprompt(); /** @@ -217,5 +219,3 @@ class MtmdLlmContext: public LlmContext { qvac_lib_inference_addon_llama::UTF8TokenBuffer utf8Buffer_; std::atomic stopGeneration_ = false; }; - - diff --git a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/TextLlmContext.cpp b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/TextLlmContext.cpp index dacb9cf6a0..9c1da490e2 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/TextLlmContext.cpp +++ b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/TextLlmContext.cpp @@ -1,6 +1,7 @@ #include "TextLlmContext.hpp" #include +#include #include #include @@ -23,8 +24,10 @@ using namespace qvac_lib_inference_addon_llama::utils; // NOLINTNEXTLINE(readability-function-cognitive-complexity) TextLlmContext::TextLlmContext( - common_params& commonParams, common_init_result&& llamaInit) + common_params& commonParams, common_init_result&& llamaInit, + bool toolsAtEnd) : llamaInit_(std::move(llamaInit)), params_(commonParams) { + dynamicToolsState().setToolsAtEnd(toolsAtEnd); { model_ = llamaInit_.model.get(); @@ -49,7 +52,8 @@ TextLlmContext::TextLlmContext( lctx_, reasoningState_); } - std::string chatTemplate = getChatTemplate(model_, params_); + std::string chatTemplate = + getChatTemplate(model_, params_, dynamicToolsState().toolsAtEnd()); tmpls_ = common_chat_templates_init(model_, chatTemplate); smpl_.reset(common_sampler_init(model_, params_.sampling)); @@ -189,6 +193,7 @@ void TextLlmContext::tokenizeChat( bool addSpecial = false; if (nPast_ == 0 && !isCacheLoaded) { + dynamicToolsState().reset(); isLastMessageFromUser = true; addSpecial = true; } else if (nPast_ > 0) { @@ -212,6 +217,22 @@ void TextLlmContext::tokenizeChat( if (!prompt.empty()) { inputTokens = common_tokenize(lctx_, prompt, addSpecial, true); + + if (dynamicToolsState().toolsAtEnd() && !tools.empty()) { + inputs.tools = {}; + inputs.add_generation_prompt = false; + inputs.use_jinja = params_.use_jinja; + auto promptNoTools = getPrompt(tmpls_.get(), inputs); + auto tokensNoTools = + common_tokenize(lctx_, promptNoTools, addSpecial, true); + dynamicToolsState().setConversationOnlyTokens(tokensNoTools.size()); + assert( + dynamicToolsState().conversationOnlyTokens() <= + static_cast(inputTokens.size()) && + "conversation-only tokens exceeds total tokens"); + } else { + dynamicToolsState().setConversationOnlyTokens(0); + } } else { std::string errorMsg = string_format( "[TextLlm] %s: formatted chat prompt is empty\n", __func__); @@ -266,7 +287,8 @@ bool TextLlmContext::evalMessageWithTools( if (nTokens >= llama_n_ctx(lctx_)) { std::string errorMsg = string_format( - "[TextLlm] context overflow at prefill step: prompt tokens %ld, max context tokens %d\n", + "[TextLlm] context overflow at prefill step: prompt tokens %ld, max " + "context tokens %d\n", nTokens, llama_n_ctx(lctx_)); throw qvac_errors::StatusError( @@ -362,6 +384,8 @@ bool TextLlmContext::evalMessageWithTools( nDiscarded_ = ctxSize - firstMsgTokens_ - 1; } } + dynamicToolsState().recordToolBoundary( + nPast_, static_cast(inputTokens.size())); return true; } @@ -529,6 +553,8 @@ void TextLlmContext::stop() { stopGeneration_.store(true); } void TextLlmContext::resetState(bool resetStats) { // Reset the n_past + + dynamicToolsState().reset(); nPast_ = 0; // Reset the first msg token length diff --git a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/TextLlmContext.hpp b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/TextLlmContext.hpp index 8d0c29eebc..37294cd073 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/TextLlmContext.hpp +++ b/packages/qvac-lib-infer-llamacpp-llm/addon/src/model-interface/TextLlmContext.hpp @@ -18,7 +18,9 @@ class TextLlmContext : public LlmContext { TextLlmContext(TextLlmContext&&) = delete; TextLlmContext& operator=(TextLlmContext&&) = delete; // Constructor - TextLlmContext(common_params& commonParams, common_init_result&& llamaInit); + TextLlmContext( + common_params& commonParams, common_init_result&& llamaInit, + bool toolsAtEnd = false); // Destructor ~TextLlmContext() override = default; @@ -32,8 +34,8 @@ class TextLlmContext : public LlmContext { * @return - true if successful, false if inference is stopped. */ bool evalMessage( - const std::vector& chatMsgs, - bool isCacheLoaded, bool prefill) override; + const std::vector& chatMsgs, bool isCacheLoaded, + bool prefill) override; /** * The eval message with tools method. It evaluates the message with tools and diff --git a/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/ChatTemplateUtils.cpp b/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/ChatTemplateUtils.cpp index 802c00d4e1..9f462fc60f 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/ChatTemplateUtils.cpp +++ b/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/ChatTemplateUtils.cpp @@ -4,6 +4,7 @@ #include +#include "Qwen3ToolsDynamicTemplate.hpp" #include "QwenTemplate.hpp" #include "utils/LoggingMacros.hpp" @@ -58,27 +59,27 @@ bool isQwen3Model(const ::llama_model* model) { } std::string getChatTemplateForModel( - const ::llama_model* model, const std::string& manualOverride) { - // If manual override is provided, use it as-is + const ::llama_model* model, const std::string& manualOverride, + bool toolsAtEnd) { if (!manualOverride.empty()) { return manualOverride; } - // For Qwen3 models, use the fixed template if (isQwen3Model(model)) { - return getFixedQwen3Template(); + return toolsAtEnd ? getToolsDynamicQwen3Template() + : getFixedQwen3Template(); } - // For other models, no override needed return ""; } -std::string -getChatTemplate(const ::llama_model* model, const common_params& params) { +std::string getChatTemplate( + const ::llama_model* model, const common_params& params, bool toolsAtEnd) { // Use fixed Qwen3 template if model is Qwen3 and Jinja is enabled std::string chatTemplate = params.chat_template; if (params.use_jinja) { - chatTemplate = getChatTemplateForModel(model, params.chat_template); + chatTemplate = + getChatTemplateForModel(model, params.chat_template, toolsAtEnd); if (!chatTemplate.empty() && chatTemplate != params.chat_template) { QLOG_IF( Priority::INFO, "[ChatTemplateUtils] Using fixed Qwen3 template\n"); diff --git a/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/ChatTemplateUtils.hpp b/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/ChatTemplateUtils.hpp index 0d73101918..1b376faf27 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/ChatTemplateUtils.hpp +++ b/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/ChatTemplateUtils.hpp @@ -17,18 +17,20 @@ bool isQwen3Model(const ::llama_model* model); /** * @brief Gets the appropriate chat template for a model * - * For Qwen3 models, returns the fixed template from QwenTemplate.hpp. + * For Qwen3 models, returns the fixed template or tools-at-end template + * based on the toolsAtEnd flag. * For other models, returns the manual override or empty string. */ std::string getChatTemplateForModel( - const ::llama_model* model, const std::string& manualOverride); + const ::llama_model* model, const std::string& manualOverride, + bool toolsAtEnd); /** * @brief Gets the chat template for a model, applying Qwen3 fixes if Jinja is * enabled */ -std::string -getChatTemplate(const ::llama_model* model, const common_params& params); +std::string getChatTemplate( + const ::llama_model* model, const common_params& params, bool toolsAtEnd); /** * @brief Applies chat templates to generate a prompt, with fallback handling diff --git a/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/Qwen3ToolsDynamicTemplate.cpp b/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/Qwen3ToolsDynamicTemplate.cpp new file mode 100644 index 0000000000..435d236a20 --- /dev/null +++ b/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/Qwen3ToolsDynamicTemplate.cpp @@ -0,0 +1,81 @@ +#include "QwenTemplate.hpp" + +namespace qvac_lib_inference_addon_llama { +namespace utils { + +const char* getToolsDynamicQwen3Template() { + return R"({%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} +{%- endif %} +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set content = message.content %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is defined and message.reasoning_content is not none %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in message.content %} + {%- set parts = message.content.split('') %} + {%- set content = parts[-1] | trim %} + {%- set think_parts = parts[0].split('') %} + {%- set reasoning_content = think_parts[-1] | trim %} + {%- endif %} + {%- endif %} + {%- if reasoning_content %} + {{- '<|im_start|>' + message.role + '\n\n' + (reasoning_content | trim) + '\n\n\n' + (content | trim) }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if tools %} + {{- '<|im_start|>system\n' }} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- endif %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %})"; +} + +} // namespace utils +} // namespace qvac_lib_inference_addon_llama diff --git a/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/Qwen3ToolsDynamicTemplate.hpp b/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/Qwen3ToolsDynamicTemplate.hpp new file mode 100644 index 0000000000..024e0416bf --- /dev/null +++ b/packages/qvac-lib-infer-llamacpp-llm/addon/src/utils/Qwen3ToolsDynamicTemplate.hpp @@ -0,0 +1,14 @@ +#pragma once + +namespace qvac_lib_inference_addon_llama { +namespace utils { + +// Uses fxed Qwen3 chat template as a base +// see QwenTemplate.hpp +// +// Changes: Tools are put in additional system prompt at the end +// in order to apply new (different) tools on each user prompt +const char* getToolsDynamicQwen3Template(); + +} // namespace utils +} // namespace qvac_lib_inference_addon_llama diff --git a/packages/qvac-lib-infer-llamacpp-llm/docs/tools-at-end.md b/packages/qvac-lib-infer-llamacpp-llm/docs/tools-at-end.md new file mode 100644 index 0000000000..9831eabee9 --- /dev/null +++ b/packages/qvac-lib-infer-llamacpp-llm/docs/tools-at-end.md @@ -0,0 +1,55 @@ +# Tools at End of Prompt + +## Overview + +The `tools_at_end` configuration option places tool definitions at the end of the prompt (after the conversation history) instead of the default position (typically inside the system prompt). This enables KV cache optimization for multi-turn conversations with dynamic tool sets. + +## Configuration + +```js +const config = { + tools: 'true', + tools_at_end: 'true' +} +``` + +## Model Support + +Currently `tools_at_end` is only supported for **Qwen3** models. If enabled on a non-Qwen3 model, the flag is silently ignored and a warning is logged. + +## Usage Requirements + +### Multi-turn Conversation Pattern + +When using `tools_at_end`, consumers must follow a specific pattern: + +1. **Include prior response**: Pass the assistant's previous response (including any `` or `` blocks) back alongside the new user message. + +2. **Full history each turn**: Since the KV cache is trimmed after each turn, the full conversation history must be re-provided. + +``` +Turn 1: [user-q-1] + [tools-1] → [response-1] +Turn 2: [response-1] + [user-q-2] + [tools-2] → [response-2] + (tools-1 is automatically trimmed from cache) +``` + +3. **Strip stale tool blocks**: Remove `` blocks from prior responses when tools have changed to prevent model from pattern-matching on removed tools. + +## Performance Characteristics + +| Overhead Type | Impact | Note | +|---------------|--------|------| +| Double tokenization | ~2% | Required to calculate tool token boundary | +| Tools prefill | Up to 100% | Tools re-evaluated every turn regardless of change | + +## When to Use + +**Use `tools_at_end` when:** +- Long conversations with many turns (cache hit on history saves significant compute) +- Frequent tool replacement between turns (e.g., tools A → tools B → tools A) + +**Use standard `tools` config when:** +- Short conversations or single-turn tool calls +- Tools remain the same across many turns + +The feature provides net benefit when conversation history cache savings outweigh the tools prefill overhead. \ No newline at end of file diff --git a/packages/qvac-lib-infer-llamacpp-llm/examples/benchToolsPlacement.js b/packages/qvac-lib-infer-llamacpp-llm/examples/benchToolsPlacement.js new file mode 100644 index 0000000000..2364281ea4 --- /dev/null +++ b/packages/qvac-lib-infer-llamacpp-llm/examples/benchToolsPlacement.js @@ -0,0 +1,539 @@ +'use strict' + +const LlmLlamacpp = require('../index') +const FilesystemDL = require('@qvac/dl-filesystem') +const path = require('bare-path') +const fs = require('bare-fs') +const process = require('bare-process') +const os = require('bare-os') +const { downloadModel } = require('./utils') + +// ─── Configuration ────────────────────────────────────────────────────────── + +const isDarwinX64 = os.platform() === 'darwin' && os.arch() === 'x64' +const isLinuxArm64 = os.platform() === 'linux' && os.arch() === 'arm64' +const useCpu = isDarwinX64 || isLinuxArm64 + +const MODEL = { + name: 'Qwen3-1.7B-Q4_0.gguf', + url: 'https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_0.gguf' +} + +const NUM_TURNS = 20 + +// ─── Tool definitions ─────────────────────────────────────────────────────── + +const TOOL_WEATHER = { + type: 'function', + name: 'getWeather', + description: 'Get current weather for a city', + parameters: { + type: 'object', + properties: { + city: { type: 'string', description: 'City name' }, + units: { type: 'string', enum: ['celsius', 'fahrenheit'], description: 'Temperature units' } + }, + required: ['city'] + } +} + +const TOOL_SEARCH = { + type: 'function', + name: 'searchWeb', + description: 'Search the web for information', + parameters: { + type: 'object', + properties: { + query: { type: 'string', description: 'Search query' }, + maxResults: { type: 'integer', minimum: 1, maximum: 10, description: 'Max results' } + }, + required: ['query'] + } +} + +const TOOL_CALCULATOR = { + type: 'function', + name: 'calculate', + description: 'Perform a math calculation', + parameters: { + type: 'object', + properties: { + expression: { type: 'string', description: 'Math expression to evaluate' } + }, + required: ['expression'] + } +} + +const TOOL_TRANSLATE = { + type: 'function', + name: 'translateText', + description: 'Translate text to another language', + parameters: { + type: 'object', + properties: { + text: { type: 'string', description: 'Text to translate' }, + targetLang: { type: 'string', description: 'Target language code (e.g. fr, es, de)' } + }, + required: ['text', 'targetLang'] + } +} + +const TOOL_EMAIL = { + type: 'function', + name: 'sendEmail', + description: 'Send an email to a recipient', + parameters: { + type: 'object', + properties: { + to: { type: 'string', description: 'Recipient email address' }, + subject: { type: 'string', description: 'Email subject' }, + body: { type: 'string', description: 'Email body content' } + }, + required: ['to', 'subject', 'body'] + } +} + +const TOOL_REMINDER = { + type: 'function', + name: 'setReminder', + description: 'Set a reminder for a specific time', + parameters: { + type: 'object', + properties: { + message: { type: 'string', description: 'Reminder message' }, + time: { type: 'string', description: 'Time for the reminder (ISO 8601)' } + }, + required: ['message', 'time'] + } +} + +// Different tools per turn (for scenario C — dynamic tools) +const DYNAMIC_TOOLS_PER_TURN = [ + [TOOL_WEATHER, TOOL_SEARCH], // Turn 1: weather + search + [TOOL_CALCULATOR], // Turn 2: calculator only + [TOOL_TRANSLATE], // Turn 3: translate only + [TOOL_EMAIL, TOOL_REMINDER], // Turn 4: email + reminder + [TOOL_WEATHER], // Turn 5: weather only + [TOOL_SEARCH, TOOL_CALCULATOR], // Turn 6: search + calculator + [TOOL_TRANSLATE, TOOL_EMAIL], // Turn 7: translate + email + [TOOL_REMINDER, TOOL_WEATHER, TOOL_SEARCH], // Turn 8: reminder + weather + search + [TOOL_CALCULATOR, TOOL_TRANSLATE], // Turn 9: calculator + translate + [TOOL_EMAIL], // Turn 10: email only + [TOOL_WEATHER, TOOL_CALCULATOR], // Turn 11: weather + calculator + [TOOL_SEARCH], // Turn 12: search only + [TOOL_REMINDER, TOOL_TRANSLATE], // Turn 13: reminder + translate + [TOOL_WEATHER, TOOL_EMAIL], // Turn 14: weather + email + [TOOL_CALCULATOR, TOOL_SEARCH], // Turn 15: calculator + search + [TOOL_TRANSLATE], // Turn 16: translate only + [TOOL_REMINDER, TOOL_EMAIL, TOOL_WEATHER], // Turn 17: reminder + email + weather + [TOOL_SEARCH, TOOL_TRANSLATE], // Turn 18: search + translate + [TOOL_CALCULATOR], // Turn 19: calculator only + [TOOL_WEATHER, TOOL_REMINDER] // Turn 20: weather + reminder +] + +const CONVERSATION_TURNS_DYNAMIC = [ + { user: 'What is the weather in Paris?' }, + { user: 'Calculate 156 * 23' }, + { user: 'Translate "hello world" to French' }, + { user: 'Send an email to bob@example.com about the meeting tomorrow' }, + { user: 'What is the weather in London?' }, + { user: 'Search for AI news and calculate 999 / 3' }, + { user: 'Translate "good morning" to Spanish and email the result to alice@example.com' }, + { user: 'Set a reminder to check the weather in Berlin tomorrow and search for flight deals' }, + { user: 'Calculate 2^10 and translate the result to German' }, + { user: 'Send an email to team@example.com with a summary of today\'s tasks' }, + { user: 'What is the weather in Tokyo and calculate 42 * 17' }, + { user: 'Search for latest Python tutorials' }, + { user: 'Set a reminder for lunch at noon and translate "thank you" to Japanese' }, + { user: 'What is the weather in Sydney and email the forecast to weather@example.com' }, + { user: 'Calculate the square root of 144 and search for math resources' }, + { user: 'Translate "goodbye" to Italian' }, + { user: 'Set a reminder to call the dentist, email jane@example.com about it, and check weather in Rome' }, + { user: 'Search for healthy recipes and translate the top result to Portuguese' }, + { user: 'Calculate 365 * 24' }, + { user: 'What is the weather in Berlin and set a reminder to pack an umbrella' } +] + +// ─── Tool call extraction & validation ────────────────────────────────────── + +function stripInternalBlocks (text) { + return text + .replace(/[\s\S]*?<\/think>/g, '') + .replace(/[\s\S]*?<\/tool_call>/g, '') + .trim() +} + +function extractToolCalls (response) { + const toolCalls = [] + const toolCallRegex = /([\s\S]*?)<\/tool_call>/g + let match + while ((match = toolCallRegex.exec(response)) !== null) { + try { + const parsed = JSON.parse(match[1].trim()) + toolCalls.push(parsed.name || parsed.function?.name || 'unknown') + } catch (_) {} + } + return toolCalls +} + +function validateToolCalls (turnIndex, output, availableTools) { + const calledTools = extractToolCalls(output) + const availableNames = availableTools.map(t => t.name) + const violations = [] + + for (const called of calledTools) { + if (!availableNames.includes(called)) { + violations.push(called) + } + } + + const status = violations.length === 0 ? 'OK' : 'VIOLATION' + return { + status, + calledTools, + availableNames, + violations + } +} + +// ─── Helpers ──────────────────────────────────────────────────────────────── + +function makeBaseConfig (toolsAtEnd) { + return { + device: useCpu ? 'cpu' : 'gpu', + gpu_layers: '999', + ctx_size: '4096', + n_predict: '256', + temp: '0.1', + seed: '1', + verbosity: '0', + tools: 'true', + tools_at_end: toolsAtEnd ? 'true' : 'false' + } +} + +async function loadModel (dirPath, modelName, config) { + const loader = new FilesystemDL({ dirPath }) + const model = new LlmLlamacpp({ + loader, + modelName, + diskPath: dirPath, + logger: console, + opts: { stats: true } + }, config) + await model.load() + return { model, loader } +} + +async function runAndCollect (model, prompt) { + const response = await model.run(prompt) + const chunks = [] + await response + .onUpdate(data => { chunks.push(data) }) + .await() + return { + output: chunks.join(''), + stats: response.stats + } +} + +function hrMs (hrtime) { + return (hrtime[0] * 1e3 + hrtime[1] / 1e6).toFixed(2) +} + +function cleanCache (cachePath) { + try { fs.unlinkSync(cachePath) } catch (_) {} +} + +// ─── Generic scenario runner ──────────────────────────────────────────────── + +async function runScenario (dirPath, modelName, opts) { + const { name, toolsAtEnd, dynamicTools, conversationTurns, getToolsForTurn, cacheName } = opts + + console.log('\n' + '='.repeat(70)) + console.log(name) + console.log('='.repeat(70)) + + const config = makeBaseConfig(toolsAtEnd) + const { model, loader } = await loadModel(dirPath, modelName, config) + const cachePath = path.join(dirPath, cacheName) + cleanCache(cachePath) + + const turnStats = [] + const toolValidations = [] + let lastAssistantResponse = null + // For tools_in_system with dynamic tools: track full conversation history for replay + const conversationHistory = [] + + try { + for (let i = 0; i < NUM_TURNS; i++) { + const turn = conversationTurns[i] + const turnTools = getToolsForTurn(i) + let prompt + + if (toolsAtEnd) { + // tools_at_end: session cache + re-send last assistant response + new user + tools + prompt = [ + { role: 'session', content: cachePath }, + ...(i === 0 + ? [{ role: 'system', content: 'You are a helpful assistant.' }, { role: 'user', content: turn.user }] + : [ + ...(lastAssistantResponse ? [{ role: 'assistant', content: lastAssistantResponse }] : []), + { role: 'user', content: turn.user } + ]), + ...turnTools + ] + } else if (dynamicTools) { + // tools_in_system with changing tools: reset cache and replay full history with new tools + prompt = [ + { role: 'session', content: cachePath }, + ...(i > 0 ? [{ role: 'session', content: 'reset' }] : []), + { role: 'system', content: 'You are a helpful assistant.' }, + ...conversationHistory, + { role: 'user', content: turn.user }, + ...turnTools + ] + } else { + // tools_in_system with same tools: session cache + only new user msg (tools cached from turn 1) + prompt = [ + { role: 'session', content: cachePath }, + ...(i === 0 + ? [{ role: 'system', content: 'You are a helpful assistant.' }, { role: 'user', content: turn.user }] + : [{ role: 'user', content: turn.user }]), + ...(i === 0 ? turnTools : []) + ] + } + + const t0 = process.hrtime() + const result = await runAndCollect(model, prompt) + const elapsed = process.hrtime(t0) + lastAssistantResponse = stripInternalBlocks(result.output) + + // Track history for replay in tools_in_system dynamic mode + conversationHistory.push({ role: 'user', content: turn.user }) + conversationHistory.push({ role: 'assistant', content: stripInternalBlocks(result.output) }) + + const validation = validateToolCalls(i, result.output, turnTools) + toolValidations.push(validation) + + turnStats.push({ + turn: i + 1, + wallMs: hrMs(elapsed), + promptTokens: result.stats?.promptTokens || 0, + generatedTokens: result.stats?.generatedTokens || 0, + cacheTokens: result.stats?.CacheTokens || 0, + ttft: result.stats?.TTFT || 0, + tps: result.stats?.TPS || 0 + }) + + const toolStatus = validation.status === 'OK' ? 'OK' : `VIOLATION: called [${validation.violations.join(', ')}]` + const calledStr = validation.calledTools.length > 0 ? validation.calledTools.join(', ') : 'none' + const availStr = validation.availableNames.join(', ') + + console.log( + ` Turn ${i + 1}: wall=${hrMs(elapsed)}ms prompt=${turnStats[i].promptTokens} ` + + `gen=${turnStats[i].generatedTokens} cache=${turnStats[i].cacheTokens} ` + + `TTFT=${turnStats[i].ttft}ms TPS=${turnStats[i].tps}` + ) + console.log( + ` tools=[${availStr}] called=[${calledStr}] validation=${toolStatus}` + ) + } + } finally { + await model.unload() + await loader.close() + cleanCache(cachePath) + } + + return { turnStats, toolValidations } +} + +// ─── Summary ──────────────────────────────────────────────────────────────── + +function printComparison (labelA, statsA, labelB, statsB) { + console.log('\n' + '='.repeat(80)) + console.log(`COMPARISON: ${labelA} (A) vs ${labelB} (B)`) + console.log('='.repeat(80)) + console.log('') + console.log('Turn | Wall A (ms) | Wall B (ms) | Δ ms | Prompt A | Prompt B | Cache A | Cache B | TTFT A | TTFT B') + console.log('-----|-------------|-------------|----------|----------|----------|---------|---------|---------|--------') + + let totalA = 0 + let totalB = 0 + + for (let i = 0; i < statsA.length; i++) { + const a = statsA[i] + const b = statsB[i] + const delta = (parseFloat(a.wallMs) - parseFloat(b.wallMs)).toFixed(2) + totalA += parseFloat(a.wallMs) + totalB += parseFloat(b.wallMs) + + const ttftA = typeof a.ttft === 'number' ? a.ttft.toFixed(0) : String(a.ttft) + const ttftB = typeof b.ttft === 'number' ? b.ttft.toFixed(0) : String(b.ttft) + + console.log( + ` ${a.turn} ` + + `| ${a.wallMs.padStart(11)} ` + + `| ${b.wallMs.padStart(11)} ` + + `| ${delta.padStart(8)} ` + + `| ${String(a.promptTokens).padStart(8)} ` + + `| ${String(b.promptTokens).padStart(8)} ` + + `| ${String(a.cacheTokens).padStart(7)} ` + + `| ${String(b.cacheTokens).padStart(7)} ` + + `| ${ttftA.padStart(7)} ` + + `| ${ttftB.padStart(7)}` + ) + } + + console.log('-----|-------------|-------------|----------|----------|----------|---------|---------|---------|--------') + console.log( + ' TOT ' + + `| ${totalA.toFixed(2).padStart(11)} ` + + `| ${totalB.toFixed(2).padStart(11)} ` + + `| ${(totalA - totalB).toFixed(2).padStart(8)} |` + ) + console.log('') + + const pctDiff = ((totalA - totalB) / totalB * 100).toFixed(1) + if (totalA > totalB) { + console.log(` → A is ${pctDiff}% SLOWER overall (${(totalA - totalB).toFixed(0)}ms extra across ${NUM_TURNS} turns)`) + } else { + console.log(` → A is ${Math.abs(pctDiff)}% FASTER overall (${(totalB - totalA).toFixed(0)}ms saved across ${NUM_TURNS} turns)`) + } +} + +function printToolValidationSummary (label, validations) { + console.log(`\n─── Tool Call Validation: ${label} ───`) + let allOk = true + for (let i = 0; i < validations.length; i++) { + const v = validations[i] + const icon = v.status === 'OK' ? 'PASS' : 'FAIL' + if (v.status !== 'OK') allOk = false + + if (v.calledTools.length === 0) { + console.log(` Turn ${i + 1} [${icon}]: no tool calls (available: ${v.availableNames.join(', ')})`) + } else if (v.violations.length > 0) { + console.log(` Turn ${i + 1} [${icon}]: called [${v.calledTools.join(', ')}] available [${v.availableNames.join(', ')}] STALE TOOLS USED: [${v.violations.join(', ')}]`) + } else { + console.log(` Turn ${i + 1} [${icon}]: called [${v.calledTools.join(', ')}] (available: ${v.availableNames.join(', ')})`) + } + } + console.log(` Result: ${allOk ? 'ALL PASSED — no stale/trimmed tools were called' : 'FAILURES DETECTED — model called tools that should have been trimmed'}`) +} + +// ─── Main ─────────────────────────────────────────────────────────────────── + +async function main () { + console.log('Benchmark: tools_at_end vs tools_in_system — performance & correctness') + console.log(`Model: ${MODEL.name}`) + console.log(`Turns: ${NUM_TURNS}`) + console.log(`Device: ${useCpu ? 'CPU' : 'GPU'}`) + + const [modelName, dirPath] = await downloadModel(MODEL.url, MODEL.name) + + // // ── Scenario A: tools_at_end, same tools every turn ── + // const resultA = await runScenario(dirPath, modelName, { + // name: 'SCENARIO A: tools_at_end = true, SAME tools every turn', + // toolsAtEnd: true, + // conversationTurns: CONVERSATION_TURNS_FIXED, + // getToolsForTurn: () => FIXED_TOOLS, + // cacheName: 'bench-A-at-end-same.bin' + // }) + + // // ── Scenario B: tools_in_system (standard), same tools every turn ── + // const resultB = await runScenario(dirPath, modelName, { + // name: 'SCENARIO B: tools_at_end = false (standard), SAME tools every turn', + // toolsAtEnd: false, + // conversationTurns: CONVERSATION_TURNS_FIXED, + // getToolsForTurn: () => FIXED_TOOLS, + // cacheName: 'bench-B-in-system-same.bin' + // }) + + // ── Scenario C: tools_at_end with dynamic tools ── + const resultC = await runScenario(dirPath, modelName, { + name: 'SCENARIO C: tools_at_end = true, DIFFERENT tools each turn', + toolsAtEnd: true, + dynamicTools: true, + conversationTurns: CONVERSATION_TURNS_DYNAMIC, + getToolsForTurn: (i) => DYNAMIC_TOOLS_PER_TURN[i], + cacheName: 'bench-C-at-end-dynamic.bin' + }) + + // ── Scenario D: tools_in_system with dynamic tools (must reset+replay each turn) ── + const resultD = await runScenario(dirPath, modelName, { + name: 'SCENARIO D: tools_at_end = false, DIFFERENT tools each turn (reset+replay)', + toolsAtEnd: false, + dynamicTools: true, + conversationTurns: CONVERSATION_TURNS_DYNAMIC, + getToolsForTurn: (i) => DYNAMIC_TOOLS_PER_TURN[i], + cacheName: 'bench-D-in-system-dynamic.bin' + }) + + // ── Comparisons ── + console.log('\n' + '#'.repeat(80)) + console.log('# RESULTS SUMMARY') + console.log('#'.repeat(80)) + + printComparison( + 'tools_at_end (dynamic tools)', + resultC.turnStats, + 'tools_in_system (dynamic tools, reset+replay)', + resultD.turnStats + ) + + // ── Tool validation summary ── + console.log('\n' + '#'.repeat(80)) + console.log('# TOOL CALL CORRECTNESS') + console.log('#'.repeat(80)) + + printToolValidationSummary('Scenario C — tools_at_end, dynamic tools', resultC.toolValidations) + printToolValidationSummary('Scenario D — tools_in_system, dynamic tools (reset+replay)', resultD.toolValidations) + + console.log('\n' + '─'.repeat(80)) + console.log('Key:') + console.log(' Scenario C: tools_at_end=true with dynamic tools — trims & re-sends prev response') + console.log(' Scenario D: tools_at_end=false with dynamic tools — must reset cache & replay full history') + console.log(' PASS = model only called tools available in that turn') + console.log(' FAIL = model called a tool from a previous turn (stale/trimmed tool leak)') + console.log('─'.repeat(80)) + + // ── ASCII Graph: wall time per turn ── + console.log('\n' + '#'.repeat(80)) + console.log('# TIME (ms) vs TURN — tools_at_end (C) vs tools_in_system+replay (D)') + console.log('#'.repeat(80)) + console.log('') + + const BAR_WIDTH = 50 + const allTimes = [ + ...resultC.turnStats.map(s => parseFloat(s.wallMs)), + ...resultD.turnStats.map(s => parseFloat(s.wallMs)) + ] + const maxTime = Math.max(...allTimes) + + function makeBar (value, max, width) { + const filled = Math.round((value / max) * width) + return '\u2588'.repeat(filled) + '\u2591'.repeat(width - filled) + } + + console.log('Turn | C (ms) | D (ms) | Graph') + console.log('-----|----------|----------|' + '-'.repeat(BAR_WIDTH * 2 + 14)) + + for (let i = 0; i < resultC.turnStats.length; i++) { + const cMs = parseFloat(resultC.turnStats[i].wallMs) + const dMs = parseFloat(resultD.turnStats[i].wallMs) + const cBar = makeBar(cMs, maxTime, BAR_WIDTH) + const dBar = makeBar(dMs, maxTime, BAR_WIDTH) + console.log( + ` ${String(i + 1).padStart(2)} ` + + `| ${String(Math.round(cMs)).padStart(8)} ` + + `| ${String(Math.round(dMs)).padStart(8)} ` + + `| C:${cBar} D:${dBar}` + ) + } + console.log('') +} + +main().catch(err => { + console.error('Fatal:', err.message || err) + process.exit(1) +}) diff --git a/packages/qvac-lib-infer-llamacpp-llm/examples/testToolRemoval.js b/packages/qvac-lib-infer-llamacpp-llm/examples/testToolRemoval.js new file mode 100644 index 0000000000..104cf1fdb8 --- /dev/null +++ b/packages/qvac-lib-infer-llamacpp-llm/examples/testToolRemoval.js @@ -0,0 +1,343 @@ +'use strict' + +const LlmLlamacpp = require('../index') +const FilesystemDL = require('@qvac/dl-filesystem') +const path = require('bare-path') +const fs = require('bare-fs') +const process = require('bare-process') +const os = require('bare-os') +const { downloadModel } = require('./utils') + +const isDarwinX64 = os.platform() === 'darwin' && os.arch() === 'x64' +const isLinuxArm64 = os.platform() === 'linux' && os.arch() === 'arm64' +const useCpu = isDarwinX64 || isLinuxArm64 + +const MODEL = { + name: 'Qwen3-1.7B-Q4_0.gguf', + url: 'https://huggingface.co/unsloth/Qwen3-1.7B-GGUF/resolve/main/Qwen3-1.7B-Q4_0.gguf' +} + +const TOOL_WEATHER = { + type: 'function', + name: 'getWeather', + description: 'Get current weather for a city', + parameters: { + type: 'object', + properties: { + city: { type: 'string', description: 'City name' }, + units: { type: 'string', enum: ['celsius', 'fahrenheit'], description: 'Temperature units' } + }, + required: ['city'] + } +} + +const TOOL_CALCULATOR = { + type: 'function', + name: 'calculate', + description: 'Perform a math calculation', + parameters: { + type: 'object', + properties: { + expression: { type: 'string', description: 'Math expression to evaluate' } + }, + required: ['expression'] + } +} + +function stripInternalBlocks (text) { + return text + .replace(/[\s\S]*?<\/think>/g, '') + .replace(/[\s\S]*?<\/tool_call>/g, '') + .trim() +} + +function extractToolCalls (response) { + const toolCalls = [] + const toolCallRegex = /([\s\S]*?)<\/tool_call>/g + let match + while ((match = toolCallRegex.exec(response)) !== null) { + try { + const parsed = JSON.parse(match[1].trim()) + toolCalls.push(parsed.name || parsed.function?.name || 'unknown') + } catch (_) {} + } + return toolCalls +} + +async function loadModel (dirPath, modelName, config) { + const loader = new FilesystemDL({ dirPath }) + const model = new LlmLlamacpp({ + loader, + modelName, + diskPath: dirPath, + logger: console, + opts: { stats: true } + }, config) + await model.load() + return { model, loader } +} + +async function runAndCollect (model, prompt) { + const response = await model.run(prompt) + const chunks = [] + await response.onUpdate(data => { chunks.push(data) }).await() + return { output: chunks.join(''), stats: response.stats } +} + +async function main () { + console.log('Test: tool removal correctness with tools_at_end') + console.log('='.repeat(70)) + console.log('') + + const [modelName, dirPath] = await downloadModel(MODEL.url, MODEL.name) + const config = { + device: useCpu ? 'cpu' : 'gpu', + gpu_layers: '999', + ctx_size: '4096', + n_predict: '256', + temp: '0.1', + seed: '1', + verbosity: '0', + tools: 'true', + tools_at_end: 'true' + } + + const { model, loader } = await loadModel(dirPath, modelName, config) + const cachePath = path.join(dirPath, 'test-tool-removal.bin') + try { fs.unlinkSync(cachePath) } catch (_) {} + + let lastResponse = null + + try { + // ── Turn 1: provide getWeather, ask about weather ── + console.log('── Turn 1: tools=[getWeather], ask about weather ──') + const prompt1 = [ + { role: 'session', content: cachePath }, + { role: 'system', content: 'You are a helpful assistant. You must use tools when available. Do not answer without using a tool.' }, + { role: 'user', content: 'What is the weather in Paris?' }, + TOOL_WEATHER + ] + const r1 = await runAndCollect(model, prompt1) + lastResponse = stripInternalBlocks(r1.output) + const calls1 = extractToolCalls(r1.output) + console.log(` Response tools called: [${calls1.join(', ') || 'none'}]`) + console.log(' Expected: [getWeather]') + console.log(` ${calls1.includes('getWeather') ? 'PASS ✓' : 'FAIL ✗'}`) + console.log('') + + // ── Turn 2: REMOVE getWeather, provide calculate instead ── + console.log('── Turn 2: tools=[calculate] (getWeather REMOVED), ask to calculate ──') + const prompt2 = [ + { role: 'session', content: cachePath }, + { role: 'assistant', content: lastResponse }, + { role: 'user', content: 'Calculate 256 * 128' }, + TOOL_CALCULATOR + ] + const r2 = await runAndCollect(model, prompt2) + lastResponse = stripInternalBlocks(r2.output) + const calls2 = extractToolCalls(r2.output) + console.log(` Response tools called: [${calls2.join(', ') || 'none'}]`) + console.log(' Expected: [calculate]') + console.log(` ${calls2.includes('calculate') && !calls2.includes('getWeather') ? 'PASS ✓' : 'FAIL ✗'}`) + console.log('') + + // ── Turn 3: KEEP only calculate, ask about weather (should NOT call getWeather) ── + console.log('── Turn 3: tools=[calculate] (getWeather still removed), ask about weather ──') + console.log(' This is the KEY test: model should NOT call getWeather (it was removed)') + const prompt3 = [ + { role: 'session', content: cachePath }, + { role: 'assistant', content: lastResponse }, + { role: 'user', content: 'What is the weather in London?' }, + TOOL_CALCULATOR + ] + const r3 = await runAndCollect(model, prompt3) + lastResponse = stripInternalBlocks(r3.output) + const calls3 = extractToolCalls(r3.output) + console.log(` Response tools called: [${calls3.join(', ') || 'none'}]`) + console.log(' Expected: NOT getWeather (it\'s not available)') + const weatherLeak = calls3.includes('getWeather') + console.log(` ${weatherLeak ? 'FAIL ✗ — stale tool leak! getWeather was called despite being removed' : 'PASS ✓ — model did not call removed tool'}`) + console.log('') + + // ── Turn 4: bring back getWeather, remove calculate, ask to calculate ── + console.log('── Turn 4: tools=[getWeather] (calculate REMOVED), ask to calculate ──') + console.log(' Model should NOT call calculate (it was removed)') + const prompt4 = [ + { role: 'session', content: cachePath }, + { role: 'assistant', content: lastResponse }, + { role: 'user', content: 'Calculate 999 / 3' }, + TOOL_WEATHER + ] + const r4 = await runAndCollect(model, prompt4) + lastResponse = stripInternalBlocks(r4.output) + const calls4 = extractToolCalls(r4.output) + console.log(` Response tools called: [${calls4.join(', ') || 'none'}]`) + console.log(' Expected: NOT calculate (it\'s not available)') + const calcLeak = calls4.includes('calculate') + console.log(` ${calcLeak ? 'FAIL ✗ — stale tool leak! calculate was called despite being removed' : 'PASS ✓ — model did not call removed tool'}`) + console.log('') + + // ── Summary ── + console.log('='.repeat(70)) + console.log('SUMMARY') + console.log('='.repeat(70)) + const results = [ + { turn: 1, pass: calls1.includes('getWeather'), desc: 'getWeather available → called it' }, + { turn: 2, pass: calls2.includes('calculate') && !calls2.includes('getWeather'), desc: 'calculate available, getWeather removed → called calculate' }, + { turn: 3, pass: !weatherLeak, desc: 'getWeather removed → did NOT call it' }, + { turn: 4, pass: !calcLeak, desc: 'calculate removed → did NOT call it' } + ] + for (const r of results) { + console.log(` Turn ${r.turn}: ${r.pass ? 'PASS ✓' : 'FAIL ✗'} — ${r.desc}`) + } + const allPass = results.every(r => r.pass) + console.log('') + console.log(allPass + ? ' ALL PASSED — tool trimming correctly prevents stale tool usage' + : ' FAILURES DETECTED — removed tools leaked through the cache') + } finally { + await model.unload() + await loader.close() + try { fs.unlinkSync(cachePath) } catch (_) {} + } +} + +// ─── Same test but with tools_in_system (reset+replay) ───────────────────── + +async function mainInSystem () { + console.log('\n\n') + console.log('Test: tool removal correctness with tools_in_system (reset+replay)') + console.log('='.repeat(70)) + console.log('') + + const [modelName, dirPath] = await downloadModel(MODEL.url, MODEL.name) + const config = { + device: useCpu ? 'cpu' : 'gpu', + gpu_layers: '999', + ctx_size: '4096', + n_predict: '256', + temp: '0.1', + seed: '1', + verbosity: '0', + tools: 'true', + tools_at_end: 'false' + } + + const { model, loader } = await loadModel(dirPath, modelName, config) + const cachePath = path.join(dirPath, 'test-tool-removal-insystem.bin') + try { fs.unlinkSync(cachePath) } catch (_) {} + + const SYSTEM = 'You are a helpful assistant. You must use tools when available. Do not answer without using a tool.' + const history = [] // accumulate {role, content} for replay + + try { + // ── Turn 1: provide getWeather, ask about weather ── + console.log('── Turn 1: tools=[getWeather], ask about weather ──') + const prompt1 = [ + { role: 'session', content: cachePath }, + { role: 'system', content: SYSTEM }, + { role: 'user', content: 'What is the weather in Paris?' }, + TOOL_WEATHER + ] + const r1 = await runAndCollect(model, prompt1) + history.push({ role: 'user', content: 'What is the weather in Paris?' }) + history.push({ role: 'assistant', content: stripInternalBlocks(r1.output) }) + const calls1 = extractToolCalls(r1.output) + console.log(` Response tools called: [${calls1.join(', ') || 'none'}]`) + console.log(' Expected: [getWeather]') + console.log(` ${calls1.includes('getWeather') ? 'PASS ✓' : 'FAIL ✗'}`) + console.log('') + + // ── Turn 2: REMOVE getWeather, provide calculate — reset+replay ── + console.log('── Turn 2: tools=[calculate] (getWeather REMOVED), ask to calculate ──') + const prompt2 = [ + { role: 'session', content: cachePath }, + { role: 'session', content: 'reset' }, + { role: 'system', content: SYSTEM }, + ...history, + { role: 'user', content: 'Calculate 256 * 128' }, + TOOL_CALCULATOR + ] + const r2 = await runAndCollect(model, prompt2) + history.push({ role: 'user', content: 'Calculate 256 * 128' }) + history.push({ role: 'assistant', content: stripInternalBlocks(r2.output) }) + const calls2 = extractToolCalls(r2.output) + console.log(` Response tools called: [${calls2.join(', ') || 'none'}]`) + console.log(' Expected: [calculate]') + console.log(` ${calls2.includes('calculate') && !calls2.includes('getWeather') ? 'PASS ✓' : 'FAIL ✗'}`) + console.log('') + + // ── Turn 3: KEEP only calculate, ask about weather ── + console.log('── Turn 3: tools=[calculate] (getWeather still removed), ask about weather ──') + console.log(' This is the KEY test: model should NOT call getWeather (it was removed)') + const prompt3 = [ + { role: 'session', content: cachePath }, + { role: 'session', content: 'reset' }, + { role: 'system', content: SYSTEM }, + ...history, + { role: 'user', content: 'What is the weather in London?' }, + TOOL_CALCULATOR + ] + const r3 = await runAndCollect(model, prompt3) + history.push({ role: 'user', content: 'What is the weather in London?' }) + history.push({ role: 'assistant', content: stripInternalBlocks(r3.output) }) + const calls3 = extractToolCalls(r3.output) + console.log(` Response tools called: [${calls3.join(', ') || 'none'}]`) + console.log(' Expected: NOT getWeather (it\'s not available)') + const weatherLeak = calls3.includes('getWeather') + console.log(` ${weatherLeak ? 'FAIL ✗ — stale tool leak! getWeather was called despite being removed' : 'PASS ✓ — model did not call removed tool'}`) + console.log('') + + // ── Turn 4: bring back getWeather, remove calculate, ask to calculate ── + console.log('── Turn 4: tools=[getWeather] (calculate REMOVED), ask to calculate ──') + console.log(' Model should NOT call calculate (it was removed)') + const prompt4 = [ + { role: 'session', content: cachePath }, + { role: 'session', content: 'reset' }, + { role: 'system', content: SYSTEM }, + ...history, + { role: 'user', content: 'Calculate 999 / 3' }, + TOOL_WEATHER + ] + const r4 = await runAndCollect(model, prompt4) + const calls4 = extractToolCalls(r4.output) + console.log(` Response tools called: [${calls4.join(', ') || 'none'}]`) + console.log(' Expected: NOT calculate (it\'s not available)') + const calcLeak = calls4.includes('calculate') + console.log(` ${calcLeak ? 'FAIL ✗ — stale tool leak! calculate was called despite being removed' : 'PASS ✓ — model did not call removed tool'}`) + console.log('') + + // ── Summary ── + console.log('='.repeat(70)) + console.log('SUMMARY (tools_in_system, reset+replay)') + console.log('='.repeat(70)) + const results = [ + { turn: 1, pass: calls1.includes('getWeather'), desc: 'getWeather available → called it' }, + { turn: 2, pass: calls2.includes('calculate') && !calls2.includes('getWeather'), desc: 'calculate available, getWeather removed → called calculate' }, + { turn: 3, pass: !weatherLeak, desc: 'getWeather removed → did NOT call it' }, + { turn: 4, pass: !calcLeak, desc: 'calculate removed → did NOT call it' } + ] + for (const r of results) { + console.log(` Turn ${r.turn}: ${r.pass ? 'PASS ✓' : 'FAIL ✗'} — ${r.desc}`) + } + const allPass = results.every(r => r.pass) + console.log('') + console.log(allPass + ? ' ALL PASSED — tool switching correctly prevents stale tool usage' + : ' FAILURES DETECTED — removed tools leaked from conversation history') + } finally { + await model.unload() + await loader.close() + try { fs.unlinkSync(cachePath) } catch (_) {} + } +} + +async function runAll () { + await main() + await mainInSystem() +} + +runAll().catch(err => { + console.error('Fatal:', err.message || err) + process.exit(1) +}) diff --git a/packages/qvac-lib-infer-llamacpp-llm/package.json b/packages/qvac-lib-infer-llamacpp-llm/package.json index 89cb192489..d09f84e089 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/package.json +++ b/packages/qvac-lib-infer-llamacpp-llm/package.json @@ -1,6 +1,6 @@ { "name": "@qvac/llm-llamacpp", - "version": "0.13.0", + "version": "0.14.0", "description": "llama addon for qvac", "addon": true, "scripts": { diff --git a/packages/qvac-lib-infer-llamacpp-llm/test/integration/dynamic-tools.test.js b/packages/qvac-lib-infer-llamacpp-llm/test/integration/dynamic-tools.test.js new file mode 100644 index 0000000000..3dc346c041 --- /dev/null +++ b/packages/qvac-lib-infer-llamacpp-llm/test/integration/dynamic-tools.test.js @@ -0,0 +1,221 @@ +'use strict' + +const test = require('brittle') +const path = require('bare-path') +const FilesystemDL = require('@qvac/dl-filesystem') +const LlmLlamacpp = require('../../index.js') +const { ensureModel } = require('./utils') +const { attachSpecLogger } = require('./spec-logger') +const os = require('bare-os') + +const isDarwinX64 = os.platform() === 'darwin' && os.arch() === 'x64' +const isLinuxArm64 = os.platform() === 'linux' && os.arch() === 'arm64' +const useCpu = isDarwinX64 || isLinuxArm64 + +const QWEN3_MODEL = { + name: 'Qwen3-0.6B-Q8_0.gguf', + url: 'https://huggingface.co/unsloth/Qwen3-0.6B-GGUF/resolve/main/Qwen3-0.6B-Q8_0.gguf' +} + +const SYSTEM_MESSAGE = { role: 'system', content: 'You are a helpful assistant.' } + +const BASE_CONFIG = { + device: useCpu ? 'cpu' : 'gpu', + gpu_layers: '999', + ctx_size: '4096', + n_predict: '64', + temp: '0.1', + seed: '1', + verbosity: '2', + tools: 'true', + tools_at_end: 'true' +} + +const TOOL_A = { + type: 'function', + name: 'getWeather', + description: 'Get current weather for a city', + parameters: { + type: 'object', + properties: { city: { type: 'string', description: 'City name' } }, + required: ['city'] + } +} + +const TOOL_B = { + type: 'function', + name: 'searchProducts', + description: 'Search for products in catalog', + parameters: { + type: 'object', + properties: { query: { type: 'string', description: 'Search query' } }, + required: ['query'] + } +} + +const TOOL_C = { + type: 'function', + name: 'sendEmail', + description: 'Send an email message', + parameters: { + type: 'object', + properties: { + to: { type: 'string', description: 'Recipient email' }, + body: { type: 'string', description: 'Email body' } + }, + required: ['to', 'body'] + } +} + +const toNumber = value => typeof value === 'number' ? value : Number(value || 0) + +function normalizeStats (rawStats = {}) { + return { + CacheTokens: toNumber(rawStats?.CacheTokens), + promptTokens: toNumber(rawStats?.promptTokens), + generatedTokens: toNumber(rawStats?.generatedTokens) + } +} + +async function setupModel (t, overrides = {}) { + const [modelName, dirPath] = await ensureModel({ + modelName: QWEN3_MODEL.name, + downloadUrl: QWEN3_MODEL.url + }) + + const loader = new FilesystemDL({ dirPath }) + const config = { ...BASE_CONFIG, ...overrides } + const specLogger = attachSpecLogger({ forwardToConsole: true }) + let loggerReleased = false + const releaseLogger = () => { + if (loggerReleased) return + loggerReleased = true + specLogger.release() + } + + const model = new LlmLlamacpp({ + loader, + modelName, + diskPath: dirPath, + logger: console, + opts: { stats: true } + }, config) + + try { + await model.load() + } catch (err) { + releaseLogger() + await loader.close().catch(() => {}) + throw err + } + + t.teardown(async () => { + await model.unload().catch(() => {}) + await loader.close().catch(() => {}) + releaseLogger() + }) + + return { model, dirPath } +} + +async function runAndCollect (model, prompt) { + const response = await model.run(prompt) + const chunks = [] + let chain = response.onUpdate(data => { chunks.push(data) }) + if (typeof response.onError === 'function') { + chain = chain.onError(err => { throw err }) + } + await chain.await() + return { + output: chunks.join(''), + stats: normalizeStats(response.stats) + } +} + +test('[dynamic-tools] multi-turn session with changing tools does not accumulate stale tokens', { timeout: 600_000 }, async t => { + const { model, dirPath } = await setupModel(t) + const sessionName = path.join(dirPath, 'dynamic-tools-changing.bin') + + const prompt1 = [ + { role: 'session', content: sessionName }, + SYSTEM_MESSAGE, + { role: 'user', content: 'Hello, what can you do?' }, + TOOL_A + ] + const r1 = await runAndCollect(model, prompt1) + t.ok(r1.output.length > 0, 'turn 1 produces output') + t.ok(r1.stats.CacheTokens > 0, 'turn 1 has cache tokens') + + const prompt2 = [ + { role: 'session', content: sessionName }, + { role: 'user', content: 'Search for laptops' }, + TOOL_B + ] + const r2 = await runAndCollect(model, prompt2) + t.ok(r2.output.length > 0, 'turn 2 produces output') + t.ok(r2.stats.CacheTokens > 0, 'turn 2 has cache tokens') + + const prompt3 = [ + { role: 'session', content: sessionName }, + { role: 'user', content: 'Send a report' }, + TOOL_C + ] + const r3 = await runAndCollect(model, prompt3) + t.ok(r3.output.length > 0, 'turn 3 produces output') + t.ok(r3.stats.CacheTokens > 0, 'turn 3 has cache tokens') + + const naiveAccumulation = r1.stats.CacheTokens + r2.stats.promptTokens + r2.stats.generatedTokens + r3.stats.promptTokens + r3.stats.generatedTokens + t.ok( + r3.stats.CacheTokens < naiveAccumulation, + `CacheTokens after 3 turns (${r3.stats.CacheTokens}) should be less than naive accumulation (${naiveAccumulation}) — proves old tools are trimmed` + ) + + t.ok( + r3.stats.CacheTokens < 2 * r1.stats.CacheTokens, + `CacheTokens after 3 turns (${r3.stats.CacheTokens}) should be less than 2x turn 1 (${2 * r1.stats.CacheTokens}) — tools are replaced, not accumulated` + ) +}) + +test('[dynamic-tools] multi-turn session with same tools works correctly', { timeout: 600_000 }, async t => { + const { model, dirPath } = await setupModel(t) + const sessionName = path.join(dirPath, 'dynamic-tools-same.bin') + + const prompt1 = [ + { role: 'session', content: sessionName }, + SYSTEM_MESSAGE, + { role: 'user', content: 'What is the weather in Paris?' }, + TOOL_A + ] + const r1 = await runAndCollect(model, prompt1) + t.ok(r1.output.length > 0, 'turn 1 produces output') + t.ok(r1.stats.CacheTokens > 0, 'turn 1 has cache tokens') + + const prompt2 = [ + { role: 'session', content: sessionName }, + { role: 'user', content: 'What about London?' }, + TOOL_A + ] + const r2 = await runAndCollect(model, prompt2) + t.ok(r2.output.length > 0, 'turn 2 produces output') + t.ok(r2.stats.CacheTokens > 0, 'turn 2 has cache tokens') + t.ok( + r2.stats.CacheTokens < 2 * r1.stats.CacheTokens, + `CacheTokens after turn 2 (${r2.stats.CacheTokens}) should be less than 2x turn 1 (${2 * r1.stats.CacheTokens})` + ) +}) + +test('[dynamic-tools] single-shot with tools works without session', { timeout: 600_000 }, async t => { + const { model } = await setupModel(t) + + const prompt = [ + SYSTEM_MESSAGE, + { role: 'user', content: 'What is the weather in Tokyo?' }, + TOOL_A + ] + const r = await runAndCollect(model, prompt) + t.ok(r.output.length > 0, 'produces output') + t.is(r.stats.CacheTokens, 0, 'no cache tokens without session') + t.ok(r.stats.promptTokens > 0, 'prompt tokens tracked') + t.ok(r.stats.generatedTokens > 0, 'generated tokens tracked') + t.end() +}) diff --git a/packages/qvac-lib-infer-llamacpp-llm/test/mobile/integration.auto.cjs b/packages/qvac-lib-infer-llamacpp-llm/test/mobile/integration.auto.cjs index 95705a51e5..3bcfb9d515 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/test/mobile/integration.auto.cjs +++ b/packages/qvac-lib-infer-llamacpp-llm/test/mobile/integration.auto.cjs @@ -26,6 +26,10 @@ async function runConfigParametersTest (options = {}) { // eslint-disable-line n return runIntegrationModule('../integration/config-parameters.test.js', options) } +async function runDynamicToolsTest (options = {}) { // eslint-disable-line no-unused-vars + return runIntegrationModule('../integration/dynamic-tools.test.js', options) +} + async function runFinetuningPauseResumeTest (options = {}) { // eslint-disable-line no-unused-vars return runIntegrationModule('../integration/finetuning-pause-resume.test.js', options) } diff --git a/packages/qvac-lib-infer-llamacpp-llm/test/unit/CMakeLists.txt b/packages/qvac-lib-infer-llamacpp-llm/test/unit/CMakeLists.txt index b527732951..82b4755b0e 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/test/unit/CMakeLists.txt +++ b/packages/qvac-lib-infer-llamacpp-llm/test/unit/CMakeLists.txt @@ -13,7 +13,9 @@ add_executable( test_llama_model.cpp test_llama_finetuning_helpers.cpp test_cache_management.cpp + test_cache_management_qwen3.cpp test_text_llm_context.cpp + test_text_llm_context_qwen3.cpp test_addon_cpp.cpp test_backend_selection.cpp test_tune_config_map.cpp @@ -43,6 +45,7 @@ add_executable( ${CMAKE_SOURCE_DIR}/addon/src/utils/ChatTemplateUtils.cpp ${CMAKE_SOURCE_DIR}/addon/src/utils/Qwen3ReasoningUtils.cpp ${CMAKE_SOURCE_DIR}/addon/src/utils/QwenTemplate.cpp + ${CMAKE_SOURCE_DIR}/addon/src/utils/Qwen3ToolsDynamicTemplate.cpp ) target_compile_options( diff --git a/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_cache_management.cpp b/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_cache_management.cpp index 6963d630ea..4510671747 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_cache_management.cpp +++ b/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_cache_management.cpp @@ -1030,3 +1030,41 @@ TEST_F(CacheManagementTest, CacheTokensExceedContextSize) { { processPromptString(model_small, loadInput); }, qvac_errors::StatusError); } + +TEST_F(CacheManagementTest, CacheWithToolsAtEndFalseSavesFullCache) { + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "false"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + std::string inputWithTools = + R"([{"role": "session", "content": "test_session1.bin"}, {"role": "user", "content": "What is the weather in Tokyo?"}, {"type": "function", "name": "getWeather", "description": "Get weather forecast", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}}])"; + + EXPECT_NO_THROW({ + std::string output = processPromptString(model, inputWithTools); + EXPECT_GE(output.length(), 0); + auto stats = model->runtimeStats(); + EXPECT_GE(stats.size(), 0); + }); + + auto statsBeforeSave = model->runtimeStats(); + double cacheTokensBeforeSave = getStatValue(statsBeforeSave, "CacheTokens"); + EXPECT_GT(cacheTokensBeforeSave, 0.0); + + llama_pos nPastBeforeTools = model->getNPastBeforeTools(); + EXPECT_EQ(nPastBeforeTools, -1); + + std::string saveInput = + R"([{"role": "session", "content": "test_session1.bin"}, {"role": "session", "content": "save"}])"; + EXPECT_NO_THROW({ + std::string saveOutput = processPromptString(model, saveInput); + EXPECT_EQ(saveOutput.length(), 0); + }); + + EXPECT_TRUE(fs::exists(session1_path)); +} diff --git a/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_cache_management_qwen3.cpp b/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_cache_management_qwen3.cpp new file mode 100644 index 0000000000..63f4ddf602 --- /dev/null +++ b/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_cache_management_qwen3.cpp @@ -0,0 +1,447 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "model-interface/LlamaModel.hpp" +#include "test_common.hpp" + +namespace fs = std::filesystem; + +namespace { +double getStatValue( + const qvac_lib_inference_addon_cpp::RuntimeStats& stats, + const std::string& key) { + for (const auto& stat : stats) { + if (stat.first == key) { + return std::visit( + [](const auto& value) -> double { + if constexpr (std::is_same_v< + std::decay_t, + double>) { + return value; + } else { + return static_cast(value); + } + }, + stat.second); + } + } + return 0.0; +} + +std::string processPromptString( + const std::unique_ptr& model, const std::string& input) { + LlamaModel::Prompt prompt; + prompt.input = input; + return model->processPrompt(prompt); +} + +bool isQwen3ModelPath(const std::string& path) { + std::string lowerPath = path; + std::transform( + lowerPath.begin(), + lowerPath.end(), + lowerPath.begin(), + [](unsigned char c) { return std::tolower(c); }); + return lowerPath.find("qwen3") != std::string::npos; +} +} // namespace + +class CacheManagementQwen3Test : public ::testing::Test { +protected: + void SetUp() override { + config_files["device"] = test_common::getTestDevice(); + config_files["ctx_size"] = "2048"; + config_files["gpu_layers"] = test_common::getTestGpuLayers(); + config_files["n_predict"] = "10"; + config_files["tools"] = "true"; + + test_model_path = test_common::BaseTestModelPath::get( + "Qwen3-1.7B-Q4_0.gguf", "Llama-3.2-1B-Instruct-Q4_0.gguf"); + test_projection_path = ""; + + config_files["backendsDir"] = test_common::getTestBackendsDir().string(); + + session1_path = "test_session1_qwen3.bin"; + session2_path = "test_session2_qwen3.bin"; + temp_session_path = "temp_session_qwen3.bin"; + } + + void TearDown() override { + for (const auto& session_file : + {session1_path, + session2_path, + temp_session_path, + std::string("test_large_cache_qwen3.bin")}) { + if (fs::exists(session_file)) { + fs::remove(session_file); + } + } + } + + bool hasValidModel() { return fs::exists(test_model_path); } + + std::unique_ptr createModel() { + if (!hasValidModel()) { + return nullptr; + } + std::string modelPath = test_model_path; + std::string projectionPath = test_projection_path; + auto configCopy = config_files; + auto model = std::make_unique( + std::move(modelPath), std::move(projectionPath), std::move(configCopy)); + model->waitForLoadInitialization(); + if (!model->isLoaded()) { + return nullptr; + } + return model; + } + + std::unordered_map config_files; + std::string test_model_path; + std::string test_projection_path; + std::string session1_path; + std::string session2_path; + std::string temp_session_path; +}; + +TEST_F(CacheManagementQwen3Test, CacheWithToolsAtEndTrueTrimsToolTokens) { + if (!isQwen3ModelPath(test_model_path)) { + GTEST_SKIP() << "Test requires Qwen3 model for tools_at_end feature"; + } + + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + std::string inputWithTools = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "user", "content": "What is the weather in Tokyo?"}, {"type": "function", "name": "getWeather", "description": "Get weather forecast", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}}])"; + + EXPECT_NO_THROW({ + std::string output = processPromptString(model, inputWithTools); + EXPECT_GE(output.length(), 0); + }); + + auto statsBeforeSave = model->runtimeStats(); + double cacheTokensBeforeSave = getStatValue(statsBeforeSave, "CacheTokens"); + EXPECT_GT(cacheTokensBeforeSave, 0.0); + + llama_pos nPastBeforeTools = model->getNPastBeforeTools(); + EXPECT_EQ(nPastBeforeTools, -1); + + std::string saveInput = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "session", "content": "save"}])"; + EXPECT_NO_THROW({ + std::string saveOutput = processPromptString(model, saveInput); + EXPECT_EQ(saveOutput.length(), 0); + }); + + EXPECT_TRUE(fs::exists(session1_path)); +} + +TEST_F(CacheManagementQwen3Test, CacheReloadWithToolsAtEndTrue) { + if (!isQwen3ModelPath(test_model_path)) { + GTEST_SKIP() << "Test requires Qwen3 model for tools_at_end feature"; + } + + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "true"; + auto model1 = createModel(); + if (!model1) { + FAIL() << "Model failed to load"; + } + + std::string inputWithTools = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "user", "content": "What is the weather in Tokyo?"}, {"type": "function", "name": "getWeather", "description": "Get weather forecast", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}}])"; + + EXPECT_NO_THROW({ + std::string output = processPromptString(model1, inputWithTools); + EXPECT_GE(output.length(), 0); + }); + + llama_pos nPastBeforeTools1 = model1->getNPastBeforeTools(); + EXPECT_EQ(nPastBeforeTools1, -1); + + std::string saveInput = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "session", "content": "save"}])"; + EXPECT_NO_THROW({ + std::string saveOutput = processPromptString(model1, saveInput); + EXPECT_EQ(saveOutput.length(), 0); + }); + + EXPECT_TRUE(fs::exists(session1_path)); + + model1.reset(); + + auto model2 = createModel(); + if (!model2) { + FAIL() << "Model failed to load"; + } + + EXPECT_NO_THROW({ + std::string output = processPromptString( + model2, + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "user", "content": "What is the weather in London?"}])"); + EXPECT_GE(output.length(), 0); + }); + + auto statsAfterReload = model2->runtimeStats(); + double cacheTokensAfterReload = getStatValue(statsAfterReload, "CacheTokens"); + EXPECT_GT(cacheTokensAfterReload, 0.0); + + llama_pos nPastBeforeTools2 = model2->getNPastBeforeTools(); + EXPECT_EQ(nPastBeforeTools2, -1); +} + +TEST_F(CacheManagementQwen3Test, CacheWithoutToolsWithToolsAtEndTrue) { + if (!isQwen3ModelPath(test_model_path)) { + GTEST_SKIP() << "Test requires Qwen3 model for tools_at_end feature"; + } + + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + std::string inputNoTools = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "user", "content": "What is bitcoin? Answer shortly."}])"; + + EXPECT_NO_THROW({ + std::string output = processPromptString(model, inputNoTools); + EXPECT_GE(output.length(), 0); + }); + + auto statsBeforeSave = model->runtimeStats(); + double cacheTokensBeforeSave = getStatValue(statsBeforeSave, "CacheTokens"); + EXPECT_GT(cacheTokensBeforeSave, 0.0); + + llama_pos nPastBeforeTools = model->getNPastBeforeTools(); + EXPECT_EQ(nPastBeforeTools, -1); + + std::string saveInput = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "session", "content": "save"}])"; + EXPECT_NO_THROW({ + std::string saveOutput = processPromptString(model, saveInput); + EXPECT_EQ(saveOutput.length(), 0); + }); + + EXPECT_TRUE(fs::exists(session1_path)); +} + +TEST_F(CacheManagementQwen3Test, CacheToolsAtEndModeWithMultiplePrompts) { + if (!isQwen3ModelPath(test_model_path)) { + GTEST_SKIP() << "Test requires Qwen3 model for tools_at_end feature"; + } + + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + std::string input1 = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "user", "content": "Hi"}, {"type": "function", "name": "get_weather", "description": "Get detailed weather forecast data with temperature humidity wind speed precipitation UV visibility pressure sunrise sunset alerts", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The name of the city to get weather for"}, "country": {"type": "string", "description": "Country code or name"}, "lat": {"type": "number", "description": "Latitude coordinate"}, "lon": {"type": "number", "description": "Longitude coordinate"}, "zip": {"type": "string", "description": "ZIP postal code"}, "units": {"type": "string", "description": "Temperature units metric imperial or kelvin"}, "lang": {"type": "string", "description": "Language code for localized descriptions"}, "forecast_days": {"type": "integer", "description": "Number of days to forecast from 1 to 7"}, "hourly": {"type": "boolean", "description": "Include hourly forecast data"}, "alerts": {"type": "boolean", "description": "Include weather alerts and warnings"}, "aqi": {"type": "boolean", "description": "Include air quality index data"}, "tides": {"type": "boolean", "description": "Include tide information"}, "solar": {"type": "boolean", "description": "Include solar data like sunrise sunset"}, "tz": {"type": "string", "description": "Timezone identifier"}, "start_dt": {"type": "string", "description": "Start datetime for historical data"}, "end_dt": {"type": "string", "description": "End datetime for historical data"}, "cnt": {"type": "integer", "description": "Number of data points to return"}, "mode": {"type": "string", "description": "Response mode json xml or html"}, "appid": {"type": "string", "description": "API key for authentication"}}, "required": ["city"]}}])"; + + EXPECT_NO_THROW({ + std::string output = processPromptString(model, input1); + EXPECT_GE(output.length(), 0); + }); + + auto stats1 = model->runtimeStats(); + double cacheTokens1 = getStatValue(stats1, "CacheTokens"); + double promptTokens1 = getStatValue(stats1, "promptTokens"); + EXPECT_GT(cacheTokens1, 0.0); + EXPECT_GT(promptTokens1, 500.0); + + const int maxExpectedCacheTokens = 50; + EXPECT_GT(cacheTokens1, 0); + EXPECT_LE(cacheTokens1, maxExpectedCacheTokens) + << "Cache tokens (" << cacheTokens1 << ") should not exceed " + << maxExpectedCacheTokens << " - function tokens should be trimmed"; + + std::string input2 = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "user", "content": "What about London?"}])"; + + EXPECT_NO_THROW({ + std::string output = processPromptString(model, input2); + EXPECT_GE(output.length(), 0); + }); + + auto stats2 = model->runtimeStats(); + double cacheTokens2 = getStatValue(stats2, "CacheTokens"); + double promptTokens2 = getStatValue(stats2, "promptTokens"); + EXPECT_GT(cacheTokens2, cacheTokens1); + EXPECT_LT(promptTokens2, 500.0); + EXPECT_LE(cacheTokens2, maxExpectedCacheTokens) + << "Cache tokens (" << cacheTokens1 << ") should not exceed " + << maxExpectedCacheTokens << " - function tokens should be trimmed"; + + std::string saveInput = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "session", "content": "save"}])"; + EXPECT_NO_THROW({ + std::string saveOutput = processPromptString(model, saveInput); + EXPECT_EQ(saveOutput.length(), 0); + }); + + EXPECT_TRUE(fs::exists(session1_path)); + + model.reset(); + + auto model2 = createModel(); + if (!model2) { + FAIL() << "Model2 failed to load"; + } + + std::string input3 = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "user", "content": "What about Paris?"}])"; + + EXPECT_NO_THROW({ + std::string output = processPromptString(model2, input3); + EXPECT_GE(output.length(), 0); + }); + + auto stats3 = model2->runtimeStats(); + double cacheTokens3 = getStatValue(stats3, "CacheTokens"); + double promptTokens3 = getStatValue(stats3, "promptTokens"); + + EXPECT_GT(cacheTokens3, cacheTokens2); + EXPECT_LT(promptTokens3, 100.0); + + auto model3 = createModel(); + if (!model3) { + FAIL() << "Model3 failed to load"; + } + + std::string getTokensInput = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "session", "content": "getTokens"}])"; + EXPECT_NO_THROW({ + std::string output = processPromptString(model3, getTokensInput); + EXPECT_EQ(output.length(), 0); + }); + + auto stats4 = model3->runtimeStats(); + double cacheTokens4 = getStatValue(stats4, "CacheTokens"); + EXPECT_EQ(cacheTokens4, cacheTokens2); +} + +TEST_F( + CacheManagementQwen3Test, + CacheToolsAtEndModeTrimOnlyWhenNPastBeforeToolsPositive) { + if (!isQwen3ModelPath(test_model_path)) { + GTEST_SKIP() << "Test requires Qwen3 model for tools_at_end feature"; + } + + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + std::string inputNoTools = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "user", "content": "Hello"}])"; + + EXPECT_NO_THROW({ + std::string output = processPromptString(model, inputNoTools); + EXPECT_GE(output.length(), 0); + }); + + llama_pos nPastBeforeTools = model->getNPastBeforeTools(); + EXPECT_EQ(nPastBeforeTools, -1); + + auto statsBeforeSave = model->runtimeStats(); + double cacheTokensBeforeSave = getStatValue(statsBeforeSave, "CacheTokens"); + EXPECT_GT(cacheTokensBeforeSave, 0.0); + + std::string saveInput = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "session", "content": "save"}])"; + EXPECT_NO_THROW({ + std::string saveOutput = processPromptString(model, saveInput); + EXPECT_EQ(saveOutput.length(), 0); + }); + + auto statsAfterSave = model->runtimeStats(); + double cacheTokensAfterSave = getStatValue(statsAfterSave, "CacheTokens"); + EXPECT_EQ(cacheTokensAfterSave, cacheTokensBeforeSave); +} + +TEST_F(CacheManagementQwen3Test, CacheToolsAtEndModeRestoresNPastBeforeTools) { + if (!isQwen3ModelPath(test_model_path)) { + GTEST_SKIP() << "Test requires Qwen3 model for tools_at_end feature"; + } + + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + std::string input1 = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "user", "content": "Hi"}, {"type": "function", "name": "get_weather", "description": "Get weather", "parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}}])"; + + EXPECT_NO_THROW({ + std::string output = processPromptString(model, input1); + EXPECT_GE(output.length(), 0); + }); + + llama_pos nPastBeforeTools1 = model->getNPastBeforeTools(); + EXPECT_EQ(nPastBeforeTools1, -1); + + std::string saveInput = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "session", "content": "save"}])"; + EXPECT_NO_THROW({ + std::string saveOutput = processPromptString(model, saveInput); + EXPECT_EQ(saveOutput.length(), 0); + }); + + EXPECT_TRUE(fs::exists(session1_path)); + + auto model2 = createModel(); + if (!model2) { + FAIL() << "Model2 failed to load"; + } + + std::string input2 = + R"([{"role": "session", "content": "test_session1_qwen3.bin"}, {"role": "user", "content": "What about London?"}])"; + + EXPECT_NO_THROW({ + std::string output = processPromptString(model2, input2); + EXPECT_GE(output.length(), 0); + }); + + llama_pos nPastBeforeTools2 = model2->getNPastBeforeTools(); + EXPECT_EQ(nPastBeforeTools2, -1); +} diff --git a/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_chat_template_utils.cpp b/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_chat_template_utils.cpp index 14742e74e2..fcb2ac5291 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_chat_template_utils.cpp +++ b/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_chat_template_utils.cpp @@ -8,6 +8,8 @@ #include "model-interface/LlamaModel.hpp" #include "test_common.hpp" #include "utils/ChatTemplateUtils.hpp" +#include "utils/Qwen3ToolsDynamicTemplate.hpp" +#include "utils/QwenTemplate.hpp" namespace fs = std::filesystem; using namespace qvac_lib_inference_addon_llama::utils; @@ -37,14 +39,33 @@ TEST_F(ChatTemplateUtilsTest, IsQwen3ModelWithNullptr) { EXPECT_FALSE(isQwen3Model(nullptr)); } -TEST_F(ChatTemplateUtilsTest, GetChatTemplateForModelWithManualOverride) { +TEST_F( + ChatTemplateUtilsTest, + GetChatTemplateForModelWithManualOverrideToolsAtEndFalse) { std::string manual_override = "custom template"; - std::string result = getChatTemplateForModel(nullptr, manual_override); + std::string result = getChatTemplateForModel(nullptr, manual_override, false); EXPECT_EQ(result, manual_override); } -TEST_F(ChatTemplateUtilsTest, GetChatTemplateForModelEmptyOverrideNullptr) { - std::string result = getChatTemplateForModel(nullptr, ""); +TEST_F( + ChatTemplateUtilsTest, + GetChatTemplateForModelWithManualOverrideToolsAtEndTrue) { + std::string manual_override = "custom template"; + std::string result = getChatTemplateForModel(nullptr, manual_override, true); + EXPECT_EQ(result, manual_override); +} + +TEST_F( + ChatTemplateUtilsTest, + GetChatTemplateForModelEmptyOverrideNullptrToolsAtEndFalse) { + std::string result = getChatTemplateForModel(nullptr, "", false); + EXPECT_EQ(result, ""); +} + +TEST_F( + ChatTemplateUtilsTest, + GetChatTemplateForModelEmptyOverrideNullptrToolsAtEndTrue) { + std::string result = getChatTemplateForModel(nullptr, "", true); EXPECT_EQ(result, ""); } @@ -53,7 +74,7 @@ TEST_F(ChatTemplateUtilsTest, GetChatTemplateWithNullptrModel) { params.chat_template = "test template"; params.use_jinja = false; - std::string result = getChatTemplate(nullptr, params); + std::string result = getChatTemplate(nullptr, params, false); EXPECT_EQ(result, params.chat_template); } @@ -62,7 +83,7 @@ TEST_F(ChatTemplateUtilsTest, GetChatTemplateJinjaDisabled) { params.chat_template = "test template"; params.use_jinja = false; - std::string result = getChatTemplate(nullptr, params); + std::string result = getChatTemplate(nullptr, params, false); EXPECT_EQ(result, "test template"); } @@ -71,7 +92,7 @@ TEST_F(ChatTemplateUtilsTest, GetChatTemplateJinjaEnabledWithOverride) { params.chat_template = "custom template"; params.use_jinja = true; - std::string result = getChatTemplate(nullptr, params); + std::string result = getChatTemplate(nullptr, params, false); EXPECT_EQ(result, "custom template"); } @@ -80,7 +101,7 @@ TEST_F(ChatTemplateUtilsTest, GetChatTemplateJinjaEnabledWithoutOverride) { params.chat_template = ""; params.use_jinja = true; - std::string result = getChatTemplate(nullptr, params); + std::string result = getChatTemplate(nullptr, params, false); EXPECT_EQ(result, ""); } @@ -89,7 +110,7 @@ TEST_F(ChatTemplateUtilsTest, GetChatTemplateParamsNotModified) { params.chat_template = "original template"; params.use_jinja = false; - std::string result = getChatTemplate(nullptr, params); + std::string result = getChatTemplate(nullptr, params, false); EXPECT_EQ(params.chat_template, "original template"); EXPECT_FALSE(params.use_jinja); @@ -98,13 +119,241 @@ TEST_F(ChatTemplateUtilsTest, GetChatTemplateParamsNotModified) { TEST_F(ChatTemplateUtilsTest, GetChatTemplateForModelPreservesWhitespace) { std::string overrideWithSpaces = " template with spaces "; - std::string result = getChatTemplateForModel(nullptr, overrideWithSpaces); + std::string result = + getChatTemplateForModel(nullptr, overrideWithSpaces, false); EXPECT_EQ(result, overrideWithSpaces); } TEST_F( ChatTemplateUtilsTest, GetChatTemplateForModelPreservesSpecialCharacters) { std::string overrideSpecial = "template\nwith\tspecial\rchars"; - std::string result = getChatTemplateForModel(nullptr, overrideSpecial); + std::string result = getChatTemplateForModel(nullptr, overrideSpecial, false); EXPECT_EQ(result, overrideSpecial); } + +TEST_F(ChatTemplateUtilsTest, GetFixedQwen3TemplateNotNull) { + const char* expectedTemplate = getFixedQwen3Template(); + ASSERT_NE(expectedTemplate, nullptr); + EXPECT_GT(strlen(expectedTemplate), 0u); +} + +TEST_F(ChatTemplateUtilsTest, GetToolsDynamicQwen3TemplateNotNull) { + const char* expectedTemplate = getToolsDynamicQwen3Template(); + ASSERT_NE(expectedTemplate, nullptr); + EXPECT_GT(strlen(expectedTemplate), 0u); +} + +TEST_F(ChatTemplateUtilsTest, TemplatesAreDifferent) { + const char* fixedTemplate = getFixedQwen3Template(); + const char* dynamicTemplate = getToolsDynamicQwen3Template(); + ASSERT_NE(fixedTemplate, nullptr); + ASSERT_NE(dynamicTemplate, nullptr); + EXPECT_STRNE(fixedTemplate, dynamicTemplate); +} + +TEST_F(ChatTemplateUtilsTest, ManualOverrideTakesPrecedenceOverToolsAtEnd) { + common_params params; + params.chat_template = "my_custom_template"; + params.use_jinja = true; + + std::string result = getChatTemplate(nullptr, params, true); + EXPECT_EQ(result, "my_custom_template"); +} + +TEST_F( + ChatTemplateUtilsTest, ManualOverrideTakesPrecedenceOverToolsAtEndFalse) { + common_params params; + params.chat_template = "my_custom_template"; + params.use_jinja = true; + + std::string result = getChatTemplate(nullptr, params, false); + EXPECT_EQ(result, "my_custom_template"); +} + +// Tests with actual Qwen3 model loaded +class ChatTemplateUtilsQwen3Test : public ::testing::Test { +protected: + void SetUp() override { + config_files["device"] = test_common::getTestDevice(); + config_files["ctx_size"] = "2048"; + config_files["gpu_layers"] = test_common::getTestGpuLayers(); + config_files["n_predict"] = "10"; + + // Use Qwen3 model for testing + test_model_path = test_common::BaseTestModelPath::get( + "Qwen3-1.7B-Q4_0.gguf", "Llama-3.2-1B-Instruct-Q4_0.gguf"); + test_projection_path = ""; + + config_files["backendsDir"] = test_common::getTestBackendsDir().string(); + } + + std::unordered_map config_files; + std::string test_model_path; + std::string test_projection_path; + + std::unique_ptr createModel() { + if (!hasValidModel()) { + return nullptr; + } + std::string modelPath = test_model_path; + std::string projectionPath = test_projection_path; + auto configCopy = config_files; + auto model = std::make_unique( + std::move(modelPath), std::move(projectionPath), std::move(configCopy)); + model->waitForLoadInitialization(); + if (!model->isLoaded()) { + return nullptr; + } + return model; + } + + bool hasValidModel() { return fs::exists(test_model_path); } +}; + +TEST_F(ChatTemplateUtilsQwen3Test, IsQwen3ModelWithQwen3ModelLoaded) { + if (!hasValidModel()) { + GTEST_SKIP() << "Qwen3 model not found at " << test_model_path; + } + + auto model = createModel(); + ASSERT_NE(model, nullptr) << "Failed to load Qwen3 model"; + ASSERT_TRUE(model->isLoaded()) << "Qwen3 model not loaded successfully"; + + llama_model* llamaModel = model->getModel(); + ASSERT_NE(llamaModel, nullptr) << "Llama model pointer is null"; + + EXPECT_TRUE(isQwen3Model(llamaModel)) + << "Model should be detected as Qwen3 model"; +} + +TEST_F(ChatTemplateUtilsQwen3Test, GetChatTemplateForModelWithQwen3NoOverride) { + if (!hasValidModel()) { + GTEST_SKIP() << "Qwen3 model not found at " << test_model_path; + } + + auto model = createModel(); + ASSERT_NE(model, nullptr) << "Failed to load Qwen3 model"; + ASSERT_TRUE(model->isLoaded()) << "Qwen3 model not loaded successfully"; + + llama_model* llamaModel = model->getModel(); + ASSERT_NE(llamaModel, nullptr) << "Llama model pointer is null"; + + // Without override, should return Qwen3 template + std::string result = getChatTemplateForModel(llamaModel, "", false); + EXPECT_NE(result, "") << "Should return Qwen3 template when no override provided"; + EXPECT_GT(result.length(), 0u) << "Template should not be empty"; +} + +TEST_F(ChatTemplateUtilsQwen3Test, GetChatTemplateForModelWithQwen3ToolsAtEnd) { + if (!hasValidModel()) { + GTEST_SKIP() << "Qwen3 model not found at " << test_model_path; + } + + auto model = createModel(); + ASSERT_NE(model, nullptr) << "Failed to load Qwen3 model"; + ASSERT_TRUE(model->isLoaded()) << "Qwen3 model not loaded successfully"; + + llama_model* llamaModel = model->getModel(); + ASSERT_NE(llamaModel, nullptr) << "Llama model pointer is null"; + + // With toolsAtEnd=true, should return dynamic template + std::string result = getChatTemplateForModel(llamaModel, "", true); + EXPECT_NE(result, "") << "Should return Qwen3 tools template when no override provided"; + EXPECT_GT(result.length(), 0u) << "Template should not be empty"; +} + +TEST_F( + ChatTemplateUtilsQwen3Test, + GetChatTemplateForModelWithQwen3ManualOverrideTakesPrecedence) { + if (!hasValidModel()) { + GTEST_SKIP() << "Qwen3 model not found at " << test_model_path; + } + + auto model = createModel(); + ASSERT_NE(model, nullptr) << "Failed to load Qwen3 model"; + ASSERT_TRUE(model->isLoaded()) << "Qwen3 model not loaded successfully"; + + llama_model* llamaModel = model->getModel(); + ASSERT_NE(llamaModel, nullptr) << "Llama model pointer is null"; + + // Manual override should take precedence + std::string manualOverride = "custom qwen3 template"; + std::string result = getChatTemplateForModel(llamaModel, manualOverride, false); + EXPECT_EQ(result, manualOverride) + << "Manual override should take precedence over Qwen3 template"; +} + +TEST_F(ChatTemplateUtilsQwen3Test, GetChatTemplateWithQwen3JinjaEnabled) { + if (!hasValidModel()) { + GTEST_SKIP() << "Qwen3 model not found at " << test_model_path; + } + + auto model = createModel(); + ASSERT_NE(model, nullptr) << "Failed to load Qwen3 model"; + ASSERT_TRUE(model->isLoaded()) << "Qwen3 model not loaded successfully"; + + llama_model* llamaModel = model->getModel(); + ASSERT_NE(llamaModel, nullptr) << "Llama model pointer is null"; + + common_params params; + params.chat_template = ""; + params.use_jinja = true; + + // With Jinja enabled and no override, should use Qwen3 template + std::string result = getChatTemplate(llamaModel, params, false); + EXPECT_NE(result, "") << "Should return Qwen3 template when Jinja is enabled"; + EXPECT_GT(result.length(), 0u) << "Template should not be empty"; +} + +TEST_F( + ChatTemplateUtilsQwen3Test, + GetChatTemplateWithQwen3JinjaEnabledManualOverride) { + if (!hasValidModel()) { + GTEST_SKIP() << "Qwen3 model not found at " << test_model_path; + } + + auto model = createModel(); + ASSERT_NE(model, nullptr) << "Failed to load Qwen3 model"; + ASSERT_TRUE(model->isLoaded()) << "Qwen3 model not loaded successfully"; + + llama_model* llamaModel = model->getModel(); + ASSERT_NE(llamaModel, nullptr) << "Llama model pointer is null"; + + common_params params; + params.chat_template = "custom template"; + params.use_jinja = true; + + // With manual override, should use the override + std::string result = getChatTemplate(llamaModel, params, false); + EXPECT_EQ(result, "custom template") + << "Manual override should take precedence"; +} + +TEST_F(ChatTemplateUtilsQwen3Test, NonQwen3ModelNotDetectedAsQwen3) { + if (!hasValidModel()) { + GTEST_SKIP() << "Qwen3 model not found at " << test_model_path; + } + + // Test with Llama model instead + std::string llamaModelPath = test_common::BaseTestModelPath::get(); + if (!fs::exists(llamaModelPath)) { + GTEST_SKIP() << "Llama model not found at " << llamaModelPath; + } + + std::string modelPath = llamaModelPath; + std::string projectionPath = ""; + auto configCopy = config_files; + auto model = std::make_unique( + std::move(modelPath), std::move(projectionPath), std::move(configCopy)); + model->waitForLoadInitialization(); + + if (!model->isLoaded()) { + GTEST_SKIP() << "Llama model failed to load"; + } + + llama_model* llamaModel = model->getModel(); + ASSERT_NE(llamaModel, nullptr) << "Llama model pointer is null"; + + EXPECT_FALSE(isQwen3Model(llamaModel)) + << "Llama model should not be detected as Qwen3 model"; +} diff --git a/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_llama_model.cpp b/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_llama_model.cpp index 5bb80420e2..ab3fcd771f 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_llama_model.cpp +++ b/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_llama_model.cpp @@ -1052,3 +1052,90 @@ TEST_F(LlamaModelTest, ReloadDuringProcessingWaitsAndDoesNotCrash) { EXPECT_GE(output.length(), 0); }); } + +TEST_F(LlamaModelTest, CommonParamsParseToolsAtEndTrue) { + if (!fs::exists(getValidModelPath())) { + FAIL() << "Test model not found at: " << getValidModelPath(); + } + + std::unordered_map config; + config["device"] = test_common::getTestDevice(); + config["ctx_size"] = "2048"; + config["gpu_layers"] = test_common::getTestGpuLayers(); + config["n_predict"] = "10"; + config["tools_at_end"] = "true"; + + fs::path backendDir; +#ifdef TEST_BINARY_DIR + backendDir = fs::path(TEST_BINARY_DIR); +#else + backendDir = fs::current_path() / "build" / "test" / "unit"; +#endif + config["backendsDir"] = backendDir.string(); + + EXPECT_NO_THROW({ + LlamaModel model( + getValidModelPath(), + std::string(test_projection_path), + std::unordered_map(config)); + model.waitForLoadInitialization(); + }); +} + +TEST_F(LlamaModelTest, CommonParamsParseToolsAtEndFalse) { + if (!fs::exists(getValidModelPath())) { + FAIL() << "Test model not found at: " << getValidModelPath(); + } + + std::unordered_map config; + config["device"] = test_common::getTestDevice(); + config["ctx_size"] = "2048"; + config["gpu_layers"] = test_common::getTestGpuLayers(); + config["n_predict"] = "10"; + config["tools_at_end"] = "false"; + + fs::path backendDir; +#ifdef TEST_BINARY_DIR + backendDir = fs::path(TEST_BINARY_DIR); +#else + backendDir = fs::current_path() / "build" / "test" / "unit"; +#endif + config["backendsDir"] = backendDir.string(); + + EXPECT_NO_THROW({ + LlamaModel model( + getValidModelPath(), + std::string(test_projection_path), + std::unordered_map(config)); + model.waitForLoadInitialization(); + }); +} + +TEST_F(LlamaModelTest, CommonParamsParseToolsAtEndUppercase) { + if (!fs::exists(getValidModelPath())) { + FAIL() << "Test model not found at: " << getValidModelPath(); + } + + std::unordered_map config; + config["device"] = test_common::getTestDevice(); + config["ctx_size"] = "2048"; + config["gpu_layers"] = test_common::getTestGpuLayers(); + config["n_predict"] = "10"; + config["tools_at_end"] = "TRUE"; + + fs::path backendDir; +#ifdef TEST_BINARY_DIR + backendDir = fs::path(TEST_BINARY_DIR); +#else + backendDir = fs::current_path() / "build" / "test" / "unit"; +#endif + config["backendsDir"] = backendDir.string(); + + EXPECT_NO_THROW({ + LlamaModel model( + getValidModelPath(), + std::string(test_projection_path), + std::unordered_map(config)); + model.waitForLoadInitialization(); + }); +} diff --git a/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_text_llm_context.cpp b/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_text_llm_context.cpp index 131a3c3c86..fc4546c2b6 100644 --- a/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_text_llm_context.cpp +++ b/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_text_llm_context.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -371,3 +372,336 @@ TEST_F(TextLlmContextTest, ProcessWithMultipleTools) { EXPECT_GE(stats.size(), 0); }); } + +TEST_F(TextLlmContextTest, DoubleTokenizeWithoutToolsAtEnd) { + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "false"; + config_files["tools"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt prompt; + prompt.input = R"([ + {"role": "user", "content": "What is the weather in Tokyo?"}, + { + "type": "function", + "name": "getWeather", + "description": "Get weather forecast for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date in YYYY-MM-DD"} + }, + "required": ["city", "date"] + } + } + ])"; + + EXPECT_NO_THROW({ std::string output = model->processPrompt(prompt); }); + + auto stats = model->runtimeStats(); + int cacheTokens = static_cast(getStatValue(stats, "CacheTokens")); + int promptTokens = static_cast(getStatValue(stats, "promptTokens")); + EXPECT_EQ(cacheTokens, 0); + // prompt tokens with tools + EXPECT_GT(promptTokens, 200); +} + +TEST_F(TextLlmContextTest, DoubleTokenizeWithToolsAtEndNoTools) { + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "true"; + config_files["tools"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt prompt; + prompt.input = R"([{"role": "user", "content": "Hello, how are you?"}])"; + + EXPECT_NO_THROW({ std::string output = model->processPrompt(prompt); }); + + // Without tools, CacheTokens should equal promptTokens (no cached + // conversation tokens) + auto stats = model->runtimeStats(); + int promptTokens = static_cast(getStatValue(stats, "promptTokens")); + EXPECT_LT(promptTokens, 50); +} + +TEST_F(TextLlmContextTest, DoubleTokenizationTimeOverhead) { + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + const std::string promptWithTools = R"([ + {"role": "user", "content": "What is the weather in Tokyo?"}, + { + "type": "function", + "name": "getWeather", + "description": "Get weather forecast for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date in YYYY-MM-DD"} + }, + "required": ["city", "date"] + } + } + ])"; + + const int numIterations = 10; + + { + config_files["tools_at_end"] = "false"; + config_files["tools"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt prompt; + prompt.input = promptWithTools; + + auto startSingle = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < numIterations; ++i) { + model->reset(); + std::string output = model->processPrompt(prompt); + } + auto endSingle = std::chrono::high_resolution_clock::now(); + auto durationSingle = std::chrono::duration_cast( + endSingle - startSingle) + .count(); + + auto stats = model->runtimeStats(); + int promptTokens = static_cast(getStatValue(stats, "promptTokens")); + + GTEST_LOG_(INFO) << "Single tokenization (no tools_at_end): " + << durationSingle / numIterations << " us per iteration (" + << promptTokens << " prompt tokens)"; + } + + { + config_files["tools_at_end"] = "true"; + config_files["tools"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt prompt; + prompt.input = promptWithTools; + + auto startDouble = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < numIterations; ++i) { + model->reset(); + std::string output = model->processPrompt(prompt); + } + auto endDouble = std::chrono::high_resolution_clock::now(); + auto durationDouble = std::chrono::duration_cast( + endDouble - startDouble) + .count(); + + auto stats = model->runtimeStats(); + int promptTokens = static_cast(getStatValue(stats, "promptTokens")); + int cacheTokens = static_cast(getStatValue(stats, "CacheTokens")); + + GTEST_LOG_(INFO) << "Double tokenization (tools_at_end=true): " + << durationDouble / numIterations << " us per iteration (" + << promptTokens << " prompt tokens, " << cacheTokens + << " cached tokens)"; + } + + { + config_files["tools_at_end"] = "true"; + config_files["tools"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt promptNoTools; + promptNoTools.input = + R"([{"role": "user", "content": "Hello, how are you?"}])"; + + auto startNoTools = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < numIterations; ++i) { + model->reset(); + std::string output = model->processPrompt(promptNoTools); + } + auto endNoTools = std::chrono::high_resolution_clock::now(); + auto durationNoTools = + std::chrono::duration_cast( + endNoTools - startNoTools) + .count(); + + auto stats = model->runtimeStats(); + int promptTokens = static_cast(getStatValue(stats, "promptTokens")); + + GTEST_LOG_(INFO) << "Without tools (tools_at_end=true): " + << durationNoTools / numIterations << " us per iteration (" + << promptTokens << " prompt tokens)"; + } +} + +TEST_F(TextLlmContextTest, DoubleTokenizationTimeOverheadLargePrompt) { + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + std::string longContent; + for (int i = 0; i < 50; ++i) { + longContent += "This is a test message number " + std::to_string(i) + + ". It contains some text that will be tokenized into many " + "tokens. The purpose is to test the performance of " + "tokenization with a large prompt. "; + } + + const std::string promptWithTools = R"([ + {"role": "user", "content": ")" + longContent + + R"("}, + { + "type": "function", + "name": "getWeather", + "description": "Get weather forecast for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date in YYYY-MM-DD"} + }, + "required": ["city", "date"] + } + } + ])"; + + const int numIterations = 3; + + { + config_files["tools_at_end"] = "false"; + config_files["tools"] = "true"; + config_files["ctx_size"] = "4096"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt prompt; + prompt.input = promptWithTools; + + auto startSingle = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < numIterations; ++i) { + model->reset(); + std::string output = model->processPrompt(prompt); + } + auto endSingle = std::chrono::high_resolution_clock::now(); + auto durationSingle = std::chrono::duration_cast( + endSingle - startSingle) + .count(); + + auto stats = model->runtimeStats(); + int promptTokens = static_cast(getStatValue(stats, "promptTokens")); + + GTEST_LOG_(INFO) << "Large prompt - Single tokenization (no tools_at_end): " + << durationSingle / numIterations << " us per iteration (" + << promptTokens << " prompt tokens)"; + } + + { + config_files["tools_at_end"] = "true"; + config_files["tools"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt prompt; + prompt.input = promptWithTools; + + auto startDouble = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < numIterations; ++i) { + model->reset(); + std::string output = model->processPrompt(prompt); + } + auto endDouble = std::chrono::high_resolution_clock::now(); + auto durationDouble = std::chrono::duration_cast( + endDouble - startDouble) + .count(); + + auto stats = model->runtimeStats(); + int promptTokens = static_cast(getStatValue(stats, "promptTokens")); + int cacheTokens = static_cast(getStatValue(stats, "CacheTokens")); + + GTEST_LOG_(INFO) + << "Large prompt - Double tokenization (tools_at_end=true): " + << durationDouble / numIterations << " us per iteration (" + << promptTokens << " prompt tokens, " << cacheTokens + << " cached tokens)"; + } +} + +TEST_F(TextLlmContextTest, NPastBeforeToolsMinusOneWithoutTools) { + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "true"; + config_files["tools"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt prompt; + prompt.input = R"([{"role": "user", "content": "Hello, how are you?"}])"; + + EXPECT_NO_THROW({ std::string output = model->processPrompt(prompt); }); + + llama_pos nPastBeforeTools = model->getNPastBeforeTools(); + EXPECT_EQ(nPastBeforeTools, -1); +} + +TEST_F(TextLlmContextTest, NPastBeforeToolsMinusOneWhenToolsAtEndFalse) { + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "false"; + config_files["tools"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt prompt; + prompt.input = R"([ + {"role": "user", "content": "What is the weather in Tokyo?"}, + { + "type": "function", + "name": "getWeather", + "description": "Get weather forecast for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date in YYYY-MM-DD"} + }, + "required": ["city", "date"] + } + } + ])"; + + EXPECT_NO_THROW({ std::string output = model->processPrompt(prompt); }); + + llama_pos nPastBeforeTools = model->getNPastBeforeTools(); + EXPECT_EQ(nPastBeforeTools, -1); +} diff --git a/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_text_llm_context_qwen3.cpp b/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_text_llm_context_qwen3.cpp new file mode 100644 index 0000000000..67d803d3da --- /dev/null +++ b/packages/qvac-lib-infer-llamacpp-llm/test/unit/test_text_llm_context_qwen3.cpp @@ -0,0 +1,335 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "common/chat.h" +#include "model-interface/LlamaModel.hpp" +#include "model-interface/TextLlmContext.hpp" +#include "test_common.hpp" + +namespace { +double getStatValue( + const qvac_lib_inference_addon_cpp::RuntimeStats& stats, + const std::string& key) { + for (const auto& stat : stats) { + if (stat.first == key) { + return std::visit( + [](const auto& value) -> double { + if constexpr (std::is_same_v< + std::decay_t, + double>) { + return value; + } else { + return static_cast(value); + } + }, + stat.second); + } + } + return 0.0; +} + +bool isQwen3ModelPath(const std::string& path) { + std::string lowerPath = path; + std::transform( + lowerPath.begin(), + lowerPath.end(), + lowerPath.begin(), + [](unsigned char c) { return std::tolower(c); }); + return lowerPath.find("qwen3") != std::string::npos; +} +} // namespace + +namespace fs = std::filesystem; + +class TextLlmContextQwen3Test : public ::testing::Test { +protected: + void SetUp() override { + config_files["device"] = test_common::getTestDevice(); + config_files["ctx_size"] = "2048"; + config_files["gpu_layers"] = test_common::getTestGpuLayers(); + config_files["n_predict"] = "10"; + + // Use Qwen3 model if available, skip if not + test_model_path = test_common::BaseTestModelPath::get( + "Qwen3-1.7B-Q4_0.gguf", "Llama-3.2-1B-Instruct-Q4_0.gguf"); + test_projection_path = ""; + + config_files["backendsDir"] = test_common::getTestBackendsDir().string(); + } + + std::unordered_map config_files; + std::string test_model_path; + std::string test_projection_path; + + bool hasValidModel() { return fs::exists(test_model_path); } + bool isQwen3Model() { return isQwen3ModelPath(test_model_path); } + + std::unique_ptr createModel() { + if (!hasValidModel()) { + return nullptr; + } + std::string modelPath = test_model_path; + std::string projectionPath = test_projection_path; + auto configCopy = config_files; + auto model = std::make_unique( + std::move(modelPath), std::move(projectionPath), std::move(configCopy)); + model->waitForLoadInitialization(); + if (!model->isLoaded()) { + return nullptr; + } + return model; + } +}; + +TEST_F(TextLlmContextQwen3Test, DoubleTokenizeWithToolsAtEnd) { + if (!isQwen3Model()) { + GTEST_SKIP() << "Test requires Qwen3 model for tools_at_end feature"; + } + + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt prompt; + prompt.input = R"([ + {"role": "user", "content": "What is the weather in Tokyo?"}, + { + "type": "function", + "name": "getWeather", + "description": "Get weather forecast for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date in YYYY-MM-DD"} + }, + "required": ["city", "date"] + } + } + ])"; + + EXPECT_NO_THROW({ std::string output = model->processPrompt(prompt); }); + + auto stats = model->runtimeStats(); + int promptTokens = static_cast(getStatValue(stats, "promptTokens")); + EXPECT_GT(promptTokens, 0); +} + +TEST_F(TextLlmContextQwen3Test, DoubleTokenizeWithMultipleTools) { + if (!isQwen3Model()) { + GTEST_SKIP() << "Test requires Qwen3 model for tools_at_end feature"; + } + + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "true"; + config_files["tools"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt prompt; + prompt.input = R"([ + {"role": "user", "content": "Search for laptops and add to cart"}, + { + "type": "function", + "name": "searchProducts", + "description": "Search products", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"] + } + }, + { + "type": "function", + "name": "addToCart", + "description": "Add items to cart", + "parameters": { + "type": "object", + "properties": { + "items": { + "type": "array", + "items": {"type": "string"} + } + }, + "required": ["items"] + } + } + ])"; + + EXPECT_NO_THROW({ std::string output = model->processPrompt(prompt); }); + + auto stats = model->runtimeStats(); + int promptTokens = static_cast(getStatValue(stats, "promptTokens")); + EXPECT_GT(promptTokens, 0); +} + +TEST_F(TextLlmContextQwen3Test, DoubleTokenizeBoundaryAccuracy) { + if (!isQwen3Model()) { + GTEST_SKIP() << "Test requires Qwen3 model for tools_at_end feature"; + } + + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "true"; + config_files["tools"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt promptWithTools; + promptWithTools.input = R"([ + {"role": "user", "content": "What is the weather in Tokyo?"}, + { + "type": "function", + "name": "getWeather", + "description": "Get weather forecast for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date in YYYY-MM-DD"} + }, + "required": ["city", "date"] + } + } + ])"; + + EXPECT_NO_THROW( + { std::string output = model->processPrompt(promptWithTools); }); + + auto statsWithTools = model->runtimeStats(); + int promptTokensWithTools = + static_cast(getStatValue(statsWithTools, "promptTokens")); + EXPECT_GT(promptTokensWithTools, 150); + + EXPECT_NO_THROW({ model->reset(); }); + + LlamaModel::Prompt promptNoTools; + promptNoTools.input = + R"([{"role": "user", "content": "What is the weather in Tokyo?"}])"; + + EXPECT_NO_THROW( + { std::string output = model->processPrompt(promptNoTools); }); + + auto statsNoTools = model->runtimeStats(); + int promptTokensNoTools = + static_cast(getStatValue(statsNoTools, "promptTokens")); + + EXPECT_LT(promptTokensNoTools, 30); +} + +TEST_F(TextLlmContextQwen3Test, NPastBeforeToolsSetAfterEvalWithTools) { + if (!isQwen3Model()) { + GTEST_SKIP() << "Test requires Qwen3 model for tools_at_end feature"; + } + + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "true"; + config_files["tools"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt prompt; + prompt.input = R"([ + {"role": "user", "content": "What is the weather in Tokyo?"}, + { + "type": "function", + "name": "getWeather", + "description": "Get weather forecast for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date in YYYY-MM-DD"} + }, + "required": ["city", "date"] + } + } + ])"; + + EXPECT_NO_THROW({ std::string output = model->processPrompt(prompt); }); + + llama_pos nPastBeforeTools = model->getNPastBeforeTools(); + auto stats = model->runtimeStats(); + int promptTokens = static_cast(getStatValue(stats, "promptTokens")); + + EXPECT_EQ(nPastBeforeTools, -1); + EXPECT_GT(promptTokens, 0); +} + +TEST_F(TextLlmContextQwen3Test, NPastBeforeToolsResetAfterResetState) { + if (!isQwen3Model()) { + GTEST_SKIP() << "Test requires Qwen3 model for tools_at_end feature"; + } + + if (!hasValidModel()) { + FAIL() << "Test model not found"; + } + + config_files["tools_at_end"] = "true"; + config_files["tools"] = "true"; + auto model = createModel(); + if (!model) { + FAIL() << "Model failed to load"; + } + + LlamaModel::Prompt prompt; + prompt.input = R"([ + {"role": "user", "content": "What is the weather in Tokyo?"}, + { + "type": "function", + "name": "getWeather", + "description": "Get weather forecast for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"}, + "date": {"type": "string", "description": "Date in YYYY-MM-DD"} + }, + "required": ["city", "date"] + } + } + ])"; + + EXPECT_NO_THROW({ std::string output = model->processPrompt(prompt); }); + + llama_pos nPastBeforeToolsBeforeReset = model->getNPastBeforeTools(); + EXPECT_EQ(nPastBeforeToolsBeforeReset, -1); + + model->reset(); + + llama_pos nPastBeforeToolsAfterReset = model->getNPastBeforeTools(); + EXPECT_EQ(nPastBeforeToolsAfterReset, -1); +}