Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "/chat/completions" as alias for "/v1/chat/completions" #5722

Merged
merged 4 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 66 additions & 65 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3207,87 +3207,88 @@ int main(int argc, char **argv)
res.set_content(models.dump(), "application/json; charset=utf-8");
});

const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);

// TODO: add mount point without "/v1" prefix -- how?
svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);

const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, false, -1);

if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);

if (!result.error && result.stop) {
json oaicompat_result = format_final_response_oaicompat(data, result);
if (!result.error && result.stop) {
json oaicompat_result = format_final_response_oaicompat(data, result);

res.set_content(oaicompat_result.dump(-1, ' ', false,
json::error_handler_t::replace),
"application/json; charset=utf-8");
} else {
res.status = 500;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
}
llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
while (true) {
task_result llama_result = llama.queue_results.recv(task_id);
if (!llama_result.error) {
std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
res.set_content(oaicompat_result.dump(-1, ' ', false,
json::error_handler_t::replace),
"application/json; charset=utf-8");
} else {
res.status = 500;
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
}
llama.queue_results.remove_waiting_task_id(task_id);
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
while (true) {
task_result llama_result = llama.queue_results.recv(task_id);
if (!llama_result.error) {
std::vector<json> result_array = format_partial_response_oaicompat( llama_result);

for (auto it = result_array.begin(); it != result_array.end(); ++it)
{
if (!it->empty()) {
const std::string str =
"data: " +
it->dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
}
}
if (llama_result.stop) {
break;
}
} else {
for (auto it = result_array.begin(); it != result_array.end(); ++it)
{
if (!it->empty()) {
const std::string str =
"error: " +
llama_result.result_json.dump(-1, ' ', false,
json::error_handler_t::replace) +
"data: " +
it->dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
break;
}
}
sink.done();
llama.queue_results.remove_waiting_task_id(task_id);
return true;
};
if (llama_result.stop) {
break;
}
} else {
const std::string str =
"error: " +
llama_result.result_json.dump(-1, ' ', false,
json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
llama.queue_results.remove_waiting_task_id(task_id);
return false;
}
break;
}
}
sink.done();
llama.queue_results.remove_waiting_task_id(task_id);
return true;
};

auto on_complete = [task_id, &llama](bool) {
// cancel request
llama.request_cancel(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
};
auto on_complete = [task_id, &llama](bool) {
// cancel request
llama.request_cancel(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
};

res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
});
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
};

svr.Post("/chat/completions", chat_completions);
svr.Post("/v1/chat/completions", chat_completions);

svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{
Expand Down
22 changes: 22 additions & 0 deletions examples/server/tests/features/parallel.feature
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,28 @@ Feature: Parallel
| disabled | 128 |
| enabled | 64 |

Scenario Outline: Multi users OAI completions compatibility no v1
Given a system prompt You are a writer.
And a model tinyllama-2
Given a prompt:
"""
Write a very long book.
"""
And a prompt:
"""
Write another a poem.
"""
And <n_predict> max tokens to predict
And streaming is <streaming>
Given concurrent OAI completions requests no v1
Then the server is busy
Then the server is idle
Then all prompts are predicted with <n_predict> tokens
Examples:
| streaming | n_predict |
| disabled | 128 |
| enabled | 64 |

Scenario: Multi users with total number of tokens to predict exceeds the KV Cache size #3969
Given a prompt:
"""
Expand Down
28 changes: 26 additions & 2 deletions examples/server/tests/features/steps/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ async def step_oai_chat_completions(context, api_error):
completion = await oai_chat_completions(context.prompts.pop(),
context.system_prompt,
context.base_url,
'/v1/chat',
False,
model=context.model if hasattr(context, 'model') else None,

Expand Down Expand Up @@ -288,6 +289,28 @@ async def step_oai_chat_completions(context):
# user_prompt is inserted automatically
context.system_prompt,
context.base_url,
'/v1/chat/completions',
True, # async_client
model=context.model
if hasattr(context, 'model') else None,
n_predict=context.n_predict
if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None,
server_seed=context.server_seed
if hasattr(context, 'server_seed') else None,
user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None)


@step(u'concurrent OAI completions requests no v1')
@async_run_until_complete
async def step_oai_chat_completions(context):
await concurrent_requests(context, oai_chat_completions,
# user_prompt is inserted automatically
context.system_prompt,
context.base_url,
'/chat/completions',
True, # async_client
model=context.model
if hasattr(context, 'model') else None,
Expand Down Expand Up @@ -497,6 +520,7 @@ async def request_completion(prompt,
async def oai_chat_completions(user_prompt,
system_prompt,
base_url,
base_path,
async_client,
debug=False,
model=None,
Expand Down Expand Up @@ -537,7 +561,7 @@ async def oai_chat_completions(user_prompt,
origin = 'llama.cpp'
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/v1/chat/completions',
async with session.post(f'{base_url}{base_path}',
json=payload,
headers=headers) as response:
if enable_streaming:
Expand Down Expand Up @@ -579,7 +603,7 @@ async def oai_chat_completions(user_prompt,
else:
try:
openai.api_key = user_api_key
openai.api_base = f'{base_url}/v1/chat'
openai.api_base = f'{base_url}{base_path}'
chat_completion = openai.Completion.create(
messages=payload['messages'],
model=model,
Expand Down
Loading