Skip to content

Commit 4a558f5

Browse files
jxyakawrykow
authored andcommitted
server : allow json array in prompt or content for direct token input (ggml-org#2306)
* server: allow json array in prompt or content We accept an array of strings and numbers representing tokens, in addition to the current string valued prompt or content. This allows direct token input, so that any special tokens can be processed and used at the frontend during the construction of the json data, before sending to the server. And the server does not need to know or parse special tokens from textual input. With this, we can use EOS and BOS used in llama-2-chat models. * server: use tokenizePrompt(json) and default "" if empty prompt * server: fix prompt check * server: tokenize endpoint no longer adds BOS
1 parent ef6dd8b commit 4a558f5

File tree

2 files changed

+74
-8
lines changed

2 files changed

+74
-8
lines changed

examples/server/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ node .
126126

127127
`stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`.
128128

129-
`prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. A space is inserted in the front like main.cpp does.
129+
`prompt`: Provide a prompt as a string, or as an array of strings and numbers representing tokens. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. If the prompt is a string, or an array with the first element given as a string, a space is inserted in the front like main.cpp does.
130130

131131
`stop`: Specify a JSON array of stopping strings.
132132
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []).

examples/server/server.cpp

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ struct llama_server_context
190190
size_t n_past = 0;
191191
size_t n_remain = 0;
192192

193+
json prompt;
193194
std::vector<llama_token> embd;
194195
std::vector<llama_token> last_n_tokens;
195196

@@ -267,6 +268,53 @@ struct llama_server_context
267268
return true;
268269
}
269270

271+
std::vector<llama_token> tokenize(json json_prompt, bool add_bos)
272+
{
273+
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
274+
// or the first element of the json_prompt array is a string.
275+
std::vector<llama_token> prompt_tokens;
276+
277+
if (json_prompt.is_array())
278+
{
279+
bool first = true;
280+
for (const auto& p : json_prompt)
281+
{
282+
if (p.is_string())
283+
{
284+
auto s = p.template get<std::string>();
285+
std::vector<llama_token> p;
286+
if (first)
287+
{
288+
s.insert(0, 1, ' '); // add a space if it's the first
289+
p = ::llama_tokenize(ctx, s, add_bos);
290+
first = false;
291+
}
292+
else
293+
{
294+
p = ::llama_tokenize(ctx, s, false);
295+
}
296+
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
297+
}
298+
else
299+
{
300+
if (first)
301+
{
302+
first = false;
303+
}
304+
prompt_tokens.push_back(p.template get<llama_token>());
305+
}
306+
}
307+
}
308+
else
309+
{
310+
auto s = json_prompt.template get<std::string>();
311+
s.insert(0, 1, ' '); // always add a first space
312+
prompt_tokens = ::llama_tokenize(ctx, s, add_bos);
313+
}
314+
315+
return prompt_tokens;
316+
}
317+
270318
bool loadGrammar()
271319
{
272320
if (!params.grammar.empty()) {
@@ -294,8 +342,8 @@ struct llama_server_context
294342

295343
void loadPrompt()
296344
{
297-
params.prompt.insert(0, 1, ' '); // always add a first space
298-
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
345+
auto prompt_tokens = tokenize(prompt, true); // always add BOS
346+
299347
num_prompt_tokens = prompt_tokens.size();
300348

301349
if (params.n_keep < 0)
@@ -1016,7 +1064,7 @@ static json format_final_response(llama_server_context &llama, const std::string
10161064
{"tokens_predicted", llama.num_tokens_predicted},
10171065
{"tokens_evaluated", llama.num_prompt_tokens},
10181066
{"generation_settings", format_generation_settings(llama)},
1019-
{"prompt", llama.params.prompt},
1067+
{"prompt", llama.prompt},
10201068
{"truncated", llama.truncated},
10211069
{"stopped_eos", llama.stopped_eos},
10221070
{"stopped_word", llama.stopped_word},
@@ -1085,10 +1133,18 @@ static void parse_options_completion(const json &body, llama_server_context &lla
10851133
llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl);
10861134
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
10871135
llama.params.seed = json_value(body, "seed", default_params.seed);
1088-
llama.params.prompt = json_value(body, "prompt", default_params.prompt);
10891136
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
10901137
llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs);
10911138

1139+
if (body.count("prompt") != 0)
1140+
{
1141+
llama.prompt = body["prompt"];
1142+
}
1143+
else
1144+
{
1145+
llama.prompt = "";
1146+
}
1147+
10921148
llama.params.logit_bias.clear();
10931149
if (json_value(body, "ignore_eos", false))
10941150
{
@@ -1345,8 +1401,11 @@ int main(int argc, char **argv)
13451401
auto lock = llama.lock();
13461402

13471403
const json body = json::parse(req.body);
1348-
const std::string content = json_value<std::string>(body, "content", "");
1349-
const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
1404+
std::vector<llama_token> tokens;
1405+
if (body.count("content") != 0)
1406+
{
1407+
tokens = llama.tokenize(body["content"], false);
1408+
}
13501409
const json data = format_tokenizer_response(tokens);
13511410
return res.set_content(data.dump(), "application/json"); });
13521411

@@ -1358,7 +1417,14 @@ int main(int argc, char **argv)
13581417

13591418
llama.rewind();
13601419
llama_reset_timings(llama.ctx);
1361-
llama.params.prompt = json_value<std::string>(body, "content", "");
1420+
if (body.count("content") != 0)
1421+
{
1422+
llama.prompt = body["content"];
1423+
}
1424+
else
1425+
{
1426+
llama.prompt = "";
1427+
}
13621428
llama.params.n_predict = 0;
13631429
llama.loadPrompt();
13641430
llama.beginCompletion();

0 commit comments

Comments
 (0)