Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cpp): sync llama cpp #53

Merged
merged 3 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ if (VULKAN_SDK)
find_package(Vulkan REQUIRED)
endif()

set(LLAMA_BUILD_COMMON ON CACHE BOOL "Build common")

set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libraries")
add_subdirectory("src/llama.cpp")

Expand Down
2 changes: 1 addition & 1 deletion src/DetokenizeWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ DetokenizeWorker::DetokenizeWorker(const Napi::CallbackInfo &info,
_tokens(std::move(tokens)) {}

void DetokenizeWorker::Execute() {
const auto text = ::llama_detokenize(_sess->context(), _tokens);
const auto text = ::common_detokenize(_sess->context(), _tokens);
_text = std::move(text);
}

Expand Down
4 changes: 2 additions & 2 deletions src/EmbeddingWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ EmbeddingWorker::EmbeddingWorker(const Napi::CallbackInfo &info,

void EmbeddingWorker::Execute() {
llama_kv_cache_clear(_sess->context());
auto tokens = ::llama_tokenize(_sess->context(), _text, true);
auto tokens = ::common_tokenize(_sess->context(), _text, true);
// add SEP if not present
if (tokens.empty() || tokens.back() != llama_token_sep(_sess->model())) {
tokens.push_back(llama_token_sep(_sess->model()));
Expand All @@ -16,7 +16,7 @@ void EmbeddingWorker::Execute() {
do {
int ret =
llama_decode(_sess->context(),
llama_batch_get_one(tokens.data(), tokens.size(), 0, 0));
llama_batch_get_one(tokens.data(), tokens.size()));
if (ret < 0) {
SetError("Failed to inference, code: " + std::to_string(ret));
break;
Expand Down
16 changes: 8 additions & 8 deletions src/LlamaCompletionWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ size_t findStoppingStrings(const std::string &text,

LlamaCompletionWorker::LlamaCompletionWorker(
const Napi::CallbackInfo &info, LlamaSessionPtr &sess,
Napi::Function callback, gpt_params params,
Napi::Function callback, common_params params,
std::vector<std::string> stop_words)
: AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess),
_params(params), _stop_words(stop_words) {
Expand Down Expand Up @@ -64,11 +64,11 @@ void LlamaCompletionWorker::Execute() {

auto sparams = llama_sampler_chain_default_params();

LlamaCppSampling sampling{gpt_sampler_init(model, _params.sparams),
gpt_sampler_free};
LlamaCppSampling sampling{common_sampler_init(model, _params.sparams),
common_sampler_free};

std::vector<llama_token> prompt_tokens =
::llama_tokenize(ctx, _params.prompt, add_bos);
::common_tokenize(ctx, _params.prompt, add_bos);
n_input = prompt_tokens.size();
if (_sess->tokens_ptr()->size() > 0) {
n_cur = common_part(*(_sess->tokens_ptr()), prompt_tokens);
Expand Down Expand Up @@ -102,18 +102,18 @@ void LlamaCompletionWorker::Execute() {
_result.truncated = true;
}
int ret = llama_decode(
ctx, llama_batch_get_one(embd->data() + n_cur, n_input, n_cur, 0));
ctx, llama_batch_get_one(embd->data() + n_cur, n_input));
if (ret < 0) {
SetError("Failed to decode token, code: " + std::to_string(ret));
break;
}
// sample the next token
const llama_token new_token_id =
gpt_sampler_sample(sampling.get(), ctx, -1);
gpt_sampler_accept(sampling.get(), new_token_id, true);
common_sampler_sample(sampling.get(), ctx, -1);
common_sampler_accept(sampling.get(), new_token_id, true);
// prepare the next batch
embd->emplace_back(new_token_id);
auto token = llama_token_to_piece(ctx, new_token_id);
auto token = common_token_to_piece(ctx, new_token_id);
_result.text += token;
n_cur += n_input;
_result.tokens_evaluated += n_input;
Expand Down
4 changes: 2 additions & 2 deletions src/LlamaCompletionWorker.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class LlamaCompletionWorker : public Napi::AsyncWorker,
public Napi::Promise::Deferred {
public:
LlamaCompletionWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess,
Napi::Function callback, gpt_params params,
Napi::Function callback, common_params params,
std::vector<std::string> stop_words = {});

~LlamaCompletionWorker();
Expand All @@ -28,7 +28,7 @@ class LlamaCompletionWorker : public Napi::AsyncWorker,

private:
LlamaSessionPtr _sess;
gpt_params _params;
common_params _params;
std::vector<std::string> _stop_words;
Napi::ThreadSafeFunction _tsfn;
bool _has_callback = false;
Expand Down
17 changes: 8 additions & 9 deletions src/LlamaContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#include "SaveSessionWorker.h"
#include "TokenizeWorker.h"

std::vector<llama_chat_msg> get_messages(Napi::Array messages) {
std::vector<llama_chat_msg> chat;
std::vector<common_chat_msg> get_messages(Napi::Array messages) {
std::vector<common_chat_msg> chat;
for (size_t i = 0; i < messages.Length(); i++) {
auto message = messages.Get(i).As<Napi::Object>();
chat.push_back({
Expand Down Expand Up @@ -67,7 +67,7 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
}
auto options = info[0].As<Napi::Object>();

gpt_params params;
common_params params;
params.model = get_option<std::string>(options, "model", "");
if (params.model.empty()) {
Napi::TypeError::New(env, "Model is required").ThrowAsJavaScriptException();
Expand All @@ -86,15 +86,15 @@ LlamaContext::LlamaContext(const Napi::CallbackInfo &info)
llama_backend_init();
llama_numa_init(params.numa);

auto result = llama_init_from_gpt_params(params);
auto result = common_init_from_params(params);

if (result.model == nullptr || result.context == nullptr) {
Napi::TypeError::New(env, "Failed to load model")
.ThrowAsJavaScriptException();
}

_sess = std::make_shared<LlamaSession>(result.model, result.context, params);
_info = gpt_params_get_system_info(params);
_info = common_params_get_system_info(params);
}

// getSystemInfo(): string
Expand All @@ -109,7 +109,7 @@ Napi::Value LlamaContext::GetFormattedChat(const Napi::CallbackInfo &info) {
Napi::TypeError::New(env, "Array expected").ThrowAsJavaScriptException();
}
auto messages = info[0].As<Napi::Array>();
auto formatted = llama_chat_apply_template(_sess->model(), "", get_messages(messages), true);
auto formatted = common_chat_apply_template(_sess->model(), "", get_messages(messages), true);
return Napi::String::New(env, formatted);
}

Expand All @@ -133,10 +133,10 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
}
auto options = info[0].As<Napi::Object>();

gpt_params params = _sess->params();
common_params params = _sess->params();
if (options.Has("messages") && options.Get("messages").IsArray()) {
auto messages = options.Get("messages").As<Napi::Array>();
auto formatted = llama_chat_apply_template(_sess->model(), "", get_messages(messages), true);
auto formatted = common_chat_apply_template(_sess->model(), "", get_messages(messages), true);
params.prompt = formatted;
} else {
params.prompt = get_option<std::string>(options, "prompt", "");
Expand All @@ -150,7 +150,6 @@ Napi::Value LlamaContext::Completion(const Napi::CallbackInfo &info) {
params.sparams.top_k = get_option<int32_t>(options, "top_k", 40);
params.sparams.top_p = get_option<float>(options, "top_p", 0.95f);
params.sparams.min_p = get_option<float>(options, "min_p", 0.05f);
params.sparams.tfs_z = get_option<float>(options, "tfs_z", 1.00f);
params.sparams.mirostat = get_option<int32_t>(options, "mirostat", 0.00f);
params.sparams.mirostat_tau =
get_option<float>(options, "mirostat_tau", 5.00f);
Expand Down
2 changes: 1 addition & 1 deletion src/TokenizeWorker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ TokenizeWorker::TokenizeWorker(const Napi::CallbackInfo &info,
: AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), _text(text) {}

void TokenizeWorker::Execute() {
const auto tokens = ::llama_tokenize(_sess->context(), _text, false);
const auto tokens = ::common_tokenize(_sess->context(), _text, false);
_result.tokens = std::move(tokens);
}

Expand Down
8 changes: 4 additions & 4 deletions src/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

typedef std::unique_ptr<llama_model, decltype(&llama_free_model)> LlamaCppModel;
typedef std::unique_ptr<llama_context, decltype(&llama_free)> LlamaCppContext;
typedef std::unique_ptr<gpt_sampler, decltype(&gpt_sampler_free)>
typedef std::unique_ptr<common_sampler, decltype(&common_sampler_free)>
LlamaCppSampling;
typedef std::unique_ptr<llama_batch, decltype(&llama_batch_free)> LlamaCppBatch;

Expand Down Expand Up @@ -47,7 +47,7 @@ constexpr T get_option(const Napi::Object &options, const std::string &name,

class LlamaSession {
public:
LlamaSession(llama_model *model, llama_context *ctx, gpt_params params)
LlamaSession(llama_model *model, llama_context *ctx, common_params params)
: model_(LlamaCppModel(model, llama_free_model)),
ctx_(LlamaCppContext(ctx, llama_free)), params_(params) {
tokens_.reserve(params.n_ctx);
Expand All @@ -65,7 +65,7 @@ class LlamaSession {
tokens_ = std::move(tokens);
}

inline const gpt_params &params() const { return params_; }
inline const common_params &params() const { return params_; }

inline std::mutex &get_mutex() { return mutex; }

Expand All @@ -79,7 +79,7 @@ class LlamaSession {
private:
LlamaCppModel model_;
LlamaCppContext ctx_;
const gpt_params params_;
const common_params params_;
std::vector<llama_token> tokens_{};
std::mutex mutex;
};
Expand Down
2 changes: 1 addition & 1 deletion src/llama.cpp
Submodule llama.cpp updated 383 files