|  | 
| 26 | 26 | #include <mutex> | 
| 27 | 27 | #include <chrono> | 
| 28 | 28 | #include <condition_variable> | 
|  | 29 | +#include <atomic> | 
| 29 | 30 | 
 | 
| 30 | 31 | #ifndef SERVER_VERBOSE | 
| 31 | 32 | #define SERVER_VERBOSE 1 | 
| @@ -146,6 +147,12 @@ static std::vector<uint8_t> base64_decode(const std::string & encoded_string) | 
| 146 | 147 | // parallel | 
| 147 | 148 | // | 
| 148 | 149 | 
 | 
|  | 150 | +enum ServerState { | 
|  | 151 | +    LOADING_MODEL,  // Server is starting up, model not fully loaded yet | 
|  | 152 | +    READY,          // Server is ready and model is loaded | 
|  | 153 | +    ERROR           // An error occurred, load_model failed | 
|  | 154 | +}; | 
|  | 155 | + | 
| 149 | 156 | enum task_type { | 
| 150 | 157 |     COMPLETION_TASK, | 
| 151 | 158 |     CANCEL_TASK | 
| @@ -2453,7 +2460,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, | 
| 2453 | 2460 |     } | 
| 2454 | 2461 | } | 
| 2455 | 2462 | 
 | 
| 2456 |  | - | 
| 2457 | 2463 | static std::string random_string() | 
| 2458 | 2464 | { | 
| 2459 | 2465 |     static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); | 
| @@ -2790,15 +2796,117 @@ int main(int argc, char **argv) | 
| 2790 | 2796 |                                 {"system_info", llama_print_system_info()}, | 
| 2791 | 2797 |                             }); | 
| 2792 | 2798 | 
 | 
| 2793 |  | -    // load the model | 
| 2794 |  | -    if (!llama.load_model(params)) | 
|  | 2799 | +    httplib::Server svr; | 
|  | 2800 | + | 
|  | 2801 | +    std::atomic<ServerState> server_state{LOADING_MODEL}; | 
|  | 2802 | + | 
|  | 2803 | +    svr.set_default_headers({{"Server", "llama.cpp"}, | 
|  | 2804 | +                             {"Access-Control-Allow-Origin", "*"}, | 
|  | 2805 | +                             {"Access-Control-Allow-Headers", "content-type"}}); | 
|  | 2806 | + | 
|  | 2807 | +    svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) { | 
|  | 2808 | +        ServerState current_state = server_state.load(); | 
|  | 2809 | +        switch(current_state) { | 
|  | 2810 | +            case READY: | 
|  | 2811 | +                res.set_content(R"({"status": "ok"})", "application/json"); | 
|  | 2812 | +                res.status = 200; // HTTP OK | 
|  | 2813 | +                break; | 
|  | 2814 | +            case LOADING_MODEL: | 
|  | 2815 | +                res.set_content(R"({"status": "loading model"})", "application/json"); | 
|  | 2816 | +                res.status = 503; // HTTP Service Unavailable | 
|  | 2817 | +                break; | 
|  | 2818 | +            case ERROR: | 
|  | 2819 | +                res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); | 
|  | 2820 | +                res.status = 500; // HTTP Internal Server Error | 
|  | 2821 | +                break; | 
|  | 2822 | +        } | 
|  | 2823 | +    }); | 
|  | 2824 | + | 
|  | 2825 | +    svr.set_logger(log_server_request); | 
|  | 2826 | + | 
|  | 2827 | +    svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep) | 
|  | 2828 | +            { | 
|  | 2829 | +                const char fmt[] = "500 Internal Server Error\n%s"; | 
|  | 2830 | +                char buf[BUFSIZ]; | 
|  | 2831 | +                try | 
|  | 2832 | +                { | 
|  | 2833 | +                    std::rethrow_exception(std::move(ep)); | 
|  | 2834 | +                } | 
|  | 2835 | +                catch (std::exception &e) | 
|  | 2836 | +                { | 
|  | 2837 | +                    snprintf(buf, sizeof(buf), fmt, e.what()); | 
|  | 2838 | +                } | 
|  | 2839 | +                catch (...) | 
|  | 2840 | +                { | 
|  | 2841 | +                    snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); | 
|  | 2842 | +                } | 
|  | 2843 | +                res.set_content(buf, "text/plain; charset=utf-8"); | 
|  | 2844 | +                res.status = 500; | 
|  | 2845 | +            }); | 
|  | 2846 | + | 
|  | 2847 | +    svr.set_error_handler([](const httplib::Request &, httplib::Response &res) | 
|  | 2848 | +            { | 
|  | 2849 | +                if (res.status == 401) | 
|  | 2850 | +                { | 
|  | 2851 | +                    res.set_content("Unauthorized", "text/plain; charset=utf-8"); | 
|  | 2852 | +                } | 
|  | 2853 | +                if (res.status == 400) | 
|  | 2854 | +                { | 
|  | 2855 | +                    res.set_content("Invalid request", "text/plain; charset=utf-8"); | 
|  | 2856 | +                } | 
|  | 2857 | +                else if (res.status == 404) | 
|  | 2858 | +                { | 
|  | 2859 | +                    res.set_content("File Not Found", "text/plain; charset=utf-8"); | 
|  | 2860 | +                    res.status = 404; | 
|  | 2861 | +                } | 
|  | 2862 | +            }); | 
|  | 2863 | + | 
|  | 2864 | +    // set timeouts and change hostname and port | 
|  | 2865 | +    svr.set_read_timeout (sparams.read_timeout); | 
|  | 2866 | +    svr.set_write_timeout(sparams.write_timeout); | 
|  | 2867 | + | 
|  | 2868 | +    if (!svr.bind_to_port(sparams.hostname, sparams.port)) | 
| 2795 | 2869 |     { | 
|  | 2870 | +        fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port); | 
| 2796 | 2871 |         return 1; | 
| 2797 | 2872 |     } | 
| 2798 | 2873 | 
 | 
