Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 8a800f4

Browse files
committed
feat: rendering chat_template
1 parent 5414e02 commit 8a800f4

File tree

17 files changed

+4349
-164
lines changed

17 files changed

+4349
-164
lines changed

engine/cli/commands/chat_completion_cmd.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,8 @@ void ChatCompletionCmd::Exec(const std::string& host, int port,
151151
json_data["model"] = model_handle;
152152
json_data["stream"] = true;
153153

154-
std::string json_payload = json_data.toStyledString();
155-
156-
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, json_payload.c_str());
154+
curl_easy_setopt(curl, CURLOPT_POSTFIELDS,
155+
json_data.toStyledString().c_str());
157156

158157
std::string ai_chat;
159158
StreamingCallback callback;

engine/common/model_metadata.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include "common/tokenizer.h"
4+
#include <sstream>
5+
6+
struct ModelMetadata {
7+
uint32_t version;
8+
uint64_t tensor_count;
9+
uint64_t metadata_kv_count;
10+
std::unique_ptr<Tokenizer> tokenizer;
11+
12+
std::string ToString() const {
13+
std::ostringstream ss;
14+
ss << "ModelMetadata {\n"
15+
<< "version: " << version << "\n"
16+
<< "tensor_count: " << tensor_count << "\n"
17+
<< "metadata_kv_count: " << metadata_kv_count << "\n"
18+
<< "tokenizer: ";
19+
20+
if (tokenizer) {
21+
ss << "\n" << tokenizer->ToString();
22+
} else {
23+
ss << "null";
24+
}
25+
26+
ss << "\n}";
27+
return ss.str();
28+
}
29+
};

engine/common/tokenizer.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#pragma once
2+
3+
#include <sstream>
4+
#include <string>
5+
6+
struct Tokenizer {
7+
std::string eos_token = "";
8+
bool add_eos_token = true;
9+
10+
std::string bos_token = "";
11+
bool add_bos_token = true;
12+
13+
std::string unknown_token = "";
14+
std::string padding_token = "";
15+
16+
std::string chat_template = "";
17+
18+
bool add_generation_prompt = true;
19+
20+
// Helper function for common fields
21+
std::string BaseToString() const {
22+
std::ostringstream ss;
23+
ss << "eos_token: \"" << eos_token << "\"\n"
24+
<< "add_eos_token: " << (add_eos_token ? "true" : "false") << "\n"
25+
<< "bos_token: \"" << bos_token << "\"\n"
26+
<< "add_bos_token: " << (add_bos_token ? "true" : "false") << "\n"
27+
<< "unknown_token: \"" << unknown_token << "\"\n"
28+
<< "padding_token: \"" << padding_token << "\"\n"
29+
<< "chat_template: \"" << chat_template << "\"\n"
30+
<< "add_generation_prompt: "
31+
<< (add_generation_prompt ? "true" : "false") << "\"";
32+
return ss.str();
33+
}
34+
35+
virtual ~Tokenizer() = default;
36+
37+
virtual std::string ToString() = 0;
38+
};
39+
40+
struct GgufTokenizer : public Tokenizer {
41+
std::string pre = "";
42+
43+
~GgufTokenizer() override = default;
44+
45+
std::string ToString() override {
46+
std::ostringstream ss;
47+
ss << "GgufTokenizer {\n";
48+
// Add base class members
49+
ss << BaseToString() << "\n";
50+
// Add derived class members
51+
ss << "pre: \"" << pre << "\"\n";
52+
ss << "}";
53+
return ss.str();
54+
}
55+
};
56+
57+
struct SafeTensorTokenizer : public Tokenizer {
58+
bool add_prefix_space = true;
59+
60+
~SafeTensorTokenizer() = default;
61+
62+
std::string ToString() override {
63+
std::ostringstream ss;
64+
ss << "SafeTensorTokenizer {\n";
65+
// Add base class members
66+
ss << BaseToString() << "\n";
67+
// Add derived class members
68+
ss << "add_prefix_space: " << (add_prefix_space ? "true" : "false") << "\n";
69+
ss << "}";
70+
return ss.str();
71+
}
72+
};

engine/controllers/files.cc

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,8 @@ void Files::RetrieveFileContent(
216216
return;
217217
}
218218

219-
auto [buffer, size] = std::move(res.value());
220-
auto resp = HttpResponse::newHttpResponse();
221-
resp->setBody(std::string(buffer.get(), size));
222-
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
219+
auto resp =
220+
cortex_utils::CreateCortexContentResponse(std::move(res.value()));
223221
callback(resp);
224222
} else {
225223
if (!msg_res->rel_path.has_value()) {
@@ -243,10 +241,8 @@ void Files::RetrieveFileContent(
243241
return;
244242
}
245243

246-
auto [buffer, size] = std::move(content_res.value());
247-
auto resp = HttpResponse::newHttpResponse();
248-
resp->setBody(std::string(buffer.get(), size));
249-
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
244+
auto resp = cortex_utils::CreateCortexContentResponse(
245+
std::move(content_res.value()));
250246
callback(resp);
251247
}
252248
}
@@ -261,9 +257,6 @@ void Files::RetrieveFileContent(
261257
return;
262258
}
263259

