Skip to content

Commit

Permalink
Fix incorrect handling of Expect: 100-continue
Browse files Browse the repository at this point in the history
  • Loading branch information
solarispika committed Sep 4, 2024
1 parent b1f8e98 commit 4880496
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 2 deletions.
4 changes: 3 additions & 1 deletion httplib.h
Original file line number Diff line number Diff line change
Expand Up @@ -6956,7 +6956,9 @@ Server::process_request(Stream &strm, bool close_connection,
strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status,
status_message(status));
break;
default: return write_response(strm, close_connection, req, res);
default:
connection_closed = true;
return write_response(strm, true, req, res);
}
}

Expand Down
4 changes: 3 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ else()
FetchContent_MakeAvailable(gtest)
endif()

find_package(curl REQUIRED)

add_executable(httplib-test test.cc)
target_compile_options(httplib-test PRIVATE "$<$<CXX_COMPILER_ID:MSVC>:/utf-8;/bigobj>")
target_link_libraries(httplib-test PRIVATE httplib GTest::gtest_main)
target_link_libraries(httplib-test PRIVATE httplib GTest::gtest_main CURL::libcurl)
gtest_discover_tests(httplib-test)

file(
Expand Down
100 changes: 100 additions & 0 deletions test/test.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <httplib.h>
#include <signal.h>

#include <curl/curl.h>
#include <gtest/gtest.h>

#include <atomic>
Expand All @@ -12,6 +13,7 @@
#include <stdexcept>
#include <thread>
#include <type_traits>
#include <vector>

#define SERVER_CERT_FILE "./cert.pem"
#define SERVER_CERT2_FILE "./cert2.pem"
Expand Down Expand Up @@ -7606,3 +7608,101 @@ TEST(DirtyDataRequestTest, HeadFieldValueContains_CR_LF_NUL) {
Client cli(HOST, PORT);
cli.Get("/test", {{"Test", "_\n\r_\n\r_"}});
}

TEST(Expect100ContinueTest, ServerClosesConnection) {
static constexpr char reject[] = "Unauthorized";
static constexpr char accept[] = "Upload accepted";
constexpr size_t total_size = 10 * 1024 * 1024 * 1024ULL;

Server svr;

svr.set_expect_100_continue_handler([](const Request &req, Response &res) {
res.status = StatusCode::Unauthorized_401;
res.set_content(reject, "text/plain");
return res.status;
});
svr.Post("/", [&](const Request & /*req*/, Response &res) {
res.set_content(accept, "text/plain");
});

auto thread = std::thread([&]() { svr.listen(HOST, PORT); });
auto se = detail::scope_exit([&] {
svr.stop();
thread.join();
ASSERT_FALSE(svr.is_running());
});

svr.wait_until_ready();

{
const auto curl = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>{
curl_easy_init(), &curl_easy_cleanup};
ASSERT_NE(curl, nullptr);

curl_easy_setopt(curl.get(), CURLOPT_URL, HOST);
curl_easy_setopt(curl.get(), CURLOPT_PORT, PORT);
curl_easy_setopt(curl.get(), CURLOPT_POST, 1L);
auto list = std::unique_ptr<curl_slist, decltype(&curl_slist_free_all)>{
curl_slist_append(nullptr, "Content-Type: application/octet-stream"),
&curl_slist_free_all};
ASSERT_NE(list, nullptr);
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, list.get());

struct read_data {
size_t read_size;
size_t total_size;
} data = {0, total_size};
using read_callback_t =
size_t (*)(char *ptr, size_t size, size_t nmemb, void *userdata);
read_callback_t read_callback = [](char *ptr, size_t size, size_t nmemb,
void *userdata) -> size_t {
read_data *data = (read_data *)userdata;

if (!userdata || data->read_size >= data->total_size) { return 0; }

std::fill_n(ptr, size * nmemb, 'A');
data->read_size += size * nmemb;
return size * nmemb;
};
curl_easy_setopt(curl.get(), CURLOPT_READDATA, data);
curl_easy_setopt(curl.get(), CURLOPT_READFUNCTION, read_callback);

std::vector<char> buffer;
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &buffer);
using write_callback_t =
size_t (*)(char *ptr, size_t size, size_t nmemb, void *userdata);
write_callback_t write_callback = [](char *ptr, size_t size, size_t nmemb,
void *userdata) -> size_t {
std::vector<char> *buffer = (std::vector<char> *)userdata;
buffer->reserve(buffer->size() + size * nmemb + 1);
buffer->insert(buffer->end(), (char *)ptr, (char *)ptr + size * nmemb);
return size * nmemb;
};
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, write_callback);

{
const auto res = curl_easy_perform(curl.get());
ASSERT_EQ(res, CURLE_OK);
}

{
auto response_code = 0L;
const auto res =
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &response_code);
ASSERT_EQ(res, CURLE_OK);
ASSERT_EQ(response_code, StatusCode::Unauthorized_401);
}

{
curl_off_t dl;
const auto res = curl_easy_getinfo(curl.get(), CURLINFO_SIZE_DOWNLOAD_T, &dl);
ASSERT_EQ(res, CURLE_OK);
ASSERT_EQ(dl, sizeof reject - 1);
}

{
buffer.push_back('\0');
ASSERT_STRCASEEQ(buffer.data(), reject);
}
}
}

0 comments on commit 4880496

Please sign in to comment.