Skip to content
2 changes: 1 addition & 1 deletion examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ node .

`stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`.

`prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. A space is inserted in the front like main.cpp does.
`prompt`: Provide a prompt as a string, or as an array of strings and numbers representing tokens. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. If the prompt is a string, or an array with the first element given as a string, a space is inserted in the front like main.cpp does.

`stop`: Specify a JSON array of stopping strings.
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []).
Expand Down
78 changes: 71 additions & 7 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ struct llama_server_context
size_t n_past = 0;
size_t n_remain = 0;

json prompt;
std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens;

Expand Down Expand Up @@ -257,10 +258,55 @@ struct llama_server_context
return true;
}

std::vector<llama_token> tokenizePrompt(json json_prompt)
{
std::vector<llama_token> prompt_tokens;

if (json_prompt.is_array())
{
bool first = true;
for (const auto& p : json_prompt)
{
if (p.is_string())
{
auto s = p.template get<std::string>();
std::vector<llama_token> p;
if (first)
{
s.insert(0, 1, ' '); // add a space if it's the first
p = ::llama_tokenize(ctx, s, true); // also add BOS
first = false;
}
else
{
p = ::llama_tokenize(ctx, s, false);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
}
else
{
if (first)
{
first = false;
}
prompt_tokens.push_back(p.template get<llama_token>());
}
}
}
else
{
auto s = json_prompt.template get<std::string>();
s.insert(0, 1, ' '); // always add a first space
prompt_tokens = ::llama_tokenize(ctx, s, true);
}

return prompt_tokens;
}

void loadPrompt()
{
params.prompt.insert(0, 1, ' '); // always add a first space
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
auto prompt_tokens = tokenizePrompt(prompt);

num_prompt_tokens = prompt_tokens.size();

if (params.n_keep < 0)
Expand Down Expand Up @@ -966,7 +1012,7 @@ static json format_final_response(llama_server_context &llama, const std::string
{"tokens_predicted", llama.num_tokens_predicted},
{"tokens_evaluated", llama.num_prompt_tokens},
{"generation_settings", format_generation_settings(llama)},
{"prompt", llama.params.prompt},
{"prompt", llama.prompt},
{"truncated", llama.truncated},
{"stopped_eos", llama.stopped_eos},
{"stopped_word", llama.stopped_word},
Expand Down Expand Up @@ -1027,9 +1073,17 @@ static void parse_options_completion(const json &body, llama_server_context &lla
llama.params.penalize_nl = body.value("penalize_nl", default_params.penalize_nl);
llama.params.n_keep = body.value("n_keep", default_params.n_keep);
llama.params.seed = body.value("seed", default_params.seed);
llama.params.prompt = body.value("prompt", default_params.prompt);
llama.params.n_probs = body.value("n_probs", default_params.n_probs);

if (body.count("content") != 0)
{
llama.prompt = body["prompt"];
}
else
{
llama.prompt = "";
}

llama.params.logit_bias.clear();
if (body.value("ignore_eos", false))
{
Expand Down Expand Up @@ -1270,8 +1324,11 @@ int main(int argc, char **argv)
auto lock = llama.lock();

const json body = json::parse(req.body);
const std::string content = body.value("content", "");
const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
std::vector<llama_token> tokens;
if (body.count("content") != 0)
{
tokens = llama.tokenizePrompt(body["content"]);
}
const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json"); });

Expand All @@ -1283,7 +1340,14 @@ int main(int argc, char **argv)

llama.rewind();
llama_reset_timings(llama.ctx);
llama.params.prompt = body.value("content", "");
if (body.count("content") != 0)
{
llama.prompt = body["content"];
}
else
{
llama.prompt = "";
}
llama.params.n_predict = 0;
llama.loadPrompt();
llama.beginCompletion();
Expand Down