Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ba210e4
server: add option to output probabilities for completion
WangHaoranRobin Jun 21, 2023
8004e67
Merge pull request #1 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 21, 2023
ccf254b
server: fix comment about max n_probs
WangHaoranRobin Jun 22, 2023
926664c
Merge pull request #2 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 22, 2023
cf76195
server: fix issue when handling probability output for incomplete tok…
WangHaoranRobin Jun 23, 2023
bdb710e
Merge pull request #3 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 23, 2023
7b93b24
server: fix some beginner mistakes
WangHaoranRobin Jun 23, 2023
7cd8fc2
Merge pull request #4 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 23, 2023
6c76c31
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jun 23, 2023
02c96a4
server: remove trailling white space
WangHaoranRobin Jun 24, 2023
7f7046e
Merge pull request #5 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 24, 2023
23b516b
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jun 24, 2023
af058cf
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jun 25, 2023
e815b69
server: remove n_probs upper limit of 5
WangHaoranRobin Jun 25, 2023
bd6550b
Merge pull request #6 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 25, 2023
c9e6642
server: handle probs output when temp=0; handle final response probs …
WangHaoranRobin Jun 25, 2023
13f5d69
Merge branch 'master' into robin_fork_master
WangHaoranRobin Jun 25, 2023
77edee7
Merge pull request #7 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 25, 2023
b5c5c8e
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jun 26, 2023
c7f7f13
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jun 27, 2023
bc88fec
server: fix llama_sample_top_k order
WangHaoranRobin Jun 27, 2023
58828c2
Merge pull request #8 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jun 27, 2023
1d22550
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jun 28, 2023
ad80773
Merge branch 'ggerganov:master' into master
WangHaoranRobin Jul 1, 2023
1a70a80
examples/common.h: put all bool variables in gpt_params together
WangHaoranRobin Jul 2, 2023
71f8296
examples/common.h: put all bool variables in gpt_params together
WangHaoranRobin Jul 2, 2023
cc3c86f
Merge pull request #9 from WangHaoranRobin/robin_fork_master
WangHaoranRobin Jul 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct gpt_params {
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
bool low_vram = 0; // if true, reduce VRAM usage at the cost of performance
int32_t n_probs = 0; // if greater than 1, output the probabilities of top n_probs tokens. Max 5

// sampling parameters
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
Expand Down
126 changes: 99 additions & 27 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,28 @@ struct server_params {
int32_t write_timeout = 600;
};

// completion string output with probabilities
struct completion_string_output {
struct token_prob {
std::string tok_str;
float prob;
};

std::vector<token_prob> probs;
std::string tok_str;
};

// completion token output with probabilities
struct completion_token_output {
struct token_prob {
llama_token tok;
float prob;
};

std::vector<token_prob> probs;
llama_token tok;
};

static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
Expand Down Expand Up @@ -107,6 +129,7 @@ struct llama_server_context {
bool stream = false;
bool has_next_token = false;
std::string generated_text;
std::vector<completion_string_output> generated_text_probs;

size_t num_tokens_predicted = 0;
size_t n_past = 0;
Expand Down Expand Up @@ -137,6 +160,7 @@ struct llama_server_context {
num_tokens_predicted = 0;
generated_text = "";
generated_text.reserve(params.n_ctx);
generated_text_probs.clear();
truncated = false;
stopped_eos = false;
stopped_word = false;
Expand Down Expand Up @@ -216,8 +240,9 @@ struct llama_server_context {
llama_set_rng_seed(ctx, params.seed);
}

llama_token nextToken() {
llama_token result = -1;
completion_token_output nextToken() {
completion_token_output result;
result.tok = -1;

if (embd.size() >= (size_t)params.n_ctx) {
// Reset context
Expand Down Expand Up @@ -256,7 +281,8 @@ struct llama_server_context {

if (params.n_predict == 0) {
has_next_token = false;
return llama_token_eos();
result.tok = llama_token_eos();
return result;
}

// out of user input, sample next token
Expand All @@ -273,7 +299,7 @@ struct llama_server_context {
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
const int32_t n_probs = params.n_probs;

{
auto * logits = llama_get_logits(ctx);
Expand Down Expand Up @@ -307,35 +333,37 @@ struct llama_server_context {

if (temp <= 0) {
// Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p);
result.tok = llama_sample_token_greedy(ctx, &candidates_p);
} else {
if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else {
// Temperature sampling
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p);
result.tok = llama_sample_token(ctx, &candidates_p);
}
}
for (size_t i = 0; i < std::min(candidates_p.size, std::min((size_t) n_probs, size_t(5))); ++i) {
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
}
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);
last_n_tokens.push_back(result.tok);
num_tokens_predicted++;
}

// add it to the context
embd.push_back(id);
result = id;
embd.push_back(result.tok);
// decrement remaining sampling budget
--n_remain;

Expand Down Expand Up @@ -377,12 +405,22 @@ struct llama_server_context {
return stop_pos;
}

std::string doCompletion() {
const llama_token token = nextToken();
completion_string_output doCompletion() {
const completion_token_output token_with_probs = nextToken();
completion_string_output result;

const std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token);
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok);
result.tok_str = token_text;
generated_text += token_text;

// iterate through token_with_probs.probs, if tok is valid, convert it to string and add to result.prob
for (const auto & prob : token_with_probs.probs) {
const std::string prob_text = prob.tok == -1 ? "" : llama_token_to_str(ctx, prob.tok);
result.probs.push_back({prob_text, prob.prob});
}

generated_text_probs.push_back(result);

if (multibyte_pending > 0) {
multibyte_pending -= token_text.size();
} else if (token_text.size() == 1) {
Expand Down Expand Up @@ -411,8 +449,8 @@ struct llama_server_context {
}

LOG_VERBOSE("next token", {
{ "token", token },
{ "token_text", llama_token_to_str(ctx, token) },
{ "token", token_with_probs.tok },
{ "token_text", llama_token_to_str(ctx, token_with_probs.tok) },
{ "has_next_token", has_next_token },
{ "n_remain", n_remain },
{ "num_tokens_predicted", num_tokens_predicted },
Expand All @@ -422,7 +460,7 @@ struct llama_server_context {
{ "stopping_word", stopping_word },
});

return token_text;
return result;
}

std::vector<float> getEmbedding() {
Expand Down Expand Up @@ -664,6 +702,7 @@ static json format_generation_settings(llama_server_context & llama) {
{ "ignore_eos", ignore_eos },
{ "stream", llama.stream },
{ "logit_bias", llama.params.logit_bias },
{ "n_probs", llama.params.n_probs },
};
}

Expand All @@ -673,9 +712,26 @@ static json format_embedding_response(llama_server_context & llama) {
};
}

static json format_final_response(llama_server_context & llama, const std::string & content) {
static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector<completion_string_output> & probs) {

json completion_probabilities_json = json::array();
for (const auto & prob : probs) {
json probs_for_token = json::array();
for (const auto & p : prob.probs) {
probs_for_token.push_back(json {
{ "tok_str", p.tok_str },
{ "prob", p.prob },
});
}
completion_probabilities_json.push_back(json {
{"content", prob.tok_str},
{"probs", probs_for_token},
});
}

return json {
{ "content", content },
{ "completion_probabilities", completion_probabilities_json},
{ "stop", true },
{ "model", llama.params.model_alias },
{ "tokens_predicted", llama.num_tokens_predicted },
Expand All @@ -689,11 +745,25 @@ static json format_final_response(llama_server_context & llama, const std::strin
};
}

static json format_partial_response(const std::string & content) {
return json {
static json format_partial_response(const std::string & content, const completion_string_output & probs) {
json res = json {
{ "content", content },
{ "stop", false },
};

// iterate through probs.probs, and add to res
json probs_json = json::array();
for (const auto & prob : probs.probs) {
probs_json.push_back(json {
{ "tok_str", prob.tok_str },
{ "prob", prob.prob },
});
}
if (probs.probs.size() > 0) {
res["probs"] = probs_json;
}

return res;
}

static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
Expand Down Expand Up @@ -723,6 +793,7 @@ static void parse_options_completion(const json & body, llama_server_context & l
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
llama.params.seed = body.value("seed", default_params.seed);
llama.params.prompt = body.value("prompt", default_params.prompt);
llama.params.n_probs = body.value("n_probs", default_params.n_probs);

llama.params.logit_bias.clear();
if (body.value("ignore_eos", false)) {
Expand Down Expand Up @@ -825,7 +896,8 @@ int main(int argc, char ** argv) {
size_t stop_pos = std::string::npos;

while (llama.has_next_token) {
const std::string token_text = llama.doCompletion();
const completion_string_output token_text_with_probs = llama.doCompletion();
const std::string token_text = token_text_with_probs.tok_str;

stop_pos = llama.findStoppingStrings(llama.generated_text,
token_text.size(), STOP_FULL);
Expand All @@ -839,7 +911,7 @@ int main(int argc, char ** argv) {
llama.generated_text.end());
}

const json data = format_final_response(llama, llama.generated_text);
const json data = format_final_response(llama, llama.generated_text, llama.generated_text_probs);

llama_print_timings(llama.ctx);

Expand All @@ -850,7 +922,7 @@ int main(int argc, char ** argv) {
size_t sent_count = 0;

while (llama.has_next_token) {
const std::string token_text = llama.doCompletion();
const completion_string_output token_text_with_probs = llama.doCompletion();
if (llama.multibyte_pending > 0) {
continue;
}
Expand All @@ -859,24 +931,24 @@ int main(int argc, char ** argv) {

const std::string str_test = llama.generated_text.substr(pos);
size_t stop_pos =
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(), STOP_FULL);
if (stop_pos != std::string::npos) {
llama.generated_text.erase(
llama.generated_text.begin() + pos + stop_pos,
llama.generated_text.end());
pos = std::min(sent_count, llama.generated_text.size());
} else {
stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
stop_pos = llama.findStoppingStrings(str_test, token_text_with_probs.tok_str.size(),
STOP_PARTIAL);
}

const std::string to_send = llama.generated_text.substr(pos, stop_pos);
sent_count += to_send.size();

const json data = llama.has_next_token
? format_partial_response(to_send)
? format_partial_response(to_send, token_text_with_probs)
// Generation is done, send extra information.
: format_final_response(llama, to_send);
: format_final_response(llama, to_send, {token_text_with_probs});

const std::string str =
"data: " +
Expand Down