| 2799 |  | -    llama.initialize(); | 
|  | 2874 | +    // Set the base directory for serving static files | 
|  | 2875 | +    svr.set_base_dir(sparams.public_path); | 
| 2800 | 2876 | 
 | 
| 2801 |  | -    httplib::Server svr; | 
|  | 2877 | +    // to make it ctrl+clickable: | 
|  | 2878 | +    LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); | 
|  | 2879 | + | 
|  | 2880 | +    std::unordered_map<std::string, std::string> log_data; | 
|  | 2881 | +    log_data["hostname"] = sparams.hostname; | 
|  | 2882 | +    log_data["port"] = std::to_string(sparams.port); | 
|  | 2883 | + | 
|  | 2884 | +    if (!sparams.api_key.empty()) { | 
|  | 2885 | +        log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4); | 
|  | 2886 | +    } | 
|  | 2887 | + | 
|  | 2888 | +    LOG_INFO("HTTP server listening", log_data); | 
|  | 2889 | +    // run the HTTP server in a thread - see comment below | 
|  | 2890 | +    std::thread t([&]() | 
|  | 2891 | +            { | 
|  | 2892 | +                if (!svr.listen_after_bind()) | 
|  | 2893 | +                { | 
|  | 2894 | +                    server_state.store(ERROR); | 
|  | 2895 | +                    return 1; | 
|  | 2896 | +                } | 
|  | 2897 | + | 
|  | 2898 | +                return 0; | 
|  | 2899 | +            }); | 
|  | 2900 | + | 
|  | 2901 | +    // load the model | 
|  | 2902 | +    if (!llama.load_model(params)) | 
|  | 2903 | +    { | 
|  | 2904 | +        server_state.store(ERROR); | 
|  | 2905 | +        return 1; | 
|  | 2906 | +    } else { | 
|  | 2907 | +        llama.initialize(); | 
|  | 2908 | +        server_state.store(READY); | 
|  | 2909 | +    } | 
| 2802 | 2910 | 
 | 
| 2803 | 2911 |     // Middleware for API key validation | 
| 2804 | 2912 |     auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { | 
| @@ -2826,10 +2934,6 @@ int main(int argc, char **argv) | 
| 2826 | 2934 |         return false; | 
| 2827 | 2935 |     }; | 
| 2828 | 2936 | 
 | 
| 2829 |  | -    svr.set_default_headers({{"Server", "llama.cpp"}, | 
| 2830 |  | -                             {"Access-Control-Allow-Origin", "*"}, | 
| 2831 |  | -                             {"Access-Control-Allow-Headers", "content-type"}}); | 
| 2832 |  | - | 
| 2833 | 2937 |     // this is only called if no index.html is found in the public --path | 
| 2834 | 2938 |     svr.Get("/", [](const httplib::Request &, httplib::Response &res) | 
| 2835 | 2939 |             { | 
| @@ -2937,8 +3041,6 @@ int main(int argc, char **argv) | 
| 2937 | 3041 |                 } | 
| 2938 | 3042 |             }); | 
| 2939 | 3043 | 
 | 
