Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified tools/server/public/index.html.gz
Binary file not shown.
880 changes: 1 addition & 879 deletions tools/server/server-ws.cpp

Large diffs are not rendered by default.

201 changes: 160 additions & 41 deletions tools/server/server-ws.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "server-common.h"
#include "server-http.h"
#include <functional>
#include <string>
#include <memory>
Expand All @@ -16,69 +17,187 @@ struct common_params;
struct sockaddr_in;
class ws_connection_impl;

// WebSocket connection interface
// Abstracts the underlying WebSocket implementation
struct server_ws_connection {
virtual ~server_ws_connection() = default;

// Send a message to the client
virtual void send(const std::string & message) = 0;

// Close the connection
virtual void close(int code = 1000, const std::string & reason = "") = 0;

// Get query parameter by key
virtual std::string get_query_param(const std::string & key) const = 0;

// Get the remote address
virtual std::string get_remote_address() const = 0;
// @ngxson: this is a demo for how a bi-directional connection between
// the server and frontend can be implemented using SSE + HTTP POST
// I'm reusing the name "WS" here, but this is not a real WebSocket implementation
// the code is 100% written by human, no AI involved
// but this is just a demo, do not use it in practice



struct server_ws_connection;

// hacky: server_ws_connection is a member of this struct because
// we want to have shared_ptr for other handler functions
// in practice, we don't really need this
struct server_ws_sse : server_http_res {
std::string id;
std::shared_ptr<server_ws_connection> conn;
const server_http_req & req;

std::mutex mutex_send;
std::condition_variable cv;
struct msg {
std::string data;
bool is_closed = false;
};
std::queue<msg> queue_send;

server_ws_sse(const server_http_req & req, const std::string & id) : id(id), req(req) {
conn = std::make_shared<server_ws_connection>(*this);

queue_send.push({
"data: {\"llamacpp_id\":\"" + id + "\"}", false
});

next = [this, &req, id](std::string & output) {
std::unique_lock<std::mutex> lk(mutex_send);
constexpr auto poll_interval = std::chrono::milliseconds(500);
while (true) {
if (!queue_send.empty()) {
output.clear();
auto & front = queue_send.front();
if (front.is_closed) {
return false; // closed
}
SRV_INF("%s: sending SSE message: %s\n", id.c_str(), front.data.c_str());
output = "data: " + front.data + "\n\n";
queue_send.pop();
return true;
}
if (req.should_stop()) {
return false; // connection closed
}
cv.wait_for(lk, poll_interval);
}
};
}

std::function<void()> on_close;
~server_ws_sse() {
close();
if (on_close) {
on_close();
}
}

void send(const std::string & message) {
std::lock_guard<std::mutex> lk(mutex_send);
queue_send.push({message, false});
cv.notify_all();
}

void close() {
std::lock_guard<std::mutex> lk(mutex_send);
queue_send.push({"", true});
cv.notify_all();
}
};

// Forward declarations
class ws_connection_impl;

// WebSocket context - manages the WebSocket server
// Runs on a separate thread and handles WebSocket connections
struct server_ws_context {
struct Impl;
std::unique_ptr<Impl> pimpl;

std::thread thread;
std::atomic<bool> is_ready = false;
struct server_ws_connection {
server_ws_sse & parent;
server_ws_connection(server_ws_sse & parent) : parent(parent) {}

std::string path_prefix; // e.g., "/mcp"
int port;
// Send a message to the client
void send(const std::string & message) {
parent.send(message);
}

server_ws_context();
~server_ws_context();
// Close the connection
void close(int code = 1000, const std::string & reason = "") {
SRV_INF("%s: closing connection: code=%d, reason=%s\n",
__func__, code, reason.c_str());
parent.close();
}

// Initialize the WebSocket server
bool init(const common_params & params);
// Get query parameter by key
std::string get_query_param(const std::string & key) const {
return parent.req.get_param(key);
}

// Start the WebSocket server (runs in background thread)
bool start();
// Get the remote address
std::string get_remote_address() {
return parent.id;
}
};

// Stop the WebSocket server
void stop();

// Get the actual port the WebSocket server is listening on
int get_actual_port() const;

// Set the port for the WebSocket server (note: actual port may differ if set to 0)
void set_port(int port) { this->port = port; }
// SSE + HTTP POST implementation of server_ws_context
struct server_ws_context {
server_ws_context() = default;
~server_ws_context() = default;

// map ID to connection
std::mutex mutex;
std::map<std::string, server_ws_sse *> res_map;

// SSE endpoint
server_http_context::handler_t get_mcp = [this](const server_http_req & req) {
auto id = random_string();
auto res = std::make_unique<server_ws_sse>(req, id);
{
std::unique_lock lock(mutex);
res_map[id] = res.get();
}
SRV_INF("%s: new SSE connection established, ID: %s\n%s", __func__, id.c_str(), req.body.c_str());
res->id = id;
res->status = 200;
res->headers["X-Connection-ID"] = id;
res->content_type = "text/event-stream";
// res->next is set in server_ws_sse constructor
res->on_close = [this, id]() {
std::unique_lock lock(mutex);
handler_on_close(res_map[id]->conn);
res_map.erase(id);
};
handler_on_open(res->conn);
return res;
};

// HTTP POST endpoint
server_http_context::handler_t post_mcp = [this](const server_http_req & req) {
auto id = req.get_param("llamacpp_id");
std::shared_ptr<server_ws_connection> conn;
SRV_INF("%s: received POST for connection ID: %s\n%s", __func__, id.c_str(), req.body.c_str());
std::unique_lock lock(mutex);
{
auto it = res_map.find(id);
if (it != res_map.end()) {
conn = it->second->conn;
}
}
if (!conn) {
SRV_ERR("%s: invalid connection ID: %s\n", __func__, id.c_str());
auto res = std::make_unique<server_http_res>();
res->status = 400;
res->data = "Invalid connection ID";
return res;
}
handler_on_message(conn, req.body);
auto res = std::make_unique<server_http_res>();
res->status = 200;
return res;
};

// Called when new connection is established
using on_open_t = std::function<void(std::shared_ptr<server_ws_connection>)>;
void on_open(on_open_t handler);
void on_open(on_open_t handler) { handler_on_open = handler; }

// Called when message is received from a connection
using on_message_t = std::function<void(std::shared_ptr<server_ws_connection>, const std::string &)>;
void on_message(on_message_t handler);
void on_message(on_message_t handler) { handler_on_message = handler; }

// Called when connection is closed
using on_close_t = std::function<void(std::shared_ptr<server_ws_connection>)>;
void on_close(on_close_t handler);
void on_close(on_close_t handler) { handler_on_close = handler; }

// For debugging
std::string listening_address;
on_open_t handler_on_open;
on_message_t handler_on_message;
on_close_t handler_on_close;
};
64 changes: 9 additions & 55 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,12 @@ int main(int argc, char ** argv, char ** envp) {
// WebSocket Server (for MCP support) - only if --webui-mcp is enabled
//

server_ws_context * ctx_ws = nullptr;
server_mcp_bridge * mcp_bridge = nullptr;
std::unique_ptr<server_ws_context> ctx_ws = nullptr;
std::unique_ptr<server_mcp_bridge> mcp_bridge = nullptr;

if (params.webui_mcp) {
ctx_ws = new server_ws_context();
mcp_bridge = new server_mcp_bridge();

// Initialize WebSocket server with params (sets port to HTTP port + 1)
if (!ctx_ws->init(params)) {
LOG_ERR("%s: failed to initialize WebSocket server\n", __func__);
delete ctx_ws;
delete mcp_bridge;
return 1;
}
ctx_ws = std::make_unique<server_ws_context>();
mcp_bridge = std::make_unique<server_mcp_bridge>();
}

// Helper function to get MCP config path
Expand Down Expand Up @@ -289,6 +281,9 @@ int main(int argc, char ** argv, char ** envp) {
res->data = response.dump();
return res;
});

ctx_http.get ("/mcp", ex_wrapper(ctx_ws->get_mcp));
ctx_http.post("/mcp", ex_wrapper(ctx_ws->post_mcp));
}

//
Expand All @@ -315,15 +310,8 @@ int main(int argc, char ** argv, char ** envp) {
if (is_router_server) {
LOG_INF("%s: starting router server, no model will be loaded in this process\n", __func__);

clean_up = [&models_routes, &ctx_ws, &mcp_bridge]() {
clean_up = [&models_routes]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
if (ctx_ws) {
ctx_ws->stop();
delete ctx_ws;
}
if (mcp_bridge) {
delete mcp_bridge;
}
if (models_routes.has_value()) {
models_routes->models.unload_all();
}
Expand All @@ -337,34 +325,14 @@ int main(int argc, char ** argv, char ** envp) {
}
ctx_http.is_ready.store(true);

// Start WebSocket server (OS will assign an available port) - only if --webui-mcp is enabled
if (params.webui_mcp && ctx_ws) {
if (!ctx_ws->start()) {
clean_up();
LOG_ERR("%s: exiting due to WebSocket server error\n", __func__);
return 1;
}
LOG_INF("%s: WebSocket server started on port %d\n", __func__, ctx_ws->get_actual_port());
}

shutdown_handler = [&](int) {
if (ctx_ws) {
ctx_ws->stop();
}
ctx_http.stop();
};

} else {
// setup clean up function, to be called before exit
clean_up = [&ctx_http, &ctx_ws, &ctx_server, &mcp_bridge]() {
clean_up = [&ctx_http, &ctx_server]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
if (ctx_ws) {
ctx_ws->stop();
delete ctx_ws;
}
if (mcp_bridge) {
delete mcp_bridge;
}
ctx_http.stop();
ctx_server.terminate();
llama_backend_free();
Expand All @@ -377,16 +345,6 @@ int main(int argc, char ** argv, char ** envp) {
return 1;
}

// Start WebSocket server (OS will assign an available port) - only if --webui-mcp is enabled
if (params.webui_mcp && ctx_ws) {
if (!ctx_ws->start()) {
clean_up();
LOG_ERR("%s: exiting due to WebSocket server error\n", __func__);
return 1;
}
LOG_INF("%s: WebSocket server started on port %d\n", __func__, ctx_ws->get_actual_port());
}

// load the model
LOG_INF("%s: loading model\n", __func__);

Expand All @@ -405,10 +363,6 @@ int main(int argc, char ** argv, char ** envp) {
LOG_INF("%s: model loaded\n", __func__);

shutdown_handler = [&](int) {
// this will unblock start_loop()
if (ctx_ws) {
ctx_ws->stop();
}
ctx_server.terminate();
};
}
Expand Down
Loading
Loading