diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 64b8548f64d..6bafcb99795 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -107,7 +107,7 @@ cdef class FlightCallOptions(_Weakrefable): cdef: CFlightCallOptions options - def __init__(self, timeout=None, write_options=None): + def __init__(self, timeout=None, write_options=None, headers=None): """Create call options. Parameters @@ -118,12 +118,18 @@ cdef class FlightCallOptions(_Weakrefable): write_options : pyarrow.ipc.IpcWriteOptions, optional IPC write options. The default options can be controlled by environment variables (see pyarrow.ipc). - + headers : List[Tuple[str, str]], optional + A list of arbitrary headers as key, value tuples """ - cdef IpcWriteOptions options = _get_options(write_options) + cdef IpcWriteOptions c_write_options + if timeout is not None: self.options.timeout = CTimeoutDuration(timeout) - self.options.write_options = options.c_options + if write_options is not None: + c_write_options = _get_options(write_options) + self.options.write_options = c_write_options.c_options + if headers is not None: + self.options.headers = headers @staticmethod cdef CFlightCallOptions* unwrap(obj): @@ -1150,6 +1156,38 @@ cdef class FlightClient(_Weakrefable): self.client.get().Authenticate(deref(c_options), move(handler))) + def authenticate_basic_token(self, username, password, + options: FlightCallOptions = None): + """Authenticate to the server with HTTP basic authentication. + + Parameters + ---------- + username : string + Username to authenticate with + password : string + Password to authenticate with + options : FlightCallOptions + Options for this call + + Returns + ------- + tuple : Tuple[str, str] + A tuple representing the FlightCallOptions authorization + header entry of a bearer token. + """ + cdef: + CResult[pair[c_string, c_string]] result + CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) + c_string user = tobytes(username) + c_string pw = tobytes(password) + + with nogil: + result = self.client.get().AuthenticateBasicToken(deref(c_options), + user, pw) + check_flight_status(result.status()) + + return GetResultValue(result) + def list_actions(self, options: FlightCallOptions = None): """List the actions available on a service.""" cdef: @@ -1871,7 +1909,6 @@ cdef CStatus _server_authenticate(void* self, CServerAuthSender* outgoing, reader.poison() return CStatus_OK() - cdef CStatus _is_valid(void* self, const c_string& token, c_string* peer_identity) except *: """Callback for implementing authentication in Python.""" diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 2b4283f9760..cd052862200 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -216,6 +216,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CFlightCallOptions() CTimeoutDuration timeout CIpcWriteOptions write_options + vector[pair[c_string, c_string]] headers cdef cppclass CCertKeyPair" arrow::flight::CertKeyPair": CCertKeyPair() @@ -307,6 +308,11 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CStatus Authenticate(CFlightCallOptions& options, unique_ptr[CClientAuthHandler] auth_handler) + CResult[pair[c_string, c_string]] AuthenticateBasicToken( + CFlightCallOptions& options, + const c_string& username, + const c_string& password) + CStatus DoAction(CFlightCallOptions& options, CAction& action, unique_ptr[CResultStream]* results) CStatus ListActions(CFlightCallOptions& options, diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 3a5f5778f62..45ba5c2dac9 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -46,7 +46,6 @@ ServerMiddleware, ServerMiddlewareFactory = object, object ClientMiddleware, ClientMiddlewareFactory = object, object - # Marks all of the tests in this module # Ignore these with pytest ... -m 'not flight' pytestmark = pytest.mark.flight @@ -506,6 +505,158 @@ def get_token(self): return self.token +class NoopAuthHandler(ServerAuthHandler): + """A no-op auth handler.""" + + def authenticate(self, outgoing, incoming): + """Do nothing.""" + + def is_valid(self, token): + """ + Returning an empty string. + Returning None causes Type error. + """ + return "" + + +def case_insensitive_header_lookup(headers, lookup_key): + """Lookup the value of given key in the given headers. + The key lookup is case insensitive. + """ + for key in headers: + if key.lower() == lookup_key.lower(): + return headers.get(key) + + +class ClientHeaderAuthMiddlewareFactory(ClientMiddlewareFactory): + """ClientMiddlewareFactory that creates ClientAuthHeaderMiddleware.""" + + def __init__(self): + self.call_credential = [] + + def start_call(self, info): + return ClientHeaderAuthMiddleware(self) + + def set_call_credential(self, call_credential): + self.call_credential = call_credential + + +class ClientHeaderAuthMiddleware(ClientMiddleware): + """ + ClientMiddleware that extracts the authorization header + from the server. + + This is an example of a ClientMiddleware that can extract + the bearer token authorization header from a HTTP header + authentication enabled server. + + Parameters + ---------- + factory : ClientHeaderAuthMiddlewareFactory + This factory is used to set call credentials if an + authorization header is found in the headers from the server. + """ + + def __init__(self, factory): + self.factory = factory + + def received_headers(self, headers): + auth_header = case_insensitive_header_lookup(headers, 'Authorization') + self.factory.set_call_credential([ + b'authorization', + auth_header[0].encode("utf-8")]) + + +class HeaderAuthServerMiddlewareFactory(ServerMiddlewareFactory): + """Validates incoming username and password.""" + + def start_call(self, info, headers): + auth_header = case_insensitive_header_lookup( + headers, + 'Authorization' + ) + values = auth_header[0].split(' ') + token = '' + error_message = 'Invalid credentials' + + if values[0] == 'Basic': + decoded = base64.b64decode(values[1]) + pair = decoded.decode("utf-8").split(':') + if not (pair[0] == 'test' and pair[1] == 'password'): + raise flight.FlightUnauthenticatedError(error_message) + token = 'token1234' + elif values[0] == 'Bearer': + token = values[1] + if not token == 'token1234': + raise flight.FlightUnauthenticatedError(error_message) + else: + raise flight.FlightUnauthenticatedError(error_message) + + return HeaderAuthServerMiddleware(token) + + +class HeaderAuthServerMiddleware(ServerMiddleware): + """A ServerMiddleware that transports incoming username and passowrd.""" + + def __init__(self, token): + self.token = token + + def sending_headers(self): + return {'authorization': 'Bearer ' + self.token} + + +class HeaderAuthFlightServer(FlightServerBase): + """A Flight server that tests with basic token authentication. """ + + def do_action(self, context, action): + middleware = context.get_middleware("auth") + if middleware: + auth_header = case_insensitive_header_lookup( + middleware.sending_headers(), 'Authorization') + values = auth_header.split(' ') + return [values[1].encode("utf-8")] + raise flight.FlightUnauthenticatedError( + 'No token auth middleware found.') + + +class ArbitraryHeadersServerMiddlewareFactory(ServerMiddlewareFactory): + """A ServerMiddlewareFactory that transports arbitrary headers.""" + + def start_call(self, info, headers): + return ArbitraryHeadersServerMiddleware(headers) + + +class ArbitraryHeadersServerMiddleware(ServerMiddleware): + """A ServerMiddleware that transports arbitrary headers.""" + + def __init__(self, incoming): + self.incoming = incoming + + def sending_headers(self): + return self.incoming + + +class ArbitraryHeadersFlightServer(FlightServerBase): + """A Flight server that tests multiple arbitrary headers.""" + + def do_action(self, context, action): + middleware = context.get_middleware("arbitrary-headers") + if middleware: + headers = middleware.sending_headers() + header_1 = case_insensitive_header_lookup( + headers, + 'test-header-1' + ) + header_2 = case_insensitive_header_lookup( + headers, + 'test-header-2' + ) + value1 = header_1[0].encode("utf-8") + value2 = header_2[0].encode("utf-8") + return [value1, value2] + raise flight.FlightServerError("No headers middleware found") + + class HeaderServerMiddleware(ServerMiddleware): """Expose a per-call value to the RPC method body.""" @@ -788,6 +939,7 @@ class ConvenienceServer(FlightServerBase): """ Server for testing various implementation conveniences (auto-boxing, etc.) """ + @property def simple_action_results(self): return [b'foo', b'bar', b'baz'] @@ -996,6 +1148,100 @@ def test_token_auth_invalid(): client.authenticate(TokenClientAuthHandler('test', 'wrong')) +header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory() +no_op_auth_handler = NoopAuthHandler() + + +def test_authenticate_basic_token(): + """Test authenticate_basic_token with bearer token and auth headers.""" + with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ + "auth": HeaderAuthServerMiddlewareFactory() + }) as server: + client = FlightClient(('localhost', server.port)) + token_pair = client.authenticate_basic_token(b'test', b'password') + assert token_pair[0] == b'authorization' + assert token_pair[1] == b'Bearer token1234' + + +def test_authenticate_basic_token_invalid_password(): + """Test authenticate_basic_token with an invalid password.""" + with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ + "auth": HeaderAuthServerMiddlewareFactory() + }) as server: + client = FlightClient(('localhost', server.port)) + with pytest.raises(flight.FlightUnauthenticatedError): + client.authenticate_basic_token(b'test', b'badpassword') + + +def test_authenticate_basic_token_and_action(): + """Test authenticate_basic_token and doAction after authentication.""" + with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ + "auth": HeaderAuthServerMiddlewareFactory() + }) as server: + client = FlightClient(('localhost', server.port)) + token_pair = client.authenticate_basic_token(b'test', b'password') + assert token_pair[0] == b'authorization' + assert token_pair[1] == b'Bearer token1234' + options = flight.FlightCallOptions(headers=[token_pair]) + result = list(client.do_action( + action=flight.Action('test-action', b''), options=options)) + assert result[0].body.to_pybytes() == b'token1234' + + +def test_authenticate_basic_token_with_client_middleware(): + """Test authenticate_basic_token with client middleware + to intercept authorization header returned by the + HTTP header auth enabled server. + """ + with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ + "auth": HeaderAuthServerMiddlewareFactory() + }) as server: + client_auth_middleware = ClientHeaderAuthMiddlewareFactory() + client = FlightClient( + ('localhost', server.port), + middleware=[client_auth_middleware] + ) + encoded_credentials = base64.b64encode(b'test:password') + options = flight.FlightCallOptions(headers=[ + (b'authorization', b'Basic ' + encoded_credentials) + ]) + result = list(client.do_action( + action=flight.Action('test-action', b''), options=options)) + assert result[0].body.to_pybytes() == b'token1234' + assert client_auth_middleware.call_credential[0] == b'authorization' + assert client_auth_middleware.call_credential[1] == \ + b'Bearer ' + b'token1234' + result2 = list(client.do_action( + action=flight.Action('test-action', b''), options=options)) + assert result2[0].body.to_pybytes() == b'token1234' + assert client_auth_middleware.call_credential[0] == b'authorization' + assert client_auth_middleware.call_credential[1] == \ + b'Bearer ' + b'token1234' + + +def test_arbitrary_headers_in_flight_call_options(): + """Test passing multiple arbitrary headers to the middleware.""" + with ArbitraryHeadersFlightServer( + auth_handler=no_op_auth_handler, + middleware={ + "auth": HeaderAuthServerMiddlewareFactory(), + "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory() + }) as server: + client = FlightClient(('localhost', server.port)) + token_pair = client.authenticate_basic_token(b'test', b'password') + assert token_pair[0] == b'authorization' + assert token_pair[1] == b'Bearer token1234' + options = flight.FlightCallOptions(headers=[ + token_pair, + (b'test-header-1', b'value1'), + (b'test-header-2', b'value2') + ]) + result = list(client.do_action(flight.Action( + "test-action", b""), options=options)) + assert result[0].body.to_pybytes() == b'value1' + assert result[1].body.to_pybytes() == b'value2' + + def test_location_invalid(): """Test constructing invalid URIs.""" with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): @@ -1295,7 +1541,7 @@ def _reader_thread(): def test_server_middleware_same_thread(): """Ensure that server middleware run on the same thread as the RPC.""" with HeaderFlightServer(middleware={ - "test": HeaderServerMiddlewareFactory(), + "test": HeaderServerMiddlewareFactory(), }) as server: client = FlightClient(('localhost', server.port)) results = list(client.do_action(flight.Action(b"test", b""))) @@ -1307,7 +1553,7 @@ def test_server_middleware_same_thread(): def test_middleware_reject(): """Test rejecting an RPC with server middleware.""" with HeaderFlightServer(middleware={ - "test": SelectiveAuthServerMiddlewareFactory(), + "test": SelectiveAuthServerMiddlewareFactory(), }) as server: client = FlightClient(('localhost', server.port)) # The middleware allows this through without auth. @@ -1526,7 +1772,7 @@ def test_doexchange_transform(): def test_middleware_multi_header(): """Test sending/receiving multiple (binary-valued) headers.""" with MultiHeaderFlightServer(middleware={ - "test": MultiHeaderServerMiddlewareFactory(), + "test": MultiHeaderServerMiddlewareFactory(), }) as server: headers = MultiHeaderClientMiddlewareFactory() client = FlightClient(('localhost', server.port), middleware=[headers])