diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 9fa8e1f798..6f9c249cfa 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -163,7 +163,7 @@ set(SOURCES if(KvikIO_REMOTE_SUPPORT) list(APPEND SOURCES "src/hdfs.cpp" "src/remote_handle.cpp" "src/detail/remote_handle.cpp" - "src/shim/libcurl.cpp" + "src/detail/url.cpp" "src/shim/libcurl.cpp" ) endif() diff --git a/cpp/include/kvikio/detail/url.hpp b/cpp/include/kvikio/detail/url.hpp new file mode 100644 index 0000000000..e57d2c4c94 --- /dev/null +++ b/cpp/include/kvikio/detail/url.hpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include + +namespace kvikio::detail { +/** + * @brief RAII wrapper for libcurl's URL handle (CURLU) + * + * This class provides automatic resource management for libcurl URL handles, + * ensuring proper cleanup when the handle goes out of scope. The class is + * move-only to prevent accidental sharing of the underlying resource. + */ +class CurlUrlHandle { + private: + CURLU* _handle{nullptr}; + + public: + /** + * @brief Create a new libcurl URL handle + * + * @exception std::runtime_error if libcurl cannot allocate the handle (usually due to out of + * memory) + */ + CurlUrlHandle(); + + /** + * @brief Clean up the underlying URL handle + */ + ~CurlUrlHandle() noexcept; + + CurlUrlHandle(CurlUrlHandle const&) = delete; + CurlUrlHandle& operator=(CurlUrlHandle const&) = delete; + + CurlUrlHandle(CurlUrlHandle&& other) noexcept; + CurlUrlHandle& operator=(CurlUrlHandle&& other) noexcept; + + /** + * @brief Get the underlying libcurl URL handle + * + * @return Pointer to the underlying libcurl URL handle + * @note The returned pointer should not be freed manually as it is managed by this class + */ + CURLU* get() const; +}; + +/** + * @brief URL parsing utility using libcurl's URL API + * + * This class provides static methods for parsing URLs into their constituent + * components (scheme, host, port, path, query, fragment). + * + * @note This class uses libcurl's URL parsing which follows RFC 3986 plus. See + * https://curl.se/docs/url-syntax.html + * + * Example: + * @code{.cpp} + * auto components = UrlParser::parse("https://example.com:8080/path?query=1#frag"); + * if (components.scheme.has_value()) { + * std::cout << "Scheme: " << components.scheme.value() << std::endl; + * } + * if (components.host.has_value()) { + * std::cout << "Host: " << components.host.value() << std::endl; + * } + * @endcode + */ +class UrlParser { + public: + /** + * @brief Container for parsed URL components + */ + struct UrlComponents { + /** + * @brief The URL scheme (e.g., "http", "https", "ftp"). May be empty for scheme-relative URLs + * or paths. + */ + std::optional scheme; + + /** + * @brief The hostname or IP address. May be empty for URLs without an authority component + * (e.g., "file:///path"). + */ + std::optional host; + + /** + * @brief The port number as a string. Will be empty if no explicit port is specified in the + * URL. + * @note Default ports (e.g., 80 for HTTP, 443 for HTTPS) are not automatically filled in. + */ + std::optional port; + + /** + * @brief The path component of the URL. Libcurl ensures that the path component is always + * present, even if empty (will be "/" for URLs like "http://example.com"). + */ + std::optional path; + + /** + * @brief The query string (without the leading "?"). Empty if no query parameters are present. + */ + std::optional query; + + /** + * @brief The fragment identifier (without the leading "#"). Empty if no fragment is present. + */ + std::optional fragment; + }; + + /** + * @brief Parses the given URL according to RFC 3986 plus and extracts its components. + * + * @param url The URL string to parse + * @param bitmask_url_flags Optional flags for URL parsing. Common flags include: + * - CURLU_DEFAULT_SCHEME: Allows URLs without schemes + * - CURLU_NON_SUPPORT_SCHEME: Accept non-supported schemes + * - CURLU_URLENCODE: URL encode the path + * @param bitmask_component_flags Optional flags for component extraction. Common flags include: + * - CURLU_URLDECODE: URL decode the component + * - CURLU_PUNYCODE: Return host as punycode + * + * @return UrlComponents structure containing the parsed URL components + * + * @throw std::runtime_error if the URL cannot be parsed or if component extraction fails + * + * Example: + * @code{.cpp} + * // Basic parsing + * auto components = UrlParser::parse("https://api.example.com/v1/users?page=1"); + * + * // Parsing with URL decoding + * auto decoded = UrlParser::parse( + * "https://example.com/hello%20world", + * std::nullopt, + * CURLU_URLDECODE + * ); + * + * // Allow non-standard schemes + * auto custom = UrlParser::parse( + * "myscheme://example.com", + * CURLU_NON_SUPPORT_SCHEME + * ); + * @endcode + */ + static UrlComponents parse(std::string const& url, + std::optional bitmask_url_flags = std::nullopt, + std::optional bitmask_component_flags = std::nullopt); + + /** + * @brief Extract a specific component from a CurlUrlHandle + * + * @param handle The CurlUrlHandle containing the parsed URL + * @param part The URL part to extract (e.g., CURLUPART_SCHEME) + * @param bitmask_component_flags Flags controlling extraction behavior + * @param allowed_err_code Optional error code to treat as valid (e.g., CURLUE_NO_SCHEME) + * @return The extracted component as a string, or std::nullopt if not present + * @throw std::runtime_error if extraction fails with an unexpected error + */ + static std::optional extract_component( + CurlUrlHandle const& handle, + CURLUPart part, + std::optional bitmask_component_flags = std::nullopt, + std::optional allowed_err_code = std::nullopt); + + /** + * @brief Extract a specific component from a URL string + * + * @param url The URL string from which to extract a component + * @param part The URL part to extract + * @param bitmask_url_flags Optional flags for URL parsing. + * @param bitmask_component_flags Flags controlling extraction behavior + * @param allowed_err_code Optional error code to treat as valid + * @return The extracted component as a string, or std::nullopt if not present + * @throw std::runtime_error if extraction fails with an unexpected error + */ + static std::optional extract_component( + std::string const& url, + CURLUPart part, + std::optional bitmask_url_flags = std::nullopt, + std::optional bitmask_component_flags = std::nullopt, + std::optional allowed_err_code = std::nullopt); +}; +} // namespace kvikio::detail diff --git a/cpp/include/kvikio/hdfs.hpp b/cpp/include/kvikio/hdfs.hpp index 0b20d658bd..345051bcbd 100644 --- a/cpp/include/kvikio/hdfs.hpp +++ b/cpp/include/kvikio/hdfs.hpp @@ -58,5 +58,13 @@ class WebHdfsEndpoint : public RemoteEndpoint { std::string str() const override; std::size_t get_file_size() override; void setup_range_request(CurlHandle& curl, std::size_t file_offset, std::size_t size) override; + + /** + * @brief Whether the given URL is valid for the WebHDFS endpoints. + * + * @param url A URL. + * @return Boolean answer. + */ + static bool is_url_valid(std::string const& url) noexcept; }; } // namespace kvikio diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index b2e2d1d0ff..0d56231d03 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -34,6 +34,18 @@ namespace kvikio { class CurlHandle; // Prototype +/** + * @brief Type of remote file. + */ +enum class RemoteEndpointType : uint8_t { + AUTO, ///< Let KvikIO infer the type of remote file from the URL and create a proper endpoint. + S3, ///< AWS S3 (based on HTTP/HTTPS protocols). + S3_PRESIGNED_URL, ///< AWS S3 presigned URL (based on HTTP/HTTPS protocols). + WEBHDFS, ///< Apache Hadoop WebHDFS (based on HTTP/HTTPS protocols). + HTTP, ///< Generic HTTP/HTTPS, excluding all the specific types listed above that use HTTP/HTTPS + ///< protocols. +}; + /** * @brief Abstract base class for remote endpoints. * @@ -43,6 +55,10 @@ class CurlHandle; // Prototype * its own ctor that takes communication protocol specific arguments. */ class RemoteEndpoint { + protected: + RemoteEndpointType _remote_endpoint_type{RemoteEndpointType::AUTO}; + RemoteEndpoint(RemoteEndpointType remote_endpoint_type); + public: virtual ~RemoteEndpoint() = default; @@ -74,6 +90,13 @@ class RemoteEndpoint { * size. */ virtual void setup_range_request(CurlHandle& curl, std::size_t file_offset, std::size_t size) = 0; + + /** + * @brief Get the type of the remote file. + * + * @return The type of the remote file. + */ + [[nodiscard]] RemoteEndpointType remote_endpoint_type() const noexcept; }; /** @@ -96,6 +119,14 @@ class HttpEndpoint : public RemoteEndpoint { std::string str() const override; std::size_t get_file_size() override; void setup_range_request(CurlHandle& curl, std::size_t file_offset, std::size_t size) override; + + /** + * @brief Whether the given URL is valid for HTTP/HTTPS endpoints. + * + * @param url A URL. + * @return Boolean answer. + */ + static bool is_url_valid(std::string const& url) noexcept; }; /** @@ -206,6 +237,14 @@ class S3Endpoint : public RemoteEndpoint { std::string str() const override; std::size_t get_file_size() override; void setup_range_request(CurlHandle& curl, std::size_t file_offset, std::size_t size) override; + + /** + * @brief Whether the given URL is valid for S3 endpoints (excluding presigned URL). + * + * @param url A URL. + * @return Boolean answer. + */ + static bool is_url_valid(std::string const& url) noexcept; }; /** @@ -224,6 +263,14 @@ class S3EndpointWithPresignedUrl : public RemoteEndpoint { std::string str() const override; std::size_t get_file_size() override; void setup_range_request(CurlHandle& curl, std::size_t file_offset, std::size_t size) override; + + /** + * @brief Whether the given URL is valid for S3 endpoints with presigned URL. + * + * @param url A URL. + * @return Boolean answer. + */ + static bool is_url_valid(std::string const& url) noexcept; }; /** @@ -235,6 +282,88 @@ class RemoteHandle { std::size_t _nbytes; public: + /** + * @brief Create a remote file handle from a URL. + * + * This function creates a RemoteHandle for reading data from various remote endpoints + * including HTTP/HTTPS servers, AWS S3 buckets, S3 presigned URLs, and WebHDFS. + * The endpoint type can be automatically detected from the URL or explicitly specified. + * + * @param url The URL of the remote file. Supported formats include: + * - S3 with credentials + * - S3 presigned URL + * - WebHDFS + * - HTTP/HTTPS + * @param remote_endpoint_type The type of remote endpoint. Default is RemoteEndpointType::AUTO + * which automatically detects the endpoint type from the URL. Can be explicitly set to + * RemoteEndpointType::S3, RemoteEndpointType::S3_PRESIGNED_URL, RemoteEndpointType::WEBHDFS, or + * RemoteEndpointType::HTTP to force a specific endpoint type. + * @param allow_list Optional list of allowed endpoint types. If provided: + * - If remote_endpoint_type is RemoteEndpointType::AUTO, Types are tried in the exact order + * specified until a match is found. + * - In explicit mode, the specified type must be in this list, otherwise an exception is + * thrown. + * + * If not provided, defaults to all supported types in this order: RemoteEndpointType::S3, + * RemoteEndpointType::S3_PRESIGNED_URL, RemoteEndpointType::WEBHDFS, and + * RemoteEndpointType::HTTP. + * @param nbytes Optional file size in bytes. If not provided, the function sends additional + * request to the server to query the file size. + * @return A RemoteHandle object that can be used to read data from the remote file. + * @exception std::runtime_error If: + * - If the URL is malformed or missing required components. + * - RemoteEndpointType::AUTO mode is used and the URL doesn't match any supported endpoint + * type. + * - The specified endpoint type is not in the `allow_list`. + * - The URL is invalid for the specified endpoint type. + * - Unable to connect to the remote server or determine file size (when nbytes not provided). + * + * Example: + * - Auto-detect endpoint type from URL + * @code{.cpp} + * auto handle = kvikio::RemoteHandle::open( + * "https://bucket.s3.amazonaws.com/object?X-Amz-Algorithm=AWS4-HMAC-SHA256" + * "&X-Amz-Credential=...&X-Amz-Signature=..." + * ); + * @endcode + * + * - Open S3 file with explicit endpoint type + * @code{.cpp} + * + * auto handle = kvikio::RemoteHandle::open( + * "https://my-bucket.s3.us-east-1.amazonaws.com/data.bin", + * kvikio::RemoteEndpointType::S3 + * ); + * @endcode + * + * - Restrict endpoint type candidates + * @code{.cpp} + * std::vector allow_list = { + * kvikio::RemoteEndpointType::HTTP, + * kvikio::RemoteEndpointType::S3_PRESIGNED_URL + * }; + * auto handle = kvikio::RemoteHandle::open( + * user_provided_url, + * kvikio::RemoteEndpointType::AUTO, + * allow_list + * ); + * @endcode + * + * - Provide known file size to skip HEAD request + * @code{.cpp} + * auto handle = kvikio::RemoteHandle::open( + * "https://example.com/large-file.bin", + * kvikio::RemoteEndpointType::HTTP, + * std::nullopt, + * 1024 * 1024 * 100 // 100 MB + * ); + * @endcode + */ + static RemoteHandle open(std::string url, + RemoteEndpointType remote_endpoint_type = RemoteEndpointType::AUTO, + std::optional> allow_list = std::nullopt, + std::optional nbytes = std::nullopt); + /** * @brief Create a new remote handle from an endpoint and a file size. * @@ -258,6 +387,13 @@ class RemoteHandle { RemoteHandle(RemoteHandle const&) = delete; RemoteHandle& operator=(RemoteHandle const&) = delete; + /** + * @brief Get the type of the remote file. + * + * @return The type of the remote file. + */ + [[nodiscard]] RemoteEndpointType remote_endpoint_type() const noexcept; + /** * @brief Get the file size. * diff --git a/cpp/src/detail/url.cpp b/cpp/src/detail/url.cpp new file mode 100644 index 0000000000..64f5b8fde9 --- /dev/null +++ b/cpp/src/detail/url.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include +#include + +#define CHECK_CURL_URL_ERR(err_code) check_curl_url_err(err_code, __LINE__, __FILE__) + +namespace kvikio::detail { +namespace { +void check_curl_url_err(CURLUcode err_code, int line_number, char const* filename) +{ + if (err_code == CURLUcode::CURLUE_OK) { return; } + + std::stringstream ss; + ss << "KvikIO detects an URL error at: " << filename << ":" << line_number << ": "; + char const* msg = curl_url_strerror(err_code); + if (msg == nullptr) { + ss << "(no message)"; + } else { + ss << msg; + } + throw std::runtime_error(ss.str()); +} +} // namespace + +CurlUrlHandle::CurlUrlHandle() : _handle(curl_url()) +{ + KVIKIO_EXPECT(_handle != nullptr, + "Libcurl is unable to allocate a URL handle (likely out of memory)."); +} + +CurlUrlHandle::~CurlUrlHandle() noexcept +{ + if (_handle) { curl_url_cleanup(_handle); } +} + +CurlUrlHandle::CurlUrlHandle(CurlUrlHandle&& other) noexcept + : _handle{std::exchange(other._handle, nullptr)} +{ +} + +CurlUrlHandle& CurlUrlHandle::operator=(CurlUrlHandle&& other) noexcept +{ + if (this != &other) { + if (_handle) { curl_url_cleanup(_handle); } + _handle = std::exchange(other._handle, nullptr); + } + + return *this; +} + +CURLU* CurlUrlHandle::get() const { return _handle; } + +std::optional UrlParser::extract_component( + CurlUrlHandle const& handle, + CURLUPart part, + std::optional bitmask_component_flags, + std::optional allowed_err_code) +{ + if (!bitmask_component_flags.has_value()) { bitmask_component_flags = 0U; } + + char* value{}; + auto err_code = curl_url_get(handle.get(), part, &value, bitmask_component_flags.value()); + + if (err_code == CURLUcode::CURLUE_OK && value != nullptr) { + std::string result{value}; + curl_free(value); + return result; + } + + if (allowed_err_code.has_value() && allowed_err_code.value() == err_code) { return std::nullopt; } + + // Throws an exception and explains the reason. + CHECK_CURL_URL_ERR(err_code); + return std::nullopt; +} + +std::optional UrlParser::extract_component( + std::string const& url, + CURLUPart part, + std::optional bitmask_url_flags, + std::optional bitmask_component_flags, + std::optional allowed_err_code) +{ + if (!bitmask_url_flags.has_value()) { bitmask_url_flags = 0U; } + if (!bitmask_component_flags.has_value()) { bitmask_component_flags = 0U; } + + CurlUrlHandle handle; + CHECK_CURL_URL_ERR( + curl_url_set(handle.get(), CURLUPART_URL, url.c_str(), bitmask_url_flags.value())); + + return extract_component(handle, part, bitmask_component_flags, allowed_err_code); +} + +UrlParser::UrlComponents UrlParser::parse(std::string const& url, + std::optional bitmask_url_flags, + std::optional bitmask_component_flags) +{ + if (!bitmask_url_flags.has_value()) { bitmask_url_flags = 0U; } + if (!bitmask_component_flags.has_value()) { bitmask_component_flags = 0U; } + + CurlUrlHandle handle; + CHECK_CURL_URL_ERR( + curl_url_set(handle.get(), CURLUPART_URL, url.c_str(), bitmask_url_flags.value())); + + UrlComponents components; + CURLUcode err_code{}; + + components.scheme = extract_component( + handle, CURLUPART_SCHEME, bitmask_component_flags.value(), CURLUcode::CURLUE_NO_SCHEME); + components.host = extract_component( + handle, CURLUPART_HOST, bitmask_component_flags.value(), CURLUcode::CURLUE_NO_HOST); + components.port = extract_component( + handle, CURLUPART_PORT, bitmask_component_flags.value(), CURLUcode::CURLUE_NO_PORT); + components.path = extract_component(handle, CURLUPART_PATH, bitmask_component_flags.value()); + components.query = extract_component( + handle, CURLUPART_QUERY, bitmask_component_flags.value(), CURLUcode::CURLUE_NO_QUERY); + components.fragment = extract_component( + handle, CURLUPART_FRAGMENT, bitmask_component_flags.value(), CURLUcode::CURLUE_NO_FRAGMENT); + + return components; +} +} // namespace kvikio::detail diff --git a/cpp/src/hdfs.cpp b/cpp/src/hdfs.cpp index 12455b3a26..2e032a1af7 100644 --- a/cpp/src/hdfs.cpp +++ b/cpp/src/hdfs.cpp @@ -25,7 +25,7 @@ namespace kvikio { -WebHdfsEndpoint::WebHdfsEndpoint(std::string url) +WebHdfsEndpoint::WebHdfsEndpoint(std::string url) : RemoteEndpoint{RemoteEndpointType::WEBHDFS} { // todo: Use libcurl URL API for more secure and idiomatic parsing. // Split the URL into two parts: one without query and one with. @@ -64,7 +64,7 @@ WebHdfsEndpoint::WebHdfsEndpoint(std::string host, std::string port, std::string file_path, std::optional username) - : _username{std::move(username)} + : RemoteEndpoint{RemoteEndpointType::WEBHDFS}, _username{std::move(username)} { std::stringstream ss; ss << "http://" << host << ":" << port << "/webhdfs/v1" << file_path; @@ -128,4 +128,15 @@ void WebHdfsEndpoint::setup_range_request(CurlHandle& curl, ss << "op=OPEN&offset=" << file_offset << "&length=" << size; curl.setopt(CURLOPT_URL, ss.str().c_str()); } + +bool WebHdfsEndpoint::is_url_valid(std::string const& url) noexcept +{ + try { + std::regex const pattern(R"(^https?://[^/]+:\d+/webhdfs/v1/.+$)", std::regex_constants::icase); + std::smatch match_result; + return std::regex_match(url, match_result, pattern); + } catch (...) { + return false; + } +} } // namespace kvikio diff --git a/cpp/src/remote_handle.cpp b/cpp/src/remote_handle.cpp index 23cf5c6305..3cf2acc862 100644 --- a/cpp/src/remote_handle.cpp +++ b/cpp/src/remote_handle.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include #include @@ -26,7 +27,9 @@ #include #include +#include #include +#include #include #include #include @@ -177,9 +180,69 @@ void setup_range_request_impl(CurlHandle& curl, std::size_t file_offset, std::si curl.setopt(CURLOPT_RANGE, byte_range.c_str()); } +/** + * @brief Whether the given URL is compatible with the S3 endpoint (including the credential-based + * access and presigned URL) which uses HTTP/HTTPS. + * + * @param url A URL. + * @return Boolean answer. + */ +bool url_has_aws_s3_http_format(std::string const& url) +{ + // Currently KvikIO supports the following AWS S3 HTTP URL formats: + static std::array const s3_patterns = { + // Virtual host style: https://.s3..amazonaws.com/ + std::regex(R"(https?://[^/]+\.s3\.[^.]+\.amazonaws\.com/.+$)", std::regex_constants::icase), + + // Path style (deprecated but still popular): + // https://s3..amazonaws.com// + std::regex(R"(https?://s3\.[^.]+\.amazonaws\.com/[^/]+/.+$)", std::regex_constants::icase), + + // Legacy global endpoint: no region code + std::regex(R"(https?://[^/]+\.s3\.amazonaws\.com/.+$)", std::regex_constants::icase), + std::regex(R"(https?://s3\.amazonaws\.com/[^/]+/.+$)", std::regex_constants::icase), + + // Legacy regional endpoint: s3 and region code are delimited by - instead of . + std::regex(R"(https?://[^/]+\.s3-[^.]+\.amazonaws\.com/.+$)", std::regex_constants::icase), + std::regex(R"(https?://s3-[^.]+\.amazonaws\.com/[^/]+/.+$)", std::regex_constants::icase)}; + + return std::any_of(s3_patterns.begin(), s3_patterns.end(), [&url = url](auto const& pattern) { + std::smatch match_result; + return std::regex_match(url, match_result, pattern); + }); +} + +char const* get_remote_endpoint_type_name(RemoteEndpointType remote_endpoint_type) +{ + switch (remote_endpoint_type) { + case RemoteEndpointType::S3: return "S3"; + case RemoteEndpointType::S3_PRESIGNED_URL: return "S3 with presigned URL"; + case RemoteEndpointType::WEBHDFS: return "WebHDFS"; + case RemoteEndpointType::HTTP: return "HTTP"; + case RemoteEndpointType::AUTO: return "AUTO"; + default: + // Unreachable + KVIKIO_FAIL("Unknown RemoteEndpointType: " + + std::to_string(static_cast(remote_endpoint_type))); + return "UNKNOWN"; + } +} } // namespace -HttpEndpoint::HttpEndpoint(std::string url) : _url{std::move(url)} {} +RemoteEndpoint::RemoteEndpoint(RemoteEndpointType remote_endpoint_type) + : _remote_endpoint_type{remote_endpoint_type} +{ +} + +RemoteEndpointType RemoteEndpoint::remote_endpoint_type() const noexcept +{ + return _remote_endpoint_type; +} + +HttpEndpoint::HttpEndpoint(std::string url) + : RemoteEndpoint{RemoteEndpointType::HTTP}, _url{std::move(url)} +{ +} std::string HttpEndpoint::str() const { return _url; } @@ -194,6 +257,19 @@ void HttpEndpoint::setup_range_request(CurlHandle& curl, std::size_t file_offset setup_range_request_impl(curl, file_offset, size); } +bool HttpEndpoint::is_url_valid(std::string const& url) noexcept +{ + try { + auto parsed_url = detail::UrlParser::parse(url); + if ((parsed_url.scheme != "http") && (parsed_url.scheme != "https")) { return false; }; + + // Check whether the file path exists, excluding the leading "/" + return parsed_url.path->length() > 1; + } catch (...) { + return false; + } +} + void HttpEndpoint::setopt(CurlHandle& curl) { curl.setopt(CURLOPT_URL, _url.c_str()); } void S3Endpoint::setopt(CurlHandle& curl) @@ -256,7 +332,7 @@ S3Endpoint::S3Endpoint(std::string url, std::optional aws_access_key, std::optional aws_secret_access_key, std::optional aws_session_token) - : _url{std::move(url)} + : RemoteEndpoint{RemoteEndpointType::S3}, _url{std::move(url)} { KVIKIO_NVTX_FUNC_RANGE(); // Regular expression to match http[s]:// @@ -348,8 +424,29 @@ void S3Endpoint::setup_range_request(CurlHandle& curl, std::size_t file_offset, setup_range_request_impl(curl, file_offset, size); } +bool S3Endpoint::is_url_valid(std::string const& url) noexcept +{ + try { + auto parsed_url = detail::UrlParser::parse(url, CURLU_NON_SUPPORT_SCHEME); + + if (parsed_url.scheme == "s3") { + if (!parsed_url.host.has_value()) { return false; } + if (!parsed_url.path.has_value()) { return false; } + + // Check whether the S3 object key exists + std::regex const pattern(R"(^/[^/]+$)", std::regex::icase); + std::smatch match_result; + return std::regex_search(parsed_url.path.value(), match_result, pattern); + } else if ((parsed_url.scheme == "http") || (parsed_url.scheme == "https")) { + return url_has_aws_s3_http_format(url) && !S3EndpointWithPresignedUrl::is_url_valid(url); + } + } catch (...) { + } + return false; +} + S3EndpointWithPresignedUrl::S3EndpointWithPresignedUrl(std::string presigned_url) - : _url{std::move(presigned_url)} + : RemoteEndpoint{RemoteEndpointType::S3_PRESIGNED_URL}, _url{std::move(presigned_url)} { } @@ -439,6 +536,95 @@ void S3EndpointWithPresignedUrl::setup_range_request(CurlHandle& curl, setup_range_request_impl(curl, file_offset, size); } +bool S3EndpointWithPresignedUrl::is_url_valid(std::string const& url) noexcept +{ + try { + if (!url_has_aws_s3_http_format(url)) { return false; } + + auto parsed_url = detail::UrlParser::parse(url); + if (!parsed_url.query.has_value()) { return false; } + + // Reference: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + return parsed_url.query->find("X-Amz-Algorithm") != std::string::npos && + parsed_url.query->find("X-Amz-Signature") != std::string::npos; + } catch (...) { + return false; + } +} + +RemoteHandle RemoteHandle::open(std::string url, + RemoteEndpointType remote_endpoint_type, + std::optional> allow_list, + std::optional nbytes) +{ + if (!allow_list.has_value()) { + allow_list = {RemoteEndpointType::S3, + RemoteEndpointType::S3_PRESIGNED_URL, + RemoteEndpointType::WEBHDFS, + RemoteEndpointType::HTTP}; + } + + auto const scheme = + detail::UrlParser::extract_component(url, CURLUPART_SCHEME, CURLU_NON_SUPPORT_SCHEME); + KVIKIO_EXPECT(scheme.has_value(), "Missing scheme in URL."); + + // Helper to create endpoint based on type + auto create_endpoint = + [&url = url, &scheme = scheme](RemoteEndpointType type) -> std::unique_ptr { + switch (type) { + case RemoteEndpointType::S3: + if (!S3Endpoint::is_url_valid(url)) { return nullptr; } + if (scheme.value() == "s3") { + auto const [bucket, object] = S3Endpoint::parse_s3_url(url); + return std::make_unique(std::pair{bucket, object}); + } + return std::make_unique(url); + + case RemoteEndpointType::S3_PRESIGNED_URL: + if (!S3EndpointWithPresignedUrl::is_url_valid(url)) { return nullptr; } + return std::make_unique(url); + + case RemoteEndpointType::WEBHDFS: + if (!WebHdfsEndpoint::is_url_valid(url)) { return nullptr; } + return std::make_unique(url); + + case RemoteEndpointType::HTTP: + if (!HttpEndpoint::is_url_valid(url)) { return nullptr; } + return std::make_unique(url); + + default: return nullptr; + } + }; + + std::unique_ptr endpoint; + + if (remote_endpoint_type == RemoteEndpointType::AUTO) { + // Try each allowed type in the order of allowlist + for (auto const& type : allow_list.value()) { + endpoint = create_endpoint(type); + if (endpoint) { break; } + } + KVIKIO_EXPECT(endpoint.get() != nullptr, "Unsupported endpoint URL.", std::runtime_error); + } else { + // Validate it is in the allow list + KVIKIO_EXPECT( + std::find(allow_list->begin(), allow_list->end(), remote_endpoint_type) != allow_list->end(), + std::string{get_remote_endpoint_type_name(remote_endpoint_type)} + + " is not in the allowlist.", + std::runtime_error); + + // Create the specific type + endpoint = create_endpoint(remote_endpoint_type); + KVIKIO_EXPECT(endpoint.get() != nullptr, + std::string{"Invalid URL for "} + + get_remote_endpoint_type_name(remote_endpoint_type) + " endpoint", + std::runtime_error); + } + + return nbytes.has_value() ? RemoteHandle(std::move(endpoint), nbytes.value()) + : RemoteHandle(std::move(endpoint)); +} + RemoteHandle::RemoteHandle(std::unique_ptr endpoint, std::size_t nbytes) : _endpoint{std::move(endpoint)}, _nbytes{nbytes} { @@ -452,6 +638,11 @@ RemoteHandle::RemoteHandle(std::unique_ptr endpoint) _endpoint = std::move(endpoint); } +RemoteEndpointType RemoteHandle::remote_endpoint_type() const noexcept +{ + return _endpoint->remote_endpoint_type(); +} + std::size_t RemoteHandle::nbytes() const noexcept { return _nbytes; } RemoteEndpoint const& RemoteHandle::endpoint() const noexcept { return *_endpoint; } diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 41de4bb6fa..a6fd2c67e4 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -79,6 +79,7 @@ kvikio_add_test(NAME MMAP_TEST SOURCES test_mmap.cpp) if(KvikIO_REMOTE_SUPPORT) kvikio_add_test(NAME REMOTE_HANDLE_TEST SOURCES test_remote_handle.cpp utils/env.cpp) kvikio_add_test(NAME HDFS_TEST SOURCES test_hdfs.cpp utils/hdfs_helper.cpp) + kvikio_add_test(NAME URL_TEST SOURCES test_url.cpp) endif() rapids_test_install_relocatable(INSTALL_COMPONENT_SET testing DESTINATION bin/tests/libkvikio) diff --git a/cpp/tests/test_defaults.cpp b/cpp/tests/test_defaults.cpp index 89bbe7399c..a74f38c86f 100644 --- a/cpp/tests/test_defaults.cpp +++ b/cpp/tests/test_defaults.cpp @@ -19,9 +19,9 @@ #include #include +#include #include -#include "kvikio/compat_mode.hpp" #include "utils/env.hpp" using ::testing::HasSubstr; diff --git a/cpp/tests/test_remote_handle.cpp b/cpp/tests/test_remote_handle.cpp index 918479b0f0..ffb7c82266 100644 --- a/cpp/tests/test_remote_handle.cpp +++ b/cpp/tests/test_remote_handle.cpp @@ -14,12 +14,101 @@ * limitations under the License. */ +#include +#include +#include +#include +#include + +#include #include +#include #include #include "utils/env.hpp" -TEST(RemoteHandleTest, s3_endpoint_constructor) +using ::testing::HasSubstr; +using ::testing::ThrowsMessage; + +class RemoteHandleTest : public testing::Test { + protected: + void SetUp() override + { + _sample_urls = { + // Endpoint type: S3 + {"s3://bucket-name/object-key-name", kvikio::RemoteEndpointType::S3}, + {"https://bucket-name.s3.region-code.amazonaws.com/object-key-name", + kvikio::RemoteEndpointType::S3}, + {"https://s3.region-code.amazonaws.com/bucket-name/object-key-name", + kvikio::RemoteEndpointType::S3}, + {"https://bucket-name.s3.amazonaws.com/object-key-name", kvikio::RemoteEndpointType::S3}, + {"https://s3.amazonaws.com/bucket-name/object-key-name", kvikio::RemoteEndpointType::S3}, + {"https://bucket-name.s3-region-code.amazonaws.com/object-key-name", + kvikio::RemoteEndpointType::S3}, + {"https://s3-region-code.amazonaws.com/bucket-name/object-key-name", + kvikio::RemoteEndpointType::S3}, + + // Endpoint type: S3 presigned URL + {"https://bucket-name.s3.region-code.amazonaws.com/" + "object-key-name?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Signature=sig&X-Amz-Credential=cred&" + "X-Amz-SignedHeaders=host", + kvikio::RemoteEndpointType::S3_PRESIGNED_URL}, + + // Endpoint type: WebHDFS + {"https://host:1234/webhdfs/v1/data.bin", kvikio::RemoteEndpointType::WEBHDFS}, + }; + } + + void TearDown() override {} + + void test_helper(kvikio::RemoteEndpointType expected_endpoint_type, + std::function url_validity_checker) + { + for (auto const& [url, endpoint_type] : _sample_urls) { + if (endpoint_type == expected_endpoint_type) { + // Given that the URL is the expected endpoint type + + // Test URL validity checker + EXPECT_TRUE(url_validity_checker(url)); + + // Test unified interface + { + // Here we pass the 1-byte argument to RemoteHandle::open. This prevents the endpoint + // constructor from querying the file size and sending requests to the server, thus + // allowing us to use dummy URLs for testing purpose. + auto remote_handle = + kvikio::RemoteHandle::open(url, kvikio::RemoteEndpointType::AUTO, std::nullopt, 1); + EXPECT_EQ(remote_handle.remote_endpoint_type(), expected_endpoint_type); + } + + // Test explicit endpoint type specification + { + EXPECT_NO_THROW({ + auto remote_handle = + kvikio::RemoteHandle::open(url, expected_endpoint_type, std::nullopt, 1); + }); + } + } else { + // Given that the URL is NOT the expected endpoint type + + // Test URL validity checker + EXPECT_FALSE(url_validity_checker(url)); + + // Test explicit endpoint type specification + { + EXPECT_ANY_THROW({ + auto remote_handle = + kvikio::RemoteHandle::open(url, expected_endpoint_type, std::nullopt, 1); + }); + } + } + } + } + + std::vector> _sample_urls; +}; + +TEST_F(RemoteHandleTest, s3_endpoint_constructor) { kvikio::test::EnvVarContext env_var_ctx{{"AWS_DEFAULT_REGION", "my_aws_default_region"}, {"AWS_ACCESS_KEY_ID", "my_aws_access_key_id"}, @@ -37,3 +126,145 @@ TEST(RemoteHandleTest, s3_endpoint_constructor) EXPECT_EQ(s1.str(), s2.str()); } + +TEST_F(RemoteHandleTest, test_http_url) +{ + // Invalid URLs + { + std::vector const invalid_urls{// Incorrect scheme + "s3://example.com", + "hdfs://example.com", + // Missing file path + "http://example.com"}; + for (auto const& invalid_url : invalid_urls) { + EXPECT_FALSE(kvikio::HttpEndpoint::is_url_valid(invalid_url)); + } + } +} + +TEST_F(RemoteHandleTest, test_s3_url) +{ + kvikio::test::EnvVarContext env_var_ctx{{"AWS_DEFAULT_REGION", "my_aws_default_region"}, + {"AWS_ACCESS_KEY_ID", "my_aws_access_key_id"}, + {"AWS_SECRET_ACCESS_KEY", "my_aws_secrete_access_key"}}; + + { + test_helper(kvikio::RemoteEndpointType::S3, kvikio::S3Endpoint::is_url_valid); + } + + // Invalid URLs + { + std::vector const invalid_urls{ + // Lack object-name + "s3://bucket-name", + "https://bucket-name.s3.region-code.amazonaws.com", + // Presigned URL + "https://bucket-name.s3.region-code.amazonaws.com/" + "object-key-name?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Signature=sig&X-Amz-Credential=" + "cred&" + "X-Amz-SignedHeaders=host"}; + for (auto const& invalid_url : invalid_urls) { + EXPECT_FALSE(kvikio::S3Endpoint::is_url_valid(invalid_url)); + } + } +} + +TEST_F(RemoteHandleTest, test_s3_url_with_presigned_url) +{ + { + test_helper(kvikio::RemoteEndpointType::S3_PRESIGNED_URL, + kvikio::S3EndpointWithPresignedUrl::is_url_valid); + } + + // Invalid URLs + { + std::vector const invalid_urls{ + // Presigned URL should not use S3 scheme + "s3://bucket-name/object-key-name", + + // Completely missing query + "https://bucket-name.s3.region-code.amazonaws.com/object-key-name", + + // Missing key parameters ("X-Amz-..."") in query + "https://bucket-name.s3.region-code.amazonaws.com/object-key-name?k0=v0&k1=v2"}; + for (auto const& invalid_url : invalid_urls) { + EXPECT_FALSE(kvikio::S3EndpointWithPresignedUrl::is_url_valid(invalid_url)); + } + } +} + +TEST_F(RemoteHandleTest, test_webhdfs_url) +{ + { + test_helper(kvikio::RemoteEndpointType::WEBHDFS, kvikio::WebHdfsEndpoint::is_url_valid); + } + + // Invalid URLs + { + std::vector const invalid_urls{// Missing file + "https://host:1234/webhdfs/v1", + "https://host:1234/webhdfs/v1/", + + // Missing WebHDFS identifier + "https://host:1234/data.bin", + + // Missing port number + "https://host/webhdfs/v1/data.bin"}; + for (auto const& invalid_url : invalid_urls) { + EXPECT_FALSE(kvikio::WebHdfsEndpoint::is_url_valid(invalid_url)); + } + } +} + +TEST_F(RemoteHandleTest, test_open) +{ + // Missing scheme + { + std::vector const urls{ + "example.com/path", "example.com:8080/path", "//example.com/path", "://example.com/path"}; + for (auto const& url : urls) { + EXPECT_THROW( + { kvikio::RemoteHandle::open(url, kvikio::RemoteEndpointType::AUTO, std::nullopt, 1); }, + std::runtime_error); + } + } + + // Unsupported type + { + std::string const url{"unsupported://example.com/path"}; + EXPECT_THAT( + [&] { kvikio::RemoteHandle::open(url, kvikio::RemoteEndpointType::AUTO, std::nullopt, 1); }, + ThrowsMessage(HasSubstr("Unsupported endpoint URL"))); + } + + // Specified URL not in the allowlist + { + std::string const url{"https://host:1234/webhdfs/v1/data.bin"}; + std::vector> const wrong_allowlists{ + {}, + {kvikio::RemoteEndpointType::S3}, + }; + for (auto const& wrong_allowlist : wrong_allowlists) { + EXPECT_THAT( + [&] { + kvikio::RemoteHandle::open(url, kvikio::RemoteEndpointType::WEBHDFS, wrong_allowlist, 1); + }, + ThrowsMessage(HasSubstr("is not in the allowlist"))); + } + } + + // Invalid URLs + { + std::vector> const invalid_urls{ + {"s3://bucket-name", kvikio::RemoteEndpointType::S3}, + {"https://bucket-name.s3.region-code.amazonaws.com/object-key-name", + kvikio::RemoteEndpointType::S3_PRESIGNED_URL}, + {"https://host:1234/webhdfs/v1", kvikio::RemoteEndpointType::WEBHDFS}, + {"http://example.com", kvikio::RemoteEndpointType::HTTP}, + }; + for (auto const& [invalid_url, endpoint_type] : invalid_urls) { + EXPECT_THAT([&] { kvikio::RemoteHandle::open(invalid_url, endpoint_type, std::nullopt, 1); }, + ThrowsMessage(HasSubstr("Invalid URL"))); + } + } +} diff --git a/cpp/tests/test_url.cpp b/cpp/tests/test_url.cpp new file mode 100644 index 0000000000..ce419ed5a5 --- /dev/null +++ b/cpp/tests/test_url.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include + +using ::testing::HasSubstr; +using ::testing::ThrowsMessage; + +TEST(UrlTest, parse_scheme) +{ + { + std::vector invalid_scheme_urls{ + "invalid_scheme://host", + // The S3 scheme is not supported by libcurl. Without the CURLU_NON_SUPPORT_SCHEME flag, an + // exception is expected. + "s3://host"}; + + for (auto const& invalid_scheme_url : invalid_scheme_urls) { + EXPECT_THAT([&] { kvikio::detail::UrlParser::parse(invalid_scheme_url); }, + ThrowsMessage(HasSubstr("KvikIO detects an URL error"))); + } + } + + // With the CURLU_NON_SUPPORT_SCHEME flag, the S3 scheme is now accepted. + { + std::vector schemes{"s3", "S3"}; + for (auto const& scheme : schemes) { + auto parsed_url = + kvikio::detail::UrlParser::parse(scheme + "://host", CURLU_NON_SUPPORT_SCHEME); + EXPECT_EQ(parsed_url.scheme.value(), "s3"); // Lowercase due to CURL's normalization + } + } +} + +TEST(UrlTest, parse_host) +{ + std::vector invalid_host_urls{"http://host with spaces.com", + "http://host[brackets].com", + "http://host{braces}.com", + "http://host.com", + R"(http://host\backslash.com)", + "http://host^caret.com", + "http://host`backtick.com"}; + for (auto const& invalid_host_url : invalid_host_urls) { + EXPECT_THROW({ kvikio::detail::UrlParser::parse(invalid_host_url); }, std::runtime_error); + } +}