11#include " inference_service.h"
22#include < drogon/HttpTypes.h>
33#include " utils/engine_constants.h"
4+ #include " utils/file_manager_utils.h"
45#include " utils/function_calling/common.h"
6+ #include " utils/gguf_metadata_reader.h"
7+ #include " utils/jinja_utils.h"
58
69namespace services {
710cpp::result<void , InferResult> InferenceService::HandleChatCompletion (
@@ -24,6 +27,56 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
2427 return cpp::fail (std::make_pair (stt, res));
2528 }
2629
30+ {
31+ // TODO: we can cache this one so we don't have to read the file every inference
32+ auto model_id = json_body->get (" model" , " " ).asString ();
33+ if (!model_id.empty ()) {
34+ if (auto model_service = model_service_.lock ()) {
35+ auto model_config = model_service->GetDownloadedModel (model_id);
36+ if (model_config.has_value () && !model_config->files .empty ()) {
37+ auto file = model_config->files [0 ];
38+
39+ auto model_metadata_res = cortex_utils::ReadGgufMetadata (
40+ file_manager_utils::ToAbsoluteCortexDataPath (
41+ std::filesystem::path (file)));
42+ if (model_metadata_res.has_value ()) {
43+ auto metadata = model_metadata_res.value ().get ();
44+ if (!metadata->tokenizer ->chat_template .empty ()) {
45+ auto messages = (*json_body)[" messages" ];
46+ Json::Value messages_jsoncpp (Json::arrayValue);
47+ for (auto message : messages) {
48+ messages_jsoncpp.append (message);
49+ }
50+
51+ Json::Value tools (Json::arrayValue);
52+ Json::Value template_data_json;
53+ template_data_json[" messages" ] = messages_jsoncpp;
54+ // template_data_json["tools"] = tools;
55+
56+ auto prompt_result = jinja::RenderTemplate (
57+ metadata->tokenizer ->chat_template , template_data_json,
58+ metadata->tokenizer ->bos_token ,
59+ metadata->tokenizer ->eos_token ,
60+ metadata->tokenizer ->add_generation_prompt );
61+ if (prompt_result.has_value ()) {
62+ (*json_body)[" prompt" ] = prompt_result.value ();
63+ Json::Value stops (Json::arrayValue);
64+ stops.append (metadata->tokenizer ->eos_token );
65+ (*json_body)[" stop" ] = stops;
66+ } else {
67+ CTL_ERR (" Failed to render prompt: " + prompt_result.error ());
68+ }
69+ }
70+ } else {
71+ CTL_ERR (" Failed to read metadata: " + model_metadata_res.error ());
72+ }
73+ }
74+ }
75+ }
76+ }
77+
78+ CTL_DBG (" Json body inference: " + json_body->toStyledString ());
79+
2780 auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
2881 if (!tool_choice.isNull ()) {
2982 res[" tool_choice" ] = tool_choice;
@@ -297,4 +350,4 @@ bool InferenceService::HasFieldInReq(std::shared_ptr<Json::Value> json_body,
297350 }
298351 return true ;
299352}
300- } // namespace services
353+ } // namespace services
0 commit comments