264-
auto [buffer, size] = std::move(res.value());
265-
auto resp = HttpResponse::newHttpResponse();
266-
resp->setBody(std::string(buffer.get(), size));
267-
resp->setContentTypeCode(CT_APPLICATION_OCTET_STREAM);
260+
auto resp = cortex_utils::CreateCortexContentResponse(std::move(res.value()));
268261
callback(resp);
269262
}

engine/controllers/server.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "trantor/utils/Logger.h"
44
#include "utils/cortex_utils.h"
55
#include "utils/function_calling/common.h"
6-
#include "utils/http_util.h"
76

87
using namespace inferences;
98

@@ -27,6 +26,14 @@ void server::ChatCompletion(
2726
std::function<void(const HttpResponsePtr&)>&& callback) {
2827
LOG_DEBUG << "Start chat completion";
2928
auto json_body = req->getJsonObject();
29+
if (json_body == nullptr) {
30+
Json::Value ret;
31+
ret["message"] = "Body can't be empty";
32+
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
33+
resp->setStatusCode(k400BadRequest);
34+
callback(resp);
35+
return;
36+
}
3037
bool is_stream = (*json_body).get("stream", false).asBool();
3138
auto model_id = (*json_body).get("model", "invalid_model").asString();
3239
auto engine_type = [this, &json_body]() -> std::string {

engine/main.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ void RunServer(std::optional<std::string> host, std::optional<int> port,
159159
auto model_src_svc = std::make_shared<services::ModelSourceService>();
160160
auto model_service = std::make_shared<ModelService>(
161161
download_service, inference_svc, engine_service);
162+
inference_svc->SetModelService(model_service);
162163

163164
auto file_watcher_srv = std::make_shared<FileWatcherService>(
164165
model_dir_path.string(), model_service);

engine/services/engine_service.h

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include <mutex>
55
#include <optional>
66
#include <string>
7-
#include <string_view>
87
#include <unordered_map>
98
#include <vector>
109

@@ -17,7 +16,6 @@
1716
#include "utils/cpuid/cpu_info.h"
1817
#include "utils/dylib.h"
1918
#include "utils/dylib_path_manager.h"
20-
#include "utils/engine_constants.h"
2119
#include "utils/github_release_utils.h"
2220
#include "utils/result.hpp"
2321
#include "utils/system_info_utils.h"
@@ -48,10 +46,6 @@ class EngineService : public EngineServiceI {
4846
struct EngineInfo {
4947
std::unique_ptr<cortex_cpp::dylib> dl;
5048
EngineV engine;
51-
#if defined(_WIN32)
52-
DLL_DIRECTORY_COOKIE cookie;
53-
DLL_DIRECTORY_COOKIE cuda_cookie;
54-
#endif
5549
};
5650

5751
std::mutex engines_mutex_;
@@ -105,21 +99,23 @@ class EngineService : public EngineServiceI {
10599

106100
cpp::result<DefaultEngineVariant, std::string> SetDefaultEngineVariant(
107101
const std::string& engine, const std::string& version,
108-
const std::string& variant);
102+
const std::string& variant) override;
109103

110104
cpp::result<DefaultEngineVariant, std::string> GetDefaultEngineVariant(
111-
const std::string& engine);
105+
const std::string& engine) override;
112106

113107
cpp::result<std::vector<EngineVariantResponse>, std::string>
114-
GetInstalledEngineVariants(const std::string& engine) const;
108+
GetInstalledEngineVariants(const std::string& engine) const override;
115109

116110
cpp::result<EngineV, std::string> GetLoadedEngine(
117111
const std::string& engine_name);
118112

119113
std::vector<EngineV> GetLoadedEngines();
120114

121-
cpp::result<void, std::string> LoadEngine(const std::string& engine_name);
122-
cpp::result<void, std::string> UnloadEngine(const std::string& engine_name);
115+
cpp::result<void, std::string> LoadEngine(
116+
const std::string& engine_name) override;
117+
cpp::result<void, std::string> UnloadEngine(
118+
const std::string& engine_name) override;
123119

124120
cpp::result<github_release_utils::GitHubRelease, std::string>
125121
GetLatestEngineVersion(const std::string& engine) const;
@@ -137,7 +133,7 @@ class EngineService : public EngineServiceI {
137133

138134
cpp::result<cortex::db::EngineEntry, std::string> GetEngineByNameAndVariant(
139135
const std::string& engine_name,
140-
const std::optional<std::string> variant = std::nullopt);
136+
const std::optional<std::string> variant = std::nullopt) override;
141137

142138
cpp::result<cortex::db::EngineEntry, std::string> UpsertEngine(
143139
const std::string& engine_name, const std::string& type,

engine/services/inference_service.cc

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#include "inference_service.h"
22
#include <drogon/HttpTypes.h>
33
#include "utils/engine_constants.h"
4+
#include "utils/file_manager_utils.h"
45
#include "utils/function_calling/common.h"
6+
#include "utils/gguf_metadata_reader.h"
7+
#include "utils/jinja_utils.h"
58

69
namespace services {
710
cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
@@ -24,6 +27,56 @@ cpp::result<void, InferResult> InferenceService::HandleChatCompletion(
2427
return cpp::fail(std::make_pair(stt, res));
2528
}
2629

30+
{
31+
// TODO: we can cache this one so we don't have to read the file every inference
32+
auto model_id = json_body->get("model", "").asString();
33+
if (!model_id.empty()) {
34+
if (auto model_service = model_service_.lock()) {
35+
auto model_config = model_service->GetDownloadedModel(model_id);
36+
if (model_config.has_value() && !model_config->files.empty()) {
37+
auto file = model_config->files[0];
38+
39+
auto model_metadata_res = cortex_utils::ReadGgufMetadata(
40+
file_manager_utils::ToAbsoluteCortexDataPath(
41+
std::filesystem::path(file)));
42+
if (model_metadata_res.has_value()) {
43+
auto metadata = model_metadata_res.value().get();
44+
if (!metadata->tokenizer->chat_template.empty()) {
45+
auto messages = (*json_body)["messages"];
46+
Json::Value messages_jsoncpp(Json::arrayValue);
47+
for (auto message : messages) {
48+
messages_jsoncpp.append(message);
49+
}
50+
51+
Json::Value tools(Json::arrayValue);
52+
Json::Value template_data_json;
53+
template_data_json["messages"] = messages_jsoncpp;
54+
// template_data_json["tools"] = tools;
55+
56+
auto prompt_result = jinja::RenderTemplate(
57+
metadata->tokenizer->chat_template, template_data_json,
58+
metadata->tokenizer->bos_token,
59+
metadata->tokenizer->eos_token,
60+
metadata->tokenizer->add_generation_prompt);
61+
if (prompt_result.has_value()) {
62+
(*json_body)["prompt"] = prompt_result.value();
63+
Json::Value stops(Json::arrayValue);
64+
stops.append(metadata->tokenizer->eos_token);
65+
(*json_body)["stop"] = stops;
66+
} else {
67+
CTL_ERR("Failed to render prompt: " + prompt_result.error());
68+
}
69+
}
70+
} else {
71+
CTL_ERR("Failed to read metadata: " + model_metadata_res.error());
72+
}
73+
}
74+
}
75+
}
76+
}
77+
78+
CTL_DBG("Json body inference: " + json_body->toStyledString());
79+
2780
auto cb = [q, tool_choice](Json::Value status, Json::Value res) {
2881
if (!tool_choice.isNull()) {
2982
res["tool_choice"] = tool_choice;
@@ -297,4 +350,4 @@ bool InferenceService::HasFieldInReq(std::shared_ptr<Json::Value> json_body,
297350
}
298351
return true;
299352
}
300-
} // namespace services
353+
} // namespace services

engine/services/inference_service.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
#include <mutex>
55
#include <queue>
66
#include "services/engine_service.h"
7+
#include "services/model_service.h"
78
#include "utils/result.hpp"
8-
#include "extensions/remote-engine/remote_engine.h"
9+
910
namespace services {
11+
1012
// Status and result
1113
using InferResult = std::pair<Json::Value, Json::Value>;
1214

@@ -58,7 +60,12 @@ class InferenceService {
5860
bool HasFieldInReq(std::shared_ptr<Json::Value> json_body,
5961
const std::string& field);
6062

63+
void SetModelService(std::shared_ptr<ModelService> model_service) {
64+
model_service_ = model_service;
65+
}
66+
6167
private:
6268
std::shared_ptr<EngineService> engine_service_;
69+
std::weak_ptr<ModelService> model_service_;
6370
};
6471
} // namespace services

engine/services/model_service.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
#include "config/yaml_config.h"
1010
#include "database/models.h"
1111
#include "hardware_service.h"
12+
#include "services/inference_service.h"
1213
#include "utils/cli_selection_utils.h"
13-
#include "utils/cortex_utils.h"
1414
#include "utils/engine_constants.h"
1515
#include "utils/file_manager_utils.h"
1616
#include "utils/huggingface_utils.h"

0 commit comments

Comments
 (0)