From 7196ac8a07100da06cfe278719a5b96a1ddfee1b Mon Sep 17 00:00:00 2001 From: "Sung, Po Han" Date: Wed, 4 Sep 2024 17:38:05 +0800 Subject: [PATCH] Fix incorrect handling of Expect: 100-continue Fix #1808 --- httplib.h | 4 +- test/CMakeLists.txt | 4 +- test/test.cc | 100 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 2 deletions(-) diff --git a/httplib.h b/httplib.h index b7be298e07..7c7bf04236 100644 --- a/httplib.h +++ b/httplib.h @@ -6956,7 +6956,9 @@ Server::process_request(Stream &strm, bool close_connection, strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, status_message(status)); break; - default: return write_response(strm, close_connection, req, res); + default: + connection_closed = true; + return write_response(strm, true, req, res); } } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d982253317..75dd978cd2 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -24,9 +24,11 @@ else() FetchContent_MakeAvailable(gtest) endif() +find_package(curl REQUIRED) + add_executable(httplib-test test.cc) target_compile_options(httplib-test PRIVATE "$<$:/utf-8;/bigobj>") -target_link_libraries(httplib-test PRIVATE httplib GTest::gtest_main) +target_link_libraries(httplib-test PRIVATE httplib GTest::gtest_main CURL::libcurl) gtest_discover_tests(httplib-test) file( diff --git a/test/test.cc b/test/test.cc index 09a2eba084..c75cdd9f63 100644 --- a/test/test.cc +++ b/test/test.cc @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -12,6 +13,7 @@ #include #include #include +#include #define SERVER_CERT_FILE "./cert.pem" #define SERVER_CERT2_FILE "./cert2.pem" @@ -7606,3 +7608,101 @@ TEST(DirtyDataRequestTest, HeadFieldValueContains_CR_LF_NUL) { Client cli(HOST, PORT); cli.Get("/test", {{"Test", "_\n\r_\n\r_"}}); } + +TEST(Expect100ContinueTest, ServerClosesConnection) { + static constexpr char reject[] = "Unauthorized"; + static constexpr char accept[] = "Upload accepted"; + constexpr size_t total_size = 10 * 1024 * 1024 * 1024ULL; + + Server svr; + + svr.set_expect_100_continue_handler([](const Request &req, Response &res) { + res.status = StatusCode::Unauthorized_401; + res.set_content(reject, "text/plain"); + return res.status; + }); + svr.Post("/", [&](const Request & /*req*/, Response &res) { + res.set_content(accept, "text/plain"); + }); + + auto thread = std::thread([&]() { svr.listen(HOST, PORT); }); + auto se = detail::scope_exit([&] { + svr.stop(); + thread.join(); + ASSERT_FALSE(svr.is_running()); + }); + + svr.wait_until_ready(); + + { + const auto curl = std::unique_ptr{ + curl_easy_init(), &curl_easy_cleanup}; + ASSERT_NE(curl, nullptr); + + curl_easy_setopt(curl.get(), CURLOPT_URL, HOST); + curl_easy_setopt(curl.get(), CURLOPT_PORT, PORT); + curl_easy_setopt(curl.get(), CURLOPT_POST, 1L); + auto list = std::unique_ptr{ + curl_slist_append(nullptr, "Content-Type: application/octet-stream"), + &curl_slist_free_all}; + ASSERT_NE(list, nullptr); + curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, list.get()); + + struct read_data { + size_t read_size; + size_t total_size; + } data = {0, total_size}; + using read_callback_t = + size_t (*)(char *ptr, size_t size, size_t nmemb, void *userdata); + read_callback_t read_callback = [](char *ptr, size_t size, size_t nmemb, + void *userdata) -> size_t { + read_data *data = (read_data *)userdata; + + if (!userdata || data->read_size >= data->total_size) { return 0; } + + std::fill_n(ptr, size * nmemb, 'A'); + data->read_size += size * nmemb; + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_READDATA, data); + curl_easy_setopt(curl.get(), CURLOPT_READFUNCTION, read_callback); + + std::vector buffer; + curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &buffer); + using write_callback_t = + size_t (*)(char *ptr, size_t size, size_t nmemb, void *userdata); + write_callback_t write_callback = [](char *ptr, size_t size, size_t nmemb, + void *userdata) -> size_t { + std::vector *buffer = (std::vector *)userdata; + buffer->reserve(buffer->size() + size * nmemb + 1); + buffer->insert(buffer->end(), (char *)ptr, (char *)ptr + size * nmemb); + return size * nmemb; + }; + curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, write_callback); + + { + const auto res = curl_easy_perform(curl.get()); + ASSERT_EQ(res, CURLE_OK); + } + + { + auto response_code = long{}; + const auto res = + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &response_code); + ASSERT_EQ(res, CURLE_OK); + ASSERT_EQ(response_code, StatusCode::Unauthorized_401); + } + + { + auto dl = curl_off_t{}; + const auto res = curl_easy_getinfo(curl.get(), CURLINFO_SIZE_DOWNLOAD_T, &dl); + ASSERT_EQ(res, CURLE_OK); + ASSERT_EQ(dl, sizeof reject - 1); + } + + { + buffer.push_back('\0'); + ASSERT_STRCASEEQ(buffer.data(), reject); + } + } +}