-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-10487 [FlightRPC][C++] Header-based auth in clients #8724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
c2ba2c7
ab7d6a3
ecb533f
73be3e7
74ef8ea
29a3192
fac0bd0
7fd1279
0b5a08e
6861cf0
01a134b
de78c6b
c44698f
e975fd8
37889fa
e7ac27c
ba7cb9f
1426252
516d993
3000ecb
1de10fa
065af4a
d4da03b
911fcc7
f41edce
6b6fbbe
199b655
47aa581
477d865
1cc3fdb
d27465d
d21006f
6cd8a45
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,6 +51,7 @@ | |
|
|
||
| #include "arrow/flight/client_auth.h" | ||
| #include "arrow/flight/client_middleware.h" | ||
| #include "arrow/flight/client_header_auth_middleware.h" | ||
| #include "arrow/flight/internal.h" | ||
| #include "arrow/flight/middleware.h" | ||
| #include "arrow/flight/middleware_internal.h" | ||
|
|
@@ -104,6 +105,9 @@ struct ClientRpc { | |
| std::chrono::system_clock::now() + options.timeout); | ||
| context.set_deadline(deadline); | ||
| } | ||
| for (auto metadata : options.metadata) { | ||
| context.AddMetadata(metadata.first, metadata.second); | ||
| } | ||
| } | ||
|
|
||
| /// \brief Add an auth token via an auth handler | ||
|
|
@@ -328,7 +332,7 @@ class GrpcClientInterceptorAdapterFactory | |
| : public grpc::experimental::ClientInterceptorFactoryInterface { | ||
| public: | ||
| GrpcClientInterceptorAdapterFactory( | ||
| std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware) | ||
| std::vector<std::shared_ptr<ClientMiddlewareFactory>>& middleware) | ||
|
||
| : middleware_(middleware) {} | ||
|
|
||
| grpc::experimental::Interceptor* CreateClientInterceptor( | ||
|
|
@@ -371,7 +375,7 @@ class GrpcClientInterceptorAdapterFactory | |
| } | ||
|
|
||
| private: | ||
| std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware_; | ||
| std::vector<std::shared_ptr<ClientMiddlewareFactory>>& middleware_; | ||
|
||
| }; | ||
|
|
||
| class GrpcClientAuthSender : public ClientAuthSender { | ||
|
|
@@ -963,8 +967,9 @@ class FlightClient::FlightClientImpl { | |
|
|
||
| std::vector<std::unique_ptr<grpc::experimental::ClientInterceptorFactoryInterface>> | ||
| interceptors; | ||
| middleware = std::move(options.middleware); | ||
|
||
| interceptors.emplace_back( | ||
| new GrpcClientInterceptorAdapterFactory(std::move(options.middleware))); | ||
| new GrpcClientInterceptorAdapterFactory(middleware)); | ||
|
|
||
| stub_ = pb::FlightService::NewStub( | ||
| grpc::experimental::CreateCustomChannelWithInterceptors( | ||
|
|
@@ -993,6 +998,30 @@ class FlightClient::FlightClientImpl { | |
| return Status::OK(); | ||
| } | ||
|
|
||
| Status AuthenticateBasicToken(std::string username, std::string password, | ||
|
||
| std::pair<std::string, std::string>* bearer_token) { | ||
| // Add bearer token factory to middleware so it can intercept the bearer token. | ||
| middleware.push_back(std::make_shared<ClientBearerTokenFactory>(bearer_token)); | ||
|
||
| ClientRpc rpc({}); | ||
|
||
| AddBasicAuthHeaders(&rpc.context, username, password); | ||
| std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>> | ||
| stream = stub_->Handshake(&rpc.context); | ||
|
|
||
| GrpcClientAuthSender outgoing{stream}; | ||
| GrpcClientAuthReader incoming{stream}; | ||
| // Explicitly close our side of the connection | ||
| bool finished_writes = stream->WritesDone(); | ||
| middleware.pop_back(); | ||
| RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context)); | ||
| if (!finished_writes) { | ||
| return MakeFlightError(FlightStatusCode::Internal, | ||
| "Could not finish writing before closing"); | ||
| } | ||
| return Status::OK(); | ||
| } | ||
|
|
||
|
|
||
|
|
||
| Status ListFlights(const FlightCallOptions& options, const Criteria& criteria, | ||
| std::unique_ptr<FlightListing>* listing) { | ||
| pb::Criteria pb_criteria; | ||
|
|
@@ -1174,6 +1203,7 @@ class FlightClient::FlightClientImpl { | |
| GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS::TlsServerAuthorizationCheckConfig> | ||
| noop_auth_check_; | ||
| #endif | ||
| std::vector<std::shared_ptr<arrow::flight::ClientMiddlewareFactory>> middleware; | ||
| int64_t write_size_limit_bytes_; | ||
| }; | ||
|
|
||
|
|
@@ -1197,6 +1227,12 @@ Status FlightClient::Authenticate(const FlightCallOptions& options, | |
| return impl_->Authenticate(options, std::move(auth_handler)); | ||
| } | ||
|
|
||
| Status FlightClient::AuthenticateBasicToken( | ||
|
||
| std::string username, std::string password, | ||
| std::pair<std::string, std::string>* bearer_token) { | ||
| return impl_->AuthenticateBasicToken(username, password, bearer_token); | ||
| } | ||
|
|
||
| Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action, | ||
| std::unique_ptr<ResultStream>* results) { | ||
| return impl_->DoAction(options, action, results); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -65,6 +65,9 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions { | |
|
|
||
| /// \brief IPC writer options, if applicable for the call. | ||
| ipc::IpcWriteOptions write_options; | ||
|
|
||
| /// \brief Metadata for client to add to context. | ||
| std::vector<std::pair<std::string, std::string>> metadata; | ||
|
||
| }; | ||
|
|
||
| /// \brief Indicate that the client attempted to write a message | ||
|
|
@@ -191,6 +194,14 @@ class ARROW_FLIGHT_EXPORT FlightClient { | |
| Status Authenticate(const FlightCallOptions& options, | ||
| std::unique_ptr<ClientAuthHandler> auth_handler); | ||
|
|
||
| /// \brief Authenticate to the server using the given handler. | ||
|
||
| /// \param[in] username Username to use | ||
| /// \param[in] password Password to use | ||
| /// \param[in] bearer_token Bearer token retreived if applicable | ||
| /// \return Status OK if the client authenticated successfully | ||
| Status AuthenticateBasicToken(std::string username, std::string password, | ||
| std::pair<std::string, std::string>* bearer_token); | ||
|
|
||
| /// \brief Perform the indicated action, returning an iterator to the stream | ||
| /// of results, if any | ||
| /// \param[in] options Per-RPC options | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| // Licensed to the Apache Software Foundation (ASF) under one | ||
| // or more contributor license agreements. See the NOTICE file | ||
| // distributed with this work for additional information | ||
| // regarding copyright ownership. The ASF licenses this file | ||
| // to you under the Apache License, Version 2.0 (the | ||
| // "License"); you may not use this file except in compliance | ||
| // with the License. You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, | ||
| // software distributed under the License is distributed on an | ||
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| // KIND, either express or implied. See the License for the | ||
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| // Interfaces for defining middleware for Flight clients. Currently | ||
| // experimental. | ||
|
|
||
| #include "client_header_auth_middleware.h" | ||
| #include "client_middleware.h" | ||
| #include "client_auth.h" | ||
| #include "client.h" | ||
|
|
||
| namespace arrow { | ||
| namespace flight { | ||
|
|
||
| std::string base64_encode(const std::string& input); | ||
|
|
||
| ClientBearerTokenMiddleware::ClientBearerTokenMiddleware( | ||
| std::pair<std::string, std::string>* bearer_token_) | ||
| : bearer_token(bearer_token_) { } | ||
|
|
||
| void ClientBearerTokenMiddleware::SendingHeaders(AddCallHeaders* outgoing_headers) { } | ||
|
|
||
| void ClientBearerTokenMiddleware::ReceivedHeaders( | ||
| const CallHeaders& incoming_headers) { | ||
| // Grab the auth token if one exists. | ||
| auto bearer_iter = incoming_headers.find(AUTH_HEADER); | ||
|
||
| if (bearer_iter == incoming_headers.end()) { | ||
| return; | ||
| } | ||
|
|
||
| // Check if the value of the auth token starts with the bearer prefix, latch the token. | ||
| std::string bearer_val = bearer_iter->second.to_string(); | ||
| if (bearer_val.size() > BEARER_PREFIX.size()) { | ||
| bool hasPrefix = std::equal(bearer_val.begin(), bearer_val.begin() + BEARER_PREFIX.size(), BEARER_PREFIX.begin(), | ||
|
||
| [] (const char& char1, const char& char2) { | ||
| return (std::toupper(char1) == std::toupper(char2)); | ||
| } | ||
| ); | ||
| if (hasPrefix) { | ||
| *bearer_token = std::make_pair(AUTH_HEADER, bearer_val); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void ClientBearerTokenMiddleware::CallCompleted(const Status& status) { } | ||
|
|
||
| void ClientBearerTokenFactory::StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware) { | ||
|
||
| *middleware = std::unique_ptr<ClientBearerTokenMiddleware>(new ClientBearerTokenMiddleware(bearer_token)); | ||
|
||
| } | ||
|
|
||
| void ClientBearerTokenFactory::Reset() { | ||
| *bearer_token = std::make_pair("", ""); | ||
| } | ||
|
|
||
| template<typename ... Args> | ||
|
||
| std::string string_format(const std::string& format, const Args... args) { | ||
| // Check size requirement for new string and increment by 1 for null terminator. | ||
| size_t size = std::snprintf(nullptr, 0, format.c_str(), args ...) + 1; | ||
| if(size <= 0){ | ||
|
||
| throw std::runtime_error("Error during string formatting. Format: '" + format + "'."); | ||
|
||
| } | ||
|
|
||
| // Create buffer for new string and write string in. | ||
| std::unique_ptr<char[]> buf(new char[size]); | ||
| std::snprintf(buf.get(), size, format.c_str(), args...); | ||
|
|
||
| // Convert to std::string, subtracting size by 1 to trim null terminator. | ||
| return std::string(buf.get(), buf.get() + size - 1); | ||
| } | ||
|
|
||
| void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password) { | ||
| const std::string formatted_credentials = string_format("%s:%s", username.c_str(), password.c_str()); | ||
| context->AddMetadata(AUTH_HEADER, BASIC_PREFIX + base64_encode(formatted_credentials)); | ||
| } | ||
|
|
||
| std::string base64_encode(const std::string& input) { | ||
| static const std::string base64_chars = | ||
|
||
| "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; | ||
| auto get_encoded_length = [] (const std::string& in) { | ||
| return 4 * ((in.size() + 2) / 3); | ||
| }; | ||
| auto get_overwrite_count = [] (const std::string& in) { | ||
| const std::string::size_type remainder = in.length() % 3; | ||
| return (remainder > 0) ? (3 - (remainder % 3)) : 0; | ||
| }; | ||
|
|
||
| // Generate string with required length for encoding. | ||
| std::string encoded; | ||
| encoded.reserve(get_encoded_length(input)); | ||
|
|
||
| // Loop through input writing base64 characters to string. | ||
| for (int i = 0; i < input.length();) { | ||
| uint32_t octet_1 = i < input.length() ? (unsigned char)input[i++] : 0; | ||
| uint32_t octet_2 = i < input.length() ? (unsigned char)input[i++] : 0; | ||
| uint32_t octet_3 = i < input.length() ? (unsigned char)input[i++] : 0; | ||
| uint32_t octriple = (octet_1 << 0x10) + (octet_2 << 0x08) + octet_3; | ||
| for (int j = 3; j >= 0; j--) { | ||
| encoded.push_back(base64_chars[(octriple >> j * 6) & 0x3F]); | ||
| } | ||
| } | ||
|
|
||
| // Round up to nearest multiple of 3 and replace characters at end based on rounding. | ||
| int overwrite_count = get_overwrite_count(input); | ||
| encoded.replace(encoded.length() - overwrite_count, | ||
| encoded.length(), | ||
| overwrite_count, '='); | ||
| return encoded; | ||
| } | ||
| } // namespace flight | ||
| } // namespace arrow | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| // Licensed to the Apache Software Foundation (ASF) under one | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably be an _internal.h header since I don't think any of this is intended to be directly used outside of the implementation here. |
||
| // or more contributor license agreements. See the NOTICE file | ||
| // distributed with this work for additional information | ||
| // regarding copyright ownership. The ASF licenses this file | ||
| // to you under the Apache License, Version 2.0 (the | ||
| // "License"); you may not use this file except in compliance | ||
| // with the License. You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, | ||
| // software distributed under the License is distributed on an | ||
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| // KIND, either express or implied. See the License for the | ||
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| // Interfaces for defining middleware for Flight clients. Currently | ||
| // experimental. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "arrow/flight/client_middleware.h" | ||
| #include "arrow/flight/client_auth.h" | ||
| #include "arrow/flight/client.h" | ||
|
|
||
| #ifdef GRPCPP_PP_INCLUDE | ||
| #include <grpcpp/grpcpp.h> | ||
| #if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) | ||
| #include <grpcpp/security/tls_credentials_options.h> | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need this include? |
||
| #endif | ||
| #else | ||
| #include <grpc++/grpc++.h> | ||
| #endif | ||
|
|
||
| #include <algorithm> | ||
| #include <iostream> | ||
| #include <cctype> | ||
| #include <string> | ||
|
|
||
| const std::string AUTH_HEADER = "authorization"; | ||
|
||
| const std::string BEARER_PREFIX = "Bearer "; | ||
| const std::string BASIC_PREFIX = "Basic "; | ||
|
|
||
| namespace arrow { | ||
| namespace flight { | ||
|
|
||
| // TODO: Need to add documentation in this file. | ||
| void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, | ||
|
||
| const std::string& username, | ||
| const std::string& password); | ||
|
|
||
| class ARROW_FLIGHT_EXPORT ClientBearerTokenMiddleware : public ClientMiddleware { | ||
| public: | ||
| explicit ClientBearerTokenMiddleware( | ||
| std::pair<std::string, std::string>* bearer_token_); | ||
|
|
||
| void SendingHeaders(AddCallHeaders* outgoing_headers); | ||
| void ReceivedHeaders(const CallHeaders& incoming_headers); | ||
| void CallCompleted(const Status& status); | ||
|
|
||
| private: | ||
| std::pair<std::string, std::string>* bearer_token; | ||
| }; | ||
|
|
||
| class ARROW_FLIGHT_EXPORT ClientBearerTokenFactory : public ClientMiddlewareFactory { | ||
| public: | ||
| explicit ClientBearerTokenFactory(std::pair<std::string, std::string>* bearer_token_) | ||
| : bearer_token(bearer_token_) {} | ||
|
|
||
| void StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware); | ||
| void Reset(); | ||
|
|
||
| private: | ||
| std::pair<std::string, std::string>* bearer_token; | ||
| }; | ||
| } // namespace flight | ||
| } // namespace arrow | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this ordering alphabetical? If so this should be up one...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, was waiting for feedback on design before proper formatting. I will correct this.