| 2940 |  | - | 
| 2941 |  | - | 
| 2942 | 3044 |     svr.Get("/v1/models", [¶ms](const httplib::Request&, httplib::Response& res) | 
| 2943 | 3045 |             { | 
| 2944 | 3046 |                 std::time_t t = std::time(0); | 
| @@ -3157,81 +3259,6 @@ int main(int argc, char **argv) | 
| 3157 | 3259 |                 return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); | 
| 3158 | 3260 |             }); | 
| 3159 | 3261 | 
 | 
| 3160 |  | -    svr.set_logger(log_server_request); | 
| 3161 |  | - | 
| 3162 |  | -    svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep) | 
| 3163 |  | -            { | 
| 3164 |  | -                const char fmt[] = "500 Internal Server Error\n%s"; | 
| 3165 |  | -                char buf[BUFSIZ]; | 
| 3166 |  | -                try | 
| 3167 |  | -                { | 
| 3168 |  | -                    std::rethrow_exception(std::move(ep)); | 
| 3169 |  | -                } | 
| 3170 |  | -                catch (std::exception &e) | 
| 3171 |  | -                { | 
| 3172 |  | -                    snprintf(buf, sizeof(buf), fmt, e.what()); | 
| 3173 |  | -                } | 
| 3174 |  | -                catch (...) | 
| 3175 |  | -                { | 
| 3176 |  | -                    snprintf(buf, sizeof(buf), fmt, "Unknown Exception"); | 
| 3177 |  | -                } | 
| 3178 |  | -                res.set_content(buf, "text/plain; charset=utf-8"); | 
| 3179 |  | -                res.status = 500; | 
| 3180 |  | -            }); | 
| 3181 |  | - | 
| 3182 |  | -    svr.set_error_handler([](const httplib::Request &, httplib::Response &res) | 
| 3183 |  | -            { | 
| 3184 |  | -                if (res.status == 401) | 
| 3185 |  | -                { | 
| 3186 |  | -                    res.set_content("Unauthorized", "text/plain; charset=utf-8"); | 
| 3187 |  | -                } | 
| 3188 |  | -                if (res.status == 400) | 
| 3189 |  | -                { | 
| 3190 |  | -                    res.set_content("Invalid request", "text/plain; charset=utf-8"); | 
| 3191 |  | -                } | 
| 3192 |  | -                else if (res.status == 404) | 
| 3193 |  | -                { | 
| 3194 |  | -                    res.set_content("File Not Found", "text/plain; charset=utf-8"); | 
| 3195 |  | -                    res.status = 404; | 
| 3196 |  | -                } | 
| 3197 |  | -            }); | 
| 3198 |  | - | 
| 3199 |  | -    // set timeouts and change hostname and port | 
| 3200 |  | -    svr.set_read_timeout (sparams.read_timeout); | 
| 3201 |  | -    svr.set_write_timeout(sparams.write_timeout); | 
| 3202 |  | - | 
| 3203 |  | -    if (!svr.bind_to_port(sparams.hostname, sparams.port)) | 
| 3204 |  | -    { | 
| 3205 |  | -        fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port); | 
| 3206 |  | -        return 1; | 
| 3207 |  | -    } | 
| 3208 |  | - | 
| 3209 |  | -    // Set the base directory for serving static files | 
| 3210 |  | -    svr.set_base_dir(sparams.public_path); | 
| 3211 |  | - | 
| 3212 |  | -    // to make it ctrl+clickable: | 
| 3213 |  | -    LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); | 
| 3214 |  | - | 
| 3215 |  | -    std::unordered_map<std::string, std::string> log_data; | 
| 3216 |  | -    log_data["hostname"] = sparams.hostname; | 
| 3217 |  | -    log_data["port"] = std::to_string(sparams.port); | 
| 3218 |  | - | 
| 3219 |  | -    if (!sparams.api_key.empty()) { | 
| 3220 |  | -        log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4); | 
| 3221 |  | -    } | 
| 3222 |  | - | 
| 3223 |  | -    LOG_INFO("HTTP server listening", log_data); | 
| 3224 |  | -    // run the HTTP server in a thread - see comment below | 
| 3225 |  | -    std::thread t([&]() | 
| 3226 |  | -            { | 
| 3227 |  | -                if (!svr.listen_after_bind()) | 
| 3228 |  | -                { | 
| 3229 |  | -                    return 1; | 
| 3230 |  | -                } | 
| 3231 |  | - | 
| 3232 |  | -                return 0; | 
| 3233 |  | -            }); | 
| 3234 |  | - | 
| 3235 | 3262 |     // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? | 
| 3236 | 3263 |     //     "Bus error: 10" - this is on macOS, it does not crash on Linux | 
| 3237 | 3264 |     //std::thread t2([&]() | 
|  | 
0 commit comments