diff --git a/LLMClientInterface.cpp b/LLMClientInterface.cpp index aaba9b1..6dd25ae 100644 --- a/LLMClientInterface.cpp +++ b/LLMClientInterface.cpp @@ -148,13 +148,16 @@ void LLMClientInterface::handleCompletion(const QJsonObject &request) auto updatedContext = prepareContext(request); LLMConfig config; + config.requestType = RequestType::Fim; config.provider = LLMProvidersManager::instance().getCurrentFimProvider(); - config.promptTemplate = PromptTemplateManager::instance().getCurrentTemplate(); + config.promptTemplate = PromptTemplateManager::instance().getCurrentFimTemplate(); config.url = QUrl(QString("%1%2").arg(Settings::generalSettings().url(), Settings::generalSettings().endPoint())); config.providerRequest = {{"model", Settings::generalSettings().modelName.value()}, - {"stream", true}}; + {"stream", true}, + {"stop", + QJsonArray::fromStringList(config.promptTemplate->stopWords())}}; config.promptTemplate->prepareRequest(config.providerRequest, updatedContext); config.provider->prepareRequest(config.providerRequest); diff --git a/PromptTemplateManager.cpp b/PromptTemplateManager.cpp index 1eea5c9..899d782 100644 --- a/PromptTemplateManager.cpp +++ b/PromptTemplateManager.cpp @@ -31,33 +31,48 @@ PromptTemplateManager &PromptTemplateManager::instance() void PromptTemplateManager::setCurrentFimTemplate(const QString &name) { - if (!m_templates.contains(name)) { - logMessage("Can't find prompt with name: " + name); + if (!m_fimTemplates.contains(name) || m_fimTemplates[name] == nullptr) { + logMessage("Error to set current FIM template" + name); return; } - if (m_templates[name] == nullptr) { - logMessage("Prompt is null"); - return; + m_currentFimTemplate = m_fimTemplates[name]; +} + +Templates::PromptTemplate *PromptTemplateManager::getCurrentFimTemplate() +{ + if (m_currentFimTemplate == nullptr) { + logMessage("Current fim provider is null"); + return nullptr; } - m_currentFimPrompt + return m_currentFimTemplate; } -Templates::PromptTemplate *PromptTemplateManager::getCurrentTemplate() +void PromptTemplateManager::setCurrentChatTemplate(const QString &name) { - auto it = m_templates.find(m_currentTemplateName); - return it != m_templates.end() ? it.value() : nullptr; + if (!m_chatTemplates.contains(name) || m_chatTemplates[name] == nullptr) { + logMessage("Error to set current chat template" + name); + return; + } + + m_currentChatTemplate = m_chatTemplates[name]; } -QStringList PromptTemplateManager::getTemplateNames() const +Templates::PromptTemplate *PromptTemplateManager::getCurrentChatTemplate() { - return m_templates.keys(); + if (m_currentChatTemplate == nullptr) { + logMessage("Current chat provider is null"); + return nullptr; + } + + return m_currentChatTemplate; } PromptTemplateManager::~PromptTemplateManager() { - qDeleteAll(m_templates); + qDeleteAll(m_fimTemplates); + qDeleteAll(m_chatTemplates); } } // namespace QodeAssist diff --git a/PromptTemplateManager.hpp b/PromptTemplateManager.hpp index 0abadbe..a7f3e10 100644 --- a/PromptTemplateManager.hpp +++ b/PromptTemplateManager.hpp @@ -40,24 +40,30 @@ class PromptTemplateManager "T must inherit from PromptTemplate"); T *template_ptr = new T(); QString name = template_ptr->name(); - m_templates[name] = template_ptr; - Settings::generalSettings().fimPrompts.addOption(name); + if (template_ptr->type() == Templates::TemplateType::Fim) { + m_fimTemplates[name] = template_ptr; + Settings::generalSettings().fimPrompts.addOption(name); + } else if (template_ptr->type() == Templates::TemplateType::Chat) { + m_chatTemplates[name] = template_ptr; + Settings::generalSettings().chatPrompts.addOption(name); + } } void setCurrentFimTemplate(const QString &name); Templates::PromptTemplate *getCurrentFimTemplate(); - QStringList getTemplateNames() const; - + void setCurrentChatTemplate(const QString &name); + Templates::PromptTemplate *getCurrentChatTemplate(); private: PromptTemplateManager() = default; PromptTemplateManager(const PromptTemplateManager &) = delete; PromptTemplateManager &operator=(const PromptTemplateManager &) = delete; - QMap m_templates; - Templates::PromptTemplate *m_currentFimPrompt; - Templates::PromptTemplate *m_currentChatPrompt; + QMap m_fimTemplates; + QMap m_chatTemplates; + Templates::PromptTemplate *m_currentFimTemplate; + Templates::PromptTemplate *m_currentChatTemplate; }; } // namespace QodeAssist diff --git a/chat/ChatClientInterface.cpp b/chat/ChatClientInterface.cpp index 5ec8b9a..c6277e7 100644 --- a/chat/ChatClientInterface.cpp +++ b/chat/ChatClientInterface.cpp @@ -54,7 +54,6 @@ ChatClientInterface::ChatClientInterface(QObject *parent) ChatClientInterface::~ChatClientInterface() { - logMessage("ChatClientInterface destroyed"); } void ChatClientInterface::sendMessage(const QString &message) @@ -65,9 +64,11 @@ void ChatClientInterface::sendMessage(const QString &message) prepareRequest(providerRequest, message); LLMConfig config; + config.requestType = RequestType::Chat; config.provider = LLMProvidersManager::instance().getCurrentChatProvider(); - config.promptTemplate = PromptTemplateManager::instance().getCurrentTemplate(); - config.url = QString("%1%2").arg(Settings::generalSettings().url(), "/api/chat"); + config.promptTemplate = PromptTemplateManager::instance().getCurrentChatTemplate(); + config.url = QString("%1%2").arg(Settings::generalSettings().chatUrl(), + Settings::generalSettings().chatEndPoint()); config.providerRequest = providerRequest; QJsonObject request; @@ -81,7 +82,7 @@ void ChatClientInterface::prepareRequest(QJsonObject &request, const QString &me { auto &settings = Settings::presetPromptsSettings(); - request["model"] = Settings::generalSettings().modelName(); //MODEL_NAME; + request["model"] = Settings::generalSettings().chatModelName(); QJsonArray messages = {QJsonObject{{"role", "user"}, {"content", message}}}; request["messages"] = messages; diff --git a/chat/ChatClientInterface.hpp b/chat/ChatClientInterface.hpp index 8341fd1..27b9cdd 100644 --- a/chat/ChatClientInterface.hpp +++ b/chat/ChatClientInterface.hpp @@ -45,10 +45,6 @@ class ChatClientInterface : public QObject LLMRequestHandler *m_requestHandler; QString m_accumulatedResponse; - - const QString MODEL_NAME = "bartowski/Llama-3.1-SauerkrautLM-8b-Instruct-GGUF"; - const QString SERVER_URL = "http://localhost:1234"; - const QString ENDPOINT = "/v1/chat/completions"; }; } // namespace QodeAssist::Chat diff --git a/chat/ChatWidget.cpp b/chat/ChatWidget.cpp index b24e0ce..dadc8f0 100644 --- a/chat/ChatWidget.cpp +++ b/chat/ChatWidget.cpp @@ -30,7 +30,7 @@ namespace QodeAssist::Chat { ChatWidget::ChatWidget(QWidget *parent) : QWidget(parent) - , m_showTimestamp(true) + , m_showTimestamp(false) , m_chatClient(new ChatClientInterface(this)) { setupUi(); diff --git a/core/LLMRequestConfig.hpp b/core/LLMRequestConfig.hpp index 33ccdb2..10b6ca2 100644 --- a/core/LLMRequestConfig.hpp +++ b/core/LLMRequestConfig.hpp @@ -7,12 +7,15 @@ namespace QodeAssist { +enum class RequestType { Fim, Chat }; + struct LLMConfig { QUrl url; Providers::LLMProvider *provider; Templates::PromptTemplate *promptTemplate; QJsonObject providerRequest; + RequestType requestType; }; } // namespace QodeAssist diff --git a/core/LLMRequestHandler.cpp b/core/LLMRequestHandler.cpp index e48ce4b..bb5b3ae 100644 --- a/core/LLMRequestHandler.cpp +++ b/core/LLMRequestHandler.cpp @@ -77,24 +77,24 @@ void LLMRequestHandler::handleLLMResponse(QNetworkReply *reply, QString &accumulatedResponse = m_accumulatedResponses[reply]; - auto provider = LLMProvidersManager::instance().getCurrentFimProvider(); - if (provider == nullptr) - qDebug() << "No provider selected"; + bool isComplete = config.provider->handleResponse(reply, accumulatedResponse); - bool isComplete = LLMProvidersManager::instance() - .getCurrentFimProvider() - ->handleResponse(reply, accumulatedResponse); - - if (!Settings::generalSettings().multiLineCompletion() - && processSingleLineCompletion(reply, request, accumulatedResponse, config)) { - return; + if (config.requestType == RequestType::Fim) { + if (!Settings::generalSettings().multiLineCompletion() + && processSingleLineCompletion(reply, request, accumulatedResponse, config)) { + return; + } } if (isComplete || reply->isFinished()) { if (isComplete) { - // auto cleanedCompletion = removeStopWords(accumulatedResponse, - // config.promptTemplate->stopWords()); - emit completionReceived(accumulatedResponse, request, true); + if (config.requestType == RequestType::Fim) { + auto cleanedCompletion = removeStopWords(accumulatedResponse, + config.promptTemplate->stopWords()); + emit completionReceived(cleanedCompletion, request, true); + } else { + emit completionReceived(accumulatedResponse, request, true); + } } else { emit completionReceived(accumulatedResponse, request, false); } diff --git a/providers/LMStudioProvider.cpp b/providers/LMStudioProvider.cpp index a9ffa09..63b7672 100644 --- a/providers/LMStudioProvider.cpp +++ b/providers/LMStudioProvider.cpp @@ -56,9 +56,6 @@ QString LMStudioProvider::chatEndpoint() const void LMStudioProvider::prepareRequest(QJsonObject &request) { auto &settings = Settings::presetPromptsSettings(); - const auto ¤tTemplate = PromptTemplateManager::instance().getCurrentTemplate(); - if (currentTemplate->name() == "Custom Template") - return; if (request.contains("prompt")) { QJsonArray messages{ {QJsonObject{{"role", "user"}, {"content", request.take("prompt").toString()}}}}; @@ -67,7 +64,6 @@ void LMStudioProvider::prepareRequest(QJsonObject &request) request["max_tokens"] = settings.maxTokens(); request["temperature"] = settings.temperature(); - request["stop"] = QJsonArray::fromStringList(currentTemplate->stopWords()); if (settings.useTopP()) request["top_p"] = settings.topP(); if (settings.useTopK()) diff --git a/providers/OllamaProvider.cpp b/providers/OllamaProvider.cpp index a10b08e..52b9d00 100644 --- a/providers/OllamaProvider.cpp +++ b/providers/OllamaProvider.cpp @@ -56,15 +56,11 @@ QString OllamaProvider::chatEndpoint() const void OllamaProvider::prepareRequest(QJsonObject &request) { auto &settings = Settings::presetPromptsSettings(); - auto currentTemplate = PromptTemplateManager::instance().getCurrentTemplate(); - if (currentTemplate->name() == "Custom Template") - return; QJsonObject options; options["num_predict"] = settings.maxTokens(); options["keep_alive"] = settings.ollamaLivetime(); options["temperature"] = settings.temperature(); - options["stop"] = QJsonArray::fromStringList(currentTemplate->stopWords()); if (settings.useTopP()) options["top_p"] = settings.topP(); if (settings.useTopK()) diff --git a/providers/OpenAICompatProvider.cpp b/providers/OpenAICompatProvider.cpp index 9bf68a3..25d1150 100644 --- a/providers/OpenAICompatProvider.cpp +++ b/providers/OpenAICompatProvider.cpp @@ -54,10 +54,6 @@ QString OpenAICompatProvider::chatEndpoint() const void OpenAICompatProvider::prepareRequest(QJsonObject &request) { auto &settings = Settings::presetPromptsSettings(); - const auto ¤tTemplate = PromptTemplateManager::instance().getCurrentTemplate(); - if (currentTemplate->name() == "Custom Template") - return; - if (request.contains("prompt")) { QJsonArray messages{ {QJsonObject{{"role", "user"}, {"content", request.take("prompt").toString()}}}}; @@ -66,7 +62,6 @@ void OpenAICompatProvider::prepareRequest(QJsonObject &request) request["max_tokens"] = settings.maxTokens(); request["temperature"] = settings.temperature(); - request["stop"] = QJsonArray::fromStringList(currentTemplate->stopWords()); if (settings.useTopP()) request["top_p"] = settings.topP(); if (settings.useTopK()) diff --git a/qodeassist.cpp b/qodeassist.cpp index 472dc72..5cbcd8f 100644 --- a/qodeassist.cpp +++ b/qodeassist.cpp @@ -80,13 +80,14 @@ class QodeAssistPlugin final : public ExtensionSystem::IPlugin providerManager.registerProvider(); providerManager.registerProvider(); providerManager.setCurrentFimProvider("Ollama"); + providerManager.setCurrentChatProvider("Ollama"); auto &templateManager = PromptTemplateManager::instance(); templateManager.registerTemplate(); templateManager.registerTemplate(); templateManager.registerTemplate(); templateManager.registerTemplate(); - templateManager.setCurrentTemplate("StarCoder2"); + templateManager.setCurrentFimTemplate("StarCoder2"); Utils::Icon QCODEASSIST_ICON( {{":/resources/images/qoderassist-icon.png", Utils::Theme::IconsBaseColor}}); diff --git a/templates/CodeLLamaTemplate.hpp b/templates/CodeLLamaTemplate.hpp index 6115d28..286d254 100644 --- a/templates/CodeLLamaTemplate.hpp +++ b/templates/CodeLLamaTemplate.hpp @@ -26,6 +26,7 @@ namespace QodeAssist::Templates { class CodeLLamaTemplate : public PromptTemplate { public: + TemplateType type() const override { return TemplateType::Fim; } QString name() const override { return "CodeLlama"; } QString promptTemplate() const override { return "%1
 %2 %3 "; }
     QStringList stopWords() const override
diff --git a/templates/CustomTemplate.hpp b/templates/CustomTemplate.hpp
index 5ae2e24..5e1b913 100644
--- a/templates/CustomTemplate.hpp
+++ b/templates/CustomTemplate.hpp
@@ -32,7 +32,8 @@ namespace QodeAssist::Templates {
 class CustomTemplate : public PromptTemplate
 {
 public:
-    QString name() const override { return "Custom Template"; }
+    TemplateType type() const override { return TemplateType::Fim; }
+    QString name() const override { return "Custom FIM Template"; }
     QString promptTemplate() const override
     {
         return Settings::customPromptSettings().customJsonTemplate();
diff --git a/templates/DeepSeekCoderV2.hpp b/templates/DeepSeekCoderV2.hpp
index 69886fb..34e47cf 100644
--- a/templates/DeepSeekCoderV2.hpp
+++ b/templates/DeepSeekCoderV2.hpp
@@ -26,6 +26,7 @@ namespace QodeAssist::Templates {
 class DeepSeekCoderV2Template : public PromptTemplate
 {
 public:
+    TemplateType type() const override { return TemplateType::Fim; }
     QString name() const override { return "DeepSeekCoderV2"; }
     QString promptTemplate() const override
     {
diff --git a/templates/PromptTemplate.hpp b/templates/PromptTemplate.hpp
index 5b134ee..0dee78e 100644
--- a/templates/PromptTemplate.hpp
+++ b/templates/PromptTemplate.hpp
@@ -27,10 +27,13 @@
 
 namespace QodeAssist::Templates {
 
+enum class TemplateType { Chat, Fim };
+
 class PromptTemplate
 {
 public:
     virtual ~PromptTemplate() = default;
+    virtual TemplateType type() const = 0;
     virtual QString name() const = 0;
     virtual QString promptTemplate() const = 0;
     virtual QStringList stopWords() const = 0;
diff --git a/templates/StarCoder2Template.hpp b/templates/StarCoder2Template.hpp
index 8f89679..ddc8e08 100644
--- a/templates/StarCoder2Template.hpp
+++ b/templates/StarCoder2Template.hpp
@@ -26,6 +26,7 @@ namespace QodeAssist::Templates {
 class StarCoder2Template : public PromptTemplate
 {
 public:
+    TemplateType type() const override { return TemplateType::Fim; }
     QString name() const override { return "StarCoder2"; }
     QString promptTemplate() const override { return "%1%2%3"; }
     QStringList stopWords() const override