Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/src/arrow/flight/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ if(ARROW_TESTING)
OUTPUTS
ARROW_FLIGHT_TESTING_LIBRARIES
SOURCES
test_auth_handlers.cc
test_definitions.cc
test_flight_server.cc
test_util.cc
DEPENDENCIES
flight_grpc_gen
Expand Down
8 changes: 5 additions & 3 deletions cpp/src/arrow/flight/flight_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
// Include before test_util.h (boost), contains Windows fixes
#include "arrow/flight/platform.h"
#include "arrow/flight/serialization_internal.h"
#include "arrow/flight/test_auth_handlers.h"
#include "arrow/flight/test_definitions.h"
#include "arrow/flight/test_flight_server.h"
#include "arrow/flight/test_util.h"
// OTel includes must come after any gRPC includes, and
// client_header_internal.h includes gRPC. See:
Expand Down Expand Up @@ -247,7 +249,7 @@ TEST(TestFlight, ConnectUriUnix) {

// CI environments don't have an IPv6 interface configured
TEST(TestFlight, DISABLED_IpV6Port) {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("[::1]", 0));
FlightServerOptions options(location);
Expand All @@ -261,7 +263,7 @@ TEST(TestFlight, DISABLED_IpV6Port) {
}

TEST(TestFlight, ServerCallContextIncomingHeaders) {
auto server = ExampleTestServer();
auto server = TestFlightServer::Make();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
Expand Down Expand Up @@ -290,7 +292,7 @@ TEST(TestFlight, ServerCallContextIncomingHeaders) {
class TestFlightClient : public ::testing::Test {
public:
void SetUp() {
server_ = ExampleTestServer();
server_ = TestFlightServer::Make();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0));
FlightServerOptions options(location);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "arrow/flight/sql/server.h"
#include "arrow/flight/sql/server_session_middleware.h"
#include "arrow/flight/sql/types.h"
#include "arrow/flight/test_auth_handlers.h"
#include "arrow/flight/test_util.h"
#include "arrow/flight/types.h"
#include "arrow/ipc/dictionary.h"
Expand Down
141 changes: 141 additions & 0 deletions cpp/src/arrow/flight/test_auth_handlers.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <string>

#include "arrow/flight/client_auth.h"
#include "arrow/flight/server.h"
#include "arrow/flight/server_auth.h"
#include "arrow/flight/test_auth_handlers.h"
#include "arrow/flight/types.h"
#include "arrow/flight/visibility.h"
#include "arrow/status.h"

namespace arrow::flight {

// TestServerAuthHandler

TestServerAuthHandler::TestServerAuthHandler(const std::string& username,
const std::string& password)
: username_(username), password_(password) {}

TestServerAuthHandler::~TestServerAuthHandler() {}

Status TestServerAuthHandler::Authenticate(const ServerCallContext& context,
ServerAuthSender* outgoing,
ServerAuthReader* incoming) {
std::string token;
RETURN_NOT_OK(incoming->Read(&token));
if (token != password_) {
return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
}
RETURN_NOT_OK(outgoing->Write(username_));
return Status::OK();
}

Status TestServerAuthHandler::IsValid(const ServerCallContext& context,
const std::string& token,
std::string* peer_identity) {
if (token != password_) {
return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
}
*peer_identity = username_;
return Status::OK();
}

// TestServerBasicAuthHandler

TestServerBasicAuthHandler::TestServerBasicAuthHandler(const std::string& username,
const std::string& password) {
basic_auth_.username = username;
basic_auth_.password = password;
}

TestServerBasicAuthHandler::~TestServerBasicAuthHandler() {}

Status TestServerBasicAuthHandler::Authenticate(const ServerCallContext& context,
ServerAuthSender* outgoing,
ServerAuthReader* incoming) {
std::string token;
RETURN_NOT_OK(incoming->Read(&token));
ARROW_ASSIGN_OR_RAISE(BasicAuth incoming_auth, BasicAuth::Deserialize(token));
if (incoming_auth.username != basic_auth_.username ||
incoming_auth.password != basic_auth_.password) {
return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
}
RETURN_NOT_OK(outgoing->Write(basic_auth_.username));
return Status::OK();
}

Status TestServerBasicAuthHandler::IsValid(const ServerCallContext& context,
const std::string& token,
std::string* peer_identity) {
if (token != basic_auth_.username) {
return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
}
*peer_identity = basic_auth_.username;
return Status::OK();
}

// TestClientAuthHandler

TestClientAuthHandler::TestClientAuthHandler(const std::string& username,
const std::string& password)
: username_(username), password_(password) {}

TestClientAuthHandler::~TestClientAuthHandler() {}

Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing,
ClientAuthReader* incoming) {
RETURN_NOT_OK(outgoing->Write(password_));
std::string username;
RETURN_NOT_OK(incoming->Read(&username));
if (username != username_) {
return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
}
return Status::OK();
}

Status TestClientAuthHandler::GetToken(std::string* token) {
*token = password_;
return Status::OK();
}

// TestClientBasicAuthHandler

TestClientBasicAuthHandler::TestClientBasicAuthHandler(const std::string& username,
const std::string& password) {
basic_auth_.username = username;
basic_auth_.password = password;
}

TestClientBasicAuthHandler::~TestClientBasicAuthHandler() {}

Status TestClientBasicAuthHandler::Authenticate(ClientAuthSender* outgoing,
ClientAuthReader* incoming) {
ARROW_ASSIGN_OR_RAISE(std::string pb_result, basic_auth_.SerializeToString());
RETURN_NOT_OK(outgoing->Write(pb_result));
RETURN_NOT_OK(incoming->Read(&token_));
return Status::OK();
}

Status TestClientBasicAuthHandler::GetToken(std::string* token) {
*token = token_;
return Status::OK();
}

} // namespace arrow::flight
89 changes: 89 additions & 0 deletions cpp/src/arrow/flight/test_auth_handlers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include <string>

#include "arrow/flight/client_auth.h"
#include "arrow/flight/server.h"
#include "arrow/flight/server_auth.h"
#include "arrow/flight/types.h"
#include "arrow/flight/visibility.h"
#include "arrow/status.h"

// A pair of authentication handlers that check for a predefined password
// and set the peer identity to a predefined username.

namespace arrow::flight {

class ARROW_FLIGHT_EXPORT TestServerAuthHandler : public ServerAuthHandler {
public:
explicit TestServerAuthHandler(const std::string& username,
const std::string& password);
~TestServerAuthHandler() override;
Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing,
ServerAuthReader* incoming) override;
Status IsValid(const ServerCallContext& context, const std::string& token,
std::string* peer_identity) override;

private:
std::string username_;
std::string password_;
};

class ARROW_FLIGHT_EXPORT TestServerBasicAuthHandler : public ServerAuthHandler {
public:
explicit TestServerBasicAuthHandler(const std::string& username,
const std::string& password);
~TestServerBasicAuthHandler() override;
Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing,
ServerAuthReader* incoming) override;
Status IsValid(const ServerCallContext& context, const std::string& token,
std::string* peer_identity) override;

private:
BasicAuth basic_auth_;
};

class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler {
public:
explicit TestClientAuthHandler(const std::string& username,
const std::string& password);
~TestClientAuthHandler() override;
Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override;
Status GetToken(std::string* token) override;

private:
std::string username_;
std::string password_;
};

class ARROW_FLIGHT_EXPORT TestClientBasicAuthHandler : public ClientAuthHandler {
public:
explicit TestClientBasicAuthHandler(const std::string& username,
const std::string& password);
~TestClientBasicAuthHandler() override;
Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override;
Status GetToken(std::string* token) override;

private:
BasicAuth basic_auth_;
std::string token_;
};

} // namespace arrow::flight
15 changes: 8 additions & 7 deletions cpp/src/arrow/flight/test_definitions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "arrow/array/util.h"
#include "arrow/flight/api.h"
#include "arrow/flight/client_middleware.h"
#include "arrow/flight/test_flight_server.h"
#include "arrow/flight/test_util.h"
#include "arrow/flight/types.h"
#include "arrow/flight/types_async.h"
Expand All @@ -53,15 +54,15 @@ using arrow::internal::checked_cast;
// Tests of initialization/shutdown

