diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index f0c23efb69f..7e1c01ea915 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -118,6 +118,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}") # protobuf-internal.cc set(ARROW_FLIGHT_SRCS client.cc + client_cookie_middleware.cc internal.cc protocol_internal.cc serialization_internal.cc diff --git a/cpp/src/arrow/flight/client_cookie_middleware.cc b/cpp/src/arrow/flight/client_cookie_middleware.cc new file mode 100644 index 00000000000..11b2a9e4146 --- /dev/null +++ b/cpp/src/arrow/flight/client_cookie_middleware.cc @@ -0,0 +1,274 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/client_cookie_middleware.h" + +// Platform-specific defines +#include +#include +#include +#include + +#include "arrow/flight/platform.h" +#include "arrow/util/string.h" +#include "arrow/util/uri.h" +#include "arrow/util/value_parsing.h" + +namespace { +#ifdef _WIN32 +#define strcasecmp stricmp +#endif + +struct CaseInsensitiveComparator + : public std::binary_function { + bool operator()(const std::string& lhs, const std::string& rhs) const { + return strcasecmp(lhs.c_str(), rhs.c_str()) < 0; + } +}; + +// Parse a cookie header string beginning at the given start_pos and identify the name and +// value of an attribute. +// +// @param cookie_header_value The value of the Set-Cookie header. +// @param start_pos An input/output parameter indicating the starting position +// of the attribute. +// It will store the position of the next attribute when the +// function returns. +// @param out_key The name of the attribute. +// @param out_value The value of the attribute. +// +// @return true if an attribute is found. +bool ParseCookieAttribute(std::string cookie_header_value, + std::string::size_type& start_pos, std::string& out_key, + std::string& out_value) { + std::string::size_type equals_pos = cookie_header_value.find('=', start_pos); + if (std::string::npos == equals_pos) { + // No cookie attribute. + return false; + } + + std::string::size_type semi_col_pos = cookie_header_value.find(';', equals_pos); + out_key = arrow::internal::TrimString( + cookie_header_value.substr(start_pos, equals_pos - start_pos)); + if (std::string::npos == semi_col_pos && semi_col_pos > equals_pos) { + // Last item - set start pos to end + out_value = arrow::internal::TrimString(cookie_header_value.substr(equals_pos + 1)); + start_pos = std::string::npos; + } else { + out_value = arrow::internal::TrimString( + cookie_header_value.substr(equals_pos + 1, semi_col_pos - equals_pos - 1)); + start_pos = semi_col_pos + 1; + } + + // Key/Value may be URI-encoded. + out_key = arrow::internal::UriUnescape(arrow::util::string_view(out_key)); + out_value = arrow::internal::UriUnescape(arrow::util::string_view(out_value)); + + // Strip outer quotes on the value. + if (out_value.size() >= 2 && out_value[0] == '"' && + out_value[out_value.size() - 1] == '"') { + out_value = out_value.substr(1, out_value.size() - 2); + } + + // Update the start position for subsequent calls to this function. + return true; +} + +struct Cookie { + static Cookie parse(const arrow::util::string_view& cookie_header_value) { + // Parse the cookie string. If the cookie has an expiration, record it. + // If the cookie has a max-age, calculate the current time + max_age and set that as + // the expiration. + Cookie cookie; + cookie.has_expiry_ = false; + std::string cookie_value_str(cookie_header_value); + + // There should always be a first match which should be the name and value of the + // cookie. + std::string::size_type pos = 0; + if (!ParseCookieAttribute(cookie_value_str, pos, cookie.cookie_name_, + cookie.cookie_value_)) { + // No cookie found. Mark the output cookie as expired. + cookie.has_expiry_ = true; + cookie.expiration_time_ = std::chrono::system_clock::now(); + } + + std::string cookie_attr_name; + std::string cookie_attr_value; + while (pos < cookie_value_str.size() && + ParseCookieAttribute(cookie_value_str, pos, cookie_attr_name, + cookie_attr_value)) { + if (arrow::internal::AsciiEqualsCaseInsensitive(cookie_attr_name, "max-age")) { + // Note: max-age takes precedence over expires. We don't really care about other + // attributes and will arbitrarily take the first max-age. We can stop the loop + // here. + cookie.has_expiry_ = true; + int max_age = std::stoi(cookie_attr_value); + if (max_age <= 0) { + // Force expiration. + cookie.expiration_time_ = std::chrono::system_clock::now(); + } else { + // Max-age is in seconds. + cookie.expiration_time_ = + std::chrono::system_clock::now() + std::chrono::seconds(max_age); + } + break; + } else if (arrow::internal::AsciiEqualsCaseInsensitive(cookie_attr_name, + "expires")) { + cookie.has_expiry_ = true; + int64_t seconds = 0; + const char* COOKIE_EXPIRES_FORMAT = "%a, %d %b %Y %H:%M:%S GMT"; + if (arrow::internal::ParseTimestampStrptime( + cookie_attr_value.c_str(), cookie_attr_value.size(), + COOKIE_EXPIRES_FORMAT, false, true, arrow::TimeUnit::SECOND, &seconds)) { + cookie.expiration_time_ = + std::chrono::system_clock::from_time_t(static_cast(seconds)); + } + } + } + + return cookie; + } + + bool IsExpired() const { + // Check if current-time is less than creation time. + return has_expiry_ && expiration_time_ <= std::chrono::system_clock::now(); + } + + std::string AsCookieString() { + // Return the string for the cookie as it would appear in a Cookie header. + // Keys must be wrapped in quotes depending on if this is a v1 or v2 cookie. + return cookie_name_ + "=\"" + cookie_value_ + "\""; + } + + std::string cookie_name_; + std::string cookie_value_; + std::chrono::time_point expiration_time_; + bool has_expiry_; +}; +} // end of anonymous namespace + +namespace arrow { + +namespace flight { + +using CookieCache = std::map; + +class ClientCookieMiddlewareFactory::Impl { + public: + void StartCall(const CallInfo& info, std::unique_ptr* middleware) { + ARROW_UNUSED(info); + *middleware = std::unique_ptr(new ClientCookieMiddleware(*this)); + } + + private: + class ClientCookieMiddleware : public ClientMiddleware { + public: + explicit ClientCookieMiddleware(ClientCookieMiddlewareFactory::Impl& factory_impl) + : factory_impl_(factory_impl) {} + + void SendingHeaders(AddCallHeaders* outgoing_headers) override { + const std::string& cookie_string = factory_impl_.GetValidCookiesAsString(); + if (!cookie_string.empty()) { + outgoing_headers->AddHeader("cookie", cookie_string); + } + } + + void ReceivedHeaders(const CallHeaders& incoming_headers) override { + const std::pair& + cookie_header_values = incoming_headers.equal_range("set-cookie"); + factory_impl_.UpdateCachedCookies(cookie_header_values); + } + + void CallCompleted(const Status& status) override {} + + private: + ClientCookieMiddlewareFactory::Impl& factory_impl_; + }; + + // Retrieve the cached cookie values as a string. + // + // @return a string that can be used in a Cookie header representing the cookies that + // have been cached. + std::string GetValidCookiesAsString() { + const std::lock_guard guard(mutex_); + + DiscardExpiredCookies(); + if (cookie_cache_.empty()) { + return ""; + } + + std::string cookie_string = cookie_cache_.begin()->second.AsCookieString(); + for (auto it = (++cookie_cache_.begin()); cookie_cache_.end() != it; ++it) { + cookie_string += "; " + it->second.AsCookieString(); + } + return cookie_string; + } + + // Updates the cache of cookies with new Set-Cookie header values. + // + // @param header_values The range representing header values. + void UpdateCachedCookies(const std::pair& header_values) { + const std::lock_guard guard(mutex_); + + for (auto it = header_values.first; it != header_values.second; ++it) { + const util::string_view& value = it->second; + Cookie cookie = Cookie::parse(value); + + // Cache cookies regardless of whether or not they are expired. The server may have + // explicitly sent a Set-Cookie to expire a cached cookie. + std::pair insertable = + cookie_cache_.insert({cookie.cookie_name_, cookie}); + + // Force overwrite on insert collision. + if (!insertable.second) { + insertable.first->second = cookie; + } + } + + DiscardExpiredCookies(); + } + + // Removes cookies that are marked as expired from the cache. + void DiscardExpiredCookies() { + for (auto it = cookie_cache_.begin(); it != cookie_cache_.end();) { + if (it->second.IsExpired()) { + it = cookie_cache_.erase(it); + } else { + ++it; + } + } + } + + // The cached cookies. Access should be protected by mutex_. + CookieCache cookie_cache_; + std::mutex mutex_; +}; + +ClientCookieMiddlewareFactory::ClientCookieMiddlewareFactory() + : impl_(new ClientCookieMiddlewareFactory::Impl()) {} + +ClientCookieMiddlewareFactory::~ClientCookieMiddlewareFactory() {} + +void ClientCookieMiddlewareFactory::StartCall( + const CallInfo& info, std::unique_ptr* middleware) { + impl_->StartCall(info, middleware); +} + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/client_cookie_middleware.h b/cpp/src/arrow/flight/client_cookie_middleware.h new file mode 100644 index 00000000000..5e3386bedbc --- /dev/null +++ b/cpp/src/arrow/flight/client_cookie_middleware.h @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Middleware implementation for sending and receiving HTTP cookies. + +#pragma once + +#include + +#include "arrow/flight/client_middleware.h" + +namespace arrow { +namespace flight { + +/// \brief Client-side middleware for sending/receiving HTTP cookies. +class ARROW_FLIGHT_EXPORT ClientCookieMiddlewareFactory : public ClientMiddlewareFactory { + public: + ClientCookieMiddlewareFactory(); + + ~ClientCookieMiddlewareFactory(); + + void StartCall(const CallInfo& info, + std::unique_ptr* middleware) override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 8d7e944bcdd..5251cf55797 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -38,11 +38,13 @@ #include "arrow/testing/util.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" +#include "arrow/util/string.h" #ifdef GRPCPP_GRPCPP_H #error "gRPC headers should not be in public API" #endif +#include "arrow/flight/client_cookie_middleware.h" #include "arrow/flight/internal.h" #include "arrow/flight/middleware_internal.h" #include "arrow/flight/test_util.h" @@ -1010,6 +1012,159 @@ class TestErrorMiddleware : public ::testing::Test { std::unique_ptr server_; }; +// This test has functions that allow adding and removing cookies from a list of cookies. +// It also automatically adds cookies to an internal list of tracked cookies when they +// are passed to the middleware. +class TestCookieMiddleware : public ::testing::Test { + public: + // Setup function creates middleware factory and starts it up. + void SetUp() { + factory_ = std::make_shared(); + CallInfo callInfo; + factory_->StartCall(callInfo, &middleware_); + } + + // Function to add cookie to middleware. + void AddCookie(const std::string& cookie) { + CallHeaders call_headers; + call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"), + arrow::util::string_view(cookie))); + middleware_->ReceivedHeaders(call_headers); + } + + // Function to get cookies from middleware. + std::string GetCookies() { + TestCallHeaders add_call_headers; + middleware_->SendingHeaders(&add_call_headers); + return add_call_headers.GetCookies(); + } + + // Function to take a list of cookies and split them into a vector of individual + // cookies. + std::vector SplitCookies(const std::string& cookies) { + std::vector split_cookies; + std::string::size_type pos1 = 0; + std::string::size_type pos2 = 0; + while ((pos2 = cookies.find(';', pos1)) != std::string::npos) { + split_cookies.push_back( + arrow::internal::TrimString(cookies.substr(pos1, pos2 - pos1))); + pos1 = pos2 + 1; + } + if (pos1 < cookies.size()) { + split_cookies.push_back(arrow::internal::TrimString(cookies.substr(pos1))); + } + std::sort(split_cookies.begin(), split_cookies.end()); + return split_cookies; + } + + // Function to check if expected cookies and actual cookies are equal. + void CheckEqual() { + std::sort(expected_cookies.begin(), expected_cookies.end()); + std::vector split_actual_cookies = SplitCookies(GetCookies()); + EXPECT_EQ(expected_cookies, split_actual_cookies); + } + + // Function to add incoming cookies to middleware and validate them. + void AddAndValidate(const std::string& incoming_cookie) { + AddCookie(incoming_cookie); + AddExpected(incoming_cookie); + CheckEqual(); + } + + // Function to add an expired cookie to middleware and validate that it was removed. + void AddExpiredAndValidate(const std::string& incoming_cookie) { + AddCookie(incoming_cookie); + CheckEqual(); + } + + // Function to convert format so that it is key="val" and not key=val. + void ConvertFormat(const std::string& incoming_cookie, std::string& outgoing_cookie) { + // Grab the first key value pair from the incoming cookie string. + std::string::size_type pos = 0; + const std::string key_val_pair = + trim(((pos = incoming_cookie.find(';')) != std::string::npos) ? + incoming_cookie.substr(0, pos) : incoming_cookie.substr(0)); + + // Split the key value pair into the key and value. + pos = key_val_pair.find('='); + ASSERT_NE(std::string::npos, pos); + std::string key = key_val_pair.substr(0, pos); + std::string val = key_val_pair.substr(pos + 1); + + // Add key/value pair to outgoing cookie, need to add double quotes around the value + // if there isn't any there already. + outgoing_cookie = + key + "=" + + (((val[0] == '\"') && (val[val.size() - 1] == '\"')) ? val : "\"" + val + "\""); + } + + // Function to find cookie matching a given key. + std::vector::iterator FindCookieByKey(const std::string& key) { + auto collision = [key](const std::string& str1) { + const std::string::size_type pos = str1.find('='); + const std::string old_key = str1.substr(0, pos); + return key == old_key; + }; + return std::find_if(expected_cookies.begin(), expected_cookies.end(), collision); + } + + // Function to add new cookie or replace existing cookie. + void AddOrReplace(const std::string& new_cookie) { + // Get position of '=' to get the key. + const std::string::size_type pos = new_cookie.find('='); + ASSERT_NE(std::string::npos, pos); + + // Get iterator of vector using key. If it exists, replace it, else insert normally. + std::vector::iterator it = FindCookieByKey(new_cookie.substr(0, pos)); + if (it != expected_cookies.end()) { + *it = new_cookie; + } else { + expected_cookies.push_back(new_cookie); + } + } + + // Function to add a cookie to the list of expected cookies. + void AddExpected(const std::string& incoming_cookie) { + std::string new_cookie; + ConvertFormat(incoming_cookie, new_cookie); + AddOrReplace(new_cookie); + } + + // Function to set a cookie as expired. + void SetExpired(const std::string& cookie) { + std::string::size_type pos = cookie.find('='); + ASSERT_NE(std::string::npos, pos); + std::vector::iterator it = FindCookieByKey(cookie.substr(0, pos)); + if (it != expected_cookies.end()) { + it = expected_cookies.erase(it); + } + } + + protected: + // Class to allow testing of the call headers. + class TestCallHeaders : public AddCallHeaders { + public: + TestCallHeaders() {} + ~TestCallHeaders() {} + + // Function to add cookie header. + void AddHeader(const std::string& key, const std::string& value) { + ASSERT_EQ(key, "cookie"); + outbound_cookie = value; + } + + // Function to get outgoing cookie. + std::string GetCookies() { return outbound_cookie; } + + private: + std::string outbound_cookie; + }; + + std::vector expected_cookies; + std::unique_ptr middleware_; + std::shared_ptr factory_; +}; + TEST_F(TestErrorMiddleware, TestMetadata) { Action action; std::unique_ptr stream; @@ -2189,5 +2344,51 @@ TEST_F(TestPropagatingMiddleware, DoPut) { ValidateStatus(status, FlightMethod::DoPut); } +TEST_F(TestCookieMiddleware, BasicParsing) { + AddAndValidate("id1=1; foo=bar;"); + AddAndValidate("id1=1; foo=bar"); + AddAndValidate("id2=2;"); + AddAndValidate("id4=\"4\""); + AddAndValidate("id5=5; foo=bar; baz=buz;"); +} + +TEST_F(TestCookieMiddleware, Overwrite) { + AddAndValidate("id0=0"); + AddAndValidate("id0=1"); + AddAndValidate("id1=0"); + AddAndValidate("id1=1"); + AddAndValidate("id1=1"); + AddAndValidate("id1=10"); + AddAndValidate("id=3"); + AddAndValidate("id=0"); + AddAndValidate("id=0"); +} + +TEST_F(TestCookieMiddleware, MaxAge) { + AddExpiredAndValidate("id0=0; max-age=0;"); + AddExpiredAndValidate("id1=0; max-age=-1;"); + AddExpiredAndValidate("id2=0; max-age=0"); + AddExpiredAndValidate("id3=0; max-age=-1"); + AddAndValidate("id4=0; max-age=1"); + AddAndValidate("id5=0; max-age=1"); + SetExpired("id4=0"); + AddExpiredAndValidate("id4=0; max-age=0"); + SetExpired("id5=0"); + AddExpiredAndValidate("id5=0; max-age=0"); +} + +TEST_F(TestCookieMiddleware, Expires) { + AddExpiredAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT;"); + AddExpiredAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT"); + AddExpiredAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;"); + AddExpiredAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT"); + AddAndValidate("id0=0; expires=Fri, 22 Dec 5000 22:15:36 GMT;"); + AddAndValidate("id1=0; expires=Fri, 22 Dec 5000 22:15:36 GMT"); + SetExpired("id0=0"); + AddExpiredAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;"); + SetExpired("id1=0"); + AddExpiredAndValidate("id1=0; expires=Fri, 22 Dec 2017 22:15:36 GMT"); +} + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/util/uri.cc b/cpp/src/arrow/util/uri.cc index f48a8cb6ed1..35c6b898177 100644 --- a/cpp/src/arrow/util/uri.cc +++ b/cpp/src/arrow/util/uri.cc @@ -55,15 +55,6 @@ bool IsDriveSpec(const util::string_view s) { } #endif -std::string UriUnescape(const util::string_view s) { - std::string result(s); - if (!result.empty()) { - auto end = uriUnescapeInPlaceA(&result[0]); - result.resize(end - &result[0]); - } - return result; -} - } // namespace std::string UriEscape(const std::string& s) { @@ -80,6 +71,15 @@ std::string UriEscape(const std::string& s) { return escaped; } +std::string UriUnescape(const util::string_view s) { + std::string result(s); + if (!result.empty()) { + auto end = uriUnescapeInPlaceA(&result[0]); + result.resize(end - &result[0]); + } + return result; +} + std::string UriEncodeHost(const std::string& host) { // Fairly naive check: if it contains a ':', it's IPv6 and needs // brackets, else it's OK diff --git a/cpp/src/arrow/util/uri.h b/cpp/src/arrow/util/uri.h index 480e35474a3..b4ffbb04dec 100644 --- a/cpp/src/arrow/util/uri.h +++ b/cpp/src/arrow/util/uri.h @@ -24,6 +24,7 @@ #include #include "arrow/type_fwd.h" +#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -91,6 +92,9 @@ class ARROW_EXPORT Uri { ARROW_EXPORT std::string UriEscape(const std::string& s); +ARROW_EXPORT +std::string UriUnescape(const arrow::util::string_view s); + /// Encode a host for use within a URI, such as "localhost", /// "127.0.0.1", or "[::1]". ARROW_EXPORT