diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 43ac48b8767..98f93705f6f 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -262,7 +262,9 @@ if(ARROW_TESTING) OUTPUTS ARROW_FLIGHT_TESTING_LIBRARIES SOURCES + test_auth_handlers.cc test_definitions.cc + test_flight_server.cc test_util.cc DEPENDENCIES flight_grpc_gen diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 101bb06b212..3d52bc3f5ae 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -52,7 +52,9 @@ // Include before test_util.h (boost), contains Windows fixes #include "arrow/flight/platform.h" #include "arrow/flight/serialization_internal.h" +#include "arrow/flight/test_auth_handlers.h" #include "arrow/flight/test_definitions.h" +#include "arrow/flight/test_flight_server.h" #include "arrow/flight/test_util.h" // OTel includes must come after any gRPC includes, and // client_header_internal.h includes gRPC. See: @@ -247,7 +249,7 @@ TEST(TestFlight, ConnectUriUnix) { // CI environments don't have an IPv6 interface configured TEST(TestFlight, DISABLED_IpV6Port) { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("[::1]", 0)); FlightServerOptions options(location); @@ -261,7 +263,7 @@ TEST(TestFlight, DISABLED_IpV6Port) { } TEST(TestFlight, ServerCallContextIncomingHeaders) { - auto server = ExampleTestServer(); + auto server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0)); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); @@ -290,7 +292,7 @@ TEST(TestFlight, ServerCallContextIncomingHeaders) { class TestFlightClient : public ::testing::Test { public: void SetUp() { - server_ = ExampleTestServer(); + server_ = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0)); FlightServerOptions options(location); diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index 665c1f1ba03..da6fcf81eb7 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -36,6 +36,7 @@ #include "arrow/flight/sql/server.h" #include "arrow/flight/sql/server_session_middleware.h" #include "arrow/flight/sql/types.h" +#include "arrow/flight/test_auth_handlers.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/ipc/dictionary.h" diff --git a/cpp/src/arrow/flight/test_auth_handlers.cc b/cpp/src/arrow/flight/test_auth_handlers.cc new file mode 100644 index 00000000000..856ccf0f2b2 --- /dev/null +++ b/cpp/src/arrow/flight/test_auth_handlers.cc @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/flight/client_auth.h" +#include "arrow/flight/server.h" +#include "arrow/flight/server_auth.h" +#include "arrow/flight/test_auth_handlers.h" +#include "arrow/flight/types.h" +#include "arrow/flight/visibility.h" +#include "arrow/status.h" + +namespace arrow::flight { + +// TestServerAuthHandler + +TestServerAuthHandler::TestServerAuthHandler(const std::string& username, + const std::string& password) + : username_(username), password_(password) {} + +TestServerAuthHandler::~TestServerAuthHandler() {} + +Status TestServerAuthHandler::Authenticate(const ServerCallContext& context, + ServerAuthSender* outgoing, + ServerAuthReader* incoming) { + std::string token; + RETURN_NOT_OK(incoming->Read(&token)); + if (token != password_) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + RETURN_NOT_OK(outgoing->Write(username_)); + return Status::OK(); +} + +Status TestServerAuthHandler::IsValid(const ServerCallContext& context, + const std::string& token, + std::string* peer_identity) { + if (token != password_) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + *peer_identity = username_; + return Status::OK(); +} + +// TestServerBasicAuthHandler + +TestServerBasicAuthHandler::TestServerBasicAuthHandler(const std::string& username, + const std::string& password) { + basic_auth_.username = username; + basic_auth_.password = password; +} + +TestServerBasicAuthHandler::~TestServerBasicAuthHandler() {} + +Status TestServerBasicAuthHandler::Authenticate(const ServerCallContext& context, + ServerAuthSender* outgoing, + ServerAuthReader* incoming) { + std::string token; + RETURN_NOT_OK(incoming->Read(&token)); + ARROW_ASSIGN_OR_RAISE(BasicAuth incoming_auth, BasicAuth::Deserialize(token)); + if (incoming_auth.username != basic_auth_.username || + incoming_auth.password != basic_auth_.password) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + RETURN_NOT_OK(outgoing->Write(basic_auth_.username)); + return Status::OK(); +} + +Status TestServerBasicAuthHandler::IsValid(const ServerCallContext& context, + const std::string& token, + std::string* peer_identity) { + if (token != basic_auth_.username) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + *peer_identity = basic_auth_.username; + return Status::OK(); +} + +// TestClientAuthHandler + +TestClientAuthHandler::TestClientAuthHandler(const std::string& username, + const std::string& password) + : username_(username), password_(password) {} + +TestClientAuthHandler::~TestClientAuthHandler() {} + +Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing, + ClientAuthReader* incoming) { + RETURN_NOT_OK(outgoing->Write(password_)); + std::string username; + RETURN_NOT_OK(incoming->Read(&username)); + if (username != username_) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + return Status::OK(); +} + +Status TestClientAuthHandler::GetToken(std::string* token) { + *token = password_; + return Status::OK(); +} + +// TestClientBasicAuthHandler + +TestClientBasicAuthHandler::TestClientBasicAuthHandler(const std::string& username, + const std::string& password) { + basic_auth_.username = username; + basic_auth_.password = password; +} + +TestClientBasicAuthHandler::~TestClientBasicAuthHandler() {} + +Status TestClientBasicAuthHandler::Authenticate(ClientAuthSender* outgoing, + ClientAuthReader* incoming) { + ARROW_ASSIGN_OR_RAISE(std::string pb_result, basic_auth_.SerializeToString()); + RETURN_NOT_OK(outgoing->Write(pb_result)); + RETURN_NOT_OK(incoming->Read(&token_)); + return Status::OK(); +} + +Status TestClientBasicAuthHandler::GetToken(std::string* token) { + *token = token_; + return Status::OK(); +} + +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_auth_handlers.h b/cpp/src/arrow/flight/test_auth_handlers.h new file mode 100644 index 00000000000..74f48798f3b --- /dev/null +++ b/cpp/src/arrow/flight/test_auth_handlers.h @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/flight/client_auth.h" +#include "arrow/flight/server.h" +#include "arrow/flight/server_auth.h" +#include "arrow/flight/types.h" +#include "arrow/flight/visibility.h" +#include "arrow/status.h" + +// A pair of authentication handlers that check for a predefined password +// and set the peer identity to a predefined username. + +namespace arrow::flight { + +class ARROW_FLIGHT_EXPORT TestServerAuthHandler : public ServerAuthHandler { + public: + explicit TestServerAuthHandler(const std::string& username, + const std::string& password); + ~TestServerAuthHandler() override; + Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing, + ServerAuthReader* incoming) override; + Status IsValid(const ServerCallContext& context, const std::string& token, + std::string* peer_identity) override; + + private: + std::string username_; + std::string password_; +}; + +class ARROW_FLIGHT_EXPORT TestServerBasicAuthHandler : public ServerAuthHandler { + public: + explicit TestServerBasicAuthHandler(const std::string& username, + const std::string& password); + ~TestServerBasicAuthHandler() override; + Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing, + ServerAuthReader* incoming) override; + Status IsValid(const ServerCallContext& context, const std::string& token, + std::string* peer_identity) override; + + private: + BasicAuth basic_auth_; +}; + +class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler { + public: + explicit TestClientAuthHandler(const std::string& username, + const std::string& password); + ~TestClientAuthHandler() override; + Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override; + Status GetToken(std::string* token) override; + + private: + std::string username_; + std::string password_; +}; + +class ARROW_FLIGHT_EXPORT TestClientBasicAuthHandler : public ClientAuthHandler { + public: + explicit TestClientBasicAuthHandler(const std::string& username, + const std::string& password); + ~TestClientBasicAuthHandler() override; + Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override; + Status GetToken(std::string* token) override; + + private: + BasicAuth basic_auth_; + std::string token_; +}; + +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc index c43b693d84a..273d394c288 100644 --- a/cpp/src/arrow/flight/test_definitions.cc +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -27,6 +27,7 @@ #include "arrow/array/util.h" #include "arrow/flight/api.h" #include "arrow/flight/client_middleware.h" +#include "arrow/flight/test_flight_server.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/flight/types_async.h" @@ -53,7 +54,7 @@ using arrow::internal::checked_cast; // Tests of initialization/shutdown void ConnectivityTest::TestGetPort() { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); @@ -61,7 +62,7 @@ void ConnectivityTest::TestGetPort() { ASSERT_GT(server->port(), 0); } void ConnectivityTest::TestBuilderHook() { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); @@ -80,7 +81,7 @@ void ConnectivityTest::TestShutdown() { constexpr int kIterations = 10; ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); for (int i = 0; i < kIterations; i++) { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); @@ -92,7 +93,7 @@ void ConnectivityTest::TestShutdown() { } } void ConnectivityTest::TestShutdownWithDeadline() { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); @@ -105,7 +106,7 @@ void ConnectivityTest::TestShutdownWithDeadline() { ASSERT_OK(server->Wait()); } void ConnectivityTest::TestBrokenConnection() { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); @@ -151,7 +152,7 @@ class GetFlightInfoListener : public AsyncListener { } // namespace void DataTest::SetUpTest() { - server_ = ExampleTestServer(); + server_ = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); @@ -1822,7 +1823,7 @@ void AsyncClientTest::SetUpTest() { ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); - server_ = ExampleTestServer(); + server_ = TestFlightServer::Make(); FlightServerOptions server_options(location); ASSERT_OK(server_->Init(server_options)); diff --git a/cpp/src/arrow/flight/test_flight_server.cc b/cpp/src/arrow/flight/test_flight_server.cc new file mode 100644 index 00000000000..0ea95ebd15b --- /dev/null +++ b/cpp/src/arrow/flight/test_flight_server.cc @@ -0,0 +1,417 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/flight/test_flight_server.h" + +#include "arrow/array/array_base.h" +#include "arrow/array/array_primitive.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/flight/server.h" +#include "arrow/flight/test_util.h" +#include "arrow/flight/type_fwd.h" +#include "arrow/status.h" + +namespace arrow::flight { +namespace { + +class ErrorRecordBatchReader : public RecordBatchReader { + public: + ErrorRecordBatchReader() : schema_(arrow::schema({})) {} + + std::shared_ptr schema() const override { return schema_; } + + Status ReadNext(std::shared_ptr* out) override { + *out = nullptr; + return Status::OK(); + } + + Status Close() override { + // This should be propagated over DoGet to the client + return Status::IOError("Expected error"); + } + + private: + std::shared_ptr schema_; +}; + +Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr* out) { + if (ticket.ticket == "ticket-ints-1") { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); + return Status::OK(); + } else if (ticket.ticket == "ticket-floats-1") { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleFloatBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); + return Status::OK(); + } else if (ticket.ticket == "ticket-dicts-1") { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleDictBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); + return Status::OK(); + } else if (ticket.ticket == "ticket-large-batch-1") { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleLargeBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); + return Status::OK(); + } else { + return Status::NotImplemented("no stream implemented for ticket: " + ticket.ticket); + } +} + +} // namespace + +std::unique_ptr TestFlightServer::Make() { + return std::make_unique(); +} + +Status TestFlightServer::ListFlights(const ServerCallContext& context, + const Criteria* criteria, + std::unique_ptr* listings) { + std::vector flights = ExampleFlightInfo(); + if (criteria && criteria->expression != "") { + // For test purposes, if we get criteria, return no results + flights.clear(); + } + *listings = std::make_unique(flights); + return Status::OK(); +} + +Status TestFlightServer::GetFlightInfo(const ServerCallContext& context, + const FlightDescriptor& request, + std::unique_ptr* out) { + // Test that Arrow-C++ status codes make it through the transport + if (request.type == FlightDescriptor::DescriptorType::CMD && + request.cmd == "status-outofmemory") { + return Status::OutOfMemory("Sentinel"); + } + + std::vector flights = ExampleFlightInfo(); + + for (const auto& info : flights) { + if (info.descriptor().Equals(request)) { + *out = std::make_unique(info); + return Status::OK(); + } + } + return Status::Invalid("Flight not found: ", request.ToString()); +} + +Status TestFlightServer::DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* data_stream) { + // Test for ARROW-5095 + if (request.ticket == "ARROW-5095-fail") { + return Status::UnknownError("Server-side error"); + } + if (request.ticket == "ARROW-5095-success") { + return Status::OK(); + } + if (request.ticket == "ARROW-13253-DoGet-Batch") { + // Make batch > 2GiB in size + ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch()); + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch})); + *data_stream = std::make_unique(std::move(reader)); + return Status::OK(); + } + if (request.ticket == "ticket-stream-error") { + auto reader = std::make_shared(); + *data_stream = std::make_unique(std::move(reader)); + return Status::OK(); + } + + std::shared_ptr batch_reader; + RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader)); + + *data_stream = std::make_unique(batch_reader); + return Status::OK(); +} + +Status TestFlightServer::DoPut(const ServerCallContext&, + std::unique_ptr reader, + std::unique_ptr writer) { + return reader->ToRecordBatches().status(); +} + +Status TestFlightServer::DoExchange(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) { + // Test various scenarios for a DoExchange + if (reader->descriptor().type != FlightDescriptor::DescriptorType::CMD) { + return Status::Invalid("Must provide a command descriptor"); + } + + const std::string& cmd = reader->descriptor().cmd; + if (cmd == "error") { + // Immediately return an error to the client. + return Status::NotImplemented("Expected error"); + } else if (cmd == "get") { + return RunExchangeGet(std::move(reader), std::move(writer)); + } else if (cmd == "put") { + return RunExchangePut(std::move(reader), std::move(writer)); + } else if (cmd == "counter") { + return RunExchangeCounter(std::move(reader), std::move(writer)); + } else if (cmd == "total") { + return RunExchangeTotal(std::move(reader), std::move(writer)); + } else if (cmd == "echo") { + return RunExchangeEcho(std::move(reader), std::move(writer)); + } else if (cmd == "large_batch") { + return RunExchangeLargeBatch(std::move(reader), std::move(writer)); + } else if (cmd == "TestUndrained") { + ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); + return Status::OK(); + } else { + return Status::NotImplemented("Scenario not implemented: ", cmd); + } +} + +// A simple example - act like DoGet. +Status TestFlightServer::RunExchangeGet(std::unique_ptr reader, + std::unique_ptr writer) { + RETURN_NOT_OK(writer->Begin(ExampleIntSchema())); + RecordBatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + for (const auto& batch : batches) { + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + } + return Status::OK(); +} + +// A simple example - act like DoPut +Status TestFlightServer::RunExchangePut(std::unique_ptr reader, + std::unique_ptr writer) { + ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); + if (!schema->Equals(ExampleIntSchema(), false)) { + return Status::Invalid("Schema is not as expected"); + } + RecordBatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + FlightStreamChunk chunk; + for (const auto& batch : batches) { + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (!chunk.data) { + return Status::Invalid("Expected another batch"); + } + if (!batch->Equals(*chunk.data)) { + return Status::Invalid("Batch does not match"); + } + } + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (chunk.data || chunk.app_metadata) { + return Status::Invalid("Too many batches"); + } + + RETURN_NOT_OK(writer->WriteMetadata(Buffer::FromString("done"))); + return Status::OK(); +} + +// Read some number of record batches from the client, send a +// metadata message back with the count, then echo the batches back. +Status TestFlightServer::RunExchangeCounter(std::unique_ptr reader, + std::unique_ptr writer) { + std::vector> batches; + FlightStreamChunk chunk; + int chunks = 0; + while (true) { + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (!chunk.data && !chunk.app_metadata) { + break; + } + if (chunk.data) { + batches.push_back(chunk.data); + chunks++; + } + } + + // Echo back the number of record batches read. + std::shared_ptr buf = Buffer::FromString(std::to_string(chunks)); + RETURN_NOT_OK(writer->WriteMetadata(buf)); + // Echo the record batches themselves. + if (chunks > 0) { + ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); + RETURN_NOT_OK(writer->Begin(schema)); + + for (const auto& batch : batches) { + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + } + } + + return Status::OK(); +} + +// Read int64 batches from the client, each time sending back a +// batch with a running sum of columns. +Status TestFlightServer::RunExchangeTotal(std::unique_ptr reader, + std::unique_ptr writer) { + FlightStreamChunk chunk{}; + ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); + // Ensure the schema contains only int64 columns + for (const auto& field : schema->fields()) { + if (field->type()->id() != Type::type::INT64) { + return Status::Invalid("Field is not INT64: ", field->name()); + } + } + std::vector sums(schema->num_fields()); + std::vector> columns(schema->num_fields()); + RETURN_NOT_OK(writer->Begin(schema)); + while (true) { + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (!chunk.data && !chunk.app_metadata) { + break; + } + if (chunk.data) { + if (!chunk.data->schema()->Equals(schema, false)) { + // A compliant client implementation would make this impossible + return Status::Invalid("Schemas are incompatible"); + } + + // Update the running totals + auto builder = std::make_shared(); + int col_index = 0; + for (const auto& column : chunk.data->columns()) { + auto arr = std::dynamic_pointer_cast(column); + if (!arr) { + return MakeFlightError(FlightStatusCode::Internal, "Could not cast array"); + } + for (int row = 0; row < column->length(); row++) { + if (!arr->IsNull(row)) { + sums[col_index] += arr->Value(row); + } + } + + builder->Reset(); + RETURN_NOT_OK(builder->Append(sums[col_index])); + RETURN_NOT_OK(builder->Finish(&columns[col_index])); + + col_index++; + } + + // Echo the totals to the client + auto response = RecordBatch::Make(schema, /* num_rows */ 1, columns); + RETURN_NOT_OK(writer->WriteRecordBatch(*response)); + } + } + return Status::OK(); +} + +// Echo the client's messages back. +Status TestFlightServer::RunExchangeEcho(std::unique_ptr reader, + std::unique_ptr writer) { + FlightStreamChunk chunk; + bool begun = false; + while (true) { + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (!chunk.data && !chunk.app_metadata) { + break; + } + if (!begun && chunk.data) { + begun = true; + RETURN_NOT_OK(writer->Begin(chunk.data->schema())); + } + if (chunk.data && chunk.app_metadata) { + RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata)); + } else if (chunk.data) { + RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data)); + } else if (chunk.app_metadata) { + RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata)); + } + } + return Status::OK(); +} + +// Regression test for ARROW-13253 +Status TestFlightServer::RunExchangeLargeBatch( + std::unique_ptr, std::unique_ptr writer) { + ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch()); + RETURN_NOT_OK(writer->Begin(batch->schema())); + return writer->WriteRecordBatch(*batch); +} + +Status TestFlightServer::RunAction1(const Action& action, + std::unique_ptr* out) { + std::vector results; + for (int i = 0; i < 3; ++i) { + Result result; + std::string value = action.body->ToString() + "-part" + std::to_string(i); + result.body = Buffer::FromString(std::move(value)); + results.push_back(result); + } + *out = std::make_unique(std::move(results)); + return Status::OK(); +} + +Status TestFlightServer::RunAction2(std::unique_ptr* out) { + // Empty + *out = std::make_unique(std::vector{}); + return Status::OK(); +} + +Status TestFlightServer::ListIncomingHeaders(const ServerCallContext& context, + const Action& action, + std::unique_ptr* out) { + std::vector results; + std::string_view prefix(*action.body); + for (const auto& header : context.incoming_headers()) { + if (header.first.substr(0, prefix.size()) != prefix) { + continue; + } + Result result; + result.body = + Buffer::FromString(std::string(header.first) + ": " + std::string(header.second)); + results.push_back(result); + } + *out = std::make_unique(std::move(results)); + return Status::OK(); +} + +Status TestFlightServer::DoAction(const ServerCallContext& context, const Action& action, + std::unique_ptr* out) { + if (action.type == "action1") { + return RunAction1(action, out); + } else if (action.type == "action2") { + return RunAction2(out); + } else if (action.type == "list-incoming-headers") { + return ListIncomingHeaders(context, action, out); + } else { + return Status::NotImplemented(action.type); + } +} + +Status TestFlightServer::ListActions(const ServerCallContext& context, + std::vector* out) { + std::vector actions = ExampleActionTypes(); + *out = std::move(actions); + return Status::OK(); +} + +Status TestFlightServer::GetSchema(const ServerCallContext& context, + const FlightDescriptor& request, + std::unique_ptr* schema) { + std::vector flights = ExampleFlightInfo(); + + for (const auto& info : flights) { + if (info.descriptor().Equals(request)) { + *schema = std::make_unique(info.serialized_schema()); + return Status::OK(); + } + } + return Status::Invalid("Flight not found: ", request.ToString()); +} + +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_flight_server.h b/cpp/src/arrow/flight/test_flight_server.h new file mode 100644 index 00000000000..794dd834c01 --- /dev/null +++ b/cpp/src/arrow/flight/test_flight_server.h @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/flight/server.h" +#include "arrow/flight/type_fwd.h" +#include "arrow/flight/visibility.h" +#include "arrow/status.h" + +namespace arrow::flight { + +class ARROW_FLIGHT_EXPORT TestFlightServer : public FlightServerBase { + public: + static std::unique_ptr Make(); + + Status ListFlights(const ServerCallContext& context, const Criteria* criteria, + std::unique_ptr* listings) override; + + Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, + std::unique_ptr* out) override; + + Status DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* data_stream) override; + + Status DoPut(const ServerCallContext&, std::unique_ptr reader, + std::unique_ptr writer) override; + + Status DoExchange(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) override; + + // A simple example - act like DoGet. + Status RunExchangeGet(std::unique_ptr reader, + std::unique_ptr writer); + + // A simple example - act like DoPut + Status RunExchangePut(std::unique_ptr reader, + std::unique_ptr writer); + + // Read some number of record batches from the client, send a + // metadata message back with the count, then echo the batches back. + Status RunExchangeCounter(std::unique_ptr reader, + std::unique_ptr writer); + + // Read int64 batches from the client, each time sending back a + // batch with a running sum of columns. + Status RunExchangeTotal(std::unique_ptr reader, + std::unique_ptr writer); + + // Echo the client's messages back. + Status RunExchangeEcho(std::unique_ptr reader, + std::unique_ptr writer); + + // Regression test for ARROW-13253 + Status RunExchangeLargeBatch(std::unique_ptr, + std::unique_ptr writer); + + Status RunAction1(const Action& action, std::unique_ptr* out); + + Status RunAction2(std::unique_ptr* out); + + Status ListIncomingHeaders(const ServerCallContext& context, const Action& action, + std::unique_ptr* out); + + Status DoAction(const ServerCallContext& context, const Action& action, + std::unique_ptr* out) override; + + Status ListActions(const ServerCallContext& context, + std::vector* out) override; + + Status GetSchema(const ServerCallContext& context, const FlightDescriptor& request, + std::unique_ptr* schema) override; +}; + +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_server.cc b/cpp/src/arrow/flight/test_server.cc index 18bf2b41359..ba84b8f532e 100644 --- a/cpp/src/arrow/flight/test_server.cc +++ b/cpp/src/arrow/flight/test_server.cc @@ -26,6 +26,7 @@ #include #include "arrow/flight/server.h" +#include "arrow/flight/test_flight_server.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/util/logging.h" @@ -38,7 +39,7 @@ std::unique_ptr g_server; int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); - g_server = arrow::flight::ExampleTestServer(); + g_server = arrow::flight::TestFlightServer::Make(); arrow::flight::Location location; if (FLAGS_unix.empty()) { diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index 8b4245e74e8..127827ff38c 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -49,8 +49,7 @@ #include "arrow/flight/api.h" #include "arrow/flight/serialization_internal.h" -namespace arrow { -namespace flight { +namespace arrow::flight { namespace bp = boost::process; namespace fs = boost::filesystem; @@ -90,25 +89,6 @@ Status ResolveCurrentExecutable(fs::path* out) { } } -class ErrorRecordBatchReader : public RecordBatchReader { - public: - ErrorRecordBatchReader() : schema_(arrow::schema({})) {} - - std::shared_ptr schema() const override { return schema_; } - - Status ReadNext(std::shared_ptr* out) override { - *out = nullptr; - return Status::OK(); - } - - Status Close() override { - // This should be propagated over DoGet to the client - return Status::IOError("Expected error"); - } - - private: - std::shared_ptr schema_; -}; } // namespace void TestServer::Start(const std::vector& extra_args) { @@ -171,364 +151,6 @@ int TestServer::port() const { return port_; } const std::string& TestServer::unix_sock() const { return unix_sock_; } -Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr* out) { - if (ticket.ticket == "ticket-ints-1") { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleIntBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); - return Status::OK(); - } else if (ticket.ticket == "ticket-floats-1") { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleFloatBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); - return Status::OK(); - } else if (ticket.ticket == "ticket-dicts-1") { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleDictBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); - return Status::OK(); - } else if (ticket.ticket == "ticket-large-batch-1") { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleLargeBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); - return Status::OK(); - } else { - return Status::NotImplemented("no stream implemented for ticket: " + ticket.ticket); - } -} - -class FlightTestServer : public FlightServerBase { - Status ListFlights(const ServerCallContext& context, const Criteria* criteria, - std::unique_ptr* listings) override { - std::vector flights = ExampleFlightInfo(); - if (criteria && criteria->expression != "") { - // For test purposes, if we get criteria, return no results - flights.clear(); - } - *listings = std::make_unique(flights); - return Status::OK(); - } - - Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, - std::unique_ptr* out) override { - // Test that Arrow-C++ status codes make it through the transport - if (request.type == FlightDescriptor::DescriptorType::CMD && - request.cmd == "status-outofmemory") { - return Status::OutOfMemory("Sentinel"); - } - - std::vector flights = ExampleFlightInfo(); - - for (const auto& info : flights) { - if (info.descriptor().Equals(request)) { - *out = std::make_unique(info); - return Status::OK(); - } - } - return Status::Invalid("Flight not found: ", request.ToString()); - } - - Status DoGet(const ServerCallContext& context, const Ticket& request, - std::unique_ptr* data_stream) override { - // Test for ARROW-5095 - if (request.ticket == "ARROW-5095-fail") { - return Status::UnknownError("Server-side error"); - } - if (request.ticket == "ARROW-5095-success") { - return Status::OK(); - } - if (request.ticket == "ARROW-13253-DoGet-Batch") { - // Make batch > 2GiB in size - ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch()); - ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch})); - *data_stream = std::make_unique(std::move(reader)); - return Status::OK(); - } - if (request.ticket == "ticket-stream-error") { - auto reader = std::make_shared(); - *data_stream = std::make_unique(std::move(reader)); - return Status::OK(); - } - - std::shared_ptr batch_reader; - RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader)); - - *data_stream = std::make_unique(batch_reader); - return Status::OK(); - } - - Status DoPut(const ServerCallContext&, std::unique_ptr reader, - std::unique_ptr writer) override { - return reader->ToRecordBatches().status(); - } - - Status DoExchange(const ServerCallContext& context, - std::unique_ptr reader, - std::unique_ptr writer) override { - // Test various scenarios for a DoExchange - if (reader->descriptor().type != FlightDescriptor::DescriptorType::CMD) { - return Status::Invalid("Must provide a command descriptor"); - } - - const std::string& cmd = reader->descriptor().cmd; - if (cmd == "error") { - // Immediately return an error to the client. - return Status::NotImplemented("Expected error"); - } else if (cmd == "get") { - return RunExchangeGet(std::move(reader), std::move(writer)); - } else if (cmd == "put") { - return RunExchangePut(std::move(reader), std::move(writer)); - } else if (cmd == "counter") { - return RunExchangeCounter(std::move(reader), std::move(writer)); - } else if (cmd == "total") { - return RunExchangeTotal(std::move(reader), std::move(writer)); - } else if (cmd == "echo") { - return RunExchangeEcho(std::move(reader), std::move(writer)); - } else if (cmd == "large_batch") { - return RunExchangeLargeBatch(std::move(reader), std::move(writer)); - } else if (cmd == "TestUndrained") { - ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); - return Status::OK(); - } else { - return Status::NotImplemented("Scenario not implemented: ", cmd); - } - } - - // A simple example - act like DoGet. - Status RunExchangeGet(std::unique_ptr reader, - std::unique_ptr writer) { - RETURN_NOT_OK(writer->Begin(ExampleIntSchema())); - RecordBatchVector batches; - RETURN_NOT_OK(ExampleIntBatches(&batches)); - for (const auto& batch : batches) { - RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); - } - return Status::OK(); - } - - // A simple example - act like DoPut - Status RunExchangePut(std::unique_ptr reader, - std::unique_ptr writer) { - ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); - if (!schema->Equals(ExampleIntSchema(), false)) { - return Status::Invalid("Schema is not as expected"); - } - RecordBatchVector batches; - RETURN_NOT_OK(ExampleIntBatches(&batches)); - FlightStreamChunk chunk; - for (const auto& batch : batches) { - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (!chunk.data) { - return Status::Invalid("Expected another batch"); - } - if (!batch->Equals(*chunk.data)) { - return Status::Invalid("Batch does not match"); - } - } - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (chunk.data || chunk.app_metadata) { - return Status::Invalid("Too many batches"); - } - - RETURN_NOT_OK(writer->WriteMetadata(Buffer::FromString("done"))); - return Status::OK(); - } - - // Read some number of record batches from the client, send a - // metadata message back with the count, then echo the batches back. - Status RunExchangeCounter(std::unique_ptr reader, - std::unique_ptr writer) { - std::vector> batches; - FlightStreamChunk chunk; - int chunks = 0; - while (true) { - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (!chunk.data && !chunk.app_metadata) { - break; - } - if (chunk.data) { - batches.push_back(chunk.data); - chunks++; - } - } - - // Echo back the number of record batches read. - std::shared_ptr buf = Buffer::FromString(std::to_string(chunks)); - RETURN_NOT_OK(writer->WriteMetadata(buf)); - // Echo the record batches themselves. - if (chunks > 0) { - ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); - RETURN_NOT_OK(writer->Begin(schema)); - - for (const auto& batch : batches) { - RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); - } - } - - return Status::OK(); - } - - // Read int64 batches from the client, each time sending back a - // batch with a running sum of columns. - Status RunExchangeTotal(std::unique_ptr reader, - std::unique_ptr writer) { - FlightStreamChunk chunk{}; - ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); - // Ensure the schema contains only int64 columns - for (const auto& field : schema->fields()) { - if (field->type()->id() != Type::type::INT64) { - return Status::Invalid("Field is not INT64: ", field->name()); - } - } - std::vector sums(schema->num_fields()); - std::vector> columns(schema->num_fields()); - RETURN_NOT_OK(writer->Begin(schema)); - while (true) { - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (!chunk.data && !chunk.app_metadata) { - break; - } - if (chunk.data) { - if (!chunk.data->schema()->Equals(schema, false)) { - // A compliant client implementation would make this impossible - return Status::Invalid("Schemas are incompatible"); - } - - // Update the running totals - auto builder = std::make_shared(); - int col_index = 0; - for (const auto& column : chunk.data->columns()) { - auto arr = std::dynamic_pointer_cast(column); - if (!arr) { - return MakeFlightError(FlightStatusCode::Internal, "Could not cast array"); - } - for (int row = 0; row < column->length(); row++) { - if (!arr->IsNull(row)) { - sums[col_index] += arr->Value(row); - } - } - - builder->Reset(); - RETURN_NOT_OK(builder->Append(sums[col_index])); - RETURN_NOT_OK(builder->Finish(&columns[col_index])); - - col_index++; - } - - // Echo the totals to the client - auto response = RecordBatch::Make(schema, /* num_rows */ 1, columns); - RETURN_NOT_OK(writer->WriteRecordBatch(*response)); - } - } - return Status::OK(); - } - - // Echo the client's messages back. - Status RunExchangeEcho(std::unique_ptr reader, - std::unique_ptr writer) { - FlightStreamChunk chunk; - bool begun = false; - while (true) { - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (!chunk.data && !chunk.app_metadata) { - break; - } - if (!begun && chunk.data) { - begun = true; - RETURN_NOT_OK(writer->Begin(chunk.data->schema())); - } - if (chunk.data && chunk.app_metadata) { - RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata)); - } else if (chunk.data) { - RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data)); - } else if (chunk.app_metadata) { - RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata)); - } - } - return Status::OK(); - } - - // Regression test for ARROW-13253 - Status RunExchangeLargeBatch(std::unique_ptr, - std::unique_ptr writer) { - ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch()); - RETURN_NOT_OK(writer->Begin(batch->schema())); - return writer->WriteRecordBatch(*batch); - } - - Status RunAction1(const Action& action, std::unique_ptr* out) { - std::vector results; - for (int i = 0; i < 3; ++i) { - Result result; - std::string value = action.body->ToString() + "-part" + std::to_string(i); - result.body = Buffer::FromString(std::move(value)); - results.push_back(result); - } - *out = std::make_unique(std::move(results)); - return Status::OK(); - } - - Status RunAction2(std::unique_ptr* out) { - // Empty - *out = std::make_unique(std::vector{}); - return Status::OK(); - } - - Status ListIncomingHeaders(const ServerCallContext& context, const Action& action, - std::unique_ptr* out) { - std::vector results; - std::string_view prefix(*action.body); - for (const auto& header : context.incoming_headers()) { - if (header.first.substr(0, prefix.size()) != prefix) { - continue; - } - Result result; - result.body = Buffer::FromString(std::string(header.first) + ": " + - std::string(header.second)); - results.push_back(result); - } - *out = std::make_unique(std::move(results)); - return Status::OK(); - } - - Status DoAction(const ServerCallContext& context, const Action& action, - std::unique_ptr* out) override { - if (action.type == "action1") { - return RunAction1(action, out); - } else if (action.type == "action2") { - return RunAction2(out); - } else if (action.type == "list-incoming-headers") { - return ListIncomingHeaders(context, action, out); - } else { - return Status::NotImplemented(action.type); - } - } - - Status ListActions(const ServerCallContext& context, - std::vector* out) override { - std::vector actions = ExampleActionTypes(); - *out = std::move(actions); - return Status::OK(); - } - - Status GetSchema(const ServerCallContext& context, const FlightDescriptor& request, - std::unique_ptr* schema) override { - std::vector flights = ExampleFlightInfo(); - - for (const auto& info : flights) { - if (info.descriptor().Equals(request)) { - *schema = std::make_unique(info.serialized_schema()); - return Status::OK(); - } - } - return Status::Invalid("Flight not found: ", request.ToString()); - } -}; - -std::unique_ptr ExampleTestServer() { - return std::make_unique(); -} - FlightInfo MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, const std::vector& endpoints, int64_t total_records, int64_t total_bytes, bool ordered, @@ -701,109 +323,6 @@ std::vector ExampleActionTypes() { return {{"drop", "drop a dataset"}, {"cache", "cache a dataset"}}; } -TestServerAuthHandler::TestServerAuthHandler(const std::string& username, - const std::string& password) - : username_(username), password_(password) {} - -TestServerAuthHandler::~TestServerAuthHandler() {} - -Status TestServerAuthHandler::Authenticate(const ServerCallContext& context, - ServerAuthSender* outgoing, - ServerAuthReader* incoming) { - std::string token; - RETURN_NOT_OK(incoming->Read(&token)); - if (token != password_) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - RETURN_NOT_OK(outgoing->Write(username_)); - return Status::OK(); -} - -Status TestServerAuthHandler::IsValid(const ServerCallContext& context, - const std::string& token, - std::string* peer_identity) { - if (token != password_) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - *peer_identity = username_; - return Status::OK(); -} - -TestServerBasicAuthHandler::TestServerBasicAuthHandler(const std::string& username, - const std::string& password) { - basic_auth_.username = username; - basic_auth_.password = password; -} - -TestServerBasicAuthHandler::~TestServerBasicAuthHandler() {} - -Status TestServerBasicAuthHandler::Authenticate(const ServerCallContext& context, - ServerAuthSender* outgoing, - ServerAuthReader* incoming) { - std::string token; - RETURN_NOT_OK(incoming->Read(&token)); - ARROW_ASSIGN_OR_RAISE(BasicAuth incoming_auth, BasicAuth::Deserialize(token)); - if (incoming_auth.username != basic_auth_.username || - incoming_auth.password != basic_auth_.password) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - RETURN_NOT_OK(outgoing->Write(basic_auth_.username)); - return Status::OK(); -} - -Status TestServerBasicAuthHandler::IsValid(const ServerCallContext& context, - const std::string& token, - std::string* peer_identity) { - if (token != basic_auth_.username) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - *peer_identity = basic_auth_.username; - return Status::OK(); -} - -TestClientAuthHandler::TestClientAuthHandler(const std::string& username, - const std::string& password) - : username_(username), password_(password) {} - -TestClientAuthHandler::~TestClientAuthHandler() {} - -Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing, - ClientAuthReader* incoming) { - RETURN_NOT_OK(outgoing->Write(password_)); - std::string username; - RETURN_NOT_OK(incoming->Read(&username)); - if (username != username_) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - return Status::OK(); -} - -Status TestClientAuthHandler::GetToken(std::string* token) { - *token = password_; - return Status::OK(); -} - -TestClientBasicAuthHandler::TestClientBasicAuthHandler(const std::string& username, - const std::string& password) { - basic_auth_.username = username; - basic_auth_.password = password; -} - -TestClientBasicAuthHandler::~TestClientBasicAuthHandler() {} - -Status TestClientBasicAuthHandler::Authenticate(ClientAuthSender* outgoing, - ClientAuthReader* incoming) { - ARROW_ASSIGN_OR_RAISE(std::string pb_result, basic_auth_.SerializeToString()); - RETURN_NOT_OK(outgoing->Write(pb_result)); - RETURN_NOT_OK(incoming->Read(&token_)); - return Status::OK(); -} - -Status TestClientBasicAuthHandler::GetToken(std::string* token) { - *token = token_; - return Status::OK(); -} - Status ExampleTlsCertificates(std::vector* out) { std::string root; RETURN_NOT_OK(GetTestResourceRoot(&root)); @@ -860,5 +379,4 @@ Status ExampleTlsCertificateRoot(CertKeyPair* out) { } } -} // namespace flight -} // namespace arrow +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h index c0b42d9b90c..15ba6145ecd 100644 --- a/cpp/src/arrow/flight/test_util.h +++ b/cpp/src/arrow/flight/test_util.h @@ -32,9 +32,7 @@ #include "arrow/testing/util.h" #include "arrow/flight/client.h" -#include "arrow/flight/client_auth.h" #include "arrow/flight/server.h" -#include "arrow/flight/server_auth.h" #include "arrow/flight/types.h" #include "arrow/flight/visibility.h" @@ -95,10 +93,6 @@ class ARROW_FLIGHT_EXPORT TestServer { std::shared_ptr<::boost::process::child> server_process_; }; -/// \brief Create a simple Flight server for testing -ARROW_FLIGHT_EXPORT -std::unique_ptr ExampleTestServer(); - // Helper to initialize a server and matching client with callbacks to // populate options. template @@ -195,65 +189,6 @@ FlightInfo MakeFlightInfo(const Schema& schema, const FlightDescriptor& descript int64_t total_records, int64_t total_bytes, bool ordered, std::string app_metadata); -// ---------------------------------------------------------------------- -// A pair of authentication handlers that check for a predefined password -// and set the peer identity to a predefined username. - -class ARROW_FLIGHT_EXPORT TestServerAuthHandler : public ServerAuthHandler { - public: - explicit TestServerAuthHandler(const std::string& username, - const std::string& password); - ~TestServerAuthHandler() override; - Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing, - ServerAuthReader* incoming) override; - Status IsValid(const ServerCallContext& context, const std::string& token, - std::string* peer_identity) override; - - private: - std::string username_; - std::string password_; -}; - -class ARROW_FLIGHT_EXPORT TestServerBasicAuthHandler : public ServerAuthHandler { - public: - explicit TestServerBasicAuthHandler(const std::string& username, - const std::string& password); - ~TestServerBasicAuthHandler() override; - Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing, - ServerAuthReader* incoming) override; - Status IsValid(const ServerCallContext& context, const std::string& token, - std::string* peer_identity) override; - - private: - BasicAuth basic_auth_; -}; - -class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler { - public: - explicit TestClientAuthHandler(const std::string& username, - const std::string& password); - ~TestClientAuthHandler() override; - Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override; - Status GetToken(std::string* token) override; - - private: - std::string username_; - std::string password_; -}; - -class ARROW_FLIGHT_EXPORT TestClientBasicAuthHandler : public ClientAuthHandler { - public: - explicit TestClientBasicAuthHandler(const std::string& username, - const std::string& password); - ~TestClientBasicAuthHandler() override; - Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override; - Status GetToken(std::string* token) override; - - private: - BasicAuth basic_auth_; - std::string token_; -}; - ARROW_FLIGHT_EXPORT Status ExampleTlsCertificates(std::vector* out);