void ConnectivityTest::TestGetPort() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
ASSERT_GT(server->port(), 0);
}
void ConnectivityTest::TestBuilderHook() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
Expand All @@ -80,7 +81,7 @@ void ConnectivityTest::TestShutdown() {
constexpr int kIterations = 10;
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
for (int i = 0; i < kIterations; i++) {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();

FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
Expand All @@ -92,7 +93,7 @@ void ConnectivityTest::TestShutdown() {
}
}
void ConnectivityTest::TestShutdownWithDeadline() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
Expand All @@ -105,7 +106,7 @@ void ConnectivityTest::TestShutdownWithDeadline() {
ASSERT_OK(server->Wait());
}
void ConnectivityTest::TestBrokenConnection() {
std::unique_ptr<FlightServerBase> server = ExampleTestServer();
std::unique_ptr<FlightServerBase> server = TestFlightServer::Make();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));
Expand Down Expand Up @@ -151,7 +152,7 @@ class GetFlightInfoListener : public AsyncListener<FlightInfo> {
} // namespace

void DataTest::SetUpTest() {
server_ = ExampleTestServer();
server_ = TestFlightServer::Make();

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));
FlightServerOptions options(location);
Expand Down Expand Up @@ -1822,7 +1823,7 @@ void AsyncClientTest::SetUpTest() {

ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0));

server_ = ExampleTestServer();
server_ = TestFlightServer::Make();
FlightServerOptions server_options(location);
ASSERT_OK(server_->Init(server_options));

Expand Down
Loading