Skip to content

Commit

Permalink
core: Add support for user-defined HTTP headers in NetworkReader. (y…
Browse files Browse the repository at this point in the history
…-scope#568)

Co-authored-by: Lin Zhihao <[email protected]>
Co-authored-by: Xiaochong Wei <[email protected]>
  • Loading branch information
3 people authored and Jack Luo committed Dec 4, 2024
1 parent 2feb739 commit 58f98ad
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 12 deletions.
62 changes: 57 additions & 5 deletions components/core/src/clp/CurlDownloadHandler.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
#include "CurlDownloadHandler.hpp"

#include <algorithm>
#include <array>
#include <cctype>
#include <chrono>
#include <cstddef>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include <curl/curl.h>
#include <fmt/core.h>

#include "ErrorCode.hpp"

namespace clp {
CurlDownloadHandler::CurlDownloadHandler(
Expand All @@ -19,7 +28,8 @@ CurlDownloadHandler::CurlDownloadHandler(
size_t offset,
bool disable_caching,
std::chrono::seconds connection_timeout,
std::chrono::seconds overall_timeout
std::chrono::seconds overall_timeout,
std::optional<std::unordered_map<std::string, std::string>> const& http_header_kv_pairs
)
: m_error_msg_buf{std::move(error_msg_buf)} {
if (nullptr != m_error_msg_buf) {
Expand Down Expand Up @@ -48,13 +58,55 @@ CurlDownloadHandler::CurlDownloadHandler(
m_easy_handle.set_option(CURLOPT_TIMEOUT, static_cast<long>(overall_timeout.count()));

// Set up http headers
constexpr std::string_view cRangeHeaderName{"range"};
constexpr std::string_view cCacheControlHeaderName{"cache-control"};
constexpr std::string_view cPragmaHeaderName{"pragma"};
std::unordered_set<std::string_view> const reserved_headers{
cRangeHeaderName,
cCacheControlHeaderName,
cPragmaHeaderName
};
if (0 != offset) {
std::string const range{"Range: bytes=" + std::to_string(offset) + "-"};
m_http_headers.append(range);
m_http_headers.append(fmt::format("{}: bytes={}-", cRangeHeaderName, offset));
}
if (disable_caching) {
m_http_headers.append("Cache-Control: no-cache");
m_http_headers.append("Pragma: no-cache");
m_http_headers.append(fmt::format("{}: no-cache", cCacheControlHeaderName));
m_http_headers.append(fmt::format("{}: no-cache", cPragmaHeaderName));
}
if (http_header_kv_pairs.has_value()) {
for (auto const& [key, value] : http_header_kv_pairs.value()) {
// HTTP header field-name (key) is case-insensitive:
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
// Therefore, we convert keys to lowercase for comparison with the reserved keys.
// NOTE: We do not check for duplicate keys due to case insensitivity, leaving duplicate
// handling to the server.
auto lower_key{key};
std::transform(
lower_key.begin(),
lower_key.end(),
lower_key.begin(),
[](unsigned char c) -> char {
// Implicitly cast the input character into `unsigned char` to avoid UB:
// https://en.cppreference.com/w/cpp/string/byte/tolower
return static_cast<char>(std::tolower(c));
}
);
if (reserved_headers.contains(lower_key) || value.ends_with("\r\n")) {
throw CurlOperationFailed(
ErrorCode_Failure,
__FILE__,
__LINE__,
CURLE_BAD_FUNCTION_ARGUMENT,
fmt::format(
"`CurlDownloadHandler` failed to construct with the following "
"invalid header: {}:{}",
key,
value
)
);
}
m_http_headers.append(fmt::format("{}: {}", key, value));
}
}
if (false == m_http_headers.is_empty()) {
m_easy_handle.set_option(CURLOPT_HTTPHEADER, m_http_headers.get_raw_list());
Expand Down
10 changes: 9 additions & 1 deletion components/core/src/clp/CurlDownloadHandler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
#include <chrono>
#include <cstddef>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_map>

#include <curl/curl.h>

Expand Down Expand Up @@ -53,6 +56,9 @@ class CurlDownloadHandler {
* Doc: https://curl.se/libcurl/c/CURLOPT_CONNECTTIMEOUT.html
* @param overall_timeout Maximum time that the transfer may take. Note that this includes
* `connection_timeout`. Doc: https://curl.se/libcurl/c/CURLOPT_TIMEOUT.html
* @param http_header_kv_pairs Key-value pairs representing HTTP headers to pass to the server
* in the download request. Doc: https://curl.se/libcurl/c/CURLOPT_HTTPHEADER.html
* @throw CurlOperationFailed if an error occurs.
*/
explicit CurlDownloadHandler(
std::shared_ptr<ErrorMsgBuf> error_msg_buf,
Expand All @@ -63,7 +69,9 @@ class CurlDownloadHandler {
size_t offset = 0,
bool disable_caching = false,
std::chrono::seconds connection_timeout = cDefaultConnectionTimeout,
std::chrono::seconds overall_timeout = cDefaultOverallTimeout
std::chrono::seconds overall_timeout = cDefaultOverallTimeout,
std::optional<std::unordered_map<std::string, std::string>> const& http_header_kv_pairs
= std::nullopt
);

// Disable copy/move constructors/assignment operators
Expand Down
16 changes: 13 additions & 3 deletions components/core/src/clp/NetworkReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_map>
#include <utility>

#include <curl/curl.h>

Expand Down Expand Up @@ -118,7 +121,8 @@ NetworkReader::NetworkReader(
std::chrono::seconds overall_timeout,
std::chrono::seconds connection_timeout,
size_t buffer_pool_size,
size_t buffer_size
size_t buffer_size,
std::optional<std::unordered_map<std::string, std::string>> http_header_kv_pairs
)
: m_src_url{src_url},
m_offset{offset},
Expand All @@ -130,7 +134,12 @@ NetworkReader::NetworkReader(
for (size_t i = 0; i < m_buffer_pool_size; ++i) {
m_buffer_pool.emplace_back(m_buffer_size);
}
m_downloader_thread = std::make_unique<DownloaderThread>(*this, offset, disable_caching);
m_downloader_thread = std::make_unique<DownloaderThread>(
*this,
offset,
disable_caching,
std::move(http_header_kv_pairs)
);
m_downloader_thread->start();
}

Expand Down Expand Up @@ -215,7 +224,8 @@ auto NetworkReader::DownloaderThread::thread_method() -> void {
m_offset,
m_disable_caching,
m_reader.m_connection_timeout,
m_reader.m_overall_timeout
m_reader.m_overall_timeout,
m_http_header_kv_pairs
};
auto const ret_code{curl_handler.perform()};
// Enqueue the last filled buffer, if any
Expand Down
21 changes: 18 additions & 3 deletions components/core/src/clp/NetworkReader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <span>
#include <string>
#include <string_view>
#include <unordered_map>
#include <utility>
#include <vector>

#include <curl/curl.h>
Expand Down Expand Up @@ -94,6 +96,8 @@ class NetworkReader : public ReaderInterface {
* Doc: https://curl.se/libcurl/c/CURLOPT_CONNECTTIMEOUT.html
* @param buffer_pool_size The required number of buffers in the buffer pool.
* @param buffer_size The size of each buffer in the buffer pool.
* @param http_header_kv_pairs Key-value pairs representing HTTP headers to pass to the server
* in the download request. Doc: https://curl.se/libcurl/c/CURLOPT_HTTPHEADER.html
*/
explicit NetworkReader(
std::string_view src_url,
Expand All @@ -103,7 +107,9 @@ class NetworkReader : public ReaderInterface {
std::chrono::seconds connection_timeout
= CurlDownloadHandler::cDefaultConnectionTimeout,
size_t buffer_pool_size = cDefaultBufferPoolSize,
size_t buffer_size = cDefaultBufferSize
size_t buffer_size = cDefaultBufferSize,
std::optional<std::unordered_map<std::string, std::string>> http_header_kv_pairs
= std::nullopt
);

// Destructor
Expand Down Expand Up @@ -242,11 +248,19 @@ class NetworkReader : public ReaderInterface {
* @param reader
* @param offset Index of the byte at which to start the download.
* @param disable_caching Whether to disable caching.
* @param http_header_kv_pairs Key-value pairs representing HTTP headers to pass to the
* server in the download request. Doc: https://curl.se/libcurl/c/CURLOPT_HTTPHEADER.html
*/
DownloaderThread(NetworkReader& reader, size_t offset, bool disable_caching)
DownloaderThread(
NetworkReader& reader,
size_t offset,
bool disable_caching,
std::optional<std::unordered_map<std::string, std::string>> http_header_kv_pairs
)
: m_reader{reader},
m_offset{offset},
m_disable_caching{disable_caching} {}
m_disable_caching{disable_caching},
m_http_header_kv_pairs{std::move(http_header_kv_pairs)} {}

private:
// Methods implementing `clp::Thread`
Expand All @@ -255,6 +269,7 @@ class NetworkReader : public ReaderInterface {
NetworkReader& m_reader;
size_t m_offset{0};
bool m_disable_caching{false};
std::optional<std::unordered_map<std::string, std::string>> m_http_header_kv_pairs;
};

/**
Expand Down
58 changes: 58 additions & 0 deletions components/core/tests/test-NetworkReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
#include <string>
#include <string_view>
#include <thread>
#include <unordered_map>
#include <vector>

#include <Catch2/single_include/catch2/catch.hpp>
#include <curl/curl.h>
#include <fmt/core.h>
#include <json/single_include/nlohmann/json.hpp>

#include "../src/clp/Array.hpp"
#include "../src/clp/CurlDownloadHandler.hpp"
Expand Down Expand Up @@ -188,3 +191,58 @@ TEST_CASE("network_reader_illegal_offset", "[NetworkReader]") {
size_t pos{};
REQUIRE((clp::ErrorCode_Failure == reader.try_get_pos(pos)));
}

TEST_CASE("network_reader_with_valid_http_header_kv_pairs", "[NetworkReader]") {
std::unordered_map<std::string, std::string> valid_http_header_kv_pairs;
// We use httpbin (https://httpbin.org/) to test the user-specified headers. On success, it is
// supposed to respond all the user-specified headers as key-value pairs in JSON form.
constexpr int cNumHttpHeaderKeyValuePairs{10};
for (size_t i{0}; i < cNumHttpHeaderKeyValuePairs; ++i) {
valid_http_header_kv_pairs.emplace(
fmt::format("Unit-Test-Key{}", i),
fmt::format("Unit-Test-Value{}", i)
);
}
clp::NetworkReader reader{
"https://httpbin.org/headers",
0,
false,
clp::CurlDownloadHandler::cDefaultOverallTimeout,
clp::CurlDownloadHandler::cDefaultConnectionTimeout,
clp::NetworkReader::cDefaultBufferPoolSize,
clp::NetworkReader::cDefaultBufferSize,
valid_http_header_kv_pairs
};
auto const content = nlohmann::json::parse(get_content(reader));
auto const& headers{content.at("headers")};
REQUIRE(assert_curl_error_code(CURLE_OK, reader));
for (auto const& [key, value] : valid_http_header_kv_pairs) {
REQUIRE((value == headers.at(key).get<std::string_view>()));
}
}

TEST_CASE("network_reader_with_illegal_http_header_kv_pairs", "[NetworkReader]") {
auto illegal_header_kv_pairs = GENERATE(
// The following headers are determined by offset and disable_cache, which should not be
// overridden by user-defined headers.
std::unordered_map<std::string, std::string>{{"Range", "bytes=100-"}},
std::unordered_map<std::string, std::string>{{"RAnGe", "bytes=100-"}},
std::unordered_map<std::string, std::string>{{"Cache-Control", "no-cache"}},
std::unordered_map<std::string, std::string>{{"Pragma", "no-cache"}},
// The CRLF-terminated headers should be rejected.
std::unordered_map<std::string, std::string>{{"Legal-Name", "CRLF\r\n"}}
);
clp::NetworkReader reader{
"https://httpbin.org/headers",
0,
false,
clp::CurlDownloadHandler::cDefaultOverallTimeout,
clp::CurlDownloadHandler::cDefaultConnectionTimeout,
clp::NetworkReader::cDefaultBufferPoolSize,
clp::NetworkReader::cDefaultBufferSize,
illegal_header_kv_pairs
};
auto const content = get_content(reader);
REQUIRE(content.empty());
REQUIRE(assert_curl_error_code(CURLE_BAD_FUNCTION_ARGUMENT, reader));
}

0 comments on commit 58f98ad

Please sign in to comment.